mirror of
https://gitea.toothfairyai.com/ToothFairyAI/tf_code.git
synced 2026-04-02 15:13:46 +00:00
- Rename packages/opencode → packages/tfcode (directory only) - Rename bin/opencode → bin/tfcode (CLI binary) - Rename .opencode → .tfcode (config directory) - Update package.json name and bin field - Update config directory path references (.tfcode) - Keep internal code references as 'opencode' for easy upstream sync - Keep @opencode-ai/* workspace package names This minimal branding approach allows clean merges from upstream opencode repository while providing tfcode branding for users.
139 lines
4.8 KiB
TypeScript
139 lines
4.8 KiB
TypeScript
import z from "zod"
|
|
import { SessionID, MessageID, PartID } from "./schema"
|
|
import { Snapshot } from "../snapshot"
|
|
import { MessageV2 } from "./message-v2"
|
|
import { Session } from "."
|
|
import { Log } from "../util/log"
|
|
import { Database, eq } from "../storage/db"
|
|
import { MessageTable, PartTable } from "./session.sql"
|
|
import { Storage } from "@/storage/storage"
|
|
import { Bus } from "../bus"
|
|
import { SessionPrompt } from "./prompt"
|
|
import { SessionSummary } from "./summary"
|
|
|
|
export namespace SessionRevert {
|
|
const log = Log.create({ service: "session.revert" })
|
|
|
|
export const RevertInput = z.object({
|
|
sessionID: SessionID.zod,
|
|
messageID: MessageID.zod,
|
|
partID: PartID.zod.optional(),
|
|
})
|
|
export type RevertInput = z.infer<typeof RevertInput>
|
|
|
|
export async function revert(input: RevertInput) {
|
|
SessionPrompt.assertNotBusy(input.sessionID)
|
|
const all = await Session.messages({ sessionID: input.sessionID })
|
|
let lastUser: MessageV2.User | undefined
|
|
const session = await Session.get(input.sessionID)
|
|
|
|
let revert: Session.Info["revert"]
|
|
const patches: Snapshot.Patch[] = []
|
|
for (const msg of all) {
|
|
if (msg.info.role === "user") lastUser = msg.info
|
|
const remaining = []
|
|
for (const part of msg.parts) {
|
|
if (revert) {
|
|
if (part.type === "patch") {
|
|
patches.push(part)
|
|
}
|
|
continue
|
|
}
|
|
|
|
if (!revert) {
|
|
if ((msg.info.id === input.messageID && !input.partID) || part.id === input.partID) {
|
|
// if no useful parts left in message, same as reverting whole message
|
|
const partID = remaining.some((item) => ["text", "tool"].includes(item.type)) ? input.partID : undefined
|
|
revert = {
|
|
messageID: !partID && lastUser ? lastUser.id : msg.info.id,
|
|
partID,
|
|
}
|
|
}
|
|
remaining.push(part)
|
|
}
|
|
}
|
|
}
|
|
|
|
if (revert) {
|
|
const session = await Session.get(input.sessionID)
|
|
revert.snapshot = session.revert?.snapshot ?? (await Snapshot.track())
|
|
await Snapshot.revert(patches)
|
|
if (revert.snapshot) revert.diff = await Snapshot.diff(revert.snapshot)
|
|
const rangeMessages = all.filter((msg) => msg.info.id >= revert!.messageID)
|
|
const diffs = await SessionSummary.computeDiff({ messages: rangeMessages })
|
|
await Storage.write(["session_diff", input.sessionID], diffs)
|
|
Bus.publish(Session.Event.Diff, {
|
|
sessionID: input.sessionID,
|
|
diff: diffs,
|
|
})
|
|
return Session.setRevert({
|
|
sessionID: input.sessionID,
|
|
revert,
|
|
summary: {
|
|
additions: diffs.reduce((sum, x) => sum + x.additions, 0),
|
|
deletions: diffs.reduce((sum, x) => sum + x.deletions, 0),
|
|
files: diffs.length,
|
|
},
|
|
})
|
|
}
|
|
return session
|
|
}
|
|
|
|
export async function unrevert(input: { sessionID: SessionID }) {
|
|
log.info("unreverting", input)
|
|
SessionPrompt.assertNotBusy(input.sessionID)
|
|
const session = await Session.get(input.sessionID)
|
|
if (!session.revert) return session
|
|
if (session.revert.snapshot) await Snapshot.restore(session.revert.snapshot)
|
|
return Session.clearRevert(input.sessionID)
|
|
}
|
|
|
|
export async function cleanup(session: Session.Info) {
|
|
if (!session.revert) return
|
|
const sessionID = session.id
|
|
const msgs = await Session.messages({ sessionID })
|
|
const messageID = session.revert.messageID
|
|
const preserve = [] as MessageV2.WithParts[]
|
|
const remove = [] as MessageV2.WithParts[]
|
|
let target: MessageV2.WithParts | undefined
|
|
for (const msg of msgs) {
|
|
if (msg.info.id < messageID) {
|
|
preserve.push(msg)
|
|
continue
|
|
}
|
|
if (msg.info.id > messageID) {
|
|
remove.push(msg)
|
|
continue
|
|
}
|
|
if (session.revert.partID) {
|
|
preserve.push(msg)
|
|
target = msg
|
|
continue
|
|
}
|
|
remove.push(msg)
|
|
}
|
|
for (const msg of remove) {
|
|
Database.use((db) => db.delete(MessageTable).where(eq(MessageTable.id, msg.info.id)).run())
|
|
await Bus.publish(MessageV2.Event.Removed, { sessionID: sessionID, messageID: msg.info.id })
|
|
}
|
|
if (session.revert.partID && target) {
|
|
const partID = session.revert.partID
|
|
const removeStart = target.parts.findIndex((part) => part.id === partID)
|
|
if (removeStart >= 0) {
|
|
const preserveParts = target.parts.slice(0, removeStart)
|
|
const removeParts = target.parts.slice(removeStart)
|
|
target.parts = preserveParts
|
|
for (const part of removeParts) {
|
|
Database.use((db) => db.delete(PartTable).where(eq(PartTable.id, part.id)).run())
|
|
await Bus.publish(MessageV2.Event.PartRemoved, {
|
|
sessionID: sessionID,
|
|
messageID: target.info.id,
|
|
partID: part.id,
|
|
})
|
|
}
|
|
}
|
|
}
|
|
await Session.clearRevert(sessionID)
|
|
}
|
|
}
|