v2 message format and upgrade to ai sdk v5 (#743)

Co-authored-by: GitHub Action <action@github.com>
Co-authored-by: Liang-Shih Lin <liangshihlin@proton.me>
Co-authored-by: Dominik Engelhardt <dominikengelhardt@ymail.com>
Co-authored-by: Jay V <air@live.ca>
Co-authored-by: adamdottv <2363879+adamdottv@users.noreply.github.com>
This commit is contained in:
Dax
2025-07-07 15:53:43 -04:00
committed by GitHub
parent 76b2e4539c
commit f884766445
116 changed files with 4707 additions and 6950 deletions

View File

@@ -91,8 +91,7 @@ export namespace Provider {
if (!info || info.type !== "oauth") return
if (!info.access || info.expires < Date.now()) {
const tokens = await copilot.access(info.refresh)
if (!tokens)
throw new Error("GitHub Copilot authentication expired")
if (!tokens) throw new Error("GitHub Copilot authentication expired")
await Auth.set("github-copilot", {
type: "oauth",
...tokens,
@@ -101,15 +100,9 @@ export namespace Provider {
}
let isAgentCall = false
try {
const body =
typeof init.body === "string"
? JSON.parse(init.body)
: init.body
const body = typeof init.body === "string" ? JSON.parse(init.body) : init.body
if (body?.messages) {
isAgentCall = body.messages.some(
(msg: any) =>
msg.role && ["tool", "assistant"].includes(msg.role),
)
isAgentCall = body.messages.some((msg: any) => msg.role && ["tool", "assistant"].includes(msg.role))
}
} catch {}
const headers = {
@@ -138,14 +131,11 @@ export namespace Provider {
}
},
"amazon-bedrock": async () => {
if (!process.env["AWS_PROFILE"] && !process.env["AWS_ACCESS_KEY_ID"])
return { autoload: false }
if (!process.env["AWS_PROFILE"] && !process.env["AWS_ACCESS_KEY_ID"]) return { autoload: false }
const region = process.env["AWS_REGION"] ?? "us-east-1"
const { fromNodeProviderChain } = await import(
await BunProc.install("@aws-sdk/credential-providers")
)
const { fromNodeProviderChain } = await import(await BunProc.install("@aws-sdk/credential-providers"))
return {
autoload: true,
options: {
@@ -157,9 +147,7 @@ export namespace Provider {
switch (regionPrefix) {
case "us": {
const modelRequiresPrefix = ["claude", "deepseek"].some((m) =>
modelID.includes(m),
)
const modelRequiresPrefix = ["claude", "deepseek"].some((m) => modelID.includes(m))
if (modelRequiresPrefix) {
modelID = `${regionPrefix}.${modelID}`
}
@@ -174,25 +162,18 @@ export namespace Provider {
"eu-south-1",
"eu-south-2",
].some((r) => region.includes(r))
const modelRequiresPrefix = [
"claude",
"nova-lite",
"nova-micro",
"llama3",
"pixtral",
].some((m) => modelID.includes(m))
const modelRequiresPrefix = ["claude", "nova-lite", "nova-micro", "llama3", "pixtral"].some((m) =>
modelID.includes(m),
)
if (regionRequiresPrefix && modelRequiresPrefix) {
modelID = `${regionPrefix}.${modelID}`
}
break
}
case "ap": {
const modelRequiresPrefix = [
"claude",
"nova-lite",
"nova-micro",
"nova-pro",
].some((m) => modelID.includes(m))
const modelRequiresPrefix = ["claude", "nova-lite", "nova-micro", "nova-pro"].some((m) =>
modelID.includes(m),
)
if (modelRequiresPrefix) {
regionPrefix = "apac"
modelID = `${regionPrefix}.${modelID}`
@@ -230,10 +211,7 @@ export namespace Provider {
options: Record<string, any>
}
} = {}
const models = new Map<
string,
{ info: ModelsDev.Model; language: LanguageModel }
>()
const models = new Map<string, { info: ModelsDev.Model; language: LanguageModel }>()
const sdk = new Map<string, SDK>()
log.info("init")
@@ -308,9 +286,7 @@ export namespace Provider {
database[providerID] = parsed
}
const disabled = await Config.get().then(
(cfg) => new Set(cfg.disabled_providers ?? []),
)
const disabled = await Config.get().then((cfg) => new Set(cfg.disabled_providers ?? []))
// load env
for (const [providerID, provider] of Object.entries(database)) {
if (disabled.has(providerID)) continue
@@ -337,12 +313,7 @@ export namespace Provider {
if (disabled.has(providerID)) continue
const result = await fn(database[providerID])
if (result && (result.autoload || providers[providerID])) {
mergeProvider(
providerID,
result.options ?? {},
"custom",
result.getModel,
)
mergeProvider(providerID, result.options ?? {}, "custom", result.getModel)
}
}
@@ -379,7 +350,7 @@ export namespace Provider {
const existing = s.sdk.get(provider.id)
if (existing) return existing
const pkg = provider.npm ?? provider.id
const mod = await import(await BunProc.install(pkg, "latest"))
const mod = await import(await BunProc.install(pkg, "beta"))
const fn = mod[Object.keys(mod).find((key) => key.startsWith("create"))!]
const loaded = fn(s.providers[provider.id]?.options)
s.sdk.set(provider.id, loaded)
@@ -406,9 +377,7 @@ export namespace Provider {
const sdk = await getSDK(provider.info)
try {
const language = provider.getModel
? await provider.getModel(sdk, modelID)
: sdk.languageModel(modelID)
const language = provider.getModel ? await provider.getModel(sdk, modelID) : sdk.languageModel(modelID)
log.info("found", { providerID, modelID })
s.models.set(key, {
info,
@@ -435,10 +404,7 @@ export namespace Provider {
export function sort(models: ModelsDev.Model[]) {
return sortBy(
models,
[
(model) => priority.findIndex((filter) => model.id.includes(filter)),
"desc",
],
[(model) => priority.findIndex((filter) => model.id.includes(filter)), "desc"],
[(model) => (model.id.includes("latest") ? 0 : 1), "asc"],
[(model) => model.id, "desc"],
)
@@ -449,11 +415,7 @@ export namespace Provider {
if (cfg.model) return parseModel(cfg.model)
const provider = await list()
.then((val) => Object.values(val))
.then((x) =>
x.find(
(p) => !cfg.provider || Object.keys(cfg.provider).includes(p.info.id),
),
)
.then((x) => x.find((p) => !cfg.provider || Object.keys(cfg.provider).includes(p.info.id)))
if (!provider) throw new Error("no providers found")
const [model] = sort(Object.values(provider.info.models))
if (!model) throw new Error("no models found")
@@ -536,9 +498,11 @@ export namespace Provider {
if (schema instanceof z.ZodUnion) {
return z.union(
schema.options.map((option: z.ZodTypeAny) =>
optionalToNullable(option),
) as [z.ZodTypeAny, z.ZodTypeAny, ...z.ZodTypeAny[]],
schema.options.map((option: z.ZodTypeAny) => optionalToNullable(option)) as [
z.ZodTypeAny,
z.ZodTypeAny,
...z.ZodTypeAny[],
],
)
}

View File

@@ -1,19 +1,15 @@
import type { LanguageModelV1Prompt } from "ai"
import type { ModelMessage } from "ai"
import { unique } from "remeda"
export namespace ProviderTransform {
export function message(
msgs: LanguageModelV1Prompt,
providerID: string,
modelID: string,
) {
export function message(msgs: ModelMessage[], providerID: string, modelID: string) {
if (providerID === "anthropic" || modelID.includes("anthropic")) {
const system = msgs.filter((msg) => msg.role === "system").slice(0, 2)
const final = msgs.filter((msg) => msg.role !== "system").slice(-2)
for (const msg of unique([...system, ...final])) {
msg.providerMetadata = {
...msg.providerMetadata,
msg.providerOptions = {
...msg.providerOptions,
anthropic: {
cacheControl: { type: "ephemeral" },
},
@@ -28,8 +24,8 @@ export namespace ProviderTransform {
const final = msgs.filter((msg) => msg.role !== "system").slice(-2)
for (const msg of unique([...system, ...final])) {
msg.providerMetadata = {
...msg.providerMetadata,
msg.providerOptions = {
...msg.providerOptions,
bedrock: {
cachePoint: { type: "ephemeral" },
},