mirror of
https://gitea.toothfairyai.com/ToothFairyAI/tf_code.git
synced 2026-03-30 05:43:55 +00:00
429 lines
13 KiB
TypeScript
429 lines
13 KiB
TypeScript
import z from "zod"
|
|
import { App } from "../app/app"
|
|
import { Config } from "../config/config"
|
|
import { mergeDeep, sortBy } from "remeda"
|
|
import { NoSuchModelError, type LanguageModel, type Provider as SDK } from "ai"
|
|
import { Log } from "../util/log"
|
|
import { BunProc } from "../bun"
|
|
import { Plugin } from "../plugin"
|
|
import { ModelsDev } from "./models"
|
|
import { NamedError } from "../util/error"
|
|
import { Auth } from "../auth"
|
|
|
|
export namespace Provider {
|
|
const log = Log.create({ service: "provider" })
|
|
|
|
type CustomLoader = (
|
|
provider: ModelsDev.Provider,
|
|
api?: string,
|
|
) => Promise<{
|
|
autoload: boolean
|
|
getModel?: (sdk: any, modelID: string) => Promise<any>
|
|
options?: Record<string, any>
|
|
}>
|
|
|
|
type Source = "env" | "config" | "custom" | "api"
|
|
|
|
const CUSTOM_LOADERS: Record<string, CustomLoader> = {
|
|
async anthropic() {
|
|
return {
|
|
autoload: false,
|
|
options: {
|
|
headers: {
|
|
"anthropic-beta":
|
|
"claude-code-20250219,interleaved-thinking-2025-05-14,fine-grained-tool-streaming-2025-05-14",
|
|
},
|
|
},
|
|
}
|
|
},
|
|
openai: async () => {
|
|
return {
|
|
autoload: false,
|
|
async getModel(sdk: any, modelID: string) {
|
|
return sdk.responses(modelID)
|
|
},
|
|
options: {},
|
|
}
|
|
},
|
|
azure: async () => {
|
|
return {
|
|
autoload: false,
|
|
async getModel(sdk: any, modelID: string) {
|
|
return sdk.responses(modelID)
|
|
},
|
|
options: {},
|
|
}
|
|
},
|
|
"amazon-bedrock": async () => {
|
|
if (!process.env["AWS_PROFILE"] && !process.env["AWS_ACCESS_KEY_ID"] && !process.env["AWS_BEARER_TOKEN_BEDROCK"])
|
|
return { autoload: false }
|
|
|
|
const region = process.env["AWS_REGION"] ?? "us-east-1"
|
|
|
|
const { fromNodeProviderChain } = await import(await BunProc.install("@aws-sdk/credential-providers"))
|
|
return {
|
|
autoload: true,
|
|
options: {
|
|
region,
|
|
credentialProvider: fromNodeProviderChain(),
|
|
},
|
|
async getModel(sdk: any, modelID: string) {
|
|
let regionPrefix = region.split("-")[0]
|
|
|
|
switch (regionPrefix) {
|
|
case "us": {
|
|
const modelRequiresPrefix = ["claude", "deepseek"].some((m) => modelID.includes(m))
|
|
if (modelRequiresPrefix) {
|
|
modelID = `${regionPrefix}.${modelID}`
|
|
}
|
|
break
|
|
}
|
|
case "eu": {
|
|
const regionRequiresPrefix = [
|
|
"eu-west-1",
|
|
"eu-west-3",
|
|
"eu-north-1",
|
|
"eu-central-1",
|
|
"eu-south-1",
|
|
"eu-south-2",
|
|
].some((r) => region.includes(r))
|
|
const modelRequiresPrefix = ["claude", "nova-lite", "nova-micro", "llama3", "pixtral"].some((m) =>
|
|
modelID.includes(m),
|
|
)
|
|
if (regionRequiresPrefix && modelRequiresPrefix) {
|
|
modelID = `${regionPrefix}.${modelID}`
|
|
}
|
|
break
|
|
}
|
|
case "ap": {
|
|
const modelRequiresPrefix = ["claude", "nova-lite", "nova-micro", "nova-pro"].some((m) =>
|
|
modelID.includes(m),
|
|
)
|
|
if (modelRequiresPrefix) {
|
|
regionPrefix = "apac"
|
|
modelID = `${regionPrefix}.${modelID}`
|
|
}
|
|
break
|
|
}
|
|
}
|
|
|
|
return sdk.languageModel(modelID)
|
|
},
|
|
}
|
|
},
|
|
openrouter: async () => {
|
|
return {
|
|
autoload: false,
|
|
options: {
|
|
headers: {
|
|
"HTTP-Referer": "https://opencode.ai/",
|
|
"X-Title": "opencode",
|
|
},
|
|
},
|
|
}
|
|
},
|
|
vercel: async () => {
|
|
return {
|
|
autoload: false,
|
|
options: {
|
|
headers: {
|
|
"http-referer": "https://opencode.ai/",
|
|
"x-title": "opencode",
|
|
},
|
|
},
|
|
}
|
|
},
|
|
}
|
|
|
|
const state = App.state("provider", async () => {
|
|
const config = await Config.get()
|
|
const database = await ModelsDev.get()
|
|
|
|
const providers: {
|
|
[providerID: string]: {
|
|
source: Source
|
|
info: ModelsDev.Provider
|
|
getModel?: (sdk: any, modelID: string) => Promise<any>
|
|
options: Record<string, any>
|
|
}
|
|
} = {}
|
|
const models = new Map<string, { info: ModelsDev.Model; language: LanguageModel }>()
|
|
const sdk = new Map<string, SDK>()
|
|
|
|
log.info("init")
|
|
|
|
function mergeProvider(
|
|
id: string,
|
|
options: Record<string, any>,
|
|
source: Source,
|
|
getModel?: (sdk: any, modelID: string) => Promise<any>,
|
|
) {
|
|
const provider = providers[id]
|
|
if (!provider) {
|
|
const info = database[id]
|
|
if (!info) return
|
|
if (info.api && !options["baseURL"]) options["baseURL"] = info.api
|
|
providers[id] = {
|
|
source,
|
|
info,
|
|
options,
|
|
getModel,
|
|
}
|
|
return
|
|
}
|
|
provider.options = mergeDeep(provider.options, options)
|
|
provider.source = source
|
|
provider.getModel = getModel ?? provider.getModel
|
|
}
|
|
|
|
const configProviders = Object.entries(config.provider ?? {})
|
|
|
|
for (const [providerID, provider] of configProviders) {
|
|
const existing = database[providerID]
|
|
const parsed: ModelsDev.Provider = {
|
|
id: providerID,
|
|
npm: provider.npm ?? existing?.npm,
|
|
name: provider.name ?? existing?.name ?? providerID,
|
|
env: provider.env ?? existing?.env ?? [],
|
|
api: provider.api ?? existing?.api,
|
|
models: existing?.models ?? {},
|
|
}
|
|
|
|
for (const [modelID, model] of Object.entries(provider.models ?? {})) {
|
|
const existing = parsed.models[modelID]
|
|
const parsedModel: ModelsDev.Model = {
|
|
id: modelID,
|
|
name: model.name ?? existing?.name ?? modelID,
|
|
release_date: model.release_date ?? existing?.release_date,
|
|
attachment: model.attachment ?? existing?.attachment ?? false,
|
|
reasoning: model.reasoning ?? existing?.reasoning ?? false,
|
|
temperature: model.temperature ?? existing?.temperature ?? false,
|
|
tool_call: model.tool_call ?? existing?.tool_call ?? true,
|
|
cost:
|
|
!model.cost && !existing?.cost
|
|
? {
|
|
input: 0,
|
|
output: 0,
|
|
cache_read: 0,
|
|
cache_write: 0,
|
|
}
|
|
: {
|
|
cache_read: 0,
|
|
cache_write: 0,
|
|
...existing?.cost,
|
|
...model.cost,
|
|
},
|
|
options: {
|
|
...existing?.options,
|
|
...model.options,
|
|
},
|
|
limit: model.limit ??
|
|
existing?.limit ?? {
|
|
context: 0,
|
|
output: 0,
|
|
},
|
|
}
|
|
parsed.models[modelID] = parsedModel
|
|
}
|
|
database[providerID] = parsed
|
|
}
|
|
|
|
const disabled = await Config.get().then((cfg) => new Set(cfg.disabled_providers ?? []))
|
|
// load env
|
|
for (const [providerID, provider] of Object.entries(database)) {
|
|
if (disabled.has(providerID)) continue
|
|
const apiKey = provider.env.map((item) => process.env[item]).at(0)
|
|
if (!apiKey) continue
|
|
mergeProvider(
|
|
providerID,
|
|
// only include apiKey if there's only one potential option
|
|
provider.env.length === 1 ? { apiKey } : {},
|
|
"env",
|
|
)
|
|
}
|
|
|
|
// load apikeys
|
|
for (const [providerID, provider] of Object.entries(await Auth.all())) {
|
|
if (disabled.has(providerID)) continue
|
|
if (provider.type === "api") {
|
|
mergeProvider(providerID, { apiKey: provider.key }, "api")
|
|
}
|
|
}
|
|
|
|
// load custom
|
|
for (const [providerID, fn] of Object.entries(CUSTOM_LOADERS)) {
|
|
if (disabled.has(providerID)) continue
|
|
const result = await fn(database[providerID])
|
|
if (result && (result.autoload || providers[providerID])) {
|
|
mergeProvider(providerID, result.options ?? {}, "custom", result.getModel)
|
|
}
|
|
}
|
|
|
|
for (const plugin of await Plugin.list()) {
|
|
if (!plugin.auth) continue
|
|
const providerID = plugin.auth.provider
|
|
if (disabled.has(providerID)) continue
|
|
const auth = await Auth.get(providerID)
|
|
if (!auth) continue
|
|
if (!plugin.auth.loader) continue
|
|
const options = await plugin.auth.loader(() => Auth.get(providerID) as any, database[plugin.auth.provider])
|
|
mergeProvider(plugin.auth.provider, options ?? {}, "custom")
|
|
}
|
|
|
|
// load config
|
|
for (const [providerID, provider] of configProviders) {
|
|
mergeProvider(providerID, provider.options ?? {}, "config")
|
|
}
|
|
|
|
for (const [providerID, provider] of Object.entries(providers)) {
|
|
if (Object.keys(provider.info.models).length === 0) {
|
|
delete providers[providerID]
|
|
continue
|
|
}
|
|
log.info("found", { providerID })
|
|
}
|
|
|
|
return {
|
|
models,
|
|
providers,
|
|
sdk,
|
|
}
|
|
})
|
|
|
|
export async function list() {
|
|
return state().then((state) => state.providers)
|
|
}
|
|
|
|
async function getSDK(provider: ModelsDev.Provider) {
|
|
return (async () => {
|
|
using _ = log.time("getSDK", {
|
|
providerID: provider.id,
|
|
})
|
|
const s = await state()
|
|
const existing = s.sdk.get(provider.id)
|
|
if (existing) return existing
|
|
const pkg = provider.npm ?? provider.id
|
|
const mod = await import(await BunProc.install(pkg, "latest"))
|
|
const fn = mod[Object.keys(mod).find((key) => key.startsWith("create"))!]
|
|
const loaded = fn({
|
|
name: provider.id,
|
|
...s.providers[provider.id]?.options,
|
|
})
|
|
s.sdk.set(provider.id, loaded)
|
|
return loaded as SDK
|
|
})().catch((e) => {
|
|
throw new InitError({ providerID: provider.id }, { cause: e })
|
|
})
|
|
}
|
|
|
|
export async function getProvider(providerID: string) {
|
|
return state().then((s) => s.providers[providerID])
|
|
}
|
|
|
|
export async function getModel(providerID: string, modelID: string) {
|
|
const key = `${providerID}/${modelID}`
|
|
const s = await state()
|
|
if (s.models.has(key)) return s.models.get(key)!
|
|
|
|
log.info("getModel", {
|
|
providerID,
|
|
modelID,
|
|
})
|
|
|
|
const provider = s.providers[providerID]
|
|
if (!provider) throw new ModelNotFoundError({ providerID, modelID })
|
|
const info = provider.info.models[modelID]
|
|
if (!info) throw new ModelNotFoundError({ providerID, modelID })
|
|
const sdk = await getSDK(provider.info)
|
|
|
|
try {
|
|
const language = provider.getModel ? await provider.getModel(sdk, modelID) : sdk.languageModel(modelID)
|
|
log.info("found", { providerID, modelID })
|
|
s.models.set(key, {
|
|
info,
|
|
language,
|
|
})
|
|
return {
|
|
info,
|
|
language,
|
|
}
|
|
} catch (e) {
|
|
if (e instanceof NoSuchModelError)
|
|
throw new ModelNotFoundError(
|
|
{
|
|
modelID: modelID,
|
|
providerID,
|
|
},
|
|
{ cause: e },
|
|
)
|
|
throw e
|
|
}
|
|
}
|
|
|
|
export async function getSmallModel(providerID: string) {
|
|
const cfg = await Config.get()
|
|
|
|
if (cfg.small_model) {
|
|
const parsed = parseModel(cfg.small_model)
|
|
return getModel(parsed.providerID, parsed.modelID)
|
|
}
|
|
|
|
const provider = await state().then((state) => state.providers[providerID])
|
|
if (!provider) return
|
|
const priority = ["3-5-haiku", "3.5-haiku", "gemini-2.5-flash", "gpt-5-nano"]
|
|
for (const item of priority) {
|
|
for (const model of Object.keys(provider.info.models)) {
|
|
if (model.includes(item)) return getModel(providerID, model)
|
|
}
|
|
}
|
|
}
|
|
|
|
const priority = ["gemini-2.5-pro-preview", "gpt-5", "claude-sonnet-4"]
|
|
export function sort(models: ModelsDev.Model[]) {
|
|
return sortBy(
|
|
models,
|
|
[(model) => priority.findIndex((filter) => model.id.includes(filter)), "desc"],
|
|
[(model) => (model.id.includes("latest") ? 0 : 1), "asc"],
|
|
[(model) => model.id, "desc"],
|
|
)
|
|
}
|
|
|
|
export async function defaultModel() {
|
|
const cfg = await Config.get()
|
|
if (cfg.model) return parseModel(cfg.model)
|
|
const provider = await list()
|
|
.then((val) => Object.values(val))
|
|
.then((x) => x.find((p) => !cfg.provider || Object.keys(cfg.provider).includes(p.info.id)))
|
|
if (!provider) throw new Error("no providers found")
|
|
const [model] = sort(Object.values(provider.info.models))
|
|
if (!model) throw new Error("no models found")
|
|
return {
|
|
providerID: provider.info.id,
|
|
modelID: model.id,
|
|
}
|
|
}
|
|
|
|
export function parseModel(model: string) {
|
|
const [providerID, ...rest] = model.split("/")
|
|
return {
|
|
providerID: providerID,
|
|
modelID: rest.join("/"),
|
|
}
|
|
}
|
|
|
|
export const ModelNotFoundError = NamedError.create(
|
|
"ProviderModelNotFoundError",
|
|
z.object({
|
|
providerID: z.string(),
|
|
modelID: z.string(),
|
|
}),
|
|
)
|
|
|
|
export const InitError = NamedError.create(
|
|
"ProviderInitError",
|
|
z.object({
|
|
providerID: z.string(),
|
|
}),
|
|
)
|
|
}
|