mirror of
https://gitea.toothfairyai.com/ToothFairyAI/tf_code.git
synced 2026-04-10 18:58:48 +00:00
feat: better preidctiosn
This commit is contained in:
423
packages/tfcode/test/provider/toothfairyai.test.ts
Normal file
423
packages/tfcode/test/provider/toothfairyai.test.ts
Normal file
@@ -0,0 +1,423 @@
|
||||
import { test, expect, describe } from "bun:test"
|
||||
import { createToothFairyAI } from "../../src/provider/sdk/toothfairyai"
|
||||
import type { LanguageModelV2StreamPart } from "@ai-sdk/provider"
|
||||
import type { FetchFunction } from "@ai-sdk/provider-utils"
|
||||
|
||||
async function collectStream(stream: ReadableStream<LanguageModelV2StreamPart>): Promise<LanguageModelV2StreamPart[]> {
|
||||
const chunks: LanguageModelV2StreamPart[] = []
|
||||
const reader = stream.getReader()
|
||||
while (true) {
|
||||
const { done, value } = await reader.read()
|
||||
if (done) break
|
||||
chunks.push(value)
|
||||
}
|
||||
return chunks
|
||||
}
|
||||
|
||||
function createValidResponse(): Response {
|
||||
return new Response(
|
||||
JSON.stringify({
|
||||
id: "test-id",
|
||||
object: "chat.completion",
|
||||
created: 123,
|
||||
model: "test",
|
||||
choices: [
|
||||
{
|
||||
index: 0,
|
||||
message: {
|
||||
role: "assistant",
|
||||
content: "Test response",
|
||||
},
|
||||
finish_reason: "stop",
|
||||
},
|
||||
],
|
||||
usage: {
|
||||
prompt_tokens: 10,
|
||||
completion_tokens: 5,
|
||||
total_tokens: 15,
|
||||
},
|
||||
}),
|
||||
{
|
||||
status: 200,
|
||||
headers: { "content-type": "application/json" },
|
||||
},
|
||||
)
|
||||
}
|
||||
|
||||
function createSSEResponse(chunks: string[]): Response {
|
||||
const encoder = new TextEncoder()
|
||||
const stream = new ReadableStream({
|
||||
start(controller) {
|
||||
for (const chunk of chunks) {
|
||||
controller.enqueue(encoder.encode(chunk))
|
||||
}
|
||||
controller.close()
|
||||
},
|
||||
})
|
||||
|
||||
return new Response(stream, {
|
||||
status: 200,
|
||||
headers: { "content-type": "text/event-stream" },
|
||||
})
|
||||
}
|
||||
|
||||
describe("ToothFairyAI Provider", () => {
|
||||
test("creates provider with default settings", () => {
|
||||
const provider = createToothFairyAI({
|
||||
apiKey: "test-key",
|
||||
workspaceId: "test-workspace",
|
||||
})
|
||||
expect(provider).toBeDefined()
|
||||
expect(typeof provider).toBe("function")
|
||||
expect(provider.languageModel).toBeDefined()
|
||||
})
|
||||
|
||||
test("creates provider with custom region", () => {
|
||||
const provider = createToothFairyAI({
|
||||
apiKey: "test-key",
|
||||
workspaceId: "test-workspace",
|
||||
region: "us",
|
||||
})
|
||||
expect(provider).toBeDefined()
|
||||
})
|
||||
|
||||
test("redirects /chat/completions to /predictions", async () => {
|
||||
let capturedUrl: string | undefined
|
||||
const provider = createToothFairyAI({
|
||||
apiKey: "test-key",
|
||||
workspaceId: "test-workspace",
|
||||
fetch: (async (input) => {
|
||||
capturedUrl = typeof input === "string" ? input : input instanceof URL ? input.href : (input as Request).url
|
||||
return createValidResponse()
|
||||
}) as FetchFunction,
|
||||
})
|
||||
|
||||
const model = provider("test-model")
|
||||
await model.doGenerate({
|
||||
prompt: [{ role: "user", content: [{ type: "text", text: "test" }] }],
|
||||
})
|
||||
|
||||
expect(capturedUrl).toBeDefined()
|
||||
const url = new URL(capturedUrl!)
|
||||
expect(url.pathname).toBe("/predictions")
|
||||
})
|
||||
|
||||
test("adds workspaceid to request body", async () => {
|
||||
let capturedBody: string | undefined
|
||||
const provider = createToothFairyAI({
|
||||
apiKey: "test-key",
|
||||
workspaceId: "test-workspace",
|
||||
fetch: (async (_input, init) => {
|
||||
capturedBody = init?.body as string
|
||||
return createValidResponse()
|
||||
}) as FetchFunction,
|
||||
})
|
||||
|
||||
const model = provider("test-model")
|
||||
await model.doGenerate({
|
||||
prompt: [{ role: "user", content: [{ type: "text", text: "test" }] }],
|
||||
})
|
||||
|
||||
expect(capturedBody).toBeDefined()
|
||||
const body = JSON.parse(capturedBody!)
|
||||
expect(body.workspaceid).toBe("test-workspace")
|
||||
})
|
||||
|
||||
test("sets x-api-key header", async () => {
|
||||
let capturedHeaders: Headers | undefined
|
||||
const provider = createToothFairyAI({
|
||||
apiKey: "test-key",
|
||||
workspaceId: "test-workspace",
|
||||
fetch: (async (_input, init) => {
|
||||
capturedHeaders = new Headers(init?.headers as HeadersInit)
|
||||
return createValidResponse()
|
||||
}) as FetchFunction,
|
||||
})
|
||||
|
||||
const model = provider("test-model")
|
||||
await model.doGenerate({
|
||||
prompt: [{ role: "user", content: [{ type: "text", text: "test" }] }],
|
||||
})
|
||||
|
||||
expect(capturedHeaders).toBeDefined()
|
||||
expect(capturedHeaders!.get("x-api-key")).toBe("test-key")
|
||||
})
|
||||
|
||||
test("removes Authorization header", async () => {
|
||||
let capturedHeaders: Headers | undefined
|
||||
const provider = createToothFairyAI({
|
||||
apiKey: "test-key",
|
||||
workspaceId: "test-workspace",
|
||||
fetch: (async (_input, init) => {
|
||||
capturedHeaders = new Headers(init?.headers as HeadersInit)
|
||||
return createValidResponse()
|
||||
}) as FetchFunction,
|
||||
})
|
||||
|
||||
const model = provider("test-model")
|
||||
await model.doGenerate({
|
||||
prompt: [{ role: "user", content: [{ type: "text", text: "test" }] }],
|
||||
headers: { Authorization: "Bearer should-be-removed" },
|
||||
})
|
||||
|
||||
expect(capturedHeaders).toBeDefined()
|
||||
expect(capturedHeaders!.has("Authorization")).toBe(false)
|
||||
})
|
||||
})
|
||||
|
||||
describe("ToothFairyAI Streaming", () => {
|
||||
test("filters connection status messages", async () => {
|
||||
const sseChunks = [
|
||||
'data: {"status": "initialising"}\n\n',
|
||||
'data: {"status": "connected"}\n\n',
|
||||
'data: {"id": "test-id", "object": "chat.completion.chunk", "created": 123, "model": "test", "choices": [{"index": 0, "delta": {"role": "assistant"}, "finish_reason": null}]}\n\n',
|
||||
'data: {"id": "test-id", "object": "chat.completion.chunk", "created": 123, "model": "test", "choices": [{"index": 0, "delta": {"content": "Hello"}, "finish_reason": null}]}\n\n',
|
||||
'data: {"id": "test-id", "object": "chat.completion.chunk", "created": 123, "model": "test", "choices": [{"index": 0, "delta": {}, "finish_reason": "stop"}]}\n\n',
|
||||
"data: [DONE]\n\n",
|
||||
]
|
||||
|
||||
const provider = createToothFairyAI({
|
||||
apiKey: "test-key",
|
||||
workspaceId: "test-workspace",
|
||||
fetch: (() => createSSEResponse(sseChunks)) as unknown as FetchFunction,
|
||||
})
|
||||
|
||||
const model = provider("test-model")
|
||||
const result = await model.doStream({
|
||||
prompt: [{ role: "user", content: [{ type: "text", text: "test" }] }],
|
||||
})
|
||||
|
||||
const chunks = await collectStream(result.stream)
|
||||
const textDeltas = chunks.filter((c) => c.type === "text-delta")
|
||||
expect(textDeltas.length).toBeGreaterThan(0)
|
||||
})
|
||||
|
||||
test("handles reasoning_content in stream", async () => {
|
||||
const sseChunks = [
|
||||
'data: {"id": "test-id", "object": "chat.completion.chunk", "created": 123, "model": "test", "choices": [{"index": 0, "delta": {"reasoning_content": "Let me think..."}, "finish_reason": null}]}\n\n',
|
||||
'data: {"id": "test-id", "object": "chat.completion.chunk", "created": 123, "model": "test", "choices": [{"index": 0, "delta": {"content": "Answer"}, "finish_reason": null}]}\n\n',
|
||||
'data: {"id": "test-id", "object": "chat.completion.chunk", "created": 123, "model": "test", "choices": [{"index": 0, "delta": {}, "finish_reason": "stop"}]}\n\n',
|
||||
"data: [DONE]\n\n",
|
||||
]
|
||||
|
||||
const provider = createToothFairyAI({
|
||||
apiKey: "test-key",
|
||||
workspaceId: "test-workspace",
|
||||
fetch: (() => createSSEResponse(sseChunks)) as unknown as FetchFunction,
|
||||
})
|
||||
|
||||
const model = provider("test-model")
|
||||
const result = await model.doStream({
|
||||
prompt: [{ role: "user", content: [{ type: "text", text: "test" }] }],
|
||||
})
|
||||
|
||||
const chunks = await collectStream(result.stream)
|
||||
const reasoningDeltas = chunks.filter((c) => c.type === "reasoning-delta")
|
||||
expect(reasoningDeltas.length).toBeGreaterThan(0)
|
||||
})
|
||||
|
||||
test("handles tool_calls in stream", async () => {
|
||||
const sseChunks = [
|
||||
'data: {"id": "test-id", "object": "chat.completion.chunk", "created": 123, "model": "test", "choices": [{"index": 0, "delta": {"tool_calls": [{"index": 0, "id": "call_123", "type": "function", "function": {"name": "get_weather", "arguments": "{\\"location\\": \\"Sydney\\"}"}}]}, "finish_reason": null}]}\n\n',
|
||||
'data: {"id": "test-id", "object": "chat.completion.chunk", "created": 123, "model": "test", "choices": [{"index": 0, "delta": {}, "finish_reason": "tool_calls"}]}\n\n',
|
||||
"data: [DONE]\n\n",
|
||||
]
|
||||
|
||||
const provider = createToothFairyAI({
|
||||
apiKey: "test-key",
|
||||
workspaceId: "test-workspace",
|
||||
fetch: (() => createSSEResponse(sseChunks)) as unknown as FetchFunction,
|
||||
})
|
||||
|
||||
const model = provider("test-model")
|
||||
const result = await model.doStream({
|
||||
prompt: [{ role: "user", content: [{ type: "text", text: "test" }] }],
|
||||
})
|
||||
|
||||
const chunks = await collectStream(result.stream)
|
||||
const toolCalls = chunks.filter((c) => c.type === "tool-call")
|
||||
expect(toolCalls.length).toBe(1)
|
||||
})
|
||||
|
||||
test("handles error chunks", async () => {
|
||||
const sseChunks = [
|
||||
'data: {"error": {"message": "Model parameter is required", "type": "invalid_request_error", "code": 400}}\n\n',
|
||||
]
|
||||
|
||||
const provider = createToothFairyAI({
|
||||
apiKey: "test-key",
|
||||
workspaceId: "test-workspace",
|
||||
fetch: (() => createSSEResponse(sseChunks)) as unknown as FetchFunction,
|
||||
})
|
||||
|
||||
const model = provider("test-model")
|
||||
const result = await model.doStream({
|
||||
prompt: [{ role: "user", content: [{ type: "text", text: "test" }] }],
|
||||
})
|
||||
|
||||
const chunks = await collectStream(result.stream)
|
||||
const errorChunks = chunks.filter((c) => c.type === "error")
|
||||
expect(errorChunks.length).toBeGreaterThan(0)
|
||||
})
|
||||
|
||||
test("handles usage information", async () => {
|
||||
const sseChunks = [
|
||||
'data: {"id": "test-id", "object": "chat.completion.chunk", "created": 123, "model": "test", "choices": [{"index": 0, "delta": {"role": "assistant"}, "finish_reason": null}]}\n\n',
|
||||
'data: {"id": "test-id", "object": "chat.completion.chunk", "created": 123, "model": "test", "choices": [{"index": 0, "delta": {"content": "Hello"}, "finish_reason": null}]}\n\n',
|
||||
'data: {"id": "test-id", "object": "chat.completion.chunk", "created": 123, "model": "test", "choices": [{"index": 0, "delta": {}, "finish_reason": "stop"}], "usage": {"prompt_tokens": 10, "completion_tokens": 5, "total_tokens": 15}}\n\n',
|
||||
"data: [DONE]\n\n",
|
||||
]
|
||||
|
||||
const provider = createToothFairyAI({
|
||||
apiKey: "test-key",
|
||||
workspaceId: "test-workspace",
|
||||
fetch: (() => createSSEResponse(sseChunks)) as unknown as FetchFunction,
|
||||
})
|
||||
|
||||
const model = provider("test-model")
|
||||
const result = await model.doStream({
|
||||
prompt: [{ role: "user", content: [{ type: "text", text: "test" }] }],
|
||||
})
|
||||
|
||||
const chunks = await collectStream(result.stream)
|
||||
const finishChunks = chunks.filter((c) => c.type === "finish")
|
||||
expect(finishChunks.length).toBe(1)
|
||||
const finish = finishChunks[0] as any
|
||||
expect(finish.usage?.inputTokens).toBe(10)
|
||||
expect(finish.usage?.outputTokens).toBe(5)
|
||||
})
|
||||
|
||||
test("handles different finish reasons", async () => {
|
||||
const finishReasons = ["stop", "length", "tool_calls"]
|
||||
|
||||
for (const reason of finishReasons) {
|
||||
const sseChunks = [
|
||||
'data: {"id": "test-id", "object": "chat.completion.chunk", "created": 123, "model": "test", "choices": [{"index": 0, "delta": {"role": "assistant"}, "finish_reason": null}]}\n\n',
|
||||
'data: {"id": "test-id", "object": "chat.completion.chunk", "created": 123, "model": "test", "choices": [{"index": 0, "delta": {"content": "Hello"}, "finish_reason": null}]}\n\n',
|
||||
`data: {"id": "test-id", "object": "chat.completion.chunk", "created": 123, "model": "test", "choices": [{"index": 0, "delta": {}, "finish_reason": "${reason}"}]}\n\n`,
|
||||
"data: [DONE]\n\n",
|
||||
]
|
||||
|
||||
const provider = createToothFairyAI({
|
||||
apiKey: "test-key",
|
||||
workspaceId: "test-workspace",
|
||||
fetch: (() => createSSEResponse(sseChunks)) as unknown as FetchFunction,
|
||||
})
|
||||
|
||||
const model = provider("test-model")
|
||||
const result = await model.doStream({
|
||||
prompt: [{ role: "user", content: [{ type: "text", text: "test" }] }],
|
||||
})
|
||||
|
||||
const chunks = await collectStream(result.stream)
|
||||
const finishChunks = chunks.filter((c) => c.type === "finish")
|
||||
expect(finishChunks.length).toBe(1)
|
||||
}
|
||||
})
|
||||
})
|
||||
|
||||
describe("ToothFairyAI Non-streaming", () => {
|
||||
test("handles non-streaming response with reasoning_content", async () => {
|
||||
const provider = createToothFairyAI({
|
||||
apiKey: "test-key",
|
||||
workspaceId: "test-workspace",
|
||||
fetch: (() =>
|
||||
new Response(
|
||||
JSON.stringify({
|
||||
id: "test-id",
|
||||
object: "chat.completion",
|
||||
created: 123,
|
||||
model: "test",
|
||||
choices: [
|
||||
{
|
||||
index: 0,
|
||||
message: {
|
||||
role: "assistant",
|
||||
reasoning_content: "Let me think about this...",
|
||||
content: "The answer is 42.",
|
||||
},
|
||||
finish_reason: "stop",
|
||||
},
|
||||
],
|
||||
usage: {
|
||||
prompt_tokens: 10,
|
||||
completion_tokens: 20,
|
||||
total_tokens: 30,
|
||||
},
|
||||
}),
|
||||
{
|
||||
status: 200,
|
||||
headers: { "content-type": "application/json" },
|
||||
},
|
||||
)) as unknown as FetchFunction,
|
||||
})
|
||||
|
||||
const model = provider("test-model")
|
||||
const result = await model.doGenerate({
|
||||
prompt: [{ role: "user", content: [{ type: "text", text: "test" }] }],
|
||||
})
|
||||
|
||||
const reasoning = result.content.find((c) => c.type === "reasoning")
|
||||
expect(reasoning).toBeDefined()
|
||||
expect((reasoning as any).text).toBe("Let me think about this...")
|
||||
|
||||
const text = result.content.find((c) => c.type === "text")
|
||||
expect(text).toBeDefined()
|
||||
expect((text as any).text).toBe("The answer is 42.")
|
||||
})
|
||||
|
||||
test("handles non-streaming response with tool_calls", async () => {
|
||||
const provider = createToothFairyAI({
|
||||
apiKey: "test-key",
|
||||
workspaceId: "test-workspace",
|
||||
fetch: (() =>
|
||||
new Response(
|
||||
JSON.stringify({
|
||||
id: "test-id",
|
||||
object: "chat.completion",
|
||||
created: 123,
|
||||
model: "test",
|
||||
choices: [
|
||||
{
|
||||
index: 0,
|
||||
message: {
|
||||
role: "assistant",
|
||||
tool_calls: [
|
||||
{
|
||||
id: "call_123",
|
||||
type: "function",
|
||||
function: {
|
||||
name: "get_weather",
|
||||
arguments: '{"location": "Sydney"}',
|
||||
},
|
||||
},
|
||||
],
|
||||
},
|
||||
finish_reason: "tool_calls",
|
||||
},
|
||||
],
|
||||
usage: {
|
||||
prompt_tokens: 10,
|
||||
completion_tokens: 5,
|
||||
total_tokens: 15,
|
||||
},
|
||||
}),
|
||||
{
|
||||
status: 200,
|
||||
headers: { "content-type": "application/json" },
|
||||
},
|
||||
)) as unknown as FetchFunction,
|
||||
})
|
||||
|
||||
const model = provider("test-model")
|
||||
const result = await model.doGenerate({
|
||||
prompt: [{ role: "user", content: [{ type: "text", text: "test" }] }],
|
||||
})
|
||||
|
||||
const toolCall = result.content.find((c) => c.type === "tool-call")
|
||||
expect(toolCall).toBeDefined()
|
||||
expect((toolCall as any).toolName).toBe("get_weather")
|
||||
expect((toolCall as any).input).toBe('{"location": "Sydney"}')
|
||||
expect(result.finishReason).toBe("tool-calls")
|
||||
})
|
||||
})
|
||||
Reference in New Issue
Block a user