mirror of
https://gitea.toothfairyai.com/ToothFairyAI/tf_code.git
synced 2026-03-30 13:54:01 +00:00
324 lines
10 KiB
TypeScript
324 lines
10 KiB
TypeScript
import { Installation } from "@/installation"
|
|
import { Provider } from "@/provider/provider"
|
|
import { Log } from "@/util/log"
|
|
import {
|
|
streamText,
|
|
wrapLanguageModel,
|
|
type ModelMessage,
|
|
type StreamTextResult,
|
|
type Tool,
|
|
type ToolSet,
|
|
tool,
|
|
jsonSchema,
|
|
} from "ai"
|
|
import { mergeDeep, pipe } from "remeda"
|
|
import { GitLabWorkflowLanguageModel } from "gitlab-ai-provider"
|
|
import { ProviderTransform } from "@/provider/transform"
|
|
import { Config } from "@/config/config"
|
|
import { Instance } from "@/project/instance"
|
|
import type { Agent } from "@/agent/agent"
|
|
import type { MessageV2 } from "./message-v2"
|
|
import { Plugin } from "@/plugin"
|
|
import { SystemPrompt } from "./system"
|
|
import { Flag } from "@/flag/flag"
|
|
import { Permission } from "@/permission"
|
|
import { Auth } from "@/auth"
|
|
|
|
export namespace LLM {
|
|
const log = Log.create({ service: "llm" })
|
|
export const OUTPUT_TOKEN_MAX = ProviderTransform.OUTPUT_TOKEN_MAX
|
|
|
|
export type StreamInput = {
|
|
user: MessageV2.User
|
|
sessionID: string
|
|
model: Provider.Model
|
|
agent: Agent.Info
|
|
permission?: Permission.Ruleset
|
|
system: string[]
|
|
abort: AbortSignal
|
|
messages: ModelMessage[]
|
|
small?: boolean
|
|
tools: Record<string, Tool>
|
|
retries?: number
|
|
toolChoice?: "auto" | "required" | "none"
|
|
}
|
|
|
|
export type StreamOutput = StreamTextResult<ToolSet, unknown>
|
|
|
|
export async function stream(input: StreamInput) {
|
|
const l = log
|
|
.clone()
|
|
.tag("providerID", input.model.providerID)
|
|
.tag("modelID", input.model.id)
|
|
.tag("sessionID", input.sessionID)
|
|
.tag("small", (input.small ?? false).toString())
|
|
.tag("agent", input.agent.name)
|
|
.tag("mode", input.agent.mode)
|
|
l.info("stream", {
|
|
modelID: input.model.id,
|
|
providerID: input.model.providerID,
|
|
})
|
|
const [language, cfg, provider, auth] = await Promise.all([
|
|
Provider.getLanguage(input.model),
|
|
Config.get(),
|
|
Provider.getProvider(input.model.providerID),
|
|
Auth.get(input.model.providerID),
|
|
])
|
|
// TODO: move this to a proper hook
|
|
const isOpenaiOauth = provider.id === "openai" && auth?.type === "oauth"
|
|
|
|
const system: string[] = []
|
|
system.push(
|
|
[
|
|
// use agent prompt otherwise provider prompt
|
|
...(input.agent.prompt ? [input.agent.prompt] : SystemPrompt.provider(input.model)),
|
|
// any custom prompt passed into this call
|
|
...input.system,
|
|
// any custom prompt from last user message
|
|
...(input.user.system ? [input.user.system] : []),
|
|
]
|
|
.filter((x) => x)
|
|
.join("\n"),
|
|
)
|
|
|
|
const header = system[0]
|
|
await Plugin.trigger(
|
|
"experimental.chat.system.transform",
|
|
{ sessionID: input.sessionID, model: input.model },
|
|
{ system },
|
|
)
|
|
// rejoin to maintain 2-part structure for caching if header unchanged
|
|
if (system.length > 2 && system[0] === header) {
|
|
const rest = system.slice(1)
|
|
system.length = 0
|
|
system.push(header, rest.join("\n"))
|
|
}
|
|
|
|
const variant =
|
|
!input.small && input.model.variants && input.user.variant ? input.model.variants[input.user.variant] : {}
|
|
const base = input.small
|
|
? ProviderTransform.smallOptions(input.model)
|
|
: ProviderTransform.options({
|
|
model: input.model,
|
|
sessionID: input.sessionID,
|
|
providerOptions: provider.options,
|
|
})
|
|
const options: Record<string, any> = pipe(
|
|
base,
|
|
mergeDeep(input.model.options),
|
|
mergeDeep(input.agent.options),
|
|
mergeDeep(variant),
|
|
)
|
|
|
|
// Remove TF-specific options for non-ToothFairyAI providers
|
|
if (input.model.providerID !== "toothfairyai") {
|
|
delete options.tf_agent_id
|
|
delete options.tf_auth_via
|
|
}
|
|
|
|
if (isOpenaiOauth) {
|
|
options.instructions = system.join("\n")
|
|
}
|
|
|
|
const messages = isOpenaiOauth
|
|
? input.messages
|
|
: [
|
|
...system.map(
|
|
(x): ModelMessage => ({
|
|
role: "system",
|
|
content: x,
|
|
}),
|
|
),
|
|
...input.messages,
|
|
]
|
|
|
|
const params = await Plugin.trigger(
|
|
"chat.params",
|
|
{
|
|
sessionID: input.sessionID,
|
|
agent: input.agent,
|
|
model: input.model,
|
|
provider,
|
|
message: input.user,
|
|
},
|
|
{
|
|
temperature: input.model.capabilities.temperature
|
|
? (input.agent.temperature ?? ProviderTransform.temperature(input.model))
|
|
: undefined,
|
|
topP: input.agent.topP ?? ProviderTransform.topP(input.model),
|
|
topK: ProviderTransform.topK(input.model),
|
|
options,
|
|
},
|
|
)
|
|
|
|
const { headers } = await Plugin.trigger(
|
|
"chat.headers",
|
|
{
|
|
sessionID: input.sessionID,
|
|
agent: input.agent,
|
|
model: input.model,
|
|
provider,
|
|
message: input.user,
|
|
},
|
|
{
|
|
headers: {},
|
|
},
|
|
)
|
|
|
|
const maxOutputTokens =
|
|
isOpenaiOauth || provider.id.includes("github-copilot")
|
|
? undefined
|
|
: ProviderTransform.maxOutputTokens(input.model)
|
|
|
|
const tools = await resolveTools(input)
|
|
|
|
// LiteLLM and some Anthropic proxies require the tools parameter to be present
|
|
// when message history contains tool calls, even if no tools are being used.
|
|
// Add a dummy tool that is never called to satisfy this validation.
|
|
// This is enabled for:
|
|
// 1. Providers with "litellm" in their ID or API ID (auto-detected)
|
|
// 2. Providers with explicit "litellmProxy: true" option (opt-in for custom gateways)
|
|
const isLiteLLMProxy =
|
|
provider.options?.["litellmProxy"] === true ||
|
|
input.model.providerID.toLowerCase().includes("litellm") ||
|
|
input.model.api.id.toLowerCase().includes("litellm")
|
|
|
|
if (isLiteLLMProxy && Object.keys(tools).length === 0 && hasToolCalls(input.messages)) {
|
|
tools["_noop"] = tool({
|
|
description:
|
|
"Placeholder for LiteLLM/Anthropic proxy compatibility - required when message history contains tool calls but no active tools are needed",
|
|
inputSchema: jsonSchema({ type: "object", properties: {} }),
|
|
execute: async () => ({ output: "", title: "", metadata: {} }),
|
|
})
|
|
}
|
|
|
|
// Wire up toolExecutor for DWS workflow models so that tool calls
|
|
// from the workflow service are executed via opencode's tool system
|
|
// and results sent back over the WebSocket.
|
|
if (language instanceof GitLabWorkflowLanguageModel) {
|
|
const workflowModel = language
|
|
workflowModel.toolExecutor = async (toolName, argsJson, _requestID) => {
|
|
const t = tools[toolName]
|
|
if (!t || !t.execute) {
|
|
return { result: "", error: `Unknown tool: ${toolName}` }
|
|
}
|
|
try {
|
|
const result = await t.execute!(JSON.parse(argsJson), {
|
|
toolCallId: _requestID,
|
|
messages: input.messages,
|
|
abortSignal: input.abort,
|
|
})
|
|
const output = typeof result === "string" ? result : (result?.output ?? JSON.stringify(result))
|
|
return {
|
|
result: output,
|
|
metadata: typeof result === "object" ? result?.metadata : undefined,
|
|
title: typeof result === "object" ? result?.title : undefined,
|
|
}
|
|
} catch (e: any) {
|
|
return { result: "", error: e.message ?? String(e) }
|
|
}
|
|
}
|
|
}
|
|
|
|
return streamText({
|
|
onError(error) {
|
|
l.error("stream error", {
|
|
error,
|
|
})
|
|
},
|
|
async experimental_repairToolCall(failed) {
|
|
const lower = failed.toolCall.toolName.toLowerCase()
|
|
if (lower !== failed.toolCall.toolName && tools[lower]) {
|
|
l.info("repairing tool call", {
|
|
tool: failed.toolCall.toolName,
|
|
repaired: lower,
|
|
})
|
|
return {
|
|
...failed.toolCall,
|
|
toolName: lower,
|
|
}
|
|
}
|
|
return {
|
|
...failed.toolCall,
|
|
input: JSON.stringify({
|
|
tool: failed.toolCall.toolName,
|
|
error: failed.error.message,
|
|
}),
|
|
toolName: "invalid",
|
|
}
|
|
},
|
|
temperature: params.temperature,
|
|
topP: params.topP,
|
|
topK: params.topK,
|
|
providerOptions: ProviderTransform.providerOptions(input.model, params.options),
|
|
activeTools: Object.keys(tools).filter((x) => x !== "invalid"),
|
|
tools,
|
|
toolChoice: input.toolChoice,
|
|
maxOutputTokens,
|
|
abortSignal: input.abort,
|
|
headers: {
|
|
...(input.model.providerID.startsWith("opencode")
|
|
? {
|
|
"x-opencode-project": Instance.project.id,
|
|
"x-opencode-session": input.sessionID,
|
|
"x-opencode-request": input.user.id,
|
|
"x-opencode-client": Flag.OPENCODE_CLIENT,
|
|
}
|
|
: {
|
|
"User-Agent": `opencode/${Installation.VERSION}`,
|
|
}),
|
|
...input.model.headers,
|
|
...headers,
|
|
},
|
|
maxRetries: input.retries ?? 0,
|
|
messages,
|
|
model: wrapLanguageModel({
|
|
model: language,
|
|
middleware: [
|
|
{
|
|
async transformParams(args) {
|
|
if (args.type === "stream") {
|
|
// @ts-expect-error
|
|
args.params.prompt = ProviderTransform.message(args.params.prompt, input.model, options)
|
|
}
|
|
return args.params
|
|
},
|
|
},
|
|
],
|
|
}),
|
|
experimental_telemetry: {
|
|
isEnabled: cfg.experimental?.openTelemetry,
|
|
metadata: {
|
|
userId: cfg.username ?? "unknown",
|
|
sessionId: input.sessionID,
|
|
},
|
|
},
|
|
})
|
|
}
|
|
|
|
async function resolveTools(input: Pick<StreamInput, "tools" | "agent" | "permission" | "user">) {
|
|
const disabled = Permission.disabled(
|
|
Object.keys(input.tools),
|
|
Permission.merge(input.agent.permission, input.permission ?? []),
|
|
)
|
|
for (const tool of Object.keys(input.tools)) {
|
|
if (input.user.tools?.[tool] === false || disabled.has(tool)) {
|
|
delete input.tools[tool]
|
|
}
|
|
}
|
|
return input.tools
|
|
}
|
|
|
|
// Check if messages contain any tool-call content
|
|
// Used to determine if a dummy tool should be added for LiteLLM proxy compatibility
|
|
export function hasToolCalls(messages: ModelMessage[]): boolean {
|
|
for (const msg of messages) {
|
|
if (!Array.isArray(msg.content)) continue
|
|
for (const part of msg.content) {
|
|
if (part.type === "tool-call" || part.type === "tool-result") return true
|
|
}
|
|
}
|
|
return false
|
|
}
|
|
}
|