core: refactor provider and model system (#5033)

Co-authored-by: opencode-agent[bot] <opencode-agent[bot]@users.noreply.github.com>
Co-authored-by: thdxr <thdxr@users.noreply.github.com>
This commit is contained in:
Dax
2025-12-03 21:09:03 -05:00
committed by GitHub
parent ee4437ff32
commit 6d3fc63658
20 changed files with 892 additions and 720 deletions

View File

@@ -11,7 +11,6 @@ import { Agent } from "../agent/agent"
import { Provider } from "../provider/provider"
import {
generateText,
streamText,
type ModelMessage,
type Tool as AITool,
tool,
@@ -288,6 +287,7 @@ export namespace SessionPrompt {
})
const model = await Provider.getModel(lastUser.model.providerID, lastUser.model.modelID)
const language = await Provider.getLanguage(model)
const task = tasks.pop()
// pending subtask
@@ -311,7 +311,7 @@ export namespace SessionPrompt {
reasoning: 0,
cache: { read: 0, write: 0 },
},
modelID: model.modelID,
modelID: model.id,
providerID: model.providerID,
time: {
created: Date.now(),
@@ -408,7 +408,7 @@ export namespace SessionPrompt {
agent: lastUser.agent,
model: {
providerID: model.providerID,
modelID: model.modelID,
modelID: model.id,
},
sessionID,
auto: task.auto,
@@ -421,7 +421,7 @@ export namespace SessionPrompt {
if (
lastFinished &&
lastFinished.summary !== true &&
SessionCompaction.isOverflow({ tokens: lastFinished.tokens, model: model.info })
SessionCompaction.isOverflow({ tokens: lastFinished.tokens, model })
) {
await SessionCompaction.create({
sessionID,
@@ -455,7 +455,7 @@ export namespace SessionPrompt {
reasoning: 0,
cache: { read: 0, write: 0 },
},
modelID: model.modelID,
modelID: model.id,
providerID: model.providerID,
time: {
created: Date.now(),
@@ -463,20 +463,18 @@ export namespace SessionPrompt {
sessionID,
})) as MessageV2.Assistant,
sessionID: sessionID,
model: model.info,
providerID: model.providerID,
model,
abort,
})
const system = await resolveSystemPrompt({
providerID: model.providerID,
modelID: model.info.id,
model,
agent,
system: lastUser.system,
})
const tools = await resolveTools({
agent,
sessionID,
model: lastUser.model,
model,
tools: lastUser.tools,
processor,
})
@@ -486,21 +484,19 @@ export namespace SessionPrompt {
{
sessionID: sessionID,
agent: lastUser.agent,
model: model.info,
model: model,
provider,
message: lastUser,
},
{
temperature: model.info.temperature
? (agent.temperature ?? ProviderTransform.temperature(model.providerID, model.modelID))
temperature: model.capabilities.temperature
? (agent.temperature ?? ProviderTransform.temperature(model))
: undefined,
topP: agent.topP ?? ProviderTransform.topP(model.providerID, model.modelID),
topP: agent.topP ?? ProviderTransform.topP(model),
options: pipe(
{},
mergeDeep(
ProviderTransform.options(model.providerID, model.modelID, model.npm ?? "", sessionID, provider?.options),
),
mergeDeep(model.info.options),
mergeDeep(ProviderTransform.options(model, sessionID, provider?.options)),
mergeDeep(model.options),
mergeDeep(agent.options),
),
},
@@ -513,113 +509,111 @@ export namespace SessionPrompt {
})
}
const result = await processor.process(() =>
streamText({
onError(error) {
log.error("stream error", {
error,
const result = await processor.process({
onError(error) {
log.error("stream error", {
error,
})
},
async experimental_repairToolCall(input) {
const lower = input.toolCall.toolName.toLowerCase()
if (lower !== input.toolCall.toolName && tools[lower]) {
log.info("repairing tool call", {
tool: input.toolCall.toolName,
repaired: lower,
})
},
async experimental_repairToolCall(input) {
const lower = input.toolCall.toolName.toLowerCase()
if (lower !== input.toolCall.toolName && tools[lower]) {
log.info("repairing tool call", {
tool: input.toolCall.toolName,
repaired: lower,
})
return {
...input.toolCall,
toolName: lower,
}
}
return {
...input.toolCall,
input: JSON.stringify({
tool: input.toolCall.toolName,
error: input.error.message,
}),
toolName: "invalid",
toolName: lower,
}
},
headers: {
...(model.providerID.startsWith("opencode")
? {
"x-opencode-project": Instance.project.id,
"x-opencode-session": sessionID,
"x-opencode-request": lastUser.id,
}
: undefined),
...model.info.headers,
},
// set to 0, we handle loop
maxRetries: 0,
activeTools: Object.keys(tools).filter((x) => x !== "invalid"),
maxOutputTokens: ProviderTransform.maxOutputTokens(
model.providerID,
params.options,
model.info.limit.output,
OUTPUT_TOKEN_MAX,
}
return {
...input.toolCall,
input: JSON.stringify({
tool: input.toolCall.toolName,
error: input.error.message,
}),
toolName: "invalid",
}
},
headers: {
...(model.providerID.startsWith("opencode")
? {
"x-opencode-project": Instance.project.id,
"x-opencode-session": sessionID,
"x-opencode-request": lastUser.id,
}
: undefined),
...model.headers,
},
// set to 0, we handle loop
maxRetries: 0,
activeTools: Object.keys(tools).filter((x) => x !== "invalid"),
maxOutputTokens: ProviderTransform.maxOutputTokens(
model.api.npm,
params.options,
model.limit.output,
OUTPUT_TOKEN_MAX,
),
abortSignal: abort,
providerOptions: ProviderTransform.providerOptions(model.api.npm, model.providerID, params.options),
stopWhen: stepCountIs(1),
temperature: params.temperature,
topP: params.topP,
messages: [
...system.map(
(x): ModelMessage => ({
role: "system",
content: x,
}),
),
abortSignal: abort,
providerOptions: ProviderTransform.providerOptions(model.npm, model.providerID, params.options),
stopWhen: stepCountIs(1),
temperature: params.temperature,
topP: params.topP,
messages: [
...system.map(
(x): ModelMessage => ({
role: "system",
content: x,
}),
),
...MessageV2.toModelMessage(
msgs.filter((m) => {
if (m.info.role !== "assistant" || m.info.error === undefined) {
return true
}
if (
MessageV2.AbortedError.isInstance(m.info.error) &&
m.parts.some((part) => part.type !== "step-start" && part.type !== "reasoning")
) {
return true
}
...MessageV2.toModelMessage(
msgs.filter((m) => {
if (m.info.role !== "assistant" || m.info.error === undefined) {
return true
}
if (
MessageV2.AbortedError.isInstance(m.info.error) &&
m.parts.some((part) => part.type !== "step-start" && part.type !== "reasoning")
) {
return true
}
return false
}),
),
],
tools: model.info.tool_call === false ? undefined : tools,
model: wrapLanguageModel({
model: model.language,
middleware: [
{
async transformParams(args) {
if (args.type === "stream") {
// @ts-expect-error
args.params.prompt = ProviderTransform.message(args.params.prompt, model.providerID, model.modelID)
}
// Transform tool schemas for provider compatibility
if (args.params.tools && Array.isArray(args.params.tools)) {
args.params.tools = args.params.tools.map((tool: any) => {
// Tools at middleware level have inputSchema, not parameters
if (tool.inputSchema && typeof tool.inputSchema === "object") {
// Transform the inputSchema for provider compatibility
return {
...tool,
inputSchema: ProviderTransform.schema(model.providerID, model.modelID, tool.inputSchema),
}
return false
}),
),
],
tools: model.capabilities.toolcall === false ? undefined : tools,
model: wrapLanguageModel({
model: language,
middleware: [
{
async transformParams(args) {
if (args.type === "stream") {
// @ts-expect-error - prompt types are compatible at runtime
args.params.prompt = ProviderTransform.message(args.params.prompt, model)
}
// Transform tool schemas for provider compatibility
if (args.params.tools && Array.isArray(args.params.tools)) {
args.params.tools = args.params.tools.map((tool: any) => {
// Tools at middleware level have inputSchema, not parameters
if (tool.inputSchema && typeof tool.inputSchema === "object") {
// Transform the inputSchema for provider compatibility
return {
...tool,
inputSchema: ProviderTransform.schema(model, tool.inputSchema),
}
// If no inputSchema, return tool unchanged
return tool
})
}
return args.params
},
}
// If no inputSchema, return tool unchanged
return tool
})
}
return args.params
},
],
}),
},
],
}),
)
})
if (result === "stop") break
continue
}
@@ -642,18 +636,13 @@ export namespace SessionPrompt {
return Provider.defaultModel()
}
async function resolveSystemPrompt(input: {
system?: string
agent: Agent.Info
providerID: string
modelID: string
}) {
let system = SystemPrompt.header(input.providerID)
async function resolveSystemPrompt(input: { system?: string; agent: Agent.Info; model: Provider.Model }) {
let system = SystemPrompt.header(input.model.providerID)
system.push(
...(() => {
if (input.system) return [input.system]
if (input.agent.prompt) return [input.agent.prompt]
return SystemPrompt.provider(input.modelID)
return SystemPrompt.provider(input.model)
})(),
)
system.push(...(await SystemPrompt.environment()))
@@ -666,10 +655,7 @@ export namespace SessionPrompt {
async function resolveTools(input: {
agent: Agent.Info
model: {
providerID: string
modelID: string
}
model: Provider.Model
sessionID: string
tools?: Record<string, boolean>
processor: SessionProcessor.Info
@@ -677,16 +663,12 @@ export namespace SessionPrompt {
const tools: Record<string, AITool> = {}
const enabledTools = pipe(
input.agent.tools,
mergeDeep(await ToolRegistry.enabled(input.model.providerID, input.model.modelID, input.agent)),
mergeDeep(await ToolRegistry.enabled(input.agent)),
mergeDeep(input.tools ?? {}),
)
for (const item of await ToolRegistry.tools(input.model.providerID, input.model.modelID)) {
for (const item of await ToolRegistry.tools(input.model.providerID)) {
if (Wildcard.all(item.id, enabledTools) === false) continue
const schema = ProviderTransform.schema(
input.model.providerID,
input.model.modelID,
z.toJSONSchema(item.parameters),
)
const schema = ProviderTransform.schema(input.model, z.toJSONSchema(item.parameters))
tools[item.id] = tool({
id: item.id as any,
description: item.description,
@@ -1437,25 +1419,18 @@ export namespace SessionPrompt {
if (!isFirst) return
const small =
(await Provider.getSmallModel(input.providerID)) ?? (await Provider.getModel(input.providerID, input.modelID))
const language = await Provider.getLanguage(small)
const provider = await Provider.getProvider(small.providerID)
const options = pipe(
{},
mergeDeep(
ProviderTransform.options(
small.providerID,
small.modelID,
small.npm ?? "",
input.session.id,
provider?.options,
),
),
mergeDeep(ProviderTransform.smallOptions({ providerID: small.providerID, modelID: small.modelID })),
mergeDeep(small.info.options),
mergeDeep(ProviderTransform.options(small, input.session.id, provider?.options)),
mergeDeep(ProviderTransform.smallOptions(small)),
mergeDeep(small.options),
)
await generateText({
// use higher # for reasoning models since reasoning tokens eat up a lot of the budget
maxOutputTokens: small.info.reasoning ? 3000 : 20,
providerOptions: ProviderTransform.providerOptions(small.npm, small.providerID, options),
maxOutputTokens: small.capabilities.reasoning ? 3000 : 20,
providerOptions: ProviderTransform.providerOptions(small.api.npm, small.providerID, options),
messages: [
...SystemPrompt.title(small.providerID).map(
(x): ModelMessage => ({
@@ -1486,8 +1461,8 @@ export namespace SessionPrompt {
},
]),
],
headers: small.info.headers,
model: small.language,
headers: small.headers,
model: language,
})
.then((result) => {
if (result.text)
@@ -1504,7 +1479,7 @@ export namespace SessionPrompt {
})
})
.catch((error) => {
log.error("failed to generate title", { error, model: small.info.id })
log.error("failed to generate title", { error, model: small.id })
})
}
}