mirror of
https://gitea.toothfairyai.com/ToothFairyAI/tf_code.git
synced 2026-04-23 00:54:43 +00:00
204 lines
6.2 KiB
TypeScript
204 lines
6.2 KiB
TypeScript
import type { LanguageModelV2 } from "@ai-sdk/provider"
|
|
import { type FetchFunction, withoutTrailingSlash, safeParseJSON, EventSourceParserStream } from "@ai-sdk/provider-utils"
|
|
import { OpenAICompatibleChatLanguageModel } from "../copilot/chat/openai-compatible-chat-language-model"
|
|
import { Log } from "@/util/log"
|
|
|
|
const VERSION = "1.0.0"
|
|
const log = Log.create({ service: "provider.toothfairyai" })
|
|
|
|
export type ToothFairyAIModelId = string
|
|
|
|
export interface ToothFairyAIProviderSettings {
|
|
apiKey?: string
|
|
workspaceId?: string
|
|
region?: "dev" | "au" | "eu" | "us"
|
|
baseURL?: string
|
|
name?: string
|
|
headers?: Record<string, string>
|
|
fetch?: FetchFunction
|
|
}
|
|
|
|
export interface ToothFairyAIProvider {
|
|
(modelId: ToothFairyAIModelId): LanguageModelV2
|
|
languageModel(modelId: ToothFairyAIModelId): LanguageModelV2
|
|
}
|
|
|
|
const REGION_STREAMING_URLS: Record<string, string> = {
|
|
dev: "https://ais.toothfairylab.link",
|
|
au: "https://ais.toothfairyai.com",
|
|
eu: "https://ais.eu.toothfairyai.com",
|
|
us: "https://ais.us.toothfairyai.com",
|
|
}
|
|
|
|
export function createToothFairyAI(options: ToothFairyAIProviderSettings = {}): ToothFairyAIProvider {
|
|
const region = options.region || "au"
|
|
const baseURL = withoutTrailingSlash(options.baseURL ?? REGION_STREAMING_URLS[region] ?? REGION_STREAMING_URLS.au)
|
|
|
|
const baseHeaders: Record<string, string> = {
|
|
...(options.apiKey && { "x-api-key": options.apiKey }),
|
|
...options.headers,
|
|
}
|
|
|
|
const workspaceId = options.workspaceId
|
|
|
|
const customFetch = async (input: RequestInfo | URL, init?: RequestInit) => {
|
|
const url = new URL(typeof input === "string" ? input : input instanceof URL ? input.href : (input as Request).url)
|
|
|
|
if (url.pathname === "/chat/completions") {
|
|
url.pathname = "/predictions"
|
|
log.debug("redirecting /chat/completions to /predictions")
|
|
}
|
|
|
|
let body: Record<string, unknown> | undefined
|
|
if (init?.body && typeof init.body === "string") {
|
|
try {
|
|
body = JSON.parse(init.body)
|
|
} catch {}
|
|
}
|
|
|
|
if (body && workspaceId) {
|
|
body.workspaceid = workspaceId
|
|
}
|
|
|
|
const headers = new Headers(init?.headers)
|
|
for (const [key, value] of Object.entries(baseHeaders)) {
|
|
headers.set(key, value)
|
|
}
|
|
headers.delete("Authorization")
|
|
|
|
log.info("request", {
|
|
url: url.href,
|
|
method: init?.method ?? "GET",
|
|
hasTools: !!body?.tools,
|
|
toolCount: (body?.tools as any[])?.length ?? 0,
|
|
model: body?.model,
|
|
})
|
|
|
|
log.debug("request body", {
|
|
messages: body?.messages
|
|
? (body.messages as any[]).map((m) => ({
|
|
role: m.role,
|
|
content: typeof m.content === "string" ? m.content.slice(0, 100) : m.content,
|
|
tool_calls: m.tool_calls,
|
|
tool_call_id: m.tool_call_id,
|
|
}))
|
|
: undefined,
|
|
tools: body?.tools
|
|
? (body.tools as any[]).map((t) => ({
|
|
name: t.function?.name,
|
|
description: t.function?.description?.slice(0, 50),
|
|
}))
|
|
: undefined,
|
|
tool_choice: body?.tool_choice,
|
|
max_tokens: body?.max_tokens,
|
|
temperature: body?.temperature,
|
|
})
|
|
|
|
const fetchFn = options.fetch ?? globalThis.fetch
|
|
const res = await fetchFn(url, {
|
|
...init,
|
|
headers,
|
|
body: body ? JSON.stringify(body) : init?.body,
|
|
})
|
|
|
|
log.info("response", {
|
|
status: res.status,
|
|
statusText: res.statusText,
|
|
contentType: res.headers.get("content-type"),
|
|
})
|
|
|
|
if (!res.ok) {
|
|
const errorBody = await res.text().catch(() => "Unable to read error body")
|
|
log.error("request failed", {
|
|
status: res.status,
|
|
statusText: res.statusText,
|
|
errorBody: errorBody.slice(0, 500),
|
|
})
|
|
}
|
|
|
|
if (res.body && res.headers.get("content-type")?.includes("text/event-stream")) {
|
|
const filteredStream = res.body
|
|
.pipeThrough(new TextDecoderStream())
|
|
.pipeThrough(new EventSourceParserStream())
|
|
.pipeThrough(
|
|
new TransformStream({
|
|
async transform({ data }, controller) {
|
|
if (data === "[DONE]") {
|
|
return
|
|
}
|
|
|
|
const parsed = await safeParseJSON({ text: data, schema: null })
|
|
|
|
if (!parsed.success) {
|
|
log.error("Failed to parse SSE chunk", {
|
|
chunk: data.slice(0, 100),
|
|
error: parsed.error,
|
|
})
|
|
controller.enqueue({ data })
|
|
return
|
|
}
|
|
|
|
const value = parsed.value
|
|
|
|
if (value.status === "initialising" || value.status === "connected") {
|
|
log.debug("filtered connection status", { status: value.status })
|
|
return
|
|
}
|
|
|
|
if (value.choices?.[0]?.finish_reason) {
|
|
log.info("stream finish_reason", {
|
|
finish_reason: value.choices[0].finish_reason,
|
|
})
|
|
}
|
|
|
|
if (value.usage) {
|
|
log.info("stream usage", {
|
|
prompt_tokens: value.usage.prompt_tokens,
|
|
completion_tokens: value.usage.completion_tokens,
|
|
total_tokens: value.usage.total_tokens,
|
|
})
|
|
}
|
|
|
|
controller.enqueue({ data })
|
|
},
|
|
}),
|
|
)
|
|
.pipeThrough(
|
|
new TransformStream({
|
|
transform({ data }, controller) {
|
|
controller.enqueue(new TextEncoder().encode(`data: ${data}\n\n`))
|
|
},
|
|
}),
|
|
)
|
|
|
|
return new Response(filteredStream, {
|
|
headers: res.headers,
|
|
status: res.status,
|
|
statusText: res.statusText,
|
|
})
|
|
}
|
|
|
|
return res
|
|
}
|
|
|
|
const createChatModel = (modelId: ToothFairyAIModelId) => {
|
|
return new OpenAICompatibleChatLanguageModel(modelId, {
|
|
provider: `${options.name ?? "toothfairyai"}.chat`,
|
|
headers: () => baseHeaders,
|
|
url: ({ path }) => `${baseURL}${path}`,
|
|
fetch: customFetch as FetchFunction,
|
|
includeUsage: true,
|
|
})
|
|
}
|
|
|
|
const createLanguageModel = (modelId: ToothFairyAIModelId) => createChatModel(modelId)
|
|
|
|
const provider = function (modelId: ToothFairyAIModelId) {
|
|
return createChatModel(modelId)
|
|
}
|
|
|
|
provider.languageModel = createLanguageModel
|
|
|
|
return provider as ToothFairyAIProvider
|
|
}
|