mirror of
https://gitea.toothfairyai.com/ToothFairyAI/tf_code.git
synced 2026-04-03 23:53:46 +00:00
fix(app): model sticks to session
This commit is contained in:
@@ -6,10 +6,19 @@ let createPromptSubmit: typeof import("./submit").createPromptSubmit
|
||||
const createdClients: string[] = []
|
||||
const createdSessions: string[] = []
|
||||
const enabledAutoAccept: Array<{ sessionID: string; directory: string }> = []
|
||||
const optimistic: Array<{
|
||||
message: {
|
||||
agent: string
|
||||
model: { providerID: string; modelID: string }
|
||||
variant?: string
|
||||
}
|
||||
}> = []
|
||||
const sentShell: string[] = []
|
||||
const syncedDirectories: string[] = []
|
||||
|
||||
let params: { id?: string } = {}
|
||||
let selected = "/repo/worktree-a"
|
||||
let variant: string | undefined
|
||||
|
||||
const promptValue: Prompt = [{ type: "text", content: "ls", start: 0, end: 2 }]
|
||||
|
||||
@@ -26,6 +35,7 @@ const clientFor = (directory: string) => {
|
||||
return { data: undefined }
|
||||
},
|
||||
prompt: async () => ({ data: undefined }),
|
||||
promptAsync: async () => ({ data: undefined }),
|
||||
command: async () => ({ data: undefined }),
|
||||
abort: async () => ({ data: undefined }),
|
||||
},
|
||||
@@ -40,7 +50,7 @@ beforeAll(async () => {
|
||||
|
||||
mock.module("@solidjs/router", () => ({
|
||||
useNavigate: () => () => undefined,
|
||||
useParams: () => ({}),
|
||||
useParams: () => params,
|
||||
}))
|
||||
|
||||
mock.module("@opencode-ai/sdk/v2/client", () => ({
|
||||
@@ -62,7 +72,7 @@ beforeAll(async () => {
|
||||
useLocal: () => ({
|
||||
model: {
|
||||
current: () => ({ id: "model", provider: { id: "provider" } }),
|
||||
variant: { current: () => undefined },
|
||||
variant: { current: () => variant },
|
||||
},
|
||||
agent: {
|
||||
current: () => ({ name: "agent" }),
|
||||
@@ -118,7 +128,11 @@ beforeAll(async () => {
|
||||
data: { command: [] },
|
||||
session: {
|
||||
optimistic: {
|
||||
add: () => undefined,
|
||||
add: (value: {
|
||||
message: { agent: string; model: { providerID: string; modelID: string }; variant?: string }
|
||||
}) => {
|
||||
optimistic.push(value)
|
||||
},
|
||||
remove: () => undefined,
|
||||
},
|
||||
},
|
||||
@@ -155,9 +169,12 @@ beforeEach(() => {
|
||||
createdClients.length = 0
|
||||
createdSessions.length = 0
|
||||
enabledAutoAccept.length = 0
|
||||
optimistic.length = 0
|
||||
params = {}
|
||||
sentShell.length = 0
|
||||
syncedDirectories.length = 0
|
||||
selected = "/repo/worktree-a"
|
||||
variant = undefined
|
||||
})
|
||||
|
||||
describe("prompt submit worktree selection", () => {
|
||||
@@ -219,4 +236,39 @@ describe("prompt submit worktree selection", () => {
|
||||
|
||||
expect(enabledAutoAccept).toEqual([{ sessionID: "session-1", directory: "/repo/worktree-a" }])
|
||||
})
|
||||
|
||||
test("includes the selected variant on optimistic prompts", async () => {
|
||||
params = { id: "session-1" }
|
||||
variant = "high"
|
||||
|
||||
const submit = createPromptSubmit({
|
||||
info: () => ({ id: "session-1" }),
|
||||
imageAttachments: () => [],
|
||||
commentCount: () => 0,
|
||||
autoAccept: () => false,
|
||||
mode: () => "normal",
|
||||
working: () => false,
|
||||
editor: () => undefined,
|
||||
queueScroll: () => undefined,
|
||||
promptLength: (value) => value.reduce((sum, part) => sum + ("content" in part ? part.content.length : 0), 0),
|
||||
addToHistory: () => undefined,
|
||||
resetHistoryNavigation: () => undefined,
|
||||
setMode: () => undefined,
|
||||
setPopover: () => undefined,
|
||||
onSubmit: () => undefined,
|
||||
})
|
||||
|
||||
const event = { preventDefault: () => undefined } as unknown as Event
|
||||
|
||||
await submit.handleSubmit(event)
|
||||
|
||||
expect(optimistic).toHaveLength(1)
|
||||
expect(optimistic[0]).toMatchObject({
|
||||
message: {
|
||||
agent: "agent",
|
||||
model: { providerID: "provider", modelID: "model" },
|
||||
variant: "high",
|
||||
},
|
||||
})
|
||||
})
|
||||
})
|
||||
|
||||
@@ -316,6 +316,7 @@ export function createPromptSubmit(input: PromptSubmitInput) {
|
||||
time: { created: Date.now() },
|
||||
agent,
|
||||
model,
|
||||
variant,
|
||||
}
|
||||
|
||||
const addOptimisticMessage = () =>
|
||||
|
||||
@@ -199,6 +199,7 @@ export const { use: useSync, provider: SyncProvider } = createSimpleContext({
|
||||
parts: Part[]
|
||||
agent: string
|
||||
model: { providerID: string; modelID: string }
|
||||
variant?: string
|
||||
}) {
|
||||
const message: Message = {
|
||||
id: input.messageID,
|
||||
@@ -207,6 +208,7 @@ export const { use: useSync, provider: SyncProvider } = createSimpleContext({
|
||||
time: { created: Date.now() },
|
||||
agent: input.agent,
|
||||
model: input.model,
|
||||
variant: input.variant,
|
||||
}
|
||||
const [, setStore] = target()
|
||||
setOptimisticAdd(setStore as (...args: unknown[]) => void, {
|
||||
|
||||
@@ -36,6 +36,7 @@ import { createSessionComposerState, SessionComposerRegion } from "@/pages/sessi
|
||||
import { createOpenReviewFile, createSizing } from "@/pages/session/helpers"
|
||||
import { MessageTimeline } from "@/pages/session/message-timeline"
|
||||
import { type DiffStyle, SessionReviewTab, type SessionReviewTabProps } from "@/pages/session/review-tab"
|
||||
import { syncSessionModel } from "@/pages/session/session-model-helpers"
|
||||
import { createScrollSpy } from "@/pages/session/scroll-spy"
|
||||
import { SessionMobileTabs } from "@/pages/session/session-mobile-tabs"
|
||||
import { SessionSidePanel } from "@/pages/session/session-side-panel"
|
||||
@@ -418,11 +419,7 @@ export default function Page() {
|
||||
() => {
|
||||
const msg = lastUserMessage()
|
||||
if (!msg) return
|
||||
if (msg.agent) {
|
||||
local.agent.set(msg.agent)
|
||||
if (local.agent.current()?.model) return
|
||||
}
|
||||
if (msg.model) local.model.set(msg.model)
|
||||
syncSessionModel(local, msg)
|
||||
},
|
||||
),
|
||||
)
|
||||
|
||||
83
packages/app/src/pages/session/session-model-helpers.test.ts
Normal file
83
packages/app/src/pages/session/session-model-helpers.test.ts
Normal file
@@ -0,0 +1,83 @@
|
||||
import { describe, expect, test } from "bun:test"
|
||||
import type { UserMessage } from "@opencode-ai/sdk/v2"
|
||||
import { syncSessionModel } from "./session-model-helpers"
|
||||
|
||||
const message = (input?: Partial<Pick<UserMessage, "agent" | "model" | "variant">>) =>
|
||||
({
|
||||
id: "msg",
|
||||
sessionID: "session",
|
||||
role: "user",
|
||||
time: { created: 1 },
|
||||
agent: input?.agent ?? "build",
|
||||
model: input?.model ?? { providerID: "anthropic", modelID: "claude-sonnet-4" },
|
||||
variant: input?.variant,
|
||||
}) as UserMessage
|
||||
|
||||
describe("syncSessionModel", () => {
|
||||
test("restores the last message model and variant", () => {
|
||||
const calls: unknown[] = []
|
||||
|
||||
syncSessionModel(
|
||||
{
|
||||
agent: {
|
||||
set(value) {
|
||||
calls.push(["agent", value])
|
||||
},
|
||||
},
|
||||
model: {
|
||||
set(value) {
|
||||
calls.push(["model", value])
|
||||
},
|
||||
current() {
|
||||
return { id: "claude-sonnet-4", provider: { id: "anthropic" } }
|
||||
},
|
||||
variant: {
|
||||
set(value) {
|
||||
calls.push(["variant", value])
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
message({ variant: "high" }),
|
||||
)
|
||||
|
||||
expect(calls).toEqual([
|
||||
["agent", "build"],
|
||||
["model", { providerID: "anthropic", modelID: "claude-sonnet-4" }],
|
||||
["variant", "high"],
|
||||
])
|
||||
})
|
||||
|
||||
test("skips variant when the model falls back", () => {
|
||||
const calls: unknown[] = []
|
||||
|
||||
syncSessionModel(
|
||||
{
|
||||
agent: {
|
||||
set(value) {
|
||||
calls.push(["agent", value])
|
||||
},
|
||||
},
|
||||
model: {
|
||||
set(value) {
|
||||
calls.push(["model", value])
|
||||
},
|
||||
current() {
|
||||
return { id: "gpt-5", provider: { id: "openai" } }
|
||||
},
|
||||
variant: {
|
||||
set(value) {
|
||||
calls.push(["variant", value])
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
message({ variant: "high" }),
|
||||
)
|
||||
|
||||
expect(calls).toEqual([
|
||||
["agent", "build"],
|
||||
["model", { providerID: "anthropic", modelID: "claude-sonnet-4" }],
|
||||
])
|
||||
})
|
||||
})
|
||||
33
packages/app/src/pages/session/session-model-helpers.ts
Normal file
33
packages/app/src/pages/session/session-model-helpers.ts
Normal file
@@ -0,0 +1,33 @@
|
||||
import type { UserMessage } from "@opencode-ai/sdk/v2"
|
||||
import { batch } from "solid-js"
|
||||
|
||||
type Local = {
|
||||
agent: {
|
||||
set(name: string | undefined): void
|
||||
}
|
||||
model: {
|
||||
set(model: UserMessage["model"] | undefined): void
|
||||
current():
|
||||
| {
|
||||
id: string
|
||||
provider: { id: string }
|
||||
}
|
||||
| undefined
|
||||
variant: {
|
||||
set(value: string | undefined): void
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
export const syncSessionModel = (local: Local, msg: UserMessage) => {
|
||||
batch(() => {
|
||||
local.agent.set(msg.agent)
|
||||
local.model.set(msg.model)
|
||||
})
|
||||
|
||||
const model = local.model.current()
|
||||
if (!model) return
|
||||
if (model.provider.id !== msg.model.providerID) return
|
||||
if (model.id !== msg.model.modelID) return
|
||||
local.model.variant.set(msg.variant)
|
||||
}
|
||||
Reference in New Issue
Block a user