refactor(provider): flow branded ProviderID/ModelID through internal signatures (#17182)

This commit is contained in:
Kit Langton
2026-03-12 10:48:17 -04:00
committed by GitHub
parent a4f8d66a9b
commit 1cb7df7159
24 changed files with 227 additions and 205 deletions

View File

@@ -1,6 +1,7 @@
import { APICallError } from "ai"
import { STATUS_CODES } from "http"
import { iife } from "@/util/iife"
import type { ProviderID } from "./schema"
export namespace ProviderError {
// Adapted from overflow detection patterns in:
@@ -40,7 +41,7 @@ export namespace ProviderError {
return /^4(00|13)\s*(status code)?\s*\(no body\)/i.test(message)
}
function message(providerID: string, e: APICallError) {
function message(providerID: ProviderID, e: APICallError) {
return iife(() => {
const msg = e.message
if (msg === "") {
@@ -164,7 +165,7 @@ export namespace ProviderError {
metadata?: Record<string, string>
}
export function parseAPICallError(input: { providerID: string; error: APICallError }): ParsedAPICallError {
export function parseAPICallError(input: { providerID: ProviderID; error: APICallError }): ParsedAPICallError {
const m = message(input.providerID, input.error)
if (isOverflow(m) || input.error.statusCode === 413) {
return {

View File

@@ -845,7 +845,7 @@ export namespace Provider {
const disabled = new Set(config.disabled_providers ?? [])
const enabled = config.enabled_providers ? new Set(config.enabled_providers) : null
function isProviderAllowed(providerID: string): boolean {
function isProviderAllowed(providerID: ProviderID): boolean {
if (enabled && !enabled.has(providerID)) return false
if (disabled.has(providerID)) return false
return true
@@ -867,16 +867,16 @@ export namespace Provider {
const githubCopilot = database["github-copilot"]
database["github-copilot-enterprise"] = {
...githubCopilot,
id: ProviderID.make("github-copilot-enterprise"),
id: ProviderID.githubCopilotEnterprise,
name: "GitHub Copilot Enterprise",
models: mapValues(githubCopilot.models, (model) => ({
...model,
providerID: ProviderID.make("github-copilot-enterprise"),
providerID: ProviderID.githubCopilotEnterprise,
})),
}
}
function mergeProvider(providerID: string, provider: Partial<Info>) {
function mergeProvider(providerID: ProviderID, provider: Partial<Info>) {
const existing = providers[providerID]
if (existing) {
// @ts-expect-error
@@ -974,7 +974,8 @@ export namespace Provider {
// load env
const env = Env.all()
for (const [providerID, provider] of Object.entries(database)) {
for (const [id, provider] of Object.entries(database)) {
const providerID = ProviderID.make(id)
if (disabled.has(providerID)) continue
const apiKey = provider.env.map((item) => env[item]).find(Boolean)
if (!apiKey) continue
@@ -985,7 +986,8 @@ export namespace Provider {
}
// load apikeys
for (const [providerID, provider] of Object.entries(await Auth.all())) {
for (const [id, provider] of Object.entries(await Auth.all())) {
const providerID = ProviderID.make(id)
if (disabled.has(providerID)) continue
if (provider.type === "api") {
mergeProvider(providerID, {
@@ -997,7 +999,7 @@ export namespace Provider {
for (const plugin of await Plugin.list()) {
if (!plugin.auth) continue
const providerID = plugin.auth.provider
const providerID = ProviderID.make(plugin.auth.provider)
if (disabled.has(providerID)) continue
// For github-copilot plugin, check if auth exists for either github-copilot or github-copilot-enterprise
@@ -1006,7 +1008,7 @@ export namespace Provider {
if (auth) hasAuth = true
// Special handling for github-copilot: also check for enterprise auth
if (providerID === "github-copilot" && !hasAuth) {
if (providerID === ProviderID.githubCopilot && !hasAuth) {
const enterpriseAuth = await Auth.get("github-copilot-enterprise")
if (enterpriseAuth) hasAuth = true
}
@@ -1023,8 +1025,8 @@ export namespace Provider {
}
// If this is github-copilot plugin, also register for github-copilot-enterprise if auth exists
if (providerID === "github-copilot") {
const enterpriseProviderID = "github-copilot-enterprise"
if (providerID === ProviderID.githubCopilot) {
const enterpriseProviderID = ProviderID.githubCopilotEnterprise
if (!disabled.has(enterpriseProviderID)) {
const enterpriseAuth = await Auth.get(enterpriseProviderID)
if (enterpriseAuth) {
@@ -1042,7 +1044,8 @@ export namespace Provider {
}
}
for (const [providerID, fn] of Object.entries(CUSTOM_LOADERS)) {
for (const [id, fn] of Object.entries(CUSTOM_LOADERS)) {
const providerID = ProviderID.make(id)
if (disabled.has(providerID)) continue
const data = database[providerID]
if (!data) {
@@ -1059,7 +1062,8 @@ export namespace Provider {
}
// load config
for (const [providerID, provider] of configProviders) {
for (const [id, provider] of configProviders) {
const providerID = ProviderID.make(id)
const partial: Partial<Info> = { source: "config" }
if (provider.env) partial.env = provider.env
if (provider.name) partial.name = provider.name
@@ -1067,7 +1071,8 @@ export namespace Provider {
mergeProvider(providerID, partial)
}
for (const [providerID, provider] of Object.entries(providers)) {
for (const [id, provider] of Object.entries(providers)) {
const providerID = ProviderID.make(id)
if (!isProviderAllowed(providerID)) {
delete providers[providerID]
continue
@@ -1077,7 +1082,7 @@ export namespace Provider {
for (const [modelID, model] of Object.entries(provider.models)) {
model.api.id = model.api.id ?? model.id ?? modelID
if (modelID === "gpt-5-chat-latest" || (providerID === "openrouter" && modelID === "openai/gpt-5-chat"))
if (modelID === "gpt-5-chat-latest" || (providerID === ProviderID.openrouter && modelID === "openai/gpt-5-chat"))
delete provider.models[modelID]
if (model.status === "alpha" && !Flag.OPENCODE_ENABLE_EXPERIMENTAL_MODELS) delete provider.models[modelID]
if (model.status === "deprecated") delete provider.models[modelID]
@@ -1230,11 +1235,11 @@ export namespace Provider {
}
}
export async function getProvider(providerID: string) {
export async function getProvider(providerID: ProviderID) {
return state().then((s) => s.providers[providerID])
}
export async function getModel(providerID: string, modelID: string) {
export async function getModel(providerID: ProviderID, modelID: ModelID) {
const s = await state()
const provider = s.providers[providerID]
if (!provider) {
@@ -1281,7 +1286,7 @@ export namespace Provider {
}
}
export async function closest(providerID: string, query: string[]) {
export async function closest(providerID: ProviderID, query: string[]) {
const s = await state()
const provider = s.providers[providerID]
if (!provider) return undefined
@@ -1296,7 +1301,7 @@ export namespace Provider {
}
}
export async function getSmallModel(providerID: string) {
export async function getSmallModel(providerID: ProviderID) {
const cfg = await Config.get()
if (cfg.small_model) {
@@ -1323,7 +1328,7 @@ export namespace Provider {
priority = ["gpt-5-mini", "claude-haiku-4.5", ...priority]
}
for (const item of priority) {
if (providerID === "amazon-bedrock") {
if (providerID === ProviderID.amazonBedrock) {
const crossRegionPrefixes = ["global.", "us.", "eu."]
const candidates = Object.keys(provider.models).filter((m) => m.includes(item))
@@ -1332,22 +1337,22 @@ export namespace Provider {
// 2. User's region prefix (us., eu.)
// 3. Unprefixed model
const globalMatch = candidates.find((m) => m.startsWith("global."))
if (globalMatch) return getModel(providerID, globalMatch)
if (globalMatch) return getModel(providerID, ModelID.make(globalMatch))
const region = provider.options?.region
if (region) {
const regionPrefix = region.split("-")[0]
if (regionPrefix === "us" || regionPrefix === "eu") {
const regionalMatch = candidates.find((m) => m.startsWith(`${regionPrefix}.`))
if (regionalMatch) return getModel(providerID, regionalMatch)
if (regionalMatch) return getModel(providerID, ModelID.make(regionalMatch))
}
}
const unprefixed = candidates.find((m) => !crossRegionPrefixes.some((p) => m.startsWith(p)))
if (unprefixed) return getModel(providerID, unprefixed)
if (unprefixed) return getModel(providerID, ModelID.make(unprefixed))
} else {
for (const model of Object.keys(provider.models)) {
if (model.includes(item)) return getModel(providerID, model)
if (model.includes(item)) return getModel(providerID, ModelID.make(model))
}
}
}

View File

@@ -11,6 +11,18 @@ export const ProviderID = providerIdSchema.pipe(
withStatics((schema: typeof providerIdSchema) => ({
make: (id: string) => schema.makeUnsafe(id),
zod: z.string().pipe(z.custom<ProviderID>()),
// Well-known providers
opencode: schema.makeUnsafe("opencode"),
anthropic: schema.makeUnsafe("anthropic"),
openai: schema.makeUnsafe("openai"),
google: schema.makeUnsafe("google"),
googleVertex: schema.makeUnsafe("google-vertex"),
githubCopilot: schema.makeUnsafe("github-copilot"),
githubCopilotEnterprise: schema.makeUnsafe("github-copilot-enterprise"),
amazonBedrock: schema.makeUnsafe("amazon-bedrock"),
azure: schema.makeUnsafe("azure"),
openrouter: schema.makeUnsafe("openrouter"),
mistral: schema.makeUnsafe("mistral"),
})),
)