mirror of
https://gitea.toothfairyai.com/ToothFairyAI/tf_code.git
synced 2026-04-05 00:23:10 +00:00
Refactor agent loop (#4412)
This commit is contained in:
@@ -1,9 +1,8 @@
|
||||
import { streamText, type ModelMessage, type StreamTextResult, type Tool as AITool } from "ai"
|
||||
import { streamText, type ModelMessage } from "ai"
|
||||
import { Session } from "."
|
||||
import { Identifier } from "../id/id"
|
||||
import { Instance } from "../project/instance"
|
||||
import { Provider } from "../provider/provider"
|
||||
import { defer } from "../util/defer"
|
||||
import { MessageV2 } from "./message-v2"
|
||||
import { SystemPrompt } from "./system"
|
||||
import { Bus } from "../bus"
|
||||
@@ -13,10 +12,9 @@ import { SessionPrompt } from "./prompt"
|
||||
import { Flag } from "../flag/flag"
|
||||
import { Token } from "../util/token"
|
||||
import { Log } from "../util/log"
|
||||
import { SessionLock } from "./lock"
|
||||
import { ProviderTransform } from "@/provider/transform"
|
||||
import { SessionRetry } from "./retry"
|
||||
import { Config } from "@/config/config"
|
||||
import { SessionProcessor } from "./processor"
|
||||
import { fn } from "@/util/fn"
|
||||
|
||||
export namespace SessionCompaction {
|
||||
const log = Log.create({ service: "session.compaction" })
|
||||
@@ -42,7 +40,6 @@ export namespace SessionCompaction {
|
||||
|
||||
export const PRUNE_MINIMUM = 20_000
|
||||
export const PRUNE_PROTECT = 40_000
|
||||
const MAX_RETRIES = 10
|
||||
|
||||
// goes backwards through parts until there are 40_000 tokens worth of tool
|
||||
// calls. then erases output of previous tool calls. idea is to throw away old
|
||||
@@ -87,38 +84,29 @@ export namespace SessionCompaction {
|
||||
}
|
||||
}
|
||||
|
||||
export async function run(input: { sessionID: string; providerID: string; modelID: string; signal?: AbortSignal }) {
|
||||
if (!input.signal) SessionLock.assertUnlocked(input.sessionID)
|
||||
await using lock = input.signal === undefined ? SessionLock.acquire({ sessionID: input.sessionID }) : undefined
|
||||
const signal = input.signal ?? lock!.signal
|
||||
|
||||
await Session.update(input.sessionID, (draft) => {
|
||||
draft.time.compacting = Date.now()
|
||||
})
|
||||
await using _ = defer(async () => {
|
||||
await Session.update(input.sessionID, (draft) => {
|
||||
draft.time.compacting = undefined
|
||||
})
|
||||
})
|
||||
const toSummarize = await MessageV2.filterCompacted(MessageV2.stream(input.sessionID))
|
||||
const model = await Provider.getModel(input.providerID, input.modelID)
|
||||
const system = [
|
||||
...SystemPrompt.summarize(model.providerID),
|
||||
...(await SystemPrompt.environment()),
|
||||
...(await SystemPrompt.custom()),
|
||||
]
|
||||
|
||||
export async function process(input: {
|
||||
parentID: string
|
||||
messages: MessageV2.WithParts[]
|
||||
sessionID: string
|
||||
model: {
|
||||
providerID: string
|
||||
modelID: string
|
||||
}
|
||||
abort: AbortSignal
|
||||
}) {
|
||||
const model = await Provider.getModel(input.model.providerID, input.model.modelID)
|
||||
const system = [...SystemPrompt.summarize(model.providerID)]
|
||||
const msg = (await Session.updateMessage({
|
||||
id: Identifier.ascending("message"),
|
||||
role: "assistant",
|
||||
parentID: toSummarize.findLast((m) => m.info.role === "user")?.info.id!,
|
||||
parentID: input.parentID,
|
||||
sessionID: input.sessionID,
|
||||
mode: "build",
|
||||
summary: true,
|
||||
path: {
|
||||
cwd: Instance.directory,
|
||||
root: Instance.worktree,
|
||||
},
|
||||
summary: true,
|
||||
cost: 0,
|
||||
tokens: {
|
||||
output: 0,
|
||||
@@ -126,37 +114,27 @@ export namespace SessionCompaction {
|
||||
reasoning: 0,
|
||||
cache: { read: 0, write: 0 },
|
||||
},
|
||||
modelID: input.modelID,
|
||||
modelID: input.model.modelID,
|
||||
providerID: model.providerID,
|
||||
time: {
|
||||
created: Date.now(),
|
||||
},
|
||||
})) as MessageV2.Assistant
|
||||
|
||||
const part = (await Session.updatePart({
|
||||
type: "text",
|
||||
const processor = SessionProcessor.create({
|
||||
assistantMessage: msg,
|
||||
sessionID: input.sessionID,
|
||||
messageID: msg.id,
|
||||
id: Identifier.ascending("part"),
|
||||
text: "",
|
||||
time: {
|
||||
start: Date.now(),
|
||||
},
|
||||
})) as MessageV2.TextPart
|
||||
|
||||
const doStream = () =>
|
||||
providerID: input.model.providerID,
|
||||
model: model.info,
|
||||
abort: input.abort,
|
||||
})
|
||||
const result = await processor.process(() =>
|
||||
streamText({
|
||||
// set to 0, we handle loop
|
||||
maxRetries: 0,
|
||||
model: model.language,
|
||||
providerOptions: ProviderTransform.providerOptions(model.npm, model.providerID, model.info.options),
|
||||
headers: model.info.headers,
|
||||
abortSignal: signal,
|
||||
onError(error) {
|
||||
log.error("stream error", {
|
||||
error,
|
||||
})
|
||||
},
|
||||
abortSignal: input.abort,
|
||||
tools: model.info.tool_call ? {} : undefined,
|
||||
messages: [
|
||||
...system.map(
|
||||
@@ -165,7 +143,7 @@ export namespace SessionCompaction {
|
||||
content: x,
|
||||
}),
|
||||
),
|
||||
...MessageV2.toModelMessage(toSummarize),
|
||||
...MessageV2.toModelMessage(input.messages),
|
||||
{
|
||||
role: "user",
|
||||
content: [
|
||||
@@ -176,168 +154,60 @@ export namespace SessionCompaction {
|
||||
],
|
||||
},
|
||||
],
|
||||
})
|
||||
|
||||
// TODO: reduce duplication between compaction.ts & prompt.ts
|
||||
const process = async (
|
||||
stream: StreamTextResult<Record<string, AITool>, never>,
|
||||
retries: { count: number; max: number },
|
||||
) => {
|
||||
let shouldRetry = false
|
||||
try {
|
||||
for await (const value of stream.fullStream) {
|
||||
signal.throwIfAborted()
|
||||
switch (value.type) {
|
||||
case "text-delta":
|
||||
part.text += value.text
|
||||
if (value.providerMetadata) part.metadata = value.providerMetadata
|
||||
if (part.text)
|
||||
await Session.updatePart({
|
||||
part,
|
||||
delta: value.text,
|
||||
})
|
||||
continue
|
||||
case "text-end": {
|
||||
part.text = part.text.trimEnd()
|
||||
part.time = {
|
||||
start: Date.now(),
|
||||
end: Date.now(),
|
||||
}
|
||||
if (value.providerMetadata) part.metadata = value.providerMetadata
|
||||
await Session.updatePart(part)
|
||||
continue
|
||||
}
|
||||
case "finish-step": {
|
||||
const usage = Session.getUsage({
|
||||
model: model.info,
|
||||
usage: value.usage,
|
||||
metadata: value.providerMetadata,
|
||||
})
|
||||
msg.cost += usage.cost
|
||||
msg.tokens = usage.tokens
|
||||
await Session.updateMessage(msg)
|
||||
continue
|
||||
}
|
||||
case "error":
|
||||
throw value.error
|
||||
default:
|
||||
continue
|
||||
}
|
||||
}
|
||||
} catch (e) {
|
||||
log.error("compaction error", {
|
||||
error: e,
|
||||
})
|
||||
const error = MessageV2.fromError(e, { providerID: input.providerID })
|
||||
if (retries.count < retries.max && MessageV2.APIError.isInstance(error) && error.data.isRetryable) {
|
||||
shouldRetry = true
|
||||
await Session.updatePart({
|
||||
id: Identifier.ascending("part"),
|
||||
messageID: msg.id,
|
||||
sessionID: msg.sessionID,
|
||||
type: "retry",
|
||||
attempt: retries.count + 1,
|
||||
time: {
|
||||
created: Date.now(),
|
||||
},
|
||||
error,
|
||||
})
|
||||
} else {
|
||||
msg.error = error
|
||||
Bus.publish(Session.Event.Error, {
|
||||
sessionID: msg.sessionID,
|
||||
error: msg.error,
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
const parts = await MessageV2.parts(msg.id)
|
||||
return {
|
||||
info: msg,
|
||||
parts,
|
||||
shouldRetry,
|
||||
}
|
||||
}
|
||||
|
||||
let stream = doStream()
|
||||
const cfg = await Config.get()
|
||||
const maxRetries = cfg.experimental?.chatMaxRetries ?? MAX_RETRIES
|
||||
let result = await process(stream, {
|
||||
count: 0,
|
||||
max: maxRetries,
|
||||
})
|
||||
if (result.shouldRetry) {
|
||||
const start = Date.now()
|
||||
for (let retry = 1; retry < maxRetries; retry++) {
|
||||
const lastRetryPart = result.parts.findLast((p): p is MessageV2.RetryPart => p.type === "retry")
|
||||
|
||||
if (lastRetryPart) {
|
||||
const delayMs = SessionRetry.getBoundedDelay({
|
||||
error: lastRetryPart.error,
|
||||
attempt: retry,
|
||||
startTime: start,
|
||||
})
|
||||
if (!delayMs) {
|
||||
break
|
||||
}
|
||||
|
||||
log.info("retrying with backoff", {
|
||||
attempt: retry,
|
||||
delayMs,
|
||||
elapsed: Date.now() - start,
|
||||
})
|
||||
|
||||
const stop = await SessionRetry.sleep(delayMs, signal)
|
||||
.then(() => false)
|
||||
.catch((error) => {
|
||||
if (error instanceof DOMException && error.name === "AbortError") {
|
||||
const err = new MessageV2.AbortedError(
|
||||
{ message: error.message },
|
||||
{
|
||||
cause: error,
|
||||
},
|
||||
).toObject()
|
||||
result.info.error = err
|
||||
Bus.publish(Session.Event.Error, {
|
||||
sessionID: result.info.sessionID,
|
||||
error: result.info.error,
|
||||
})
|
||||
return true
|
||||
}
|
||||
throw error
|
||||
})
|
||||
|
||||
if (stop) break
|
||||
}
|
||||
|
||||
stream = doStream()
|
||||
result = await process(stream, {
|
||||
count: retry,
|
||||
max: maxRetries,
|
||||
})
|
||||
if (!result.shouldRetry) {
|
||||
break
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
msg.time.completed = Date.now()
|
||||
|
||||
if (
|
||||
!msg.error ||
|
||||
(MessageV2.AbortedError.isInstance(msg.error) &&
|
||||
result.parts.some((part): part is MessageV2.TextPart => part.type === "text" && part.text.length > 0))
|
||||
) {
|
||||
msg.summary = true
|
||||
Bus.publish(Event.Compacted, {
|
||||
}),
|
||||
)
|
||||
if (result === "continue") {
|
||||
const continueMsg = await Session.updateMessage({
|
||||
id: Identifier.ascending("message"),
|
||||
role: "user",
|
||||
sessionID: input.sessionID,
|
||||
time: {
|
||||
created: Date.now(),
|
||||
},
|
||||
agent: "build",
|
||||
model: input.model,
|
||||
})
|
||||
await Session.updatePart({
|
||||
id: Identifier.ascending("part"),
|
||||
messageID: continueMsg.id,
|
||||
sessionID: input.sessionID,
|
||||
type: "text",
|
||||
synthetic: true,
|
||||
text: "Continue if you have next steps",
|
||||
time: {
|
||||
start: Date.now(),
|
||||
end: Date.now(),
|
||||
},
|
||||
})
|
||||
}
|
||||
await Session.updateMessage(msg)
|
||||
|
||||
return {
|
||||
info: msg,
|
||||
parts: result.parts,
|
||||
}
|
||||
return "continue"
|
||||
}
|
||||
|
||||
export const create = fn(
|
||||
z.object({
|
||||
sessionID: Identifier.schema("session"),
|
||||
model: z.object({
|
||||
providerID: z.string(),
|
||||
modelID: z.string(),
|
||||
}),
|
||||
}),
|
||||
async (input) => {
|
||||
const msg = await Session.updateMessage({
|
||||
id: Identifier.ascending("message"),
|
||||
role: "user",
|
||||
model: input.model,
|
||||
sessionID: input.sessionID,
|
||||
agent: "build",
|
||||
time: {
|
||||
created: Date.now(),
|
||||
},
|
||||
})
|
||||
await Session.updatePart({
|
||||
id: Identifier.ascending("part"),
|
||||
messageID: msg.id,
|
||||
sessionID: msg.sessionID,
|
||||
type: "compaction",
|
||||
})
|
||||
},
|
||||
)
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user