import z from "zod" import path from "path" import { App } from "../app/app" import { Config } from "../config/config" import { mergeDeep, sortBy } from "remeda" import { NoSuchModelError, type LanguageModel, type Provider as SDK } from "ai" import { Log } from "../util/log" import { BunProc } from "../bun" import { BashTool } from "../tool/bash" import { EditTool } from "../tool/edit" import { WebFetchTool } from "../tool/webfetch" import { GlobTool } from "../tool/glob" import { GrepTool } from "../tool/grep" import { ListTool } from "../tool/ls" import { LspDiagnosticTool } from "../tool/lsp-diagnostics" import { LspHoverTool } from "../tool/lsp-hover" import { PatchTool } from "../tool/patch" import { ReadTool } from "../tool/read" import type { Tool } from "../tool/tool" import { WriteTool } from "../tool/write" import { TodoReadTool, TodoWriteTool } from "../tool/todo" import { AuthAnthropic } from "../auth/anthropic" import { ModelsDev } from "./models" import { NamedError } from "../util/error" import { Auth } from "../auth" import { TaskTool } from "../tool/task" import { GlobalConfig } from "../global/config" import { Global } from "../global" export namespace Provider { const log = Log.create({ service: "provider" }) type CustomLoader = ( provider: ModelsDev.Provider, ) => Promise | false> type Source = "env" | "config" | "custom" | "api" const CUSTOM_LOADERS: Record = { async anthropic(provider) { const access = await AuthAnthropic.access() if (!access) return false for (const model of Object.values(provider.models)) { model.cost = { input: 0, output: 0, } } return { apiKey: "", async fetch(input: any, init: any) { const access = await AuthAnthropic.access() const headers = { ...init.headers, authorization: `Bearer ${access}`, "anthropic-beta": "oauth-2025-04-20", } delete headers["x-api-key"] return fetch(input, { ...init, headers, }) }, } }, "amazon-bedrock": async () => { if (!process.env["AWS_PROFILE"]) return false const { fromNodeProviderChain } = await import( await BunProc.install("@aws-sdk/credential-providers") ) return { region: process.env["AWS_REGION"] ?? "us-east-1", credentialProvider: fromNodeProviderChain(), } }, } const state = App.state("provider", async () => { const config = await Config.get() const database = await ModelsDev.get() const providers: { [providerID: string]: { source: Source info: ModelsDev.Provider options: Record } } = {} const models = new Map< string, { info: ModelsDev.Model; language: LanguageModel } >() const sdk = new Map() log.info("init") function mergeProvider( id: string, options: Record, source: Source, ) { const provider = providers[id] if (!provider) { providers[id] = { source, info: database[id], options, } return } provider.options = mergeDeep(provider.options, options) provider.source = source } const configProviders = Object.entries(config.provider ?? {}) for await (const providerPath of new Bun.Glob("*/provider.toml").scan({ cwd: Global.Path.providers, })) { const [providerID] = providerPath.split("/") const toml = await import( path.join(Global.Path.providers, providerPath), { with: { type: "toml", }, } ).then((mod) => mod.default) toml.models = {} const modelsPath = path.join(Global.Path.providers, providerID, "models") for await (const modelPath of new Bun.Glob("**/*.toml").scan({ cwd: modelsPath, })) { const modelID = modelPath.slice(0, -5) toml.models[modelID] = await import(path.join(modelsPath, modelPath), { with: { type: "toml", }, }) } configProviders.unshift([providerID, toml]) } for (const [providerID, provider] of configProviders) { const existing = database[providerID] const parsed: ModelsDev.Provider = { id: providerID, npm: provider.npm ?? existing?.npm, name: provider.name ?? existing?.name ?? providerID, env: provider.env ?? existing?.env ?? [], models: existing?.models ?? {}, } for (const [modelID, model] of Object.entries(provider.models ?? {})) { const existing = parsed.models[modelID] const parsedModel: ModelsDev.Model = { id: modelID, name: model.name ?? existing?.name ?? modelID, attachment: model.attachment ?? existing?.attachment ?? false, reasoning: model.reasoning ?? existing?.reasoning ?? false, temperature: model.temperature ?? existing?.temperature ?? false, cost: model.cost ?? existing?.cost ?? { input: 0, output: 0, inputCached: 0, outputCached: 0, }, limit: model.limit ?? existing?.limit ?? { context: 0, output: 0, }, } parsed.models[modelID] = parsedModel } database[providerID] = parsed } const disabled = await GlobalConfig.get().then( (cfg) => new Set(cfg.disabled_providers ?? []), ) // load env for (const [providerID, provider] of Object.entries(database)) { if (disabled.has(providerID)) continue if (provider.env.some((item) => process.env[item])) { mergeProvider(providerID, {}, "env") } } // load apikeys for (const [providerID, provider] of Object.entries(await Auth.all())) { if (disabled.has(providerID)) continue if (provider.type === "api") { mergeProvider(providerID, { apiKey: provider.key }, "api") } } // load custom for (const [providerID, fn] of Object.entries(CUSTOM_LOADERS)) { if (disabled.has(providerID)) continue const result = await fn(database[providerID]) if (result) mergeProvider(providerID, result, "custom") } // load config for (const [providerID, provider] of Object.entries( config.provider ?? {}, )) { mergeProvider(providerID, provider.options ?? {}, "config") } for (const providerID of Object.keys(providers)) { log.info("found", { providerID }) } return { models, providers, sdk, } }) export async function list() { return state().then((state) => state.providers) } async function getSDK(provider: ModelsDev.Provider) { return (async () => { using _ = log.time("getSDK", { providerID: provider.id, }) const s = await state() const existing = s.sdk.get(provider.id) if (existing) return existing const pkg = provider.npm ?? provider.id const mod = await import(await BunProc.install(pkg, "latest")) const fn = mod[Object.keys(mod).find((key) => key.startsWith("create"))!] const loaded = fn(s.providers[provider.id]?.options) s.sdk.set(provider.id, loaded) return loaded as SDK })().catch((e) => { throw new InitError({ providerID: provider.id }, { cause: e }) }) } export async function getModel(providerID: string, modelID: string) { const key = `${providerID}/${modelID}` const s = await state() if (s.models.has(key)) return s.models.get(key)! log.info("getModel", { providerID, modelID, }) const provider = s.providers[providerID] if (!provider) throw new ModelNotFoundError({ providerID, modelID }) const info = provider.info.models[modelID] if (!info) throw new ModelNotFoundError({ providerID, modelID }) const sdk = await getSDK(provider.info) try { const language = // @ts-expect-error "responses" in sdk ? sdk.responses(modelID) : sdk.languageModel(modelID) log.info("found", { providerID, modelID }) s.models.set(key, { info, language, }) return { info, language, } } catch (e) { if (e instanceof NoSuchModelError) throw new ModelNotFoundError( { modelID: modelID, providerID, }, { cause: e }, ) throw e } } const priority = ["gemini-2.5-pro-preview", "codex-mini", "claude-sonnet-4"] export function sort(models: ModelsDev.Model[]) { return sortBy( models, [ (model) => priority.findIndex((filter) => model.id.includes(filter)), "desc", ], [(model) => (model.id.includes("latest") ? 0 : 1), "asc"], [(model) => model.id, "desc"], ) } export async function defaultModel() { const cfg = await GlobalConfig.get() const provider = await list() .then((val) => Object.values(val)) .then((x) => x.find((p) => !cfg.provider || cfg.provider === p.info.id)) if (!provider) throw new Error("no providers found") const [model] = sort(Object.values(provider.info.models)) if (!model) throw new Error("no models found") return { providerID: provider.info.id, modelID: model.id, } } const TOOLS = [ BashTool, EditTool, WebFetchTool, GlobTool, GrepTool, ListTool, LspDiagnosticTool, LspHoverTool, PatchTool, ReadTool, EditTool, // MultiEditTool, WriteTool, TodoWriteTool, TaskTool, TodoReadTool, ] const TOOL_MAPPING: Record = { anthropic: TOOLS.filter((t) => t.id !== "opencode.patch"), openai: TOOLS.map((t) => ({ ...t, parameters: optionalToNullable(t.parameters), })), google: TOOLS, } export async function tools(providerID: string) { /* const cfg = await Config.get() if (cfg.tool?.provider?.[providerID]) return cfg.tool.provider[providerID].map( (id) => TOOLS.find((t) => t.id === id)!, ) */ return TOOL_MAPPING[providerID] ?? TOOLS } function optionalToNullable(schema: z.ZodTypeAny): z.ZodTypeAny { if (schema instanceof z.ZodObject) { const shape = schema.shape const newShape: Record = {} for (const [key, value] of Object.entries(shape)) { const zodValue = value as z.ZodTypeAny if (zodValue instanceof z.ZodOptional) { newShape[key] = zodValue.unwrap().nullable() } else { newShape[key] = optionalToNullable(zodValue) } } return z.object(newShape) } if (schema instanceof z.ZodArray) { return z.array(optionalToNullable(schema.element)) } if (schema instanceof z.ZodUnion) { return z.union( schema.options.map((option: z.ZodTypeAny) => optionalToNullable(option), ) as [z.ZodTypeAny, z.ZodTypeAny, ...z.ZodTypeAny[]], ) } return schema } export const ModelNotFoundError = NamedError.create( "ProviderModelNotFoundError", z.object({ providerID: z.string(), modelID: z.string(), }), ) export const InitError = NamedError.create( "ProviderInitError", z.object({ providerID: z.string(), }), ) export const AuthError = NamedError.create( "ProviderAuthError", z.object({ providerID: z.string(), message: z.string(), }), ) }