wip: more snapshot stuff

This commit is contained in:
Dax Raad
2025-07-24 17:34:15 -04:00
parent 22c9e2942b
commit 284c01018e
12 changed files with 340 additions and 144 deletions

View File

@@ -661,6 +661,7 @@ export namespace Session {
description: item.description,
inputSchema: item.parameters as ZodSchema,
async execute(args, options) {
await processor.track(options.toolCallId)
const result = await item.execute(args, {
sessionID: input.sessionID,
abort: abort.signal,
@@ -699,6 +700,7 @@ export namespace Session {
const execute = item.execute
if (!execute) continue
item.execute = async (args, opts) => {
await processor.track(opts.toolCallId)
const result = await execute(args, opts)
const output = result.content
.filter((x: any) => x.type === "text")
@@ -814,7 +816,12 @@ export namespace Session {
function createProcessor(assistantMsg: MessageV2.Assistant, model: ModelsDev.Model) {
const toolCalls: Record<string, MessageV2.ToolPart> = {}
const snapshots: Record<string, string> = {}
return {
async track(toolCallID: string) {
const hash = await Snapshot.track()
if (hash) snapshots[toolCallID] = hash
},
partFromToolCall(toolCallID: string) {
return toolCalls[toolCallID]
},
@@ -828,15 +835,6 @@ export namespace Session {
})
switch (value.type) {
case "start":
const snapshot = await Snapshot.create()
if (snapshot)
await updatePart({
id: Identifier.ascending("part"),
messageID: assistantMsg.id,
sessionID: assistantMsg.sessionID,
type: "snapshot",
snapshot,
})
break
case "tool-input-start":
@@ -857,6 +855,9 @@ export namespace Session {
case "tool-input-delta":
break
case "tool-input-end":
break
case "tool-call": {
const match = toolCalls[value.toolCallId]
if (match) {
@@ -892,15 +893,20 @@ export namespace Session {
},
})
delete toolCalls[value.toolCallId]
const snapshot = await Snapshot.create()
if (snapshot)
await updatePart({
id: Identifier.ascending("part"),
messageID: assistantMsg.id,
sessionID: assistantMsg.sessionID,
type: "snapshot",
snapshot,
})
const snapshot = snapshots[value.toolCallId]
if (snapshot) {
const patch = await Snapshot.patch(snapshot)
if (patch.files.length) {
await updatePart({
id: Identifier.ascending("part"),
messageID: assistantMsg.id,
sessionID: assistantMsg.sessionID,
type: "patch",
hash: patch.hash,
files: patch.files,
})
}
}
}
break
}
@@ -921,15 +927,18 @@ export namespace Session {
},
})
delete toolCalls[value.toolCallId]
const snapshot = await Snapshot.create()
if (snapshot)
const snapshot = snapshots[value.toolCallId]
if (snapshot) {
const patch = await Snapshot.patch(snapshot)
await updatePart({
id: Identifier.ascending("part"),
messageID: assistantMsg.id,
sessionID: assistantMsg.sessionID,
type: "snapshot",
snapshot,
type: "patch",
hash: patch.hash,
files: patch.files,
})
}
}
break
}
@@ -1073,33 +1082,45 @@ export namespace Session {
export async function revert(input: RevertInput) {
const all = await messages(input.sessionID)
const session = await get(input.sessionID)
let lastUser: MessageV2.User | undefined
let lastSnapshot: MessageV2.SnapshotPart | undefined
const session = await get(input.sessionID)
let revert: Info["revert"]
const patches: Snapshot.Patch[] = []
for (const msg of all) {
if (msg.info.role === "user") lastUser = msg.info
const remaining = []
for (const part of msg.parts) {
if (part.type === "snapshot") lastSnapshot = part
if ((msg.info.id === input.messageID && !input.partID) || part.id === input.partID) {
// if no useful parts left in message, same as reverting whole message
const partID = remaining.some((item) => ["text", "tool"].includes(item.type)) ? input.partID : undefined
const snapshot = session.revert?.snapshot ?? (await Snapshot.create())
log.info("revert snapshot", { snapshot })
if (lastSnapshot) await Snapshot.restore(lastSnapshot.snapshot)
const next = await update(input.sessionID, (draft) => {
draft.revert = {
// if not part id jump to the last user message
if (revert) {
if (part.type === "patch") {
patches.push(part)
}
continue
}
if (!revert) {
if ((msg.info.id === input.messageID && !input.partID) || part.id === input.partID) {
// if no useful parts left in message, same as reverting whole message
const partID = remaining.some((item) => ["text", "tool"].includes(item.type)) ? input.partID : undefined
revert = {
messageID: !partID && lastUser ? lastUser.id : msg.info.id,
partID,
snapshot,
}
})
return next
}
remaining.push(part)
}
remaining.push(part)
}
}
if (revert) {
const session = await get(input.sessionID)
revert.snapshot = session.revert?.snapshot ?? (await Snapshot.track())
await Snapshot.revert(patches)
return update(input.sessionID, (draft) => {
draft.revert = revert
})
}
return session
}
export async function unrevert(input: { sessionID: string }) {

View File

@@ -94,6 +94,15 @@ export namespace MessageV2 {
})
export type SnapshotPart = z.infer<typeof SnapshotPart>
export const PatchPart = PartBase.extend({
type: z.literal("patch"),
hash: z.string(),
files: z.string().array(),
}).openapi({
ref: "PatchPart",
})
export type PatchPart = z.infer<typeof PatchPart>
export const TextPart = PartBase.extend({
type: z.literal("text"),
text: z.string(),
@@ -203,7 +212,7 @@ export namespace MessageV2 {
export type User = z.infer<typeof User>
export const Part = z
.discriminatedUnion("type", [TextPart, FilePart, ToolPart, StepStartPart, StepFinishPart, SnapshotPart])
.discriminatedUnion("type", [TextPart, FilePart, ToolPart, StepStartPart, StepFinishPart, SnapshotPart, PatchPart])
.openapi({
ref: "Part",
})