Part data model (#950)

This commit is contained in:
Dax
2025-07-13 17:22:11 -04:00
committed by GitHub
parent 736396fc70
commit 90d6c4ab41
27 changed files with 1447 additions and 965 deletions

View File

@@ -12,6 +12,7 @@ import {
type ProviderMetadata,
type ModelMessage,
stepCountIs,
type StreamTextResult,
} from "ai"
import PROMPT_INITIALIZE from "../session/prompt/initialize.txt"
@@ -190,7 +191,10 @@ export namespace Session {
await Storage.writeJSON<ShareInfo>("session/share/" + id, share)
await Share.sync("session/info/" + id, session)
for (const msg of await messages(id)) {
await Share.sync("session/message/" + id + "/" + msg.id, msg)
await Share.sync("session/message/" + id + "/" + msg.info.id, msg.info)
for (const part of msg.parts) {
await Share.sync("session/part/" + id + "/" + msg.info.id + "/" + part.id, part)
}
}
return share
}
@@ -220,13 +224,19 @@ export namespace Session {
}
export async function messages(sessionID: string) {
const result = [] as MessageV2.Info[]
const result = [] as {
info: MessageV2.Info
parts: MessageV2.Part[]
}[]
const list = Storage.list("session/message/" + sessionID)
for await (const p of list) {
const read = await Storage.readJSON<MessageV2.Info>(p)
result.push(read)
result.push({
info: read,
parts: await parts(sessionID, read.id),
})
}
result.sort((a, b) => (a.id > b.id ? 1 : -1))
result.sort((a, b) => (a.info.id > b.info.id ? 1 : -1))
return result
}
@@ -234,6 +244,16 @@ export namespace Session {
return Storage.readJSON<MessageV2.Info>("session/message/" + sessionID + "/" + messageID)
}
export async function parts(sessionID: string, messageID: string) {
const result = [] as MessageV2.Part[]
for await (const item of Storage.list("session/part/" + sessionID + "/" + messageID)) {
const read = await Storage.readJSON<MessageV2.Part>(item)
result.push(read)
}
result.sort((a, b) => (a.id > b.id ? 1 : -1))
return result
}
export async function* list() {
for await (const item of Storage.list("session/info")) {
const sessionID = path.basename(item, ".json")
@@ -289,12 +309,21 @@ export namespace Session {
})
}
async function updatePart(part: MessageV2.Part) {
await Storage.writeJSON(["session", "part", part.sessionID, part.messageID, part.id].join("/"), part)
Bus.publish(MessageV2.Event.PartUpdated, {
part,
})
return part
}
export async function chat(input: {
sessionID: string
messageID: string
providerID: string
modelID: string
mode?: string
parts: MessageV2.UserPart[]
parts: (MessageV2.TextPart | MessageV2.FilePart)[]
}) {
const l = log.clone().tag("session", input.sessionID)
l.info("chatting")
@@ -306,16 +335,19 @@ export namespace Session {
if (session.revert) {
const trimmed = []
for (const msg of msgs) {
if (msg.id > session.revert.messageID || (msg.id === session.revert.messageID && session.revert.part === 0)) {
await Storage.remove("session/message/" + input.sessionID + "/" + msg.id)
if (
msg.info.id > session.revert.messageID ||
(msg.info.id === session.revert.messageID && session.revert.part === 0)
) {
await Storage.remove("session/message/" + input.sessionID + "/" + msg.info.id)
await Bus.publish(MessageV2.Event.Removed, {
sessionID: input.sessionID,
messageID: msg.id,
messageID: msg.info.id,
})
continue
}
if (msg.id === session.revert.messageID) {
if (msg.info.id === session.revert.messageID) {
if (session.revert.part === 0) break
msg.parts = msg.parts.slice(0, session.revert.part)
}
@@ -327,7 +359,7 @@ export namespace Session {
})
}
const previous = msgs.at(-1) as MessageV2.Assistant
const previous = msgs.filter((x) => x.info.role === "assistant").at(-1)?.info as MessageV2.Assistant
const outputLimit = Math.min(model.info.limit.output, OUTPUT_TOKEN_MAX) || OUTPUT_TOKEN_MAX
// auto summarize if too long
@@ -346,12 +378,21 @@ export namespace Session {
using abort = lock(input.sessionID)
const lastSummary = msgs.findLast((msg) => msg.role === "assistant" && msg.summary === true)
if (lastSummary) msgs = msgs.filter((msg) => msg.id >= lastSummary.id)
const lastSummary = msgs.findLast((msg) => msg.info.role === "assistant" && msg.info.summary === true)
if (lastSummary) msgs = msgs.filter((msg) => msg.info.id >= lastSummary.info.id)
const userMsg: MessageV2.Info = {
id: input.messageID,
role: "user",
sessionID: input.sessionID,
time: {
created: Date.now(),
},
}
const app = App.info()
input.parts = await Promise.all(
input.parts.map(async (part): Promise<MessageV2.UserPart[]> => {
const userParts = await Promise.all(
input.parts.map(async (part): Promise<MessageV2.Part[]> => {
if (part.type === "file") {
const url = new URL(part.url)
switch (url.protocol) {
@@ -406,11 +447,17 @@ export namespace Session {
})
return [
{
id: Identifier.ascending("part"),
messageID: userMsg.id,
sessionID: input.sessionID,
type: "text",
synthetic: true,
text: `Called the Read tool with the following input: ${JSON.stringify(args)}`,
},
{
id: Identifier.ascending("part"),
messageID: userMsg.id,
sessionID: input.sessionID,
type: "text",
synthetic: true,
text: result.output,
@@ -422,11 +469,17 @@ export namespace Session {
FileTime.read(input.sessionID, filePath)
return [
{
id: Identifier.ascending("part"),
messageID: userMsg.id,
sessionID: input.sessionID,
type: "text",
text: `Called the Read tool with the following input: {\"filePath\":\"${pathname}\"}`,
synthetic: true,
},
{
id: Identifier.ascending("part"),
messageID: userMsg.id,
sessionID: input.sessionID,
type: "file",
url: `data:${part.mime};base64,` + Buffer.from(await file.bytes()).toString("base64"),
mime: part.mime,
@@ -440,7 +493,10 @@ export namespace Session {
).then((x) => x.flat())
if (input.mode === "plan")
input.parts.push({
userParts.push({
id: Identifier.ascending("part"),
messageID: userMsg.id,
sessionID: input.sessionID,
type: "text",
text: PROMPT_PLAN,
synthetic: true,
@@ -459,13 +515,15 @@ export namespace Session {
),
...MessageV2.toModelMessage([
{
id: Identifier.ascending("message"),
role: "user",
sessionID: input.sessionID,
parts: input.parts,
time: {
created: Date.now(),
info: {
id: Identifier.ascending("message"),
role: "user",
sessionID: input.sessionID,
time: {
created: Date.now(),
},
},
parts: userParts,
},
]),
],
@@ -479,17 +537,11 @@ export namespace Session {
})
.catch(() => {})
}
const msg: MessageV2.Info = {
id: Identifier.ascending("message"),
role: "user",
sessionID: input.sessionID,
parts: input.parts,
time: {
created: Date.now(),
},
await updateMessage(userMsg)
for (const part of userParts) {
await updatePart(part)
}
await updateMessage(msg)
msgs.push(msg)
msgs.push({ info: userMsg, parts: userParts })
const mode = await Mode.get(input.mode ?? "build")
let system = mode.prompt ? [mode.prompt] : SystemPrompt.provider(input.providerID, input.modelID)
@@ -499,10 +551,9 @@ export namespace Session {
const [first, ...rest] = system
system = [first, rest.join("\n")]
const next: MessageV2.Info = {
const assistantMsg: MessageV2.Info = {
id: Identifier.ascending("message"),
role: "assistant",
parts: [],
system,
path: {
cwd: app.path.cwd,
@@ -522,7 +573,7 @@ export namespace Session {
},
sessionID: input.sessionID,
}
await updateMessage(next)
await updateMessage(assistantMsg)
const tools: Record<string, AITool> = {}
for (const item of await Provider.tools(input.providerID)) {
@@ -531,20 +582,29 @@ export namespace Session {
id: item.id as any,
description: item.description,
inputSchema: item.parameters as ZodSchema,
async execute(args, opts) {
async execute(args) {
const result = await item.execute(args, {
sessionID: input.sessionID,
abort: abort.signal,
messageID: next.id,
metadata: async (val) => {
const match = next.parts.find(
(p): p is MessageV2.ToolPart => p.type === "tool" && p.id === opts.toolCallId,
)
messageID: assistantMsg.id,
metadata: async () => {
/*
const match = toolCalls[opts.toolCallId]
if (match && match.state.status === "running") {
match.state.title = val.title
match.state.metadata = val.metadata
await updatePart({
...match,
state: {
title: val.title,
metadata: val.metadata,
status: "running",
input: args.input,
time: {
start: Date.now(),
},
},
})
}
await updateMessage(next)
*/
},
})
return result
@@ -582,10 +642,6 @@ export namespace Session {
tools[key] = item
}
let text: MessageV2.TextPart = {
type: "text",
text: "",
}
const result = streamText({
onError() {},
maxRetries: 10,
@@ -619,9 +675,20 @@ export namespace Session {
],
}),
})
return processStream(assistantMsg, model.info, result)
}
async function processStream(
assistantMsg: MessageV2.Assistant,
model: ModelsDev.Model,
stream: StreamTextResult<Record<string, AITool>, never>,
) {
try {
for await (const value of result.fullStream) {
l.info("part", {
let currentText: MessageV2.TextPart | undefined
const toolCalls: Record<string, MessageV2.ToolPart> = {}
for await (const value of stream.fullStream) {
log.info("part", {
type: value.type,
})
switch (value.type) {
@@ -629,88 +696,78 @@ export namespace Session {
break
case "tool-input-start":
next.parts.push({
const part = await updatePart({
id: Identifier.ascending("part"),
messageID: assistantMsg.id,
sessionID: assistantMsg.sessionID,
type: "tool",
tool: value.toolName,
id: value.id,
callID: value.id,
state: {
status: "pending",
},
})
Bus.publish(MessageV2.Event.PartUpdated, {
part: next.parts[next.parts.length - 1],
sessionID: next.sessionID,
messageID: next.id,
})
toolCalls[value.id] = part as MessageV2.ToolPart
break
case "tool-input-delta":
break
case "tool-call": {
const match = next.parts.find(
(p): p is MessageV2.ToolPart => p.type === "tool" && p.id === value.toolCallId,
)
const match = toolCalls[value.toolCallId]
if (match) {
match.state = {
status: "running",
input: value.input,
time: {
start: Date.now(),
const part = await updatePart({
...match,
state: {
status: "running",
input: value.input,
time: {
start: Date.now(),
},
},
}
Bus.publish(MessageV2.Event.PartUpdated, {
part: match,
sessionID: next.sessionID,
messageID: next.id,
})
toolCalls[value.toolCallId] = part as MessageV2.ToolPart
}
break
}
case "tool-result": {
const match = next.parts.find(
(p): p is MessageV2.ToolPart => p.type === "tool" && p.id === value.toolCallId,
)
const match = toolCalls[value.toolCallId]
if (match && match.state.status === "running") {
match.state = {
status: "completed",
input: value.input,
output: value.output.output,
metadata: value.output.metadata,
title: value.output.title,
time: {
start: match.state.time.start,
end: Date.now(),
await updatePart({
...match,
state: {
status: "completed",
input: value.input,
output: value.output.output,
metadata: value.output.metadata,
title: value.output.title,
time: {
start: match.state.time.start,
end: Date.now(),
},
},
}
Bus.publish(MessageV2.Event.PartUpdated, {
part: match,
sessionID: next.sessionID,
messageID: next.id,
})
delete toolCalls[value.toolCallId]
}
break
}
case "tool-error": {
const match = next.parts.find(
(p): p is MessageV2.ToolPart => p.type === "tool" && p.id === value.toolCallId,
)
const match = toolCalls[value.toolCallId]
if (match && match.state.status === "running") {
match.state = {
status: "error",
input: value.input,
error: (value.error as any).toString(),
time: {
start: match.state.time.start,
end: Date.now(),
await updatePart({
...match,
state: {
status: "error",
input: value.input,
error: (value.error as any).toString(),
time: {
start: match.state.time.start,
end: Date.now(),
},
},
}
Bus.publish(MessageV2.Event.PartUpdated, {
part: match,
sessionID: next.sessionID,
messageID: next.id,
})
delete toolCalls[value.toolCallId]
}
break
}
@@ -719,53 +776,71 @@ export namespace Session {
throw value.error
case "start-step":
next.parts.push({
await updatePart({
id: Identifier.ascending("part"),
messageID: assistantMsg.id,
sessionID: assistantMsg.sessionID,
type: "step-start",
})
break
case "finish-step":
const usage = getUsage(model.info, value.usage, value.providerMetadata)
next.cost += usage.cost
next.tokens = usage.tokens
next.parts.push({
const usage = getUsage(model, value.usage, value.providerMetadata)
assistantMsg.cost += usage.cost
assistantMsg.tokens = usage.tokens
await updatePart({
id: Identifier.ascending("part"),
messageID: assistantMsg.id,
sessionID: assistantMsg.sessionID,
type: "step-finish",
tokens: usage.tokens,
cost: usage.cost,
})
await updateMessage(assistantMsg)
break
case "text-start":
text = {
currentText = {
id: Identifier.ascending("part"),
messageID: assistantMsg.id,
sessionID: assistantMsg.sessionID,
type: "text",
text: "",
time: {
start: Date.now(),
},
}
break
case "text":
if (text.text === "") next.parts.push(text)
text.text += value.text
if (currentText) {
currentText.text += value.text
await updatePart(currentText)
}
break
case "text-end":
Bus.publish(MessageV2.Event.PartUpdated, {
part: text,
sessionID: next.sessionID,
messageID: next.id,
})
if (currentText && currentText.text) {
currentText.time = {
start: Date.now(),
end: Date.now(),
}
await updatePart(currentText)
}
currentText = undefined
break
case "finish":
next.time.completed = Date.now()
assistantMsg.time.completed = Date.now()
await updateMessage(assistantMsg)
break
default:
l.info("unhandled", {
log.info("unhandled", {
...value,
})
continue
}
await updateMessage(next)
}
} catch (e) {
log.error("", {
@@ -773,7 +848,7 @@ export namespace Session {
})
switch (true) {
case e instanceof DOMException && e.name === "AbortError":
next.error = new MessageV2.AbortedError(
assistantMsg.error = new MessageV2.AbortedError(
{ message: e.message },
{
cause: e,
@@ -781,44 +856,48 @@ export namespace Session {
).toObject()
break
case MessageV2.OutputLengthError.isInstance(e):
next.error = e
assistantMsg.error = e
break
case LoadAPIKeyError.isInstance(e):
next.error = new Provider.AuthError(
assistantMsg.error = new Provider.AuthError(
{
providerID: input.providerID,
providerID: model.id,
message: e.message,
},
{ cause: e },
).toObject()
break
case e instanceof Error:
next.error = new NamedError.Unknown({ message: e.toString() }, { cause: e }).toObject()
assistantMsg.error = new NamedError.Unknown({ message: e.toString() }, { cause: e }).toObject()
break
default:
next.error = new NamedError.Unknown({ message: JSON.stringify(e) }, { cause: e })
assistantMsg.error = new NamedError.Unknown({ message: JSON.stringify(e) }, { cause: e })
}
Bus.publish(Event.Error, {
sessionID: next.sessionID,
error: next.error,
sessionID: assistantMsg.sessionID,
error: assistantMsg.error,
})
}
for (const part of next.parts) {
const p = await parts(assistantMsg.sessionID, assistantMsg.id)
for (const part of p) {
if (part.type === "tool" && part.state.status !== "completed") {
part.state = {
status: "error",
error: "Tool execution aborted",
time: {
start: Date.now(),
end: Date.now(),
updatePart({
...part,
state: {
status: "error",
error: "Tool execution aborted",
time: {
start: Date.now(),
end: Date.now(),
},
input: {},
},
input: {},
}
})
}
}
next.time.completed = Date.now()
await updateMessage(next)
return next
assistantMsg.time.completed = Date.now()
await updateMessage(assistantMsg)
return { info: assistantMsg, parts: p }
}
export async function revert(_input: { sessionID: string; messageID: string; part: number }) {
@@ -867,8 +946,8 @@ export namespace Session {
export async function summarize(input: { sessionID: string; providerID: string; modelID: string }) {
using abort = lock(input.sessionID)
const msgs = await messages(input.sessionID)
const lastSummary = msgs.findLast((msg) => msg.role === "assistant" && msg.summary === true)?.id
const filtered = msgs.filter((msg) => !lastSummary || msg.id >= lastSummary)
const lastSummary = msgs.findLast((msg) => msg.info.role === "assistant" && msg.info.summary === true)
const filtered = msgs.filter((msg) => !lastSummary || msg.info.id >= lastSummary.info.id)
const model = await Provider.getModel(input.providerID, input.modelID)
const app = App.info()
const system = SystemPrompt.summarize(input.providerID)
@@ -876,7 +955,6 @@ export namespace Session {
const next: MessageV2.Info = {
id: Identifier.ascending("message"),
role: "assistant",
parts: [],
sessionID: input.sessionID,
system,
path: {
@@ -899,7 +977,6 @@ export namespace Session {
}
await updateMessage(next)
let text: MessageV2.TextPart | undefined
const result = streamText({
abortSignal: abort.signal,
model: model.language,
@@ -921,81 +998,9 @@ export namespace Session {
],
},
],
onStepFinish: async (step) => {
const usage = getUsage(model.info, step.usage, step.providerMetadata)
next.cost += usage.cost
next.tokens = usage.tokens
await updateMessage(next)
if (text) {
Bus.publish(MessageV2.Event.PartUpdated, {
part: text,
messageID: next.id,
sessionID: next.sessionID,
})
}
text = undefined
},
async onFinish(input) {
const usage = getUsage(model.info, input.usage, input.providerMetadata)
next.cost += usage.cost
next.tokens = usage.tokens
next.time.completed = Date.now()
await updateMessage(next)
},
})
try {
for await (const value of result.fullStream) {
switch (value.type) {
case "text":
if (!text) {
text = {
type: "text",
text: value.text,
}
next.parts.push(text)
} else text.text += value.text
await updateMessage(next)
break
}
}
} catch (e: any) {
log.error("summarize stream error", {
error: e,
})
switch (true) {
case e instanceof DOMException && e.name === "AbortError":
next.error = new MessageV2.AbortedError(
{ message: e.message },
{
cause: e,
},
).toObject()
break
case MessageV2.OutputLengthError.isInstance(e):
next.error = e
break
case LoadAPIKeyError.isInstance(e):
next.error = new Provider.AuthError(
{
providerID: input.providerID,
message: e.message,
},
{ cause: e },
).toObject()
break
case e instanceof Error:
next.error = new NamedError.Unknown({ message: e.toString() }, { cause: e }).toObject()
break
default:
next.error = new NamedError.Unknown({ message: JSON.stringify(e) }, { cause: e }).toObject()
}
Bus.publish(Event.Error, {
error: next.error,
})
}
next.time.completed = Date.now()
await updateMessage(next)
return processStream(next, model.info, result)
}
function lock(sessionID: string) {
@@ -1045,14 +1050,23 @@ export namespace Session {
}
}
export async function initialize(input: { sessionID: string; modelID: string; providerID: string }) {
export async function initialize(input: {
sessionID: string
modelID: string
providerID: string
messageID: string
}) {
const app = App.info()
await Session.chat({
sessionID: input.sessionID,
messageID: input.messageID,
providerID: input.providerID,
modelID: input.modelID,
parts: [
{
id: Identifier.ascending("part"),
sessionID: input.sessionID,
messageID: input.messageID,
type: "text",
text: PROMPT_INITIALIZE.replace("${path}", app.path.root),
},