Part data model (#950)

This commit is contained in:
Dax
2025-07-13 17:22:11 -04:00
committed by GitHub
parent 736396fc70
commit 90d6c4ab41
27 changed files with 1447 additions and 965 deletions

View File

@@ -4,6 +4,7 @@ import { Provider } from "../provider/provider"
import { NamedError } from "../util/error"
import { Message } from "./message"
import { convertToModelMessages, type ModelMessage, type UIMessage } from "ai"
import { Identifier } from "../id/id"
export namespace MessageV2 {
export const OutputLengthError = NamedError.create("MessageOutputLengthError", z.object({}))
@@ -72,67 +73,69 @@ export namespace MessageV2 {
ref: "ToolState",
})
export const TextPart = z
.object({
type: z.literal("text"),
text: z.string(),
synthetic: z.boolean().optional(),
})
.openapi({
ref: "TextPart",
})
const PartBase = z.object({
id: z.string(),
sessionID: z.string(),
messageID: z.string(),
})
export const TextPart = PartBase.extend({
type: z.literal("text"),
text: z.string(),
synthetic: z.boolean().optional(),
time: z
.object({
start: z.number(),
end: z.number().optional(),
})
.optional(),
}).openapi({
ref: "TextPart",
})
export type TextPart = z.infer<typeof TextPart>
export const ToolPart = z
.object({
type: z.literal("tool"),
id: z.string(),
tool: z.string(),
state: ToolState,
})
.openapi({
ref: "ToolPart",
})
export const ToolPart = PartBase.extend({
type: z.literal("tool"),
callID: z.string(),
tool: z.string(),
state: ToolState,
}).openapi({
ref: "ToolPart",
})
export type ToolPart = z.infer<typeof ToolPart>
export const FilePart = z
.object({
type: z.literal("file"),
mime: z.string(),
filename: z.string().optional(),
url: z.string(),
})
.openapi({
ref: "FilePart",
})
export const FilePart = PartBase.extend({
type: z.literal("file"),
mime: z.string(),
filename: z.string().optional(),
url: z.string(),
}).openapi({
ref: "FilePart",
})
export type FilePart = z.infer<typeof FilePart>
export const StepStartPart = z
.object({
type: z.literal("step-start"),
})
.openapi({
ref: "StepStartPart",
})
export const StepStartPart = PartBase.extend({
type: z.literal("step-start"),
}).openapi({
ref: "StepStartPart",
})
export type StepStartPart = z.infer<typeof StepStartPart>
export const StepFinishPart = z
.object({
type: z.literal("step-finish"),
cost: z.number(),
tokens: z.object({
input: z.number(),
output: z.number(),
reasoning: z.number(),
cache: z.object({
read: z.number(),
write: z.number(),
}),
export const StepFinishPart = PartBase.extend({
type: z.literal("step-finish"),
cost: z.number(),
tokens: z.object({
input: z.number(),
output: z.number(),
reasoning: z.number(),
cache: z.object({
read: z.number(),
write: z.number(),
}),
})
.openapi({
ref: "StepFinishPart",
})
}),
}).openapi({
ref: "StepFinishPart",
})
export type StepFinishPart = z.infer<typeof StepFinishPart>
const Base = z.object({
@@ -140,14 +143,8 @@ export namespace MessageV2 {
sessionID: z.string(),
})
export const UserPart = z.discriminatedUnion("type", [TextPart, FilePart]).openapi({
ref: "UserMessagePart",
})
export type UserPart = z.infer<typeof UserPart>
export const User = Base.extend({
role: z.literal("user"),
parts: z.array(UserPart),
time: z.object({
created: z.number(),
}),
@@ -156,16 +153,15 @@ export namespace MessageV2 {
})
export type User = z.infer<typeof User>
export const AssistantPart = z
.discriminatedUnion("type", [TextPart, ToolPart, StepStartPart, StepFinishPart])
export const Part = z
.discriminatedUnion("type", [TextPart, FilePart, ToolPart, StepStartPart, StepFinishPart])
.openapi({
ref: "AssistantMessagePart",
ref: "Part",
})
export type AssistantPart = z.infer<typeof AssistantPart>
export type Part = z.infer<typeof Part>
export const Assistant = Base.extend({
role: z.literal("assistant"),
parts: z.array(AssistantPart),
time: z.object({
created: z.number(),
completed: z.number().optional(),
@@ -223,16 +219,14 @@ export namespace MessageV2 {
PartUpdated: Bus.event(
"message.part.updated",
z.object({
part: AssistantPart,
sessionID: z.string(),
messageID: z.string(),
part: Part,
}),
),
}
export function fromV1(v1: Message.Info) {
if (v1.role === "assistant") {
const result: Assistant = {
const info: Assistant = {
id: v1.id,
sessionID: v1.metadata.sessionID,
role: "assistant",
@@ -248,109 +242,135 @@ export namespace MessageV2 {
providerID: v1.metadata.assistant!.providerID,
system: v1.metadata.assistant!.system,
error: v1.metadata.error,
parts: v1.parts.flatMap((part): AssistantPart[] => {
if (part.type === "text") {
return [
{
type: "text",
text: part.text,
},
]
}
if (part.type === "step-start") {
return [
{
type: "step-start",
},
]
}
if (part.type === "tool-invocation") {
return [
{
type: "tool",
id: part.toolInvocation.toolCallId,
tool: part.toolInvocation.toolName,
state: (() => {
if (part.toolInvocation.state === "partial-call") {
return {
status: "pending",
}
}
const { title, time, ...metadata } = v1.metadata.tool[part.toolInvocation.toolCallId] ?? {}
if (part.toolInvocation.state === "call") {
return {
status: "running",
input: part.toolInvocation.args,
time: {
start: time?.start,
},
}
}
if (part.toolInvocation.state === "result") {
return {
status: "completed",
input: part.toolInvocation.args,
output: part.toolInvocation.result,
title,
time,
metadata,
}
}
throw new Error("unknown tool invocation state")
})(),
},
]
}
return []
}),
}
return result
const parts = v1.parts.flatMap((part): Part[] => {
const base = {
id: Identifier.ascending("part"),
messageID: v1.id,
sessionID: v1.metadata.sessionID,
}
if (part.type === "text") {
return [
{
...base,
type: "text",
text: part.text,
},
]
}
if (part.type === "step-start") {
return [
{
...base,
type: "step-start",
},
]
}
if (part.type === "tool-invocation") {
return [
{
...base,
type: "tool",
callID: part.toolInvocation.toolCallId,
tool: part.toolInvocation.toolName,
state: (() => {
if (part.toolInvocation.state === "partial-call") {
return {
status: "pending",
}
}
const { title, time, ...metadata } = v1.metadata.tool[part.toolInvocation.toolCallId] ?? {}
if (part.toolInvocation.state === "call") {
return {
status: "running",
input: part.toolInvocation.args,
time: {
start: time?.start,
},
}
}
if (part.toolInvocation.state === "result") {
return {
status: "completed",
input: part.toolInvocation.args,
output: part.toolInvocation.result,
title,
time,
metadata,
}
}
throw new Error("unknown tool invocation state")
})(),
},
]
}
return []
})
return {
info,
parts,
}
}
if (v1.role === "user") {
const result: User = {
const info: User = {
id: v1.id,
sessionID: v1.metadata.sessionID,
role: "user",
time: {
created: v1.metadata.time.created,
},
parts: v1.parts.flatMap((part): UserPart[] => {
if (part.type === "text") {
return [
{
type: "text",
text: part.text,
},
]
}
if (part.type === "file") {
return [
{
type: "file",
mime: part.mediaType,
filename: part.filename,
url: part.url,
},
]
}
return []
}),
}
return result
const parts = v1.parts.flatMap((part): Part[] => {
const base = {
id: Identifier.ascending("part"),
messageID: v1.id,
sessionID: v1.metadata.sessionID,
}
if (part.type === "text") {
return [
{
...base,
type: "text",
text: part.text,
},
]
}
if (part.type === "file") {
return [
{
...base,
type: "file",
mime: part.mediaType,
filename: part.filename,
url: part.url,
},
]
}
return []
})
return { info, parts }
}
throw new Error("unknown message type")
}
export function toModelMessage(input: Info[]): ModelMessage[] {
export function toModelMessage(
input: {
info: Info
parts: Part[]
}[],
): ModelMessage[] {
const result: UIMessage[] = []
for (const msg of input) {
if (msg.parts.length === 0) continue
if (msg.role === "user") {
if (msg.info.role === "user") {
result.push({
id: msg.id,
id: msg.info.id,
role: "user",
parts: msg.parts.flatMap((part): UIMessage["parts"] => {
if (part.type === "text")
@@ -374,9 +394,9 @@ export namespace MessageV2 {
})
}
if (msg.role === "assistant") {
if (msg.info.role === "assistant") {
result.push({
id: msg.id,
id: msg.info.id,
role: "assistant",
parts: msg.parts.flatMap((part): UIMessage["parts"] => {
if (part.type === "text")
@@ -398,7 +418,7 @@ export namespace MessageV2 {
{
type: ("tool-" + part.tool) as `tool-${string}`,
state: "output-available",
toolCallId: part.id,
toolCallId: part.callID,
input: part.state.input,
output: part.state.output,
},
@@ -408,7 +428,7 @@ export namespace MessageV2 {
{
type: ("tool-" + part.tool) as `tool-${string}`,
state: "output-error",
toolCallId: part.id,
toolCallId: part.callID,
input: part.state.input,
errorText: part.state.error,
},