2025-12-11 00:43:00 -05:00

452 lines
13 KiB
TypeScript

import { BusEvent } from "@/bus/bus-event"
import { Bus } from "@/bus"
import { Decimal } from "decimal.js"
import z from "zod"
import { type LanguageModelUsage, type ProviderMetadata } from "ai"
import { Config } from "../config/config"
import { Flag } from "../flag/flag"
import { Identifier } from "../id/id"
import { Installation } from "../installation"
import { Storage } from "../storage/storage"
import { Log } from "../util/log"
import { MessageV2 } from "./message-v2"
import { Instance } from "../project/instance"
import { SessionPrompt } from "./prompt"
import { fn } from "@/util/fn"
import { Command } from "../command"
import { Snapshot } from "@/snapshot"
import type { Provider } from "@/provider/provider"
export namespace Session {
const log = Log.create({ service: "session" })
const parentTitlePrefix = "New session - "
const childTitlePrefix = "Child session - "
function createDefaultTitle(isChild = false) {
return (isChild ? childTitlePrefix : parentTitlePrefix) + new Date().toISOString()
}
export function isDefaultTitle(title: string) {
return new RegExp(
`^(${parentTitlePrefix}|${childTitlePrefix})\\d{4}-\\d{2}-\\d{2}T\\d{2}:\\d{2}:\\d{2}\\.\\d{3}Z$`,
).test(title)
}
export const Info = z
.object({
id: Identifier.schema("session"),
projectID: z.string(),
directory: z.string(),
parentID: Identifier.schema("session").optional(),
summary: z
.object({
additions: z.number(),
deletions: z.number(),
files: z.number(),
diffs: Snapshot.FileDiff.array().optional(),
})
.optional(),
share: z
.object({
url: z.string(),
})
.optional(),
title: z.string(),
version: z.string(),
time: z.object({
created: z.number(),
updated: z.number(),
compacting: z.number().optional(),
archived: z.number().optional(),
}),
revert: z
.object({
messageID: z.string(),
partID: z.string().optional(),
snapshot: z.string().optional(),
diff: z.string().optional(),
})
.optional(),
})
.meta({
ref: "Session",
})
export type Info = z.output<typeof Info>
export const ShareInfo = z
.object({
secret: z.string(),
url: z.string(),
})
.meta({
ref: "SessionShare",
})
export type ShareInfo = z.output<typeof ShareInfo>
export const Event = {
Created: BusEvent.define(
"session.created",
z.object({
info: Info,
}),
),
Updated: BusEvent.define(
"session.updated",
z.object({
info: Info,
}),
),
Deleted: BusEvent.define(
"session.deleted",
z.object({
info: Info,
}),
),
Diff: BusEvent.define(
"session.diff",
z.object({
sessionID: z.string(),
diff: Snapshot.FileDiff.array(),
}),
),
Error: BusEvent.define(
"session.error",
z.object({
sessionID: z.string().optional(),
error: MessageV2.Assistant.shape.error,
}),
),
}
export const create = fn(
z
.object({
parentID: Identifier.schema("session").optional(),
title: z.string().optional(),
})
.optional(),
async (input) => {
return createNext({
parentID: input?.parentID,
directory: Instance.directory,
title: input?.title,
})
},
)
export const fork = fn(
z.object({
sessionID: Identifier.schema("session"),
messageID: Identifier.schema("message").optional(),
}),
async (input) => {
const session = await createNext({
directory: Instance.directory,
})
const msgs = await messages({ sessionID: input.sessionID })
for (const msg of msgs) {
if (input.messageID && msg.info.id >= input.messageID) break
const cloned = await updateMessage({
...msg.info,
sessionID: session.id,
id: Identifier.ascending("message"),
})
for (const part of msg.parts) {
await updatePart({
...part,
id: Identifier.ascending("part"),
messageID: cloned.id,
sessionID: session.id,
})
}
}
return session
},
)
export const touch = fn(Identifier.schema("session"), async (sessionID) => {
await update(sessionID, (draft) => {
draft.time.updated = Date.now()
})
})
export async function createNext(input: { id?: string; title?: string; parentID?: string; directory: string }) {
const result: Info = {
id: Identifier.descending("session", input.id),
version: Installation.VERSION,
projectID: Instance.project.id,
directory: input.directory,
parentID: input.parentID,
title: input.title ?? createDefaultTitle(!!input.parentID),
time: {
created: Date.now(),
updated: Date.now(),
},
}
log.info("created", result)
await Storage.write(["session", Instance.project.id, result.id], result)
Bus.publish(Event.Created, {
info: result,
})
const cfg = await Config.get()
if (!result.parentID && (Flag.OPENCODE_AUTO_SHARE || cfg.share === "auto"))
share(result.id)
.then((share) => {
update(result.id, (draft) => {
draft.share = share
})
})
.catch(() => {
// Silently ignore sharing errors during session creation
})
Bus.publish(Event.Updated, {
info: result,
})
return result
}
export const get = fn(Identifier.schema("session"), async (id) => {
const read = await Storage.read<Info>(["session", Instance.project.id, id])
return read as Info
})
export const getShare = fn(Identifier.schema("session"), async (id) => {
return Storage.read<ShareInfo>(["share", id])
})
export const share = fn(Identifier.schema("session"), async (id) => {
const cfg = await Config.get()
if (cfg.share === "disabled") {
throw new Error("Sharing is disabled in configuration")
}
const { ShareNext } = await import("@/share/share-next")
const share = await ShareNext.create(id)
await update(id, (draft) => {
draft.share = {
url: share.url,
}
})
return share
})
export const unshare = fn(Identifier.schema("session"), async (id) => {
const cfg = await Config.get()
if (cfg.enterprise?.url) {
const { ShareNext } = await import("@/share/share-next")
await ShareNext.remove(id)
await update(id, (draft) => {
draft.share = undefined
})
}
const share = await getShare(id)
if (!share) return
await Storage.remove(["share", id])
await update(id, (draft) => {
draft.share = undefined
})
const { Share } = await import("../share/share")
await Share.remove(id, share.secret)
})
export async function update(id: string, editor: (session: Info) => void) {
const project = Instance.project
const result = await Storage.update<Info>(["session", project.id, id], (draft) => {
editor(draft)
draft.time.updated = Date.now()
})
Bus.publish(Event.Updated, {
info: result,
})
return result
}
export const diff = fn(Identifier.schema("session"), async (sessionID) => {
const diffs = await Storage.read<Snapshot.FileDiff[]>(["session_diff", sessionID])
return diffs ?? []
})
export const messages = fn(
z.object({
sessionID: Identifier.schema("session"),
limit: z.number().optional(),
}),
async (input) => {
const result = [] as MessageV2.WithParts[]
for await (const msg of MessageV2.stream(input.sessionID)) {
if (input.limit && result.length >= input.limit) break
result.push(msg)
}
result.reverse()
return result
},
)
export async function* list() {
const project = Instance.project
for (const item of await Storage.list(["session", project.id])) {
yield Storage.read<Info>(item)
}
}
export const children = fn(Identifier.schema("session"), async (parentID) => {
const project = Instance.project
const result = [] as Session.Info[]
for (const item of await Storage.list(["session", project.id])) {
const session = await Storage.read<Info>(item)
if (session.parentID !== parentID) continue
result.push(session)
}
return result
})
export const remove = fn(Identifier.schema("session"), async (sessionID) => {
const project = Instance.project
try {
const session = await get(sessionID)
for (const child of await children(sessionID)) {
await remove(child.id)
}
await unshare(sessionID).catch(() => {})
for (const msg of await Storage.list(["message", sessionID])) {
for (const part of await Storage.list(["part", msg.at(-1)!])) {
await Storage.remove(part)
}
await Storage.remove(msg)
}
await Storage.remove(["session", project.id, sessionID])
Bus.publish(Event.Deleted, {
info: session,
})
} catch (e) {
log.error(e)
}
})
export const updateMessage = fn(MessageV2.Info, async (msg) => {
await Storage.write(["message", msg.sessionID, msg.id], msg)
Bus.publish(MessageV2.Event.Updated, {
info: msg,
})
return msg
})
export const removeMessage = fn(
z.object({
sessionID: Identifier.schema("session"),
messageID: Identifier.schema("message"),
}),
async (input) => {
await Storage.remove(["message", input.sessionID, input.messageID])
Bus.publish(MessageV2.Event.Removed, {
sessionID: input.sessionID,
messageID: input.messageID,
})
return input.messageID
},
)
const UpdatePartInput = z.union([
MessageV2.Part,
z.object({
part: MessageV2.TextPart,
delta: z.string(),
}),
z.object({
part: MessageV2.ReasoningPart,
delta: z.string(),
}),
])
export const updatePart = fn(UpdatePartInput, async (input) => {
const part = "delta" in input ? input.part : input
const delta = "delta" in input ? input.delta : undefined
await Storage.write(["part", part.messageID, part.id], part)
Bus.publish(MessageV2.Event.PartUpdated, {
part,
delta,
})
return part
})
export const getUsage = fn(
z.object({
model: z.custom<Provider.Model>(),
usage: z.custom<LanguageModelUsage>(),
metadata: z.custom<ProviderMetadata>().optional(),
}),
(input) => {
const cachedInputTokens = input.usage.cachedInputTokens ?? 0
const excludesCachedTokens = !!(input.metadata?.["anthropic"] || input.metadata?.["bedrock"])
const adjustedInputTokens = excludesCachedTokens
? (input.usage.inputTokens ?? 0)
: (input.usage.inputTokens ?? 0) - cachedInputTokens
const safe = (value: number) => {
if (!Number.isFinite(value)) return 0
return value
}
const tokens = {
input: safe(adjustedInputTokens),
output: safe(input.usage.outputTokens ?? 0),
reasoning: safe(input.usage?.reasoningTokens ?? 0),
cache: {
write: safe(
(input.metadata?.["anthropic"]?.["cacheCreationInputTokens"] ??
// @ts-expect-error
input.metadata?.["bedrock"]?.["usage"]?.["cacheWriteInputTokens"] ??
0) as number,
),
read: safe(cachedInputTokens),
},
}
const costInfo =
input.model.cost?.experimentalOver200K && tokens.input + tokens.cache.read > 200_000
? input.model.cost.experimentalOver200K
: input.model.cost
return {
cost: safe(
new Decimal(0)
.add(new Decimal(tokens.input).mul(costInfo?.input ?? 0).div(1_000_000))
.add(new Decimal(tokens.output).mul(costInfo?.output ?? 0).div(1_000_000))
.add(new Decimal(tokens.cache.read).mul(costInfo?.cache?.read ?? 0).div(1_000_000))
.add(new Decimal(tokens.cache.write).mul(costInfo?.cache?.write ?? 0).div(1_000_000))
// TODO: update models.dev to have better pricing model, for now:
// charge reasoning tokens at the same rate as output tokens
.add(new Decimal(tokens.reasoning).mul(costInfo?.output ?? 0).div(1_000_000))
.toNumber(),
),
tokens,
}
},
)
export class BusyError extends Error {
constructor(public readonly sessionID: string) {
super(`Session ${sessionID} is busy`)
}
}
export const initialize = fn(
z.object({
sessionID: Identifier.schema("session"),
modelID: z.string(),
providerID: z.string(),
messageID: Identifier.schema("message"),
}),
async (input) => {
await SessionPrompt.command({
sessionID: input.sessionID,
messageID: input.messageID,
model: input.providerID + "/" + input.modelID,
command: Command.Default.INIT,
arguments: "",
})
},
)
}