refactor(provider): flow branded ProviderID/ModelID through internal signatures (#17182)

This commit is contained in:
Kit Langton
2026-03-12 10:48:17 -04:00
committed by GitHub
parent a4f8d66a9b
commit 1cb7df7159
24 changed files with 227 additions and 205 deletions

View File

@@ -35,7 +35,7 @@ import { Hash } from "../util/hash"
import { ACPSessionManager } from "./session"
import type { ACPConfig } from "./types"
import { Provider } from "../provider/provider"
import { ProviderID } from "../provider/schema"
import { ModelID, ProviderID } from "../provider/schema"
import { Agent as AgentModule } from "../agent/agent"
import { Installation } from "@/installation"
import { MessageV2 } from "@/session/message-v2"
@@ -56,8 +56,8 @@ export namespace ACP {
async function getContextLimit(
sdk: OpencodeClient,
providerID: string,
modelID: string,
providerID: ProviderID,
modelID: ModelID,
directory: string,
): Promise<number | null> {
const providers = await sdk.config
@@ -97,7 +97,8 @@ export namespace ACP {
if (!lastAssistant) return
const msg = lastAssistant.info
const size = await getContextLimit(sdk, msg.providerID, msg.modelID, directory)
if (!msg.providerID || !msg.modelID) return
const size = await getContextLimit(sdk, ProviderID.make(msg.providerID), ModelID.make(msg.modelID), directory)
if (!size) {
// Cannot calculate usage without known context size
@@ -637,8 +638,8 @@ export namespace ACP {
if (lastUser?.role === "user") {
result.models.currentModelId = `${lastUser.model.providerID}/${lastUser.model.modelID}`
this.sessionManager.setModel(sessionId, {
providerID: lastUser.model.providerID,
modelID: lastUser.model.modelID,
providerID: ProviderID.make(lastUser.model.providerID),
modelID: ModelID.make(lastUser.model.modelID),
})
if (result.modes?.availableModes.some((m) => m.id === lastUser.agent)) {
result.modes.currentModeId = lastUser.agent
@@ -1526,7 +1527,7 @@ export namespace ACP {
}
}
async function defaultModel(config: ACPConfig, cwd?: string) {
async function defaultModel(config: ACPConfig, cwd?: string): Promise<{ providerID: ProviderID; modelID: ModelID }> {
const sdk = config.sdk
const configured = config.defaultModel
if (configured) return configured
@@ -1538,11 +1539,7 @@ export namespace ACP {
.then((resp) => {
const cfg = resp.data
if (!cfg || !cfg.model) return undefined
const parsed = Provider.parseModel(cfg.model)
return {
providerID: parsed.providerID,
modelID: parsed.modelID,
}
return Provider.parseModel(cfg.model)
})
.catch((error) => {
log.error("failed to load user config for default model", { error })
@@ -1567,13 +1564,13 @@ export namespace ACP {
const opencodeProvider = providers.find((p) => p.id === "opencode")
if (opencodeProvider) {
if (opencodeProvider.models["big-pickle"]) {
return { providerID: "opencode", modelID: "big-pickle" }
return { providerID: ProviderID.opencode, modelID: ModelID.make("big-pickle") }
}
const [best] = Provider.sort(Object.values(opencodeProvider.models))
if (best) {
return {
providerID: best.providerID,
modelID: best.id,
providerID: ProviderID.make(best.providerID),
modelID: ModelID.make(best.id),
}
}
}
@@ -1582,14 +1579,14 @@ export namespace ACP {
const [best] = Provider.sort(models)
if (best) {
return {
providerID: best.providerID,
modelID: best.id,
providerID: ProviderID.make(best.providerID),
modelID: ModelID.make(best.id),
}
}
if (specified) return specified
return { providerID: "opencode", modelID: "big-pickle" }
return { providerID: ProviderID.opencode, modelID: ModelID.make("big-pickle") }
}
function parseUri(
@@ -1652,7 +1649,7 @@ export namespace ACP {
function modelVariantsFromProviders(
providers: Array<{ id: string; models: Record<string, { variants?: Record<string, any> }> }>,
model: { providerID: string; modelID: string },
model: { providerID: ProviderID; modelID: ModelID },
): string[] {
const provider = providers.find((entry) => entry.id === model.providerID)
if (!provider) return []
@@ -1688,7 +1685,7 @@ export namespace ACP {
}
function formatModelIdWithVariant(
model: { providerID: string; modelID: string },
model: { providerID: ProviderID; modelID: ModelID },
variant: string | undefined,
availableVariants: string[],
includeVariant: boolean,
@@ -1699,7 +1696,7 @@ export namespace ACP {
}
function buildVariantMeta(input: {
model: { providerID: string; modelID: string }
model: { providerID: ProviderID; modelID: ModelID }
variant?: string
availableVariants: string[]
}) {
@@ -1715,7 +1712,7 @@ export namespace ACP {
function parseModelSelection(
modelId: string,
providers: Array<{ id: string; models: Record<string, { variants?: Record<string, any> }> }>,
): { model: { providerID: string; modelID: string }; variant?: string } {
): { model: { providerID: ProviderID; modelID: ModelID }; variant?: string } {
const parsed = Provider.parseModel(modelId)
const provider = providers.find((p) => p.id === parsed.providerID)
if (!provider) {
@@ -1735,7 +1732,7 @@ export namespace ACP {
const baseModelInfo = provider.models[baseModelId]
if (baseModelInfo?.variants && candidateVariant in baseModelInfo.variants) {
return {
model: { providerID: parsed.providerID, modelID: baseModelId },
model: { providerID: parsed.providerID, modelID: ModelID.make(baseModelId) },
variant: candidateVariant,
}
}

View File

@@ -1,5 +1,6 @@
import type { McpServer } from "@agentclientprotocol/sdk"
import type { OpencodeClient } from "@opencode-ai/sdk/v2"
import type { ProviderID, ModelID } from "../provider/schema"
export interface ACPSessionState {
id: string
@@ -7,8 +8,8 @@ export interface ACPSessionState {
mcpServers: McpServer[]
createdAt: Date
model?: {
providerID: string
modelID: string
providerID: ProviderID
modelID: ModelID
}
variant?: string
modeId?: string
@@ -17,7 +18,7 @@ export interface ACPSessionState {
export interface ACPConfig {
sdk: OpencodeClient
defaultModel?: {
providerID: string
modelID: string
providerID: ProviderID
modelID: ModelID
}
}

View File

@@ -281,7 +281,7 @@ export namespace Agent {
return primaryVisible.name
}
export async function generate(input: { description: string; model?: { providerID: string; modelID: string } }) {
export async function generate(input: { description: string; model?: { providerID: ProviderID; modelID: ModelID } }) {
const cfg = await Config.get()
const defaultModel = input.model ?? (await Provider.defaultModel())
const model = await Provider.getModel(defaultModel.providerID, defaultModel.modelID)

View File

@@ -1,6 +1,7 @@
import type { Argv } from "yargs"
import { Instance } from "../../project/instance"
import { Provider } from "../../provider/provider"
import { ProviderID } from "../../provider/schema"
import { ModelsDev } from "../../provider/models"
import { cmd } from "./cmd"
import { UI } from "../ui"
@@ -36,7 +37,7 @@ export const ModelsCommand = cmd({
async fn() {
const providers = await Provider.list()
function printModels(providerID: string, verbose?: boolean) {
function printModels(providerID: ProviderID, verbose?: boolean) {
const provider = providers[providerID]
const sortedModels = Object.entries(provider.models).sort(([a], [b]) => a.localeCompare(b))
for (const [modelID, model] of sortedModels) {
@@ -56,7 +57,7 @@ export const ModelsCommand = cmd({
return
}
printModels(args.provider, args.verbose)
printModels(ProviderID.make(args.provider), args.verbose)
return
}
@@ -69,7 +70,7 @@ export const ModelsCommand = cmd({
})
for (const providerID of providerIDs) {
printModels(providerID, args.verbose)
printModels(ProviderID.make(providerID), args.verbose)
}
},
})

View File

@@ -15,9 +15,13 @@ export namespace Permission {
return pattern === undefined ? [type] : Array.isArray(pattern) ? pattern : [pattern]
}
function covered(keys: string[], approved: Record<string, boolean>): boolean {
const pats = Object.keys(approved)
return keys.every((k) => pats.some((p) => Wildcard.match(k, p)))
function covered(keys: string[], approved: Map<string, boolean>): boolean {
return keys.every((k) => {
for (const p of approved.keys()) {
if (Wildcard.match(k, p)) return true
}
return false
})
}
export const Info = z
@@ -39,6 +43,12 @@ export namespace Permission {
})
export type Info = z.infer<typeof Info>
interface PendingEntry {
info: Info
resolve: () => void
reject: (e: any) => void
}
export const Event = {
Updated: BusEvent.define("permission.updated", Info),
Replied: BusEvent.define(
@@ -52,31 +62,13 @@ export namespace Permission {
}
const state = Instance.state(
() => {
const pending: {
[sessionID: string]: {
[permissionID: string]: {
info: Info
resolve: () => void
reject: (e: any) => void
}
}
} = {}
const approved: {
[sessionID: string]: {
[permissionID: string]: boolean
}
} = {}
return {
pending,
approved,
}
},
() => ({
pending: new Map<SessionID, Map<PermissionID, PendingEntry>>(),
approved: new Map<SessionID, Map<string, boolean>>(),
}),
async (state) => {
for (const pending of Object.values(state.pending)) {
for (const item of Object.values(pending)) {
for (const session of state.pending.values()) {
for (const item of session.values()) {
item.reject(new RejectedError(item.info.sessionID, item.info.id, item.info.callID, item.info.metadata))
}
}
@@ -90,8 +82,8 @@ export namespace Permission {
export function list() {
const { pending } = state()
const result: Info[] = []
for (const items of Object.values(pending)) {
for (const item of Object.values(items)) {
for (const session of pending.values()) {
for (const item of session.values()) {
result.push(item.info)
}
}
@@ -114,9 +106,9 @@ export namespace Permission {
toolCallID: input.callID,
pattern: input.pattern,
})
const approvedForSession = approved[input.sessionID] || {}
const approvedForSession = approved.get(input.sessionID)
const keys = toKeys(input.pattern, input.type)
if (covered(keys, approvedForSession)) return
if (approvedForSession && covered(keys, approvedForSession)) return
const info: Info = {
id: PermissionID.ascending(),
type: input.type,
@@ -142,13 +134,13 @@ export namespace Permission {
return
}
pending[input.sessionID] = pending[input.sessionID] || {}
if (!pending.has(input.sessionID)) pending.set(input.sessionID, new Map())
return new Promise<void>((resolve, reject) => {
pending[input.sessionID][info.id] = {
pending.get(input.sessionID)!.set(info.id, {
info,
resolve,
reject,
}
})
Bus.publish(Event.Updated, info)
})
}
@@ -159,9 +151,11 @@ export namespace Permission {
export function respond(input: { sessionID: Info["sessionID"]; permissionID: Info["id"]; response: Response }) {
log.info("response", input)
const { pending, approved } = state()
const match = pending[input.sessionID]?.[input.permissionID]
if (!match) return
delete pending[input.sessionID][input.permissionID]
const session = pending.get(input.sessionID)
const match = session?.get(input.permissionID)
if (!session || !match) return
session.delete(input.permissionID)
if (session.size === 0) pending.delete(input.sessionID)
Bus.publish(Event.Replied, {
sessionID: input.sessionID,
permissionID: input.permissionID,
@@ -173,30 +167,35 @@ export namespace Permission {
}
match.resolve()
if (input.response === "always") {
approved[input.sessionID] = approved[input.sessionID] || {}
if (!approved.has(input.sessionID)) approved.set(input.sessionID, new Map())
const approvedSession = approved.get(input.sessionID)!
const approveKeys = toKeys(match.info.pattern, match.info.type)
for (const k of approveKeys) {
approved[input.sessionID][k] = true
approvedSession.set(k, true)
}
const items = pending[input.sessionID]
const items = pending.get(input.sessionID)
if (!items) return
for (const item of Object.values(items)) {
const toRespond: Info[] = []
for (const item of items.values()) {
const itemKeys = toKeys(item.info.pattern, item.info.type)
if (covered(itemKeys, approved[input.sessionID])) {
respond({
sessionID: item.info.sessionID,
permissionID: item.info.id,
response: input.response,
})
if (covered(itemKeys, approvedSession)) {
toRespond.push(item.info)
}
}
for (const item of toRespond) {
respond({
sessionID: item.sessionID,
permissionID: item.id,
response: input.response,
})
}
}
}
export class RejectedError extends Error {
constructor(
public readonly sessionID: string,
public readonly permissionID: string,
public readonly sessionID: SessionID,
public readonly permissionID: PermissionID,
public readonly toolCallID?: string,
public readonly metadata?: Record<string, any>,
public readonly reason?: string,

View File

@@ -108,6 +108,12 @@ export namespace PermissionNext {
),
}
interface PendingEntry {
info: Request
resolve: () => void
reject: (e: any) => void
}
const state = Instance.state(() => {
const projectID = Instance.project.id
const row = Database.use((db) =>
@@ -115,17 +121,8 @@ export namespace PermissionNext {
)
const stored = row?.data ?? ([] as Ruleset)
const pending: Record<
string,
{
info: Request
resolve: () => void
reject: (e: any) => void
}
> = {}
return {
pending,
pending: new Map<PermissionID, PendingEntry>(),
approved: stored,
}
})
@@ -149,11 +146,11 @@ export namespace PermissionNext {
id,
...request,
}
s.pending[id] = {
s.pending.set(id, {
info,
resolve,
reject,
}
})
Bus.publish(Event.Asked, info)
})
}
@@ -170,9 +167,9 @@ export namespace PermissionNext {
}),
async (input) => {
const s = await state()
const existing = s.pending[input.requestID]
const existing = s.pending.get(input.requestID)
if (!existing) return
delete s.pending[input.requestID]
s.pending.delete(input.requestID)
Bus.publish(Event.Replied, {
sessionID: existing.info.sessionID,
requestID: existing.info.id,
@@ -182,9 +179,9 @@ export namespace PermissionNext {
existing.reject(input.message ? new CorrectedError(input.message) : new RejectedError())
// Reject all other pending permissions for this session
const sessionID = existing.info.sessionID
for (const [id, pending] of Object.entries(s.pending)) {
for (const [id, pending] of s.pending) {
if (pending.info.sessionID === sessionID) {
delete s.pending[id]
s.pending.delete(id)
Bus.publish(Event.Replied, {
sessionID: pending.info.sessionID,
requestID: pending.info.id,
@@ -211,13 +208,13 @@ export namespace PermissionNext {
existing.resolve()
const sessionID = existing.info.sessionID
for (const [id, pending] of Object.entries(s.pending)) {
for (const [id, pending] of s.pending) {
if (pending.info.sessionID !== sessionID) continue
const ok = pending.info.patterns.every(
(pattern) => evaluate(pending.info.permission, pattern, s.approved).action === "allow",
)
if (!ok) continue
delete s.pending[id]
s.pending.delete(id)
Bus.publish(Event.Replied, {
sessionID: pending.info.sessionID,
requestID: pending.info.id,
@@ -283,6 +280,6 @@ export namespace PermissionNext {
export async function list() {
const s = await state()
return Object.values(s.pending).map((x) => x.info)
return Array.from(s.pending.values(), (x) => x.info)
}
}

View File

@@ -377,7 +377,7 @@ export async function CodexAuthPlugin(input: PluginInput): Promise<Hooks> {
if (!provider.models["gpt-5.3-codex"]) {
const model = {
id: ModelID.make("gpt-5.3-codex"),
providerID: ProviderID.make("openai"),
providerID: ProviderID.openai,
api: {
id: "gpt-5.3-codex",
url: "https://chatgpt.com/backend-api/codex",

View File

@@ -1,6 +1,7 @@
import { APICallError } from "ai"
import { STATUS_CODES } from "http"
import { iife } from "@/util/iife"
import type { ProviderID } from "./schema"
export namespace ProviderError {
// Adapted from overflow detection patterns in:
@@ -40,7 +41,7 @@ export namespace ProviderError {
return /^4(00|13)\s*(status code)?\s*\(no body\)/i.test(message)
}
function message(providerID: string, e: APICallError) {
function message(providerID: ProviderID, e: APICallError) {
return iife(() => {
const msg = e.message
if (msg === "") {
@@ -164,7 +165,7 @@ export namespace ProviderError {
metadata?: Record<string, string>
}
export function parseAPICallError(input: { providerID: string; error: APICallError }): ParsedAPICallError {
export function parseAPICallError(input: { providerID: ProviderID; error: APICallError }): ParsedAPICallError {
const m = message(input.providerID, input.error)
if (isOverflow(m) || input.error.statusCode === 413) {
return {

View File

@@ -845,7 +845,7 @@ export namespace Provider {
const disabled = new Set(config.disabled_providers ?? [])
const enabled = config.enabled_providers ? new Set(config.enabled_providers) : null
function isProviderAllowed(providerID: string): boolean {
function isProviderAllowed(providerID: ProviderID): boolean {
if (enabled && !enabled.has(providerID)) return false
if (disabled.has(providerID)) return false
return true
@@ -867,16 +867,16 @@ export namespace Provider {
const githubCopilot = database["github-copilot"]
database["github-copilot-enterprise"] = {
...githubCopilot,
id: ProviderID.make("github-copilot-enterprise"),
id: ProviderID.githubCopilotEnterprise,
name: "GitHub Copilot Enterprise",
models: mapValues(githubCopilot.models, (model) => ({
...model,
providerID: ProviderID.make("github-copilot-enterprise"),
providerID: ProviderID.githubCopilotEnterprise,
})),
}
}
function mergeProvider(providerID: string, provider: Partial<Info>) {
function mergeProvider(providerID: ProviderID, provider: Partial<Info>) {
const existing = providers[providerID]
if (existing) {
// @ts-expect-error
@@ -974,7 +974,8 @@ export namespace Provider {
// load env
const env = Env.all()
for (const [providerID, provider] of Object.entries(database)) {
for (const [id, provider] of Object.entries(database)) {
const providerID = ProviderID.make(id)
if (disabled.has(providerID)) continue
const apiKey = provider.env.map((item) => env[item]).find(Boolean)
if (!apiKey) continue
@@ -985,7 +986,8 @@ export namespace Provider {
}
// load apikeys
for (const [providerID, provider] of Object.entries(await Auth.all())) {
for (const [id, provider] of Object.entries(await Auth.all())) {
const providerID = ProviderID.make(id)
if (disabled.has(providerID)) continue
if (provider.type === "api") {
mergeProvider(providerID, {
@@ -997,7 +999,7 @@ export namespace Provider {
for (const plugin of await Plugin.list()) {
if (!plugin.auth) continue
const providerID = plugin.auth.provider
const providerID = ProviderID.make(plugin.auth.provider)
if (disabled.has(providerID)) continue
// For github-copilot plugin, check if auth exists for either github-copilot or github-copilot-enterprise
@@ -1006,7 +1008,7 @@ export namespace Provider {
if (auth) hasAuth = true
// Special handling for github-copilot: also check for enterprise auth
if (providerID === "github-copilot" && !hasAuth) {
if (providerID === ProviderID.githubCopilot && !hasAuth) {
const enterpriseAuth = await Auth.get("github-copilot-enterprise")
if (enterpriseAuth) hasAuth = true
}
@@ -1023,8 +1025,8 @@ export namespace Provider {
}
// If this is github-copilot plugin, also register for github-copilot-enterprise if auth exists
if (providerID === "github-copilot") {
const enterpriseProviderID = "github-copilot-enterprise"
if (providerID === ProviderID.githubCopilot) {
const enterpriseProviderID = ProviderID.githubCopilotEnterprise
if (!disabled.has(enterpriseProviderID)) {
const enterpriseAuth = await Auth.get(enterpriseProviderID)
if (enterpriseAuth) {
@@ -1042,7 +1044,8 @@ export namespace Provider {
}
}
for (const [providerID, fn] of Object.entries(CUSTOM_LOADERS)) {
for (const [id, fn] of Object.entries(CUSTOM_LOADERS)) {
const providerID = ProviderID.make(id)
if (disabled.has(providerID)) continue
const data = database[providerID]
if (!data) {
@@ -1059,7 +1062,8 @@ export namespace Provider {
}
// load config
for (const [providerID, provider] of configProviders) {
for (const [id, provider] of configProviders) {
const providerID = ProviderID.make(id)
const partial: Partial<Info> = { source: "config" }
if (provider.env) partial.env = provider.env
if (provider.name) partial.name = provider.name
@@ -1067,7 +1071,8 @@ export namespace Provider {
mergeProvider(providerID, partial)
}
for (const [providerID, provider] of Object.entries(providers)) {
for (const [id, provider] of Object.entries(providers)) {
const providerID = ProviderID.make(id)
if (!isProviderAllowed(providerID)) {
delete providers[providerID]
continue
@@ -1077,7 +1082,7 @@ export namespace Provider {
for (const [modelID, model] of Object.entries(provider.models)) {
model.api.id = model.api.id ?? model.id ?? modelID
if (modelID === "gpt-5-chat-latest" || (providerID === "openrouter" && modelID === "openai/gpt-5-chat"))
if (modelID === "gpt-5-chat-latest" || (providerID === ProviderID.openrouter && modelID === "openai/gpt-5-chat"))
delete provider.models[modelID]
if (model.status === "alpha" && !Flag.OPENCODE_ENABLE_EXPERIMENTAL_MODELS) delete provider.models[modelID]
if (model.status === "deprecated") delete provider.models[modelID]
@@ -1230,11 +1235,11 @@ export namespace Provider {
}
}
export async function getProvider(providerID: string) {
export async function getProvider(providerID: ProviderID) {
return state().then((s) => s.providers[providerID])
}
export async function getModel(providerID: string, modelID: string) {
export async function getModel(providerID: ProviderID, modelID: ModelID) {
const s = await state()
const provider = s.providers[providerID]
if (!provider) {
@@ -1281,7 +1286,7 @@ export namespace Provider {
}
}
export async function closest(providerID: string, query: string[]) {
export async function closest(providerID: ProviderID, query: string[]) {
const s = await state()
const provider = s.providers[providerID]
if (!provider) return undefined
@@ -1296,7 +1301,7 @@ export namespace Provider {
}
}
export async function getSmallModel(providerID: string) {
export async function getSmallModel(providerID: ProviderID) {
const cfg = await Config.get()
if (cfg.small_model) {
@@ -1323,7 +1328,7 @@ export namespace Provider {
priority = ["gpt-5-mini", "claude-haiku-4.5", ...priority]
}
for (const item of priority) {
if (providerID === "amazon-bedrock") {
if (providerID === ProviderID.amazonBedrock) {
const crossRegionPrefixes = ["global.", "us.", "eu."]
const candidates = Object.keys(provider.models).filter((m) => m.includes(item))
@@ -1332,22 +1337,22 @@ export namespace Provider {
// 2. User's region prefix (us., eu.)
// 3. Unprefixed model
const globalMatch = candidates.find((m) => m.startsWith("global."))
if (globalMatch) return getModel(providerID, globalMatch)
if (globalMatch) return getModel(providerID, ModelID.make(globalMatch))
const region = provider.options?.region
if (region) {
const regionPrefix = region.split("-")[0]
if (regionPrefix === "us" || regionPrefix === "eu") {
const regionalMatch = candidates.find((m) => m.startsWith(`${regionPrefix}.`))
if (regionalMatch) return getModel(providerID, regionalMatch)
if (regionalMatch) return getModel(providerID, ModelID.make(regionalMatch))
}
}
const unprefixed = candidates.find((m) => !crossRegionPrefixes.some((p) => m.startsWith(p)))
if (unprefixed) return getModel(providerID, unprefixed)
if (unprefixed) return getModel(providerID, ModelID.make(unprefixed))
} else {
for (const model of Object.keys(provider.models)) {
if (model.includes(item)) return getModel(providerID, model)
if (model.includes(item)) return getModel(providerID, ModelID.make(model))
}
}
}

View File

@@ -11,6 +11,18 @@ export const ProviderID = providerIdSchema.pipe(
withStatics((schema: typeof providerIdSchema) => ({
make: (id: string) => schema.makeUnsafe(id),
zod: z.string().pipe(z.custom<ProviderID>()),
// Well-known providers
opencode: schema.makeUnsafe("opencode"),
anthropic: schema.makeUnsafe("anthropic"),
openai: schema.makeUnsafe("openai"),
google: schema.makeUnsafe("google"),
googleVertex: schema.makeUnsafe("google-vertex"),
githubCopilot: schema.makeUnsafe("github-copilot"),
githubCopilotEnterprise: schema.makeUnsafe("github-copilot-enterprise"),
amazonBedrock: schema.makeUnsafe("amazon-bedrock"),
azure: schema.makeUnsafe("azure"),
openrouter: schema.makeUnsafe("openrouter"),
mistral: schema.makeUnsafe("mistral"),
})),
)

View File

@@ -91,7 +91,7 @@ export namespace Pty {
}
const state = Instance.state(
() => new Map<string, ActiveSession>(),
() => new Map<PtyID, ActiveSession>(),
async (sessions) => {
for (const session of sessions.values()) {
try {
@@ -113,7 +113,7 @@ export namespace Pty {
return Array.from(state().values()).map((s) => s.info)
}
export function get(id: string) {
export function get(id: PtyID) {
return state().get(id)?.info
}
@@ -205,7 +205,7 @@ export namespace Pty {
return info
}
export async function update(id: string, input: UpdateInput) {
export async function update(id: PtyID, input: UpdateInput) {
const session = state().get(id)
if (!session) return
if (input.title) {
@@ -218,7 +218,7 @@ export namespace Pty {
return session.info
}
export async function remove(id: string) {
export async function remove(id: PtyID) {
const session = state().get(id)
if (!session) return
state().delete(id)
@@ -237,21 +237,21 @@ export namespace Pty {
Bus.publish(Event.Deleted, { id: session.info.id })
}
export function resize(id: string, cols: number, rows: number) {
export function resize(id: PtyID, cols: number, rows: number) {
const session = state().get(id)
if (session && session.info.status === "running") {
session.process.resize(cols, rows)
}
}
export function write(id: string, data: string) {
export function write(id: PtyID, data: string) {
const session = state().get(id)
if (session && session.info.status === "running") {
session.process.write(data)
}
}
export function connect(id: string, ws: Socket, cursor?: number) {
export function connect(id: PtyID, ws: Socket, cursor?: number) {
const session = state().get(id)
if (!session) {
ws.close()

View File

@@ -80,20 +80,15 @@ export namespace Question {
),
}
const state = Instance.state(async () => {
const pending: Record<
string,
{
info: Request
resolve: (answers: Answer[]) => void
reject: (e: any) => void
}
> = {}
interface PendingEntry {
info: Request
resolve: (answers: Answer[]) => void
reject: (e: any) => void
}
return {
pending,
}
})
const state = Instance.state(async () => ({
pending: new Map<QuestionID, PendingEntry>(),
}))
export async function ask(input: {
sessionID: SessionID
@@ -112,23 +107,23 @@ export namespace Question {
questions: input.questions,
tool: input.tool,
}
s.pending[id] = {
s.pending.set(id, {
info,
resolve,
reject,
}
})
Bus.publish(Event.Asked, info)
})
}
export async function reply(input: { requestID: string; answers: Answer[] }): Promise<void> {
export async function reply(input: { requestID: QuestionID; answers: Answer[] }): Promise<void> {
const s = await state()
const existing = s.pending[input.requestID]
const existing = s.pending.get(input.requestID)
if (!existing) {
log.warn("reply for unknown request", { requestID: input.requestID })
return
}
delete s.pending[input.requestID]
s.pending.delete(input.requestID)
log.info("replied", { requestID: input.requestID, answers: input.answers })
@@ -141,14 +136,14 @@ export namespace Question {
existing.resolve(input.answers)
}
export async function reject(requestID: string): Promise<void> {
export async function reject(requestID: QuestionID): Promise<void> {
const s = await state()
const existing = s.pending[requestID]
const existing = s.pending.get(requestID)
if (!existing) {
log.warn("reject for unknown request", { requestID })
return
}
delete s.pending[requestID]
s.pending.delete(requestID)
log.info("rejected", { requestID })
@@ -167,6 +162,6 @@ export namespace Question {
}
export async function list() {
return state().then((x) => Object.values(x.pending).map((x) => x.info))
return state().then((x) => Array.from(x.pending.values(), (x) => x.info))
}
}

View File

@@ -1,6 +1,7 @@
import { Hono } from "hono"
import { describeRoute, validator, resolver } from "hono-openapi"
import z from "zod"
import { ProviderID, ModelID } from "../../provider/schema"
import { ToolRegistry } from "../../tool/registry"
import { Worktree } from "../../worktree"
import { Instance } from "../../project/instance"
@@ -77,7 +78,7 @@ export const ExperimentalRoutes = lazy(() =>
),
async (c) => {
const { provider, model } = c.req.valid("query")
const tools = await ToolRegistry.tools({ providerID: provider, modelID: model })
const tools = await ToolRegistry.tools({ providerID: ProviderID.make(provider), modelID: ModelID.make(model) })
return c.json(
tools.map((t) => ({
id: t.id,

View File

@@ -237,7 +237,7 @@ export namespace SessionPrompt {
return parts
}
function start(sessionID: string) {
function start(sessionID: SessionID) {
const s = state()
if (s[sessionID]) return
const controller = new AbortController()
@@ -248,7 +248,7 @@ export namespace SessionPrompt {
return controller.signal
}
function resume(sessionID: string) {
function resume(sessionID: SessionID) {
const s = state()
if (!s[sessionID]) return
@@ -788,7 +788,7 @@ export namespace SessionPrompt {
})
for (const item of await ToolRegistry.tools(
{ modelID: input.model.api.id, providerID: input.model.providerID },
{ modelID: ModelID.make(input.model.api.id), providerID: input.model.providerID },
input.agent,
)) {
const schema = ProviderTransform.schema(input.model, z.toJSONSchema(item.parameters))
@@ -1898,8 +1898,8 @@ NOTE: At any point in time through this workflow you should feel free to ask the
async function ensureTitle(input: {
session: Session.Info
history: MessageV2.WithParts[]
providerID: string
modelID: string
providerID: ProviderID
modelID: ModelID
}) {
if (input.session.parentID) return
if (!Session.isDefaultTitle(input.session.title)) return

View File

@@ -2,6 +2,7 @@ import { Bus } from "@/bus"
import { Account } from "@/account"
import { Config } from "@/config/config"
import { Provider } from "@/provider/provider"
import { ProviderID, ModelID } from "@/provider/schema"
import { Session } from "@/session"
import type { SessionID } from "@/session/schema"
import { MessageV2 } from "@/session/message-v2"
@@ -262,7 +263,7 @@ export namespace ShareNext {
.map((m) => (m.info as SDK.UserMessage).model)
.map((m) => [`${m.providerID}/${m.modelID}`, m] as const),
).values(),
).map((m) => Provider.getModel(m.providerID, m.modelID).then((item) => item)),
).map((m) => Provider.getModel(ProviderID.make(m.providerID), ModelID.make(m.modelID)).then((item) => item)),
)
await sync(sessionID, [
{

View File

@@ -1,5 +1,6 @@
import z from "zod"
import { Tool } from "./tool"
import { ProviderID, ModelID } from "../provider/schema"
import DESCRIPTION from "./batch.txt"
const DISALLOWED = new Set(["batch"])
@@ -37,7 +38,7 @@ export const BatchTool = Tool.define("batch", async () => {
const discardedCalls = params.tool_calls.slice(25)
const { ToolRegistry } = await import("./registry")
const availableTools = await ToolRegistry.tools({ modelID: "", providerID: "" })
const availableTools = await ToolRegistry.tools({ modelID: ModelID.make(""), providerID: ProviderID.make("") })
const toolMap = new Map(availableTools.map((t) => [t.id, t]))
const executeCall = async (call: (typeof toolCalls)[0]) => {

View File

@@ -20,6 +20,7 @@ import path from "path"
import { type ToolContext as PluginToolContext, type ToolDefinition } from "@opencode-ai/plugin"
import z from "zod"
import { Plugin } from "../plugin"
import { ProviderID, type ModelID } from "../provider/schema"
import { WebSearchTool } from "./websearch"
import { CodeSearchTool } from "./codesearch"
import { Flag } from "@/flag/flag"
@@ -130,8 +131,8 @@ export namespace ToolRegistry {
export async function tools(
model: {
providerID: string
modelID: string
providerID: ProviderID
modelID: ModelID
},
agent?: Agent.Info,
) {
@@ -141,7 +142,7 @@ export namespace ToolRegistry {
.filter((t) => {
// Enable websearch/codesearch for zen users OR via enable flag
if (t.id === "codesearch" || t.id === "websearch") {
return model.providerID === "opencode" || Flag.OPENCODE_ENABLE_EXA
return model.providerID === ProviderID.opencode || Flag.OPENCODE_ENABLE_EXA
}
// use apply tool in same format as codex