mirror of
https://gitea.toothfairyai.com/ToothFairyAI/tf_code.git
synced 2026-04-06 16:59:01 +00:00
basic undo feature (#1268)
Co-authored-by: adamdotdevin <2363879+adamdottv@users.noreply.github.com> Co-authored-by: Jay V <air@live.ca> Co-authored-by: Aiden Cline <63023139+rekram1-node@users.noreply.github.com> Co-authored-by: Andrew Joslin <andrew@ajoslin.com> Co-authored-by: GitHub Action <action@github.com> Co-authored-by: Tobias Walle <9933601+tobias-walle@users.noreply.github.com>
This commit is contained in:
@@ -40,6 +40,7 @@ import { MessageV2 } from "./message-v2"
|
||||
import { Mode } from "./mode"
|
||||
import { LSP } from "../lsp"
|
||||
import { ReadTool } from "../tool/read"
|
||||
import { splitWhen } from "remeda"
|
||||
|
||||
export namespace Session {
|
||||
const log = Log.create({ service: "session" })
|
||||
@@ -64,7 +65,7 @@ export namespace Session {
|
||||
revert: z
|
||||
.object({
|
||||
messageID: z.string(),
|
||||
part: z.number(),
|
||||
partID: z.string().optional(),
|
||||
snapshot: z.string().optional(),
|
||||
})
|
||||
.optional(),
|
||||
@@ -246,7 +247,7 @@ export namespace Session {
|
||||
const read = await Storage.readJSON<MessageV2.Info>(p)
|
||||
result.push({
|
||||
info: read,
|
||||
parts: await parts(sessionID, read.id),
|
||||
parts: await getParts(sessionID, read.id),
|
||||
})
|
||||
}
|
||||
result.sort((a, b) => (a.info.id > b.info.id ? 1 : -1))
|
||||
@@ -257,7 +258,7 @@ export namespace Session {
|
||||
return Storage.readJSON<MessageV2.Info>("session/message/" + sessionID + "/" + messageID)
|
||||
}
|
||||
|
||||
export async function parts(sessionID: string, messageID: string) {
|
||||
export async function getParts(sessionID: string, messageID: string) {
|
||||
const result = [] as MessageV2.Part[]
|
||||
for (const item of await Storage.list("session/part/" + sessionID + "/" + messageID)) {
|
||||
const read = await Storage.readJSON<MessageV2.Part>(item)
|
||||
@@ -531,30 +532,26 @@ export namespace Session {
|
||||
const session = await get(input.sessionID)
|
||||
|
||||
if (session.revert) {
|
||||
const trimmed = []
|
||||
for (const msg of msgs) {
|
||||
if (
|
||||
msg.info.id > session.revert.messageID ||
|
||||
(msg.info.id === session.revert.messageID && session.revert.part === 0)
|
||||
) {
|
||||
await Storage.remove("session/message/" + input.sessionID + "/" + msg.info.id)
|
||||
await Bus.publish(MessageV2.Event.Removed, {
|
||||
sessionID: input.sessionID,
|
||||
messageID: msg.info.id,
|
||||
})
|
||||
continue
|
||||
}
|
||||
|
||||
if (msg.info.id === session.revert.messageID) {
|
||||
if (session.revert.part === 0) break
|
||||
msg.parts = msg.parts.slice(0, session.revert.part)
|
||||
}
|
||||
trimmed.push(msg)
|
||||
const messageID = session.revert.messageID
|
||||
const [preserve, remove] = splitWhen(msgs, (x) => x.info.id === messageID)
|
||||
msgs = preserve
|
||||
for (const msg of remove) {
|
||||
await Storage.remove(`session/message/${input.sessionID}/${msg.info.id}`)
|
||||
await Bus.publish(MessageV2.Event.Removed, { sessionID: input.sessionID, messageID: msg.info.id })
|
||||
}
|
||||
const last = preserve.at(-1)
|
||||
if (session.revert.partID && last) {
|
||||
const partID = session.revert.partID
|
||||
const [preserveParts, removeParts] = splitWhen(last.parts, (x) => x.id === partID)
|
||||
last.parts = preserveParts
|
||||
for (const part of removeParts) {
|
||||
await Storage.remove(`session/part/${input.sessionID}/${last.info.id}/${part.id}`)
|
||||
await Bus.publish(MessageV2.Event.PartRemoved, {
|
||||
messageID: last.info.id,
|
||||
partID: part.id,
|
||||
})
|
||||
}
|
||||
}
|
||||
msgs = trimmed
|
||||
await update(input.sessionID, (draft) => {
|
||||
draft.revert = undefined
|
||||
})
|
||||
}
|
||||
|
||||
const previous = msgs.filter((x) => x.info.role === "assistant").at(-1)?.info as MessageV2.Assistant
|
||||
@@ -831,7 +828,7 @@ export namespace Session {
|
||||
})
|
||||
switch (value.type) {
|
||||
case "start":
|
||||
const snapshot = await Snapshot.create(assistantMsg.sessionID)
|
||||
const snapshot = await Snapshot.create()
|
||||
if (snapshot)
|
||||
await updatePart({
|
||||
id: Identifier.ascending("part"),
|
||||
@@ -895,7 +892,7 @@ export namespace Session {
|
||||
},
|
||||
})
|
||||
delete toolCalls[value.toolCallId]
|
||||
const snapshot = await Snapshot.create(assistantMsg.sessionID)
|
||||
const snapshot = await Snapshot.create()
|
||||
if (snapshot)
|
||||
await updatePart({
|
||||
id: Identifier.ascending("part"),
|
||||
@@ -924,7 +921,7 @@ export namespace Session {
|
||||
},
|
||||
})
|
||||
delete toolCalls[value.toolCallId]
|
||||
const snapshot = await Snapshot.create(assistantMsg.sessionID)
|
||||
const snapshot = await Snapshot.create()
|
||||
if (snapshot)
|
||||
await updatePart({
|
||||
id: Identifier.ascending("part"),
|
||||
@@ -1043,7 +1040,7 @@ export namespace Session {
|
||||
error: assistantMsg.error,
|
||||
})
|
||||
}
|
||||
const p = await parts(assistantMsg.sessionID, assistantMsg.id)
|
||||
const p = await getParts(assistantMsg.sessionID, assistantMsg.id)
|
||||
for (const part of p) {
|
||||
if (part.type === "tool" && part.state.status !== "completed") {
|
||||
updatePart({
|
||||
@@ -1067,47 +1064,53 @@ export namespace Session {
|
||||
}
|
||||
}
|
||||
|
||||
export async function revert(_input: { sessionID: string; messageID: string; part: number }) {
|
||||
// TODO
|
||||
/*
|
||||
const message = await getMessage(input.sessionID, input.messageID)
|
||||
if (!message) return
|
||||
const part = message.parts[input.part]
|
||||
if (!part) return
|
||||
export const RevertInput = z.object({
|
||||
sessionID: Identifier.schema("session"),
|
||||
messageID: Identifier.schema("message"),
|
||||
partID: Identifier.schema("part").optional(),
|
||||
})
|
||||
export type RevertInput = z.infer<typeof RevertInput>
|
||||
|
||||
export async function revert(input: RevertInput) {
|
||||
const all = await messages(input.sessionID)
|
||||
const session = await get(input.sessionID)
|
||||
const snapshot =
|
||||
session.revert?.snapshot ?? (await Snapshot.create(input.sessionID))
|
||||
const old = (() => {
|
||||
if (message.role === "assistant") {
|
||||
const lastTool = message.parts.findLast(
|
||||
(part, index) =>
|
||||
part.type === "tool-invocation" && index < input.part,
|
||||
)
|
||||
if (lastTool && lastTool.type === "tool-invocation")
|
||||
return message.metadata.tool[lastTool.toolInvocation.toolCallId]
|
||||
.snapshot
|
||||
let lastUser: MessageV2.User | undefined
|
||||
let lastSnapshot: MessageV2.SnapshotPart | undefined
|
||||
for (const msg of all) {
|
||||
if (msg.info.role === "user") lastUser = msg.info
|
||||
const remaining = []
|
||||
for (const part of msg.parts) {
|
||||
if (part.type === "snapshot") lastSnapshot = part
|
||||
if ((msg.info.id === input.messageID && !input.partID) || part.id === input.partID) {
|
||||
// if no useful parts left in message, same as reverting whole message
|
||||
const partID = remaining.some((item) => ["text", "tool"].includes(item.type)) ? input.partID : undefined
|
||||
const snapshot = session.revert?.snapshot ?? (await Snapshot.create(true))
|
||||
log.info("revert snapshot", { snapshot })
|
||||
if (lastSnapshot) await Snapshot.restore(lastSnapshot.snapshot)
|
||||
const next = await update(input.sessionID, (draft) => {
|
||||
draft.revert = {
|
||||
// if not part id jump to the last user message
|
||||
messageID: !partID && lastUser ? lastUser.id : msg.info.id,
|
||||
partID,
|
||||
snapshot,
|
||||
}
|
||||
})
|
||||
return next
|
||||
}
|
||||
remaining.push(part)
|
||||
}
|
||||
return message.metadata.snapshot
|
||||
})()
|
||||
if (old) await Snapshot.restore(input.sessionID, old)
|
||||
await update(input.sessionID, (draft) => {
|
||||
draft.revert = {
|
||||
messageID: input.messageID,
|
||||
part: input.part,
|
||||
snapshot,
|
||||
}
|
||||
})
|
||||
*/
|
||||
}
|
||||
}
|
||||
|
||||
export async function unrevert(sessionID: string) {
|
||||
const session = await get(sessionID)
|
||||
if (!session) return
|
||||
if (!session.revert) return
|
||||
if (session.revert.snapshot) await Snapshot.restore(sessionID, session.revert.snapshot)
|
||||
update(sessionID, (draft) => {
|
||||
export async function unrevert(input: { sessionID: string }) {
|
||||
log.info("unreverting", input)
|
||||
const session = await get(input.sessionID)
|
||||
if (!session.revert) return session
|
||||
if (session.revert.snapshot) await Snapshot.restore(session.revert.snapshot)
|
||||
const next = await update(input.sessionID, (draft) => {
|
||||
draft.revert = undefined
|
||||
})
|
||||
return next
|
||||
}
|
||||
|
||||
export async function summarize(input: { sessionID: string; providerID: string; modelID: string }) {
|
||||
|
||||
@@ -272,6 +272,13 @@ export namespace MessageV2 {
|
||||
part: Part,
|
||||
}),
|
||||
),
|
||||
PartRemoved: Bus.event(
|
||||
"message.part.removed",
|
||||
z.object({
|
||||
messageID: z.string(),
|
||||
partID: z.string(),
|
||||
}),
|
||||
),
|
||||
}
|
||||
|
||||
export function fromV1(v1: Message.Info) {
|
||||
|
||||
Reference in New Issue
Block a user