mirror of
https://gitea.toothfairyai.com/ToothFairyAI/tf_code.git
synced 2026-04-02 07:03:45 +00:00
core: refactor provider and model system (#5033)
Co-authored-by: opencode-agent[bot] <opencode-agent[bot]@users.noreply.github.com> Co-authored-by: thdxr <thdxr@users.noreply.github.com>
This commit is contained in:
@@ -1,4 +1,4 @@
|
||||
import { streamText, wrapLanguageModel, type ModelMessage } from "ai"
|
||||
import { wrapLanguageModel, type ModelMessage } from "ai"
|
||||
import { Session } from "."
|
||||
import { Identifier } from "../id/id"
|
||||
import { Instance } from "../project/instance"
|
||||
@@ -7,7 +7,6 @@ import { MessageV2 } from "./message-v2"
|
||||
import { SystemPrompt } from "./system"
|
||||
import { Bus } from "../bus"
|
||||
import z from "zod"
|
||||
import type { ModelsDev } from "../provider/models"
|
||||
import { SessionPrompt } from "./prompt"
|
||||
import { Flag } from "../flag/flag"
|
||||
import { Token } from "../util/token"
|
||||
@@ -29,7 +28,7 @@ export namespace SessionCompaction {
|
||||
),
|
||||
}
|
||||
|
||||
export function isOverflow(input: { tokens: MessageV2.Assistant["tokens"]; model: ModelsDev.Model }) {
|
||||
export function isOverflow(input: { tokens: MessageV2.Assistant["tokens"]; model: Provider.Model }) {
|
||||
if (Flag.OPENCODE_DISABLE_AUTOCOMPACT) return false
|
||||
const context = input.model.limit.context
|
||||
if (context === 0) return false
|
||||
@@ -98,6 +97,7 @@ export namespace SessionCompaction {
|
||||
auto: boolean
|
||||
}) {
|
||||
const model = await Provider.getModel(input.model.providerID, input.model.modelID)
|
||||
const language = await Provider.getLanguage(model)
|
||||
const system = [...SystemPrompt.compaction(model.providerID)]
|
||||
const msg = (await Session.updateMessage({
|
||||
id: Identifier.ascending("message"),
|
||||
@@ -126,79 +126,72 @@ export namespace SessionCompaction {
|
||||
const processor = SessionProcessor.create({
|
||||
assistantMessage: msg,
|
||||
sessionID: input.sessionID,
|
||||
providerID: input.model.providerID,
|
||||
model: model.info,
|
||||
model: model,
|
||||
abort: input.abort,
|
||||
})
|
||||
const result = await processor.process(() =>
|
||||
streamText({
|
||||
onError(error) {
|
||||
log.error("stream error", {
|
||||
error,
|
||||
})
|
||||
},
|
||||
// set to 0, we handle loop
|
||||
maxRetries: 0,
|
||||
providerOptions: ProviderTransform.providerOptions(
|
||||
model.npm,
|
||||
model.providerID,
|
||||
pipe(
|
||||
{},
|
||||
mergeDeep(ProviderTransform.options(model.providerID, model.modelID, model.npm ?? "", input.sessionID)),
|
||||
mergeDeep(model.info.options),
|
||||
),
|
||||
const result = await processor.process({
|
||||
onError(error) {
|
||||
log.error("stream error", {
|
||||
error,
|
||||
})
|
||||
},
|
||||
// set to 0, we handle loop
|
||||
maxRetries: 0,
|
||||
providerOptions: ProviderTransform.providerOptions(
|
||||
model.api.npm,
|
||||
model.providerID,
|
||||
pipe({}, mergeDeep(ProviderTransform.options(model, input.sessionID)), mergeDeep(model.options)),
|
||||
),
|
||||
headers: model.headers,
|
||||
abortSignal: input.abort,
|
||||
tools: model.capabilities.toolcall ? {} : undefined,
|
||||
messages: [
|
||||
...system.map(
|
||||
(x): ModelMessage => ({
|
||||
role: "system",
|
||||
content: x,
|
||||
}),
|
||||
),
|
||||
headers: model.info.headers,
|
||||
abortSignal: input.abort,
|
||||
tools: model.info.tool_call ? {} : undefined,
|
||||
messages: [
|
||||
...system.map(
|
||||
(x): ModelMessage => ({
|
||||
role: "system",
|
||||
content: x,
|
||||
}),
|
||||
),
|
||||
...MessageV2.toModelMessage(
|
||||
input.messages.filter((m) => {
|
||||
if (m.info.role !== "assistant" || m.info.error === undefined) {
|
||||
return true
|
||||
}
|
||||
if (
|
||||
MessageV2.AbortedError.isInstance(m.info.error) &&
|
||||
m.parts.some((part) => part.type !== "step-start" && part.type !== "reasoning")
|
||||
) {
|
||||
return true
|
||||
}
|
||||
...MessageV2.toModelMessage(
|
||||
input.messages.filter((m) => {
|
||||
if (m.info.role !== "assistant" || m.info.error === undefined) {
|
||||
return true
|
||||
}
|
||||
if (
|
||||
MessageV2.AbortedError.isInstance(m.info.error) &&
|
||||
m.parts.some((part) => part.type !== "step-start" && part.type !== "reasoning")
|
||||
) {
|
||||
return true
|
||||
}
|
||||
|
||||
return false
|
||||
}),
|
||||
),
|
||||
{
|
||||
role: "user",
|
||||
content: [
|
||||
{
|
||||
type: "text",
|
||||
text: "Summarize our conversation above. This summary will be the only context available when the conversation continues, so preserve critical information including: what was accomplished, current work in progress, files involved, next steps, and any key user requests or constraints. Be concise but detailed enough that work can continue seamlessly.",
|
||||
},
|
||||
],
|
||||
},
|
||||
],
|
||||
model: wrapLanguageModel({
|
||||
model: model.language,
|
||||
middleware: [
|
||||
return false
|
||||
}),
|
||||
),
|
||||
{
|
||||
role: "user",
|
||||
content: [
|
||||
{
|
||||
async transformParams(args) {
|
||||
if (args.type === "stream") {
|
||||
// @ts-expect-error
|
||||
args.params.prompt = ProviderTransform.message(args.params.prompt, model.providerID, model.modelID)
|
||||
}
|
||||
return args.params
|
||||
},
|
||||
type: "text",
|
||||
text: "Summarize our conversation above. This summary will be the only context available when the conversation continues, so preserve critical information including: what was accomplished, current work in progress, files involved, next steps, and any key user requests or constraints. Be concise but detailed enough that work can continue seamlessly.",
|
||||
},
|
||||
],
|
||||
}),
|
||||
},
|
||||
],
|
||||
model: wrapLanguageModel({
|
||||
model: language,
|
||||
middleware: [
|
||||
{
|
||||
async transformParams(args) {
|
||||
if (args.type === "stream") {
|
||||
// @ts-expect-error
|
||||
args.params.prompt = ProviderTransform.message(args.params.prompt, model.providerID, model.modelID)
|
||||
}
|
||||
return args.params
|
||||
},
|
||||
},
|
||||
],
|
||||
}),
|
||||
)
|
||||
})
|
||||
if (result === "continue" && input.auto) {
|
||||
const continueMsg = await Session.updateMessage({
|
||||
id: Identifier.ascending("message"),
|
||||
|
||||
@@ -6,8 +6,7 @@ import { Config } from "../config/config"
|
||||
import { Flag } from "../flag/flag"
|
||||
import { Identifier } from "../id/id"
|
||||
import { Installation } from "../installation"
|
||||
import type { ModelsDev } from "../provider/models"
|
||||
import { Share } from "../share/share"
|
||||
|
||||
import { Storage } from "../storage/storage"
|
||||
import { Log } from "../util/log"
|
||||
import { MessageV2 } from "./message-v2"
|
||||
@@ -16,7 +15,8 @@ import { SessionPrompt } from "./prompt"
|
||||
import { fn } from "@/util/fn"
|
||||
import { Command } from "../command"
|
||||
import { Snapshot } from "@/snapshot"
|
||||
import { ShareNext } from "@/share/share-next"
|
||||
|
||||
import type { Provider } from "@/provider/provider"
|
||||
|
||||
export namespace Session {
|
||||
const log = Log.create({ service: "session" })
|
||||
@@ -223,6 +223,7 @@ export namespace Session {
|
||||
}
|
||||
|
||||
if (cfg.enterprise?.url) {
|
||||
const { ShareNext } = await import("@/share/share-next")
|
||||
const share = await ShareNext.create(id)
|
||||
await update(id, (draft) => {
|
||||
draft.share = {
|
||||
@@ -233,6 +234,7 @@ export namespace Session {
|
||||
|
||||
const session = await get(id)
|
||||
if (session.share) return session.share
|
||||
const { Share } = await import("../share/share")
|
||||
const share = await Share.create(id)
|
||||
await update(id, (draft) => {
|
||||
draft.share = {
|
||||
@@ -253,6 +255,7 @@ export namespace Session {
|
||||
export const unshare = fn(Identifier.schema("session"), async (id) => {
|
||||
const cfg = await Config.get()
|
||||
if (cfg.enterprise?.url) {
|
||||
const { ShareNext } = await import("@/share/share-next")
|
||||
await ShareNext.remove(id)
|
||||
await update(id, (draft) => {
|
||||
draft.share = undefined
|
||||
@@ -264,6 +267,7 @@ export namespace Session {
|
||||
await update(id, (draft) => {
|
||||
draft.share = undefined
|
||||
})
|
||||
const { Share } = await import("../share/share")
|
||||
await Share.remove(id, share.secret)
|
||||
})
|
||||
|
||||
@@ -389,7 +393,7 @@ export namespace Session {
|
||||
|
||||
export const getUsage = fn(
|
||||
z.object({
|
||||
model: z.custom<ModelsDev.Model>(),
|
||||
model: z.custom<Provider.Model>(),
|
||||
usage: z.custom<LanguageModelUsage>(),
|
||||
metadata: z.custom<ProviderMetadata>().optional(),
|
||||
}),
|
||||
@@ -420,16 +424,16 @@ export namespace Session {
|
||||
}
|
||||
|
||||
const costInfo =
|
||||
input.model.cost?.context_over_200k && tokens.input + tokens.cache.read > 200_000
|
||||
? input.model.cost.context_over_200k
|
||||
input.model.cost?.experimentalOver200K && tokens.input + tokens.cache.read > 200_000
|
||||
? input.model.cost.experimentalOver200K
|
||||
: input.model.cost
|
||||
return {
|
||||
cost: safe(
|
||||
new Decimal(0)
|
||||
.add(new Decimal(tokens.input).mul(costInfo?.input ?? 0).div(1_000_000))
|
||||
.add(new Decimal(tokens.output).mul(costInfo?.output ?? 0).div(1_000_000))
|
||||
.add(new Decimal(tokens.cache.read).mul(costInfo?.cache_read ?? 0).div(1_000_000))
|
||||
.add(new Decimal(tokens.cache.write).mul(costInfo?.cache_write ?? 0).div(1_000_000))
|
||||
.add(new Decimal(tokens.cache.read).mul(costInfo?.cache.read ?? 0).div(1_000_000))
|
||||
.add(new Decimal(tokens.cache.write).mul(costInfo?.cache.write ?? 0).div(1_000_000))
|
||||
// TODO: update models.dev to have better pricing model, for now:
|
||||
// charge reasoning tokens at the same rate as output tokens
|
||||
.add(new Decimal(tokens.reasoning).mul(costInfo?.output ?? 0).div(1_000_000))
|
||||
|
||||
@@ -1,6 +1,5 @@
|
||||
import type { ModelsDev } from "@/provider/models"
|
||||
import { MessageV2 } from "./message-v2"
|
||||
import { type StreamTextResult, type Tool as AITool, APICallError } from "ai"
|
||||
import { streamText } from "ai"
|
||||
import { Log } from "@/util/log"
|
||||
import { Identifier } from "@/id/id"
|
||||
import { Session } from "."
|
||||
@@ -11,6 +10,7 @@ import { SessionSummary } from "./summary"
|
||||
import { Bus } from "@/bus"
|
||||
import { SessionRetry } from "./retry"
|
||||
import { SessionStatus } from "./status"
|
||||
import type { Provider } from "@/provider/provider"
|
||||
|
||||
export namespace SessionProcessor {
|
||||
const DOOM_LOOP_THRESHOLD = 3
|
||||
@@ -19,11 +19,19 @@ export namespace SessionProcessor {
|
||||
export type Info = Awaited<ReturnType<typeof create>>
|
||||
export type Result = Awaited<ReturnType<Info["process"]>>
|
||||
|
||||
export type StreamInput = Parameters<typeof streamText>[0]
|
||||
|
||||
export type TBD = {
|
||||
model: {
|
||||
modelID: string
|
||||
providerID: string
|
||||
}
|
||||
}
|
||||
|
||||
export function create(input: {
|
||||
assistantMessage: MessageV2.Assistant
|
||||
sessionID: string
|
||||
providerID: string
|
||||
model: ModelsDev.Model
|
||||
model: Provider.Model
|
||||
abort: AbortSignal
|
||||
}) {
|
||||
const toolcalls: Record<string, MessageV2.ToolPart> = {}
|
||||
@@ -38,13 +46,13 @@ export namespace SessionProcessor {
|
||||
partFromToolCall(toolCallID: string) {
|
||||
return toolcalls[toolCallID]
|
||||
},
|
||||
async process(fn: () => StreamTextResult<Record<string, AITool>, never>) {
|
||||
async process(streamInput: StreamInput) {
|
||||
log.info("process")
|
||||
while (true) {
|
||||
try {
|
||||
let currentText: MessageV2.TextPart | undefined
|
||||
let reasoningMap: Record<string, MessageV2.ReasoningPart> = {}
|
||||
const stream = fn()
|
||||
const stream = streamText(streamInput)
|
||||
|
||||
for await (const value of stream.fullStream) {
|
||||
input.abort.throwIfAborted()
|
||||
@@ -328,11 +336,12 @@ export namespace SessionProcessor {
|
||||
continue
|
||||
}
|
||||
}
|
||||
} catch (e) {
|
||||
} catch (e: any) {
|
||||
log.error("process", {
|
||||
error: e,
|
||||
stack: JSON.stringify(e.stack),
|
||||
})
|
||||
const error = MessageV2.fromError(e, { providerID: input.providerID })
|
||||
const error = MessageV2.fromError(e, { providerID: input.sessionID })
|
||||
const retry = SessionRetry.retryable(error)
|
||||
if (retry !== undefined) {
|
||||
attempt++
|
||||
|
||||
@@ -11,7 +11,6 @@ import { Agent } from "../agent/agent"
|
||||
import { Provider } from "../provider/provider"
|
||||
import {
|
||||
generateText,
|
||||
streamText,
|
||||
type ModelMessage,
|
||||
type Tool as AITool,
|
||||
tool,
|
||||
@@ -288,6 +287,7 @@ export namespace SessionPrompt {
|
||||
})
|
||||
|
||||
const model = await Provider.getModel(lastUser.model.providerID, lastUser.model.modelID)
|
||||
const language = await Provider.getLanguage(model)
|
||||
const task = tasks.pop()
|
||||
|
||||
// pending subtask
|
||||
@@ -311,7 +311,7 @@ export namespace SessionPrompt {
|
||||
reasoning: 0,
|
||||
cache: { read: 0, write: 0 },
|
||||
},
|
||||
modelID: model.modelID,
|
||||
modelID: model.id,
|
||||
providerID: model.providerID,
|
||||
time: {
|
||||
created: Date.now(),
|
||||
@@ -408,7 +408,7 @@ export namespace SessionPrompt {
|
||||
agent: lastUser.agent,
|
||||
model: {
|
||||
providerID: model.providerID,
|
||||
modelID: model.modelID,
|
||||
modelID: model.id,
|
||||
},
|
||||
sessionID,
|
||||
auto: task.auto,
|
||||
@@ -421,7 +421,7 @@ export namespace SessionPrompt {
|
||||
if (
|
||||
lastFinished &&
|
||||
lastFinished.summary !== true &&
|
||||
SessionCompaction.isOverflow({ tokens: lastFinished.tokens, model: model.info })
|
||||
SessionCompaction.isOverflow({ tokens: lastFinished.tokens, model })
|
||||
) {
|
||||
await SessionCompaction.create({
|
||||
sessionID,
|
||||
@@ -455,7 +455,7 @@ export namespace SessionPrompt {
|
||||
reasoning: 0,
|
||||
cache: { read: 0, write: 0 },
|
||||
},
|
||||
modelID: model.modelID,
|
||||
modelID: model.id,
|
||||
providerID: model.providerID,
|
||||
time: {
|
||||
created: Date.now(),
|
||||
@@ -463,20 +463,18 @@ export namespace SessionPrompt {
|
||||
sessionID,
|
||||
})) as MessageV2.Assistant,
|
||||
sessionID: sessionID,
|
||||
model: model.info,
|
||||
providerID: model.providerID,
|
||||
model,
|
||||
abort,
|
||||
})
|
||||
const system = await resolveSystemPrompt({
|
||||
providerID: model.providerID,
|
||||
modelID: model.info.id,
|
||||
model,
|
||||
agent,
|
||||
system: lastUser.system,
|
||||
})
|
||||
const tools = await resolveTools({
|
||||
agent,
|
||||
sessionID,
|
||||
model: lastUser.model,
|
||||
model,
|
||||
tools: lastUser.tools,
|
||||
processor,
|
||||
})
|
||||
@@ -486,21 +484,19 @@ export namespace SessionPrompt {
|
||||
{
|
||||
sessionID: sessionID,
|
||||
agent: lastUser.agent,
|
||||
model: model.info,
|
||||
model: model,
|
||||
provider,
|
||||
message: lastUser,
|
||||
},
|
||||
{
|
||||
temperature: model.info.temperature
|
||||
? (agent.temperature ?? ProviderTransform.temperature(model.providerID, model.modelID))
|
||||
temperature: model.capabilities.temperature
|
||||
? (agent.temperature ?? ProviderTransform.temperature(model))
|
||||
: undefined,
|
||||
topP: agent.topP ?? ProviderTransform.topP(model.providerID, model.modelID),
|
||||
topP: agent.topP ?? ProviderTransform.topP(model),
|
||||
options: pipe(
|
||||
{},
|
||||
mergeDeep(
|
||||
ProviderTransform.options(model.providerID, model.modelID, model.npm ?? "", sessionID, provider?.options),
|
||||
),
|
||||
mergeDeep(model.info.options),
|
||||
mergeDeep(ProviderTransform.options(model, sessionID, provider?.options)),
|
||||
mergeDeep(model.options),
|
||||
mergeDeep(agent.options),
|
||||
),
|
||||
},
|
||||
@@ -513,113 +509,111 @@ export namespace SessionPrompt {
|
||||
})
|
||||
}
|
||||
|
||||
const result = await processor.process(() =>
|
||||
streamText({
|
||||
onError(error) {
|
||||
log.error("stream error", {
|
||||
error,
|
||||
const result = await processor.process({
|
||||
onError(error) {
|
||||
log.error("stream error", {
|
||||
error,
|
||||
})
|
||||
},
|
||||
async experimental_repairToolCall(input) {
|
||||
const lower = input.toolCall.toolName.toLowerCase()
|
||||
if (lower !== input.toolCall.toolName && tools[lower]) {
|
||||
log.info("repairing tool call", {
|
||||
tool: input.toolCall.toolName,
|
||||
repaired: lower,
|
||||
})
|
||||
},
|
||||
async experimental_repairToolCall(input) {
|
||||
const lower = input.toolCall.toolName.toLowerCase()
|
||||
if (lower !== input.toolCall.toolName && tools[lower]) {
|
||||
log.info("repairing tool call", {
|
||||
tool: input.toolCall.toolName,
|
||||
repaired: lower,
|
||||
})
|
||||
return {
|
||||
...input.toolCall,
|
||||
toolName: lower,
|
||||
}
|
||||
}
|
||||
return {
|
||||
...input.toolCall,
|
||||
input: JSON.stringify({
|
||||
tool: input.toolCall.toolName,
|
||||
error: input.error.message,
|
||||
}),
|
||||
toolName: "invalid",
|
||||
toolName: lower,
|
||||
}
|
||||
},
|
||||
headers: {
|
||||
...(model.providerID.startsWith("opencode")
|
||||
? {
|
||||
"x-opencode-project": Instance.project.id,
|
||||
"x-opencode-session": sessionID,
|
||||
"x-opencode-request": lastUser.id,
|
||||
}
|
||||
: undefined),
|
||||
...model.info.headers,
|
||||
},
|
||||
// set to 0, we handle loop
|
||||
maxRetries: 0,
|
||||
activeTools: Object.keys(tools).filter((x) => x !== "invalid"),
|
||||
maxOutputTokens: ProviderTransform.maxOutputTokens(
|
||||
model.providerID,
|
||||
params.options,
|
||||
model.info.limit.output,
|
||||
OUTPUT_TOKEN_MAX,
|
||||
}
|
||||
return {
|
||||
...input.toolCall,
|
||||
input: JSON.stringify({
|
||||
tool: input.toolCall.toolName,
|
||||
error: input.error.message,
|
||||
}),
|
||||
toolName: "invalid",
|
||||
}
|
||||
},
|
||||
headers: {
|
||||
...(model.providerID.startsWith("opencode")
|
||||
? {
|
||||
"x-opencode-project": Instance.project.id,
|
||||
"x-opencode-session": sessionID,
|
||||
"x-opencode-request": lastUser.id,
|
||||
}
|
||||
: undefined),
|
||||
...model.headers,
|
||||
},
|
||||
// set to 0, we handle loop
|
||||
maxRetries: 0,
|
||||
activeTools: Object.keys(tools).filter((x) => x !== "invalid"),
|
||||
maxOutputTokens: ProviderTransform.maxOutputTokens(
|
||||
model.api.npm,
|
||||
params.options,
|
||||
model.limit.output,
|
||||
OUTPUT_TOKEN_MAX,
|
||||
),
|
||||
abortSignal: abort,
|
||||
providerOptions: ProviderTransform.providerOptions(model.api.npm, model.providerID, params.options),
|
||||
stopWhen: stepCountIs(1),
|
||||
temperature: params.temperature,
|
||||
topP: params.topP,
|
||||
messages: [
|
||||
...system.map(
|
||||
(x): ModelMessage => ({
|
||||
role: "system",
|
||||
content: x,
|
||||
}),
|
||||
),
|
||||
abortSignal: abort,
|
||||
providerOptions: ProviderTransform.providerOptions(model.npm, model.providerID, params.options),
|
||||
stopWhen: stepCountIs(1),
|
||||
temperature: params.temperature,
|
||||
topP: params.topP,
|
||||
messages: [
|
||||
...system.map(
|
||||
(x): ModelMessage => ({
|
||||
role: "system",
|
||||
content: x,
|
||||
}),
|
||||
),
|
||||
...MessageV2.toModelMessage(
|
||||
msgs.filter((m) => {
|
||||
if (m.info.role !== "assistant" || m.info.error === undefined) {
|
||||
return true
|
||||
}
|
||||
if (
|
||||
MessageV2.AbortedError.isInstance(m.info.error) &&
|
||||
m.parts.some((part) => part.type !== "step-start" && part.type !== "reasoning")
|
||||
) {
|
||||
return true
|
||||
}
|
||||
...MessageV2.toModelMessage(
|
||||
msgs.filter((m) => {
|
||||
if (m.info.role !== "assistant" || m.info.error === undefined) {
|
||||
return true
|
||||
}
|
||||
if (
|
||||
MessageV2.AbortedError.isInstance(m.info.error) &&
|
||||
m.parts.some((part) => part.type !== "step-start" && part.type !== "reasoning")
|
||||
) {
|
||||
return true
|
||||
}
|
||||
|
||||
return false
|
||||
}),
|
||||
),
|
||||
],
|
||||
tools: model.info.tool_call === false ? undefined : tools,
|
||||
model: wrapLanguageModel({
|
||||
model: model.language,
|
||||
middleware: [
|
||||
{
|
||||
async transformParams(args) {
|
||||
if (args.type === "stream") {
|
||||
// @ts-expect-error
|
||||
args.params.prompt = ProviderTransform.message(args.params.prompt, model.providerID, model.modelID)
|
||||
}
|
||||
// Transform tool schemas for provider compatibility
|
||||
if (args.params.tools && Array.isArray(args.params.tools)) {
|
||||
args.params.tools = args.params.tools.map((tool: any) => {
|
||||
// Tools at middleware level have inputSchema, not parameters
|
||||
if (tool.inputSchema && typeof tool.inputSchema === "object") {
|
||||
// Transform the inputSchema for provider compatibility
|
||||
return {
|
||||
...tool,
|
||||
inputSchema: ProviderTransform.schema(model.providerID, model.modelID, tool.inputSchema),
|
||||
}
|
||||
return false
|
||||
}),
|
||||
),
|
||||
],
|
||||
tools: model.capabilities.toolcall === false ? undefined : tools,
|
||||
model: wrapLanguageModel({
|
||||
model: language,
|
||||
middleware: [
|
||||
{
|
||||
async transformParams(args) {
|
||||
if (args.type === "stream") {
|
||||
// @ts-expect-error - prompt types are compatible at runtime
|
||||
args.params.prompt = ProviderTransform.message(args.params.prompt, model)
|
||||
}
|
||||
// Transform tool schemas for provider compatibility
|
||||
if (args.params.tools && Array.isArray(args.params.tools)) {
|
||||
args.params.tools = args.params.tools.map((tool: any) => {
|
||||
// Tools at middleware level have inputSchema, not parameters
|
||||
if (tool.inputSchema && typeof tool.inputSchema === "object") {
|
||||
// Transform the inputSchema for provider compatibility
|
||||
return {
|
||||
...tool,
|
||||
inputSchema: ProviderTransform.schema(model, tool.inputSchema),
|
||||
}
|
||||
// If no inputSchema, return tool unchanged
|
||||
return tool
|
||||
})
|
||||
}
|
||||
return args.params
|
||||
},
|
||||
}
|
||||
// If no inputSchema, return tool unchanged
|
||||
return tool
|
||||
})
|
||||
}
|
||||
return args.params
|
||||
},
|
||||
],
|
||||
}),
|
||||
},
|
||||
],
|
||||
}),
|
||||
)
|
||||
})
|
||||
if (result === "stop") break
|
||||
continue
|
||||
}
|
||||
@@ -642,18 +636,13 @@ export namespace SessionPrompt {
|
||||
return Provider.defaultModel()
|
||||
}
|
||||
|
||||
async function resolveSystemPrompt(input: {
|
||||
system?: string
|
||||
agent: Agent.Info
|
||||
providerID: string
|
||||
modelID: string
|
||||
}) {
|
||||
let system = SystemPrompt.header(input.providerID)
|
||||
async function resolveSystemPrompt(input: { system?: string; agent: Agent.Info; model: Provider.Model }) {
|
||||
let system = SystemPrompt.header(input.model.providerID)
|
||||
system.push(
|
||||
...(() => {
|
||||
if (input.system) return [input.system]
|
||||
if (input.agent.prompt) return [input.agent.prompt]
|
||||
return SystemPrompt.provider(input.modelID)
|
||||
return SystemPrompt.provider(input.model)
|
||||
})(),
|
||||
)
|
||||
system.push(...(await SystemPrompt.environment()))
|
||||
@@ -666,10 +655,7 @@ export namespace SessionPrompt {
|
||||
|
||||
async function resolveTools(input: {
|
||||
agent: Agent.Info
|
||||
model: {
|
||||
providerID: string
|
||||
modelID: string
|
||||
}
|
||||
model: Provider.Model
|
||||
sessionID: string
|
||||
tools?: Record<string, boolean>
|
||||
processor: SessionProcessor.Info
|
||||
@@ -677,16 +663,12 @@ export namespace SessionPrompt {
|
||||
const tools: Record<string, AITool> = {}
|
||||
const enabledTools = pipe(
|
||||
input.agent.tools,
|
||||
mergeDeep(await ToolRegistry.enabled(input.model.providerID, input.model.modelID, input.agent)),
|
||||
mergeDeep(await ToolRegistry.enabled(input.agent)),
|
||||
mergeDeep(input.tools ?? {}),
|
||||
)
|
||||
for (const item of await ToolRegistry.tools(input.model.providerID, input.model.modelID)) {
|
||||
for (const item of await ToolRegistry.tools(input.model.providerID)) {
|
||||
if (Wildcard.all(item.id, enabledTools) === false) continue
|
||||
const schema = ProviderTransform.schema(
|
||||
input.model.providerID,
|
||||
input.model.modelID,
|
||||
z.toJSONSchema(item.parameters),
|
||||
)
|
||||
const schema = ProviderTransform.schema(input.model, z.toJSONSchema(item.parameters))
|
||||
tools[item.id] = tool({
|
||||
id: item.id as any,
|
||||
description: item.description,
|
||||
@@ -1437,25 +1419,18 @@ export namespace SessionPrompt {
|
||||
if (!isFirst) return
|
||||
const small =
|
||||
(await Provider.getSmallModel(input.providerID)) ?? (await Provider.getModel(input.providerID, input.modelID))
|
||||
const language = await Provider.getLanguage(small)
|
||||
const provider = await Provider.getProvider(small.providerID)
|
||||
const options = pipe(
|
||||
{},
|
||||
mergeDeep(
|
||||
ProviderTransform.options(
|
||||
small.providerID,
|
||||
small.modelID,
|
||||
small.npm ?? "",
|
||||
input.session.id,
|
||||
provider?.options,
|
||||
),
|
||||
),
|
||||
mergeDeep(ProviderTransform.smallOptions({ providerID: small.providerID, modelID: small.modelID })),
|
||||
mergeDeep(small.info.options),
|
||||
mergeDeep(ProviderTransform.options(small, input.session.id, provider?.options)),
|
||||
mergeDeep(ProviderTransform.smallOptions(small)),
|
||||
mergeDeep(small.options),
|
||||
)
|
||||
await generateText({
|
||||
// use higher # for reasoning models since reasoning tokens eat up a lot of the budget
|
||||
maxOutputTokens: small.info.reasoning ? 3000 : 20,
|
||||
providerOptions: ProviderTransform.providerOptions(small.npm, small.providerID, options),
|
||||
maxOutputTokens: small.capabilities.reasoning ? 3000 : 20,
|
||||
providerOptions: ProviderTransform.providerOptions(small.api.npm, small.providerID, options),
|
||||
messages: [
|
||||
...SystemPrompt.title(small.providerID).map(
|
||||
(x): ModelMessage => ({
|
||||
@@ -1486,8 +1461,8 @@ export namespace SessionPrompt {
|
||||
},
|
||||
]),
|
||||
],
|
||||
headers: small.info.headers,
|
||||
model: small.language,
|
||||
headers: small.headers,
|
||||
model: language,
|
||||
})
|
||||
.then((result) => {
|
||||
if (result.text)
|
||||
@@ -1504,7 +1479,7 @@ export namespace SessionPrompt {
|
||||
})
|
||||
})
|
||||
.catch((error) => {
|
||||
log.error("failed to generate title", { error, model: small.info.id })
|
||||
log.error("failed to generate title", { error, model: small.id })
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
@@ -76,19 +76,20 @@ export namespace SessionSummary {
|
||||
const small =
|
||||
(await Provider.getSmallModel(assistantMsg.providerID)) ??
|
||||
(await Provider.getModel(assistantMsg.providerID, assistantMsg.modelID))
|
||||
const language = await Provider.getLanguage(small)
|
||||
|
||||
const options = pipe(
|
||||
{},
|
||||
mergeDeep(ProviderTransform.options(small.providerID, small.modelID, small.npm ?? "", assistantMsg.sessionID)),
|
||||
mergeDeep(ProviderTransform.smallOptions({ providerID: small.providerID, modelID: small.modelID })),
|
||||
mergeDeep(small.info.options),
|
||||
mergeDeep(ProviderTransform.options(small, assistantMsg.sessionID)),
|
||||
mergeDeep(ProviderTransform.smallOptions(small)),
|
||||
mergeDeep(small.options),
|
||||
)
|
||||
|
||||
const textPart = msgWithParts.parts.find((p) => p.type === "text" && !p.synthetic) as MessageV2.TextPart
|
||||
if (textPart && !userMsg.summary?.title) {
|
||||
const result = await generateText({
|
||||
maxOutputTokens: small.info.reasoning ? 1500 : 20,
|
||||
providerOptions: ProviderTransform.providerOptions(small.npm, small.providerID, options),
|
||||
maxOutputTokens: small.capabilities.reasoning ? 1500 : 20,
|
||||
providerOptions: ProviderTransform.providerOptions(small.api.npm, small.providerID, options),
|
||||
messages: [
|
||||
...SystemPrompt.title(small.providerID).map(
|
||||
(x): ModelMessage => ({
|
||||
@@ -106,8 +107,8 @@ export namespace SessionSummary {
|
||||
`,
|
||||
},
|
||||
],
|
||||
headers: small.info.headers,
|
||||
model: small.language,
|
||||
headers: small.headers,
|
||||
model: language,
|
||||
})
|
||||
log.info("title", { title: result.text })
|
||||
userMsg.summary.title = result.text
|
||||
@@ -132,9 +133,9 @@ export namespace SessionSummary {
|
||||
}
|
||||
}
|
||||
const result = await generateText({
|
||||
model: small.language,
|
||||
model: language,
|
||||
maxOutputTokens: 100,
|
||||
providerOptions: ProviderTransform.providerOptions(small.npm, small.providerID, options),
|
||||
providerOptions: ProviderTransform.providerOptions(small.api.npm, small.providerID, options),
|
||||
messages: [
|
||||
...SystemPrompt.summarize(small.providerID).map(
|
||||
(x): ModelMessage => ({
|
||||
@@ -148,7 +149,7 @@ export namespace SessionSummary {
|
||||
content: `Summarize the above conversation according to your system prompts.`,
|
||||
},
|
||||
],
|
||||
headers: small.info.headers,
|
||||
headers: small.headers,
|
||||
}).catch(() => {})
|
||||
if (result) summary = result.text
|
||||
}
|
||||
|
||||
@@ -17,6 +17,7 @@ import PROMPT_COMPACTION from "./prompt/compaction.txt"
|
||||
import PROMPT_SUMMARIZE from "./prompt/summarize.txt"
|
||||
import PROMPT_TITLE from "./prompt/title.txt"
|
||||
import PROMPT_CODEX from "./prompt/codex.txt"
|
||||
import type { Provider } from "@/provider/provider"
|
||||
|
||||
export namespace SystemPrompt {
|
||||
export function header(providerID: string) {
|
||||
@@ -24,12 +25,13 @@ export namespace SystemPrompt {
|
||||
return []
|
||||
}
|
||||
|
||||
export function provider(modelID: string) {
|
||||
if (modelID.includes("gpt-5")) return [PROMPT_CODEX]
|
||||
if (modelID.includes("gpt-") || modelID.includes("o1") || modelID.includes("o3")) return [PROMPT_BEAST]
|
||||
if (modelID.includes("gemini-")) return [PROMPT_GEMINI]
|
||||
if (modelID.includes("claude")) return [PROMPT_ANTHROPIC]
|
||||
if (modelID.includes("polaris-alpha")) return [PROMPT_POLARIS]
|
||||
export function provider(model: Provider.Model) {
|
||||
if (model.api.id.includes("gpt-5")) return [PROMPT_CODEX]
|
||||
if (model.api.id.includes("gpt-") || model.api.id.includes("o1") || model.api.id.includes("o3"))
|
||||
return [PROMPT_BEAST]
|
||||
if (model.api.id.includes("gemini-")) return [PROMPT_GEMINI]
|
||||
if (model.api.id.includes("claude")) return [PROMPT_ANTHROPIC]
|
||||
if (model.api.id.includes("polaris-alpha")) return [PROMPT_POLARIS]
|
||||
return [PROMPT_ANTHROPIC_WITHOUT_TODO]
|
||||
}
|
||||
|
||||
|
||||
Reference in New Issue
Block a user