mirror of
https://gitea.toothfairyai.com/ToothFairyAI/tf_code.git
synced 2026-03-29 21:33:54 +00:00
feat: toothfairyai as provdier
This commit is contained in:
parent
1460f80d1a
commit
b4c7faa842
@ -22,6 +22,13 @@ export namespace ModelsDev {
|
||||
us: "https://ai.us.toothfairyai.com",
|
||||
}
|
||||
|
||||
const REGION_STREAMING_URLS: Record<string, string> = {
|
||||
dev: "https://ais.toothfairylab.link",
|
||||
au: "https://ais.toothfairyai.com",
|
||||
eu: "https://ais.eu.toothfairyai.com",
|
||||
us: "https://ais.us.toothfairyai.com",
|
||||
}
|
||||
|
||||
export const Model = z.object({
|
||||
id: z.string(),
|
||||
name: z.string(),
|
||||
@ -163,7 +170,8 @@ export namespace ModelsDev {
|
||||
// Only include serverless models
|
||||
if (model.deploymentType && model.deploymentType !== "serverless") continue
|
||||
|
||||
const modelId = key.startsWith("z/") ? key.slice(2) : key
|
||||
// Use the full key as the model ID (API expects the exact key)
|
||||
const modelId = key
|
||||
|
||||
tfModels[modelId] = {
|
||||
id: modelId,
|
||||
@ -194,6 +202,8 @@ export namespace ModelsDev {
|
||||
id: "toothfairyai",
|
||||
name: "ToothFairyAI",
|
||||
env: ["TF_API_KEY", "TF_WORKSPACE_ID"],
|
||||
npm: "@toothfairyai/sdk",
|
||||
api: REGION_STREAMING_URLS[tfRegion],
|
||||
models: tfModels,
|
||||
}
|
||||
}
|
||||
@ -208,6 +218,8 @@ export namespace ModelsDev {
|
||||
id: "toothfairyai",
|
||||
name: "ToothFairyAI",
|
||||
env: ["TF_API_KEY", "TF_WORKSPACE_ID"],
|
||||
npm: "@toothfairyai/sdk",
|
||||
api: "https://ais.toothfairyai.com",
|
||||
models: {
|
||||
sorcerer: {
|
||||
id: "sorcerer",
|
||||
|
||||
@ -30,6 +30,7 @@ import { createOpenAI } from "@ai-sdk/openai"
|
||||
import { createOpenAICompatible } from "@ai-sdk/openai-compatible"
|
||||
import { createOpenRouter, type LanguageModelV2 } from "@openrouter/ai-sdk-provider"
|
||||
import { createOpenaiCompatible as createGitHubCopilotOpenAICompatible } from "./sdk/copilot"
|
||||
import { createToothFairyAI } from "./sdk/toothfairyai"
|
||||
import { createXai } from "@ai-sdk/xai"
|
||||
import { createMistral } from "@ai-sdk/mistral"
|
||||
import { createGroq } from "@ai-sdk/groq"
|
||||
@ -130,6 +131,7 @@ export namespace Provider {
|
||||
"@ai-sdk/perplexity": createPerplexity,
|
||||
"@ai-sdk/vercel": createVercel,
|
||||
"gitlab-ai-provider": createGitLab,
|
||||
"@toothfairyai/sdk": createToothFairyAI as any,
|
||||
// @ts-ignore (TODO: kill this code so we dont have to maintain it)
|
||||
"@ai-sdk/github-copilot": createGitHubCopilotOpenAICompatible,
|
||||
}
|
||||
@ -183,26 +185,50 @@ export namespace Provider {
|
||||
}
|
||||
},
|
||||
async toothfairyai(input) {
|
||||
const hasCredentials = await (async () => {
|
||||
const credentials = await (async () => {
|
||||
const env = Env.all()
|
||||
if (env.TF_API_KEY) return true
|
||||
if (await Auth.get(input.id)) return true
|
||||
let apiKey = env.TF_API_KEY
|
||||
let workspaceId = env.TF_WORKSPACE_ID
|
||||
let region = env.TF_REGION || "au"
|
||||
|
||||
// Check auth storage
|
||||
const auth = await Auth.get(input.id)
|
||||
if (auth?.type === "api") {
|
||||
apiKey = apiKey || auth.key
|
||||
}
|
||||
|
||||
// Check config
|
||||
const config = await Config.get()
|
||||
if (config.provider?.["toothfairyai"]?.options?.apiKey) return true
|
||||
|
||||
const tfConfig = config.provider?.["toothfairyai"]
|
||||
if (tfConfig?.options?.apiKey) apiKey = apiKey || tfConfig.options.apiKey
|
||||
if (tfConfig?.options?.workspaceId) workspaceId = workspaceId || tfConfig.options.workspaceId
|
||||
if (tfConfig?.options?.region) region = tfConfig.options.region
|
||||
|
||||
// Check stored credentials file
|
||||
try {
|
||||
const credPath = path.join(Global.Path.data, ".tfcode", "credentials.json")
|
||||
const credData = await Bun.file(credPath).json() as { api_key?: string }
|
||||
if (credData.api_key) return true
|
||||
const credData = (await Bun.file(credPath).json()) as {
|
||||
api_key?: string
|
||||
workspace_id?: string
|
||||
region?: string
|
||||
}
|
||||
if (credData.api_key) apiKey = apiKey || credData.api_key
|
||||
if (credData.workspace_id) workspaceId = workspaceId || credData.workspace_id
|
||||
if (credData.region) region = credData.region
|
||||
} catch {}
|
||||
|
||||
return false
|
||||
|
||||
return { apiKey, workspaceId, region }
|
||||
})()
|
||||
|
||||
return {
|
||||
autoload: hasCredentials,
|
||||
options: hasCredentials ? {} : { apiKey: "setup-required" },
|
||||
autoload: !!credentials.apiKey,
|
||||
options: credentials.apiKey
|
||||
? {
|
||||
apiKey: credentials.apiKey,
|
||||
workspaceId: credentials.workspaceId,
|
||||
region: credentials.region,
|
||||
}
|
||||
: { apiKey: "setup-required" },
|
||||
}
|
||||
},
|
||||
openai: async () => {
|
||||
|
||||
2
packages/tfcode/src/provider/sdk/toothfairyai/index.ts
Normal file
2
packages/tfcode/src/provider/sdk/toothfairyai/index.ts
Normal file
@ -0,0 +1,2 @@
|
||||
export { createToothFairyAI } from "./toothfairyai-provider"
|
||||
export type { ToothFairyAIProvider, ToothFairyAIProviderSettings, ToothFairyAIModelId } from "./toothfairyai-provider"
|
||||
@ -0,0 +1,136 @@
|
||||
import type { LanguageModelV2 } from "@ai-sdk/provider"
|
||||
import { type FetchFunction, withoutTrailingSlash } from "@ai-sdk/provider-utils"
|
||||
import { OpenAICompatibleChatLanguageModel } from "../copilot/chat/openai-compatible-chat-language-model"
|
||||
|
||||
const VERSION = "1.0.0"
|
||||
|
||||
export type ToothFairyAIModelId = string
|
||||
|
||||
export interface ToothFairyAIProviderSettings {
|
||||
apiKey?: string
|
||||
workspaceId?: string
|
||||
region?: "dev" | "au" | "eu" | "us"
|
||||
baseURL?: string
|
||||
name?: string
|
||||
headers?: Record<string, string>
|
||||
fetch?: FetchFunction
|
||||
}
|
||||
|
||||
export interface ToothFairyAIProvider {
|
||||
(modelId: ToothFairyAIModelId): LanguageModelV2
|
||||
languageModel(modelId: ToothFairyAIModelId): LanguageModelV2
|
||||
}
|
||||
|
||||
const REGION_STREAMING_URLS: Record<string, string> = {
|
||||
dev: "https://ais.toothfairylab.link",
|
||||
au: "https://ais.toothfairyai.com",
|
||||
eu: "https://ais.eu.toothfairyai.com",
|
||||
us: "https://ais.us.toothfairyai.com",
|
||||
}
|
||||
|
||||
export function createToothFairyAI(options: ToothFairyAIProviderSettings = {}): ToothFairyAIProvider {
|
||||
const region = options.region || "au"
|
||||
const baseURL = withoutTrailingSlash(options.baseURL ?? REGION_STREAMING_URLS[region] ?? REGION_STREAMING_URLS.au)
|
||||
|
||||
const baseHeaders: Record<string, string> = {
|
||||
...(options.apiKey && { "x-api-key": options.apiKey }),
|
||||
...options.headers,
|
||||
}
|
||||
|
||||
const workspaceId = options.workspaceId
|
||||
|
||||
const customFetch = async (input: RequestInfo | URL, init?: RequestInit) => {
|
||||
const url = new URL(typeof input === "string" ? input : input instanceof URL ? input.href : (input as Request).url)
|
||||
|
||||
if (url.pathname === "/chat/completions") {
|
||||
url.pathname = "/predictions"
|
||||
}
|
||||
|
||||
let body: Record<string, unknown> | undefined
|
||||
if (init?.body && typeof init.body === "string") {
|
||||
try {
|
||||
body = JSON.parse(init.body)
|
||||
} catch {}
|
||||
}
|
||||
|
||||
if (body && workspaceId) {
|
||||
body.workspaceid = workspaceId
|
||||
}
|
||||
|
||||
const headers = new Headers(init?.headers)
|
||||
for (const [key, value] of Object.entries(baseHeaders)) {
|
||||
headers.set(key, value)
|
||||
}
|
||||
headers.delete("Authorization")
|
||||
|
||||
const fetchFn = options.fetch ?? globalThis.fetch
|
||||
const res = await fetchFn(url, {
|
||||
...init,
|
||||
headers,
|
||||
body: body ? JSON.stringify(body) : init?.body,
|
||||
})
|
||||
|
||||
if (res.body && res.headers.get("content-type")?.includes("text/event-stream")) {
|
||||
const reader = res.body.getReader()
|
||||
const decoder = new TextDecoder()
|
||||
const encoder = new TextEncoder()
|
||||
|
||||
const filteredStream = new ReadableStream({
|
||||
async pull(controller) {
|
||||
const { done, value } = await reader.read()
|
||||
if (done) {
|
||||
controller.close()
|
||||
return
|
||||
}
|
||||
|
||||
const text = decoder.decode(value, { stream: true })
|
||||
const lines = text.split("\n")
|
||||
|
||||
const filtered: string[] = []
|
||||
for (const line of lines) {
|
||||
if (line.startsWith("data: ")) {
|
||||
const json = line.slice(6).trim()
|
||||
if (json && !json.startsWith('{"status":')) {
|
||||
filtered.push(line)
|
||||
}
|
||||
} else {
|
||||
filtered.push(line)
|
||||
}
|
||||
}
|
||||
|
||||
controller.enqueue(encoder.encode(filtered.join("\n")))
|
||||
},
|
||||
cancel() {
|
||||
reader.cancel()
|
||||
},
|
||||
})
|
||||
|
||||
return new Response(filteredStream, {
|
||||
headers: res.headers,
|
||||
status: res.status,
|
||||
statusText: res.statusText,
|
||||
})
|
||||
}
|
||||
|
||||
return res
|
||||
}
|
||||
|
||||
const createChatModel = (modelId: ToothFairyAIModelId) => {
|
||||
return new OpenAICompatibleChatLanguageModel(modelId, {
|
||||
provider: `${options.name ?? "toothfairyai"}.chat`,
|
||||
headers: () => baseHeaders,
|
||||
url: ({ path }) => `${baseURL}${path}`,
|
||||
fetch: customFetch as FetchFunction,
|
||||
})
|
||||
}
|
||||
|
||||
const createLanguageModel = (modelId: ToothFairyAIModelId) => createChatModel(modelId)
|
||||
|
||||
const provider = function (modelId: ToothFairyAIModelId) {
|
||||
return createChatModel(modelId)
|
||||
}
|
||||
|
||||
provider.languageModel = createLanguageModel
|
||||
|
||||
return provider as ToothFairyAIProvider
|
||||
}
|
||||
Loading…
x
Reference in New Issue
Block a user