diff --git a/packages/tfcode/src/provider/models.ts b/packages/tfcode/src/provider/models.ts index da42fb359..c09c5c9a3 100644 --- a/packages/tfcode/src/provider/models.ts +++ b/packages/tfcode/src/provider/models.ts @@ -22,6 +22,13 @@ export namespace ModelsDev { us: "https://ai.us.toothfairyai.com", } + const REGION_STREAMING_URLS: Record = { + dev: "https://ais.toothfairylab.link", + au: "https://ais.toothfairyai.com", + eu: "https://ais.eu.toothfairyai.com", + us: "https://ais.us.toothfairyai.com", + } + export const Model = z.object({ id: z.string(), name: z.string(), @@ -163,7 +170,8 @@ export namespace ModelsDev { // Only include serverless models if (model.deploymentType && model.deploymentType !== "serverless") continue - const modelId = key.startsWith("z/") ? key.slice(2) : key + // Use the full key as the model ID (API expects the exact key) + const modelId = key tfModels[modelId] = { id: modelId, @@ -194,6 +202,8 @@ export namespace ModelsDev { id: "toothfairyai", name: "ToothFairyAI", env: ["TF_API_KEY", "TF_WORKSPACE_ID"], + npm: "@toothfairyai/sdk", + api: REGION_STREAMING_URLS[tfRegion], models: tfModels, } } @@ -208,6 +218,8 @@ export namespace ModelsDev { id: "toothfairyai", name: "ToothFairyAI", env: ["TF_API_KEY", "TF_WORKSPACE_ID"], + npm: "@toothfairyai/sdk", + api: "https://ais.toothfairyai.com", models: { sorcerer: { id: "sorcerer", diff --git a/packages/tfcode/src/provider/provider.ts b/packages/tfcode/src/provider/provider.ts index 4f1cefd69..a36b1ca27 100644 --- a/packages/tfcode/src/provider/provider.ts +++ b/packages/tfcode/src/provider/provider.ts @@ -30,6 +30,7 @@ import { createOpenAI } from "@ai-sdk/openai" import { createOpenAICompatible } from "@ai-sdk/openai-compatible" import { createOpenRouter, type LanguageModelV2 } from "@openrouter/ai-sdk-provider" import { createOpenaiCompatible as createGitHubCopilotOpenAICompatible } from "./sdk/copilot" +import { createToothFairyAI } from "./sdk/toothfairyai" import { createXai } from "@ai-sdk/xai" import { createMistral } from "@ai-sdk/mistral" import { createGroq } from "@ai-sdk/groq" @@ -130,6 +131,7 @@ export namespace Provider { "@ai-sdk/perplexity": createPerplexity, "@ai-sdk/vercel": createVercel, "gitlab-ai-provider": createGitLab, + "@toothfairyai/sdk": createToothFairyAI as any, // @ts-ignore (TODO: kill this code so we dont have to maintain it) "@ai-sdk/github-copilot": createGitHubCopilotOpenAICompatible, } @@ -183,26 +185,50 @@ export namespace Provider { } }, async toothfairyai(input) { - const hasCredentials = await (async () => { + const credentials = await (async () => { const env = Env.all() - if (env.TF_API_KEY) return true - if (await Auth.get(input.id)) return true + let apiKey = env.TF_API_KEY + let workspaceId = env.TF_WORKSPACE_ID + let region = env.TF_REGION || "au" + + // Check auth storage + const auth = await Auth.get(input.id) + if (auth?.type === "api") { + apiKey = apiKey || auth.key + } + + // Check config const config = await Config.get() - if (config.provider?.["toothfairyai"]?.options?.apiKey) return true - + const tfConfig = config.provider?.["toothfairyai"] + if (tfConfig?.options?.apiKey) apiKey = apiKey || tfConfig.options.apiKey + if (tfConfig?.options?.workspaceId) workspaceId = workspaceId || tfConfig.options.workspaceId + if (tfConfig?.options?.region) region = tfConfig.options.region + // Check stored credentials file try { const credPath = path.join(Global.Path.data, ".tfcode", "credentials.json") - const credData = await Bun.file(credPath).json() as { api_key?: string } - if (credData.api_key) return true + const credData = (await Bun.file(credPath).json()) as { + api_key?: string + workspace_id?: string + region?: string + } + if (credData.api_key) apiKey = apiKey || credData.api_key + if (credData.workspace_id) workspaceId = workspaceId || credData.workspace_id + if (credData.region) region = credData.region } catch {} - - return false + + return { apiKey, workspaceId, region } })() return { - autoload: hasCredentials, - options: hasCredentials ? {} : { apiKey: "setup-required" }, + autoload: !!credentials.apiKey, + options: credentials.apiKey + ? { + apiKey: credentials.apiKey, + workspaceId: credentials.workspaceId, + region: credentials.region, + } + : { apiKey: "setup-required" }, } }, openai: async () => { diff --git a/packages/tfcode/src/provider/sdk/toothfairyai/index.ts b/packages/tfcode/src/provider/sdk/toothfairyai/index.ts new file mode 100644 index 000000000..c69e6babf --- /dev/null +++ b/packages/tfcode/src/provider/sdk/toothfairyai/index.ts @@ -0,0 +1,2 @@ +export { createToothFairyAI } from "./toothfairyai-provider" +export type { ToothFairyAIProvider, ToothFairyAIProviderSettings, ToothFairyAIModelId } from "./toothfairyai-provider" diff --git a/packages/tfcode/src/provider/sdk/toothfairyai/toothfairyai-provider.ts b/packages/tfcode/src/provider/sdk/toothfairyai/toothfairyai-provider.ts new file mode 100644 index 000000000..3e6bb0ac4 --- /dev/null +++ b/packages/tfcode/src/provider/sdk/toothfairyai/toothfairyai-provider.ts @@ -0,0 +1,136 @@ +import type { LanguageModelV2 } from "@ai-sdk/provider" +import { type FetchFunction, withoutTrailingSlash } from "@ai-sdk/provider-utils" +import { OpenAICompatibleChatLanguageModel } from "../copilot/chat/openai-compatible-chat-language-model" + +const VERSION = "1.0.0" + +export type ToothFairyAIModelId = string + +export interface ToothFairyAIProviderSettings { + apiKey?: string + workspaceId?: string + region?: "dev" | "au" | "eu" | "us" + baseURL?: string + name?: string + headers?: Record + fetch?: FetchFunction +} + +export interface ToothFairyAIProvider { + (modelId: ToothFairyAIModelId): LanguageModelV2 + languageModel(modelId: ToothFairyAIModelId): LanguageModelV2 +} + +const REGION_STREAMING_URLS: Record = { + dev: "https://ais.toothfairylab.link", + au: "https://ais.toothfairyai.com", + eu: "https://ais.eu.toothfairyai.com", + us: "https://ais.us.toothfairyai.com", +} + +export function createToothFairyAI(options: ToothFairyAIProviderSettings = {}): ToothFairyAIProvider { + const region = options.region || "au" + const baseURL = withoutTrailingSlash(options.baseURL ?? REGION_STREAMING_URLS[region] ?? REGION_STREAMING_URLS.au) + + const baseHeaders: Record = { + ...(options.apiKey && { "x-api-key": options.apiKey }), + ...options.headers, + } + + const workspaceId = options.workspaceId + + const customFetch = async (input: RequestInfo | URL, init?: RequestInit) => { + const url = new URL(typeof input === "string" ? input : input instanceof URL ? input.href : (input as Request).url) + + if (url.pathname === "/chat/completions") { + url.pathname = "/predictions" + } + + let body: Record | undefined + if (init?.body && typeof init.body === "string") { + try { + body = JSON.parse(init.body) + } catch {} + } + + if (body && workspaceId) { + body.workspaceid = workspaceId + } + + const headers = new Headers(init?.headers) + for (const [key, value] of Object.entries(baseHeaders)) { + headers.set(key, value) + } + headers.delete("Authorization") + + const fetchFn = options.fetch ?? globalThis.fetch + const res = await fetchFn(url, { + ...init, + headers, + body: body ? JSON.stringify(body) : init?.body, + }) + + if (res.body && res.headers.get("content-type")?.includes("text/event-stream")) { + const reader = res.body.getReader() + const decoder = new TextDecoder() + const encoder = new TextEncoder() + + const filteredStream = new ReadableStream({ + async pull(controller) { + const { done, value } = await reader.read() + if (done) { + controller.close() + return + } + + const text = decoder.decode(value, { stream: true }) + const lines = text.split("\n") + + const filtered: string[] = [] + for (const line of lines) { + if (line.startsWith("data: ")) { + const json = line.slice(6).trim() + if (json && !json.startsWith('{"status":')) { + filtered.push(line) + } + } else { + filtered.push(line) + } + } + + controller.enqueue(encoder.encode(filtered.join("\n"))) + }, + cancel() { + reader.cancel() + }, + }) + + return new Response(filteredStream, { + headers: res.headers, + status: res.status, + statusText: res.statusText, + }) + } + + return res + } + + const createChatModel = (modelId: ToothFairyAIModelId) => { + return new OpenAICompatibleChatLanguageModel(modelId, { + provider: `${options.name ?? "toothfairyai"}.chat`, + headers: () => baseHeaders, + url: ({ path }) => `${baseURL}${path}`, + fetch: customFetch as FetchFunction, + }) + } + + const createLanguageModel = (modelId: ToothFairyAIModelId) => createChatModel(modelId) + + const provider = function (modelId: ToothFairyAIModelId) { + return createChatModel(modelId) + } + + provider.languageModel = createLanguageModel + + return provider as ToothFairyAIProvider +}