mirror of
https://gitea.toothfairyai.com/ToothFairyAI/tf_code.git
synced 2026-04-02 23:23:45 +00:00
compaction improvements
This commit is contained in:
@@ -86,6 +86,7 @@ export namespace Session {
|
||||
time: z.object({
|
||||
created: z.number(),
|
||||
updated: z.number(),
|
||||
compacting: z.number().optional(),
|
||||
}),
|
||||
revert: z
|
||||
.object({
|
||||
@@ -137,12 +138,17 @@ export namespace Session {
|
||||
error: MessageV2.Assistant.shape.error,
|
||||
}),
|
||||
),
|
||||
Compacted: Bus.event(
|
||||
"session.compacted",
|
||||
z.object({
|
||||
sessionID: z.string(),
|
||||
}),
|
||||
),
|
||||
}
|
||||
|
||||
const state = Instance.state(
|
||||
() => {
|
||||
const pending = new Map<string, AbortController>()
|
||||
const autoCompacting = new Map<string, boolean>()
|
||||
const queued = new Map<
|
||||
string,
|
||||
{
|
||||
@@ -156,7 +162,6 @@ export namespace Session {
|
||||
|
||||
return {
|
||||
pending,
|
||||
autoCompacting,
|
||||
queued,
|
||||
}
|
||||
},
|
||||
@@ -714,24 +719,8 @@ export namespace Session {
|
||||
})().then((x) => Provider.getModel(x.providerID, x.modelID))
|
||||
let msgs = await messages(input.sessionID)
|
||||
|
||||
const previous = msgs.filter((x) => x.info.role === "assistant").at(-1)?.info as MessageV2.Assistant
|
||||
const outputLimit = Math.min(model.info.limit.output, OUTPUT_TOKEN_MAX) || OUTPUT_TOKEN_MAX
|
||||
|
||||
// auto summarize if too long
|
||||
if (previous && previous.tokens) {
|
||||
const tokens =
|
||||
previous.tokens.input + previous.tokens.cache.read + previous.tokens.cache.write + previous.tokens.output
|
||||
if (model.info.limit.context && tokens > Math.max((model.info.limit.context - outputLimit) * 0.9, 0)) {
|
||||
state().autoCompacting.set(input.sessionID, true)
|
||||
|
||||
await summarize({
|
||||
sessionID: input.sessionID,
|
||||
providerID: model.providerID,
|
||||
modelID: model.info.id,
|
||||
})
|
||||
return prompt(input)
|
||||
}
|
||||
}
|
||||
using abort = lock(input.sessionID)
|
||||
|
||||
const lastSummary = msgs.findLast((msg) => msg.info.role === "assistant" && msg.info.summary === true)
|
||||
@@ -999,7 +988,38 @@ export namespace Session {
|
||||
error: e,
|
||||
})
|
||||
},
|
||||
async prepareStep({ messages }) {
|
||||
async prepareStep({ messages, steps }) {
|
||||
// Auto compact if too long
|
||||
const tokens = (() => {
|
||||
if (steps.length) {
|
||||
const previous = steps.at(-1)
|
||||
if (previous) return getUsage(model.info, previous.usage, previous.providerMetadata).tokens
|
||||
}
|
||||
const msg = msgs.findLast((x) => x.info.role === "assistant")?.info as MessageV2.Assistant
|
||||
if (msg && msg.tokens) {
|
||||
return msg.tokens
|
||||
}
|
||||
})()
|
||||
if (tokens) {
|
||||
log.info("compact check", tokens)
|
||||
const count = tokens.input + tokens.cache.read + tokens.cache.write + tokens.output
|
||||
if (model.info.limit.context && count > Math.max((model.info.limit.context - outputLimit) * 0.9, 0)) {
|
||||
log.info("compacting in prepareStep")
|
||||
const summarized = await summarize({
|
||||
sessionID: input.sessionID,
|
||||
providerID: model.providerID,
|
||||
modelID: model.info.id,
|
||||
})
|
||||
const msgs = await Session.messages(input.sessionID).then((x) =>
|
||||
x.filter((x) => x.info.id >= summarized.id),
|
||||
)
|
||||
return {
|
||||
messages: MessageV2.toModelMessage(msgs),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Add queued messages to the stream
|
||||
const queue = (state().queued.get(input.sessionID) ?? []).filter((x) => !x.processed)
|
||||
if (queue.length) {
|
||||
for (const item of queue) {
|
||||
@@ -1756,10 +1776,22 @@ export namespace Session {
|
||||
}
|
||||
|
||||
export async function summarize(input: { sessionID: string; providerID: string; modelID: string }) {
|
||||
using abort = lock(input.sessionID)
|
||||
await update(input.sessionID, (draft) => {
|
||||
draft.time.compacting = Date.now()
|
||||
})
|
||||
await using _ = defer(async () => {
|
||||
await update(input.sessionID, (draft) => {
|
||||
draft.time.compacting = undefined
|
||||
})
|
||||
})
|
||||
const msgs = await messages(input.sessionID)
|
||||
const lastSummary = msgs.findLast((msg) => msg.info.role === "assistant" && msg.info.summary === true)
|
||||
const filtered = msgs.filter((msg) => !lastSummary || msg.info.id >= lastSummary.info.id)
|
||||
const start = Math.max(
|
||||
0,
|
||||
msgs.findLastIndex((msg) => msg.info.role === "assistant" && msg.info.summary === true),
|
||||
)
|
||||
const split = start + Math.floor((msgs.length - start) / 2)
|
||||
log.info("summarizing", { start, split })
|
||||
const toSummarize = msgs.slice(start, split)
|
||||
const model = await Provider.getModel(input.providerID, input.modelID)
|
||||
const system = [
|
||||
...SystemPrompt.summarize(model.providerID),
|
||||
@@ -1767,36 +1799,8 @@ export namespace Session {
|
||||
...(await SystemPrompt.custom()),
|
||||
]
|
||||
|
||||
const next: MessageV2.Info = {
|
||||
id: Identifier.ascending("message"),
|
||||
role: "assistant",
|
||||
sessionID: input.sessionID,
|
||||
system,
|
||||
mode: "build",
|
||||
path: {
|
||||
cwd: Instance.directory,
|
||||
root: Instance.worktree,
|
||||
},
|
||||
summary: true,
|
||||
cost: 0,
|
||||
modelID: input.modelID,
|
||||
providerID: model.providerID,
|
||||
tokens: {
|
||||
input: 0,
|
||||
output: 0,
|
||||
reasoning: 0,
|
||||
cache: { read: 0, write: 0 },
|
||||
},
|
||||
time: {
|
||||
created: Date.now(),
|
||||
},
|
||||
}
|
||||
await updateMessage(next)
|
||||
|
||||
const processor = createProcessor(next, model.info)
|
||||
const stream = streamText({
|
||||
const generated = await generateText({
|
||||
maxRetries: 10,
|
||||
abortSignal: abort.signal,
|
||||
model: model.language,
|
||||
messages: [
|
||||
...system.map(
|
||||
@@ -1805,7 +1809,7 @@ export namespace Session {
|
||||
content: x,
|
||||
}),
|
||||
),
|
||||
...MessageV2.toModelMessage(filtered),
|
||||
...MessageV2.toModelMessage(toSummarize),
|
||||
{
|
||||
role: "user",
|
||||
content: [
|
||||
@@ -1817,9 +1821,45 @@ export namespace Session {
|
||||
},
|
||||
],
|
||||
})
|
||||
const usage = getUsage(model.info, generated.usage, generated.providerMetadata)
|
||||
const msg: MessageV2.Info = {
|
||||
id: Identifier.create("message", false, toSummarize.at(-1)!.info.time.created + 1),
|
||||
role: "assistant",
|
||||
sessionID: input.sessionID,
|
||||
system,
|
||||
mode: "build",
|
||||
path: {
|
||||
cwd: Instance.directory,
|
||||
root: Instance.worktree,
|
||||
},
|
||||
summary: true,
|
||||
cost: usage.cost,
|
||||
tokens: usage.tokens,
|
||||
modelID: input.modelID,
|
||||
providerID: model.providerID,
|
||||
time: {
|
||||
created: Date.now(),
|
||||
completed: Date.now(),
|
||||
},
|
||||
}
|
||||
await updateMessage(msg)
|
||||
await updatePart({
|
||||
type: "text",
|
||||
sessionID: input.sessionID,
|
||||
messageID: msg.id,
|
||||
id: Identifier.ascending("part"),
|
||||
text: generated.text,
|
||||
time: {
|
||||
start: Date.now(),
|
||||
end: Date.now(),
|
||||
},
|
||||
})
|
||||
|
||||
const result = await processor.process(stream)
|
||||
return result
|
||||
Bus.publish(Event.Compacted, {
|
||||
sessionID: input.sessionID,
|
||||
})
|
||||
|
||||
return msg
|
||||
}
|
||||
|
||||
function isLocked(sessionID: string) {
|
||||
@@ -1837,12 +1877,6 @@ export namespace Session {
|
||||
log.info("unlocking", { sessionID })
|
||||
state().pending.delete(sessionID)
|
||||
|
||||
const isAutoCompacting = state().autoCompacting.get(sessionID) ?? false
|
||||
if (isAutoCompacting) {
|
||||
state().autoCompacting.delete(sessionID)
|
||||
return
|
||||
}
|
||||
|
||||
const session = await get(sessionID)
|
||||
if (session.parentID) return
|
||||
|
||||
|
||||
Reference in New Issue
Block a user