mirror of
https://gitea.toothfairyai.com/ToothFairyAI/tf_code.git
synced 2026-04-09 18:29:39 +00:00
perf(server): paginate session history (#17134)
This commit is contained in:
@@ -6,8 +6,8 @@ import { APICallError, convertToModelMessages, LoadAPIKeyError, type ModelMessag
|
||||
import { LSP } from "../lsp"
|
||||
import { Snapshot } from "@/snapshot"
|
||||
import { fn } from "@/util/fn"
|
||||
import { Database, eq, desc, inArray } from "@/storage/db"
|
||||
import { MessageTable, PartTable } from "./session.sql"
|
||||
import { Database, NotFoundError, and, desc, eq, inArray, lt, or } from "@/storage/db"
|
||||
import { MessageTable, PartTable, SessionTable } from "./session.sql"
|
||||
import { ProviderTransform } from "@/provider/transform"
|
||||
import { STATUS_CODES } from "http"
|
||||
import { Storage } from "@/storage/storage"
|
||||
@@ -494,6 +494,68 @@ export namespace MessageV2 {
|
||||
})
|
||||
export type WithParts = z.infer<typeof WithParts>
|
||||
|
||||
const Cursor = z.object({
|
||||
id: MessageID.zod,
|
||||
time: z.number(),
|
||||
})
|
||||
type Cursor = z.infer<typeof Cursor>
|
||||
|
||||
export const cursor = {
|
||||
encode(input: Cursor) {
|
||||
return Buffer.from(JSON.stringify(input)).toString("base64url")
|
||||
},
|
||||
decode(input: string) {
|
||||
return Cursor.parse(JSON.parse(Buffer.from(input, "base64url").toString("utf8")))
|
||||
},
|
||||
}
|
||||
|
||||
const info = (row: typeof MessageTable.$inferSelect) =>
|
||||
({
|
||||
...row.data,
|
||||
id: row.id,
|
||||
sessionID: row.session_id,
|
||||
}) as MessageV2.Info
|
||||
|
||||
const part = (row: typeof PartTable.$inferSelect) =>
|
||||
({
|
||||
...row.data,
|
||||
id: row.id,
|
||||
sessionID: row.session_id,
|
||||
messageID: row.message_id,
|
||||
}) as MessageV2.Part
|
||||
|
||||
const older = (row: Cursor) =>
|
||||
or(
|
||||
lt(MessageTable.time_created, row.time),
|
||||
and(eq(MessageTable.time_created, row.time), lt(MessageTable.id, row.id)),
|
||||
)
|
||||
|
||||
async function hydrate(rows: (typeof MessageTable.$inferSelect)[]) {
|
||||
const ids = rows.map((row) => row.id)
|
||||
const partByMessage = new Map<string, MessageV2.Part[]>()
|
||||
if (ids.length > 0) {
|
||||
const partRows = Database.use((db) =>
|
||||
db
|
||||
.select()
|
||||
.from(PartTable)
|
||||
.where(inArray(PartTable.message_id, ids))
|
||||
.orderBy(PartTable.message_id, PartTable.id)
|
||||
.all(),
|
||||
)
|
||||
for (const row of partRows) {
|
||||
const next = part(row)
|
||||
const list = partByMessage.get(row.message_id)
|
||||
if (list) list.push(next)
|
||||
else partByMessage.set(row.message_id, [next])
|
||||
}
|
||||
}
|
||||
|
||||
return rows.map((row) => ({
|
||||
info: info(row),
|
||||
parts: partByMessage.get(row.id) ?? [],
|
||||
}))
|
||||
}
|
||||
|
||||
export function toModelMessages(
|
||||
input: WithParts[],
|
||||
model: Provider.Model,
|
||||
@@ -729,56 +791,61 @@ export namespace MessageV2 {
|
||||
)
|
||||
}
|
||||
|
||||
export const stream = fn(SessionID.zod, async function* (sessionID) {
|
||||
const size = 50
|
||||
let offset = 0
|
||||
while (true) {
|
||||
export const page = fn(
|
||||
z.object({
|
||||
sessionID: SessionID.zod,
|
||||
limit: z.number().int().positive(),
|
||||
before: z.string().optional(),
|
||||
}),
|
||||
async (input) => {
|
||||
const before = input.before ? cursor.decode(input.before) : undefined
|
||||
const where = before
|
||||
? and(eq(MessageTable.session_id, input.sessionID), older(before))
|
||||
: eq(MessageTable.session_id, input.sessionID)
|
||||
const rows = Database.use((db) =>
|
||||
db
|
||||
.select()
|
||||
.from(MessageTable)
|
||||
.where(eq(MessageTable.session_id, sessionID))
|
||||
.orderBy(desc(MessageTable.time_created))
|
||||
.limit(size)
|
||||
.offset(offset)
|
||||
.where(where)
|
||||
.orderBy(desc(MessageTable.time_created), desc(MessageTable.id))
|
||||
.limit(input.limit + 1)
|
||||
.all(),
|
||||
)
|
||||
if (rows.length === 0) break
|
||||
|
||||
const ids = rows.map((row) => row.id)
|
||||
const partsByMessage = new Map<string, MessageV2.Part[]>()
|
||||
if (ids.length > 0) {
|
||||
const partRows = Database.use((db) =>
|
||||
db
|
||||
.select()
|
||||
.from(PartTable)
|
||||
.where(inArray(PartTable.message_id, ids))
|
||||
.orderBy(PartTable.message_id, PartTable.id)
|
||||
.all(),
|
||||
if (rows.length === 0) {
|
||||
const row = Database.use((db) =>
|
||||
db.select({ id: SessionTable.id }).from(SessionTable).where(eq(SessionTable.id, input.sessionID)).get(),
|
||||
)
|
||||
for (const row of partRows) {
|
||||
const part = {
|
||||
...row.data,
|
||||
id: row.id,
|
||||
sessionID: row.session_id,
|
||||
messageID: row.message_id,
|
||||
} as MessageV2.Part
|
||||
const list = partsByMessage.get(row.message_id)
|
||||
if (list) list.push(part)
|
||||
else partsByMessage.set(row.message_id, [part])
|
||||
if (!row) throw new NotFoundError({ message: `Session not found: ${input.sessionID}` })
|
||||
return {
|
||||
items: [] as MessageV2.WithParts[],
|
||||
more: false,
|
||||
}
|
||||
}
|
||||
|
||||
for (const row of rows) {
|
||||
const info = { ...row.data, id: row.id, sessionID: row.session_id } as MessageV2.Info
|
||||
yield {
|
||||
info,
|
||||
parts: partsByMessage.get(row.id) ?? [],
|
||||
}
|
||||
const more = rows.length > input.limit
|
||||
const page = more ? rows.slice(0, input.limit) : rows
|
||||
const items = await hydrate(page)
|
||||
items.reverse()
|
||||
const tail = page.at(-1)
|
||||
return {
|
||||
items,
|
||||
more,
|
||||
cursor: more && tail ? cursor.encode({ id: tail.id, time: tail.time_created }) : undefined,
|
||||
}
|
||||
},
|
||||
)
|
||||
|
||||
offset += rows.length
|
||||
if (rows.length < size) break
|
||||
export const stream = fn(SessionID.zod, async function* (sessionID) {
|
||||
const size = 50
|
||||
let before: string | undefined
|
||||
while (true) {
|
||||
const next = await page({ sessionID, limit: size, before })
|
||||
if (next.items.length === 0) break
|
||||
for (let i = next.items.length - 1; i >= 0; i--) {
|
||||
yield next.items[i]
|
||||
}
|
||||
if (!next.more || !next.cursor) break
|
||||
before = next.cursor
|
||||
}
|
||||
})
|
||||
|
||||
@@ -797,11 +864,16 @@ export namespace MessageV2 {
|
||||
messageID: MessageID.zod,
|
||||
}),
|
||||
async (input): Promise<WithParts> => {
|
||||
const row = Database.use((db) => db.select().from(MessageTable).where(eq(MessageTable.id, input.messageID)).get())
|
||||
if (!row) throw new Error(`Message not found: ${input.messageID}`)
|
||||
const info = { ...row.data, id: row.id, sessionID: row.session_id } as MessageV2.Info
|
||||
const row = Database.use((db) =>
|
||||
db
|
||||
.select()
|
||||
.from(MessageTable)
|
||||
.where(and(eq(MessageTable.id, input.messageID), eq(MessageTable.session_id, input.sessionID)))
|
||||
.get(),
|
||||
)
|
||||
if (!row) throw new NotFoundError({ message: `Message not found: ${input.messageID}` })
|
||||
return {
|
||||
info,
|
||||
info: info(row),
|
||||
parts: await parts(input.messageID),
|
||||
}
|
||||
},
|
||||
|
||||
@@ -54,7 +54,7 @@ export const MessageTable = sqliteTable(
|
||||
...Timestamps,
|
||||
data: text({ mode: "json" }).notNull().$type<InfoData>(),
|
||||
},
|
||||
(table) => [index("message_session_idx").on(table.session_id)],
|
||||
(table) => [index("message_session_time_created_id_idx").on(table.session_id, table.time_created, table.id)],
|
||||
)
|
||||
|
||||
export const PartTable = sqliteTable(
|
||||
@@ -69,7 +69,10 @@ export const PartTable = sqliteTable(
|
||||
...Timestamps,
|
||||
data: text({ mode: "json" }).notNull().$type<PartData>(),
|
||||
},
|
||||
(table) => [index("part_message_idx").on(table.message_id), index("part_session_idx").on(table.session_id)],
|
||||
(table) => [
|
||||
index("part_message_id_id_idx").on(table.message_id, table.id),
|
||||
index("part_session_idx").on(table.session_id),
|
||||
],
|
||||
)
|
||||
|
||||
export const TodoTable = sqliteTable(
|
||||
|
||||
Reference in New Issue
Block a user