feat: toothfairyai as provdier

This commit is contained in:
Gab 2026-03-24 18:37:53 +11:00
parent 1460f80d1a
commit b4c7faa842
4 changed files with 188 additions and 12 deletions

View File

@ -22,6 +22,13 @@ export namespace ModelsDev {
us: "https://ai.us.toothfairyai.com", us: "https://ai.us.toothfairyai.com",
} }
const REGION_STREAMING_URLS: Record<string, string> = {
dev: "https://ais.toothfairylab.link",
au: "https://ais.toothfairyai.com",
eu: "https://ais.eu.toothfairyai.com",
us: "https://ais.us.toothfairyai.com",
}
export const Model = z.object({ export const Model = z.object({
id: z.string(), id: z.string(),
name: z.string(), name: z.string(),
@ -163,7 +170,8 @@ export namespace ModelsDev {
// Only include serverless models // Only include serverless models
if (model.deploymentType && model.deploymentType !== "serverless") continue if (model.deploymentType && model.deploymentType !== "serverless") continue
const modelId = key.startsWith("z/") ? key.slice(2) : key // Use the full key as the model ID (API expects the exact key)
const modelId = key
tfModels[modelId] = { tfModels[modelId] = {
id: modelId, id: modelId,
@ -194,6 +202,8 @@ export namespace ModelsDev {
id: "toothfairyai", id: "toothfairyai",
name: "ToothFairyAI", name: "ToothFairyAI",
env: ["TF_API_KEY", "TF_WORKSPACE_ID"], env: ["TF_API_KEY", "TF_WORKSPACE_ID"],
npm: "@toothfairyai/sdk",
api: REGION_STREAMING_URLS[tfRegion],
models: tfModels, models: tfModels,
} }
} }
@ -208,6 +218,8 @@ export namespace ModelsDev {
id: "toothfairyai", id: "toothfairyai",
name: "ToothFairyAI", name: "ToothFairyAI",
env: ["TF_API_KEY", "TF_WORKSPACE_ID"], env: ["TF_API_KEY", "TF_WORKSPACE_ID"],
npm: "@toothfairyai/sdk",
api: "https://ais.toothfairyai.com",
models: { models: {
sorcerer: { sorcerer: {
id: "sorcerer", id: "sorcerer",

View File

@ -30,6 +30,7 @@ import { createOpenAI } from "@ai-sdk/openai"
import { createOpenAICompatible } from "@ai-sdk/openai-compatible" import { createOpenAICompatible } from "@ai-sdk/openai-compatible"
import { createOpenRouter, type LanguageModelV2 } from "@openrouter/ai-sdk-provider" import { createOpenRouter, type LanguageModelV2 } from "@openrouter/ai-sdk-provider"
import { createOpenaiCompatible as createGitHubCopilotOpenAICompatible } from "./sdk/copilot" import { createOpenaiCompatible as createGitHubCopilotOpenAICompatible } from "./sdk/copilot"
import { createToothFairyAI } from "./sdk/toothfairyai"
import { createXai } from "@ai-sdk/xai" import { createXai } from "@ai-sdk/xai"
import { createMistral } from "@ai-sdk/mistral" import { createMistral } from "@ai-sdk/mistral"
import { createGroq } from "@ai-sdk/groq" import { createGroq } from "@ai-sdk/groq"
@ -130,6 +131,7 @@ export namespace Provider {
"@ai-sdk/perplexity": createPerplexity, "@ai-sdk/perplexity": createPerplexity,
"@ai-sdk/vercel": createVercel, "@ai-sdk/vercel": createVercel,
"gitlab-ai-provider": createGitLab, "gitlab-ai-provider": createGitLab,
"@toothfairyai/sdk": createToothFairyAI as any,
// @ts-ignore (TODO: kill this code so we dont have to maintain it) // @ts-ignore (TODO: kill this code so we dont have to maintain it)
"@ai-sdk/github-copilot": createGitHubCopilotOpenAICompatible, "@ai-sdk/github-copilot": createGitHubCopilotOpenAICompatible,
} }
@ -183,26 +185,50 @@ export namespace Provider {
} }
}, },
async toothfairyai(input) { async toothfairyai(input) {
const hasCredentials = await (async () => { const credentials = await (async () => {
const env = Env.all() const env = Env.all()
if (env.TF_API_KEY) return true let apiKey = env.TF_API_KEY
if (await Auth.get(input.id)) return true let workspaceId = env.TF_WORKSPACE_ID
let region = env.TF_REGION || "au"
// Check auth storage
const auth = await Auth.get(input.id)
if (auth?.type === "api") {
apiKey = apiKey || auth.key
}
// Check config
const config = await Config.get() const config = await Config.get()
if (config.provider?.["toothfairyai"]?.options?.apiKey) return true const tfConfig = config.provider?.["toothfairyai"]
if (tfConfig?.options?.apiKey) apiKey = apiKey || tfConfig.options.apiKey
if (tfConfig?.options?.workspaceId) workspaceId = workspaceId || tfConfig.options.workspaceId
if (tfConfig?.options?.region) region = tfConfig.options.region
// Check stored credentials file // Check stored credentials file
try { try {
const credPath = path.join(Global.Path.data, ".tfcode", "credentials.json") const credPath = path.join(Global.Path.data, ".tfcode", "credentials.json")
const credData = await Bun.file(credPath).json() as { api_key?: string } const credData = (await Bun.file(credPath).json()) as {
if (credData.api_key) return true api_key?: string
workspace_id?: string
region?: string
}
if (credData.api_key) apiKey = apiKey || credData.api_key
if (credData.workspace_id) workspaceId = workspaceId || credData.workspace_id
if (credData.region) region = credData.region
} catch {} } catch {}
return false return { apiKey, workspaceId, region }
})() })()
return { return {
autoload: hasCredentials, autoload: !!credentials.apiKey,
options: hasCredentials ? {} : { apiKey: "setup-required" }, options: credentials.apiKey
? {
apiKey: credentials.apiKey,
workspaceId: credentials.workspaceId,
region: credentials.region,
}
: { apiKey: "setup-required" },
} }
}, },
openai: async () => { openai: async () => {

View File

@ -0,0 +1,2 @@
export { createToothFairyAI } from "./toothfairyai-provider"
export type { ToothFairyAIProvider, ToothFairyAIProviderSettings, ToothFairyAIModelId } from "./toothfairyai-provider"

View File

@ -0,0 +1,136 @@
import type { LanguageModelV2 } from "@ai-sdk/provider"
import { type FetchFunction, withoutTrailingSlash } from "@ai-sdk/provider-utils"
import { OpenAICompatibleChatLanguageModel } from "../copilot/chat/openai-compatible-chat-language-model"
const VERSION = "1.0.0"
export type ToothFairyAIModelId = string
export interface ToothFairyAIProviderSettings {
apiKey?: string
workspaceId?: string
region?: "dev" | "au" | "eu" | "us"
baseURL?: string
name?: string
headers?: Record<string, string>
fetch?: FetchFunction
}
export interface ToothFairyAIProvider {
(modelId: ToothFairyAIModelId): LanguageModelV2
languageModel(modelId: ToothFairyAIModelId): LanguageModelV2
}
const REGION_STREAMING_URLS: Record<string, string> = {
dev: "https://ais.toothfairylab.link",
au: "https://ais.toothfairyai.com",
eu: "https://ais.eu.toothfairyai.com",
us: "https://ais.us.toothfairyai.com",
}
export function createToothFairyAI(options: ToothFairyAIProviderSettings = {}): ToothFairyAIProvider {
const region = options.region || "au"
const baseURL = withoutTrailingSlash(options.baseURL ?? REGION_STREAMING_URLS[region] ?? REGION_STREAMING_URLS.au)
const baseHeaders: Record<string, string> = {
...(options.apiKey && { "x-api-key": options.apiKey }),
...options.headers,
}
const workspaceId = options.workspaceId
const customFetch = async (input: RequestInfo | URL, init?: RequestInit) => {
const url = new URL(typeof input === "string" ? input : input instanceof URL ? input.href : (input as Request).url)
if (url.pathname === "/chat/completions") {
url.pathname = "/predictions"
}
let body: Record<string, unknown> | undefined
if (init?.body && typeof init.body === "string") {
try {
body = JSON.parse(init.body)
} catch {}
}
if (body && workspaceId) {
body.workspaceid = workspaceId
}
const headers = new Headers(init?.headers)
for (const [key, value] of Object.entries(baseHeaders)) {
headers.set(key, value)
}
headers.delete("Authorization")
const fetchFn = options.fetch ?? globalThis.fetch
const res = await fetchFn(url, {
...init,
headers,
body: body ? JSON.stringify(body) : init?.body,
})
if (res.body && res.headers.get("content-type")?.includes("text/event-stream")) {
const reader = res.body.getReader()
const decoder = new TextDecoder()
const encoder = new TextEncoder()
const filteredStream = new ReadableStream({
async pull(controller) {
const { done, value } = await reader.read()
if (done) {
controller.close()
return
}
const text = decoder.decode(value, { stream: true })
const lines = text.split("\n")
const filtered: string[] = []
for (const line of lines) {
if (line.startsWith("data: ")) {
const json = line.slice(6).trim()
if (json && !json.startsWith('{"status":')) {
filtered.push(line)
}
} else {
filtered.push(line)
}
}
controller.enqueue(encoder.encode(filtered.join("\n")))
},
cancel() {
reader.cancel()
},
})
return new Response(filteredStream, {
headers: res.headers,
status: res.status,
statusText: res.statusText,
})
}
return res
}
const createChatModel = (modelId: ToothFairyAIModelId) => {
return new OpenAICompatibleChatLanguageModel(modelId, {
provider: `${options.name ?? "toothfairyai"}.chat`,
headers: () => baseHeaders,
url: ({ path }) => `${baseURL}${path}`,
fetch: customFetch as FetchFunction,
})
}
const createLanguageModel = (modelId: ToothFairyAIModelId) => createChatModel(modelId)
const provider = function (modelId: ToothFairyAIModelId) {
return createChatModel(modelId)
}
provider.languageModel = createLanguageModel
return provider as ToothFairyAIProvider
}