Refactor agent loop (#4412)

This commit is contained in:
Dax
2025-11-17 10:57:18 -05:00
committed by GitHub
parent 9fd43ec616
commit a1214fff2e
22 changed files with 1297 additions and 1324 deletions

View File

@@ -1,9 +1,8 @@
import { streamText, type ModelMessage, type StreamTextResult, type Tool as AITool } from "ai"
import { streamText, type ModelMessage } from "ai"
import { Session } from "."
import { Identifier } from "../id/id"
import { Instance } from "../project/instance"
import { Provider } from "../provider/provider"
import { defer } from "../util/defer"
import { MessageV2 } from "./message-v2"
import { SystemPrompt } from "./system"
import { Bus } from "../bus"
@@ -13,10 +12,9 @@ import { SessionPrompt } from "./prompt"
import { Flag } from "../flag/flag"
import { Token } from "../util/token"
import { Log } from "../util/log"
import { SessionLock } from "./lock"
import { ProviderTransform } from "@/provider/transform"
import { SessionRetry } from "./retry"
import { Config } from "@/config/config"
import { SessionProcessor } from "./processor"
import { fn } from "@/util/fn"
export namespace SessionCompaction {
const log = Log.create({ service: "session.compaction" })
@@ -42,7 +40,6 @@ export namespace SessionCompaction {
export const PRUNE_MINIMUM = 20_000
export const PRUNE_PROTECT = 40_000
const MAX_RETRIES = 10
// goes backwards through parts until there are 40_000 tokens worth of tool
// calls. then erases output of previous tool calls. idea is to throw away old
@@ -87,38 +84,29 @@ export namespace SessionCompaction {
}
}
export async function run(input: { sessionID: string; providerID: string; modelID: string; signal?: AbortSignal }) {
if (!input.signal) SessionLock.assertUnlocked(input.sessionID)
await using lock = input.signal === undefined ? SessionLock.acquire({ sessionID: input.sessionID }) : undefined
const signal = input.signal ?? lock!.signal
await Session.update(input.sessionID, (draft) => {
draft.time.compacting = Date.now()
})
await using _ = defer(async () => {
await Session.update(input.sessionID, (draft) => {
draft.time.compacting = undefined
})
})
const toSummarize = await MessageV2.filterCompacted(MessageV2.stream(input.sessionID))
const model = await Provider.getModel(input.providerID, input.modelID)
const system = [
...SystemPrompt.summarize(model.providerID),
...(await SystemPrompt.environment()),
...(await SystemPrompt.custom()),
]
export async function process(input: {
parentID: string
messages: MessageV2.WithParts[]
sessionID: string
model: {
providerID: string
modelID: string
}
abort: AbortSignal
}) {
const model = await Provider.getModel(input.model.providerID, input.model.modelID)
const system = [...SystemPrompt.summarize(model.providerID)]
const msg = (await Session.updateMessage({
id: Identifier.ascending("message"),
role: "assistant",
parentID: toSummarize.findLast((m) => m.info.role === "user")?.info.id!,
parentID: input.parentID,
sessionID: input.sessionID,
mode: "build",
summary: true,
path: {
cwd: Instance.directory,
root: Instance.worktree,
},
summary: true,
cost: 0,
tokens: {
output: 0,
@@ -126,37 +114,27 @@ export namespace SessionCompaction {
reasoning: 0,
cache: { read: 0, write: 0 },
},
modelID: input.modelID,
modelID: input.model.modelID,
providerID: model.providerID,
time: {
created: Date.now(),
},
})) as MessageV2.Assistant
const part = (await Session.updatePart({
type: "text",
const processor = SessionProcessor.create({
assistantMessage: msg,
sessionID: input.sessionID,
messageID: msg.id,
id: Identifier.ascending("part"),
text: "",
time: {
start: Date.now(),
},
})) as MessageV2.TextPart
const doStream = () =>
providerID: input.model.providerID,
model: model.info,
abort: input.abort,
})
const result = await processor.process(() =>
streamText({
// set to 0, we handle loop
maxRetries: 0,
model: model.language,
providerOptions: ProviderTransform.providerOptions(model.npm, model.providerID, model.info.options),
headers: model.info.headers,
abortSignal: signal,
onError(error) {
log.error("stream error", {
error,
})
},
abortSignal: input.abort,
tools: model.info.tool_call ? {} : undefined,
messages: [
...system.map(
@@ -165,7 +143,7 @@ export namespace SessionCompaction {
content: x,
}),
),
...MessageV2.toModelMessage(toSummarize),
...MessageV2.toModelMessage(input.messages),
{
role: "user",
content: [
@@ -176,168 +154,60 @@ export namespace SessionCompaction {
],
},
],
})
// TODO: reduce duplication between compaction.ts & prompt.ts
const process = async (
stream: StreamTextResult<Record<string, AITool>, never>,
retries: { count: number; max: number },
) => {
let shouldRetry = false
try {
for await (const value of stream.fullStream) {
signal.throwIfAborted()
switch (value.type) {
case "text-delta":
part.text += value.text
if (value.providerMetadata) part.metadata = value.providerMetadata
if (part.text)
await Session.updatePart({
part,
delta: value.text,
})
continue
case "text-end": {
part.text = part.text.trimEnd()
part.time = {
start: Date.now(),
end: Date.now(),
}
if (value.providerMetadata) part.metadata = value.providerMetadata
await Session.updatePart(part)
continue
}
case "finish-step": {
const usage = Session.getUsage({
model: model.info,
usage: value.usage,
metadata: value.providerMetadata,
})
msg.cost += usage.cost
msg.tokens = usage.tokens
await Session.updateMessage(msg)
continue
}
case "error":
throw value.error
default:
continue
}
}
} catch (e) {
log.error("compaction error", {
error: e,
})
const error = MessageV2.fromError(e, { providerID: input.providerID })
if (retries.count < retries.max && MessageV2.APIError.isInstance(error) && error.data.isRetryable) {
shouldRetry = true
await Session.updatePart({
id: Identifier.ascending("part"),
messageID: msg.id,
sessionID: msg.sessionID,
type: "retry",
attempt: retries.count + 1,
time: {
created: Date.now(),
},
error,
})
} else {
msg.error = error
Bus.publish(Session.Event.Error, {
sessionID: msg.sessionID,
error: msg.error,
})
}
}
const parts = await MessageV2.parts(msg.id)
return {
info: msg,
parts,
shouldRetry,
}
}
let stream = doStream()
const cfg = await Config.get()
const maxRetries = cfg.experimental?.chatMaxRetries ?? MAX_RETRIES
let result = await process(stream, {
count: 0,
max: maxRetries,
})
if (result.shouldRetry) {
const start = Date.now()
for (let retry = 1; retry < maxRetries; retry++) {
const lastRetryPart = result.parts.findLast((p): p is MessageV2.RetryPart => p.type === "retry")
if (lastRetryPart) {
const delayMs = SessionRetry.getBoundedDelay({
error: lastRetryPart.error,
attempt: retry,
startTime: start,
})
if (!delayMs) {
break
}
log.info("retrying with backoff", {
attempt: retry,
delayMs,
elapsed: Date.now() - start,
})
const stop = await SessionRetry.sleep(delayMs, signal)
.then(() => false)
.catch((error) => {
if (error instanceof DOMException && error.name === "AbortError") {
const err = new MessageV2.AbortedError(
{ message: error.message },
{
cause: error,
},
).toObject()
result.info.error = err
Bus.publish(Session.Event.Error, {
sessionID: result.info.sessionID,
error: result.info.error,
})
return true
}
throw error
})
if (stop) break
}
stream = doStream()
result = await process(stream, {
count: retry,
max: maxRetries,
})
if (!result.shouldRetry) {
break
}
}
}
msg.time.completed = Date.now()
if (
!msg.error ||
(MessageV2.AbortedError.isInstance(msg.error) &&
result.parts.some((part): part is MessageV2.TextPart => part.type === "text" && part.text.length > 0))
) {
msg.summary = true
Bus.publish(Event.Compacted, {
}),
)
if (result === "continue") {
const continueMsg = await Session.updateMessage({
id: Identifier.ascending("message"),
role: "user",
sessionID: input.sessionID,
time: {
created: Date.now(),
},
agent: "build",
model: input.model,
})
await Session.updatePart({
id: Identifier.ascending("part"),
messageID: continueMsg.id,
sessionID: input.sessionID,
type: "text",
synthetic: true,
text: "Continue if you have next steps",
time: {
start: Date.now(),
end: Date.now(),
},
})
}
await Session.updateMessage(msg)
return {
info: msg,
parts: result.parts,
}
return "continue"
}
export const create = fn(
z.object({
sessionID: Identifier.schema("session"),
model: z.object({
providerID: z.string(),
modelID: z.string(),
}),
}),
async (input) => {
const msg = await Session.updateMessage({
id: Identifier.ascending("message"),
role: "user",
model: input.model,
sessionID: input.sessionID,
agent: "build",
time: {
created: Date.now(),
},
})
await Session.updatePart({
id: Identifier.ascending("part"),
messageID: msg.id,
sessionID: msg.sessionID,
type: "compaction",
})
},
)
}

View File

@@ -1,7 +1,6 @@
import { Decimal } from "decimal.js"
import z from "zod"
import { type LanguageModelUsage, type ProviderMetadata } from "ai"
import { Bus } from "../bus"
import { Config } from "../config/config"
import { Flag } from "../flag/flag"

View File

@@ -1,97 +0,0 @@
import z from "zod"
import { Instance } from "../project/instance"
import { Log } from "../util/log"
import { NamedError } from "../util/error"
export namespace SessionLock {
const log = Log.create({ service: "session.lock" })
export const LockedError = NamedError.create(
"SessionLockedError",
z.object({
sessionID: z.string(),
message: z.string(),
}),
)
type LockState = {
controller: AbortController
created: number
}
const state = Instance.state(
() => {
const locks = new Map<string, LockState>()
return {
locks,
}
},
async (current) => {
for (const [sessionID, lock] of current.locks) {
log.info("force abort", { sessionID })
lock.controller.abort()
}
current.locks.clear()
},
)
function get(sessionID: string) {
return state().locks.get(sessionID)
}
function unset(input: { sessionID: string; controller: AbortController }) {
const lock = get(input.sessionID)
if (!lock) return false
if (lock.controller !== input.controller) return false
state().locks.delete(input.sessionID)
return true
}
export function acquire(input: { sessionID: string }) {
const lock = get(input.sessionID)
if (lock) {
throw new LockedError({
sessionID: input.sessionID,
message: `Session ${input.sessionID} is locked`,
})
}
const controller = new AbortController()
state().locks.set(input.sessionID, {
controller,
created: Date.now(),
})
log.info("locked", { sessionID: input.sessionID })
return {
signal: controller.signal,
abort() {
controller.abort()
unset({ sessionID: input.sessionID, controller })
},
async [Symbol.dispose]() {
const removed = unset({ sessionID: input.sessionID, controller })
if (removed) {
log.info("unlocked", { sessionID: input.sessionID })
}
},
}
}
export function abort(sessionID: string) {
const lock = get(sessionID)
if (!lock) return false
log.info("abort", { sessionID })
lock.controller.abort()
state().locks.delete(sessionID)
return true
}
export function isLocked(sessionID: string) {
return get(sessionID) !== undefined
}
export function assertUnlocked(sessionID: string) {
const lock = get(sessionID)
if (!lock) return
throw new LockedError({ sessionID, message: `Session ${sessionID} is locked` })
}
}

View File

@@ -142,6 +142,21 @@ export namespace MessageV2 {
})
export type AgentPart = z.infer<typeof AgentPart>
export const CompactionPart = PartBase.extend({
type: z.literal("compaction"),
}).meta({
ref: "CompactionPart",
})
export type CompactionPart = z.infer<typeof CompactionPart>
export const SubtaskPart = PartBase.extend({
type: z.literal("subtask"),
prompt: z.string(),
description: z.string(),
agent: z.string(),
})
export type SubtaskPart = z.infer<typeof SubtaskPart>
export const RetryPart = PartBase.extend({
type: z.literal("retry"),
attempt: z.number(),
@@ -277,6 +292,13 @@ export namespace MessageV2 {
diffs: Snapshot.FileDiff.array(),
})
.optional(),
agent: z.string(),
model: z.object({
providerID: z.string(),
modelID: z.string(),
}),
system: z.string().optional(),
tools: z.record(z.string(), z.boolean()).optional(),
}).meta({
ref: "UserMessage",
})
@@ -285,6 +307,7 @@ export namespace MessageV2 {
export const Part = z
.discriminatedUnion("type", [
TextPart,
SubtaskPart,
ReasoningPart,
FilePart,
ToolPart,
@@ -294,6 +317,7 @@ export namespace MessageV2 {
PatchPart,
AgentPart,
RetryPart,
CompactionPart,
])
.meta({
ref: "Part",
@@ -334,6 +358,7 @@ export namespace MessageV2 {
write: z.number(),
}),
}),
finish: z.string().optional(),
}).meta({
ref: "AssistantMessage",
})
@@ -482,6 +507,11 @@ export namespace MessageV2 {
time: {
created: v1.metadata.time.created,
},
agent: "build",
model: {
providerID: "opencode",
modelID: "opencode",
},
}
const parts = v1.parts.flatMap((part): Part[] => {
const base = {
@@ -529,107 +559,107 @@ export namespace MessageV2 {
if (msg.parts.length === 0) continue
if (msg.info.role === "user") {
result.push({
const userMessage: UIMessage = {
id: msg.info.id,
role: "user",
parts: msg.parts.flatMap((part): UIMessage["parts"] => {
if (part.type === "text")
return [
{
type: "text",
text: part.text,
},
]
// text/plain and directory files are converted into text parts, ignore them
if (part.type === "file" && part.mime !== "text/plain" && part.mime !== "application/x-directory")
return [
{
type: "file",
url: part.url,
mediaType: part.mime,
filename: part.filename,
},
]
return []
}),
})
parts: [],
}
result.push(userMessage)
for (const part of msg.parts) {
if (part.type === "text")
userMessage.parts.push({
type: "text",
text: part.text,
})
// text/plain and directory files are converted into text parts, ignore them
if (part.type === "file" && part.mime !== "text/plain" && part.mime !== "application/x-directory")
userMessage.parts.push({
type: "file",
url: part.url,
mediaType: part.mime,
filename: part.filename,
})
if (part.type === "compaction") {
userMessage.parts.push({
type: "text",
text: "What did we do so far?",
})
}
if (part.type === "subtask") {
userMessage.parts.push({
type: "text",
text: "The following tool was executed by the user",
})
}
}
}
if (msg.info.role === "assistant") {
result.push({
const assistantMessage: UIMessage = {
id: msg.info.id,
role: "assistant",
parts: msg.parts.flatMap((part): UIMessage["parts"] => {
if (part.type === "text")
return [
{
type: "text",
text: part.text,
providerMetadata: part.metadata,
},
]
if (part.type === "step-start")
return [
{
type: "step-start",
},
]
if (part.type === "tool") {
if (part.state.status === "completed") {
if (part.state.attachments?.length) {
result.push({
id: Identifier.ascending("message"),
role: "user",
parts: [
{
type: "text",
text: `Tool ${part.tool} returned an attachment:`,
},
...part.state.attachments.map((attachment) => ({
type: "file" as const,
url: attachment.url,
mediaType: attachment.mime,
filename: attachment.filename,
})),
],
})
}
return [
{
type: ("tool-" + part.tool) as `tool-${string}`,
state: "output-available",
toolCallId: part.callID,
input: part.state.input,
output: part.state.time.compacted ? "[Old tool result content cleared]" : part.state.output,
callProviderMetadata: part.metadata,
},
]
parts: [],
}
result.push(assistantMessage)
for (const part of msg.parts) {
if (part.type === "text")
assistantMessage.parts.push({
type: "text",
text: part.text,
providerMetadata: part.metadata,
})
if (part.type === "step-start")
assistantMessage.parts.push({
type: "step-start",
})
if (part.type === "tool") {
if (part.state.status === "completed") {
if (part.state.attachments?.length) {
result.push({
id: Identifier.ascending("message"),
role: "user",
parts: [
{
type: "text",
text: `Tool ${part.tool} returned an attachment:`,
},
...part.state.attachments.map((attachment) => ({
type: "file" as const,
url: attachment.url,
mediaType: attachment.mime,
filename: attachment.filename,
})),
],
})
}
if (part.state.status === "error")
return [
{
type: ("tool-" + part.tool) as `tool-${string}`,
state: "output-error",
toolCallId: part.callID,
input: part.state.input,
errorText: part.state.error,
callProviderMetadata: part.metadata,
},
]
assistantMessage.parts.push({
type: ("tool-" + part.tool) as `tool-${string}`,
state: "output-available",
toolCallId: part.callID,
input: part.state.input,
output: part.state.time.compacted ? "[Old tool result content cleared]" : part.state.output,
callProviderMetadata: part.metadata,
})
}
if (part.type === "reasoning") {
return [
{
type: "reasoning",
text: part.text,
providerMetadata: part.metadata,
},
]
}
return []
}),
})
if (part.state.status === "error")
assistantMessage.parts.push({
type: ("tool-" + part.tool) as `tool-${string}`,
state: "output-error",
toolCallId: part.callID,
input: part.state.input,
errorText: part.state.error,
callProviderMetadata: part.metadata,
})
}
if (part.type === "reasoning") {
assistantMessage.parts.push({
type: "reasoning",
text: part.text,
providerMetadata: part.metadata,
})
}
}
}
}
@@ -671,9 +701,16 @@ export namespace MessageV2 {
export async function filterCompacted(stream: AsyncIterable<MessageV2.WithParts>) {
const result = [] as MessageV2.WithParts[]
const completed = new Set<string>()
for await (const msg of stream) {
result.push(msg)
if (msg.info.role === "assistant" && msg.info.summary === true) break
if (
msg.info.role === "user" &&
completed.has(msg.info.id) &&
msg.parts.some((part) => part.type === "compaction")
)
break
if (msg.info.role === "assistant" && msg.info.summary && msg.info.finish) completed.add(msg.info.parentID)
}
result.reverse()
return result

View File

@@ -0,0 +1,372 @@
import type { ModelsDev } from "@/provider/models"
import { MessageV2 } from "./message-v2"
import { type StreamTextResult, type Tool as AITool, APICallError } from "ai"
import { Log } from "@/util/log"
import { Identifier } from "@/id/id"
import { Session } from "."
import { Agent } from "@/agent/agent"
import { Permission } from "@/permission"
import { Snapshot } from "@/snapshot"
import { SessionSummary } from "./summary"
import { Bus } from "@/bus"
import { SessionRetry } from "./retry"
import { SessionStatus } from "./status"
export namespace SessionProcessor {
const DOOM_LOOP_THRESHOLD = 3
const log = Log.create({ service: "session.processor" })
export type Info = Awaited<ReturnType<typeof create>>
export type Result = Awaited<ReturnType<Info["process"]>>
export function create(input: {
assistantMessage: MessageV2.Assistant
sessionID: string
providerID: string
model: ModelsDev.Model
abort: AbortSignal
}) {
const toolcalls: Record<string, MessageV2.ToolPart> = {}
let snapshot: string | undefined
let blocked = false
let attempt = 0
const result = {
get message() {
return input.assistantMessage
},
partFromToolCall(toolCallID: string) {
return toolcalls[toolCallID]
},
async process(fn: () => StreamTextResult<Record<string, AITool>, never>) {
log.info("process")
while (true) {
try {
let currentText: MessageV2.TextPart | undefined
let reasoningMap: Record<string, MessageV2.ReasoningPart> = {}
const stream = fn()
for await (const value of stream.fullStream) {
input.abort.throwIfAborted()
switch (value.type) {
case "start":
SessionStatus.set(input.sessionID, { type: "busy" })
break
case "reasoning-start":
if (value.id in reasoningMap) {
continue
}
reasoningMap[value.id] = {
id: Identifier.ascending("part"),
messageID: input.assistantMessage.id,
sessionID: input.assistantMessage.sessionID,
type: "reasoning",
text: "",
time: {
start: Date.now(),
},
metadata: value.providerMetadata,
}
break
case "reasoning-delta":
if (value.id in reasoningMap) {
const part = reasoningMap[value.id]
part.text += value.text
if (value.providerMetadata) part.metadata = value.providerMetadata
if (part.text) await Session.updatePart({ part, delta: value.text })
}
break
case "reasoning-end":
if (value.id in reasoningMap) {
const part = reasoningMap[value.id]
part.text = part.text.trimEnd()
part.time = {
...part.time,
end: Date.now(),
}
if (value.providerMetadata) part.metadata = value.providerMetadata
await Session.updatePart(part)
delete reasoningMap[value.id]
}
break
case "tool-input-start":
const part = await Session.updatePart({
id: toolcalls[value.id]?.id ?? Identifier.ascending("part"),
messageID: input.assistantMessage.id,
sessionID: input.assistantMessage.sessionID,
type: "tool",
tool: value.toolName,
callID: value.id,
state: {
status: "pending",
input: {},
raw: "",
},
})
toolcalls[value.id] = part as MessageV2.ToolPart
break
case "tool-input-delta":
break
case "tool-input-end":
break
case "tool-call": {
const match = toolcalls[value.toolCallId]
if (match) {
const part = await Session.updatePart({
...match,
tool: value.toolName,
state: {
status: "running",
input: value.input,
time: {
start: Date.now(),
},
},
metadata: value.providerMetadata,
})
toolcalls[value.toolCallId] = part as MessageV2.ToolPart
const parts = await MessageV2.parts(input.assistantMessage.id)
const lastThree = parts.slice(-DOOM_LOOP_THRESHOLD)
if (
lastThree.length === DOOM_LOOP_THRESHOLD &&
lastThree.every(
(p) =>
p.type === "tool" &&
p.tool === value.toolName &&
p.state.status !== "pending" &&
JSON.stringify(p.state.input) === JSON.stringify(value.input),
)
) {
const permission = await Agent.get(input.assistantMessage.mode).then((x) => x.permission)
if (permission.doom_loop === "ask") {
await Permission.ask({
type: "doom_loop",
pattern: value.toolName,
sessionID: input.assistantMessage.sessionID,
messageID: input.assistantMessage.id,
callID: value.toolCallId,
title: `Possible doom loop: "${value.toolName}" called ${DOOM_LOOP_THRESHOLD} times with identical arguments`,
metadata: {
tool: value.toolName,
input: value.input,
},
})
}
}
}
break
}
case "tool-result": {
const match = toolcalls[value.toolCallId]
if (match && match.state.status === "running") {
await Session.updatePart({
...match,
state: {
status: "completed",
input: value.input,
output: value.output.output,
metadata: value.output.metadata,
title: value.output.title,
time: {
start: match.state.time.start,
end: Date.now(),
},
attachments: value.output.attachments,
},
})
delete toolcalls[value.toolCallId]
}
break
}
case "tool-error": {
const match = toolcalls[value.toolCallId]
if (match && match.state.status === "running") {
await Session.updatePart({
...match,
state: {
status: "error",
input: value.input,
error: (value.error as any).toString(),
metadata: value.error instanceof Permission.RejectedError ? value.error.metadata : undefined,
time: {
start: match.state.time.start,
end: Date.now(),
},
},
})
if (value.error instanceof Permission.RejectedError) {
blocked = true
}
delete toolcalls[value.toolCallId]
}
break
}
case "error":
throw value.error
case "start-step":
snapshot = await Snapshot.track()
await Session.updatePart({
id: Identifier.ascending("part"),
messageID: input.assistantMessage.id,
sessionID: input.sessionID,
snapshot,
type: "step-start",
})
break
case "finish-step":
const usage = Session.getUsage({
model: input.model,
usage: value.usage,
metadata: value.providerMetadata,
})
input.assistantMessage.finish = value.finishReason
input.assistantMessage.cost += usage.cost
input.assistantMessage.tokens = usage.tokens
await Session.updatePart({
id: Identifier.ascending("part"),
reason: value.finishReason,
snapshot: await Snapshot.track(),
messageID: input.assistantMessage.id,
sessionID: input.assistantMessage.sessionID,
type: "step-finish",
tokens: usage.tokens,
cost: usage.cost,
})
await Session.updateMessage(input.assistantMessage)
if (snapshot) {
const patch = await Snapshot.patch(snapshot)
if (patch.files.length) {
await Session.updatePart({
id: Identifier.ascending("part"),
messageID: input.assistantMessage.id,
sessionID: input.sessionID,
type: "patch",
hash: patch.hash,
files: patch.files,
})
}
snapshot = undefined
}
SessionSummary.summarize({
sessionID: input.sessionID,
messageID: input.assistantMessage.parentID,
})
break
case "text-start":
currentText = {
id: Identifier.ascending("part"),
messageID: input.assistantMessage.id,
sessionID: input.assistantMessage.sessionID,
type: "text",
text: "",
time: {
start: Date.now(),
},
metadata: value.providerMetadata,
}
break
case "text-delta":
if (currentText) {
currentText.text += value.text
if (value.providerMetadata) currentText.metadata = value.providerMetadata
if (currentText.text)
await Session.updatePart({
part: currentText,
delta: value.text,
})
}
break
case "text-end":
if (currentText) {
currentText.text = currentText.text.trimEnd()
currentText.time = {
start: Date.now(),
end: Date.now(),
}
if (value.providerMetadata) currentText.metadata = value.providerMetadata
await Session.updatePart(currentText)
}
currentText = undefined
break
case "finish":
input.assistantMessage.time.completed = Date.now()
await Session.updateMessage(input.assistantMessage)
break
default:
log.info("unhandled", {
...value,
})
continue
}
}
} catch (e) {
log.error("process", {
error: e,
})
const error = MessageV2.fromError(e, { providerID: input.providerID })
if (error?.name === "APIError" && error.data.isRetryable) {
attempt++
const delay = SessionRetry.getRetryDelayInMs(error, attempt)
if (delay) {
SessionStatus.set(input.sessionID, {
type: "retry",
attempt,
message: error.data.message,
})
await SessionRetry.sleep(delay, input.abort).catch(() => {})
continue
}
}
input.assistantMessage.error = error
Bus.publish(Session.Event.Error, {
sessionID: input.assistantMessage.sessionID,
error: input.assistantMessage.error,
})
}
const p = await MessageV2.parts(input.assistantMessage.id)
for (const part of p) {
if (part.type === "tool" && part.state.status !== "completed" && part.state.status !== "error") {
await Session.updatePart({
...part,
state: {
...part.state,
status: "error",
error: "Tool execution aborted",
time: {
start: Date.now(),
end: Date.now(),
},
},
})
}
}
input.assistantMessage.time.completed = Date.now()
await Session.updateMessage(input.assistantMessage)
if (blocked) return "stop"
if (input.assistantMessage.error) return "stop"
return "continue"
}
},
}
return result
}
}

File diff suppressed because it is too large Load Diff

View File

@@ -7,7 +7,7 @@ import { Log } from "../util/log"
import { splitWhen } from "remeda"
import { Storage } from "../storage/storage"
import { Bus } from "../bus"
import { SessionLock } from "./lock"
import { SessionPrompt } from "./prompt"
export namespace SessionRevert {
const log = Log.create({ service: "session.revert" })
@@ -20,11 +20,7 @@ export namespace SessionRevert {
export type RevertInput = z.infer<typeof RevertInput>
export async function revert(input: RevertInput) {
SessionLock.assertUnlocked(input.sessionID)
using _ = SessionLock.acquire({
sessionID: input.sessionID,
})
SessionPrompt.assertNotBusy(input.sessionID)
const all = await Session.messages({ sessionID: input.sessionID })
let lastUser: MessageV2.User | undefined
const session = await Session.get(input.sessionID)
@@ -70,10 +66,7 @@ export namespace SessionRevert {
export async function unrevert(input: { sessionID: string }) {
log.info("unreverting", input)
SessionLock.assertUnlocked(input.sessionID)
using _ = SessionLock.acquire({
sessionID: input.sessionID,
})
SessionPrompt.assertNotBusy(input.sessionID)
const session = await Session.get(input.sessionID)
if (!session.revert) return session
if (session.revert.snapshot) await Snapshot.restore(session.revert.snapshot)

View File

@@ -0,0 +1,63 @@
import { Bus } from "@/bus"
import { Instance } from "@/project/instance"
import z from "zod"
export namespace SessionStatus {
export const Info = z
.union([
z.object({
type: z.literal("idle"),
}),
z.object({
type: z.literal("retry"),
attempt: z.number(),
message: z.string(),
}),
z.object({
type: z.literal("busy"),
}),
])
.meta({
ref: "SessionStatus",
})
export type Info = z.infer<typeof Info>
export const Event = {
Status: Bus.event(
"session.status",
z.object({
sessionID: z.string(),
status: Info,
}),
),
}
const state = Instance.state(() => {
const data: Record<string, Info> = {}
return data
})
export function get(sessionID: string) {
return (
state()[sessionID] ?? {
type: "idle",
}
)
}
export function list() {
return Object.values(state())
}
export function set(sessionID: string, status: Info) {
Bus.publish(Event.Status, {
sessionID,
status,
})
if (status.type === "idle") {
delete state()[sessionID]
return
}
state()[sessionID] = status
}
}

View File

@@ -43,7 +43,7 @@ export namespace SystemPrompt {
` Platform: ${process.platform}`,
` Today's date: ${new Date().toDateString()}`,
`</env>`,
`<project>`,
`<files>`,
` ${
project.vcs === "git"
? await Ripgrep.tree({
@@ -52,7 +52,7 @@ export namespace SystemPrompt {
})
: ""
}`,
`</project>`,
`</files>`,
].join("\n"),
]
}