fix: enforced loop detection

This commit is contained in:
Gab
2026-03-29 19:24:02 +11:00
parent cecaf216c2
commit 493b0f924f
3 changed files with 62 additions and 54 deletions

View File

@@ -1,5 +1,5 @@
import type { LanguageModelV2 } from "@ai-sdk/provider"
import { type FetchFunction, withoutTrailingSlash } from "@ai-sdk/provider-utils"
import { type FetchFunction, withoutTrailingSlash, safeParseJSON, EventSourceParserStream } from "@ai-sdk/provider-utils"
import { OpenAICompatibleChatLanguageModel } from "../copilot/chat/openai-compatible-chat-language-model"
import { Log } from "@/util/log"
@@ -117,57 +117,59 @@ export function createToothFairyAI(options: ToothFairyAIProviderSettings = {}):
}
if (res.body && res.headers.get("content-type")?.includes("text/event-stream")) {
const decoder = new TextDecoder()
const encoder = new TextEncoder()
let buffer = ""
const filteredStream = res.body.pipeThrough(
new TransformStream({
transform(chunk, controller) {
buffer += decoder.decode(chunk, { stream: true })
const lines = buffer.split("\n")
buffer = lines.pop() || ""
const filtered: string[] = []
for (const line of lines) {
if (line.startsWith("data: ")) {
const json = line.slice(6).trim()
if (json) {
try {
const parsed = JSON.parse(json)
if (parsed.status === "initialising" || parsed.status === "connected") {
log.debug("filtered connection status", { status: parsed.status })
continue
}
if (parsed.choices?.[0]?.finish_reason) {
log.info("stream finish_reason", {
finish_reason: parsed.choices[0].finish_reason,
})
}
if (parsed.usage) {
log.info("stream usage", {
prompt_tokens: parsed.usage.prompt_tokens,
completion_tokens: parsed.usage.completion_tokens,
total_tokens: parsed.usage.total_tokens,
})
}
} catch {}
}
const filteredStream = res.body
.pipeThrough(new TextDecoderStream())
.pipeThrough(new EventSourceParserStream())
.pipeThrough(
new TransformStream({
async transform({ data }, controller) {
if (data === "[DONE]") {
return
}
filtered.push(line)
}
if (filtered.length > 0) {
controller.enqueue(encoder.encode(filtered.join("\n") + "\n"))
}
},
flush(controller) {
if (buffer) {
controller.enqueue(encoder.encode(buffer))
}
},
}),
)
const parsed = await safeParseJSON({ text: data, schema: null })
if (!parsed.success) {
log.error("Failed to parse SSE chunk", {
chunk: data.slice(0, 100),
error: parsed.error,
})
controller.enqueue({ data })
return
}
const value = parsed.value
if (value.status === "initialising" || value.status === "connected") {
log.debug("filtered connection status", { status: value.status })
return
}
if (value.choices?.[0]?.finish_reason) {
log.info("stream finish_reason", {
finish_reason: value.choices[0].finish_reason,
})
}
if (value.usage) {
log.info("stream usage", {
prompt_tokens: value.usage.prompt_tokens,
completion_tokens: value.usage.completion_tokens,
total_tokens: value.usage.total_tokens,
})
}
controller.enqueue({ data })
},
}),
)
.pipeThrough(
new TransformStream({
transform({ data }, controller) {
controller.enqueue(new TextEncoder().encode(`data: ${data}\n\n`))
},
}),
)
return new Response(filteredStream, {
headers: res.headers,