Files
tf_code/packages/tfcode/src/provider/sdk/toothfairyai/toothfairyai-provider.ts
2026-03-29 19:24:02 +11:00

204 lines
6.2 KiB
TypeScript

import type { LanguageModelV2 } from "@ai-sdk/provider"
import { type FetchFunction, withoutTrailingSlash, safeParseJSON, EventSourceParserStream } from "@ai-sdk/provider-utils"
import { OpenAICompatibleChatLanguageModel } from "../copilot/chat/openai-compatible-chat-language-model"
import { Log } from "@/util/log"
const VERSION = "1.0.0"
const log = Log.create({ service: "provider.toothfairyai" })
export type ToothFairyAIModelId = string
export interface ToothFairyAIProviderSettings {
apiKey?: string
workspaceId?: string
region?: "dev" | "au" | "eu" | "us"
baseURL?: string
name?: string
headers?: Record<string, string>
fetch?: FetchFunction
}
export interface ToothFairyAIProvider {
(modelId: ToothFairyAIModelId): LanguageModelV2
languageModel(modelId: ToothFairyAIModelId): LanguageModelV2
}
const REGION_STREAMING_URLS: Record<string, string> = {
dev: "https://ais.toothfairylab.link",
au: "https://ais.toothfairyai.com",
eu: "https://ais.eu.toothfairyai.com",
us: "https://ais.us.toothfairyai.com",
}
export function createToothFairyAI(options: ToothFairyAIProviderSettings = {}): ToothFairyAIProvider {
const region = options.region || "au"
const baseURL = withoutTrailingSlash(options.baseURL ?? REGION_STREAMING_URLS[region] ?? REGION_STREAMING_URLS.au)
const baseHeaders: Record<string, string> = {
...(options.apiKey && { "x-api-key": options.apiKey }),
...options.headers,
}
const workspaceId = options.workspaceId
const customFetch = async (input: RequestInfo | URL, init?: RequestInit) => {
const url = new URL(typeof input === "string" ? input : input instanceof URL ? input.href : (input as Request).url)
if (url.pathname === "/chat/completions") {
url.pathname = "/predictions"
log.debug("redirecting /chat/completions to /predictions")
}
let body: Record<string, unknown> | undefined
if (init?.body && typeof init.body === "string") {
try {
body = JSON.parse(init.body)
} catch {}
}
if (body && workspaceId) {
body.workspaceid = workspaceId
}
const headers = new Headers(init?.headers)
for (const [key, value] of Object.entries(baseHeaders)) {
headers.set(key, value)
}
headers.delete("Authorization")
log.info("request", {
url: url.href,
method: init?.method ?? "GET",
hasTools: !!body?.tools,
toolCount: (body?.tools as any[])?.length ?? 0,
model: body?.model,
})
log.debug("request body", {
messages: body?.messages
? (body.messages as any[]).map((m) => ({
role: m.role,
content: typeof m.content === "string" ? m.content.slice(0, 100) : m.content,
tool_calls: m.tool_calls,
tool_call_id: m.tool_call_id,
}))
: undefined,
tools: body?.tools
? (body.tools as any[]).map((t) => ({
name: t.function?.name,
description: t.function?.description?.slice(0, 50),
}))
: undefined,
tool_choice: body?.tool_choice,
max_tokens: body?.max_tokens,
temperature: body?.temperature,
})
const fetchFn = options.fetch ?? globalThis.fetch
const res = await fetchFn(url, {
...init,
headers,
body: body ? JSON.stringify(body) : init?.body,
})
log.info("response", {
status: res.status,
statusText: res.statusText,
contentType: res.headers.get("content-type"),
})
if (!res.ok) {
const errorBody = await res.text().catch(() => "Unable to read error body")
log.error("request failed", {
status: res.status,
statusText: res.statusText,
errorBody: errorBody.slice(0, 500),
})
}
if (res.body && res.headers.get("content-type")?.includes("text/event-stream")) {
const filteredStream = res.body
.pipeThrough(new TextDecoderStream())
.pipeThrough(new EventSourceParserStream())
.pipeThrough(
new TransformStream({
async transform({ data }, controller) {
if (data === "[DONE]") {
return
}
const parsed = await safeParseJSON({ text: data, schema: null })
if (!parsed.success) {
log.error("Failed to parse SSE chunk", {
chunk: data.slice(0, 100),
error: parsed.error,
})
controller.enqueue({ data })
return
}
const value = parsed.value
if (value.status === "initialising" || value.status === "connected") {
log.debug("filtered connection status", { status: value.status })
return
}
if (value.choices?.[0]?.finish_reason) {
log.info("stream finish_reason", {
finish_reason: value.choices[0].finish_reason,
})
}
if (value.usage) {
log.info("stream usage", {
prompt_tokens: value.usage.prompt_tokens,
completion_tokens: value.usage.completion_tokens,
total_tokens: value.usage.total_tokens,
})
}
controller.enqueue({ data })
},
}),
)
.pipeThrough(
new TransformStream({
transform({ data }, controller) {
controller.enqueue(new TextEncoder().encode(`data: ${data}\n\n`))
},
}),
)
return new Response(filteredStream, {
headers: res.headers,
status: res.status,
statusText: res.statusText,
})
}
return res
}
const createChatModel = (modelId: ToothFairyAIModelId) => {
return new OpenAICompatibleChatLanguageModel(modelId, {
provider: `${options.name ?? "toothfairyai"}.chat`,
headers: () => baseHeaders,
url: ({ path }) => `${baseURL}${path}`,
fetch: customFetch as FetchFunction,
includeUsage: true,
})
}
const createLanguageModel = (modelId: ToothFairyAIModelId) => createChatModel(modelId)
const provider = function (modelId: ToothFairyAIModelId) {
return createChatModel(modelId)
}
provider.languageModel = createLanguageModel
return provider as ToothFairyAIProvider
}