This commit is contained in:
Dax Raad
2025-06-09 14:01:11 -04:00
parent fa3253d1b6
commit 021fd3fcb5
8 changed files with 734 additions and 198 deletions

View File

@@ -23,6 +23,7 @@ import { WriteTool } from "../tool/write"
import { TodoReadTool, TodoWriteTool } from "../tool/todo"
import { AuthAnthropic } from "../auth/anthropic"
import { ModelsDev } from "./models"
import { NamedError } from "../util/error"
export namespace Provider {
const log = Log.create({ service: "provider" })
@@ -75,9 +76,18 @@ export namespace Provider {
string,
(provider: Info) => Promise<Record<string, any> | false>
> = {
anthropic: async () => {
async anthropic(provider) {
const access = await AuthAnthropic.access()
if (access)
if (access) {
// claude sub doesn't have usage cost
for (const model of Object.values(provider.models)) {
model.cost = {
input: 0,
inputCached: 0,
output: 0,
outputCached: 0,
}
}
return {
apiKey: "",
headers: {
@@ -85,16 +95,15 @@ export namespace Provider {
"anthropic-beta": "oauth-2025-04-20",
},
}
return env("ANTHROPIC_API_KEY")
}
return env("ANTHROPIC_API_KEY")(provider)
},
google: env("GOOGLE_GENERATIVE_AI_API_KEY"),
openai: env("OPENAI_API_KEY"),
}
const state = App.state("provider", async () => {
log.info("loading config")
const config = await Config.get()
log.info("loading providers")
const database: Record<string, Provider.Info> = await ModelsDev.get()
const providers: {
@@ -134,6 +143,10 @@ export namespace Provider {
}
}
for (const providerID of Object.keys(providers)) {
log.info("loaded", { providerID })
}
return {
models,
providers,
@@ -148,28 +161,32 @@ export namespace Provider {
}
async function getSDK(providerID: string) {
const s = await state()
if (s.sdk.has(providerID)) return s.sdk.get(providerID)!
const dir = path.join(
Global.Path.cache,
`node_modules`,
`@ai-sdk`,
providerID,
)
if (!(await Bun.file(path.join(dir, "package.json")).exists())) {
log.info("installing", {
return (async () => {
const s = await state()
const existing = s.sdk.get(providerID)
if (existing) return existing
const dir = path.join(
Global.Path.cache,
`node_modules`,
`@ai-sdk`,
providerID,
})
await BunProc.run(["add", `@ai-sdk/${providerID}@alpha`], {
cwd: Global.Path.cache,
})
}
const mod = await import(path.join(dir))
const fn = mod[Object.keys(mod).find((key) => key.startsWith("create"))!]
const loaded = fn(s.providers[providerID]?.options)
s.sdk.set(providerID, loaded)
return loaded as SDK
)
if (!(await Bun.file(path.join(dir, "package.json")).exists())) {
log.info("installing", {
providerID,
})
await BunProc.run(["add", `@ai-sdk/${providerID}@alpha`], {
cwd: Global.Path.cache,
})
}
const mod = await import(path.join(dir))
const fn = mod[Object.keys(mod).find((key) => key.startsWith("create"))!]
const loaded = fn(s.providers[providerID]?.options)
s.sdk.set(providerID, loaded)
return loaded as SDK
})().catch((e) => {
throw new InitError({ providerID: providerID }, { cause: e })
})
}
export async function getModel(providerID: string, modelID: string) {
@@ -183,12 +200,11 @@ export namespace Provider {
})
const provider = s.providers[providerID]
if (!provider) throw new ModelNotFoundError(modelID)
if (!provider) throw new ModelNotFoundError({ providerID, modelID })
const info = provider.info.models[modelID]
if (!info) throw new ModelNotFoundError(modelID)
if (!info) throw new ModelNotFoundError({ providerID, modelID })
const sdk = await getSDK(providerID)
if (!sdk) throw new ModelNotFoundError(modelID)
try {
const language = sdk.languageModel(modelID)
@@ -202,7 +218,14 @@ export namespace Provider {
language,
}
} catch (e) {
if (e instanceof NoSuchModelError) throw new ModelNotFoundError(modelID)
if (e instanceof NoSuchModelError)
throw new ModelNotFoundError(
{
modelID: modelID,
providerID,
},
{ cause: e },
)
throw e
}
}
@@ -259,9 +282,26 @@ export namespace Provider {
return TOOL_MAPPING[providerID] ?? TOOLS
}
class ModelNotFoundError extends Error {
constructor(public readonly model: string) {
super()
}
}
export const ModelNotFoundError = NamedError.create(
"ProviderModelNotFoundError",
z.object({
providerID: z.string(),
modelID: z.string(),
}),
)
export const InitError = NamedError.create(
"ProviderInitError",
z.object({
providerID: z.string(),
}),
)
export const AuthError = NamedError.create(
"ProviderAuthError",
z.object({
providerID: z.string(),
message: z.string(),
}),
)
}