feat: better preidctiosn

This commit is contained in:
Gab 2026-03-26 15:07:30 +11:00
parent a42f8fa99f
commit 12ae1cb9b5
6 changed files with 647 additions and 22 deletions

View File

@ -381,7 +381,7 @@
}, },
"packages/tfcode": { "packages/tfcode": {
"name": "tfcode", "name": "tfcode",
"version": "1.0.6", "version": "1.0.7",
"bin": { "bin": {
"tfcode": "./bin/tfcode", "tfcode": "./bin/tfcode",
}, },

View File

@ -1,6 +1,6 @@
{ {
"$schema": "https://json.schemastore.org/package.json", "$schema": "https://json.schemastore.org/package.json",
"version": "1.0.6", "version": "1.0.7",
"name": "tfcode", "name": "tfcode",
"type": "module", "type": "module",
"license": "MIT", "license": "MIT",

View File

@ -0,0 +1,179 @@
import { createMemo, createSignal, createResource } from "solid-js"
import { DialogSelect, type DialogSelectRef, type DialogSelectOption } from "@tui/ui/dialog-select"
import { useTheme } from "../context/theme"
import { TextAttributes } from "@opentui/core"
import { Global } from "@/global"
import path from "path"
import { useToast } from "../ui/toast"
import { Keybind } from "@/util/keybind"
interface Hook {
id: string
name: string
description?: string
remote_environment_name?: string
code_execution_instructions?: string
predefined_code_snippet?: string
execute_as_static_script?: boolean
execute_as_python_tool?: boolean
secrets_injected?: boolean
created_at?: string
updated_at?: string
}
const TF_DEFAULT_REGION = "au"
const REGION_API_URLS: Record<string, string> = {
dev: "https://api.toothfairylab.link",
au: "https://api.toothfairyai.com",
eu: "https://api.eu.toothfairyai.com",
us: "https://api.us.toothfairyai.com",
}
function HookStatus(props: { hook: Hook }) {
const { theme } = useTheme()
const features = []
if (props.hook.execute_as_static_script) features.push("Static")
if (props.hook.execute_as_python_tool) features.push("Python Tool")
if (props.hook.secrets_injected) features.push("Secrets")
if (props.hook.remote_environment_name) features.push("Remote")
if (features.length === 0) {
return <span style={{ fg: theme.textMuted }}>Default</span>
}
return (
<span style={{ fg: theme.success }}>
{features.map((f, i) => (
<>
{i > 0 && <span style={{ fg: theme.textMuted }}>·</span>}
{f}
</>
))}
</span>
)
}
export function DialogTfHooks() {
const toast = useToast()
const { theme } = useTheme()
const [, setRef] = createSignal<DialogSelectRef<unknown>>()
const [loading, setLoading] = createSignal(false)
const [refreshKey, setRefreshKey] = createSignal(0)
const [credentials] = createResource(refreshKey, async () => {
try {
const credPath = path.join(Global.Path.data, ".tfcode", "credentials.json")
const data = (await Bun.file(credPath).json()) as {
api_key?: string
workspace_id?: string
region?: string
}
return data
} catch {
return {}
}
})
const [hooks, { refetch: refetchHooks }] = createResource(refreshKey, async () => {
const creds = credentials()
if (!creds?.api_key || !creds?.workspace_id) {
return []
}
setLoading(true)
try {
const region = creds.region || TF_DEFAULT_REGION
const baseUrl = REGION_API_URLS[region] || REGION_API_URLS[TF_DEFAULT_REGION]
const response = await fetch(`${baseUrl}/hook/list`, {
method: "GET",
headers: {
"x-api-key": creds.api_key,
},
})
if (!response.ok) {
const errorText = await response.text()
throw new Error(`HTTP ${response.status}: ${errorText}`)
}
const data = (await response.json()) as { success?: boolean; hooks?: Hook[] }
return data.hooks || []
} catch (error) {
toast.show({
variant: "error",
message: `Failed to fetch hooks: ${error}`,
duration: 5000,
})
return []
} finally {
setLoading(false)
}
})
const options = createMemo<DialogSelectOption<Hook>[]>(() => {
const hooksList = hooks() || []
const isLoading = loading()
if (isLoading) {
return [
{
value: { id: "loading", name: "Loading..." } as Hook,
title: "Loading hooks...",
description: "Fetching CodeExecutionEnvironments from ToothFairyAI",
footer: <span style={{ fg: theme.textMuted }}></span>,
category: "Code Execution Environments",
},
]
}
if (hooksList.length === 0) {
return [
{
value: { id: "empty", name: "No hooks found" } as Hook,
title: "No CodeExecutionEnvironments found",
description: "Create hooks in ToothFairyAI Settings > Hooks",
footer: <span style={{ fg: theme.textMuted }}></span>,
category: "Code Execution Environments",
},
]
}
return hooksList.map((hook) => ({
value: hook,
title: hook.name,
description: hook.description || "No description",
footer: <HookStatus hook={hook} />,
category: "Code Execution Environments",
}))
})
const keybinds = createMemo(() => [
{
keybind: Keybind.parse("space")[0],
title: "refresh",
onTrigger: async () => {
setRefreshKey((k) => k + 1)
toast.show({
variant: "info",
message: "Refreshing hooks...",
duration: 2000,
})
},
},
])
return (
<DialogSelect
ref={setRef}
title="Code Execution Environments"
options={options()}
keybind={keybinds()}
onSelect={(option) => {
if (option.value.id === "loading" || option.value.id === "empty") return
// Don't close on select - just show the hook details
}}
/>
)
}

View File

@ -36,6 +36,7 @@ import { useKV } from "../../context/kv"
import { useTextareaKeybindings } from "../textarea-keybindings" import { useTextareaKeybindings } from "../textarea-keybindings"
import { DialogSkill } from "../dialog-skill" import { DialogSkill } from "../dialog-skill"
import { DialogTfMcp } from "../dialog-tf-mcp" import { DialogTfMcp } from "../dialog-tf-mcp"
import { DialogTfHooks } from "../dialog-tf-hooks"
export type PromptProps = { export type PromptProps = {
sessionID?: string sessionID?: string
@ -365,6 +366,17 @@ export function Prompt(props: PromptProps) {
dialog.replace(() => <DialogTfMcp />) dialog.replace(() => <DialogTfMcp />)
}, },
}, },
{
title: "Hooks",
value: "prompt.hooks",
category: "Prompt",
slash: {
name: "hooks",
},
onSelect: () => {
dialog.replace(() => <DialogTfHooks />)
},
},
] ]
}) })

View File

@ -136,30 +136,41 @@ export function createToothFairyAI(options: ToothFairyAIProviderSettings = {}):
for (const line of lines) { for (const line of lines) {
if (line.startsWith("data: ")) { if (line.startsWith("data: ")) {
const json = line.slice(6).trim() const json = line.slice(6).trim()
if (json && !json.startsWith('{"status":')) { // Filter out connection status messages like {"status":"initialising"}, {"status":"connected"}
filtered.push(line) // These are internal progress indicators, not OpenAI-format chunks
// Log tool calls and finish_reason if (json) {
try { try {
const parsed = JSON.parse(json) const parsed = JSON.parse(json)
if (parsed.choices?.[0]?.delta?.tool_calls) { if (parsed.status === "initialising" || parsed.status === "connected") {
log.debug("stream tool_calls", { log.debug("filtered connection status", { status: parsed.status })
tool_calls: parsed.choices[0].delta.tool_calls, continue
})
} }
if (parsed.choices?.[0]?.finish_reason) { } catch {
log.info("stream finish_reason", { // Not valid JSON, keep the line
finish_reason: parsed.choices[0].finish_reason, }
})
}
if (parsed.usage) {
log.info("stream usage", {
prompt_tokens: parsed.usage.prompt_tokens,
completion_tokens: parsed.usage.completion_tokens,
total_tokens: parsed.usage.total_tokens,
})
}
} catch {}
} }
filtered.push(line)
// Log tool calls and finish_reason
try {
const parsed = JSON.parse(json)
if (parsed.choices?.[0]?.delta?.tool_calls) {
log.debug("stream tool_calls", {
tool_calls: parsed.choices[0].delta.tool_calls,
})
}
if (parsed.choices?.[0]?.finish_reason) {
log.info("stream finish_reason", {
finish_reason: parsed.choices[0].finish_reason,
})
}
if (parsed.usage) {
log.info("stream usage", {
prompt_tokens: parsed.usage.prompt_tokens,
completion_tokens: parsed.usage.completion_tokens,
total_tokens: parsed.usage.total_tokens,
})
}
} catch {}
} else { } else {
filtered.push(line) filtered.push(line)
} }

View File

@ -0,0 +1,423 @@
import { test, expect, describe } from "bun:test"
import { createToothFairyAI } from "../../src/provider/sdk/toothfairyai"
import type { LanguageModelV2StreamPart } from "@ai-sdk/provider"
import type { FetchFunction } from "@ai-sdk/provider-utils"
async function collectStream(stream: ReadableStream<LanguageModelV2StreamPart>): Promise<LanguageModelV2StreamPart[]> {
const chunks: LanguageModelV2StreamPart[] = []
const reader = stream.getReader()
while (true) {
const { done, value } = await reader.read()
if (done) break
chunks.push(value)
}
return chunks
}
function createValidResponse(): Response {
return new Response(
JSON.stringify({
id: "test-id",
object: "chat.completion",
created: 123,
model: "test",
choices: [
{
index: 0,
message: {
role: "assistant",
content: "Test response",
},
finish_reason: "stop",
},
],
usage: {
prompt_tokens: 10,
completion_tokens: 5,
total_tokens: 15,
},
}),
{
status: 200,
headers: { "content-type": "application/json" },
},
)
}
function createSSEResponse(chunks: string[]): Response {
const encoder = new TextEncoder()
const stream = new ReadableStream({
start(controller) {
for (const chunk of chunks) {
controller.enqueue(encoder.encode(chunk))
}
controller.close()
},
})
return new Response(stream, {
status: 200,
headers: { "content-type": "text/event-stream" },
})
}
describe("ToothFairyAI Provider", () => {
test("creates provider with default settings", () => {
const provider = createToothFairyAI({
apiKey: "test-key",
workspaceId: "test-workspace",
})
expect(provider).toBeDefined()
expect(typeof provider).toBe("function")
expect(provider.languageModel).toBeDefined()
})
test("creates provider with custom region", () => {
const provider = createToothFairyAI({
apiKey: "test-key",
workspaceId: "test-workspace",
region: "us",
})
expect(provider).toBeDefined()
})
test("redirects /chat/completions to /predictions", async () => {
let capturedUrl: string | undefined
const provider = createToothFairyAI({
apiKey: "test-key",
workspaceId: "test-workspace",
fetch: (async (input) => {
capturedUrl = typeof input === "string" ? input : input instanceof URL ? input.href : (input as Request).url
return createValidResponse()
}) as FetchFunction,
})
const model = provider("test-model")
await model.doGenerate({
prompt: [{ role: "user", content: [{ type: "text", text: "test" }] }],
})
expect(capturedUrl).toBeDefined()
const url = new URL(capturedUrl!)
expect(url.pathname).toBe("/predictions")
})
test("adds workspaceid to request body", async () => {
let capturedBody: string | undefined
const provider = createToothFairyAI({
apiKey: "test-key",
workspaceId: "test-workspace",
fetch: (async (_input, init) => {
capturedBody = init?.body as string
return createValidResponse()
}) as FetchFunction,
})
const model = provider("test-model")
await model.doGenerate({
prompt: [{ role: "user", content: [{ type: "text", text: "test" }] }],
})
expect(capturedBody).toBeDefined()
const body = JSON.parse(capturedBody!)
expect(body.workspaceid).toBe("test-workspace")
})
test("sets x-api-key header", async () => {
let capturedHeaders: Headers | undefined
const provider = createToothFairyAI({
apiKey: "test-key",
workspaceId: "test-workspace",
fetch: (async (_input, init) => {
capturedHeaders = new Headers(init?.headers as HeadersInit)
return createValidResponse()
}) as FetchFunction,
})
const model = provider("test-model")
await model.doGenerate({
prompt: [{ role: "user", content: [{ type: "text", text: "test" }] }],
})
expect(capturedHeaders).toBeDefined()
expect(capturedHeaders!.get("x-api-key")).toBe("test-key")
})
test("removes Authorization header", async () => {
let capturedHeaders: Headers | undefined
const provider = createToothFairyAI({
apiKey: "test-key",
workspaceId: "test-workspace",
fetch: (async (_input, init) => {
capturedHeaders = new Headers(init?.headers as HeadersInit)
return createValidResponse()
}) as FetchFunction,
})
const model = provider("test-model")
await model.doGenerate({
prompt: [{ role: "user", content: [{ type: "text", text: "test" }] }],
headers: { Authorization: "Bearer should-be-removed" },
})
expect(capturedHeaders).toBeDefined()
expect(capturedHeaders!.has("Authorization")).toBe(false)
})
})
describe("ToothFairyAI Streaming", () => {
test("filters connection status messages", async () => {
const sseChunks = [
'data: {"status": "initialising"}\n\n',
'data: {"status": "connected"}\n\n',
'data: {"id": "test-id", "object": "chat.completion.chunk", "created": 123, "model": "test", "choices": [{"index": 0, "delta": {"role": "assistant"}, "finish_reason": null}]}\n\n',
'data: {"id": "test-id", "object": "chat.completion.chunk", "created": 123, "model": "test", "choices": [{"index": 0, "delta": {"content": "Hello"}, "finish_reason": null}]}\n\n',
'data: {"id": "test-id", "object": "chat.completion.chunk", "created": 123, "model": "test", "choices": [{"index": 0, "delta": {}, "finish_reason": "stop"}]}\n\n',
"data: [DONE]\n\n",
]
const provider = createToothFairyAI({
apiKey: "test-key",
workspaceId: "test-workspace",
fetch: (() => createSSEResponse(sseChunks)) as unknown as FetchFunction,
})
const model = provider("test-model")
const result = await model.doStream({
prompt: [{ role: "user", content: [{ type: "text", text: "test" }] }],
})
const chunks = await collectStream(result.stream)
const textDeltas = chunks.filter((c) => c.type === "text-delta")
expect(textDeltas.length).toBeGreaterThan(0)
})
test("handles reasoning_content in stream", async () => {
const sseChunks = [
'data: {"id": "test-id", "object": "chat.completion.chunk", "created": 123, "model": "test", "choices": [{"index": 0, "delta": {"reasoning_content": "Let me think..."}, "finish_reason": null}]}\n\n',
'data: {"id": "test-id", "object": "chat.completion.chunk", "created": 123, "model": "test", "choices": [{"index": 0, "delta": {"content": "Answer"}, "finish_reason": null}]}\n\n',
'data: {"id": "test-id", "object": "chat.completion.chunk", "created": 123, "model": "test", "choices": [{"index": 0, "delta": {}, "finish_reason": "stop"}]}\n\n',
"data: [DONE]\n\n",
]
const provider = createToothFairyAI({
apiKey: "test-key",
workspaceId: "test-workspace",
fetch: (() => createSSEResponse(sseChunks)) as unknown as FetchFunction,
})
const model = provider("test-model")
const result = await model.doStream({
prompt: [{ role: "user", content: [{ type: "text", text: "test" }] }],
})
const chunks = await collectStream(result.stream)
const reasoningDeltas = chunks.filter((c) => c.type === "reasoning-delta")
expect(reasoningDeltas.length).toBeGreaterThan(0)
})
test("handles tool_calls in stream", async () => {
const sseChunks = [
'data: {"id": "test-id", "object": "chat.completion.chunk", "created": 123, "model": "test", "choices": [{"index": 0, "delta": {"tool_calls": [{"index": 0, "id": "call_123", "type": "function", "function": {"name": "get_weather", "arguments": "{\\"location\\": \\"Sydney\\"}"}}]}, "finish_reason": null}]}\n\n',
'data: {"id": "test-id", "object": "chat.completion.chunk", "created": 123, "model": "test", "choices": [{"index": 0, "delta": {}, "finish_reason": "tool_calls"}]}\n\n',
"data: [DONE]\n\n",
]
const provider = createToothFairyAI({
apiKey: "test-key",
workspaceId: "test-workspace",
fetch: (() => createSSEResponse(sseChunks)) as unknown as FetchFunction,
})
const model = provider("test-model")
const result = await model.doStream({
prompt: [{ role: "user", content: [{ type: "text", text: "test" }] }],
})
const chunks = await collectStream(result.stream)
const toolCalls = chunks.filter((c) => c.type === "tool-call")
expect(toolCalls.length).toBe(1)
})
test("handles error chunks", async () => {
const sseChunks = [
'data: {"error": {"message": "Model parameter is required", "type": "invalid_request_error", "code": 400}}\n\n',
]
const provider = createToothFairyAI({
apiKey: "test-key",
workspaceId: "test-workspace",
fetch: (() => createSSEResponse(sseChunks)) as unknown as FetchFunction,
})
const model = provider("test-model")
const result = await model.doStream({
prompt: [{ role: "user", content: [{ type: "text", text: "test" }] }],
})
const chunks = await collectStream(result.stream)
const errorChunks = chunks.filter((c) => c.type === "error")
expect(errorChunks.length).toBeGreaterThan(0)
})
test("handles usage information", async () => {
const sseChunks = [
'data: {"id": "test-id", "object": "chat.completion.chunk", "created": 123, "model": "test", "choices": [{"index": 0, "delta": {"role": "assistant"}, "finish_reason": null}]}\n\n',
'data: {"id": "test-id", "object": "chat.completion.chunk", "created": 123, "model": "test", "choices": [{"index": 0, "delta": {"content": "Hello"}, "finish_reason": null}]}\n\n',
'data: {"id": "test-id", "object": "chat.completion.chunk", "created": 123, "model": "test", "choices": [{"index": 0, "delta": {}, "finish_reason": "stop"}], "usage": {"prompt_tokens": 10, "completion_tokens": 5, "total_tokens": 15}}\n\n',
"data: [DONE]\n\n",
]
const provider = createToothFairyAI({
apiKey: "test-key",
workspaceId: "test-workspace",
fetch: (() => createSSEResponse(sseChunks)) as unknown as FetchFunction,
})
const model = provider("test-model")
const result = await model.doStream({
prompt: [{ role: "user", content: [{ type: "text", text: "test" }] }],
})
const chunks = await collectStream(result.stream)
const finishChunks = chunks.filter((c) => c.type === "finish")
expect(finishChunks.length).toBe(1)
const finish = finishChunks[0] as any
expect(finish.usage?.inputTokens).toBe(10)
expect(finish.usage?.outputTokens).toBe(5)
})
test("handles different finish reasons", async () => {
const finishReasons = ["stop", "length", "tool_calls"]
for (const reason of finishReasons) {
const sseChunks = [
'data: {"id": "test-id", "object": "chat.completion.chunk", "created": 123, "model": "test", "choices": [{"index": 0, "delta": {"role": "assistant"}, "finish_reason": null}]}\n\n',
'data: {"id": "test-id", "object": "chat.completion.chunk", "created": 123, "model": "test", "choices": [{"index": 0, "delta": {"content": "Hello"}, "finish_reason": null}]}\n\n',
`data: {"id": "test-id", "object": "chat.completion.chunk", "created": 123, "model": "test", "choices": [{"index": 0, "delta": {}, "finish_reason": "${reason}"}]}\n\n`,
"data: [DONE]\n\n",
]
const provider = createToothFairyAI({
apiKey: "test-key",
workspaceId: "test-workspace",
fetch: (() => createSSEResponse(sseChunks)) as unknown as FetchFunction,
})
const model = provider("test-model")
const result = await model.doStream({
prompt: [{ role: "user", content: [{ type: "text", text: "test" }] }],
})
const chunks = await collectStream(result.stream)
const finishChunks = chunks.filter((c) => c.type === "finish")
expect(finishChunks.length).toBe(1)
}
})
})
describe("ToothFairyAI Non-streaming", () => {
test("handles non-streaming response with reasoning_content", async () => {
const provider = createToothFairyAI({
apiKey: "test-key",
workspaceId: "test-workspace",
fetch: (() =>
new Response(
JSON.stringify({
id: "test-id",
object: "chat.completion",
created: 123,
model: "test",
choices: [
{
index: 0,
message: {
role: "assistant",
reasoning_content: "Let me think about this...",
content: "The answer is 42.",
},
finish_reason: "stop",
},
],
usage: {
prompt_tokens: 10,
completion_tokens: 20,
total_tokens: 30,
},
}),
{
status: 200,
headers: { "content-type": "application/json" },
},
)) as unknown as FetchFunction,
})
const model = provider("test-model")
const result = await model.doGenerate({
prompt: [{ role: "user", content: [{ type: "text", text: "test" }] }],
})
const reasoning = result.content.find((c) => c.type === "reasoning")
expect(reasoning).toBeDefined()
expect((reasoning as any).text).toBe("Let me think about this...")
const text = result.content.find((c) => c.type === "text")
expect(text).toBeDefined()
expect((text as any).text).toBe("The answer is 42.")
})
test("handles non-streaming response with tool_calls", async () => {
const provider = createToothFairyAI({
apiKey: "test-key",
workspaceId: "test-workspace",
fetch: (() =>
new Response(
JSON.stringify({
id: "test-id",
object: "chat.completion",
created: 123,
model: "test",
choices: [
{
index: 0,
message: {
role: "assistant",
tool_calls: [
{
id: "call_123",
type: "function",
function: {
name: "get_weather",
arguments: '{"location": "Sydney"}',
},
},
],
},
finish_reason: "tool_calls",
},
],
usage: {
prompt_tokens: 10,
completion_tokens: 5,
total_tokens: 15,
},
}),
{
status: 200,
headers: { "content-type": "application/json" },
},
)) as unknown as FetchFunction,
})
const model = provider("test-model")
const result = await model.doGenerate({
prompt: [{ role: "user", content: [{ type: "text", text: "test" }] }],
})
const toolCall = result.content.find((c) => c.type === "tool-call")
expect(toolCall).toBeDefined()
expect((toolCall as any).toolName).toBe("get_weather")
expect((toolCall as any).input).toBe('{"location": "Sydney"}')
expect(result.finishReason).toBe("tool-calls")
})
})