feat: retry parts (#3369)

This commit is contained in:
Aiden Cline
2025-10-22 18:31:36 -05:00
committed by GitHub
parent 9def7cff2d
commit 7c7ebb0a9d
5 changed files with 487 additions and 214 deletions

View File

@@ -17,7 +17,6 @@ import {
tool,
wrapLanguageModel,
type StreamTextResult,
LoadAPIKeyError,
stepCountIs,
jsonSchema,
} from "ai"
@@ -28,6 +27,7 @@ import { Bus } from "../bus"
import { ProviderTransform } from "../provider/transform"
import { SystemPrompt } from "./system"
import { Plugin } from "../plugin"
import { SessionRetry } from "./retry"
import PROMPT_PLAN from "../session/prompt/plan.txt"
import BUILD_SWITCH from "../session/prompt/build-switch.txt"
@@ -44,7 +44,6 @@ import { TaskTool } from "../tool/task"
import { FileTime } from "../file/time"
import { Permission } from "../permission"
import { Snapshot } from "../snapshot"
import { NamedError } from "../util/error"
import { ulid } from "ulid"
import { spawn } from "child_process"
import { Command } from "../command"
@@ -55,6 +54,7 @@ import { MessageSummary } from "./summary"
export namespace SessionPrompt {
const log = Log.create({ service: "session.prompt" })
export const OUTPUT_TOKEN_MAX = 32_000
const MAX_RETRIES = 10
export const Event = {
Idle: Bus.event(
@@ -240,93 +240,145 @@ export namespace SessionPrompt {
await using _ = defer(async () => {
await processor.end()
})
const stream = streamText({
onError(error) {
log.error("stream error", {
error,
})
},
async experimental_repairToolCall(input) {
const lower = input.toolCall.toolName.toLowerCase()
if (lower !== input.toolCall.toolName && tools[lower]) {
log.info("repairing tool call", {
tool: input.toolCall.toolName,
repaired: lower,
const doStream = () =>
streamText({
onError(error) {
log.error("stream error", {
error,
})
},
async experimental_repairToolCall(input) {
const lower = input.toolCall.toolName.toLowerCase()
if (lower !== input.toolCall.toolName && tools[lower]) {
log.info("repairing tool call", {
tool: input.toolCall.toolName,
repaired: lower,
})
return {
...input.toolCall,
toolName: lower,
}
}
return {
...input.toolCall,
toolName: lower,
input: JSON.stringify({
tool: input.toolCall.toolName,
error: input.error.message,
}),
toolName: "invalid",
}
}
return {
...input.toolCall,
input: JSON.stringify({
tool: input.toolCall.toolName,
error: input.error.message,
}),
toolName: "invalid",
}
},
headers:
model.providerID === "opencode"
? {
"x-opencode-session": input.sessionID,
"x-opencode-request": userMsg.info.id,
}
: undefined,
maxRetries: 10,
activeTools: Object.keys(tools).filter((x) => x !== "invalid"),
maxOutputTokens: ProviderTransform.maxOutputTokens(
model.providerID,
params.options,
model.info.limit.output,
OUTPUT_TOKEN_MAX,
),
abortSignal: abort.signal,
providerOptions: ProviderTransform.providerOptions(model.npm, model.providerID, params.options),
stopWhen: stepCountIs(1),
temperature: params.temperature,
topP: params.topP,
messages: [
...system.map(
(x): ModelMessage => ({
role: "system",
content: x,
}),
),
...MessageV2.toModelMessage(
msgs.filter((m) => {
if (m.info.role !== "assistant" || m.info.error === undefined) {
return true
}
if (
MessageV2.AbortedError.isInstance(m.info.error) &&
m.parts.some((part) => part.type !== "step-start" && part.type !== "reasoning")
) {
return true
}
return false
}),
),
],
tools: model.info.tool_call === false ? undefined : tools,
model: wrapLanguageModel({
model: model.language,
middleware: [
{
async transformParams(args) {
if (args.type === "stream") {
// @ts-expect-error
args.params.prompt = ProviderTransform.message(args.params.prompt, model.providerID, model.modelID)
},
headers:
model.providerID === "opencode"
? {
"x-opencode-session": input.sessionID,
"x-opencode-request": userMsg.info.id,
}
return args.params
},
},
: undefined,
// set to 0, we handle loop
maxRetries: 0,
activeTools: Object.keys(tools).filter((x) => x !== "invalid"),
maxOutputTokens: ProviderTransform.maxOutputTokens(
model.providerID,
params.options,
model.info.limit.output,
OUTPUT_TOKEN_MAX,
),
abortSignal: abort.signal,
providerOptions: ProviderTransform.providerOptions(model.npm, model.providerID, params.options),
stopWhen: stepCountIs(1),
temperature: params.temperature,
topP: params.topP,
messages: [
...system.map(
(x): ModelMessage => ({
role: "system",
content: x,
}),
),
...MessageV2.toModelMessage(
msgs.filter((m) => {
if (m.info.role !== "assistant" || m.info.error === undefined) {
return true
}
if (
MessageV2.AbortedError.isInstance(m.info.error) &&
m.parts.some((part) => part.type !== "step-start" && part.type !== "reasoning")
) {
return true
}
return false
}),
),
],
}),
tools: model.info.tool_call === false ? undefined : tools,
model: wrapLanguageModel({
model: model.language,
middleware: [
{
async transformParams(args) {
if (args.type === "stream") {
// @ts-expect-error
args.params.prompt = ProviderTransform.message(args.params.prompt, model.providerID, model.modelID)
}
return args.params
},
},
],
}),
})
let stream = doStream()
let result = await processor.process(stream, {
count: 0,
max: MAX_RETRIES,
})
const result = await processor.process(stream)
if (result.shouldRetry) {
for (let retry = 1; retry < MAX_RETRIES; retry++) {
const lastRetryPart = result.parts.findLast((p) => p.type === "retry")
if (lastRetryPart) {
const delayMs = SessionRetry.getRetryDelayInMs(lastRetryPart.error, retry)
log.info("retrying with backoff", {
attempt: retry,
delayMs,
})
const stop = await SessionRetry.sleep(delayMs, abort.signal)
.then(() => false)
.catch((error) => {
if (error instanceof DOMException && error.name === "AbortError") {
const err = new MessageV2.AbortedError(
{ message: error.message },
{
cause: error,
},
).toObject()
result.info.error = err
Bus.publish(Session.Event.Error, {
sessionID: result.info.sessionID,
error: result.info.error,
})
return true
}
throw error
})
if (stop) break
}
stream = doStream()
result = await processor.process(stream, {
count: retry,
max: MAX_RETRIES,
})
if (!result.shouldRetry) {
break
}
}
}
await processor.end()
const queued = state().queued.get(input.sessionID) ?? []
@@ -959,9 +1011,10 @@ export namespace SessionPrompt {
partFromToolCall(toolCallID: string) {
return toolcalls[toolCallID]
},
async process(stream: StreamTextResult<Record<string, AITool>, never>) {
async process(stream: StreamTextResult<Record<string, AITool>, never>, retries: { count: number; max: number }) {
log.info("process")
if (!assistantMsg) throw new Error("call next() first before processing")
let shouldRetry = false
try {
let currentText: MessageV2.TextPart | undefined
let reasoningMap: Record<string, MessageV2.ReasoningPart> = {}
@@ -1314,37 +1367,27 @@ export namespace SessionPrompt {
log.error("process", {
error: e,
})
switch (true) {
case e instanceof DOMException && e.name === "AbortError":
assistantMsg.error = new MessageV2.AbortedError(
{ message: e.message },
{
cause: e,
},
).toObject()
break
case MessageV2.OutputLengthError.isInstance(e):
assistantMsg.error = e
break
case LoadAPIKeyError.isInstance(e):
assistantMsg.error = new MessageV2.AuthError(
{
providerID: input.providerID,
message: e.message,
},
{ cause: e },
).toObject()
break
case e instanceof Error:
assistantMsg.error = new NamedError.Unknown({ message: e.toString() }, { cause: e }).toObject()
break
default:
assistantMsg.error = new NamedError.Unknown({ message: JSON.stringify(e) }, { cause: e })
const error = MessageV2.fromError(e, { providerID: input.providerID })
if (retries.count < retries.max && MessageV2.APIError.isInstance(error) && error.data.isRetryable) {
shouldRetry = true
await Session.updatePart({
id: Identifier.ascending("part"),
messageID: assistantMsg.id,
sessionID: assistantMsg.sessionID,
type: "retry",
attempt: retries.count + 1,
time: {
created: Date.now(),
},
error,
})
} else {
assistantMsg.error = error
Bus.publish(Session.Event.Error, {
sessionID: assistantMsg.sessionID,
error: assistantMsg.error,
})
}
Bus.publish(Session.Event.Error, {
sessionID: assistantMsg.sessionID,
error: assistantMsg.error,
})
}
const p = await Session.getParts(assistantMsg.id)
for (const part of p) {
@@ -1363,9 +1406,11 @@ export namespace SessionPrompt {
})
}
}
assistantMsg.time.completed = Date.now()
if (!shouldRetry) {
assistantMsg.time.completed = Date.now()
}
await Session.updateMessage(assistantMsg)
return { info: assistantMsg, parts: p, blocked }
return { info: assistantMsg, parts: p, blocked, shouldRetry }
},
}
return result