From 12ae1cb9b5526580298abd9a8a944b52de64013e Mon Sep 17 00:00:00 2001 From: Gab Date: Thu, 26 Mar 2026 15:07:30 +1100 Subject: [PATCH] feat: better preidctiosn --- bun.lock | 2 +- packages/tfcode/package.json | 2 +- .../cli/cmd/tui/component/dialog-tf-hooks.tsx | 179 ++++++++ .../cli/cmd/tui/component/prompt/index.tsx | 12 + .../sdk/toothfairyai/toothfairyai-provider.ts | 51 ++- .../tfcode/test/provider/toothfairyai.test.ts | 423 ++++++++++++++++++ 6 files changed, 647 insertions(+), 22 deletions(-) create mode 100644 packages/tfcode/src/cli/cmd/tui/component/dialog-tf-hooks.tsx create mode 100644 packages/tfcode/test/provider/toothfairyai.test.ts diff --git a/bun.lock b/bun.lock index 001f76823..33264cb29 100644 --- a/bun.lock +++ b/bun.lock @@ -381,7 +381,7 @@ }, "packages/tfcode": { "name": "tfcode", - "version": "1.0.6", + "version": "1.0.7", "bin": { "tfcode": "./bin/tfcode", }, diff --git a/packages/tfcode/package.json b/packages/tfcode/package.json index 805f0f706..d51733857 100644 --- a/packages/tfcode/package.json +++ b/packages/tfcode/package.json @@ -1,6 +1,6 @@ { "$schema": "https://json.schemastore.org/package.json", - "version": "1.0.6", + "version": "1.0.7", "name": "tfcode", "type": "module", "license": "MIT", diff --git a/packages/tfcode/src/cli/cmd/tui/component/dialog-tf-hooks.tsx b/packages/tfcode/src/cli/cmd/tui/component/dialog-tf-hooks.tsx new file mode 100644 index 000000000..f510c1d11 --- /dev/null +++ b/packages/tfcode/src/cli/cmd/tui/component/dialog-tf-hooks.tsx @@ -0,0 +1,179 @@ +import { createMemo, createSignal, createResource } from "solid-js" +import { DialogSelect, type DialogSelectRef, type DialogSelectOption } from "@tui/ui/dialog-select" +import { useTheme } from "../context/theme" +import { TextAttributes } from "@opentui/core" +import { Global } from "@/global" +import path from "path" +import { useToast } from "../ui/toast" +import { Keybind } from "@/util/keybind" + +interface Hook { + id: string + name: string + description?: string + remote_environment_name?: string + code_execution_instructions?: string + predefined_code_snippet?: string + execute_as_static_script?: boolean + execute_as_python_tool?: boolean + secrets_injected?: boolean + created_at?: string + updated_at?: string +} + +const TF_DEFAULT_REGION = "au" + +const REGION_API_URLS: Record = { + dev: "https://api.toothfairylab.link", + au: "https://api.toothfairyai.com", + eu: "https://api.eu.toothfairyai.com", + us: "https://api.us.toothfairyai.com", +} + +function HookStatus(props: { hook: Hook }) { + const { theme } = useTheme() + const features = [] + if (props.hook.execute_as_static_script) features.push("Static") + if (props.hook.execute_as_python_tool) features.push("Python Tool") + if (props.hook.secrets_injected) features.push("Secrets") + if (props.hook.remote_environment_name) features.push("Remote") + + if (features.length === 0) { + return Default + } + + return ( + + {features.map((f, i) => ( + <> + {i > 0 && ·} + {f} + + ))} + + ) +} + +export function DialogTfHooks() { + const toast = useToast() + const { theme } = useTheme() + const [, setRef] = createSignal>() + const [loading, setLoading] = createSignal(false) + const [refreshKey, setRefreshKey] = createSignal(0) + + const [credentials] = createResource(refreshKey, async () => { + try { + const credPath = path.join(Global.Path.data, ".tfcode", "credentials.json") + const data = (await Bun.file(credPath).json()) as { + api_key?: string + workspace_id?: string + region?: string + } + return data + } catch { + return {} + } + }) + + const [hooks, { refetch: refetchHooks }] = createResource(refreshKey, async () => { + const creds = credentials() + if (!creds?.api_key || !creds?.workspace_id) { + return [] + } + + setLoading(true) + try { + const region = creds.region || TF_DEFAULT_REGION + const baseUrl = REGION_API_URLS[region] || REGION_API_URLS[TF_DEFAULT_REGION] + + const response = await fetch(`${baseUrl}/hook/list`, { + method: "GET", + headers: { + "x-api-key": creds.api_key, + }, + }) + + if (!response.ok) { + const errorText = await response.text() + throw new Error(`HTTP ${response.status}: ${errorText}`) + } + + const data = (await response.json()) as { success?: boolean; hooks?: Hook[] } + return data.hooks || [] + } catch (error) { + toast.show({ + variant: "error", + message: `Failed to fetch hooks: ${error}`, + duration: 5000, + }) + return [] + } finally { + setLoading(false) + } + }) + + const options = createMemo[]>(() => { + const hooksList = hooks() || [] + const isLoading = loading() + + if (isLoading) { + return [ + { + value: { id: "loading", name: "Loading..." } as Hook, + title: "Loading hooks...", + description: "Fetching CodeExecutionEnvironments from ToothFairyAI", + footer: , + category: "Code Execution Environments", + }, + ] + } + + if (hooksList.length === 0) { + return [ + { + value: { id: "empty", name: "No hooks found" } as Hook, + title: "No CodeExecutionEnvironments found", + description: "Create hooks in ToothFairyAI Settings > Hooks", + footer: , + category: "Code Execution Environments", + }, + ] + } + + return hooksList.map((hook) => ({ + value: hook, + title: hook.name, + description: hook.description || "No description", + footer: , + category: "Code Execution Environments", + })) + }) + + const keybinds = createMemo(() => [ + { + keybind: Keybind.parse("space")[0], + title: "refresh", + onTrigger: async () => { + setRefreshKey((k) => k + 1) + toast.show({ + variant: "info", + message: "Refreshing hooks...", + duration: 2000, + }) + }, + }, + ]) + + return ( + { + if (option.value.id === "loading" || option.value.id === "empty") return + // Don't close on select - just show the hook details + }} + /> + ) +} diff --git a/packages/tfcode/src/cli/cmd/tui/component/prompt/index.tsx b/packages/tfcode/src/cli/cmd/tui/component/prompt/index.tsx index cae64d12f..ab059e20c 100644 --- a/packages/tfcode/src/cli/cmd/tui/component/prompt/index.tsx +++ b/packages/tfcode/src/cli/cmd/tui/component/prompt/index.tsx @@ -36,6 +36,7 @@ import { useKV } from "../../context/kv" import { useTextareaKeybindings } from "../textarea-keybindings" import { DialogSkill } from "../dialog-skill" import { DialogTfMcp } from "../dialog-tf-mcp" +import { DialogTfHooks } from "../dialog-tf-hooks" export type PromptProps = { sessionID?: string @@ -365,6 +366,17 @@ export function Prompt(props: PromptProps) { dialog.replace(() => ) }, }, + { + title: "Hooks", + value: "prompt.hooks", + category: "Prompt", + slash: { + name: "hooks", + }, + onSelect: () => { + dialog.replace(() => ) + }, + }, ] }) diff --git a/packages/tfcode/src/provider/sdk/toothfairyai/toothfairyai-provider.ts b/packages/tfcode/src/provider/sdk/toothfairyai/toothfairyai-provider.ts index dc96f6f8f..454112669 100644 --- a/packages/tfcode/src/provider/sdk/toothfairyai/toothfairyai-provider.ts +++ b/packages/tfcode/src/provider/sdk/toothfairyai/toothfairyai-provider.ts @@ -136,30 +136,41 @@ export function createToothFairyAI(options: ToothFairyAIProviderSettings = {}): for (const line of lines) { if (line.startsWith("data: ")) { const json = line.slice(6).trim() - if (json && !json.startsWith('{"status":')) { - filtered.push(line) - // Log tool calls and finish_reason + // Filter out connection status messages like {"status":"initialising"}, {"status":"connected"} + // These are internal progress indicators, not OpenAI-format chunks + if (json) { try { const parsed = JSON.parse(json) - if (parsed.choices?.[0]?.delta?.tool_calls) { - log.debug("stream tool_calls", { - tool_calls: parsed.choices[0].delta.tool_calls, - }) + if (parsed.status === "initialising" || parsed.status === "connected") { + log.debug("filtered connection status", { status: parsed.status }) + continue } - if (parsed.choices?.[0]?.finish_reason) { - log.info("stream finish_reason", { - finish_reason: parsed.choices[0].finish_reason, - }) - } - if (parsed.usage) { - log.info("stream usage", { - prompt_tokens: parsed.usage.prompt_tokens, - completion_tokens: parsed.usage.completion_tokens, - total_tokens: parsed.usage.total_tokens, - }) - } - } catch {} + } catch { + // Not valid JSON, keep the line + } } + filtered.push(line) + // Log tool calls and finish_reason + try { + const parsed = JSON.parse(json) + if (parsed.choices?.[0]?.delta?.tool_calls) { + log.debug("stream tool_calls", { + tool_calls: parsed.choices[0].delta.tool_calls, + }) + } + if (parsed.choices?.[0]?.finish_reason) { + log.info("stream finish_reason", { + finish_reason: parsed.choices[0].finish_reason, + }) + } + if (parsed.usage) { + log.info("stream usage", { + prompt_tokens: parsed.usage.prompt_tokens, + completion_tokens: parsed.usage.completion_tokens, + total_tokens: parsed.usage.total_tokens, + }) + } + } catch {} } else { filtered.push(line) } diff --git a/packages/tfcode/test/provider/toothfairyai.test.ts b/packages/tfcode/test/provider/toothfairyai.test.ts new file mode 100644 index 000000000..4308b7e21 --- /dev/null +++ b/packages/tfcode/test/provider/toothfairyai.test.ts @@ -0,0 +1,423 @@ +import { test, expect, describe } from "bun:test" +import { createToothFairyAI } from "../../src/provider/sdk/toothfairyai" +import type { LanguageModelV2StreamPart } from "@ai-sdk/provider" +import type { FetchFunction } from "@ai-sdk/provider-utils" + +async function collectStream(stream: ReadableStream): Promise { + const chunks: LanguageModelV2StreamPart[] = [] + const reader = stream.getReader() + while (true) { + const { done, value } = await reader.read() + if (done) break + chunks.push(value) + } + return chunks +} + +function createValidResponse(): Response { + return new Response( + JSON.stringify({ + id: "test-id", + object: "chat.completion", + created: 123, + model: "test", + choices: [ + { + index: 0, + message: { + role: "assistant", + content: "Test response", + }, + finish_reason: "stop", + }, + ], + usage: { + prompt_tokens: 10, + completion_tokens: 5, + total_tokens: 15, + }, + }), + { + status: 200, + headers: { "content-type": "application/json" }, + }, + ) +} + +function createSSEResponse(chunks: string[]): Response { + const encoder = new TextEncoder() + const stream = new ReadableStream({ + start(controller) { + for (const chunk of chunks) { + controller.enqueue(encoder.encode(chunk)) + } + controller.close() + }, + }) + + return new Response(stream, { + status: 200, + headers: { "content-type": "text/event-stream" }, + }) +} + +describe("ToothFairyAI Provider", () => { + test("creates provider with default settings", () => { + const provider = createToothFairyAI({ + apiKey: "test-key", + workspaceId: "test-workspace", + }) + expect(provider).toBeDefined() + expect(typeof provider).toBe("function") + expect(provider.languageModel).toBeDefined() + }) + + test("creates provider with custom region", () => { + const provider = createToothFairyAI({ + apiKey: "test-key", + workspaceId: "test-workspace", + region: "us", + }) + expect(provider).toBeDefined() + }) + + test("redirects /chat/completions to /predictions", async () => { + let capturedUrl: string | undefined + const provider = createToothFairyAI({ + apiKey: "test-key", + workspaceId: "test-workspace", + fetch: (async (input) => { + capturedUrl = typeof input === "string" ? input : input instanceof URL ? input.href : (input as Request).url + return createValidResponse() + }) as FetchFunction, + }) + + const model = provider("test-model") + await model.doGenerate({ + prompt: [{ role: "user", content: [{ type: "text", text: "test" }] }], + }) + + expect(capturedUrl).toBeDefined() + const url = new URL(capturedUrl!) + expect(url.pathname).toBe("/predictions") + }) + + test("adds workspaceid to request body", async () => { + let capturedBody: string | undefined + const provider = createToothFairyAI({ + apiKey: "test-key", + workspaceId: "test-workspace", + fetch: (async (_input, init) => { + capturedBody = init?.body as string + return createValidResponse() + }) as FetchFunction, + }) + + const model = provider("test-model") + await model.doGenerate({ + prompt: [{ role: "user", content: [{ type: "text", text: "test" }] }], + }) + + expect(capturedBody).toBeDefined() + const body = JSON.parse(capturedBody!) + expect(body.workspaceid).toBe("test-workspace") + }) + + test("sets x-api-key header", async () => { + let capturedHeaders: Headers | undefined + const provider = createToothFairyAI({ + apiKey: "test-key", + workspaceId: "test-workspace", + fetch: (async (_input, init) => { + capturedHeaders = new Headers(init?.headers as HeadersInit) + return createValidResponse() + }) as FetchFunction, + }) + + const model = provider("test-model") + await model.doGenerate({ + prompt: [{ role: "user", content: [{ type: "text", text: "test" }] }], + }) + + expect(capturedHeaders).toBeDefined() + expect(capturedHeaders!.get("x-api-key")).toBe("test-key") + }) + + test("removes Authorization header", async () => { + let capturedHeaders: Headers | undefined + const provider = createToothFairyAI({ + apiKey: "test-key", + workspaceId: "test-workspace", + fetch: (async (_input, init) => { + capturedHeaders = new Headers(init?.headers as HeadersInit) + return createValidResponse() + }) as FetchFunction, + }) + + const model = provider("test-model") + await model.doGenerate({ + prompt: [{ role: "user", content: [{ type: "text", text: "test" }] }], + headers: { Authorization: "Bearer should-be-removed" }, + }) + + expect(capturedHeaders).toBeDefined() + expect(capturedHeaders!.has("Authorization")).toBe(false) + }) +}) + +describe("ToothFairyAI Streaming", () => { + test("filters connection status messages", async () => { + const sseChunks = [ + 'data: {"status": "initialising"}\n\n', + 'data: {"status": "connected"}\n\n', + 'data: {"id": "test-id", "object": "chat.completion.chunk", "created": 123, "model": "test", "choices": [{"index": 0, "delta": {"role": "assistant"}, "finish_reason": null}]}\n\n', + 'data: {"id": "test-id", "object": "chat.completion.chunk", "created": 123, "model": "test", "choices": [{"index": 0, "delta": {"content": "Hello"}, "finish_reason": null}]}\n\n', + 'data: {"id": "test-id", "object": "chat.completion.chunk", "created": 123, "model": "test", "choices": [{"index": 0, "delta": {}, "finish_reason": "stop"}]}\n\n', + "data: [DONE]\n\n", + ] + + const provider = createToothFairyAI({ + apiKey: "test-key", + workspaceId: "test-workspace", + fetch: (() => createSSEResponse(sseChunks)) as unknown as FetchFunction, + }) + + const model = provider("test-model") + const result = await model.doStream({ + prompt: [{ role: "user", content: [{ type: "text", text: "test" }] }], + }) + + const chunks = await collectStream(result.stream) + const textDeltas = chunks.filter((c) => c.type === "text-delta") + expect(textDeltas.length).toBeGreaterThan(0) + }) + + test("handles reasoning_content in stream", async () => { + const sseChunks = [ + 'data: {"id": "test-id", "object": "chat.completion.chunk", "created": 123, "model": "test", "choices": [{"index": 0, "delta": {"reasoning_content": "Let me think..."}, "finish_reason": null}]}\n\n', + 'data: {"id": "test-id", "object": "chat.completion.chunk", "created": 123, "model": "test", "choices": [{"index": 0, "delta": {"content": "Answer"}, "finish_reason": null}]}\n\n', + 'data: {"id": "test-id", "object": "chat.completion.chunk", "created": 123, "model": "test", "choices": [{"index": 0, "delta": {}, "finish_reason": "stop"}]}\n\n', + "data: [DONE]\n\n", + ] + + const provider = createToothFairyAI({ + apiKey: "test-key", + workspaceId: "test-workspace", + fetch: (() => createSSEResponse(sseChunks)) as unknown as FetchFunction, + }) + + const model = provider("test-model") + const result = await model.doStream({ + prompt: [{ role: "user", content: [{ type: "text", text: "test" }] }], + }) + + const chunks = await collectStream(result.stream) + const reasoningDeltas = chunks.filter((c) => c.type === "reasoning-delta") + expect(reasoningDeltas.length).toBeGreaterThan(0) + }) + + test("handles tool_calls in stream", async () => { + const sseChunks = [ + 'data: {"id": "test-id", "object": "chat.completion.chunk", "created": 123, "model": "test", "choices": [{"index": 0, "delta": {"tool_calls": [{"index": 0, "id": "call_123", "type": "function", "function": {"name": "get_weather", "arguments": "{\\"location\\": \\"Sydney\\"}"}}]}, "finish_reason": null}]}\n\n', + 'data: {"id": "test-id", "object": "chat.completion.chunk", "created": 123, "model": "test", "choices": [{"index": 0, "delta": {}, "finish_reason": "tool_calls"}]}\n\n', + "data: [DONE]\n\n", + ] + + const provider = createToothFairyAI({ + apiKey: "test-key", + workspaceId: "test-workspace", + fetch: (() => createSSEResponse(sseChunks)) as unknown as FetchFunction, + }) + + const model = provider("test-model") + const result = await model.doStream({ + prompt: [{ role: "user", content: [{ type: "text", text: "test" }] }], + }) + + const chunks = await collectStream(result.stream) + const toolCalls = chunks.filter((c) => c.type === "tool-call") + expect(toolCalls.length).toBe(1) + }) + + test("handles error chunks", async () => { + const sseChunks = [ + 'data: {"error": {"message": "Model parameter is required", "type": "invalid_request_error", "code": 400}}\n\n', + ] + + const provider = createToothFairyAI({ + apiKey: "test-key", + workspaceId: "test-workspace", + fetch: (() => createSSEResponse(sseChunks)) as unknown as FetchFunction, + }) + + const model = provider("test-model") + const result = await model.doStream({ + prompt: [{ role: "user", content: [{ type: "text", text: "test" }] }], + }) + + const chunks = await collectStream(result.stream) + const errorChunks = chunks.filter((c) => c.type === "error") + expect(errorChunks.length).toBeGreaterThan(0) + }) + + test("handles usage information", async () => { + const sseChunks = [ + 'data: {"id": "test-id", "object": "chat.completion.chunk", "created": 123, "model": "test", "choices": [{"index": 0, "delta": {"role": "assistant"}, "finish_reason": null}]}\n\n', + 'data: {"id": "test-id", "object": "chat.completion.chunk", "created": 123, "model": "test", "choices": [{"index": 0, "delta": {"content": "Hello"}, "finish_reason": null}]}\n\n', + 'data: {"id": "test-id", "object": "chat.completion.chunk", "created": 123, "model": "test", "choices": [{"index": 0, "delta": {}, "finish_reason": "stop"}], "usage": {"prompt_tokens": 10, "completion_tokens": 5, "total_tokens": 15}}\n\n', + "data: [DONE]\n\n", + ] + + const provider = createToothFairyAI({ + apiKey: "test-key", + workspaceId: "test-workspace", + fetch: (() => createSSEResponse(sseChunks)) as unknown as FetchFunction, + }) + + const model = provider("test-model") + const result = await model.doStream({ + prompt: [{ role: "user", content: [{ type: "text", text: "test" }] }], + }) + + const chunks = await collectStream(result.stream) + const finishChunks = chunks.filter((c) => c.type === "finish") + expect(finishChunks.length).toBe(1) + const finish = finishChunks[0] as any + expect(finish.usage?.inputTokens).toBe(10) + expect(finish.usage?.outputTokens).toBe(5) + }) + + test("handles different finish reasons", async () => { + const finishReasons = ["stop", "length", "tool_calls"] + + for (const reason of finishReasons) { + const sseChunks = [ + 'data: {"id": "test-id", "object": "chat.completion.chunk", "created": 123, "model": "test", "choices": [{"index": 0, "delta": {"role": "assistant"}, "finish_reason": null}]}\n\n', + 'data: {"id": "test-id", "object": "chat.completion.chunk", "created": 123, "model": "test", "choices": [{"index": 0, "delta": {"content": "Hello"}, "finish_reason": null}]}\n\n', + `data: {"id": "test-id", "object": "chat.completion.chunk", "created": 123, "model": "test", "choices": [{"index": 0, "delta": {}, "finish_reason": "${reason}"}]}\n\n`, + "data: [DONE]\n\n", + ] + + const provider = createToothFairyAI({ + apiKey: "test-key", + workspaceId: "test-workspace", + fetch: (() => createSSEResponse(sseChunks)) as unknown as FetchFunction, + }) + + const model = provider("test-model") + const result = await model.doStream({ + prompt: [{ role: "user", content: [{ type: "text", text: "test" }] }], + }) + + const chunks = await collectStream(result.stream) + const finishChunks = chunks.filter((c) => c.type === "finish") + expect(finishChunks.length).toBe(1) + } + }) +}) + +describe("ToothFairyAI Non-streaming", () => { + test("handles non-streaming response with reasoning_content", async () => { + const provider = createToothFairyAI({ + apiKey: "test-key", + workspaceId: "test-workspace", + fetch: (() => + new Response( + JSON.stringify({ + id: "test-id", + object: "chat.completion", + created: 123, + model: "test", + choices: [ + { + index: 0, + message: { + role: "assistant", + reasoning_content: "Let me think about this...", + content: "The answer is 42.", + }, + finish_reason: "stop", + }, + ], + usage: { + prompt_tokens: 10, + completion_tokens: 20, + total_tokens: 30, + }, + }), + { + status: 200, + headers: { "content-type": "application/json" }, + }, + )) as unknown as FetchFunction, + }) + + const model = provider("test-model") + const result = await model.doGenerate({ + prompt: [{ role: "user", content: [{ type: "text", text: "test" }] }], + }) + + const reasoning = result.content.find((c) => c.type === "reasoning") + expect(reasoning).toBeDefined() + expect((reasoning as any).text).toBe("Let me think about this...") + + const text = result.content.find((c) => c.type === "text") + expect(text).toBeDefined() + expect((text as any).text).toBe("The answer is 42.") + }) + + test("handles non-streaming response with tool_calls", async () => { + const provider = createToothFairyAI({ + apiKey: "test-key", + workspaceId: "test-workspace", + fetch: (() => + new Response( + JSON.stringify({ + id: "test-id", + object: "chat.completion", + created: 123, + model: "test", + choices: [ + { + index: 0, + message: { + role: "assistant", + tool_calls: [ + { + id: "call_123", + type: "function", + function: { + name: "get_weather", + arguments: '{"location": "Sydney"}', + }, + }, + ], + }, + finish_reason: "tool_calls", + }, + ], + usage: { + prompt_tokens: 10, + completion_tokens: 5, + total_tokens: 15, + }, + }), + { + status: 200, + headers: { "content-type": "application/json" }, + }, + )) as unknown as FetchFunction, + }) + + const model = provider("test-model") + const result = await model.doGenerate({ + prompt: [{ role: "user", content: [{ type: "text", text: "test" }] }], + }) + + const toolCall = result.content.find((c) => c.type === "tool-call") + expect(toolCall).toBeDefined() + expect((toolCall as any).toolName).toBe("get_weather") + expect((toolCall as any).input).toBe('{"location": "Sydney"}') + expect(result.finishReason).toBe("tool-calls") + }) +})