Refactor agent loop (#4412)

This commit is contained in:
Dax
2025-11-17 10:57:18 -05:00
committed by GitHub
parent 9fd43ec616
commit a1214fff2e
22 changed files with 1297 additions and 1324 deletions

View File

@@ -1,9 +1,8 @@
import { streamText, type ModelMessage, type StreamTextResult, type Tool as AITool } from "ai"
import { streamText, type ModelMessage } from "ai"
import { Session } from "."
import { Identifier } from "../id/id"
import { Instance } from "../project/instance"
import { Provider } from "../provider/provider"
import { defer } from "../util/defer"
import { MessageV2 } from "./message-v2"
import { SystemPrompt } from "./system"
import { Bus } from "../bus"
@@ -13,10 +12,9 @@ import { SessionPrompt } from "./prompt"
import { Flag } from "../flag/flag"
import { Token } from "../util/token"
import { Log } from "../util/log"
import { SessionLock } from "./lock"
import { ProviderTransform } from "@/provider/transform"
import { SessionRetry } from "./retry"
import { Config } from "@/config/config"
import { SessionProcessor } from "./processor"
import { fn } from "@/util/fn"
export namespace SessionCompaction {
const log = Log.create({ service: "session.compaction" })
@@ -42,7 +40,6 @@ export namespace SessionCompaction {
export const PRUNE_MINIMUM = 20_000
export const PRUNE_PROTECT = 40_000
const MAX_RETRIES = 10
// goes backwards through parts until there are 40_000 tokens worth of tool
// calls. then erases output of previous tool calls. idea is to throw away old
@@ -87,38 +84,29 @@ export namespace SessionCompaction {
}
}
export async function run(input: { sessionID: string; providerID: string; modelID: string; signal?: AbortSignal }) {
if (!input.signal) SessionLock.assertUnlocked(input.sessionID)
await using lock = input.signal === undefined ? SessionLock.acquire({ sessionID: input.sessionID }) : undefined
const signal = input.signal ?? lock!.signal
await Session.update(input.sessionID, (draft) => {
draft.time.compacting = Date.now()
})
await using _ = defer(async () => {
await Session.update(input.sessionID, (draft) => {
draft.time.compacting = undefined
})
})
const toSummarize = await MessageV2.filterCompacted(MessageV2.stream(input.sessionID))
const model = await Provider.getModel(input.providerID, input.modelID)
const system = [
...SystemPrompt.summarize(model.providerID),
...(await SystemPrompt.environment()),
...(await SystemPrompt.custom()),
]
export async function process(input: {
parentID: string
messages: MessageV2.WithParts[]
sessionID: string
model: {
providerID: string
modelID: string
}
abort: AbortSignal
}) {
const model = await Provider.getModel(input.model.providerID, input.model.modelID)
const system = [...SystemPrompt.summarize(model.providerID)]
const msg = (await Session.updateMessage({
id: Identifier.ascending("message"),
role: "assistant",
parentID: toSummarize.findLast((m) => m.info.role === "user")?.info.id!,
parentID: input.parentID,
sessionID: input.sessionID,
mode: "build",
summary: true,
path: {
cwd: Instance.directory,
root: Instance.worktree,
},
summary: true,
cost: 0,
tokens: {
output: 0,
@@ -126,37 +114,27 @@ export namespace SessionCompaction {
reasoning: 0,
cache: { read: 0, write: 0 },
},
modelID: input.modelID,
modelID: input.model.modelID,
providerID: model.providerID,
time: {
created: Date.now(),
},
})) as MessageV2.Assistant
const part = (await Session.updatePart({
type: "text",
const processor = SessionProcessor.create({
assistantMessage: msg,
sessionID: input.sessionID,
messageID: msg.id,
id: Identifier.ascending("part"),
text: "",
time: {
start: Date.now(),
},
})) as MessageV2.TextPart
const doStream = () =>
providerID: input.model.providerID,
model: model.info,
abort: input.abort,
})
const result = await processor.process(() =>
streamText({
// set to 0, we handle loop
maxRetries: 0,
model: model.language,
providerOptions: ProviderTransform.providerOptions(model.npm, model.providerID, model.info.options),
headers: model.info.headers,
abortSignal: signal,
onError(error) {
log.error("stream error", {
error,
})
},
abortSignal: input.abort,
tools: model.info.tool_call ? {} : undefined,
messages: [
...system.map(
@@ -165,7 +143,7 @@ export namespace SessionCompaction {
content: x,
}),
),
...MessageV2.toModelMessage(toSummarize),
...MessageV2.toModelMessage(input.messages),
{
role: "user",
content: [
@@ -176,168 +154,60 @@ export namespace SessionCompaction {
],
},
],
})
// TODO: reduce duplication between compaction.ts & prompt.ts
const process = async (
stream: StreamTextResult<Record<string, AITool>, never>,
retries: { count: number; max: number },
) => {
let shouldRetry = false
try {
for await (const value of stream.fullStream) {
signal.throwIfAborted()
switch (value.type) {
case "text-delta":
part.text += value.text
if (value.providerMetadata) part.metadata = value.providerMetadata
if (part.text)
await Session.updatePart({
part,
delta: value.text,
})
continue
case "text-end": {
part.text = part.text.trimEnd()
part.time = {
start: Date.now(),
end: Date.now(),
}
if (value.providerMetadata) part.metadata = value.providerMetadata
await Session.updatePart(part)
continue
}
case "finish-step": {
const usage = Session.getUsage({
model: model.info,
usage: value.usage,
metadata: value.providerMetadata,
})
msg.cost += usage.cost
msg.tokens = usage.tokens
await Session.updateMessage(msg)
continue
}
case "error":
throw value.error
default:
continue
}
}
} catch (e) {
log.error("compaction error", {
error: e,
})
const error = MessageV2.fromError(e, { providerID: input.providerID })
if (retries.count < retries.max && MessageV2.APIError.isInstance(error) && error.data.isRetryable) {
shouldRetry = true
await Session.updatePart({
id: Identifier.ascending("part"),
messageID: msg.id,
sessionID: msg.sessionID,
type: "retry",
attempt: retries.count + 1,
time: {
created: Date.now(),
},
error,
})
} else {
msg.error = error
Bus.publish(Session.Event.Error, {
sessionID: msg.sessionID,
error: msg.error,
})
}
}
const parts = await MessageV2.parts(msg.id)
return {
info: msg,
parts,
shouldRetry,
}
}
let stream = doStream()
const cfg = await Config.get()
const maxRetries = cfg.experimental?.chatMaxRetries ?? MAX_RETRIES
let result = await process(stream, {
count: 0,
max: maxRetries,
})
if (result.shouldRetry) {
const start = Date.now()
for (let retry = 1; retry < maxRetries; retry++) {
const lastRetryPart = result.parts.findLast((p): p is MessageV2.RetryPart => p.type === "retry")
if (lastRetryPart) {
const delayMs = SessionRetry.getBoundedDelay({
error: lastRetryPart.error,
attempt: retry,
startTime: start,
})
if (!delayMs) {
break
}
log.info("retrying with backoff", {
attempt: retry,
delayMs,
elapsed: Date.now() - start,
})
const stop = await SessionRetry.sleep(delayMs, signal)
.then(() => false)
.catch((error) => {
if (error instanceof DOMException && error.name === "AbortError") {
const err = new MessageV2.AbortedError(
{ message: error.message },
{
cause: error,
},
).toObject()
result.info.error = err
Bus.publish(Session.Event.Error, {
sessionID: result.info.sessionID,
error: result.info.error,
})
return true
}
throw error
})
if (stop) break
}
stream = doStream()
result = await process(stream, {
count: retry,
max: maxRetries,
})
if (!result.shouldRetry) {
break
}
}
}
msg.time.completed = Date.now()
if (
!msg.error ||
(MessageV2.AbortedError.isInstance(msg.error) &&
result.parts.some((part): part is MessageV2.TextPart => part.type === "text" && part.text.length > 0))
) {
msg.summary = true
Bus.publish(Event.Compacted, {
}),
)
if (result === "continue") {
const continueMsg = await Session.updateMessage({
id: Identifier.ascending("message"),
role: "user",
sessionID: input.sessionID,
time: {
created: Date.now(),
},
agent: "build",
model: input.model,
})
await Session.updatePart({
id: Identifier.ascending("part"),
messageID: continueMsg.id,
sessionID: input.sessionID,
type: "text",
synthetic: true,
text: "Continue if you have next steps",
time: {
start: Date.now(),
end: Date.now(),
},
})
}
await Session.updateMessage(msg)
return {
info: msg,
parts: result.parts,
}
return "continue"
}
export const create = fn(
z.object({
sessionID: Identifier.schema("session"),
model: z.object({
providerID: z.string(),
modelID: z.string(),
}),
}),
async (input) => {
const msg = await Session.updateMessage({
id: Identifier.ascending("message"),
role: "user",
model: input.model,
sessionID: input.sessionID,
agent: "build",
time: {
created: Date.now(),
},
})
await Session.updatePart({
id: Identifier.ascending("part"),
messageID: msg.id,
sessionID: msg.sessionID,
type: "compaction",
})
},
)
}