mirror of
https://gitea.toothfairyai.com/ToothFairyAI/tf_code.git
synced 2026-03-30 05:43:55 +00:00
feat: better preidctiosn
This commit is contained in:
parent
a42f8fa99f
commit
12ae1cb9b5
2
bun.lock
2
bun.lock
@ -381,7 +381,7 @@
|
|||||||
},
|
},
|
||||||
"packages/tfcode": {
|
"packages/tfcode": {
|
||||||
"name": "tfcode",
|
"name": "tfcode",
|
||||||
"version": "1.0.6",
|
"version": "1.0.7",
|
||||||
"bin": {
|
"bin": {
|
||||||
"tfcode": "./bin/tfcode",
|
"tfcode": "./bin/tfcode",
|
||||||
},
|
},
|
||||||
|
|||||||
@ -1,6 +1,6 @@
|
|||||||
{
|
{
|
||||||
"$schema": "https://json.schemastore.org/package.json",
|
"$schema": "https://json.schemastore.org/package.json",
|
||||||
"version": "1.0.6",
|
"version": "1.0.7",
|
||||||
"name": "tfcode",
|
"name": "tfcode",
|
||||||
"type": "module",
|
"type": "module",
|
||||||
"license": "MIT",
|
"license": "MIT",
|
||||||
|
|||||||
179
packages/tfcode/src/cli/cmd/tui/component/dialog-tf-hooks.tsx
Normal file
179
packages/tfcode/src/cli/cmd/tui/component/dialog-tf-hooks.tsx
Normal file
@ -0,0 +1,179 @@
|
|||||||
|
import { createMemo, createSignal, createResource } from "solid-js"
|
||||||
|
import { DialogSelect, type DialogSelectRef, type DialogSelectOption } from "@tui/ui/dialog-select"
|
||||||
|
import { useTheme } from "../context/theme"
|
||||||
|
import { TextAttributes } from "@opentui/core"
|
||||||
|
import { Global } from "@/global"
|
||||||
|
import path from "path"
|
||||||
|
import { useToast } from "../ui/toast"
|
||||||
|
import { Keybind } from "@/util/keybind"
|
||||||
|
|
||||||
|
interface Hook {
|
||||||
|
id: string
|
||||||
|
name: string
|
||||||
|
description?: string
|
||||||
|
remote_environment_name?: string
|
||||||
|
code_execution_instructions?: string
|
||||||
|
predefined_code_snippet?: string
|
||||||
|
execute_as_static_script?: boolean
|
||||||
|
execute_as_python_tool?: boolean
|
||||||
|
secrets_injected?: boolean
|
||||||
|
created_at?: string
|
||||||
|
updated_at?: string
|
||||||
|
}
|
||||||
|
|
||||||
|
const TF_DEFAULT_REGION = "au"
|
||||||
|
|
||||||
|
const REGION_API_URLS: Record<string, string> = {
|
||||||
|
dev: "https://api.toothfairylab.link",
|
||||||
|
au: "https://api.toothfairyai.com",
|
||||||
|
eu: "https://api.eu.toothfairyai.com",
|
||||||
|
us: "https://api.us.toothfairyai.com",
|
||||||
|
}
|
||||||
|
|
||||||
|
function HookStatus(props: { hook: Hook }) {
|
||||||
|
const { theme } = useTheme()
|
||||||
|
const features = []
|
||||||
|
if (props.hook.execute_as_static_script) features.push("Static")
|
||||||
|
if (props.hook.execute_as_python_tool) features.push("Python Tool")
|
||||||
|
if (props.hook.secrets_injected) features.push("Secrets")
|
||||||
|
if (props.hook.remote_environment_name) features.push("Remote")
|
||||||
|
|
||||||
|
if (features.length === 0) {
|
||||||
|
return <span style={{ fg: theme.textMuted }}>Default</span>
|
||||||
|
}
|
||||||
|
|
||||||
|
return (
|
||||||
|
<span style={{ fg: theme.success }}>
|
||||||
|
{features.map((f, i) => (
|
||||||
|
<>
|
||||||
|
{i > 0 && <span style={{ fg: theme.textMuted }}>·</span>}
|
||||||
|
{f}
|
||||||
|
</>
|
||||||
|
))}
|
||||||
|
</span>
|
||||||
|
)
|
||||||
|
}
|
||||||
|
|
||||||
|
export function DialogTfHooks() {
|
||||||
|
const toast = useToast()
|
||||||
|
const { theme } = useTheme()
|
||||||
|
const [, setRef] = createSignal<DialogSelectRef<unknown>>()
|
||||||
|
const [loading, setLoading] = createSignal(false)
|
||||||
|
const [refreshKey, setRefreshKey] = createSignal(0)
|
||||||
|
|
||||||
|
const [credentials] = createResource(refreshKey, async () => {
|
||||||
|
try {
|
||||||
|
const credPath = path.join(Global.Path.data, ".tfcode", "credentials.json")
|
||||||
|
const data = (await Bun.file(credPath).json()) as {
|
||||||
|
api_key?: string
|
||||||
|
workspace_id?: string
|
||||||
|
region?: string
|
||||||
|
}
|
||||||
|
return data
|
||||||
|
} catch {
|
||||||
|
return {}
|
||||||
|
}
|
||||||
|
})
|
||||||
|
|
||||||
|
const [hooks, { refetch: refetchHooks }] = createResource(refreshKey, async () => {
|
||||||
|
const creds = credentials()
|
||||||
|
if (!creds?.api_key || !creds?.workspace_id) {
|
||||||
|
return []
|
||||||
|
}
|
||||||
|
|
||||||
|
setLoading(true)
|
||||||
|
try {
|
||||||
|
const region = creds.region || TF_DEFAULT_REGION
|
||||||
|
const baseUrl = REGION_API_URLS[region] || REGION_API_URLS[TF_DEFAULT_REGION]
|
||||||
|
|
||||||
|
const response = await fetch(`${baseUrl}/hook/list`, {
|
||||||
|
method: "GET",
|
||||||
|
headers: {
|
||||||
|
"x-api-key": creds.api_key,
|
||||||
|
},
|
||||||
|
})
|
||||||
|
|
||||||
|
if (!response.ok) {
|
||||||
|
const errorText = await response.text()
|
||||||
|
throw new Error(`HTTP ${response.status}: ${errorText}`)
|
||||||
|
}
|
||||||
|
|
||||||
|
const data = (await response.json()) as { success?: boolean; hooks?: Hook[] }
|
||||||
|
return data.hooks || []
|
||||||
|
} catch (error) {
|
||||||
|
toast.show({
|
||||||
|
variant: "error",
|
||||||
|
message: `Failed to fetch hooks: ${error}`,
|
||||||
|
duration: 5000,
|
||||||
|
})
|
||||||
|
return []
|
||||||
|
} finally {
|
||||||
|
setLoading(false)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
|
||||||
|
const options = createMemo<DialogSelectOption<Hook>[]>(() => {
|
||||||
|
const hooksList = hooks() || []
|
||||||
|
const isLoading = loading()
|
||||||
|
|
||||||
|
if (isLoading) {
|
||||||
|
return [
|
||||||
|
{
|
||||||
|
value: { id: "loading", name: "Loading..." } as Hook,
|
||||||
|
title: "Loading hooks...",
|
||||||
|
description: "Fetching CodeExecutionEnvironments from ToothFairyAI",
|
||||||
|
footer: <span style={{ fg: theme.textMuted }}>⋯</span>,
|
||||||
|
category: "Code Execution Environments",
|
||||||
|
},
|
||||||
|
]
|
||||||
|
}
|
||||||
|
|
||||||
|
if (hooksList.length === 0) {
|
||||||
|
return [
|
||||||
|
{
|
||||||
|
value: { id: "empty", name: "No hooks found" } as Hook,
|
||||||
|
title: "No CodeExecutionEnvironments found",
|
||||||
|
description: "Create hooks in ToothFairyAI Settings > Hooks",
|
||||||
|
footer: <span style={{ fg: theme.textMuted }}>○</span>,
|
||||||
|
category: "Code Execution Environments",
|
||||||
|
},
|
||||||
|
]
|
||||||
|
}
|
||||||
|
|
||||||
|
return hooksList.map((hook) => ({
|
||||||
|
value: hook,
|
||||||
|
title: hook.name,
|
||||||
|
description: hook.description || "No description",
|
||||||
|
footer: <HookStatus hook={hook} />,
|
||||||
|
category: "Code Execution Environments",
|
||||||
|
}))
|
||||||
|
})
|
||||||
|
|
||||||
|
const keybinds = createMemo(() => [
|
||||||
|
{
|
||||||
|
keybind: Keybind.parse("space")[0],
|
||||||
|
title: "refresh",
|
||||||
|
onTrigger: async () => {
|
||||||
|
setRefreshKey((k) => k + 1)
|
||||||
|
toast.show({
|
||||||
|
variant: "info",
|
||||||
|
message: "Refreshing hooks...",
|
||||||
|
duration: 2000,
|
||||||
|
})
|
||||||
|
},
|
||||||
|
},
|
||||||
|
])
|
||||||
|
|
||||||
|
return (
|
||||||
|
<DialogSelect
|
||||||
|
ref={setRef}
|
||||||
|
title="Code Execution Environments"
|
||||||
|
options={options()}
|
||||||
|
keybind={keybinds()}
|
||||||
|
onSelect={(option) => {
|
||||||
|
if (option.value.id === "loading" || option.value.id === "empty") return
|
||||||
|
// Don't close on select - just show the hook details
|
||||||
|
}}
|
||||||
|
/>
|
||||||
|
)
|
||||||
|
}
|
||||||
@ -36,6 +36,7 @@ import { useKV } from "../../context/kv"
|
|||||||
import { useTextareaKeybindings } from "../textarea-keybindings"
|
import { useTextareaKeybindings } from "../textarea-keybindings"
|
||||||
import { DialogSkill } from "../dialog-skill"
|
import { DialogSkill } from "../dialog-skill"
|
||||||
import { DialogTfMcp } from "../dialog-tf-mcp"
|
import { DialogTfMcp } from "../dialog-tf-mcp"
|
||||||
|
import { DialogTfHooks } from "../dialog-tf-hooks"
|
||||||
|
|
||||||
export type PromptProps = {
|
export type PromptProps = {
|
||||||
sessionID?: string
|
sessionID?: string
|
||||||
@ -365,6 +366,17 @@ export function Prompt(props: PromptProps) {
|
|||||||
dialog.replace(() => <DialogTfMcp />)
|
dialog.replace(() => <DialogTfMcp />)
|
||||||
},
|
},
|
||||||
},
|
},
|
||||||
|
{
|
||||||
|
title: "Hooks",
|
||||||
|
value: "prompt.hooks",
|
||||||
|
category: "Prompt",
|
||||||
|
slash: {
|
||||||
|
name: "hooks",
|
||||||
|
},
|
||||||
|
onSelect: () => {
|
||||||
|
dialog.replace(() => <DialogTfHooks />)
|
||||||
|
},
|
||||||
|
},
|
||||||
]
|
]
|
||||||
})
|
})
|
||||||
|
|
||||||
|
|||||||
@ -136,7 +136,19 @@ export function createToothFairyAI(options: ToothFairyAIProviderSettings = {}):
|
|||||||
for (const line of lines) {
|
for (const line of lines) {
|
||||||
if (line.startsWith("data: ")) {
|
if (line.startsWith("data: ")) {
|
||||||
const json = line.slice(6).trim()
|
const json = line.slice(6).trim()
|
||||||
if (json && !json.startsWith('{"status":')) {
|
// Filter out connection status messages like {"status":"initialising"}, {"status":"connected"}
|
||||||
|
// These are internal progress indicators, not OpenAI-format chunks
|
||||||
|
if (json) {
|
||||||
|
try {
|
||||||
|
const parsed = JSON.parse(json)
|
||||||
|
if (parsed.status === "initialising" || parsed.status === "connected") {
|
||||||
|
log.debug("filtered connection status", { status: parsed.status })
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
} catch {
|
||||||
|
// Not valid JSON, keep the line
|
||||||
|
}
|
||||||
|
}
|
||||||
filtered.push(line)
|
filtered.push(line)
|
||||||
// Log tool calls and finish_reason
|
// Log tool calls and finish_reason
|
||||||
try {
|
try {
|
||||||
@ -159,7 +171,6 @@ export function createToothFairyAI(options: ToothFairyAIProviderSettings = {}):
|
|||||||
})
|
})
|
||||||
}
|
}
|
||||||
} catch {}
|
} catch {}
|
||||||
}
|
|
||||||
} else {
|
} else {
|
||||||
filtered.push(line)
|
filtered.push(line)
|
||||||
}
|
}
|
||||||
|
|||||||
423
packages/tfcode/test/provider/toothfairyai.test.ts
Normal file
423
packages/tfcode/test/provider/toothfairyai.test.ts
Normal file
@ -0,0 +1,423 @@
|
|||||||
|
import { test, expect, describe } from "bun:test"
|
||||||
|
import { createToothFairyAI } from "../../src/provider/sdk/toothfairyai"
|
||||||
|
import type { LanguageModelV2StreamPart } from "@ai-sdk/provider"
|
||||||
|
import type { FetchFunction } from "@ai-sdk/provider-utils"
|
||||||
|
|
||||||
|
async function collectStream(stream: ReadableStream<LanguageModelV2StreamPart>): Promise<LanguageModelV2StreamPart[]> {
|
||||||
|
const chunks: LanguageModelV2StreamPart[] = []
|
||||||
|
const reader = stream.getReader()
|
||||||
|
while (true) {
|
||||||
|
const { done, value } = await reader.read()
|
||||||
|
if (done) break
|
||||||
|
chunks.push(value)
|
||||||
|
}
|
||||||
|
return chunks
|
||||||
|
}
|
||||||
|
|
||||||
|
function createValidResponse(): Response {
|
||||||
|
return new Response(
|
||||||
|
JSON.stringify({
|
||||||
|
id: "test-id",
|
||||||
|
object: "chat.completion",
|
||||||
|
created: 123,
|
||||||
|
model: "test",
|
||||||
|
choices: [
|
||||||
|
{
|
||||||
|
index: 0,
|
||||||
|
message: {
|
||||||
|
role: "assistant",
|
||||||
|
content: "Test response",
|
||||||
|
},
|
||||||
|
finish_reason: "stop",
|
||||||
|
},
|
||||||
|
],
|
||||||
|
usage: {
|
||||||
|
prompt_tokens: 10,
|
||||||
|
completion_tokens: 5,
|
||||||
|
total_tokens: 15,
|
||||||
|
},
|
||||||
|
}),
|
||||||
|
{
|
||||||
|
status: 200,
|
||||||
|
headers: { "content-type": "application/json" },
|
||||||
|
},
|
||||||
|
)
|
||||||
|
}
|
||||||
|
|
||||||
|
function createSSEResponse(chunks: string[]): Response {
|
||||||
|
const encoder = new TextEncoder()
|
||||||
|
const stream = new ReadableStream({
|
||||||
|
start(controller) {
|
||||||
|
for (const chunk of chunks) {
|
||||||
|
controller.enqueue(encoder.encode(chunk))
|
||||||
|
}
|
||||||
|
controller.close()
|
||||||
|
},
|
||||||
|
})
|
||||||
|
|
||||||
|
return new Response(stream, {
|
||||||
|
status: 200,
|
||||||
|
headers: { "content-type": "text/event-stream" },
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
describe("ToothFairyAI Provider", () => {
|
||||||
|
test("creates provider with default settings", () => {
|
||||||
|
const provider = createToothFairyAI({
|
||||||
|
apiKey: "test-key",
|
||||||
|
workspaceId: "test-workspace",
|
||||||
|
})
|
||||||
|
expect(provider).toBeDefined()
|
||||||
|
expect(typeof provider).toBe("function")
|
||||||
|
expect(provider.languageModel).toBeDefined()
|
||||||
|
})
|
||||||
|
|
||||||
|
test("creates provider with custom region", () => {
|
||||||
|
const provider = createToothFairyAI({
|
||||||
|
apiKey: "test-key",
|
||||||
|
workspaceId: "test-workspace",
|
||||||
|
region: "us",
|
||||||
|
})
|
||||||
|
expect(provider).toBeDefined()
|
||||||
|
})
|
||||||
|
|
||||||
|
test("redirects /chat/completions to /predictions", async () => {
|
||||||
|
let capturedUrl: string | undefined
|
||||||
|
const provider = createToothFairyAI({
|
||||||
|
apiKey: "test-key",
|
||||||
|
workspaceId: "test-workspace",
|
||||||
|
fetch: (async (input) => {
|
||||||
|
capturedUrl = typeof input === "string" ? input : input instanceof URL ? input.href : (input as Request).url
|
||||||
|
return createValidResponse()
|
||||||
|
}) as FetchFunction,
|
||||||
|
})
|
||||||
|
|
||||||
|
const model = provider("test-model")
|
||||||
|
await model.doGenerate({
|
||||||
|
prompt: [{ role: "user", content: [{ type: "text", text: "test" }] }],
|
||||||
|
})
|
||||||
|
|
||||||
|
expect(capturedUrl).toBeDefined()
|
||||||
|
const url = new URL(capturedUrl!)
|
||||||
|
expect(url.pathname).toBe("/predictions")
|
||||||
|
})
|
||||||
|
|
||||||
|
test("adds workspaceid to request body", async () => {
|
||||||
|
let capturedBody: string | undefined
|
||||||
|
const provider = createToothFairyAI({
|
||||||
|
apiKey: "test-key",
|
||||||
|
workspaceId: "test-workspace",
|
||||||
|
fetch: (async (_input, init) => {
|
||||||
|
capturedBody = init?.body as string
|
||||||
|
return createValidResponse()
|
||||||
|
}) as FetchFunction,
|
||||||
|
})
|
||||||
|
|
||||||
|
const model = provider("test-model")
|
||||||
|
await model.doGenerate({
|
||||||
|
prompt: [{ role: "user", content: [{ type: "text", text: "test" }] }],
|
||||||
|
})
|
||||||
|
|
||||||
|
expect(capturedBody).toBeDefined()
|
||||||
|
const body = JSON.parse(capturedBody!)
|
||||||
|
expect(body.workspaceid).toBe("test-workspace")
|
||||||
|
})
|
||||||
|
|
||||||
|
test("sets x-api-key header", async () => {
|
||||||
|
let capturedHeaders: Headers | undefined
|
||||||
|
const provider = createToothFairyAI({
|
||||||
|
apiKey: "test-key",
|
||||||
|
workspaceId: "test-workspace",
|
||||||
|
fetch: (async (_input, init) => {
|
||||||
|
capturedHeaders = new Headers(init?.headers as HeadersInit)
|
||||||
|
return createValidResponse()
|
||||||
|
}) as FetchFunction,
|
||||||
|
})
|
||||||
|
|
||||||
|
const model = provider("test-model")
|
||||||
|
await model.doGenerate({
|
||||||
|
prompt: [{ role: "user", content: [{ type: "text", text: "test" }] }],
|
||||||
|
})
|
||||||
|
|
||||||
|
expect(capturedHeaders).toBeDefined()
|
||||||
|
expect(capturedHeaders!.get("x-api-key")).toBe("test-key")
|
||||||
|
})
|
||||||
|
|
||||||
|
test("removes Authorization header", async () => {
|
||||||
|
let capturedHeaders: Headers | undefined
|
||||||
|
const provider = createToothFairyAI({
|
||||||
|
apiKey: "test-key",
|
||||||
|
workspaceId: "test-workspace",
|
||||||
|
fetch: (async (_input, init) => {
|
||||||
|
capturedHeaders = new Headers(init?.headers as HeadersInit)
|
||||||
|
return createValidResponse()
|
||||||
|
}) as FetchFunction,
|
||||||
|
})
|
||||||
|
|
||||||
|
const model = provider("test-model")
|
||||||
|
await model.doGenerate({
|
||||||
|
prompt: [{ role: "user", content: [{ type: "text", text: "test" }] }],
|
||||||
|
headers: { Authorization: "Bearer should-be-removed" },
|
||||||
|
})
|
||||||
|
|
||||||
|
expect(capturedHeaders).toBeDefined()
|
||||||
|
expect(capturedHeaders!.has("Authorization")).toBe(false)
|
||||||
|
})
|
||||||
|
})
|
||||||
|
|
||||||
|
describe("ToothFairyAI Streaming", () => {
|
||||||
|
test("filters connection status messages", async () => {
|
||||||
|
const sseChunks = [
|
||||||
|
'data: {"status": "initialising"}\n\n',
|
||||||
|
'data: {"status": "connected"}\n\n',
|
||||||
|
'data: {"id": "test-id", "object": "chat.completion.chunk", "created": 123, "model": "test", "choices": [{"index": 0, "delta": {"role": "assistant"}, "finish_reason": null}]}\n\n',
|
||||||
|
'data: {"id": "test-id", "object": "chat.completion.chunk", "created": 123, "model": "test", "choices": [{"index": 0, "delta": {"content": "Hello"}, "finish_reason": null}]}\n\n',
|
||||||
|
'data: {"id": "test-id", "object": "chat.completion.chunk", "created": 123, "model": "test", "choices": [{"index": 0, "delta": {}, "finish_reason": "stop"}]}\n\n',
|
||||||
|
"data: [DONE]\n\n",
|
||||||
|
]
|
||||||
|
|
||||||
|
const provider = createToothFairyAI({
|
||||||
|
apiKey: "test-key",
|
||||||
|
workspaceId: "test-workspace",
|
||||||
|
fetch: (() => createSSEResponse(sseChunks)) as unknown as FetchFunction,
|
||||||
|
})
|
||||||
|
|
||||||
|
const model = provider("test-model")
|
||||||
|
const result = await model.doStream({
|
||||||
|
prompt: [{ role: "user", content: [{ type: "text", text: "test" }] }],
|
||||||
|
})
|
||||||
|
|
||||||
|
const chunks = await collectStream(result.stream)
|
||||||
|
const textDeltas = chunks.filter((c) => c.type === "text-delta")
|
||||||
|
expect(textDeltas.length).toBeGreaterThan(0)
|
||||||
|
})
|
||||||
|
|
||||||
|
test("handles reasoning_content in stream", async () => {
|
||||||
|
const sseChunks = [
|
||||||
|
'data: {"id": "test-id", "object": "chat.completion.chunk", "created": 123, "model": "test", "choices": [{"index": 0, "delta": {"reasoning_content": "Let me think..."}, "finish_reason": null}]}\n\n',
|
||||||
|
'data: {"id": "test-id", "object": "chat.completion.chunk", "created": 123, "model": "test", "choices": [{"index": 0, "delta": {"content": "Answer"}, "finish_reason": null}]}\n\n',
|
||||||
|
'data: {"id": "test-id", "object": "chat.completion.chunk", "created": 123, "model": "test", "choices": [{"index": 0, "delta": {}, "finish_reason": "stop"}]}\n\n',
|
||||||
|
"data: [DONE]\n\n",
|
||||||
|
]
|
||||||
|
|
||||||
|
const provider = createToothFairyAI({
|
||||||
|
apiKey: "test-key",
|
||||||
|
workspaceId: "test-workspace",
|
||||||
|
fetch: (() => createSSEResponse(sseChunks)) as unknown as FetchFunction,
|
||||||
|
})
|
||||||
|
|
||||||
|
const model = provider("test-model")
|
||||||
|
const result = await model.doStream({
|
||||||
|
prompt: [{ role: "user", content: [{ type: "text", text: "test" }] }],
|
||||||
|
})
|
||||||
|
|
||||||
|
const chunks = await collectStream(result.stream)
|
||||||
|
const reasoningDeltas = chunks.filter((c) => c.type === "reasoning-delta")
|
||||||
|
expect(reasoningDeltas.length).toBeGreaterThan(0)
|
||||||
|
})
|
||||||
|
|
||||||
|
test("handles tool_calls in stream", async () => {
|
||||||
|
const sseChunks = [
|
||||||
|
'data: {"id": "test-id", "object": "chat.completion.chunk", "created": 123, "model": "test", "choices": [{"index": 0, "delta": {"tool_calls": [{"index": 0, "id": "call_123", "type": "function", "function": {"name": "get_weather", "arguments": "{\\"location\\": \\"Sydney\\"}"}}]}, "finish_reason": null}]}\n\n',
|
||||||
|
'data: {"id": "test-id", "object": "chat.completion.chunk", "created": 123, "model": "test", "choices": [{"index": 0, "delta": {}, "finish_reason": "tool_calls"}]}\n\n',
|
||||||
|
"data: [DONE]\n\n",
|
||||||
|
]
|
||||||
|
|
||||||
|
const provider = createToothFairyAI({
|
||||||
|
apiKey: "test-key",
|
||||||
|
workspaceId: "test-workspace",
|
||||||
|
fetch: (() => createSSEResponse(sseChunks)) as unknown as FetchFunction,
|
||||||
|
})
|
||||||
|
|
||||||
|
const model = provider("test-model")
|
||||||
|
const result = await model.doStream({
|
||||||
|
prompt: [{ role: "user", content: [{ type: "text", text: "test" }] }],
|
||||||
|
})
|
||||||
|
|
||||||
|
const chunks = await collectStream(result.stream)
|
||||||
|
const toolCalls = chunks.filter((c) => c.type === "tool-call")
|
||||||
|
expect(toolCalls.length).toBe(1)
|
||||||
|
})
|
||||||
|
|
||||||
|
test("handles error chunks", async () => {
|
||||||
|
const sseChunks = [
|
||||||
|
'data: {"error": {"message": "Model parameter is required", "type": "invalid_request_error", "code": 400}}\n\n',
|
||||||
|
]
|
||||||
|
|
||||||
|
const provider = createToothFairyAI({
|
||||||
|
apiKey: "test-key",
|
||||||
|
workspaceId: "test-workspace",
|
||||||
|
fetch: (() => createSSEResponse(sseChunks)) as unknown as FetchFunction,
|
||||||
|
})
|
||||||
|
|
||||||
|
const model = provider("test-model")
|
||||||
|
const result = await model.doStream({
|
||||||
|
prompt: [{ role: "user", content: [{ type: "text", text: "test" }] }],
|
||||||
|
})
|
||||||
|
|
||||||
|
const chunks = await collectStream(result.stream)
|
||||||
|
const errorChunks = chunks.filter((c) => c.type === "error")
|
||||||
|
expect(errorChunks.length).toBeGreaterThan(0)
|
||||||
|
})
|
||||||
|
|
||||||
|
test("handles usage information", async () => {
|
||||||
|
const sseChunks = [
|
||||||
|
'data: {"id": "test-id", "object": "chat.completion.chunk", "created": 123, "model": "test", "choices": [{"index": 0, "delta": {"role": "assistant"}, "finish_reason": null}]}\n\n',
|
||||||
|
'data: {"id": "test-id", "object": "chat.completion.chunk", "created": 123, "model": "test", "choices": [{"index": 0, "delta": {"content": "Hello"}, "finish_reason": null}]}\n\n',
|
||||||
|
'data: {"id": "test-id", "object": "chat.completion.chunk", "created": 123, "model": "test", "choices": [{"index": 0, "delta": {}, "finish_reason": "stop"}], "usage": {"prompt_tokens": 10, "completion_tokens": 5, "total_tokens": 15}}\n\n',
|
||||||
|
"data: [DONE]\n\n",
|
||||||
|
]
|
||||||
|
|
||||||
|
const provider = createToothFairyAI({
|
||||||
|
apiKey: "test-key",
|
||||||
|
workspaceId: "test-workspace",
|
||||||
|
fetch: (() => createSSEResponse(sseChunks)) as unknown as FetchFunction,
|
||||||
|
})
|
||||||
|
|
||||||
|
const model = provider("test-model")
|
||||||
|
const result = await model.doStream({
|
||||||
|
prompt: [{ role: "user", content: [{ type: "text", text: "test" }] }],
|
||||||
|
})
|
||||||
|
|
||||||
|
const chunks = await collectStream(result.stream)
|
||||||
|
const finishChunks = chunks.filter((c) => c.type === "finish")
|
||||||
|
expect(finishChunks.length).toBe(1)
|
||||||
|
const finish = finishChunks[0] as any
|
||||||
|
expect(finish.usage?.inputTokens).toBe(10)
|
||||||
|
expect(finish.usage?.outputTokens).toBe(5)
|
||||||
|
})
|
||||||
|
|
||||||
|
test("handles different finish reasons", async () => {
|
||||||
|
const finishReasons = ["stop", "length", "tool_calls"]
|
||||||
|
|
||||||
|
for (const reason of finishReasons) {
|
||||||
|
const sseChunks = [
|
||||||
|
'data: {"id": "test-id", "object": "chat.completion.chunk", "created": 123, "model": "test", "choices": [{"index": 0, "delta": {"role": "assistant"}, "finish_reason": null}]}\n\n',
|
||||||
|
'data: {"id": "test-id", "object": "chat.completion.chunk", "created": 123, "model": "test", "choices": [{"index": 0, "delta": {"content": "Hello"}, "finish_reason": null}]}\n\n',
|
||||||
|
`data: {"id": "test-id", "object": "chat.completion.chunk", "created": 123, "model": "test", "choices": [{"index": 0, "delta": {}, "finish_reason": "${reason}"}]}\n\n`,
|
||||||
|
"data: [DONE]\n\n",
|
||||||
|
]
|
||||||
|
|
||||||
|
const provider = createToothFairyAI({
|
||||||
|
apiKey: "test-key",
|
||||||
|
workspaceId: "test-workspace",
|
||||||
|
fetch: (() => createSSEResponse(sseChunks)) as unknown as FetchFunction,
|
||||||
|
})
|
||||||
|
|
||||||
|
const model = provider("test-model")
|
||||||
|
const result = await model.doStream({
|
||||||
|
prompt: [{ role: "user", content: [{ type: "text", text: "test" }] }],
|
||||||
|
})
|
||||||
|
|
||||||
|
const chunks = await collectStream(result.stream)
|
||||||
|
const finishChunks = chunks.filter((c) => c.type === "finish")
|
||||||
|
expect(finishChunks.length).toBe(1)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
})
|
||||||
|
|
||||||
|
describe("ToothFairyAI Non-streaming", () => {
|
||||||
|
test("handles non-streaming response with reasoning_content", async () => {
|
||||||
|
const provider = createToothFairyAI({
|
||||||
|
apiKey: "test-key",
|
||||||
|
workspaceId: "test-workspace",
|
||||||
|
fetch: (() =>
|
||||||
|
new Response(
|
||||||
|
JSON.stringify({
|
||||||
|
id: "test-id",
|
||||||
|
object: "chat.completion",
|
||||||
|
created: 123,
|
||||||
|
model: "test",
|
||||||
|
choices: [
|
||||||
|
{
|
||||||
|
index: 0,
|
||||||
|
message: {
|
||||||
|
role: "assistant",
|
||||||
|
reasoning_content: "Let me think about this...",
|
||||||
|
content: "The answer is 42.",
|
||||||
|
},
|
||||||
|
finish_reason: "stop",
|
||||||
|
},
|
||||||
|
],
|
||||||
|
usage: {
|
||||||
|
prompt_tokens: 10,
|
||||||
|
completion_tokens: 20,
|
||||||
|
total_tokens: 30,
|
||||||
|
},
|
||||||
|
}),
|
||||||
|
{
|
||||||
|
status: 200,
|
||||||
|
headers: { "content-type": "application/json" },
|
||||||
|
},
|
||||||
|
)) as unknown as FetchFunction,
|
||||||
|
})
|
||||||
|
|
||||||
|
const model = provider("test-model")
|
||||||
|
const result = await model.doGenerate({
|
||||||
|
prompt: [{ role: "user", content: [{ type: "text", text: "test" }] }],
|
||||||
|
})
|
||||||
|
|
||||||
|
const reasoning = result.content.find((c) => c.type === "reasoning")
|
||||||
|
expect(reasoning).toBeDefined()
|
||||||
|
expect((reasoning as any).text).toBe("Let me think about this...")
|
||||||
|
|
||||||
|
const text = result.content.find((c) => c.type === "text")
|
||||||
|
expect(text).toBeDefined()
|
||||||
|
expect((text as any).text).toBe("The answer is 42.")
|
||||||
|
})
|
||||||
|
|
||||||
|
test("handles non-streaming response with tool_calls", async () => {
|
||||||
|
const provider = createToothFairyAI({
|
||||||
|
apiKey: "test-key",
|
||||||
|
workspaceId: "test-workspace",
|
||||||
|
fetch: (() =>
|
||||||
|
new Response(
|
||||||
|
JSON.stringify({
|
||||||
|
id: "test-id",
|
||||||
|
object: "chat.completion",
|
||||||
|
created: 123,
|
||||||
|
model: "test",
|
||||||
|
choices: [
|
||||||
|
{
|
||||||
|
index: 0,
|
||||||
|
message: {
|
||||||
|
role: "assistant",
|
||||||
|
tool_calls: [
|
||||||
|
{
|
||||||
|
id: "call_123",
|
||||||
|
type: "function",
|
||||||
|
function: {
|
||||||
|
name: "get_weather",
|
||||||
|
arguments: '{"location": "Sydney"}',
|
||||||
|
},
|
||||||
|
},
|
||||||
|
],
|
||||||
|
},
|
||||||
|
finish_reason: "tool_calls",
|
||||||
|
},
|
||||||
|
],
|
||||||
|
usage: {
|
||||||
|
prompt_tokens: 10,
|
||||||
|
completion_tokens: 5,
|
||||||
|
total_tokens: 15,
|
||||||
|
},
|
||||||
|
}),
|
||||||
|
{
|
||||||
|
status: 200,
|
||||||
|
headers: { "content-type": "application/json" },
|
||||||
|
},
|
||||||
|
)) as unknown as FetchFunction,
|
||||||
|
})
|
||||||
|
|
||||||
|
const model = provider("test-model")
|
||||||
|
const result = await model.doGenerate({
|
||||||
|
prompt: [{ role: "user", content: [{ type: "text", text: "test" }] }],
|
||||||
|
})
|
||||||
|
|
||||||
|
const toolCall = result.content.find((c) => c.type === "tool-call")
|
||||||
|
expect(toolCall).toBeDefined()
|
||||||
|
expect((toolCall as any).toolName).toBe("get_weather")
|
||||||
|
expect((toolCall as any).input).toBe('{"location": "Sydney"}')
|
||||||
|
expect(result.finishReason).toBe("tool-calls")
|
||||||
|
})
|
||||||
|
})
|
||||||
Loading…
x
Reference in New Issue
Block a user