mirror of
https://gitea.toothfairyai.com/ToothFairyAI/tf_code.git
synced 2026-03-30 05:43:55 +00:00
218 lines
6.3 KiB
TypeScript
218 lines
6.3 KiB
TypeScript
import { Provider } from "@/provider/provider"
|
|
|
|
import { fn } from "@/util/fn"
|
|
import z from "zod"
|
|
import { Session } from "."
|
|
|
|
import { MessageV2 } from "./message-v2"
|
|
import { Identifier } from "@/id/id"
|
|
import { Snapshot } from "@/snapshot"
|
|
|
|
import { Log } from "@/util/log"
|
|
import path from "path"
|
|
import { Instance } from "@/project/instance"
|
|
import { Storage } from "@/storage/storage"
|
|
import { Bus } from "@/bus"
|
|
|
|
import { LLM } from "./llm"
|
|
import { Agent } from "@/agent/agent"
|
|
|
|
export namespace SessionSummary {
|
|
const log = Log.create({ service: "session.summary" })
|
|
|
|
function unquoteGitPath(input: string) {
|
|
if (!input.startsWith('"')) return input
|
|
if (!input.endsWith('"')) return input
|
|
const body = input.slice(1, -1)
|
|
const bytes: number[] = []
|
|
|
|
for (let i = 0; i < body.length; i++) {
|
|
const char = body[i]!
|
|
if (char !== "\\") {
|
|
bytes.push(char.charCodeAt(0))
|
|
continue
|
|
}
|
|
|
|
const next = body[i + 1]
|
|
if (!next) {
|
|
bytes.push("\\".charCodeAt(0))
|
|
continue
|
|
}
|
|
|
|
if (next >= "0" && next <= "7") {
|
|
const chunk = body.slice(i + 1, i + 4)
|
|
const match = chunk.match(/^[0-7]{1,3}/)
|
|
if (!match) {
|
|
bytes.push(next.charCodeAt(0))
|
|
i++
|
|
continue
|
|
}
|
|
bytes.push(parseInt(match[0], 8))
|
|
i += match[0].length
|
|
continue
|
|
}
|
|
|
|
const escaped =
|
|
next === "n"
|
|
? "\n"
|
|
: next === "r"
|
|
? "\r"
|
|
: next === "t"
|
|
? "\t"
|
|
: next === "b"
|
|
? "\b"
|
|
: next === "f"
|
|
? "\f"
|
|
: next === "v"
|
|
? "\v"
|
|
: next === "\\" || next === '"'
|
|
? next
|
|
: undefined
|
|
|
|
bytes.push((escaped ?? next).charCodeAt(0))
|
|
i++
|
|
}
|
|
|
|
return Buffer.from(bytes).toString()
|
|
}
|
|
|
|
export const summarize = fn(
|
|
z.object({
|
|
sessionID: z.string(),
|
|
messageID: z.string(),
|
|
}),
|
|
async (input) => {
|
|
const all = await Session.messages({ sessionID: input.sessionID })
|
|
await Promise.all([
|
|
summarizeSession({ sessionID: input.sessionID, messages: all }),
|
|
summarizeMessage({ messageID: input.messageID, messages: all }),
|
|
])
|
|
},
|
|
)
|
|
|
|
async function summarizeSession(input: { sessionID: string; messages: MessageV2.WithParts[] }) {
|
|
const files = new Set(
|
|
input.messages
|
|
.flatMap((x) => x.parts)
|
|
.filter((x) => x.type === "patch")
|
|
.flatMap((x) => x.files)
|
|
.map((x) => path.relative(Instance.worktree, x).replaceAll("\\", "/")),
|
|
)
|
|
const diffs = await computeDiff({ messages: input.messages }).then((x) =>
|
|
x.filter((x) => {
|
|
return files.has(x.file)
|
|
}),
|
|
)
|
|
await Session.update(input.sessionID, (draft) => {
|
|
draft.summary = {
|
|
additions: diffs.reduce((sum, x) => sum + x.additions, 0),
|
|
deletions: diffs.reduce((sum, x) => sum + x.deletions, 0),
|
|
files: diffs.length,
|
|
}
|
|
})
|
|
await Storage.write(["session_diff", input.sessionID], diffs)
|
|
Bus.publish(Session.Event.Diff, {
|
|
sessionID: input.sessionID,
|
|
diff: diffs,
|
|
})
|
|
}
|
|
|
|
async function summarizeMessage(input: { messageID: string; messages: MessageV2.WithParts[] }) {
|
|
const messages = input.messages.filter(
|
|
(m) => m.info.id === input.messageID || (m.info.role === "assistant" && m.info.parentID === input.messageID),
|
|
)
|
|
const msgWithParts = messages.find((m) => m.info.id === input.messageID)!
|
|
const userMsg = msgWithParts.info as MessageV2.User
|
|
const diffs = await computeDiff({ messages })
|
|
userMsg.summary = {
|
|
...userMsg.summary,
|
|
diffs,
|
|
}
|
|
await Session.updateMessage(userMsg)
|
|
|
|
const textPart = msgWithParts.parts.find((p) => p.type === "text" && !p.synthetic) as MessageV2.TextPart
|
|
if (textPart && !userMsg.summary?.title) {
|
|
const agent = await Agent.get("title")
|
|
if (!agent) return
|
|
const stream = await LLM.stream({
|
|
agent,
|
|
user: userMsg,
|
|
tools: {},
|
|
model: agent.model
|
|
? await Provider.getModel(agent.model.providerID, agent.model.modelID)
|
|
: ((await Provider.getSmallModel(userMsg.model.providerID)) ??
|
|
(await Provider.getModel(userMsg.model.providerID, userMsg.model.modelID))),
|
|
small: true,
|
|
messages: [
|
|
{
|
|
role: "user" as const,
|
|
content: `
|
|
The following is the text to summarize:
|
|
<text>
|
|
${textPart?.text ?? ""}
|
|
</text>
|
|
`,
|
|
},
|
|
],
|
|
abort: new AbortController().signal,
|
|
sessionID: userMsg.sessionID,
|
|
system: [],
|
|
retries: 3,
|
|
})
|
|
const result = await stream.text
|
|
log.info("title", { title: result })
|
|
userMsg.summary.title = result
|
|
await Session.updateMessage(userMsg)
|
|
}
|
|
}
|
|
|
|
export const diff = fn(
|
|
z.object({
|
|
sessionID: Identifier.schema("session"),
|
|
messageID: Identifier.schema("message").optional(),
|
|
}),
|
|
async (input) => {
|
|
const diffs = await Storage.read<Snapshot.FileDiff[]>(["session_diff", input.sessionID]).catch(() => [])
|
|
const next = diffs.map((item) => {
|
|
const file = unquoteGitPath(item.file)
|
|
if (file === item.file) return item
|
|
return {
|
|
...item,
|
|
file,
|
|
}
|
|
})
|
|
const changed = next.some((item, i) => item.file !== diffs[i]?.file)
|
|
if (changed) Storage.write(["session_diff", input.sessionID], next).catch(() => {})
|
|
return next
|
|
},
|
|
)
|
|
|
|
export async function computeDiff(input: { messages: MessageV2.WithParts[] }) {
|
|
let from: string | undefined
|
|
let to: string | undefined
|
|
|
|
// scan assistant messages to find earliest from and latest to
|
|
// snapshot
|
|
for (const item of input.messages) {
|
|
if (!from) {
|
|
for (const part of item.parts) {
|
|
if (part.type === "step-start" && part.snapshot) {
|
|
from = part.snapshot
|
|
break
|
|
}
|
|
}
|
|
}
|
|
|
|
for (const part of item.parts) {
|
|
if (part.type === "step-finish" && part.snapshot) {
|
|
to = part.snapshot
|
|
break
|
|
}
|
|
}
|
|
}
|
|
|
|
if (from && to) return Snapshot.diffFull(from, to)
|
|
return []
|
|
}
|
|
}
|