perf(server): paginate session history (#17134)

This commit is contained in:
Shoubhit Dash
2026-03-13 15:48:43 +05:30
committed by GitHub
parent ff748b82ca
commit 9457493696
10 changed files with 3042 additions and 363 deletions

View File

@@ -6,8 +6,8 @@ import { APICallError, convertToModelMessages, LoadAPIKeyError, type ModelMessag
import { LSP } from "../lsp"
import { Snapshot } from "@/snapshot"
import { fn } from "@/util/fn"
import { Database, eq, desc, inArray } from "@/storage/db"
import { MessageTable, PartTable } from "./session.sql"
import { Database, NotFoundError, and, desc, eq, inArray, lt, or } from "@/storage/db"
import { MessageTable, PartTable, SessionTable } from "./session.sql"
import { ProviderTransform } from "@/provider/transform"
import { STATUS_CODES } from "http"
import { Storage } from "@/storage/storage"
@@ -494,6 +494,68 @@ export namespace MessageV2 {
})
export type WithParts = z.infer<typeof WithParts>
const Cursor = z.object({
id: MessageID.zod,
time: z.number(),
})
type Cursor = z.infer<typeof Cursor>
export const cursor = {
encode(input: Cursor) {
return Buffer.from(JSON.stringify(input)).toString("base64url")
},
decode(input: string) {
return Cursor.parse(JSON.parse(Buffer.from(input, "base64url").toString("utf8")))
},
}
const info = (row: typeof MessageTable.$inferSelect) =>
({
...row.data,
id: row.id,
sessionID: row.session_id,
}) as MessageV2.Info
const part = (row: typeof PartTable.$inferSelect) =>
({
...row.data,
id: row.id,
sessionID: row.session_id,
messageID: row.message_id,
}) as MessageV2.Part
const older = (row: Cursor) =>
or(
lt(MessageTable.time_created, row.time),
and(eq(MessageTable.time_created, row.time), lt(MessageTable.id, row.id)),
)
async function hydrate(rows: (typeof MessageTable.$inferSelect)[]) {
const ids = rows.map((row) => row.id)
const partByMessage = new Map<string, MessageV2.Part[]>()
if (ids.length > 0) {
const partRows = Database.use((db) =>
db
.select()
.from(PartTable)
.where(inArray(PartTable.message_id, ids))
.orderBy(PartTable.message_id, PartTable.id)
.all(),
)
for (const row of partRows) {
const next = part(row)
const list = partByMessage.get(row.message_id)
if (list) list.push(next)
else partByMessage.set(row.message_id, [next])
}
}
return rows.map((row) => ({
info: info(row),
parts: partByMessage.get(row.id) ?? [],
}))
}
export function toModelMessages(
input: WithParts[],
model: Provider.Model,
@@ -729,56 +791,61 @@ export namespace MessageV2 {
)
}
export const stream = fn(SessionID.zod, async function* (sessionID) {
const size = 50
let offset = 0
while (true) {
export const page = fn(
z.object({
sessionID: SessionID.zod,
limit: z.number().int().positive(),
before: z.string().optional(),
}),
async (input) => {
const before = input.before ? cursor.decode(input.before) : undefined
const where = before
? and(eq(MessageTable.session_id, input.sessionID), older(before))
: eq(MessageTable.session_id, input.sessionID)
const rows = Database.use((db) =>
db
.select()
.from(MessageTable)
.where(eq(MessageTable.session_id, sessionID))
.orderBy(desc(MessageTable.time_created))
.limit(size)
.offset(offset)
.where(where)
.orderBy(desc(MessageTable.time_created), desc(MessageTable.id))
.limit(input.limit + 1)
.all(),
)
if (rows.length === 0) break
const ids = rows.map((row) => row.id)
const partsByMessage = new Map<string, MessageV2.Part[]>()
if (ids.length > 0) {
const partRows = Database.use((db) =>
db
.select()
.from(PartTable)
.where(inArray(PartTable.message_id, ids))
.orderBy(PartTable.message_id, PartTable.id)
.all(),
if (rows.length === 0) {
const row = Database.use((db) =>
db.select({ id: SessionTable.id }).from(SessionTable).where(eq(SessionTable.id, input.sessionID)).get(),
)
for (const row of partRows) {
const part = {
...row.data,
id: row.id,
sessionID: row.session_id,
messageID: row.message_id,
} as MessageV2.Part
const list = partsByMessage.get(row.message_id)
if (list) list.push(part)
else partsByMessage.set(row.message_id, [part])
if (!row) throw new NotFoundError({ message: `Session not found: ${input.sessionID}` })
return {
items: [] as MessageV2.WithParts[],
more: false,
}
}
for (const row of rows) {
const info = { ...row.data, id: row.id, sessionID: row.session_id } as MessageV2.Info
yield {
info,
parts: partsByMessage.get(row.id) ?? [],
}
const more = rows.length > input.limit
const page = more ? rows.slice(0, input.limit) : rows
const items = await hydrate(page)
items.reverse()
const tail = page.at(-1)
return {
items,
more,
cursor: more && tail ? cursor.encode({ id: tail.id, time: tail.time_created }) : undefined,
}
},
)
offset += rows.length
if (rows.length < size) break
export const stream = fn(SessionID.zod, async function* (sessionID) {
const size = 50
let before: string | undefined
while (true) {
const next = await page({ sessionID, limit: size, before })
if (next.items.length === 0) break
for (let i = next.items.length - 1; i >= 0; i--) {
yield next.items[i]
}
if (!next.more || !next.cursor) break
before = next.cursor
}
})
@@ -797,11 +864,16 @@ export namespace MessageV2 {
messageID: MessageID.zod,
}),
async (input): Promise<WithParts> => {
const row = Database.use((db) => db.select().from(MessageTable).where(eq(MessageTable.id, input.messageID)).get())
if (!row) throw new Error(`Message not found: ${input.messageID}`)
const info = { ...row.data, id: row.id, sessionID: row.session_id } as MessageV2.Info
const row = Database.use((db) =>
db
.select()
.from(MessageTable)
.where(and(eq(MessageTable.id, input.messageID), eq(MessageTable.session_id, input.sessionID)))
.get(),
)
if (!row) throw new NotFoundError({ message: `Message not found: ${input.messageID}` })
return {
info,
info: info(row),
parts: await parts(input.messageID),
}
},

View File

@@ -54,7 +54,7 @@ export const MessageTable = sqliteTable(
...Timestamps,
data: text({ mode: "json" }).notNull().$type<InfoData>(),
},
(table) => [index("message_session_idx").on(table.session_id)],
(table) => [index("message_session_time_created_id_idx").on(table.session_id, table.time_created, table.id)],
)
export const PartTable = sqliteTable(
@@ -69,7 +69,10 @@ export const PartTable = sqliteTable(
...Timestamps,
data: text({ mode: "json" }).notNull().$type<PartData>(),
},
(table) => [index("part_message_idx").on(table.message_id), index("part_session_idx").on(table.session_id)],
(table) => [
index("part_message_id_id_idx").on(table.message_id, table.id),
index("part_session_idx").on(table.session_id),
],
)
export const TodoTable = sqliteTable(