fix: tool management

This commit is contained in:
Gab 2026-03-25 16:20:15 +11:00
parent 8ecbe43e2d
commit 06db5cd050

View File

@ -1,8 +1,10 @@
import type { LanguageModelV2 } from "@ai-sdk/provider" import type { LanguageModelV2 } from "@ai-sdk/provider"
import { type FetchFunction, withoutTrailingSlash } from "@ai-sdk/provider-utils" import { type FetchFunction, withoutTrailingSlash } from "@ai-sdk/provider-utils"
import { OpenAICompatibleChatLanguageModel } from "../copilot/chat/openai-compatible-chat-language-model" import { OpenAICompatibleChatLanguageModel } from "../copilot/chat/openai-compatible-chat-language-model"
import { Log } from "@/util/log"
const VERSION = "1.0.0" const VERSION = "1.0.0"
const log = Log.create({ service: "provider.toothfairyai" })
export type ToothFairyAIModelId = string export type ToothFairyAIModelId = string
@ -44,6 +46,7 @@ export function createToothFairyAI(options: ToothFairyAIProviderSettings = {}):
if (url.pathname === "/chat/completions") { if (url.pathname === "/chat/completions") {
url.pathname = "/predictions" url.pathname = "/predictions"
log.debug("redirecting /chat/completions to /predictions")
} }
let body: Record<string, unknown> | undefined let body: Record<string, unknown> | undefined
@ -63,6 +66,34 @@ export function createToothFairyAI(options: ToothFairyAIProviderSettings = {}):
} }
headers.delete("Authorization") headers.delete("Authorization")
log.info("request", {
url: url.href,
method: init?.method ?? "GET",
hasTools: !!body?.tools,
toolCount: (body?.tools as any[])?.length ?? 0,
model: body?.model,
})
log.debug("request body", {
messages: body?.messages
? (body.messages as any[]).map((m) => ({
role: m.role,
content: typeof m.content === "string" ? m.content.slice(0, 100) : m.content,
tool_calls: m.tool_calls,
tool_call_id: m.tool_call_id,
}))
: undefined,
tools: body?.tools
? (body.tools as any[]).map((t) => ({
name: t.function?.name,
description: t.function?.description?.slice(0, 50),
}))
: undefined,
tool_choice: body?.tool_choice,
max_tokens: body?.max_tokens,
temperature: body?.temperature,
})
const fetchFn = options.fetch ?? globalThis.fetch const fetchFn = options.fetch ?? globalThis.fetch
const res = await fetchFn(url, { const res = await fetchFn(url, {
...init, ...init,
@ -70,6 +101,21 @@ export function createToothFairyAI(options: ToothFairyAIProviderSettings = {}):
body: body ? JSON.stringify(body) : init?.body, body: body ? JSON.stringify(body) : init?.body,
}) })
log.info("response", {
status: res.status,
statusText: res.statusText,
contentType: res.headers.get("content-type"),
})
if (!res.ok) {
const errorBody = await res.text().catch(() => "Unable to read error body")
log.error("request failed", {
status: res.status,
statusText: res.statusText,
errorBody: errorBody.slice(0, 500),
})
}
if (res.body && res.headers.get("content-type")?.includes("text/event-stream")) { if (res.body && res.headers.get("content-type")?.includes("text/event-stream")) {
const reader = res.body.getReader() const reader = res.body.getReader()
const decoder = new TextDecoder() const decoder = new TextDecoder()
@ -92,6 +138,20 @@ export function createToothFairyAI(options: ToothFairyAIProviderSettings = {}):
const json = line.slice(6).trim() const json = line.slice(6).trim()
if (json && !json.startsWith('{"status":')) { if (json && !json.startsWith('{"status":')) {
filtered.push(line) filtered.push(line)
// Log tool calls and finish_reason
try {
const parsed = JSON.parse(json)
if (parsed.choices?.[0]?.delta?.tool_calls) {
log.debug("stream tool_calls", {
tool_calls: parsed.choices[0].delta.tool_calls,
})
}
if (parsed.choices?.[0]?.finish_reason) {
log.info("stream finish_reason", {
finish_reason: parsed.choices[0].finish_reason,
})
}
} catch {}
} }
} else { } else {
filtered.push(line) filtered.push(line)