refactor(provider): flow branded ProviderID/ModelID through internal signatures (#17182)

This commit is contained in:
Kit Langton 2026-03-12 10:48:17 -04:00 committed by GitHub
parent a4f8d66a9b
commit 1cb7df7159
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
24 changed files with 227 additions and 205 deletions

View File

@ -46,7 +46,7 @@
"@solidjs/router": "catalog:",
"@thisbeyond/solid-dnd": "0.7.5",
"diff": "catalog:",
"effect": "4.0.0-beta.29",
"effect": "4.0.0-beta.31",
"fuzzysort": "catalog:",
"ghostty-web": "github:anomalyco/ghostty-web#main",
"luxon": "catalog:",
@ -227,7 +227,7 @@
"@solid-primitives/storage": "catalog:",
"@solidjs/meta": "catalog:",
"@solidjs/router": "0.15.4",
"effect": "4.0.0-beta.29",
"effect": "4.0.0-beta.31",
"electron-log": "^5",
"electron-store": "^10",
"electron-updater": "^6",
@ -614,7 +614,7 @@
"dompurify": "3.3.1",
"drizzle-kit": "1.0.0-beta.16-ea816b6",
"drizzle-orm": "1.0.0-beta.16-ea816b6",
"effect": "4.0.0-beta.29",
"effect": "4.0.0-beta.31",
"fuzzysort": "3.1.0",
"hono": "4.10.7",
"hono-openapi": "1.1.2",
@ -2738,7 +2738,7 @@
"ee-first": ["ee-first@1.1.1", "", {}, "sha512-WMwm9LhRUo+WUaRN+vRuETqG89IgZphVSNkdFgeb6sS/E4OrDIN7t48CAewSHXc6C8lefD8KKfr5vY61brQlow=="],
"effect": ["effect@4.0.0-beta.29", "", { "dependencies": { "@standard-schema/spec": "^1.1.0", "fast-check": "^4.5.3", "find-my-way-ts": "^0.1.6", "ini": "^6.0.0", "kubernetes-types": "^1.30.0", "msgpackr": "^1.11.8", "multipasta": "^0.2.7", "toml": "^3.0.0", "uuid": "^13.0.0", "yaml": "^2.8.2" } }, "sha512-7UoBAEiktoS81XLMX/39Mq/Ymq8whxmqFpsI0MEYdMlbDcbytzQlyuyhvrwEIdrd9qrqa8DZ5mKblWasamryqw=="],
"effect": ["effect@4.0.0-beta.31", "", { "dependencies": { "@standard-schema/spec": "^1.1.0", "fast-check": "^4.5.3", "find-my-way-ts": "^0.1.6", "ini": "^6.0.0", "kubernetes-types": "^1.30.0", "msgpackr": "^1.11.8", "multipasta": "^0.2.7", "toml": "^3.0.0", "uuid": "^13.0.0", "yaml": "^2.8.2" } }, "sha512-w3QwJnlaLtWWiUSzhCXUTIisnULPsxLzpO6uqaBFjXybKx6FvCqsLJT6v4dV7G9eA9jeTtG6Gv7kF+jGe3HxzA=="],
"ejs": ["ejs@3.1.10", "", { "dependencies": { "jake": "^10.8.5" }, "bin": { "ejs": "bin/cli.js" } }, "sha512-UeJmFfOrAQS8OJWPZ4qtgHyWExa088/MtK5UEyoJGFH67cDEXkZSviOiKRCZ4Xij0zxI3JECgYs3oKx+AizQBA=="],
@ -5226,6 +5226,10 @@
"@solidjs/start/vite": ["vite@7.1.10", "", { "dependencies": { "esbuild": "^0.25.0", "fdir": "^6.5.0", "picomatch": "^4.0.3", "postcss": "^8.5.6", "rollup": "^4.43.0", "tinyglobby": "^0.2.15" }, "optionalDependencies": { "fsevents": "~2.3.3" }, "peerDependencies": { "@types/node": "^20.19.0 || >=22.12.0", "jiti": ">=1.21.0", "less": "^4.0.0", "lightningcss": "^1.21.0", "sass": "^1.70.0", "sass-embedded": "^1.70.0", "stylus": ">=0.54.8", "sugarss": "^5.0.0", "terser": "^5.16.0", "tsx": "^4.8.1", "yaml": "^2.4.2" }, "optionalPeers": ["@types/node", "jiti", "less", "lightningcss", "sass", "sass-embedded", "stylus", "sugarss", "terser", "tsx", "yaml"], "bin": { "vite": "bin/vite.js" } }, "sha512-CmuvUBzVJ/e3HGxhg6cYk88NGgTnBoOo7ogtfJJ0fefUWAxN/WDSUa50o+oVBxuIhO8FoEZW0j2eW7sfjs5EtA=="],
"@standard-community/standard-json/effect": ["effect@4.0.0-beta.29", "", { "dependencies": { "@standard-schema/spec": "^1.1.0", "fast-check": "^4.5.3", "find-my-way-ts": "^0.1.6", "ini": "^6.0.0", "kubernetes-types": "^1.30.0", "msgpackr": "^1.11.8", "multipasta": "^0.2.7", "toml": "^3.0.0", "uuid": "^13.0.0", "yaml": "^2.8.2" } }, "sha512-7UoBAEiktoS81XLMX/39Mq/Ymq8whxmqFpsI0MEYdMlbDcbytzQlyuyhvrwEIdrd9qrqa8DZ5mKblWasamryqw=="],
"@standard-community/standard-openapi/effect": ["effect@4.0.0-beta.29", "", { "dependencies": { "@standard-schema/spec": "^1.1.0", "fast-check": "^4.5.3", "find-my-way-ts": "^0.1.6", "ini": "^6.0.0", "kubernetes-types": "^1.30.0", "msgpackr": "^1.11.8", "multipasta": "^0.2.7", "toml": "^3.0.0", "uuid": "^13.0.0", "yaml": "^2.8.2" } }, "sha512-7UoBAEiktoS81XLMX/39Mq/Ymq8whxmqFpsI0MEYdMlbDcbytzQlyuyhvrwEIdrd9qrqa8DZ5mKblWasamryqw=="],
"@tailwindcss/oxide/detect-libc": ["detect-libc@2.1.2", "", {}, "sha512-Btj2BOOO83o3WyH59e8MgXsxEQVcarkUOpEYrubB0urwnN10yQ364rsiByU11nZlqWYZm05i/of7io4mzihBtQ=="],
"@tailwindcss/oxide-wasm32-wasi/@emnapi/core": ["@emnapi/core@1.8.1", "", { "dependencies": { "@emnapi/wasi-threads": "1.1.0", "tslib": "^2.4.0" }, "bundled": true }, "sha512-AvT9QFpxK0Zd8J0jopedNm+w/2fIzvtPKPjqyw9jwvBaReTTqPBk9Hixaz7KbjimP+QNz605/XnjFcDAL2pqBg=="],
@ -6124,6 +6128,10 @@
"@solidjs/start/shiki/@shikijs/types": ["@shikijs/types@1.29.2", "", { "dependencies": { "@shikijs/vscode-textmate": "^10.0.1", "@types/hast": "^3.0.4" } }, "sha512-VJjK0eIijTZf0QSTODEXCqinjBn0joAHQ+aPSBzrv4O2d/QSbsMw+ZeSRx03kV34Hy7NzUvV/7NqfYGRLrASmw=="],
"@standard-community/standard-json/effect/@standard-schema/spec": ["@standard-schema/spec@1.1.0", "", {}, "sha512-l2aFy5jALhniG5HgqrD6jXLi/rUWrKvqN/qJx6yoJsgKhblVd+iqqU4RCXavm/jPityDo5TCvKMnpjKnOriy0w=="],
"@standard-community/standard-openapi/effect/@standard-schema/spec": ["@standard-schema/spec@1.1.0", "", {}, "sha512-l2aFy5jALhniG5HgqrD6jXLi/rUWrKvqN/qJx6yoJsgKhblVd+iqqU4RCXavm/jPityDo5TCvKMnpjKnOriy0w=="],
"@tailwindcss/oxide-wasm32-wasi/@napi-rs/wasm-runtime/@tybys/wasm-util": ["@tybys/wasm-util@0.10.1", "", { "dependencies": { "tslib": "^2.4.0" } }, "sha512-9tTaPJLSiejZKx+Bmog4uSubteqTvFrVrURwkmHixBo0G4seD0zUxp98E1DzUBJxLQ3NPwXrGKDiVjwx/DpPsg=="],
"@vitest/expect/@vitest/utils/@vitest/pretty-format": ["@vitest/pretty-format@3.2.4", "", { "dependencies": { "tinyrainbow": "^2.0.0" } }, "sha512-IVNZik8IVRJRTr9fxlitMKeJeXFFFN0JaB9PHPGQ8NKQbGpfjlTx9zO4RefN8gp7eqjNy8nyK3NZmBzOPeIxtA=="],

View File

@ -43,7 +43,7 @@
"dompurify": "3.3.1",
"drizzle-kit": "1.0.0-beta.16-ea816b6",
"drizzle-orm": "1.0.0-beta.16-ea816b6",
"effect": "4.0.0-beta.29",
"effect": "4.0.0-beta.31",
"ai": "5.0.124",
"hono": "4.10.7",
"hono-openapi": "1.1.2",

View File

@ -56,7 +56,7 @@
"@solidjs/router": "catalog:",
"@thisbeyond/solid-dnd": "0.7.5",
"diff": "catalog:",
"effect": "4.0.0-beta.29",
"effect": "4.0.0-beta.31",
"fuzzysort": "catalog:",
"ghostty-web": "github:anomalyco/ghostty-web#main",
"luxon": "catalog:",

View File

@ -30,7 +30,7 @@
"@solid-primitives/storage": "catalog:",
"@solidjs/meta": "catalog:",
"@solidjs/router": "0.15.4",
"effect": "4.0.0-beta.29",
"effect": "4.0.0-beta.31",
"electron-log": "^5",
"electron-store": "^10",
"electron-updater": "^6",

View File

@ -35,7 +35,7 @@ import { Hash } from "../util/hash"
import { ACPSessionManager } from "./session"
import type { ACPConfig } from "./types"
import { Provider } from "../provider/provider"
import { ProviderID } from "../provider/schema"
import { ModelID, ProviderID } from "../provider/schema"
import { Agent as AgentModule } from "../agent/agent"
import { Installation } from "@/installation"
import { MessageV2 } from "@/session/message-v2"
@ -56,8 +56,8 @@ export namespace ACP {
async function getContextLimit(
sdk: OpencodeClient,
providerID: string,
modelID: string,
providerID: ProviderID,
modelID: ModelID,
directory: string,
): Promise<number | null> {
const providers = await sdk.config
@ -97,7 +97,8 @@ export namespace ACP {
if (!lastAssistant) return
const msg = lastAssistant.info
const size = await getContextLimit(sdk, msg.providerID, msg.modelID, directory)
if (!msg.providerID || !msg.modelID) return
const size = await getContextLimit(sdk, ProviderID.make(msg.providerID), ModelID.make(msg.modelID), directory)
if (!size) {
// Cannot calculate usage without known context size
@ -637,8 +638,8 @@ export namespace ACP {
if (lastUser?.role === "user") {
result.models.currentModelId = `${lastUser.model.providerID}/${lastUser.model.modelID}`
this.sessionManager.setModel(sessionId, {
providerID: lastUser.model.providerID,
modelID: lastUser.model.modelID,
providerID: ProviderID.make(lastUser.model.providerID),
modelID: ModelID.make(lastUser.model.modelID),
})
if (result.modes?.availableModes.some((m) => m.id === lastUser.agent)) {
result.modes.currentModeId = lastUser.agent
@ -1526,7 +1527,7 @@ export namespace ACP {
}
}
async function defaultModel(config: ACPConfig, cwd?: string) {
async function defaultModel(config: ACPConfig, cwd?: string): Promise<{ providerID: ProviderID; modelID: ModelID }> {
const sdk = config.sdk
const configured = config.defaultModel
if (configured) return configured
@ -1538,11 +1539,7 @@ export namespace ACP {
.then((resp) => {
const cfg = resp.data
if (!cfg || !cfg.model) return undefined
const parsed = Provider.parseModel(cfg.model)
return {
providerID: parsed.providerID,
modelID: parsed.modelID,
}
return Provider.parseModel(cfg.model)
})
.catch((error) => {
log.error("failed to load user config for default model", { error })
@ -1567,13 +1564,13 @@ export namespace ACP {
const opencodeProvider = providers.find((p) => p.id === "opencode")
if (opencodeProvider) {
if (opencodeProvider.models["big-pickle"]) {
return { providerID: "opencode", modelID: "big-pickle" }
return { providerID: ProviderID.opencode, modelID: ModelID.make("big-pickle") }
}
const [best] = Provider.sort(Object.values(opencodeProvider.models))
if (best) {
return {
providerID: best.providerID,
modelID: best.id,
providerID: ProviderID.make(best.providerID),
modelID: ModelID.make(best.id),
}
}
}
@ -1582,14 +1579,14 @@ export namespace ACP {
const [best] = Provider.sort(models)
if (best) {
return {
providerID: best.providerID,
modelID: best.id,
providerID: ProviderID.make(best.providerID),
modelID: ModelID.make(best.id),
}
}
if (specified) return specified
return { providerID: "opencode", modelID: "big-pickle" }
return { providerID: ProviderID.opencode, modelID: ModelID.make("big-pickle") }
}
function parseUri(
@ -1652,7 +1649,7 @@ export namespace ACP {
function modelVariantsFromProviders(
providers: Array<{ id: string; models: Record<string, { variants?: Record<string, any> }> }>,
model: { providerID: string; modelID: string },
model: { providerID: ProviderID; modelID: ModelID },
): string[] {
const provider = providers.find((entry) => entry.id === model.providerID)
if (!provider) return []
@ -1688,7 +1685,7 @@ export namespace ACP {
}
function formatModelIdWithVariant(
model: { providerID: string; modelID: string },
model: { providerID: ProviderID; modelID: ModelID },
variant: string | undefined,
availableVariants: string[],
includeVariant: boolean,
@ -1699,7 +1696,7 @@ export namespace ACP {
}
function buildVariantMeta(input: {
model: { providerID: string; modelID: string }
model: { providerID: ProviderID; modelID: ModelID }
variant?: string
availableVariants: string[]
}) {
@ -1715,7 +1712,7 @@ export namespace ACP {
function parseModelSelection(
modelId: string,
providers: Array<{ id: string; models: Record<string, { variants?: Record<string, any> }> }>,
): { model: { providerID: string; modelID: string }; variant?: string } {
): { model: { providerID: ProviderID; modelID: ModelID }; variant?: string } {
const parsed = Provider.parseModel(modelId)
const provider = providers.find((p) => p.id === parsed.providerID)
if (!provider) {
@ -1735,7 +1732,7 @@ export namespace ACP {
const baseModelInfo = provider.models[baseModelId]
if (baseModelInfo?.variants && candidateVariant in baseModelInfo.variants) {
return {
model: { providerID: parsed.providerID, modelID: baseModelId },
model: { providerID: parsed.providerID, modelID: ModelID.make(baseModelId) },
variant: candidateVariant,
}
}

View File

@ -1,5 +1,6 @@
import type { McpServer } from "@agentclientprotocol/sdk"
import type { OpencodeClient } from "@opencode-ai/sdk/v2"
import type { ProviderID, ModelID } from "../provider/schema"
export interface ACPSessionState {
id: string
@ -7,8 +8,8 @@ export interface ACPSessionState {
mcpServers: McpServer[]
createdAt: Date
model?: {
providerID: string
modelID: string
providerID: ProviderID
modelID: ModelID
}
variant?: string
modeId?: string
@ -17,7 +18,7 @@ export interface ACPSessionState {
export interface ACPConfig {
sdk: OpencodeClient
defaultModel?: {
providerID: string
modelID: string
providerID: ProviderID
modelID: ModelID
}
}

View File

@ -281,7 +281,7 @@ export namespace Agent {
return primaryVisible.name
}
export async function generate(input: { description: string; model?: { providerID: string; modelID: string } }) {
export async function generate(input: { description: string; model?: { providerID: ProviderID; modelID: ModelID } }) {
const cfg = await Config.get()
const defaultModel = input.model ?? (await Provider.defaultModel())
const model = await Provider.getModel(defaultModel.providerID, defaultModel.modelID)

View File

@ -1,6 +1,7 @@
import type { Argv } from "yargs"
import { Instance } from "../../project/instance"
import { Provider } from "../../provider/provider"
import { ProviderID } from "../../provider/schema"
import { ModelsDev } from "../../provider/models"
import { cmd } from "./cmd"
import { UI } from "../ui"
@ -36,7 +37,7 @@ export const ModelsCommand = cmd({
async fn() {
const providers = await Provider.list()
function printModels(providerID: string, verbose?: boolean) {
function printModels(providerID: ProviderID, verbose?: boolean) {
const provider = providers[providerID]
const sortedModels = Object.entries(provider.models).sort(([a], [b]) => a.localeCompare(b))
for (const [modelID, model] of sortedModels) {
@ -56,7 +57,7 @@ export const ModelsCommand = cmd({
return
}
printModels(args.provider, args.verbose)
printModels(ProviderID.make(args.provider), args.verbose)
return
}
@ -69,7 +70,7 @@ export const ModelsCommand = cmd({
})
for (const providerID of providerIDs) {
printModels(providerID, args.verbose)
printModels(ProviderID.make(providerID), args.verbose)
}
},
})

View File

@ -15,9 +15,13 @@ export namespace Permission {
return pattern === undefined ? [type] : Array.isArray(pattern) ? pattern : [pattern]
}
function covered(keys: string[], approved: Record<string, boolean>): boolean {
const pats = Object.keys(approved)
return keys.every((k) => pats.some((p) => Wildcard.match(k, p)))
function covered(keys: string[], approved: Map<string, boolean>): boolean {
return keys.every((k) => {
for (const p of approved.keys()) {
if (Wildcard.match(k, p)) return true
}
return false
})
}
export const Info = z
@ -39,6 +43,12 @@ export namespace Permission {
})
export type Info = z.infer<typeof Info>
interface PendingEntry {
info: Info
resolve: () => void
reject: (e: any) => void
}
export const Event = {
Updated: BusEvent.define("permission.updated", Info),
Replied: BusEvent.define(
@ -52,31 +62,13 @@ export namespace Permission {
}
const state = Instance.state(
() => {
const pending: {
[sessionID: string]: {
[permissionID: string]: {
info: Info
resolve: () => void
reject: (e: any) => void
}
}
} = {}
const approved: {
[sessionID: string]: {
[permissionID: string]: boolean
}
} = {}
return {
pending,
approved,
}
},
() => ({
pending: new Map<SessionID, Map<PermissionID, PendingEntry>>(),
approved: new Map<SessionID, Map<string, boolean>>(),
}),
async (state) => {
for (const pending of Object.values(state.pending)) {
for (const item of Object.values(pending)) {
for (const session of state.pending.values()) {
for (const item of session.values()) {
item.reject(new RejectedError(item.info.sessionID, item.info.id, item.info.callID, item.info.metadata))
}
}
@ -90,8 +82,8 @@ export namespace Permission {
export function list() {
const { pending } = state()
const result: Info[] = []
for (const items of Object.values(pending)) {
for (const item of Object.values(items)) {
for (const session of pending.values()) {
for (const item of session.values()) {
result.push(item.info)
}
}
@ -114,9 +106,9 @@ export namespace Permission {
toolCallID: input.callID,
pattern: input.pattern,
})
const approvedForSession = approved[input.sessionID] || {}
const approvedForSession = approved.get(input.sessionID)
const keys = toKeys(input.pattern, input.type)
if (covered(keys, approvedForSession)) return
if (approvedForSession && covered(keys, approvedForSession)) return
const info: Info = {
id: PermissionID.ascending(),
type: input.type,
@ -142,13 +134,13 @@ export namespace Permission {
return
}
pending[input.sessionID] = pending[input.sessionID] || {}
if (!pending.has(input.sessionID)) pending.set(input.sessionID, new Map())
return new Promise<void>((resolve, reject) => {
pending[input.sessionID][info.id] = {
pending.get(input.sessionID)!.set(info.id, {
info,
resolve,
reject,
}
})
Bus.publish(Event.Updated, info)
})
}
@ -159,9 +151,11 @@ export namespace Permission {
export function respond(input: { sessionID: Info["sessionID"]; permissionID: Info["id"]; response: Response }) {
log.info("response", input)
const { pending, approved } = state()
const match = pending[input.sessionID]?.[input.permissionID]
if (!match) return
delete pending[input.sessionID][input.permissionID]
const session = pending.get(input.sessionID)
const match = session?.get(input.permissionID)
if (!session || !match) return
session.delete(input.permissionID)
if (session.size === 0) pending.delete(input.sessionID)
Bus.publish(Event.Replied, {
sessionID: input.sessionID,
permissionID: input.permissionID,
@ -173,30 +167,35 @@ export namespace Permission {
}
match.resolve()
if (input.response === "always") {
approved[input.sessionID] = approved[input.sessionID] || {}
if (!approved.has(input.sessionID)) approved.set(input.sessionID, new Map())
const approvedSession = approved.get(input.sessionID)!
const approveKeys = toKeys(match.info.pattern, match.info.type)
for (const k of approveKeys) {
approved[input.sessionID][k] = true
approvedSession.set(k, true)
}
const items = pending[input.sessionID]
const items = pending.get(input.sessionID)
if (!items) return
for (const item of Object.values(items)) {
const toRespond: Info[] = []
for (const item of items.values()) {
const itemKeys = toKeys(item.info.pattern, item.info.type)
if (covered(itemKeys, approved[input.sessionID])) {
respond({
sessionID: item.info.sessionID,
permissionID: item.info.id,
response: input.response,
})
if (covered(itemKeys, approvedSession)) {
toRespond.push(item.info)
}
}
for (const item of toRespond) {
respond({
sessionID: item.sessionID,
permissionID: item.id,
response: input.response,
})
}
}
}
export class RejectedError extends Error {
constructor(
public readonly sessionID: string,
public readonly permissionID: string,
public readonly sessionID: SessionID,
public readonly permissionID: PermissionID,
public readonly toolCallID?: string,
public readonly metadata?: Record<string, any>,
public readonly reason?: string,

View File

@ -108,6 +108,12 @@ export namespace PermissionNext {
),
}
interface PendingEntry {
info: Request
resolve: () => void
reject: (e: any) => void
}
const state = Instance.state(() => {
const projectID = Instance.project.id
const row = Database.use((db) =>
@ -115,17 +121,8 @@ export namespace PermissionNext {
)
const stored = row?.data ?? ([] as Ruleset)
const pending: Record<
string,
{
info: Request
resolve: () => void
reject: (e: any) => void
}
> = {}
return {
pending,
pending: new Map<PermissionID, PendingEntry>(),
approved: stored,
}
})
@ -149,11 +146,11 @@ export namespace PermissionNext {
id,
...request,
}
s.pending[id] = {
s.pending.set(id, {
info,
resolve,
reject,
}
})
Bus.publish(Event.Asked, info)
})
}
@ -170,9 +167,9 @@ export namespace PermissionNext {
}),
async (input) => {
const s = await state()
const existing = s.pending[input.requestID]
const existing = s.pending.get(input.requestID)
if (!existing) return
delete s.pending[input.requestID]
s.pending.delete(input.requestID)
Bus.publish(Event.Replied, {
sessionID: existing.info.sessionID,
requestID: existing.info.id,
@ -182,9 +179,9 @@ export namespace PermissionNext {
existing.reject(input.message ? new CorrectedError(input.message) : new RejectedError())
// Reject all other pending permissions for this session
const sessionID = existing.info.sessionID
for (const [id, pending] of Object.entries(s.pending)) {
for (const [id, pending] of s.pending) {
if (pending.info.sessionID === sessionID) {
delete s.pending[id]
s.pending.delete(id)
Bus.publish(Event.Replied, {
sessionID: pending.info.sessionID,
requestID: pending.info.id,
@ -211,13 +208,13 @@ export namespace PermissionNext {
existing.resolve()
const sessionID = existing.info.sessionID
for (const [id, pending] of Object.entries(s.pending)) {
for (const [id, pending] of s.pending) {
if (pending.info.sessionID !== sessionID) continue
const ok = pending.info.patterns.every(
(pattern) => evaluate(pending.info.permission, pattern, s.approved).action === "allow",
)
if (!ok) continue
delete s.pending[id]
s.pending.delete(id)
Bus.publish(Event.Replied, {
sessionID: pending.info.sessionID,
requestID: pending.info.id,
@ -283,6 +280,6 @@ export namespace PermissionNext {
export async function list() {
const s = await state()
return Object.values(s.pending).map((x) => x.info)
return Array.from(s.pending.values(), (x) => x.info)
}
}

View File

@ -377,7 +377,7 @@ export async function CodexAuthPlugin(input: PluginInput): Promise<Hooks> {
if (!provider.models["gpt-5.3-codex"]) {
const model = {
id: ModelID.make("gpt-5.3-codex"),
providerID: ProviderID.make("openai"),
providerID: ProviderID.openai,
api: {
id: "gpt-5.3-codex",
url: "https://chatgpt.com/backend-api/codex",

View File

@ -1,6 +1,7 @@
import { APICallError } from "ai"
import { STATUS_CODES } from "http"
import { iife } from "@/util/iife"
import type { ProviderID } from "./schema"
export namespace ProviderError {
// Adapted from overflow detection patterns in:
@ -40,7 +41,7 @@ export namespace ProviderError {
return /^4(00|13)\s*(status code)?\s*\(no body\)/i.test(message)
}
function message(providerID: string, e: APICallError) {
function message(providerID: ProviderID, e: APICallError) {
return iife(() => {
const msg = e.message
if (msg === "") {
@ -164,7 +165,7 @@ export namespace ProviderError {
metadata?: Record<string, string>
}
export function parseAPICallError(input: { providerID: string; error: APICallError }): ParsedAPICallError {
export function parseAPICallError(input: { providerID: ProviderID; error: APICallError }): ParsedAPICallError {
const m = message(input.providerID, input.error)
if (isOverflow(m) || input.error.statusCode === 413) {
return {

View File

@ -845,7 +845,7 @@ export namespace Provider {
const disabled = new Set(config.disabled_providers ?? [])
const enabled = config.enabled_providers ? new Set(config.enabled_providers) : null
function isProviderAllowed(providerID: string): boolean {
function isProviderAllowed(providerID: ProviderID): boolean {
if (enabled && !enabled.has(providerID)) return false
if (disabled.has(providerID)) return false
return true
@ -867,16 +867,16 @@ export namespace Provider {
const githubCopilot = database["github-copilot"]
database["github-copilot-enterprise"] = {
...githubCopilot,
id: ProviderID.make("github-copilot-enterprise"),
id: ProviderID.githubCopilotEnterprise,
name: "GitHub Copilot Enterprise",
models: mapValues(githubCopilot.models, (model) => ({
...model,
providerID: ProviderID.make("github-copilot-enterprise"),
providerID: ProviderID.githubCopilotEnterprise,
})),
}
}
function mergeProvider(providerID: string, provider: Partial<Info>) {
function mergeProvider(providerID: ProviderID, provider: Partial<Info>) {
const existing = providers[providerID]
if (existing) {
// @ts-expect-error
@ -974,7 +974,8 @@ export namespace Provider {
// load env
const env = Env.all()
for (const [providerID, provider] of Object.entries(database)) {
for (const [id, provider] of Object.entries(database)) {
const providerID = ProviderID.make(id)
if (disabled.has(providerID)) continue
const apiKey = provider.env.map((item) => env[item]).find(Boolean)
if (!apiKey) continue
@ -985,7 +986,8 @@ export namespace Provider {
}
// load apikeys
for (const [providerID, provider] of Object.entries(await Auth.all())) {
for (const [id, provider] of Object.entries(await Auth.all())) {
const providerID = ProviderID.make(id)
if (disabled.has(providerID)) continue
if (provider.type === "api") {
mergeProvider(providerID, {
@ -997,7 +999,7 @@ export namespace Provider {
for (const plugin of await Plugin.list()) {
if (!plugin.auth) continue
const providerID = plugin.auth.provider
const providerID = ProviderID.make(plugin.auth.provider)
if (disabled.has(providerID)) continue
// For github-copilot plugin, check if auth exists for either github-copilot or github-copilot-enterprise
@ -1006,7 +1008,7 @@ export namespace Provider {
if (auth) hasAuth = true
// Special handling for github-copilot: also check for enterprise auth
if (providerID === "github-copilot" && !hasAuth) {
if (providerID === ProviderID.githubCopilot && !hasAuth) {
const enterpriseAuth = await Auth.get("github-copilot-enterprise")
if (enterpriseAuth) hasAuth = true
}
@ -1023,8 +1025,8 @@ export namespace Provider {
}
// If this is github-copilot plugin, also register for github-copilot-enterprise if auth exists
if (providerID === "github-copilot") {
const enterpriseProviderID = "github-copilot-enterprise"
if (providerID === ProviderID.githubCopilot) {
const enterpriseProviderID = ProviderID.githubCopilotEnterprise
if (!disabled.has(enterpriseProviderID)) {
const enterpriseAuth = await Auth.get(enterpriseProviderID)
if (enterpriseAuth) {
@ -1042,7 +1044,8 @@ export namespace Provider {
}
}
for (const [providerID, fn] of Object.entries(CUSTOM_LOADERS)) {
for (const [id, fn] of Object.entries(CUSTOM_LOADERS)) {
const providerID = ProviderID.make(id)
if (disabled.has(providerID)) continue
const data = database[providerID]
if (!data) {
@ -1059,7 +1062,8 @@ export namespace Provider {
}
// load config
for (const [providerID, provider] of configProviders) {
for (const [id, provider] of configProviders) {
const providerID = ProviderID.make(id)
const partial: Partial<Info> = { source: "config" }
if (provider.env) partial.env = provider.env
if (provider.name) partial.name = provider.name
@ -1067,7 +1071,8 @@ export namespace Provider {
mergeProvider(providerID, partial)
}
for (const [providerID, provider] of Object.entries(providers)) {
for (const [id, provider] of Object.entries(providers)) {
const providerID = ProviderID.make(id)
if (!isProviderAllowed(providerID)) {
delete providers[providerID]
continue
@ -1077,7 +1082,7 @@ export namespace Provider {
for (const [modelID, model] of Object.entries(provider.models)) {
model.api.id = model.api.id ?? model.id ?? modelID
if (modelID === "gpt-5-chat-latest" || (providerID === "openrouter" && modelID === "openai/gpt-5-chat"))
if (modelID === "gpt-5-chat-latest" || (providerID === ProviderID.openrouter && modelID === "openai/gpt-5-chat"))
delete provider.models[modelID]
if (model.status === "alpha" && !Flag.OPENCODE_ENABLE_EXPERIMENTAL_MODELS) delete provider.models[modelID]
if (model.status === "deprecated") delete provider.models[modelID]
@ -1230,11 +1235,11 @@ export namespace Provider {
}
}
export async function getProvider(providerID: string) {
export async function getProvider(providerID: ProviderID) {
return state().then((s) => s.providers[providerID])
}
export async function getModel(providerID: string, modelID: string) {
export async function getModel(providerID: ProviderID, modelID: ModelID) {
const s = await state()
const provider = s.providers[providerID]
if (!provider) {
@ -1281,7 +1286,7 @@ export namespace Provider {
}
}
export async function closest(providerID: string, query: string[]) {
export async function closest(providerID: ProviderID, query: string[]) {
const s = await state()
const provider = s.providers[providerID]
if (!provider) return undefined
@ -1296,7 +1301,7 @@ export namespace Provider {
}
}
export async function getSmallModel(providerID: string) {
export async function getSmallModel(providerID: ProviderID) {
const cfg = await Config.get()
if (cfg.small_model) {
@ -1323,7 +1328,7 @@ export namespace Provider {
priority = ["gpt-5-mini", "claude-haiku-4.5", ...priority]
}
for (const item of priority) {
if (providerID === "amazon-bedrock") {
if (providerID === ProviderID.amazonBedrock) {
const crossRegionPrefixes = ["global.", "us.", "eu."]
const candidates = Object.keys(provider.models).filter((m) => m.includes(item))
@ -1332,22 +1337,22 @@ export namespace Provider {
// 2. User's region prefix (us., eu.)
// 3. Unprefixed model
const globalMatch = candidates.find((m) => m.startsWith("global."))
if (globalMatch) return getModel(providerID, globalMatch)
if (globalMatch) return getModel(providerID, ModelID.make(globalMatch))
const region = provider.options?.region
if (region) {
const regionPrefix = region.split("-")[0]
if (regionPrefix === "us" || regionPrefix === "eu") {
const regionalMatch = candidates.find((m) => m.startsWith(`${regionPrefix}.`))
if (regionalMatch) return getModel(providerID, regionalMatch)
if (regionalMatch) return getModel(providerID, ModelID.make(regionalMatch))
}
}
const unprefixed = candidates.find((m) => !crossRegionPrefixes.some((p) => m.startsWith(p)))
if (unprefixed) return getModel(providerID, unprefixed)
if (unprefixed) return getModel(providerID, ModelID.make(unprefixed))
} else {
for (const model of Object.keys(provider.models)) {
if (model.includes(item)) return getModel(providerID, model)
if (model.includes(item)) return getModel(providerID, ModelID.make(model))
}
}
}

View File

@ -11,6 +11,18 @@ export const ProviderID = providerIdSchema.pipe(
withStatics((schema: typeof providerIdSchema) => ({
make: (id: string) => schema.makeUnsafe(id),
zod: z.string().pipe(z.custom<ProviderID>()),
// Well-known providers
opencode: schema.makeUnsafe("opencode"),
anthropic: schema.makeUnsafe("anthropic"),
openai: schema.makeUnsafe("openai"),
google: schema.makeUnsafe("google"),
googleVertex: schema.makeUnsafe("google-vertex"),
githubCopilot: schema.makeUnsafe("github-copilot"),
githubCopilotEnterprise: schema.makeUnsafe("github-copilot-enterprise"),
amazonBedrock: schema.makeUnsafe("amazon-bedrock"),
azure: schema.makeUnsafe("azure"),
openrouter: schema.makeUnsafe("openrouter"),
mistral: schema.makeUnsafe("mistral"),
})),
)

View File

@ -91,7 +91,7 @@ export namespace Pty {
}
const state = Instance.state(
() => new Map<string, ActiveSession>(),
() => new Map<PtyID, ActiveSession>(),
async (sessions) => {
for (const session of sessions.values()) {
try {
@ -113,7 +113,7 @@ export namespace Pty {
return Array.from(state().values()).map((s) => s.info)
}
export function get(id: string) {
export function get(id: PtyID) {
return state().get(id)?.info
}
@ -205,7 +205,7 @@ export namespace Pty {
return info
}
export async function update(id: string, input: UpdateInput) {
export async function update(id: PtyID, input: UpdateInput) {
const session = state().get(id)
if (!session) return
if (input.title) {
@ -218,7 +218,7 @@ export namespace Pty {
return session.info
}
export async function remove(id: string) {
export async function remove(id: PtyID) {
const session = state().get(id)
if (!session) return
state().delete(id)
@ -237,21 +237,21 @@ export namespace Pty {
Bus.publish(Event.Deleted, { id: session.info.id })
}
export function resize(id: string, cols: number, rows: number) {
export function resize(id: PtyID, cols: number, rows: number) {
const session = state().get(id)
if (session && session.info.status === "running") {
session.process.resize(cols, rows)
}
}
export function write(id: string, data: string) {
export function write(id: PtyID, data: string) {
const session = state().get(id)
if (session && session.info.status === "running") {
session.process.write(data)
}
}
export function connect(id: string, ws: Socket, cursor?: number) {
export function connect(id: PtyID, ws: Socket, cursor?: number) {
const session = state().get(id)
if (!session) {
ws.close()

View File

@ -80,20 +80,15 @@ export namespace Question {
),
}
const state = Instance.state(async () => {
const pending: Record<
string,
{
info: Request
resolve: (answers: Answer[]) => void
reject: (e: any) => void
}
> = {}
interface PendingEntry {
info: Request
resolve: (answers: Answer[]) => void
reject: (e: any) => void
}
return {
pending,
}
})
const state = Instance.state(async () => ({
pending: new Map<QuestionID, PendingEntry>(),
}))
export async function ask(input: {
sessionID: SessionID
@ -112,23 +107,23 @@ export namespace Question {
questions: input.questions,
tool: input.tool,
}
s.pending[id] = {
s.pending.set(id, {
info,
resolve,
reject,
}
})
Bus.publish(Event.Asked, info)
})
}
export async function reply(input: { requestID: string; answers: Answer[] }): Promise<void> {
export async function reply(input: { requestID: QuestionID; answers: Answer[] }): Promise<void> {
const s = await state()
const existing = s.pending[input.requestID]
const existing = s.pending.get(input.requestID)
if (!existing) {
log.warn("reply for unknown request", { requestID: input.requestID })
return
}
delete s.pending[input.requestID]
s.pending.delete(input.requestID)
log.info("replied", { requestID: input.requestID, answers: input.answers })
@ -141,14 +136,14 @@ export namespace Question {
existing.resolve(input.answers)
}
export async function reject(requestID: string): Promise<void> {
export async function reject(requestID: QuestionID): Promise<void> {
const s = await state()
const existing = s.pending[requestID]
const existing = s.pending.get(requestID)
if (!existing) {
log.warn("reject for unknown request", { requestID })
return
}
delete s.pending[requestID]
s.pending.delete(requestID)
log.info("rejected", { requestID })
@ -167,6 +162,6 @@ export namespace Question {
}
export async function list() {
return state().then((x) => Object.values(x.pending).map((x) => x.info))
return state().then((x) => Array.from(x.pending.values(), (x) => x.info))
}
}

View File

@ -1,6 +1,7 @@
import { Hono } from "hono"
import { describeRoute, validator, resolver } from "hono-openapi"
import z from "zod"
import { ProviderID, ModelID } from "../../provider/schema"
import { ToolRegistry } from "../../tool/registry"
import { Worktree } from "../../worktree"
import { Instance } from "../../project/instance"
@ -77,7 +78,7 @@ export const ExperimentalRoutes = lazy(() =>
),
async (c) => {
const { provider, model } = c.req.valid("query")
const tools = await ToolRegistry.tools({ providerID: provider, modelID: model })
const tools = await ToolRegistry.tools({ providerID: ProviderID.make(provider), modelID: ModelID.make(model) })
return c.json(
tools.map((t) => ({
id: t.id,

View File

@ -237,7 +237,7 @@ export namespace SessionPrompt {
return parts
}
function start(sessionID: string) {
function start(sessionID: SessionID) {
const s = state()
if (s[sessionID]) return
const controller = new AbortController()
@ -248,7 +248,7 @@ export namespace SessionPrompt {
return controller.signal
}
function resume(sessionID: string) {
function resume(sessionID: SessionID) {
const s = state()
if (!s[sessionID]) return
@ -788,7 +788,7 @@ export namespace SessionPrompt {
})
for (const item of await ToolRegistry.tools(
{ modelID: input.model.api.id, providerID: input.model.providerID },
{ modelID: ModelID.make(input.model.api.id), providerID: input.model.providerID },
input.agent,
)) {
const schema = ProviderTransform.schema(input.model, z.toJSONSchema(item.parameters))
@ -1898,8 +1898,8 @@ NOTE: At any point in time through this workflow you should feel free to ask the
async function ensureTitle(input: {
session: Session.Info
history: MessageV2.WithParts[]
providerID: string
modelID: string
providerID: ProviderID
modelID: ModelID
}) {
if (input.session.parentID) return
if (!Session.isDefaultTitle(input.session.title)) return

View File

@ -2,6 +2,7 @@ import { Bus } from "@/bus"
import { Account } from "@/account"
import { Config } from "@/config/config"
import { Provider } from "@/provider/provider"
import { ProviderID, ModelID } from "@/provider/schema"
import { Session } from "@/session"
import type { SessionID } from "@/session/schema"
import { MessageV2 } from "@/session/message-v2"
@ -262,7 +263,7 @@ export namespace ShareNext {
.map((m) => (m.info as SDK.UserMessage).model)
.map((m) => [`${m.providerID}/${m.modelID}`, m] as const),
).values(),
).map((m) => Provider.getModel(m.providerID, m.modelID).then((item) => item)),
).map((m) => Provider.getModel(ProviderID.make(m.providerID), ModelID.make(m.modelID)).then((item) => item)),
)
await sync(sessionID, [
{

View File

@ -1,5 +1,6 @@
import z from "zod"
import { Tool } from "./tool"
import { ProviderID, ModelID } from "../provider/schema"
import DESCRIPTION from "./batch.txt"
const DISALLOWED = new Set(["batch"])
@ -37,7 +38,7 @@ export const BatchTool = Tool.define("batch", async () => {
const discardedCalls = params.tool_calls.slice(25)
const { ToolRegistry } = await import("./registry")
const availableTools = await ToolRegistry.tools({ modelID: "", providerID: "" })
const availableTools = await ToolRegistry.tools({ modelID: ModelID.make(""), providerID: ProviderID.make("") })
const toolMap = new Map(availableTools.map((t) => [t.id, t]))
const executeCall = async (call: (typeof toolCalls)[0]) => {

View File

@ -20,6 +20,7 @@ import path from "path"
import { type ToolContext as PluginToolContext, type ToolDefinition } from "@opencode-ai/plugin"
import z from "zod"
import { Plugin } from "../plugin"
import { ProviderID, type ModelID } from "../provider/schema"
import { WebSearchTool } from "./websearch"
import { CodeSearchTool } from "./codesearch"
import { Flag } from "@/flag/flag"
@ -130,8 +131,8 @@ export namespace ToolRegistry {
export async function tools(
model: {
providerID: string
modelID: string
providerID: ProviderID
modelID: ModelID
},
agent?: Agent.Info,
) {
@ -141,7 +142,7 @@ export namespace ToolRegistry {
.filter((t) => {
// Enable websearch/codesearch for zen users OR via enable flag
if (t.id === "codesearch" || t.id === "websearch") {
return model.providerID === "opencode" || Flag.OPENCODE_ENABLE_EXA
return model.providerID === ProviderID.opencode || Flag.OPENCODE_ENABLE_EXA
}
// use apply tool in same format as codex

View File

@ -4,6 +4,7 @@ import path from "path"
import { tmpdir } from "../fixture/fixture"
import { Instance } from "../../src/project/instance"
import { Provider } from "../../src/provider/provider"
import { ProviderID, ModelID } from "../../src/provider/schema"
import { Env } from "../../src/env"
test("provider loaded from env variable", async () => {
@ -300,7 +301,7 @@ test("getModel returns model for valid provider/model", async () => {
Env.set("ANTHROPIC_API_KEY", "test-api-key")
},
fn: async () => {
const model = await Provider.getModel("anthropic", "claude-sonnet-4-20250514")
const model = await Provider.getModel(ProviderID.anthropic, ModelID.make("claude-sonnet-4-20250514"))
expect(model).toBeDefined()
expect(String(model.providerID)).toBe("anthropic")
expect(String(model.id)).toBe("claude-sonnet-4-20250514")
@ -327,7 +328,7 @@ test("getModel throws ModelNotFoundError for invalid model", async () => {
Env.set("ANTHROPIC_API_KEY", "test-api-key")
},
fn: async () => {
expect(Provider.getModel("anthropic", "nonexistent-model")).rejects.toThrow()
expect(Provider.getModel(ProviderID.anthropic, ModelID.make("nonexistent-model"))).rejects.toThrow()
},
})
})
@ -346,7 +347,7 @@ test("getModel throws ModelNotFoundError for invalid provider", async () => {
await Instance.provide({
directory: tmp.path,
fn: async () => {
expect(Provider.getModel("nonexistent-provider", "some-model")).rejects.toThrow()
expect(Provider.getModel(ProviderID.make("nonexistent-provider"), ModelID.make("some-model"))).rejects.toThrow()
},
})
})
@ -572,10 +573,10 @@ test("closest finds model by partial match", async () => {
Env.set("ANTHROPIC_API_KEY", "test-api-key")
},
fn: async () => {
const result = await Provider.closest("anthropic", ["sonnet-4"])
const result = await Provider.closest(ProviderID.anthropic, ["sonnet-4"])
expect(result).toBeDefined()
expect(result?.providerID).toBe("anthropic")
expect(result?.modelID).toContain("sonnet-4")
expect(String(result?.providerID)).toBe("anthropic")
expect(String(result?.modelID)).toContain("sonnet-4")
},
})
})
@ -594,7 +595,7 @@ test("closest returns undefined for nonexistent provider", async () => {
await Instance.provide({
directory: tmp.path,
fn: async () => {
const result = await Provider.closest("nonexistent", ["model"])
const result = await Provider.closest(ProviderID.make("nonexistent"), ["model"])
expect(result).toBeUndefined()
},
})
@ -630,7 +631,7 @@ test("getModel uses realIdByKey for aliased models", async () => {
const providers = await Provider.list()
expect(providers["anthropic"].models["my-sonnet"]).toBeDefined()
const model = await Provider.getModel("anthropic", "my-sonnet")
const model = await Provider.getModel(ProviderID.anthropic, ModelID.make("my-sonnet"))
expect(model).toBeDefined()
expect(String(model.id)).toBe("my-sonnet")
expect(model.name).toBe("My Sonnet Alias")
@ -933,7 +934,7 @@ test("getSmallModel returns appropriate small model", async () => {
Env.set("ANTHROPIC_API_KEY", "test-api-key")
},
fn: async () => {
const model = await Provider.getSmallModel("anthropic")
const model = await Provider.getSmallModel(ProviderID.anthropic)
expect(model).toBeDefined()
expect(model?.id).toContain("haiku")
},
@ -958,7 +959,7 @@ test("getSmallModel respects config small_model override", async () => {
Env.set("ANTHROPIC_API_KEY", "test-api-key")
},
fn: async () => {
const model = await Provider.getSmallModel("anthropic")
const model = await Provider.getSmallModel(ProviderID.anthropic)
expect(model).toBeDefined()
expect(String(model?.providerID)).toBe("anthropic")
expect(String(model?.id)).toBe("claude-sonnet-4-20250514")
@ -1466,8 +1467,8 @@ test("getModel returns consistent results", async () => {
Env.set("ANTHROPIC_API_KEY", "test-api-key")
},
fn: async () => {
const model1 = await Provider.getModel("anthropic", "claude-sonnet-4-20250514")
const model2 = await Provider.getModel("anthropic", "claude-sonnet-4-20250514")
const model1 = await Provider.getModel(ProviderID.anthropic, ModelID.make("claude-sonnet-4-20250514"))
const model2 = await Provider.getModel(ProviderID.anthropic, ModelID.make("claude-sonnet-4-20250514"))
expect(model1.providerID).toEqual(model2.providerID)
expect(model1.id).toEqual(model2.id)
expect(model1).toEqual(model2)
@ -1528,7 +1529,7 @@ test("ModelNotFoundError includes suggestions for typos", async () => {
},
fn: async () => {
try {
await Provider.getModel("anthropic", "claude-sonet-4") // typo: sonet instead of sonnet
await Provider.getModel(ProviderID.anthropic, ModelID.make("claude-sonet-4")) // typo: sonet instead of sonnet
expect(true).toBe(false) // Should not reach here
} catch (e: any) {
expect(e.data.suggestions).toBeDefined()
@ -1556,7 +1557,7 @@ test("ModelNotFoundError for provider includes suggestions", async () => {
},
fn: async () => {
try {
await Provider.getModel("antropic", "claude-sonnet-4") // typo: antropic
await Provider.getModel(ProviderID.make("antropic"), ModelID.make("claude-sonnet-4")) // typo: antropic
expect(true).toBe(false) // Should not reach here
} catch (e: any) {
expect(e.data.suggestions).toBeDefined()
@ -1580,7 +1581,7 @@ test("getProvider returns undefined for nonexistent provider", async () => {
await Instance.provide({
directory: tmp.path,
fn: async () => {
const provider = await Provider.getProvider("nonexistent")
const provider = await Provider.getProvider(ProviderID.make("nonexistent"))
expect(provider).toBeUndefined()
},
})
@ -1603,7 +1604,7 @@ test("getProvider returns provider info", async () => {
Env.set("ANTHROPIC_API_KEY", "test-api-key")
},
fn: async () => {
const provider = await Provider.getProvider("anthropic")
const provider = await Provider.getProvider(ProviderID.anthropic)
expect(provider).toBeDefined()
expect(String(provider?.id)).toBe("anthropic")
},
@ -1627,7 +1628,7 @@ test("closest returns undefined when no partial match found", async () => {
Env.set("ANTHROPIC_API_KEY", "test-api-key")
},
fn: async () => {
const result = await Provider.closest("anthropic", ["nonexistent-xyz-model"])
const result = await Provider.closest(ProviderID.anthropic, ["nonexistent-xyz-model"])
expect(result).toBeUndefined()
},
})
@ -1651,7 +1652,7 @@ test("closest checks multiple query terms in order", async () => {
},
fn: async () => {
// First term won't match, second will
const result = await Provider.closest("anthropic", ["nonexistent", "haiku"])
const result = await Provider.closest(ProviderID.anthropic, ["nonexistent", "haiku"])
expect(result).toBeDefined()
expect(result?.modelID).toContain("haiku")
},

View File

@ -2,6 +2,7 @@ import { describe, expect, test } from "bun:test"
import { Bus } from "../../src/bus"
import { Instance } from "../../src/project/instance"
import { Pty } from "../../src/pty"
import type { PtyID } from "../../src/pty/schema"
import { tmpdir } from "../fixture/fixture"
import { setTimeout as sleep } from "node:timers/promises"
@ -14,7 +15,7 @@ const wait = async (fn: () => boolean, ms = 2000) => {
throw new Error("timeout waiting for pty events")
}
const pick = (log: Array<{ type: "created" | "exited" | "deleted"; id: string }>, id: string) => {
const pick = (log: Array<{ type: "created" | "exited" | "deleted"; id: PtyID }>, id: PtyID) => {
return log.filter((evt) => evt.id === id).map((evt) => evt.type)
}
@ -27,23 +28,23 @@ describe("pty", () => {
await Instance.provide({
directory: dir.path,
fn: async () => {
const log: Array<{ type: "created" | "exited" | "deleted"; id: string }> = []
const log: Array<{ type: "created" | "exited" | "deleted"; id: PtyID }> = []
const off = [
Bus.subscribe(Pty.Event.Created, (evt) => log.push({ type: "created", id: evt.properties.info.id })),
Bus.subscribe(Pty.Event.Exited, (evt) => log.push({ type: "exited", id: evt.properties.id })),
Bus.subscribe(Pty.Event.Deleted, (evt) => log.push({ type: "deleted", id: evt.properties.id })),
]
let id = ""
let id: PtyID | undefined
try {
const info = await Pty.create({ command: "/bin/ls", title: "ls" })
id = info.id
await wait(() => pick(log, id).includes("exited"))
await wait(() => pick(log, id!).includes("exited"))
await Pty.remove(id)
await wait(() => pick(log, id).length >= 3)
expect(pick(log, id)).toEqual(["created", "exited", "deleted"])
await wait(() => pick(log, id!).length >= 3)
expect(pick(log, id!)).toEqual(["created", "exited", "deleted"])
} finally {
off.forEach((x) => x())
if (id) await Pty.remove(id)
@ -60,14 +61,14 @@ describe("pty", () => {
await Instance.provide({
directory: dir.path,
fn: async () => {
const log: Array<{ type: "created" | "exited" | "deleted"; id: string }> = []
const log: Array<{ type: "created" | "exited" | "deleted"; id: PtyID }> = []
const off = [
Bus.subscribe(Pty.Event.Created, (evt) => log.push({ type: "created", id: evt.properties.info.id })),
Bus.subscribe(Pty.Event.Exited, (evt) => log.push({ type: "exited", id: evt.properties.id })),
Bus.subscribe(Pty.Event.Deleted, (evt) => log.push({ type: "deleted", id: evt.properties.id })),
]
let id = ""
let id: PtyID | undefined
try {
const info = await Pty.create({ command: "/bin/sh", title: "sh" })
id = info.id
@ -75,8 +76,8 @@ describe("pty", () => {
await sleep(100)
await Pty.remove(id)
await wait(() => pick(log, id).length >= 3)
expect(pick(log, id)).toEqual(["created", "exited", "deleted"])
await wait(() => pick(log, id!).length >= 3)
expect(pick(log, id!)).toEqual(["created", "exited", "deleted"])
} finally {
off.forEach((x) => x())
if (id) await Pty.remove(id)

View File

@ -7,7 +7,7 @@ import { Instance } from "../../src/project/instance"
import { Provider } from "../../src/provider/provider"
import { ProviderTransform } from "../../src/provider/transform"
import { ModelsDev } from "../../src/provider/models"
import { ProviderID } from "../../src/provider/schema"
import { ProviderID, ModelID } from "../../src/provider/schema"
import { Filesystem } from "../../src/util/filesystem"
import { tmpdir } from "../fixture/fixture"
import type { Agent } from "../../src/agent/agent"
@ -266,7 +266,7 @@ describe("session.llm.stream", () => {
await Instance.provide({
directory: tmp.path,
fn: async () => {
const resolved = await Provider.getModel(providerID, model.id)
const resolved = await Provider.getModel(ProviderID.make(providerID), ModelID.make(model.id))
const sessionID = SessionID.make("session-test-1")
const agent = {
name: "test",
@ -396,7 +396,7 @@ describe("session.llm.stream", () => {
await Instance.provide({
directory: tmp.path,
fn: async () => {
const resolved = await Provider.getModel("openai", model.id)
const resolved = await Provider.getModel(ProviderID.openai, ModelID.make(model.id))
const sessionID = SessionID.make("session-test-2")
const agent = {
name: "test",
@ -518,7 +518,7 @@ describe("session.llm.stream", () => {
await Instance.provide({
directory: tmp.path,
fn: async () => {
const resolved = await Provider.getModel(providerID, model.id)
const resolved = await Provider.getModel(ProviderID.make(providerID), ModelID.make(model.id))
const sessionID = SessionID.make("session-test-3")
const agent = {
name: "test",
@ -619,7 +619,7 @@ describe("session.llm.stream", () => {
await Instance.provide({
directory: tmp.path,
fn: async () => {
const resolved = await Provider.getModel(providerID, model.id)
const resolved = await Provider.getModel(ProviderID.make(providerID), ModelID.make(model.id))
const sessionID = SessionID.make("session-test-4")
const agent = {
name: "test",