fix: ensure variants also work for completely custom models (#6481)

Co-authored-by: Daniel Smolsky <dannysmo@gmail.com>
This commit is contained in:
Aiden Cline
2025-12-30 14:37:32 -08:00
committed by GitHub
parent 3fe5d91372
commit 81fef60266
10 changed files with 955 additions and 27 deletions

View File

@@ -319,9 +319,7 @@ export const { use: useLocal, provider: LocalProvider } = createSimpleContext({
const provider = sync.data.provider.find((x) => x.id === m.providerID)
const info = provider?.models[m.modelID]
if (!info?.variants) return []
return Object.entries(info.variants)
.filter(([_, v]) => !v.disabled)
.map(([name]) => name)
return Object.keys(info.variants)
},
set(value: string | undefined) {
const m = currentModel()

View File

@@ -620,7 +620,24 @@ export namespace Config {
.extend({
whitelist: z.array(z.string()).optional(),
blacklist: z.array(z.string()).optional(),
models: z.record(z.string(), ModelsDev.Model.partial()).optional(),
models: z
.record(
z.string(),
ModelsDev.Model.partial().extend({
variants: z
.record(
z.string(),
z
.object({
disabled: z.boolean().optional().describe("Disable this variant for the model"),
})
.catchall(z.any()),
)
.optional()
.describe("Variant-specific configuration"),
}),
)
.optional(),
options: z
.object({
apiKey: z.string().optional(),

View File

@@ -60,6 +60,7 @@ export namespace ModelsDev {
options: z.record(z.string(), z.any()),
headers: z.record(z.string(), z.string()).optional(),
provider: z.object({ npm: z.string() }).optional(),
variants: z.record(z.string(), z.record(z.string(), z.any())).optional(),
})
export type Model = z.infer<typeof Model>

View File

@@ -1,7 +1,7 @@
import z from "zod"
import fuzzysort from "fuzzysort"
import { Config } from "../config/config"
import { mapValues, mergeDeep, sortBy } from "remeda"
import { mapValues, mergeDeep, omit, pickBy, sortBy } from "remeda"
import { NoSuchModelError, type Provider as SDK } from "ai"
import { Log } from "../util/log"
import { BunProc } from "../bun"
@@ -405,16 +405,6 @@ export namespace Provider {
},
}
export const Variant = z
.object({
disabled: z.boolean(),
})
.catchall(z.any())
.meta({
ref: "Variant",
})
export type Variant = z.infer<typeof Variant>
export const Model = z
.object({
id: z.string(),
@@ -478,7 +468,7 @@ export namespace Provider {
options: z.record(z.string(), z.any()),
headers: z.record(z.string(), z.string()),
release_date: z.string(),
variants: z.record(z.string(), Variant).optional(),
variants: z.record(z.string(), z.record(z.string(), z.any())).optional(),
})
.meta({
ref: "Model",
@@ -561,7 +551,7 @@ export namespace Provider {
variants: {},
}
m.variants = mapValues(ProviderTransform.variants(m), (v) => ({ disabled: false, ...v }))
m.variants = mapValues(ProviderTransform.variants(m), (v) => v)
return m
}
@@ -697,7 +687,13 @@ export namespace Provider {
headers: mergeDeep(existingModel?.headers ?? {}, model.headers ?? {}),
family: model.family ?? existingModel?.family ?? "",
release_date: model.release_date ?? existingModel?.release_date ?? "",
variants: {},
}
const merged = mergeDeep(ProviderTransform.variants(parsedModel), model.variants ?? {})
parsedModel.variants = mapValues(
pickBy(merged, (v) => !v.disabled),
(v) => omit(v, ["disabled"]),
)
parsed.models[modelID] = parsedModel
}
database[providerID] = parsed
@@ -822,6 +818,16 @@ export namespace Provider {
(configProvider?.whitelist && !configProvider.whitelist.includes(modelID))
)
delete provider.models[modelID]
// Filter out disabled variants from config
const configVariants = configProvider?.models?.[modelID]?.variants
if (configVariants && model.variants) {
const merged = mergeDeep(model.variants, configVariants)
model.variants = mapValues(
pickBy(merged, (v) => !v.disabled),
(v) => omit(v, ["disabled"]),
)
}
}
if (Object.keys(provider.models).length === 0) {

View File

@@ -246,7 +246,7 @@ export namespace ProviderTransform {
const WIDELY_SUPPORTED_EFFORTS = ["low", "medium", "high"]
const OPENAI_EFFORTS = ["none", "minimal", ...WIDELY_SUPPORTED_EFFORTS, "xhigh"]
export function variants(model: Provider.Model) {
export function variants(model: Provider.Model): Record<string, Record<string, any>> {
if (!model.capabilities.reasoning) return {}
const id = model.id.toLowerCase()

View File

@@ -82,13 +82,14 @@ export namespace LLM {
}
const provider = await Provider.getProvider(input.model.providerID)
const variant = input.model.variants && input.user.variant ? input.model.variants[input.user.variant] : undefined
const small = input.small ? ProviderTransform.smallOptions(input.model) : {}
const variant = input.model.variants && input.user.variant ? input.model.variants[input.user.variant] : {}
const options = pipe(
ProviderTransform.options(input.model, input.sessionID, provider.options),
mergeDeep(input.small ? ProviderTransform.smallOptions(input.model) : {}),
mergeDeep(small),
mergeDeep(input.model.options),
mergeDeep(input.agent.options),
mergeDeep(variant && !variant.disabled ? variant : {}),
mergeDeep(variant),
)
const params = await Plugin.trigger(