mirror of
https://gitea.toothfairyai.com/ToothFairyAI/tf_code.git
synced 2026-04-02 23:23:45 +00:00
feat: retry parts (#3369)
This commit is contained in:
@@ -1,4 +1,4 @@
|
||||
import { streamText, type ModelMessage, LoadAPIKeyError } from "ai"
|
||||
import { streamText, type ModelMessage, LoadAPIKeyError, type StreamTextResult, type Tool as AITool } from "ai"
|
||||
import { Session } from "."
|
||||
import { Identifier } from "../id/id"
|
||||
import { Instance } from "../project/instance"
|
||||
@@ -14,8 +14,8 @@ import { Flag } from "../flag/flag"
|
||||
import { Token } from "../util/token"
|
||||
import { Log } from "../util/log"
|
||||
import { SessionLock } from "./lock"
|
||||
import { NamedError } from "../util/error"
|
||||
import { ProviderTransform } from "@/provider/transform"
|
||||
import { SessionRetry } from "./retry"
|
||||
|
||||
export namespace SessionCompaction {
|
||||
const log = Log.create({ service: "session.compaction" })
|
||||
@@ -41,6 +41,7 @@ export namespace SessionCompaction {
|
||||
|
||||
export const PRUNE_MINIMUM = 20_000
|
||||
export const PRUNE_PROTECT = 40_000
|
||||
const MAX_RETRIES = 10
|
||||
|
||||
// goes backwards through parts until there are 40_000 tokens worth of tool
|
||||
// calls. then erases output of previous tool calls. idea is to throw away old
|
||||
@@ -142,112 +143,173 @@ export namespace SessionCompaction {
|
||||
},
|
||||
})) as MessageV2.TextPart
|
||||
|
||||
const stream = streamText({
|
||||
maxRetries: 10,
|
||||
model: model.language,
|
||||
providerOptions: ProviderTransform.providerOptions(model.npm, model.providerID, model.info.options),
|
||||
abortSignal: signal,
|
||||
onError(error) {
|
||||
log.error("stream error", {
|
||||
error,
|
||||
})
|
||||
},
|
||||
messages: [
|
||||
...system.map(
|
||||
(x): ModelMessage => ({
|
||||
role: "system",
|
||||
content: x,
|
||||
}),
|
||||
),
|
||||
...MessageV2.toModelMessage(toSummarize),
|
||||
{
|
||||
role: "user",
|
||||
content: [
|
||||
{
|
||||
type: "text",
|
||||
text: "Provide a detailed but concise summary of our conversation above. Focus on information that would be helpful for continuing the conversation, including what we did, what we're doing, which files we're working on, and what we're going to do next.",
|
||||
},
|
||||
],
|
||||
const doStream = () =>
|
||||
streamText({
|
||||
// set to 0, we handle loop
|
||||
maxRetries: 0,
|
||||
model: model.language,
|
||||
providerOptions: ProviderTransform.providerOptions(model.npm, model.providerID, model.info.options),
|
||||
abortSignal: signal,
|
||||
onError(error) {
|
||||
log.error("stream error", {
|
||||
error,
|
||||
})
|
||||
},
|
||||
],
|
||||
})
|
||||
messages: [
|
||||
...system.map(
|
||||
(x): ModelMessage => ({
|
||||
role: "system",
|
||||
content: x,
|
||||
}),
|
||||
),
|
||||
...MessageV2.toModelMessage(toSummarize),
|
||||
{
|
||||
role: "user",
|
||||
content: [
|
||||
{
|
||||
type: "text",
|
||||
text: "Provide a detailed but concise summary of our conversation above. Focus on information that would be helpful for continuing the conversation, including what we did, what we're doing, which files we're working on, and what we're going to do next.",
|
||||
},
|
||||
],
|
||||
},
|
||||
],
|
||||
})
|
||||
|
||||
try {
|
||||
for await (const value of stream.fullStream) {
|
||||
signal.throwIfAborted()
|
||||
switch (value.type) {
|
||||
case "text-delta":
|
||||
part.text += value.text
|
||||
if (value.providerMetadata) part.metadata = value.providerMetadata
|
||||
if (part.text) await Session.updatePart(part)
|
||||
continue
|
||||
case "text-end": {
|
||||
part.text = part.text.trimEnd()
|
||||
part.time = {
|
||||
start: Date.now(),
|
||||
end: Date.now(),
|
||||
// TODO: reduce duplication between compaction.ts & prompt.ts
|
||||
const process = async (
|
||||
stream: StreamTextResult<Record<string, AITool>, never>,
|
||||
retries: { count: number; max: number },
|
||||
) => {
|
||||
let shouldRetry = false
|
||||
try {
|
||||
for await (const value of stream.fullStream) {
|
||||
signal.throwIfAborted()
|
||||
switch (value.type) {
|
||||
case "text-delta":
|
||||
part.text += value.text
|
||||
if (value.providerMetadata) part.metadata = value.providerMetadata
|
||||
if (part.text) await Session.updatePart(part)
|
||||
continue
|
||||
case "text-end": {
|
||||
part.text = part.text.trimEnd()
|
||||
part.time = {
|
||||
start: Date.now(),
|
||||
end: Date.now(),
|
||||
}
|
||||
if (value.providerMetadata) part.metadata = value.providerMetadata
|
||||
await Session.updatePart(part)
|
||||
continue
|
||||
}
|
||||
if (value.providerMetadata) part.metadata = value.providerMetadata
|
||||
await Session.updatePart(part)
|
||||
continue
|
||||
case "finish-step": {
|
||||
const usage = Session.getUsage({
|
||||
model: model.info,
|
||||
usage: value.usage,
|
||||
metadata: value.providerMetadata,
|
||||
})
|
||||
msg.cost += usage.cost
|
||||
msg.tokens = usage.tokens
|
||||
await Session.updateMessage(msg)
|
||||
continue
|
||||
}
|
||||
case "error":
|
||||
throw value.error
|
||||
default:
|
||||
continue
|
||||
}
|
||||
case "finish-step": {
|
||||
const usage = Session.getUsage({
|
||||
model: model.info,
|
||||
usage: value.usage,
|
||||
metadata: value.providerMetadata,
|
||||
})
|
||||
msg.cost += usage.cost
|
||||
msg.tokens = usage.tokens
|
||||
await Session.updateMessage(msg)
|
||||
continue
|
||||
}
|
||||
case "error":
|
||||
throw value.error
|
||||
default:
|
||||
continue
|
||||
}
|
||||
} catch (e) {
|
||||
log.error("compaction error", {
|
||||
error: e,
|
||||
})
|
||||
const error = MessageV2.fromError(e, { providerID: input.providerID })
|
||||
if (retries.count < retries.max && MessageV2.APIError.isInstance(error) && error.data.isRetryable) {
|
||||
shouldRetry = true
|
||||
await Session.updatePart({
|
||||
id: Identifier.ascending("part"),
|
||||
messageID: msg.id,
|
||||
sessionID: msg.sessionID,
|
||||
type: "retry",
|
||||
attempt: retries.count + 1,
|
||||
time: {
|
||||
created: Date.now(),
|
||||
},
|
||||
error,
|
||||
})
|
||||
} else {
|
||||
msg.error = error
|
||||
Bus.publish(Session.Event.Error, {
|
||||
sessionID: msg.sessionID,
|
||||
error: msg.error,
|
||||
})
|
||||
}
|
||||
}
|
||||
} catch (e) {
|
||||
log.error("compaction error", {
|
||||
error: e,
|
||||
})
|
||||
switch (true) {
|
||||
case e instanceof DOMException && e.name === "AbortError":
|
||||
msg.error = new MessageV2.AbortedError(
|
||||
{ message: e.message },
|
||||
{
|
||||
cause: e,
|
||||
},
|
||||
).toObject()
|
||||
break
|
||||
case MessageV2.OutputLengthError.isInstance(e):
|
||||
msg.error = e
|
||||
break
|
||||
case LoadAPIKeyError.isInstance(e):
|
||||
msg.error = new MessageV2.AuthError(
|
||||
{
|
||||
providerID: model.providerID,
|
||||
message: e.message,
|
||||
},
|
||||
{ cause: e },
|
||||
).toObject()
|
||||
break
|
||||
case e instanceof Error:
|
||||
msg.error = new NamedError.Unknown({ message: e.toString() }, { cause: e }).toObject()
|
||||
break
|
||||
default:
|
||||
msg.error = new NamedError.Unknown({ message: JSON.stringify(e) }, { cause: e })
|
||||
|
||||
const parts = await Session.getParts(msg.id)
|
||||
return {
|
||||
info: msg,
|
||||
parts,
|
||||
shouldRetry,
|
||||
}
|
||||
}
|
||||
|
||||
let stream = doStream()
|
||||
let result = await process(stream, {
|
||||
count: 0,
|
||||
max: MAX_RETRIES,
|
||||
})
|
||||
if (result.shouldRetry) {
|
||||
for (let retry = 1; retry < MAX_RETRIES; retry++) {
|
||||
const lastRetryPart = result.parts.findLast((p) => p.type === "retry")
|
||||
|
||||
if (lastRetryPart) {
|
||||
const delayMs = SessionRetry.getRetryDelayInMs(lastRetryPart.error, retry)
|
||||
|
||||
log.info("retrying with backoff", {
|
||||
attempt: retry,
|
||||
delayMs,
|
||||
})
|
||||
|
||||
const stop = await SessionRetry.sleep(delayMs, signal)
|
||||
.then(() => false)
|
||||
.catch((error) => {
|
||||
if (error instanceof DOMException && error.name === "AbortError") {
|
||||
const err = new MessageV2.AbortedError(
|
||||
{ message: error.message },
|
||||
{
|
||||
cause: error,
|
||||
},
|
||||
).toObject()
|
||||
result.info.error = err
|
||||
Bus.publish(Session.Event.Error, {
|
||||
sessionID: result.info.sessionID,
|
||||
error: result.info.error,
|
||||
})
|
||||
return true
|
||||
}
|
||||
throw error
|
||||
})
|
||||
|
||||
if (stop) break
|
||||
}
|
||||
|
||||
stream = doStream()
|
||||
result = await process(stream, {
|
||||
count: retry,
|
||||
max: MAX_RETRIES,
|
||||
})
|
||||
if (!result.shouldRetry) {
|
||||
break
|
||||
}
|
||||
}
|
||||
Bus.publish(Session.Event.Error, {
|
||||
sessionID: input.sessionID,
|
||||
error: msg.error,
|
||||
})
|
||||
}
|
||||
|
||||
msg.time.completed = Date.now()
|
||||
|
||||
if (!msg.error || MessageV2.AbortedError.isInstance(msg.error)) {
|
||||
if (
|
||||
!msg.error ||
|
||||
(MessageV2.AbortedError.isInstance(msg.error) &&
|
||||
result.parts.some((part) => part.type === "text" && part.text.length > 0))
|
||||
) {
|
||||
msg.summary = true
|
||||
Bus.publish(Event.Compacted, {
|
||||
sessionID: input.sessionID,
|
||||
@@ -257,7 +319,7 @@ export namespace SessionCompaction {
|
||||
|
||||
return {
|
||||
info: msg,
|
||||
parts: [part],
|
||||
parts: result.parts,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user