feat(id): brand ProviderID and ModelID (#17110)

This commit is contained in:
Kit Langton
2026-03-12 09:27:52 -04:00
committed by GitHub
parent 2eeba53b07
commit c45467964c
23 changed files with 157 additions and 107 deletions

View File

@@ -137,8 +137,8 @@ test("custom agent from config creates new agent", async () => {
fn: async () => {
const custom = await Agent.get("my_custom_agent")
expect(custom).toBeDefined()
expect(custom?.model?.providerID).toBe("openai")
expect(custom?.model?.modelID).toBe("gpt-4")
expect(String(custom?.model?.providerID)).toBe("openai")
expect(String(custom?.model?.modelID)).toBe("gpt-4")
expect(custom?.description).toBe("My custom agent")
expect(custom?.temperature).toBe(0.5)
expect(custom?.topP).toBe(0.9)
@@ -166,8 +166,8 @@ test("custom agent config overrides native agent properties", async () => {
fn: async () => {
const build = await Agent.get("build")
expect(build).toBeDefined()
expect(build?.model?.providerID).toBe("anthropic")
expect(build?.model?.modelID).toBe("claude-3")
expect(String(build?.model?.providerID)).toBe("anthropic")
expect(String(build?.model?.modelID)).toBe("claude-3")
expect(build?.description).toBe("Custom build agent")
expect(build?.temperature).toBe(0.7)
expect(build?.color).toBe("#FF0000")

View File

@@ -302,8 +302,8 @@ test("getModel returns model for valid provider/model", async () => {
fn: async () => {
const model = await Provider.getModel("anthropic", "claude-sonnet-4-20250514")
expect(model).toBeDefined()
expect(model.providerID).toBe("anthropic")
expect(model.id).toBe("claude-sonnet-4-20250514")
expect(String(model.providerID)).toBe("anthropic")
expect(String(model.id)).toBe("claude-sonnet-4-20250514")
const language = await Provider.getLanguage(model)
expect(language).toBeDefined()
},
@@ -353,14 +353,14 @@ test("getModel throws ModelNotFoundError for invalid provider", async () => {
test("parseModel correctly parses provider/model string", () => {
const result = Provider.parseModel("anthropic/claude-sonnet-4")
expect(result.providerID).toBe("anthropic")
expect(result.modelID).toBe("claude-sonnet-4")
expect(String(result.providerID)).toBe("anthropic")
expect(String(result.modelID)).toBe("claude-sonnet-4")
})
test("parseModel handles model IDs with slashes", () => {
const result = Provider.parseModel("openrouter/anthropic/claude-3-opus")
expect(result.providerID).toBe("openrouter")
expect(result.modelID).toBe("anthropic/claude-3-opus")
expect(String(result.providerID)).toBe("openrouter")
expect(String(result.modelID)).toBe("anthropic/claude-3-opus")
})
test("defaultModel returns first available model when no config set", async () => {
@@ -406,8 +406,8 @@ test("defaultModel respects config model setting", async () => {
},
fn: async () => {
const model = await Provider.defaultModel()
expect(model.providerID).toBe("anthropic")
expect(model.modelID).toBe("claude-sonnet-4-20250514")
expect(String(model.providerID)).toBe("anthropic")
expect(String(model.modelID)).toBe("claude-sonnet-4-20250514")
},
})
})
@@ -632,7 +632,7 @@ test("getModel uses realIdByKey for aliased models", async () => {
const model = await Provider.getModel("anthropic", "my-sonnet")
expect(model).toBeDefined()
expect(model.id).toBe("my-sonnet")
expect(String(model.id)).toBe("my-sonnet")
expect(model.name).toBe("My Sonnet Alias")
},
})
@@ -960,8 +960,8 @@ test("getSmallModel respects config small_model override", async () => {
fn: async () => {
const model = await Provider.getSmallModel("anthropic")
expect(model).toBeDefined()
expect(model?.providerID).toBe("anthropic")
expect(model?.id).toBe("claude-sonnet-4-20250514")
expect(String(model?.providerID)).toBe("anthropic")
expect(String(model?.id)).toBe("claude-sonnet-4-20250514")
},
})
})
@@ -1605,7 +1605,7 @@ test("getProvider returns provider info", async () => {
fn: async () => {
const provider = await Provider.getProvider("anthropic")
expect(provider).toBeDefined()
expect(provider?.id).toBe("anthropic")
expect(String(provider?.id)).toBe("anthropic")
},
})
})

View File

@@ -1,5 +1,6 @@
import { describe, expect, test } from "bun:test"
import { ProviderTransform } from "../../src/provider/transform"
import { ModelID, ProviderID } from "../../src/provider/schema"
const OUTPUT_TOKEN_MAX = 32000
@@ -740,8 +741,8 @@ describe("ProviderTransform.message - DeepSeek reasoning content", () => {
const result = ProviderTransform.message(
msgs,
{
id: "deepseek/deepseek-chat",
providerID: "deepseek",
id: ModelID.make("deepseek/deepseek-chat"),
providerID: ProviderID.make("deepseek"),
api: {
id: "deepseek-chat",
url: "https://api.deepseek.com",
@@ -802,8 +803,8 @@ describe("ProviderTransform.message - DeepSeek reasoning content", () => {
const result = ProviderTransform.message(
msgs,
{
id: "openai/gpt-4",
providerID: "openai",
id: ModelID.make("openai/gpt-4"),
providerID: ProviderID.make("openai"),
api: {
id: "gpt-4",
url: "https://api.openai.com",

View File

@@ -7,6 +7,7 @@ import { Instance } from "../../src/project/instance"
import { Provider } from "../../src/provider/provider"
import { ProviderTransform } from "../../src/provider/transform"
import { ModelsDev } from "../../src/provider/models"
import { ProviderID } from "../../src/provider/schema"
import { Filesystem } from "../../src/util/filesystem"
import { tmpdir } from "../fixture/fixture"
import type { Agent } from "../../src/agent/agent"
@@ -282,7 +283,7 @@ describe("session.llm.stream", () => {
role: "user",
time: { created: Date.now() },
agent: agent.name,
model: { providerID, modelID: resolved.id },
model: { providerID: ProviderID.make(providerID), modelID: resolved.id },
variant: "high",
} satisfies MessageV2.User
@@ -411,7 +412,7 @@ describe("session.llm.stream", () => {
role: "user",
time: { created: Date.now() },
agent: agent.name,
model: { providerID: "openai", modelID: resolved.id },
model: { providerID: ProviderID.make("openai"), modelID: resolved.id },
variant: "high",
} satisfies MessageV2.User
@@ -534,7 +535,7 @@ describe("session.llm.stream", () => {
role: "user",
time: { created: Date.now() },
agent: agent.name,
model: { providerID, modelID: resolved.id },
model: { providerID: ProviderID.make(providerID), modelID: resolved.id },
} satisfies MessageV2.User
const stream = await LLM.stream({
@@ -635,7 +636,7 @@ describe("session.llm.stream", () => {
role: "user",
time: { created: Date.now() },
agent: agent.name,
model: { providerID, modelID: resolved.id },
model: { providerID: ProviderID.make(providerID), modelID: resolved.id },
} satisfies MessageV2.User
const stream = await LLM.stream({

View File

@@ -2,12 +2,14 @@ import { describe, expect, test } from "bun:test"
import { APICallError } from "ai"
import { MessageV2 } from "../../src/session/message-v2"
import type { Provider } from "../../src/provider/provider"
import { ModelID, ProviderID } from "../../src/provider/schema"
import { SessionID, MessageID, PartID } from "../../src/session/schema"
const sessionID = SessionID.make("session")
const providerID = ProviderID.make("test")
const model: Provider.Model = {
id: "test-model",
providerID: "test",
id: ModelID.make("test-model"),
providerID,
api: {
id: "test-model",
url: "https://example.com",
@@ -61,7 +63,7 @@ function userInfo(id: string): MessageV2.User {
role: "user",
time: { created: 0 },
agent: "user",
model: { providerID: "test", modelID: "test" },
model: { providerID, modelID: ModelID.make("test") },
tools: {},
mode: "",
} as unknown as MessageV2.User
@@ -795,7 +797,7 @@ describe("session.message-v2.fromError", () => {
code: "context_length_exceeded",
},
}
const result = MessageV2.fromError(input, { providerID: "test" })
const result = MessageV2.fromError(input, { providerID })
expect(result).toStrictEqual({
name: "ContextOverflowError",
@@ -830,7 +832,7 @@ describe("session.message-v2.fromError", () => {
message: item.code === "invalid_prompt" ? item.message : undefined,
},
}
const result = MessageV2.fromError(input, { providerID: "test" })
const result = MessageV2.fromError(input, { providerID })
expect(result).toStrictEqual({
name: "APIError",
@@ -862,7 +864,7 @@ describe("session.message-v2.fromError", () => {
responseHeaders: { "content-type": "application/json" },
isRetryable: false,
})
const result = MessageV2.fromError(error, { providerID: "test" })
const result = MessageV2.fromError(error, { providerID })
expect(MessageV2.ContextOverflowError.isInstance(result)).toBe(true)
})
})
@@ -877,14 +879,14 @@ describe("session.message-v2.fromError", () => {
responseHeaders: { "content-type": "application/json" },
isRetryable: false,
}),
{ providerID: "test" },
{ providerID },
)
expect(MessageV2.ContextOverflowError.isInstance(result)).toBe(false)
expect(MessageV2.APIError.isInstance(result)).toBe(true)
})
test("serializes unknown inputs", () => {
const result = MessageV2.fromError(123, { providerID: "test" })
const result = MessageV2.fromError(123, { providerID })
expect(result).toStrictEqual({
name: "UnknownError",

View File

@@ -2,6 +2,7 @@ import path from "path"
import { describe, expect, test } from "bun:test"
import { fileURLToPath } from "url"
import { Instance } from "../../src/project/instance"
import { ModelID, ProviderID } from "../../src/provider/schema"
import { Session } from "../../src/session"
import { MessageV2 } from "../../src/session/message-v2"
import { SessionPrompt } from "../../src/session/prompt"
@@ -173,7 +174,7 @@ describe("session.prompt agent variant", () => {
const other = await SessionPrompt.prompt({
sessionID: session.id,
agent: "build",
model: { providerID: "opencode", modelID: "kimi-k2.5-free" },
model: { providerID: ProviderID.make("opencode"), modelID: ModelID.make("kimi-k2.5-free") },
noReply: true,
parts: [{ type: "text", text: "hello" }],
})
@@ -187,7 +188,7 @@ describe("session.prompt agent variant", () => {
parts: [{ type: "text", text: "hello again" }],
})
if (match.info.role !== "user") throw new Error("expected user message")
expect(match.info.model).toEqual({ providerID: "openai", modelID: "gpt-5.2" })
expect(match.info.model).toEqual({ providerID: ProviderID.make("openai"), modelID: ModelID.make("gpt-5.2") })
expect(match.info.variant).toBe("xhigh")
const override = await SessionPrompt.prompt({

View File

@@ -4,6 +4,9 @@ import { APICallError } from "ai"
import { setTimeout as sleep } from "node:timers/promises"
import { SessionRetry } from "../../src/session/retry"
import { MessageV2 } from "../../src/session/message-v2"
import { ProviderID } from "../../src/provider/schema"
const providerID = ProviderID.make("test")
function apiError(headers?: Record<string, string>): MessageV2.APIError {
return new MessageV2.APIError({
@@ -150,7 +153,7 @@ describe("session.message-v2.fromError", () => {
.then((res) => res.text())
.catch((e) => e)
const result = MessageV2.fromError(error, { providerID: "test" })
const result = MessageV2.fromError(error, { providerID })
expect(MessageV2.APIError.isInstance(result)).toBe(true)
expect((result as MessageV2.APIError).data.isRetryable).toBe(true)
@@ -183,7 +186,7 @@ describe("session.message-v2.fromError", () => {
responseBody: '{"error":"boom"}',
isRetryable: false,
})
const result = MessageV2.fromError(error, { providerID: "openai" }) as MessageV2.APIError
const result = MessageV2.fromError(error, { providerID: ProviderID.make("openai") }) as MessageV2.APIError
expect(result.data.isRetryable).toBe(true)
})
})

View File

@@ -1,6 +1,7 @@
import { describe, expect, test, beforeEach, afterEach } from "bun:test"
import path from "path"
import { Session } from "../../src/session"
import { ModelID, ProviderID } from "../../src/provider/schema"
import { SessionRevert } from "../../src/session/revert"
import { SessionCompaction } from "../../src/session/compaction"
import { MessageV2 } from "../../src/session/message-v2"
@@ -29,8 +30,8 @@ describe("revert + compact workflow", () => {
sessionID,
agent: "default",
model: {
providerID: "openai",
modelID: "gpt-4",
providerID: ProviderID.make("openai"),
modelID: ModelID.make("gpt-4"),
},
time: {
created: Date.now(),
@@ -64,8 +65,8 @@ describe("revert + compact workflow", () => {
reasoning: 0,
cache: { read: 0, write: 0 },
},
modelID: "gpt-4",
providerID: "openai",
modelID: ModelID.make("gpt-4"),
providerID: ProviderID.make("openai"),
parentID: userMsg1.id,
time: {
created: Date.now(),
@@ -90,8 +91,8 @@ describe("revert + compact workflow", () => {
sessionID,
agent: "default",
model: {
providerID: "openai",
modelID: "gpt-4",
providerID: ProviderID.make("openai"),
modelID: ModelID.make("gpt-4"),
},
time: {
created: Date.now(),
@@ -124,8 +125,8 @@ describe("revert + compact workflow", () => {
reasoning: 0,
cache: { read: 0, write: 0 },
},
modelID: "gpt-4",
providerID: "openai",
modelID: ModelID.make("gpt-4"),
providerID: ProviderID.make("openai"),
parentID: userMsg2.id,
time: {
created: Date.now(),
@@ -205,8 +206,8 @@ describe("revert + compact workflow", () => {
sessionID,
agent: "default",
model: {
providerID: "openai",
modelID: "gpt-4",
providerID: ProviderID.make("openai"),
modelID: ModelID.make("gpt-4"),
},
time: {
created: Date.now(),
@@ -238,8 +239,8 @@ describe("revert + compact workflow", () => {
reasoning: 0,
cache: { read: 0, write: 0 },
},
modelID: "gpt-4",
providerID: "openai",
modelID: ModelID.make("gpt-4"),
providerID: ProviderID.make("openai"),
parentID: userMsg.id,
time: {
created: Date.now(),