mirror of
https://gitea.toothfairyai.com/ToothFairyAI/tf_code.git
synced 2026-04-02 15:13:46 +00:00
pdf support in read tool (#5222)
Co-authored-by: ammi1378 <ammi1378@users.noreply.github.com>
This commit is contained in:
@@ -2,6 +2,17 @@ import type { APICallError, ModelMessage } from "ai"
|
||||
import { unique } from "remeda"
|
||||
import type { JSONSchema } from "zod/v4/core"
|
||||
import type { Provider } from "./provider"
|
||||
import type { ModelsDev } from "./models"
|
||||
|
||||
type Modality = NonNullable<ModelsDev.Model["modalities"]>["input"][number]
|
||||
|
||||
function mimeToModality(mime: string): Modality | undefined {
|
||||
if (mime.startsWith("image/")) return "image"
|
||||
if (mime.startsWith("audio/")) return "audio"
|
||||
if (mime.startsWith("video/")) return "video"
|
||||
if (mime === "application/pdf") return "pdf"
|
||||
return undefined
|
||||
}
|
||||
|
||||
export namespace ProviderTransform {
|
||||
function normalizeMessages(msgs: ModelMessage[], model: Provider.Model): ModelMessage[] {
|
||||
@@ -148,7 +159,32 @@ export namespace ProviderTransform {
|
||||
return msgs
|
||||
}
|
||||
|
||||
function unsupportedParts(msgs: ModelMessage[], model: Provider.Model): ModelMessage[] {
|
||||
return msgs.map((msg) => {
|
||||
if (msg.role !== "user" || !Array.isArray(msg.content)) return msg
|
||||
|
||||
const filtered = msg.content.map((part) => {
|
||||
if (part.type !== "file" && part.type !== "image") return part
|
||||
|
||||
const mime = part.type === "image" ? part.image.toString().split(";")[0].replace("data:", "") : part.mediaType
|
||||
const filename = part.type === "file" ? part.filename : undefined
|
||||
const modality = mimeToModality(mime)
|
||||
if (!modality) return part
|
||||
if (model.capabilities.input[modality]) return part
|
||||
|
||||
const name = filename ? `"${filename}"` : modality
|
||||
return {
|
||||
type: "text" as const,
|
||||
text: `ERROR: Cannot read ${name} (this model does not support ${modality} input). Inform the user.`,
|
||||
}
|
||||
})
|
||||
|
||||
return { ...msg, content: filtered }
|
||||
})
|
||||
}
|
||||
|
||||
export function message(msgs: ModelMessage[], model: Provider.Model) {
|
||||
msgs = unsupportedParts(msgs, model)
|
||||
msgs = normalizeMessages(msgs, model)
|
||||
if (model.providerID === "anthropic" || model.api.id.includes("anthropic") || model.api.id.includes("claude")) {
|
||||
msgs = applyCaching(msgs, model.providerID)
|
||||
|
||||
Reference in New Issue
Block a user