mirror of
https://gitea.toothfairyai.com/ToothFairyAI/tf_code.git
synced 2026-04-02 15:13:46 +00:00
refactor(provider): flow branded ProviderID/ModelID through internal signatures (#17182)
This commit is contained in:
@@ -1,6 +1,7 @@
|
||||
import { APICallError } from "ai"
|
||||
import { STATUS_CODES } from "http"
|
||||
import { iife } from "@/util/iife"
|
||||
import type { ProviderID } from "./schema"
|
||||
|
||||
export namespace ProviderError {
|
||||
// Adapted from overflow detection patterns in:
|
||||
@@ -40,7 +41,7 @@ export namespace ProviderError {
|
||||
return /^4(00|13)\s*(status code)?\s*\(no body\)/i.test(message)
|
||||
}
|
||||
|
||||
function message(providerID: string, e: APICallError) {
|
||||
function message(providerID: ProviderID, e: APICallError) {
|
||||
return iife(() => {
|
||||
const msg = e.message
|
||||
if (msg === "") {
|
||||
@@ -164,7 +165,7 @@ export namespace ProviderError {
|
||||
metadata?: Record<string, string>
|
||||
}
|
||||
|
||||
export function parseAPICallError(input: { providerID: string; error: APICallError }): ParsedAPICallError {
|
||||
export function parseAPICallError(input: { providerID: ProviderID; error: APICallError }): ParsedAPICallError {
|
||||
const m = message(input.providerID, input.error)
|
||||
if (isOverflow(m) || input.error.statusCode === 413) {
|
||||
return {
|
||||
|
||||
@@ -845,7 +845,7 @@ export namespace Provider {
|
||||
const disabled = new Set(config.disabled_providers ?? [])
|
||||
const enabled = config.enabled_providers ? new Set(config.enabled_providers) : null
|
||||
|
||||
function isProviderAllowed(providerID: string): boolean {
|
||||
function isProviderAllowed(providerID: ProviderID): boolean {
|
||||
if (enabled && !enabled.has(providerID)) return false
|
||||
if (disabled.has(providerID)) return false
|
||||
return true
|
||||
@@ -867,16 +867,16 @@ export namespace Provider {
|
||||
const githubCopilot = database["github-copilot"]
|
||||
database["github-copilot-enterprise"] = {
|
||||
...githubCopilot,
|
||||
id: ProviderID.make("github-copilot-enterprise"),
|
||||
id: ProviderID.githubCopilotEnterprise,
|
||||
name: "GitHub Copilot Enterprise",
|
||||
models: mapValues(githubCopilot.models, (model) => ({
|
||||
...model,
|
||||
providerID: ProviderID.make("github-copilot-enterprise"),
|
||||
providerID: ProviderID.githubCopilotEnterprise,
|
||||
})),
|
||||
}
|
||||
}
|
||||
|
||||
function mergeProvider(providerID: string, provider: Partial<Info>) {
|
||||
function mergeProvider(providerID: ProviderID, provider: Partial<Info>) {
|
||||
const existing = providers[providerID]
|
||||
if (existing) {
|
||||
// @ts-expect-error
|
||||
@@ -974,7 +974,8 @@ export namespace Provider {
|
||||
|
||||
// load env
|
||||
const env = Env.all()
|
||||
for (const [providerID, provider] of Object.entries(database)) {
|
||||
for (const [id, provider] of Object.entries(database)) {
|
||||
const providerID = ProviderID.make(id)
|
||||
if (disabled.has(providerID)) continue
|
||||
const apiKey = provider.env.map((item) => env[item]).find(Boolean)
|
||||
if (!apiKey) continue
|
||||
@@ -985,7 +986,8 @@ export namespace Provider {
|
||||
}
|
||||
|
||||
// load apikeys
|
||||
for (const [providerID, provider] of Object.entries(await Auth.all())) {
|
||||
for (const [id, provider] of Object.entries(await Auth.all())) {
|
||||
const providerID = ProviderID.make(id)
|
||||
if (disabled.has(providerID)) continue
|
||||
if (provider.type === "api") {
|
||||
mergeProvider(providerID, {
|
||||
@@ -997,7 +999,7 @@ export namespace Provider {
|
||||
|
||||
for (const plugin of await Plugin.list()) {
|
||||
if (!plugin.auth) continue
|
||||
const providerID = plugin.auth.provider
|
||||
const providerID = ProviderID.make(plugin.auth.provider)
|
||||
if (disabled.has(providerID)) continue
|
||||
|
||||
// For github-copilot plugin, check if auth exists for either github-copilot or github-copilot-enterprise
|
||||
@@ -1006,7 +1008,7 @@ export namespace Provider {
|
||||
if (auth) hasAuth = true
|
||||
|
||||
// Special handling for github-copilot: also check for enterprise auth
|
||||
if (providerID === "github-copilot" && !hasAuth) {
|
||||
if (providerID === ProviderID.githubCopilot && !hasAuth) {
|
||||
const enterpriseAuth = await Auth.get("github-copilot-enterprise")
|
||||
if (enterpriseAuth) hasAuth = true
|
||||
}
|
||||
@@ -1023,8 +1025,8 @@ export namespace Provider {
|
||||
}
|
||||
|
||||
// If this is github-copilot plugin, also register for github-copilot-enterprise if auth exists
|
||||
if (providerID === "github-copilot") {
|
||||
const enterpriseProviderID = "github-copilot-enterprise"
|
||||
if (providerID === ProviderID.githubCopilot) {
|
||||
const enterpriseProviderID = ProviderID.githubCopilotEnterprise
|
||||
if (!disabled.has(enterpriseProviderID)) {
|
||||
const enterpriseAuth = await Auth.get(enterpriseProviderID)
|
||||
if (enterpriseAuth) {
|
||||
@@ -1042,7 +1044,8 @@ export namespace Provider {
|
||||
}
|
||||
}
|
||||
|
||||
for (const [providerID, fn] of Object.entries(CUSTOM_LOADERS)) {
|
||||
for (const [id, fn] of Object.entries(CUSTOM_LOADERS)) {
|
||||
const providerID = ProviderID.make(id)
|
||||
if (disabled.has(providerID)) continue
|
||||
const data = database[providerID]
|
||||
if (!data) {
|
||||
@@ -1059,7 +1062,8 @@ export namespace Provider {
|
||||
}
|
||||
|
||||
// load config
|
||||
for (const [providerID, provider] of configProviders) {
|
||||
for (const [id, provider] of configProviders) {
|
||||
const providerID = ProviderID.make(id)
|
||||
const partial: Partial<Info> = { source: "config" }
|
||||
if (provider.env) partial.env = provider.env
|
||||
if (provider.name) partial.name = provider.name
|
||||
@@ -1067,7 +1071,8 @@ export namespace Provider {
|
||||
mergeProvider(providerID, partial)
|
||||
}
|
||||
|
||||
for (const [providerID, provider] of Object.entries(providers)) {
|
||||
for (const [id, provider] of Object.entries(providers)) {
|
||||
const providerID = ProviderID.make(id)
|
||||
if (!isProviderAllowed(providerID)) {
|
||||
delete providers[providerID]
|
||||
continue
|
||||
@@ -1077,7 +1082,7 @@ export namespace Provider {
|
||||
|
||||
for (const [modelID, model] of Object.entries(provider.models)) {
|
||||
model.api.id = model.api.id ?? model.id ?? modelID
|
||||
if (modelID === "gpt-5-chat-latest" || (providerID === "openrouter" && modelID === "openai/gpt-5-chat"))
|
||||
if (modelID === "gpt-5-chat-latest" || (providerID === ProviderID.openrouter && modelID === "openai/gpt-5-chat"))
|
||||
delete provider.models[modelID]
|
||||
if (model.status === "alpha" && !Flag.OPENCODE_ENABLE_EXPERIMENTAL_MODELS) delete provider.models[modelID]
|
||||
if (model.status === "deprecated") delete provider.models[modelID]
|
||||
@@ -1230,11 +1235,11 @@ export namespace Provider {
|
||||
}
|
||||
}
|
||||
|
||||
export async function getProvider(providerID: string) {
|
||||
export async function getProvider(providerID: ProviderID) {
|
||||
return state().then((s) => s.providers[providerID])
|
||||
}
|
||||
|
||||
export async function getModel(providerID: string, modelID: string) {
|
||||
export async function getModel(providerID: ProviderID, modelID: ModelID) {
|
||||
const s = await state()
|
||||
const provider = s.providers[providerID]
|
||||
if (!provider) {
|
||||
@@ -1281,7 +1286,7 @@ export namespace Provider {
|
||||
}
|
||||
}
|
||||
|
||||
export async function closest(providerID: string, query: string[]) {
|
||||
export async function closest(providerID: ProviderID, query: string[]) {
|
||||
const s = await state()
|
||||
const provider = s.providers[providerID]
|
||||
if (!provider) return undefined
|
||||
@@ -1296,7 +1301,7 @@ export namespace Provider {
|
||||
}
|
||||
}
|
||||
|
||||
export async function getSmallModel(providerID: string) {
|
||||
export async function getSmallModel(providerID: ProviderID) {
|
||||
const cfg = await Config.get()
|
||||
|
||||
if (cfg.small_model) {
|
||||
@@ -1323,7 +1328,7 @@ export namespace Provider {
|
||||
priority = ["gpt-5-mini", "claude-haiku-4.5", ...priority]
|
||||
}
|
||||
for (const item of priority) {
|
||||
if (providerID === "amazon-bedrock") {
|
||||
if (providerID === ProviderID.amazonBedrock) {
|
||||
const crossRegionPrefixes = ["global.", "us.", "eu."]
|
||||
const candidates = Object.keys(provider.models).filter((m) => m.includes(item))
|
||||
|
||||
@@ -1332,22 +1337,22 @@ export namespace Provider {
|
||||
// 2. User's region prefix (us., eu.)
|
||||
// 3. Unprefixed model
|
||||
const globalMatch = candidates.find((m) => m.startsWith("global."))
|
||||
if (globalMatch) return getModel(providerID, globalMatch)
|
||||
if (globalMatch) return getModel(providerID, ModelID.make(globalMatch))
|
||||
|
||||
const region = provider.options?.region
|
||||
if (region) {
|
||||
const regionPrefix = region.split("-")[0]
|
||||
if (regionPrefix === "us" || regionPrefix === "eu") {
|
||||
const regionalMatch = candidates.find((m) => m.startsWith(`${regionPrefix}.`))
|
||||
if (regionalMatch) return getModel(providerID, regionalMatch)
|
||||
if (regionalMatch) return getModel(providerID, ModelID.make(regionalMatch))
|
||||
}
|
||||
}
|
||||
|
||||
const unprefixed = candidates.find((m) => !crossRegionPrefixes.some((p) => m.startsWith(p)))
|
||||
if (unprefixed) return getModel(providerID, unprefixed)
|
||||
if (unprefixed) return getModel(providerID, ModelID.make(unprefixed))
|
||||
} else {
|
||||
for (const model of Object.keys(provider.models)) {
|
||||
if (model.includes(item)) return getModel(providerID, model)
|
||||
if (model.includes(item)) return getModel(providerID, ModelID.make(model))
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@@ -11,6 +11,18 @@ export const ProviderID = providerIdSchema.pipe(
|
||||
withStatics((schema: typeof providerIdSchema) => ({
|
||||
make: (id: string) => schema.makeUnsafe(id),
|
||||
zod: z.string().pipe(z.custom<ProviderID>()),
|
||||
// Well-known providers
|
||||
opencode: schema.makeUnsafe("opencode"),
|
||||
anthropic: schema.makeUnsafe("anthropic"),
|
||||
openai: schema.makeUnsafe("openai"),
|
||||
google: schema.makeUnsafe("google"),
|
||||
googleVertex: schema.makeUnsafe("google-vertex"),
|
||||
githubCopilot: schema.makeUnsafe("github-copilot"),
|
||||
githubCopilotEnterprise: schema.makeUnsafe("github-copilot-enterprise"),
|
||||
amazonBedrock: schema.makeUnsafe("amazon-bedrock"),
|
||||
azure: schema.makeUnsafe("azure"),
|
||||
openrouter: schema.makeUnsafe("openrouter"),
|
||||
mistral: schema.makeUnsafe("mistral"),
|
||||
})),
|
||||
)
|
||||
|
||||
|
||||
Reference in New Issue
Block a user