mirror of
https://gitea.toothfairyai.com/ToothFairyAI/tf_code.git
synced 2026-04-02 23:23:45 +00:00
v2 message format and upgrade to ai sdk v5 (#743)
Co-authored-by: GitHub Action <action@github.com> Co-authored-by: Liang-Shih Lin <liangshihlin@proton.me> Co-authored-by: Dominik Engelhardt <dominikengelhardt@ymail.com> Co-authored-by: Jay V <air@live.ca> Co-authored-by: adamdottv <2363879+adamdottv@users.noreply.github.com>
This commit is contained in:
@@ -91,8 +91,7 @@ export namespace Provider {
|
||||
if (!info || info.type !== "oauth") return
|
||||
if (!info.access || info.expires < Date.now()) {
|
||||
const tokens = await copilot.access(info.refresh)
|
||||
if (!tokens)
|
||||
throw new Error("GitHub Copilot authentication expired")
|
||||
if (!tokens) throw new Error("GitHub Copilot authentication expired")
|
||||
await Auth.set("github-copilot", {
|
||||
type: "oauth",
|
||||
...tokens,
|
||||
@@ -101,15 +100,9 @@ export namespace Provider {
|
||||
}
|
||||
let isAgentCall = false
|
||||
try {
|
||||
const body =
|
||||
typeof init.body === "string"
|
||||
? JSON.parse(init.body)
|
||||
: init.body
|
||||
const body = typeof init.body === "string" ? JSON.parse(init.body) : init.body
|
||||
if (body?.messages) {
|
||||
isAgentCall = body.messages.some(
|
||||
(msg: any) =>
|
||||
msg.role && ["tool", "assistant"].includes(msg.role),
|
||||
)
|
||||
isAgentCall = body.messages.some((msg: any) => msg.role && ["tool", "assistant"].includes(msg.role))
|
||||
}
|
||||
} catch {}
|
||||
const headers = {
|
||||
@@ -138,14 +131,11 @@ export namespace Provider {
|
||||
}
|
||||
},
|
||||
"amazon-bedrock": async () => {
|
||||
if (!process.env["AWS_PROFILE"] && !process.env["AWS_ACCESS_KEY_ID"])
|
||||
return { autoload: false }
|
||||
if (!process.env["AWS_PROFILE"] && !process.env["AWS_ACCESS_KEY_ID"]) return { autoload: false }
|
||||
|
||||
const region = process.env["AWS_REGION"] ?? "us-east-1"
|
||||
|
||||
const { fromNodeProviderChain } = await import(
|
||||
await BunProc.install("@aws-sdk/credential-providers")
|
||||
)
|
||||
const { fromNodeProviderChain } = await import(await BunProc.install("@aws-sdk/credential-providers"))
|
||||
return {
|
||||
autoload: true,
|
||||
options: {
|
||||
@@ -157,9 +147,7 @@ export namespace Provider {
|
||||
|
||||
switch (regionPrefix) {
|
||||
case "us": {
|
||||
const modelRequiresPrefix = ["claude", "deepseek"].some((m) =>
|
||||
modelID.includes(m),
|
||||
)
|
||||
const modelRequiresPrefix = ["claude", "deepseek"].some((m) => modelID.includes(m))
|
||||
if (modelRequiresPrefix) {
|
||||
modelID = `${regionPrefix}.${modelID}`
|
||||
}
|
||||
@@ -174,25 +162,18 @@ export namespace Provider {
|
||||
"eu-south-1",
|
||||
"eu-south-2",
|
||||
].some((r) => region.includes(r))
|
||||
const modelRequiresPrefix = [
|
||||
"claude",
|
||||
"nova-lite",
|
||||
"nova-micro",
|
||||
"llama3",
|
||||
"pixtral",
|
||||
].some((m) => modelID.includes(m))
|
||||
const modelRequiresPrefix = ["claude", "nova-lite", "nova-micro", "llama3", "pixtral"].some((m) =>
|
||||
modelID.includes(m),
|
||||
)
|
||||
if (regionRequiresPrefix && modelRequiresPrefix) {
|
||||
modelID = `${regionPrefix}.${modelID}`
|
||||
}
|
||||
break
|
||||
}
|
||||
case "ap": {
|
||||
const modelRequiresPrefix = [
|
||||
"claude",
|
||||
"nova-lite",
|
||||
"nova-micro",
|
||||
"nova-pro",
|
||||
].some((m) => modelID.includes(m))
|
||||
const modelRequiresPrefix = ["claude", "nova-lite", "nova-micro", "nova-pro"].some((m) =>
|
||||
modelID.includes(m),
|
||||
)
|
||||
if (modelRequiresPrefix) {
|
||||
regionPrefix = "apac"
|
||||
modelID = `${regionPrefix}.${modelID}`
|
||||
@@ -230,10 +211,7 @@ export namespace Provider {
|
||||
options: Record<string, any>
|
||||
}
|
||||
} = {}
|
||||
const models = new Map<
|
||||
string,
|
||||
{ info: ModelsDev.Model; language: LanguageModel }
|
||||
>()
|
||||
const models = new Map<string, { info: ModelsDev.Model; language: LanguageModel }>()
|
||||
const sdk = new Map<string, SDK>()
|
||||
|
||||
log.info("init")
|
||||
@@ -308,9 +286,7 @@ export namespace Provider {
|
||||
database[providerID] = parsed
|
||||
}
|
||||
|
||||
const disabled = await Config.get().then(
|
||||
(cfg) => new Set(cfg.disabled_providers ?? []),
|
||||
)
|
||||
const disabled = await Config.get().then((cfg) => new Set(cfg.disabled_providers ?? []))
|
||||
// load env
|
||||
for (const [providerID, provider] of Object.entries(database)) {
|
||||
if (disabled.has(providerID)) continue
|
||||
@@ -337,12 +313,7 @@ export namespace Provider {
|
||||
if (disabled.has(providerID)) continue
|
||||
const result = await fn(database[providerID])
|
||||
if (result && (result.autoload || providers[providerID])) {
|
||||
mergeProvider(
|
||||
providerID,
|
||||
result.options ?? {},
|
||||
"custom",
|
||||
result.getModel,
|
||||
)
|
||||
mergeProvider(providerID, result.options ?? {}, "custom", result.getModel)
|
||||
}
|
||||
}
|
||||
|
||||
@@ -379,7 +350,7 @@ export namespace Provider {
|
||||
const existing = s.sdk.get(provider.id)
|
||||
if (existing) return existing
|
||||
const pkg = provider.npm ?? provider.id
|
||||
const mod = await import(await BunProc.install(pkg, "latest"))
|
||||
const mod = await import(await BunProc.install(pkg, "beta"))
|
||||
const fn = mod[Object.keys(mod).find((key) => key.startsWith("create"))!]
|
||||
const loaded = fn(s.providers[provider.id]?.options)
|
||||
s.sdk.set(provider.id, loaded)
|
||||
@@ -406,9 +377,7 @@ export namespace Provider {
|
||||
const sdk = await getSDK(provider.info)
|
||||
|
||||
try {
|
||||
const language = provider.getModel
|
||||
? await provider.getModel(sdk, modelID)
|
||||
: sdk.languageModel(modelID)
|
||||
const language = provider.getModel ? await provider.getModel(sdk, modelID) : sdk.languageModel(modelID)
|
||||
log.info("found", { providerID, modelID })
|
||||
s.models.set(key, {
|
||||
info,
|
||||
@@ -435,10 +404,7 @@ export namespace Provider {
|
||||
export function sort(models: ModelsDev.Model[]) {
|
||||
return sortBy(
|
||||
models,
|
||||
[
|
||||
(model) => priority.findIndex((filter) => model.id.includes(filter)),
|
||||
"desc",
|
||||
],
|
||||
[(model) => priority.findIndex((filter) => model.id.includes(filter)), "desc"],
|
||||
[(model) => (model.id.includes("latest") ? 0 : 1), "asc"],
|
||||
[(model) => model.id, "desc"],
|
||||
)
|
||||
@@ -449,11 +415,7 @@ export namespace Provider {
|
||||
if (cfg.model) return parseModel(cfg.model)
|
||||
const provider = await list()
|
||||
.then((val) => Object.values(val))
|
||||
.then((x) =>
|
||||
x.find(
|
||||
(p) => !cfg.provider || Object.keys(cfg.provider).includes(p.info.id),
|
||||
),
|
||||
)
|
||||
.then((x) => x.find((p) => !cfg.provider || Object.keys(cfg.provider).includes(p.info.id)))
|
||||
if (!provider) throw new Error("no providers found")
|
||||
const [model] = sort(Object.values(provider.info.models))
|
||||
if (!model) throw new Error("no models found")
|
||||
@@ -536,9 +498,11 @@ export namespace Provider {
|
||||
|
||||
if (schema instanceof z.ZodUnion) {
|
||||
return z.union(
|
||||
schema.options.map((option: z.ZodTypeAny) =>
|
||||
optionalToNullable(option),
|
||||
) as [z.ZodTypeAny, z.ZodTypeAny, ...z.ZodTypeAny[]],
|
||||
schema.options.map((option: z.ZodTypeAny) => optionalToNullable(option)) as [
|
||||
z.ZodTypeAny,
|
||||
z.ZodTypeAny,
|
||||
...z.ZodTypeAny[],
|
||||
],
|
||||
)
|
||||
}
|
||||
|
||||
|
||||
@@ -1,19 +1,15 @@
|
||||
import type { LanguageModelV1Prompt } from "ai"
|
||||
import type { ModelMessage } from "ai"
|
||||
import { unique } from "remeda"
|
||||
|
||||
export namespace ProviderTransform {
|
||||
export function message(
|
||||
msgs: LanguageModelV1Prompt,
|
||||
providerID: string,
|
||||
modelID: string,
|
||||
) {
|
||||
export function message(msgs: ModelMessage[], providerID: string, modelID: string) {
|
||||
if (providerID === "anthropic" || modelID.includes("anthropic")) {
|
||||
const system = msgs.filter((msg) => msg.role === "system").slice(0, 2)
|
||||
const final = msgs.filter((msg) => msg.role !== "system").slice(-2)
|
||||
|
||||
for (const msg of unique([...system, ...final])) {
|
||||
msg.providerMetadata = {
|
||||
...msg.providerMetadata,
|
||||
msg.providerOptions = {
|
||||
...msg.providerOptions,
|
||||
anthropic: {
|
||||
cacheControl: { type: "ephemeral" },
|
||||
},
|
||||
@@ -28,8 +24,8 @@ export namespace ProviderTransform {
|
||||
const final = msgs.filter((msg) => msg.role !== "system").slice(-2)
|
||||
|
||||
for (const msg of unique([...system, ...final])) {
|
||||
msg.providerMetadata = {
|
||||
...msg.providerMetadata,
|
||||
msg.providerOptions = {
|
||||
...msg.providerOptions,
|
||||
bedrock: {
|
||||
cachePoint: { type: "ephemeral" },
|
||||
},
|
||||
|
||||
Reference in New Issue
Block a user