mirror of
https://gitea.toothfairyai.com/ToothFairyAI/tf_code.git
synced 2026-04-07 09:18:41 +00:00
feat: retry parts (#3369)
This commit is contained in:
@@ -17,7 +17,6 @@ import {
|
||||
tool,
|
||||
wrapLanguageModel,
|
||||
type StreamTextResult,
|
||||
LoadAPIKeyError,
|
||||
stepCountIs,
|
||||
jsonSchema,
|
||||
} from "ai"
|
||||
@@ -28,6 +27,7 @@ import { Bus } from "../bus"
|
||||
import { ProviderTransform } from "../provider/transform"
|
||||
import { SystemPrompt } from "./system"
|
||||
import { Plugin } from "../plugin"
|
||||
import { SessionRetry } from "./retry"
|
||||
|
||||
import PROMPT_PLAN from "../session/prompt/plan.txt"
|
||||
import BUILD_SWITCH from "../session/prompt/build-switch.txt"
|
||||
@@ -44,7 +44,6 @@ import { TaskTool } from "../tool/task"
|
||||
import { FileTime } from "../file/time"
|
||||
import { Permission } from "../permission"
|
||||
import { Snapshot } from "../snapshot"
|
||||
import { NamedError } from "../util/error"
|
||||
import { ulid } from "ulid"
|
||||
import { spawn } from "child_process"
|
||||
import { Command } from "../command"
|
||||
@@ -55,6 +54,7 @@ import { MessageSummary } from "./summary"
|
||||
export namespace SessionPrompt {
|
||||
const log = Log.create({ service: "session.prompt" })
|
||||
export const OUTPUT_TOKEN_MAX = 32_000
|
||||
const MAX_RETRIES = 10
|
||||
|
||||
export const Event = {
|
||||
Idle: Bus.event(
|
||||
@@ -240,93 +240,145 @@ export namespace SessionPrompt {
|
||||
await using _ = defer(async () => {
|
||||
await processor.end()
|
||||
})
|
||||
const stream = streamText({
|
||||
onError(error) {
|
||||
log.error("stream error", {
|
||||
error,
|
||||
})
|
||||
},
|
||||
async experimental_repairToolCall(input) {
|
||||
const lower = input.toolCall.toolName.toLowerCase()
|
||||
if (lower !== input.toolCall.toolName && tools[lower]) {
|
||||
log.info("repairing tool call", {
|
||||
tool: input.toolCall.toolName,
|
||||
repaired: lower,
|
||||
const doStream = () =>
|
||||
streamText({
|
||||
onError(error) {
|
||||
log.error("stream error", {
|
||||
error,
|
||||
})
|
||||
},
|
||||
async experimental_repairToolCall(input) {
|
||||
const lower = input.toolCall.toolName.toLowerCase()
|
||||
if (lower !== input.toolCall.toolName && tools[lower]) {
|
||||
log.info("repairing tool call", {
|
||||
tool: input.toolCall.toolName,
|
||||
repaired: lower,
|
||||
})
|
||||
return {
|
||||
...input.toolCall,
|
||||
toolName: lower,
|
||||
}
|
||||
}
|
||||
return {
|
||||
...input.toolCall,
|
||||
toolName: lower,
|
||||
input: JSON.stringify({
|
||||
tool: input.toolCall.toolName,
|
||||
error: input.error.message,
|
||||
}),
|
||||
toolName: "invalid",
|
||||
}
|
||||
}
|
||||
return {
|
||||
...input.toolCall,
|
||||
input: JSON.stringify({
|
||||
tool: input.toolCall.toolName,
|
||||
error: input.error.message,
|
||||
}),
|
||||
toolName: "invalid",
|
||||
}
|
||||
},
|
||||
headers:
|
||||
model.providerID === "opencode"
|
||||
? {
|
||||
"x-opencode-session": input.sessionID,
|
||||
"x-opencode-request": userMsg.info.id,
|
||||
}
|
||||
: undefined,
|
||||
maxRetries: 10,
|
||||
activeTools: Object.keys(tools).filter((x) => x !== "invalid"),
|
||||
maxOutputTokens: ProviderTransform.maxOutputTokens(
|
||||
model.providerID,
|
||||
params.options,
|
||||
model.info.limit.output,
|
||||
OUTPUT_TOKEN_MAX,
|
||||
),
|
||||
abortSignal: abort.signal,
|
||||
providerOptions: ProviderTransform.providerOptions(model.npm, model.providerID, params.options),
|
||||
stopWhen: stepCountIs(1),
|
||||
temperature: params.temperature,
|
||||
topP: params.topP,
|
||||
messages: [
|
||||
...system.map(
|
||||
(x): ModelMessage => ({
|
||||
role: "system",
|
||||
content: x,
|
||||
}),
|
||||
),
|
||||
...MessageV2.toModelMessage(
|
||||
msgs.filter((m) => {
|
||||
if (m.info.role !== "assistant" || m.info.error === undefined) {
|
||||
return true
|
||||
}
|
||||
if (
|
||||
MessageV2.AbortedError.isInstance(m.info.error) &&
|
||||
m.parts.some((part) => part.type !== "step-start" && part.type !== "reasoning")
|
||||
) {
|
||||
return true
|
||||
}
|
||||
|
||||
return false
|
||||
}),
|
||||
),
|
||||
],
|
||||
tools: model.info.tool_call === false ? undefined : tools,
|
||||
model: wrapLanguageModel({
|
||||
model: model.language,
|
||||
middleware: [
|
||||
{
|
||||
async transformParams(args) {
|
||||
if (args.type === "stream") {
|
||||
// @ts-expect-error
|
||||
args.params.prompt = ProviderTransform.message(args.params.prompt, model.providerID, model.modelID)
|
||||
},
|
||||
headers:
|
||||
model.providerID === "opencode"
|
||||
? {
|
||||
"x-opencode-session": input.sessionID,
|
||||
"x-opencode-request": userMsg.info.id,
|
||||
}
|
||||
return args.params
|
||||
},
|
||||
},
|
||||
: undefined,
|
||||
// set to 0, we handle loop
|
||||
maxRetries: 0,
|
||||
activeTools: Object.keys(tools).filter((x) => x !== "invalid"),
|
||||
maxOutputTokens: ProviderTransform.maxOutputTokens(
|
||||
model.providerID,
|
||||
params.options,
|
||||
model.info.limit.output,
|
||||
OUTPUT_TOKEN_MAX,
|
||||
),
|
||||
abortSignal: abort.signal,
|
||||
providerOptions: ProviderTransform.providerOptions(model.npm, model.providerID, params.options),
|
||||
stopWhen: stepCountIs(1),
|
||||
temperature: params.temperature,
|
||||
topP: params.topP,
|
||||
messages: [
|
||||
...system.map(
|
||||
(x): ModelMessage => ({
|
||||
role: "system",
|
||||
content: x,
|
||||
}),
|
||||
),
|
||||
...MessageV2.toModelMessage(
|
||||
msgs.filter((m) => {
|
||||
if (m.info.role !== "assistant" || m.info.error === undefined) {
|
||||
return true
|
||||
}
|
||||
if (
|
||||
MessageV2.AbortedError.isInstance(m.info.error) &&
|
||||
m.parts.some((part) => part.type !== "step-start" && part.type !== "reasoning")
|
||||
) {
|
||||
return true
|
||||
}
|
||||
|
||||
return false
|
||||
}),
|
||||
),
|
||||
],
|
||||
}),
|
||||
tools: model.info.tool_call === false ? undefined : tools,
|
||||
model: wrapLanguageModel({
|
||||
model: model.language,
|
||||
middleware: [
|
||||
{
|
||||
async transformParams(args) {
|
||||
if (args.type === "stream") {
|
||||
// @ts-expect-error
|
||||
args.params.prompt = ProviderTransform.message(args.params.prompt, model.providerID, model.modelID)
|
||||
}
|
||||
return args.params
|
||||
},
|
||||
},
|
||||
],
|
||||
}),
|
||||
})
|
||||
|
||||
let stream = doStream()
|
||||
let result = await processor.process(stream, {
|
||||
count: 0,
|
||||
max: MAX_RETRIES,
|
||||
})
|
||||
const result = await processor.process(stream)
|
||||
if (result.shouldRetry) {
|
||||
for (let retry = 1; retry < MAX_RETRIES; retry++) {
|
||||
const lastRetryPart = result.parts.findLast((p) => p.type === "retry")
|
||||
|
||||
if (lastRetryPart) {
|
||||
const delayMs = SessionRetry.getRetryDelayInMs(lastRetryPart.error, retry)
|
||||
|
||||
log.info("retrying with backoff", {
|
||||
attempt: retry,
|
||||
delayMs,
|
||||
})
|
||||
|
||||
const stop = await SessionRetry.sleep(delayMs, abort.signal)
|
||||
.then(() => false)
|
||||
.catch((error) => {
|
||||
if (error instanceof DOMException && error.name === "AbortError") {
|
||||
const err = new MessageV2.AbortedError(
|
||||
{ message: error.message },
|
||||
{
|
||||
cause: error,
|
||||
},
|
||||
).toObject()
|
||||
result.info.error = err
|
||||
Bus.publish(Session.Event.Error, {
|
||||
sessionID: result.info.sessionID,
|
||||
error: result.info.error,
|
||||
})
|
||||
return true
|
||||
}
|
||||
throw error
|
||||
})
|
||||
|
||||
if (stop) break
|
||||
}
|
||||
|
||||
stream = doStream()
|
||||
result = await processor.process(stream, {
|
||||
count: retry,
|
||||
max: MAX_RETRIES,
|
||||
})
|
||||
if (!result.shouldRetry) {
|
||||
break
|
||||
}
|
||||
}
|
||||
}
|
||||
await processor.end()
|
||||
|
||||
const queued = state().queued.get(input.sessionID) ?? []
|
||||
@@ -959,9 +1011,10 @@ export namespace SessionPrompt {
|
||||
partFromToolCall(toolCallID: string) {
|
||||
return toolcalls[toolCallID]
|
||||
},
|
||||
async process(stream: StreamTextResult<Record<string, AITool>, never>) {
|
||||
async process(stream: StreamTextResult<Record<string, AITool>, never>, retries: { count: number; max: number }) {
|
||||
log.info("process")
|
||||
if (!assistantMsg) throw new Error("call next() first before processing")
|
||||
let shouldRetry = false
|
||||
try {
|
||||
let currentText: MessageV2.TextPart | undefined
|
||||
let reasoningMap: Record<string, MessageV2.ReasoningPart> = {}
|
||||
@@ -1314,37 +1367,27 @@ export namespace SessionPrompt {
|
||||
log.error("process", {
|
||||
error: e,
|
||||
})
|
||||
switch (true) {
|
||||
case e instanceof DOMException && e.name === "AbortError":
|
||||
assistantMsg.error = new MessageV2.AbortedError(
|
||||
{ message: e.message },
|
||||
{
|
||||
cause: e,
|
||||
},
|
||||
).toObject()
|
||||
break
|
||||
case MessageV2.OutputLengthError.isInstance(e):
|
||||
assistantMsg.error = e
|
||||
break
|
||||
case LoadAPIKeyError.isInstance(e):
|
||||
assistantMsg.error = new MessageV2.AuthError(
|
||||
{
|
||||
providerID: input.providerID,
|
||||
message: e.message,
|
||||
},
|
||||
{ cause: e },
|
||||
).toObject()
|
||||
break
|
||||
case e instanceof Error:
|
||||
assistantMsg.error = new NamedError.Unknown({ message: e.toString() }, { cause: e }).toObject()
|
||||
break
|
||||
default:
|
||||
assistantMsg.error = new NamedError.Unknown({ message: JSON.stringify(e) }, { cause: e })
|
||||
const error = MessageV2.fromError(e, { providerID: input.providerID })
|
||||
if (retries.count < retries.max && MessageV2.APIError.isInstance(error) && error.data.isRetryable) {
|
||||
shouldRetry = true
|
||||
await Session.updatePart({
|
||||
id: Identifier.ascending("part"),
|
||||
messageID: assistantMsg.id,
|
||||
sessionID: assistantMsg.sessionID,
|
||||
type: "retry",
|
||||
attempt: retries.count + 1,
|
||||
time: {
|
||||
created: Date.now(),
|
||||
},
|
||||
error,
|
||||
})
|
||||
} else {
|
||||
assistantMsg.error = error
|
||||
Bus.publish(Session.Event.Error, {
|
||||
sessionID: assistantMsg.sessionID,
|
||||
error: assistantMsg.error,
|
||||
})
|
||||
}
|
||||
Bus.publish(Session.Event.Error, {
|
||||
sessionID: assistantMsg.sessionID,
|
||||
error: assistantMsg.error,
|
||||
})
|
||||
}
|
||||
const p = await Session.getParts(assistantMsg.id)
|
||||
for (const part of p) {
|
||||
@@ -1363,9 +1406,11 @@ export namespace SessionPrompt {
|
||||
})
|
||||
}
|
||||
}
|
||||
assistantMsg.time.completed = Date.now()
|
||||
if (!shouldRetry) {
|
||||
assistantMsg.time.completed = Date.now()
|
||||
}
|
||||
await Session.updateMessage(assistantMsg)
|
||||
return { info: assistantMsg, parts: p, blocked }
|
||||
return { info: assistantMsg, parts: p, blocked, shouldRetry }
|
||||
},
|
||||
}
|
||||
return result
|
||||
|
||||
Reference in New Issue
Block a user