feat(id): brand SessionID through Drizzle and Zod schemas (#16953)

This commit is contained in:
Kit Langton
2026-03-11 19:16:56 -04:00
committed by GitHub
parent 4e73473119
commit cb67465675
44 changed files with 226 additions and 158 deletions

View File

@@ -2,6 +2,7 @@ import { BusEvent } from "@/bus/bus-event"
import { Bus } from "@/bus"
import { Session } from "."
import { Identifier } from "../id/id"
import { SessionID } from "./schema"
import { Instance } from "../project/instance"
import { Provider } from "../provider/provider"
import { MessageV2 } from "./message-v2"
@@ -22,7 +23,7 @@ export namespace SessionCompaction {
Compacted: BusEvent.define(
"session.compacted",
z.object({
sessionID: z.string(),
sessionID: SessionID.zod,
}),
),
}
@@ -55,7 +56,7 @@ export namespace SessionCompaction {
// goes backwards through parts until there are 40_000 tokens worth of tool
// calls. then erases output of previous tool calls. idea is to throw away old
// tool calls that are no longer relevant.
export async function prune(input: { sessionID: string }) {
export async function prune(input: { sessionID: SessionID }) {
const config = await Config.get()
if (config.compaction?.prune === false) return
log.info("pruning")
@@ -101,7 +102,7 @@ export namespace SessionCompaction {
export async function process(input: {
parentID: string
messages: MessageV2.WithParts[]
sessionID: string
sessionID: SessionID
abort: AbortSignal
auto: boolean
overflow?: boolean
@@ -295,7 +296,7 @@ When constructing the summary, try to stick to this template:
export const create = fn(
z.object({
sessionID: Identifier.schema("session"),
sessionID: SessionID.zod,
agent: z.string(),
model: z.object({
providerID: z.string(),

View File

@@ -24,6 +24,7 @@ import { Command } from "../command"
import { Snapshot } from "@/snapshot"
import { WorkspaceContext } from "../control-plane/workspace-context"
import { ProjectID } from "../project/schema"
import { SessionID } from "./schema"
import type { Provider } from "@/provider/provider"
import { PermissionNext } from "@/permission/next"
@@ -119,12 +120,12 @@ export namespace Session {
export const Info = z
.object({
id: Identifier.schema("session"),
id: SessionID.zod,
slug: z.string(),
projectID: ProjectID.zod,
workspaceID: z.string().optional(),
directory: z.string(),
parentID: Identifier.schema("session").optional(),
parentID: SessionID.zod.optional(),
summary: z
.object({
additions: z.number(),
@@ -201,14 +202,14 @@ export namespace Session {
Diff: BusEvent.define(
"session.diff",
z.object({
sessionID: z.string(),
sessionID: SessionID.zod,
diff: Snapshot.FileDiff.array(),
}),
),
Error: BusEvent.define(
"session.error",
z.object({
sessionID: z.string().optional(),
sessionID: SessionID.zod.optional(),
error: MessageV2.Assistant.shape.error,
}),
),
@@ -217,7 +218,7 @@ export namespace Session {
export const create = fn(
z
.object({
parentID: Identifier.schema("session").optional(),
parentID: SessionID.zod.optional(),
title: z.string().optional(),
permission: Info.shape.permission,
workspaceID: Identifier.schema("workspace").optional(),
@@ -236,7 +237,7 @@ export namespace Session {
export const fork = fn(
z.object({
sessionID: Identifier.schema("session"),
sessionID: SessionID.zod,
messageID: Identifier.schema("message").optional(),
}),
async (input) => {
@@ -277,7 +278,7 @@ export namespace Session {
},
)
export const touch = fn(Identifier.schema("session"), async (sessionID) => {
export const touch = fn(SessionID.zod, async (sessionID) => {
const now = Date.now()
Database.use((db) => {
const row = db
@@ -293,15 +294,15 @@ export namespace Session {
})
export async function createNext(input: {
id?: string
id?: SessionID
title?: string
parentID?: string
parentID?: SessionID
workspaceID?: string
directory: string
permission?: PermissionNext.Ruleset
}) {
const result: Info = {
id: Identifier.descending("session", input.id),
id: SessionID.descending(input.id),
slug: Slug.create(),
version: Installation.VERSION,
projectID: Instance.project.id,
@@ -342,13 +343,13 @@ export namespace Session {
return path.join(base, [input.time.created, input.slug].join("-") + ".md")
}
export const get = fn(Identifier.schema("session"), async (id) => {
export const get = fn(SessionID.zod, async (id) => {
const row = Database.use((db) => db.select().from(SessionTable).where(eq(SessionTable.id, id)).get())
if (!row) throw new NotFoundError({ message: `Session not found: ${id}` })
return fromRow(row)
})
export const share = fn(Identifier.schema("session"), async (id) => {
export const share = fn(SessionID.zod, async (id) => {
const cfg = await Config.get()
if (cfg.share === "disabled") {
throw new Error("Sharing is disabled in configuration")
@@ -364,7 +365,7 @@ export namespace Session {
return share
})
export const unshare = fn(Identifier.schema("session"), async (id) => {
export const unshare = fn(SessionID.zod, async (id) => {
// Use ShareNext to remove the share (same as share function uses ShareNext to create)
const { ShareNext } = await import("@/share/share-next")
await ShareNext.remove(id)
@@ -378,7 +379,7 @@ export namespace Session {
export const setTitle = fn(
z.object({
sessionID: Identifier.schema("session"),
sessionID: SessionID.zod,
title: z.string(),
}),
async (input) => {
@@ -399,7 +400,7 @@ export namespace Session {
export const setArchived = fn(
z.object({
sessionID: Identifier.schema("session"),
sessionID: SessionID.zod,
time: z.number().optional(),
}),
async (input) => {
@@ -420,7 +421,7 @@ export namespace Session {
export const setPermission = fn(
z.object({
sessionID: Identifier.schema("session"),
sessionID: SessionID.zod,
permission: PermissionNext.Ruleset,
}),
async (input) => {
@@ -441,7 +442,7 @@ export namespace Session {
export const setRevert = fn(
z.object({
sessionID: Identifier.schema("session"),
sessionID: SessionID.zod,
revert: Info.shape.revert,
summary: Info.shape.summary,
}),
@@ -467,7 +468,7 @@ export namespace Session {
},
)
export const clearRevert = fn(Identifier.schema("session"), async (sessionID) => {
export const clearRevert = fn(SessionID.zod, async (sessionID) => {
return Database.use((db) => {
const row = db
.update(SessionTable)
@@ -487,7 +488,7 @@ export namespace Session {
export const setSummary = fn(
z.object({
sessionID: Identifier.schema("session"),
sessionID: SessionID.zod,
summary: Info.shape.summary,
}),
async (input) => {
@@ -511,7 +512,7 @@ export namespace Session {
},
)
export const diff = fn(Identifier.schema("session"), async (sessionID) => {
export const diff = fn(SessionID.zod, async (sessionID) => {
try {
return await Storage.read<Snapshot.FileDiff[]>(["session_diff", sessionID])
} catch {
@@ -521,7 +522,7 @@ export namespace Session {
export const messages = fn(
z.object({
sessionID: Identifier.schema("session"),
sessionID: SessionID.zod,
limit: z.number().optional(),
}),
async (input) => {
@@ -647,7 +648,7 @@ export namespace Session {
}
}
export const children = fn(Identifier.schema("session"), async (parentID) => {
export const children = fn(SessionID.zod, async (parentID) => {
const project = Instance.project
const rows = Database.use((db) =>
db
@@ -659,7 +660,7 @@ export namespace Session {
return rows.map(fromRow)
})
export const remove = fn(Identifier.schema("session"), async (sessionID) => {
export const remove = fn(SessionID.zod, async (sessionID) => {
const project = Instance.project
try {
const session = await get(sessionID)
@@ -705,7 +706,7 @@ export namespace Session {
export const removeMessage = fn(
z.object({
sessionID: Identifier.schema("session"),
sessionID: SessionID.zod,
messageID: Identifier.schema("message"),
}),
async (input) => {
@@ -727,7 +728,7 @@ export namespace Session {
export const removePart = fn(
z.object({
sessionID: Identifier.schema("session"),
sessionID: SessionID.zod,
messageID: Identifier.schema("message"),
partID: Identifier.schema("part"),
}),
@@ -775,7 +776,7 @@ export namespace Session {
export const updatePartDelta = fn(
z.object({
sessionID: z.string(),
sessionID: SessionID.zod,
messageID: z.string(),
partID: z.string(),
field: z.string(),
@@ -873,7 +874,7 @@ export namespace Session {
export const initialize = fn(
z.object({
sessionID: Identifier.schema("session"),
sessionID: SessionID.zod,
modelID: z.string(),
providerID: z.string(),
messageID: Identifier.schema("message"),

View File

@@ -1,4 +1,5 @@
import { BusEvent } from "@/bus/bus-event"
import { SessionID } from "./schema"
import z from "zod"
import { NamedError } from "@opencode-ai/util/error"
import { APICallError, convertToModelMessages, LoadAPIKeyError, type ModelMessage, type UIMessage } from "ai"
@@ -79,7 +80,7 @@ export namespace MessageV2 {
const PartBase = z.object({
id: z.string(),
sessionID: z.string(),
sessionID: SessionID.zod,
messageID: z.string(),
})
@@ -344,7 +345,7 @@ export namespace MessageV2 {
const Base = z.object({
id: z.string(),
sessionID: z.string(),
sessionID: SessionID.zod,
})
export const User = Base.extend({
@@ -457,7 +458,7 @@ export namespace MessageV2 {
Removed: BusEvent.define(
"message.removed",
z.object({
sessionID: z.string(),
sessionID: SessionID.zod,
messageID: z.string(),
}),
),
@@ -470,7 +471,7 @@ export namespace MessageV2 {
PartDelta: BusEvent.define(
"message.part.delta",
z.object({
sessionID: z.string(),
sessionID: SessionID.zod,
messageID: z.string(),
partID: z.string(),
field: z.string(),
@@ -480,7 +481,7 @@ export namespace MessageV2 {
PartRemoved: BusEvent.define(
"message.part.removed",
z.object({
sessionID: z.string(),
sessionID: SessionID.zod,
messageID: z.string(),
partID: z.string(),
}),
@@ -728,7 +729,7 @@ export namespace MessageV2 {
)
}
export const stream = fn(Identifier.schema("session"), async function* (sessionID) {
export const stream = fn(SessionID.zod, async function* (sessionID) {
const size = 50
let offset = 0
while (true) {
@@ -792,7 +793,7 @@ export namespace MessageV2 {
export const get = fn(
z.object({
sessionID: Identifier.schema("session"),
sessionID: SessionID.zod,
messageID: Identifier.schema("message"),
}),
async (input): Promise<WithParts> => {

View File

@@ -1,4 +1,5 @@
import z from "zod"
import { SessionID } from "./schema"
import { NamedError } from "@opencode-ai/util/error"
export namespace Message {
@@ -142,7 +143,7 @@ export namespace Message {
error: z
.discriminatedUnion("name", [AuthError.Schema, NamedError.Unknown.Schema, OutputLengthError.Schema])
.optional(),
sessionID: z.string(),
sessionID: SessionID.zod,
tool: z.record(
z.string(),
z

View File

@@ -15,6 +15,7 @@ import { Config } from "@/config/config"
import { SessionCompaction } from "./compaction"
import { PermissionNext } from "@/permission/next"
import { Question } from "@/question"
import type { SessionID } from "./schema"
export namespace SessionProcessor {
const DOOM_LOOP_THRESHOLD = 3
@@ -25,7 +26,7 @@ export namespace SessionProcessor {
export function create(input: {
assistantMessage: MessageV2.Assistant
sessionID: string
sessionID: SessionID
model: Provider.Model
abort: AbortSignal
}) {

View File

@@ -4,6 +4,7 @@ import fs from "fs/promises"
import z from "zod"
import { Filesystem } from "../util/filesystem"
import { Identifier } from "../id/id"
import { SessionID } from "./schema"
import { MessageV2 } from "./message-v2"
import { Log } from "../util/log"
import { SessionRevert } from "./revert"
@@ -84,13 +85,13 @@ export namespace SessionPrompt {
},
)
export function assertNotBusy(sessionID: string) {
export function assertNotBusy(sessionID: SessionID) {
const match = state()[sessionID]
if (match) throw new Session.BusyError(sessionID)
}
export const PromptInput = z.object({
sessionID: Identifier.schema("session"),
sessionID: SessionID.zod,
messageID: Identifier.schema("message").optional(),
model: z
.object({
@@ -254,7 +255,7 @@ export namespace SessionPrompt {
return s[sessionID].abort.signal
}
export function cancel(sessionID: string) {
export function cancel(sessionID: SessionID) {
log.info("cancel", { sessionID })
const s = state()
const match = s[sessionID]
@@ -269,7 +270,7 @@ export namespace SessionPrompt {
}
export const LoopInput = z.object({
sessionID: Identifier.schema("session"),
sessionID: SessionID.zod,
resume_existing: z.boolean().optional(),
})
export const loop = fn(LoopInput, async (input) => {
@@ -731,7 +732,7 @@ export namespace SessionPrompt {
throw new Error("Impossible")
})
async function lastModel(sessionID: string) {
async function lastModel(sessionID: SessionID) {
for await (const item of MessageV2.stream(sessionID)) {
if (item.info.role === "user" && item.info.model) return item.info.model
}
@@ -1467,7 +1468,7 @@ NOTE: At any point in time through this workflow you should feel free to ask the
}
export const ShellInput = z.object({
sessionID: Identifier.schema("session"),
sessionID: SessionID.zod,
agent: z.string(),
model: z
.object({
@@ -1719,7 +1720,7 @@ NOTE: At any point in time through this workflow you should feel free to ask the
export const CommandInput = z.object({
messageID: Identifier.schema("message").optional(),
sessionID: Identifier.schema("session"),
sessionID: SessionID.zod,
agent: z.string().optional(),
model: z.string().optional(),
arguments: z.string(),

View File

@@ -1,5 +1,6 @@
import z from "zod"
import { Identifier } from "../id/id"
import { SessionID } from "./schema"
import { Snapshot } from "../snapshot"
import { MessageV2 } from "./message-v2"
import { Session } from "."
@@ -15,7 +16,7 @@ export namespace SessionRevert {
const log = Log.create({ service: "session.revert" })
export const RevertInput = z.object({
sessionID: Identifier.schema("session"),
sessionID: SessionID.zod,
messageID: Identifier.schema("message"),
partID: Identifier.schema("part").optional(),
})
@@ -79,7 +80,7 @@ export namespace SessionRevert {
return session
}
export async function unrevert(input: { sessionID: string }) {
export async function unrevert(input: { sessionID: SessionID }) {
log.info("unreverting", input)
SessionPrompt.assertNotBusy(input.sessionID)
const session = await Session.get(input.sessionID)

View File

@@ -0,0 +1,17 @@
import { Schema } from "effect"
import z from "zod"
import { withStatics } from "@/util/schema"
import { Identifier } from "@/id/id"
const sessionIdSchema = Schema.String.pipe(Schema.brand("SessionId"))
export type SessionID = typeof sessionIdSchema.Type
export const SessionID = sessionIdSchema.pipe(
withStatics((schema: typeof sessionIdSchema) => ({
make: (id: string) => schema.makeUnsafe(id),
descending: (id?: string) => schema.makeUnsafe(Identifier.descending("session", id)),
zod: z.string().startsWith("ses").pipe(z.custom<SessionID>()),
})),
)

View File

@@ -4,6 +4,7 @@ import type { MessageV2 } from "./message-v2"
import type { Snapshot } from "../snapshot"
import type { PermissionNext } from "../permission/next"
import type { ProjectID } from "../project/schema"
import type { SessionID } from "./schema"
import { Timestamps } from "../storage/schema.sql"
type PartData = Omit<MessageV2.Part, "id" | "sessionID" | "messageID">
@@ -12,13 +13,13 @@ type InfoData = Omit<MessageV2.Info, "id" | "sessionID">
export const SessionTable = sqliteTable(
"session",
{
id: text().primaryKey(),
id: text().$type<SessionID>().primaryKey(),
project_id: text()
.$type<ProjectID>()
.notNull()
.references(() => ProjectTable.id, { onDelete: "cascade" }),
workspace_id: text(),
parent_id: text(),
parent_id: text().$type<SessionID>(),
slug: text().notNull(),
directory: text().notNull(),
title: text().notNull(),
@@ -46,6 +47,7 @@ export const MessageTable = sqliteTable(
{
id: text().primaryKey(),
session_id: text()
.$type<SessionID>()
.notNull()
.references(() => SessionTable.id, { onDelete: "cascade" }),
...Timestamps,
@@ -61,7 +63,7 @@ export const PartTable = sqliteTable(
message_id: text()
.notNull()
.references(() => MessageTable.id, { onDelete: "cascade" }),
session_id: text().notNull(),
session_id: text().$type<SessionID>().notNull(),
...Timestamps,
data: text({ mode: "json" }).notNull().$type<PartData>(),
},
@@ -72,6 +74,7 @@ export const TodoTable = sqliteTable(
"todo",
{
session_id: text()
.$type<SessionID>()
.notNull()
.references(() => SessionTable.id, { onDelete: "cascade" }),
content: text().notNull(),

View File

@@ -1,6 +1,7 @@
import { BusEvent } from "@/bus/bus-event"
import { Bus } from "@/bus"
import { Instance } from "@/project/instance"
import { SessionID } from "./schema"
import z from "zod"
export namespace SessionStatus {
@@ -28,7 +29,7 @@ export namespace SessionStatus {
Status: BusEvent.define(
"session.status",
z.object({
sessionID: z.string(),
sessionID: SessionID.zod,
status: Info,
}),
),
@@ -36,7 +37,7 @@ export namespace SessionStatus {
Idle: BusEvent.define(
"session.idle",
z.object({
sessionID: z.string(),
sessionID: SessionID.zod,
}),
),
}
@@ -46,7 +47,7 @@ export namespace SessionStatus {
return data
})
export function get(sessionID: string) {
export function get(sessionID: SessionID) {
return (
state()[sessionID] ?? {
type: "idle",
@@ -58,7 +59,7 @@ export namespace SessionStatus {
return state()
}
export function set(sessionID: string, status: Info) {
export function set(sessionID: SessionID, status: Info) {
Bus.publish(Event.Status, {
sessionID,
status,

View File

@@ -4,6 +4,7 @@ import { Session } from "."
import { MessageV2 } from "./message-v2"
import { Identifier } from "@/id/id"
import { SessionID } from "./schema"
import { Snapshot } from "@/snapshot"
import { Storage } from "@/storage/storage"
@@ -68,7 +69,7 @@ export namespace SessionSummary {
export const summarize = fn(
z.object({
sessionID: z.string(),
sessionID: SessionID.zod,
messageID: z.string(),
}),
async (input) => {
@@ -80,7 +81,7 @@ export namespace SessionSummary {
},
)
async function summarizeSession(input: { sessionID: string; messages: MessageV2.WithParts[] }) {
async function summarizeSession(input: { sessionID: SessionID; messages: MessageV2.WithParts[] }) {
const diffs = await computeDiff({ messages: input.messages })
await Session.setSummary({
sessionID: input.sessionID,
@@ -113,7 +114,7 @@ export namespace SessionSummary {
export const diff = fn(
z.object({
sessionID: Identifier.schema("session"),
sessionID: SessionID.zod,
messageID: Identifier.schema("message").optional(),
}),
async (input) => {

View File

@@ -1,5 +1,6 @@
import { BusEvent } from "@/bus/bus-event"
import { Bus } from "@/bus"
import { SessionID } from "./schema"
import z from "zod"
import { Database, eq, asc } from "../storage/db"
import { TodoTable } from "./session.sql"
@@ -18,13 +19,13 @@ export namespace Todo {
Updated: BusEvent.define(
"todo.updated",
z.object({
sessionID: z.string(),
sessionID: SessionID.zod,
todos: z.array(Info),
}),
),
}
export function update(input: { sessionID: string; todos: Info[] }) {
export function update(input: { sessionID: SessionID; todos: Info[] }) {
Database.transaction((db) => {
db.delete(TodoTable).where(eq(TodoTable.session_id, input.sessionID)).run()
if (input.todos.length === 0) return
@@ -43,7 +44,7 @@ export namespace Todo {
Bus.publish(Event.Updated, input)
}
export function get(sessionID: string) {
export function get(sessionID: SessionID) {
const rows = Database.use((db) =>
db.select().from(TodoTable).where(eq(TodoTable.session_id, sessionID)).orderBy(asc(TodoTable.position)).all(),
)