fix(app): model selection persist by session (#17348)

This commit is contained in:
Adam 2026-03-13 11:05:08 -05:00 committed by GitHub
parent 5c7088338c
commit 4ad8116ce3
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
19 changed files with 981 additions and 452 deletions

View File

@ -95,6 +95,9 @@ async function seedStorage(page: Page, input: { directory: string; extra?: strin
const win = window as E2EWindow const win = window as E2EWindow
win.__opencode_e2e = { win.__opencode_e2e = {
...win.__opencode_e2e, ...win.__opencode_e2e,
model: {
enabled: true,
},
terminal: { terminal: {
enabled: true, enabled: true,
terminals: {}, terminals: {},

View File

@ -13,6 +13,9 @@ export const sessionTodoToggleButtonSelector = '[data-action="session-todo-toggl
export const sessionTodoListSelector = '[data-slot="session-todo-list"]' export const sessionTodoListSelector = '[data-slot="session-todo-list"]'
export const modelVariantCycleSelector = '[data-action="model-variant-cycle"]' export const modelVariantCycleSelector = '[data-action="model-variant-cycle"]'
export const promptAgentSelector = '[data-component="prompt-agent-control"]'
export const promptModelSelector = '[data-component="prompt-model-control"]'
export const promptVariantSelector = '[data-component="prompt-variant-control"]'
export const settingsLanguageSelectSelector = '[data-action="settings-language"]' export const settingsLanguageSelectSelector = '[data-action="settings-language"]'
export const settingsColorSchemeSelector = '[data-action="settings-color-scheme"]' export const settingsColorSchemeSelector = '[data-action="settings-color-scheme"]'
export const settingsThemeSelector = '[data-action="settings-theme"]' export const settingsThemeSelector = '[data-action="settings-theme"]'

View File

@ -0,0 +1,351 @@
import { base64Decode } from "@opencode-ai/util/encode"
import type { Locator, Page } from "@playwright/test"
import { test, expect } from "../fixtures"
import { openSidebar, sessionIDFromUrl, setWorkspacesEnabled, waitSessionIdle, waitSlug } from "../actions"
import {
promptAgentSelector,
promptModelSelector,
promptSelector,
promptVariantSelector,
workspaceItemSelector,
workspaceNewSessionSelector,
} from "../selectors"
import { createSdk, sessionPath } from "../utils"
type Footer = {
agent: string
model: string
variant: string
}
type Probe = {
dir?: string
sessionID?: string
model?: { providerID: string; modelID: string }
}
const escape = (value: string) => value.replace(/[.*+?^${}()|[\]\\]/g, "\\$&")
const text = async (locator: Locator) => ((await locator.textContent()) ?? "").trim()
const modelKey = (state: Probe | null) => (state?.model ? `${state.model.providerID}:${state.model.modelID}` : null)
const dirKey = (state: Probe | null) => state?.dir ?? ""
async function probe(page: Page): Promise<Probe | null> {
return page.evaluate(() => {
const win = window as Window & {
__opencode_e2e?: {
model?: {
current?: Probe
}
}
}
return win.__opencode_e2e?.model?.current ?? null
})
}
async function currentDir(page: Page) {
let hit = ""
await expect
.poll(
async () => {
const next = dirKey(await probe(page))
if (next) hit = next
return next
},
{ timeout: 30_000 },
)
.not.toBe("")
return hit
}
async function read(page: Page): Promise<Footer> {
return {
agent: await text(page.locator(`${promptAgentSelector} [data-slot="select-select-trigger-value"]`).first()),
model: await text(page.locator(`${promptModelSelector} [data-action="prompt-model"] span`).first()),
variant: await text(page.locator(`${promptVariantSelector} [data-slot="select-select-trigger-value"]`).first()),
}
}
async function waitFooter(page: Page, expected: Partial<Footer>) {
let hit: Footer | null = null
await expect
.poll(
async () => {
const state = await read(page)
const ok = Object.entries(expected).every(([key, value]) => state[key as keyof Footer] === value)
if (ok) hit = state
return ok
},
{ timeout: 30_000 },
)
.toBe(true)
if (!hit) throw new Error("Failed to resolve prompt footer state")
return hit
}
async function waitModel(page: Page, value: string) {
await expect.poll(() => probe(page).then(modelKey), { timeout: 30_000 }).toBe(value)
}
async function choose(page: Page, root: string, value: string) {
const select = page.locator(root)
await expect(select).toBeVisible()
await select.locator('[data-action], [data-slot="select-select-trigger"]').first().click()
const item = page
.locator('[data-slot="select-select-item"]')
.filter({ hasText: new RegExp(`^\\s*${escape(value)}\\s*$`) })
.first()
await expect(item).toBeVisible()
await item.click()
}
async function variantCount(page: Page) {
const select = page.locator(promptVariantSelector)
await expect(select).toBeVisible()
await select.locator('[data-slot="select-select-trigger"]').click()
const count = await page.locator('[data-slot="select-select-item"]').count()
await page.keyboard.press("Escape")
return count
}
async function agents(page: Page) {
const select = page.locator(promptAgentSelector)
await expect(select).toBeVisible()
await select.locator('[data-action], [data-slot="select-select-trigger"]').first().click()
const labels = await page.locator('[data-slot="select-select-item-label"]').allTextContents()
await page.keyboard.press("Escape")
return labels.map((item) => item.trim()).filter(Boolean)
}
async function ensureVariant(page: Page, directory: string): Promise<Footer> {
const current = await read(page)
if ((await variantCount(page)) >= 2) return current
const cfg = await createSdk(directory)
.config.get()
.then((x) => x.data)
const visible = new Set(await agents(page))
const entry = Object.entries(cfg?.agent ?? {}).find((item) => {
const value = item[1]
return !!value && typeof value === "object" && "variant" in value && "model" in value && visible.has(item[0])
})
const name = entry?.[0]
test.skip(!name, "no agent with alternate variants available")
if (!name) return current
await choose(page, promptAgentSelector, name)
await expect.poll(() => variantCount(page), { timeout: 30_000 }).toBeGreaterThanOrEqual(2)
return waitFooter(page, { agent: name })
}
async function chooseDifferentVariant(page: Page): Promise<Footer> {
const current = await read(page)
const select = page.locator(promptVariantSelector)
await expect(select).toBeVisible()
await select.locator('[data-slot="select-select-trigger"]').click()
const items = page.locator('[data-slot="select-select-item"]')
const count = await items.count()
if (count < 2) throw new Error("Current model has no alternate variant to select")
for (let i = 0; i < count; i++) {
const item = items.nth(i)
const next = await text(item.locator('[data-slot="select-select-item-label"]').first())
if (!next || next === current.variant) continue
await item.click()
return waitFooter(page, { agent: current.agent, model: current.model, variant: next })
}
throw new Error("Failed to choose a different variant")
}
async function chooseOtherModel(page: Page): Promise<Footer> {
const current = await read(page)
const button = page.locator(`${promptModelSelector} [data-action="prompt-model"]`)
await expect(button).toBeVisible()
await button.click()
const dialog = page.getByRole("dialog")
await expect(dialog).toBeVisible()
const items = dialog.locator('[data-slot="list-item"]')
const count = await items.count()
expect(count).toBeGreaterThan(1)
for (let i = 0; i < count; i++) {
const item = items.nth(i)
const selected = (await item.getAttribute("data-selected")) === "true"
if (selected) continue
await item.click()
await expect(dialog).toHaveCount(0)
await expect.poll(async () => (await read(page)).model !== current.model, { timeout: 30_000 }).toBe(true)
return read(page)
}
throw new Error("Failed to choose a different model")
}
async function goto(page: Page, directory: string, sessionID?: string) {
await page.goto(sessionPath(directory, sessionID))
await expect(page.locator(promptSelector)).toBeVisible()
await expect.poll(async () => dirKey(await probe(page)), { timeout: 30_000 }).toBe(directory)
}
async function submit(page: Page, value: string) {
const prompt = page.locator(promptSelector)
await expect(prompt).toBeVisible()
await prompt.click()
await prompt.fill(value)
await prompt.press("Enter")
await expect.poll(() => sessionIDFromUrl(page.url()) ?? "", { timeout: 30_000 }).not.toBe("")
const id = sessionIDFromUrl(page.url())
if (!id) throw new Error(`Failed to resolve session id from ${page.url()}`)
return id
}
async function waitUser(directory: string, sessionID: string) {
const sdk = createSdk(directory)
await expect
.poll(
async () => {
const items = await sdk.session.messages({ sessionID, limit: 20 }).then((x) => x.data ?? [])
return items.some((item) => item.info.role === "user")
},
{ timeout: 30_000 },
)
.toBe(true)
await sdk.session.abort({ sessionID }).catch(() => undefined)
await waitSessionIdle(sdk, sessionID, 30_000).catch(() => undefined)
}
async function createWorkspace(page: Page, root: string, seen: string[]) {
await openSidebar(page)
await page.getByRole("button", { name: "New workspace" }).first().click()
const slug = await waitSlug(page, [root, ...seen])
const directory = base64Decode(slug)
if (!directory) throw new Error(`Failed to decode workspace slug: ${slug}`)
return { slug, directory }
}
async function waitWorkspace(page: Page, slug: string) {
await openSidebar(page)
await expect
.poll(
async () => {
const item = page.locator(workspaceItemSelector(slug)).first()
try {
await item.hover({ timeout: 500 })
return true
} catch {
return false
}
},
{ timeout: 60_000 },
)
.toBe(true)
}
async function newWorkspaceSession(page: Page, slug: string) {
await waitWorkspace(page, slug)
const item = page.locator(workspaceItemSelector(slug)).first()
await item.hover()
const button = page.locator(workspaceNewSessionSelector(slug)).first()
await expect(button).toBeVisible()
await button.click({ force: true })
const next = await waitSlug(page)
await expect(page).toHaveURL(new RegExp(`/${next}/session(?:[/?#]|$)`))
await expect(page.locator(promptSelector)).toBeVisible()
return currentDir(page)
}
test("session model and variant restore per session without leaking into new sessions", async ({
page,
withProject,
}) => {
await page.setViewportSize({ width: 1440, height: 900 })
await withProject(async ({ directory, gotoSession, trackSession }) => {
await gotoSession()
await ensureVariant(page, directory)
const firstState = await chooseDifferentVariant(page)
const first = await submit(page, `session variant ${Date.now()}`)
trackSession(first)
await waitUser(directory, first)
await page.reload()
await expect(page.locator(promptSelector)).toBeVisible()
await waitFooter(page, firstState)
await gotoSession()
const fresh = await ensureVariant(page, directory)
expect(fresh.variant).not.toBe(firstState.variant)
const secondState = await chooseOtherModel(page)
const second = await submit(page, `session model ${Date.now()}`)
trackSession(second)
await waitUser(directory, second)
await goto(page, directory, first)
await waitFooter(page, firstState)
await goto(page, directory, second)
await waitFooter(page, secondState)
await gotoSession()
await waitFooter(page, fresh)
})
})
test("session model restore across workspaces", async ({ page, withProject }) => {
await page.setViewportSize({ width: 1440, height: 900 })
await withProject(async ({ directory: root, slug, gotoSession, trackDirectory, trackSession }) => {
await gotoSession()
await ensureVariant(page, root)
const firstState = await chooseDifferentVariant(page)
const first = await submit(page, `root session ${Date.now()}`)
trackSession(first, root)
await waitUser(root, first)
await openSidebar(page)
await setWorkspacesEnabled(page, slug, true)
const one = await createWorkspace(page, slug, [])
const oneDir = await newWorkspaceSession(page, one.slug)
trackDirectory(oneDir)
const secondState = await chooseOtherModel(page)
const second = await submit(page, `workspace one ${Date.now()}`)
trackSession(second, oneDir)
await waitUser(oneDir, second)
const two = await createWorkspace(page, slug, [one.slug])
const twoDir = await newWorkspaceSession(page, two.slug)
trackDirectory(twoDir)
await ensureVariant(page, twoDir)
const thirdState = await chooseDifferentVariant(page)
const third = await submit(page, `workspace two ${Date.now()}`)
trackSession(third, twoDir)
await waitUser(twoDir, third)
await goto(page, root, first)
await waitFooter(page, firstState)
await goto(page, oneDir, second)
await waitFooter(page, secondState)
await goto(page, twoDir, third)
await waitFooter(page, thirdState)
await goto(page, root, first)
await waitFooter(page, firstState)
})
})

View File

@ -13,8 +13,10 @@ import { DialogSelectProvider } from "./dialog-select-provider"
import { ModelTooltip } from "./model-tooltip" import { ModelTooltip } from "./model-tooltip"
import { useLanguage } from "@/context/language" import { useLanguage } from "@/context/language"
export const DialogSelectModelUnpaid: Component = () => { type ModelState = ReturnType<typeof useLocal>["model"]
const local = useLocal()
export const DialogSelectModelUnpaid: Component<{ model?: ModelState }> = (props) => {
const model = props.model ?? useLocal().model
const dialog = useDialog() const dialog = useDialog()
const providers = useProviders() const providers = useProviders()
const language = useLanguage() const language = useLanguage()
@ -35,8 +37,8 @@ export const DialogSelectModelUnpaid: Component = () => {
<List <List
class="[&_[data-slot=list-scroll]]:overflow-visible" class="[&_[data-slot=list-scroll]]:overflow-visible"
ref={(ref) => (listRef = ref)} ref={(ref) => (listRef = ref)}
items={local.model.list} items={model.list}
current={local.model.current()} current={model.current()}
key={(x) => `${x.provider.id}:${x.id}`} key={(x) => `${x.provider.id}:${x.id}`}
itemWrapper={(item, node) => ( itemWrapper={(item, node) => (
<Tooltip <Tooltip
@ -55,7 +57,7 @@ export const DialogSelectModelUnpaid: Component = () => {
</Tooltip> </Tooltip>
)} )}
onSelect={(x) => { onSelect={(x) => {
local.model.set(x ? { modelID: x.id, providerID: x.provider.id } : undefined, { model.set(x ? { modelID: x.id, providerID: x.provider.id } : undefined, {
recent: true, recent: true,
}) })
dialog.close() dialog.close()

View File

@ -18,19 +18,22 @@ import { useLanguage } from "@/context/language"
const isFree = (provider: string, cost: { input: number } | undefined) => const isFree = (provider: string, cost: { input: number } | undefined) =>
provider === "opencode" && (!cost || cost.input === 0) provider === "opencode" && (!cost || cost.input === 0)
type ModelState = ReturnType<typeof useLocal>["model"]
const ModelList: Component<{ const ModelList: Component<{
provider?: string provider?: string
class?: string class?: string
onSelect: () => void onSelect: () => void
action?: JSX.Element action?: JSX.Element
model?: ModelState
}> = (props) => { }> = (props) => {
const local = useLocal() const model = props.model ?? useLocal().model
const language = useLanguage() const language = useLanguage()
const models = createMemo(() => const models = createMemo(() =>
local.model model
.list() .list()
.filter((m) => local.model.visible({ modelID: m.id, providerID: m.provider.id })) .filter((m) => model.visible({ modelID: m.id, providerID: m.provider.id }))
.filter((m) => (props.provider ? m.provider.id === props.provider : true)), .filter((m) => (props.provider ? m.provider.id === props.provider : true)),
) )
@ -41,7 +44,7 @@ const ModelList: Component<{
emptyMessage={language.t("dialog.model.empty")} emptyMessage={language.t("dialog.model.empty")}
key={(x) => `${x.provider.id}:${x.id}`} key={(x) => `${x.provider.id}:${x.id}`}
items={models} items={models}
current={local.model.current()} current={model.current()}
filterKeys={["provider.name", "name", "id"]} filterKeys={["provider.name", "name", "id"]}
sortBy={(a, b) => a.name.localeCompare(b.name)} sortBy={(a, b) => a.name.localeCompare(b.name)}
groupBy={(x) => x.provider.name} groupBy={(x) => x.provider.name}
@ -63,7 +66,7 @@ const ModelList: Component<{
</Tooltip> </Tooltip>
)} )}
onSelect={(x) => { onSelect={(x) => {
local.model.set(x ? { modelID: x.id, providerID: x.provider.id } : undefined, { model.set(x ? { modelID: x.id, providerID: x.provider.id } : undefined, {
recent: true, recent: true,
}) })
props.onSelect() props.onSelect()
@ -88,6 +91,7 @@ type ModelSelectorTriggerProps = Omit<ComponentProps<typeof Kobalte.Trigger>, "a
export function ModelSelectorPopover(props: { export function ModelSelectorPopover(props: {
provider?: string provider?: string
model?: ModelState
children?: JSX.Element children?: JSX.Element
triggerAs?: ValidComponent triggerAs?: ValidComponent
triggerProps?: ModelSelectorTriggerProps triggerProps?: ModelSelectorTriggerProps
@ -151,6 +155,7 @@ export function ModelSelectorPopover(props: {
<Kobalte.Title class="sr-only">{language.t("dialog.model.select.title")}</Kobalte.Title> <Kobalte.Title class="sr-only">{language.t("dialog.model.select.title")}</Kobalte.Title>
<ModelList <ModelList
provider={props.provider} provider={props.provider}
model={props.model}
onSelect={() => setStore("open", false)} onSelect={() => setStore("open", false)}
class="p-1" class="p-1"
action={ action={
@ -184,7 +189,7 @@ export function ModelSelectorPopover(props: {
) )
} }
export const DialogSelectModel: Component<{ provider?: string }> = (props) => { export const DialogSelectModel: Component<{ provider?: string; model?: ModelState }> = (props) => {
const dialog = useDialog() const dialog = useDialog()
const language = useLanguage() const language = useLanguage()
@ -202,7 +207,7 @@ export const DialogSelectModel: Component<{ provider?: string }> = (props) => {
</Button> </Button>
} }
> >
<ModelList provider={props.provider} onSelect={() => dialog.close()} /> <ModelList provider={props.provider} model={props.model} onSelect={() => dialog.close()} />
<Button <Button
variant="ghost" variant="ghost"
class="ml-3 mt-5 mb-6 text-text-base self-start" class="ml-3 mt-5 mb-6 text-text-base self-start"

View File

@ -1430,39 +1430,76 @@ export const PromptInput: Component<PromptInputProps> = (props) => {
<div class="size-4 shrink-0" /> <div class="size-4 shrink-0" />
</div> </div>
<div class="flex items-center gap-1.5 min-w-0 flex-1"> <div class="flex items-center gap-1.5 min-w-0 flex-1">
<TooltipKeybind <div data-component="prompt-agent-control">
placement="top" <TooltipKeybind
gutter={4} placement="top"
title={language.t("command.agent.cycle")} gutter={4}
keybind={command.keybind("agent.cycle")} title={language.t("command.agent.cycle")}
> keybind={command.keybind("agent.cycle")}
<Select >
size="normal" <Select
options={agentNames()} size="normal"
current={local.agent.current()?.name ?? ""} options={agentNames()}
onSelect={local.agent.set} current={local.agent.current()?.name ?? ""}
class="capitalize max-w-[160px] text-text-base" onSelect={local.agent.set}
valueClass="truncate text-13-regular text-text-base" class="capitalize max-w-[160px] text-text-base"
triggerStyle={control()} valueClass="truncate text-13-regular text-text-base"
variant="ghost" triggerStyle={control()}
/> triggerProps={{ "data-action": "prompt-agent" }}
</TooltipKeybind> variant="ghost"
<Show />
when={providers.paid().length > 0} </TooltipKeybind>
fallback={ </div>
<div data-component="prompt-model-control">
<Show
when={providers.paid().length > 0}
fallback={
<TooltipKeybind
placement="top"
gutter={4}
title={language.t("command.model.choose")}
keybind={command.keybind("model.choose")}
>
<Button
data-action="prompt-model"
as="div"
variant="ghost"
size="normal"
class="min-w-0 max-w-[320px] text-13-regular text-text-base group"
style={control()}
onClick={() => dialog.show(() => <DialogSelectModelUnpaid model={local.model} />)}
>
<Show when={local.model.current()?.provider?.id}>
<ProviderIcon
id={local.model.current()!.provider.id}
class="size-4 shrink-0 opacity-40 group-hover:opacity-100 transition-opacity duration-150"
style={{ "will-change": "opacity", transform: "translateZ(0)" }}
/>
</Show>
<span class="truncate">
{local.model.current()?.name ?? language.t("dialog.model.select.title")}
</span>
<Icon name="chevron-down" size="small" class="shrink-0" />
</Button>
</TooltipKeybind>
}
>
<TooltipKeybind <TooltipKeybind
placement="top" placement="top"
gutter={4} gutter={4}
title={language.t("command.model.choose")} title={language.t("command.model.choose")}
keybind={command.keybind("model.choose")} keybind={command.keybind("model.choose")}
> >
<Button <ModelSelectorPopover
as="div" model={local.model}
variant="ghost" triggerAs={Button}
size="normal" triggerProps={{
class="min-w-0 max-w-[320px] text-13-regular text-text-base group" variant: "ghost",
style={control()} size: "normal",
onClick={() => dialog.show(() => <DialogSelectModelUnpaid />)} style: control(),
class: "min-w-0 max-w-[320px] text-13-regular text-text-base group",
"data-action": "prompt-model",
}}
> >
<Show when={local.model.current()?.provider?.id}> <Show when={local.model.current()?.provider?.id}>
<ProviderIcon <ProviderIcon
@ -1475,57 +1512,31 @@ export const PromptInput: Component<PromptInputProps> = (props) => {
{local.model.current()?.name ?? language.t("dialog.model.select.title")} {local.model.current()?.name ?? language.t("dialog.model.select.title")}
</span> </span>
<Icon name="chevron-down" size="small" class="shrink-0" /> <Icon name="chevron-down" size="small" class="shrink-0" />
</Button> </ModelSelectorPopover>
</TooltipKeybind> </TooltipKeybind>
} </Show>
> </div>
<div data-component="prompt-variant-control">
<TooltipKeybind <TooltipKeybind
placement="top" placement="top"
gutter={4} gutter={4}
title={language.t("command.model.choose")} title={language.t("command.model.variant.cycle")}
keybind={command.keybind("model.choose")} keybind={command.keybind("model.variant.cycle")}
> >
<ModelSelectorPopover <Select
triggerAs={Button} size="normal"
triggerProps={{ options={variants()}
variant: "ghost", current={local.model.variant.current() ?? "default"}
size: "normal", label={(x) => (x === "default" ? language.t("common.default") : x)}
style: control(), onSelect={(x) => local.model.variant.set(x === "default" ? undefined : x)}
class: "min-w-0 max-w-[320px] text-13-regular text-text-base group", class="capitalize max-w-[160px] text-text-base"
}} valueClass="truncate text-13-regular text-text-base"
> triggerStyle={control()}
<Show when={local.model.current()?.provider?.id}> triggerProps={{ "data-action": "prompt-model-variant" }}
<ProviderIcon variant="ghost"
id={local.model.current()!.provider.id} />
class="size-4 shrink-0 opacity-40 group-hover:opacity-100 transition-opacity duration-150"
style={{ "will-change": "opacity", transform: "translateZ(0)" }}
/>
</Show>
<span class="truncate">
{local.model.current()?.name ?? language.t("dialog.model.select.title")}
</span>
<Icon name="chevron-down" size="small" class="shrink-0" />
</ModelSelectorPopover>
</TooltipKeybind> </TooltipKeybind>
</Show> </div>
<TooltipKeybind
placement="top"
gutter={4}
title={language.t("command.model.variant.cycle")}
keybind={command.keybind("model.variant.cycle")}
>
<Select
size="normal"
options={variants()}
current={local.model.variant.current() ?? "default"}
label={(x) => (x === "default" ? language.t("common.default") : x)}
onSelect={(x) => local.model.variant.set(x === "default" ? undefined : x)}
class="capitalize max-w-[160px] text-text-base"
valueClass="truncate text-13-regular text-text-base"
triggerStyle={control()}
variant="ghost"
/>
</TooltipKeybind>
<TooltipKeybind <TooltipKeybind
placement="top" placement="top"
gutter={8} gutter={8}

View File

@ -17,6 +17,7 @@ const optimistic: Array<{
}> = [] }> = []
const optimisticSeeded: boolean[] = [] const optimisticSeeded: boolean[] = []
const storedSessions: Record<string, Array<{ id: string; title?: string }>> = {} const storedSessions: Record<string, Array<{ id: string; title?: string }>> = {}
const promoted: Array<{ directory: string; sessionID: string }> = []
const sentShell: string[] = [] const sentShell: string[] = []
const syncedDirectories: string[] = [] const syncedDirectories: string[] = []
@ -86,6 +87,11 @@ beforeAll(async () => {
agent: { agent: {
current: () => ({ name: "agent" }), current: () => ({ name: "agent" }),
}, },
session: {
promote(directory: string, sessionID: string) {
promoted.push({ directory, sessionID })
},
},
}), }),
})) }))
@ -201,6 +207,7 @@ beforeEach(() => {
enabledAutoAccept.length = 0 enabledAutoAccept.length = 0
optimistic.length = 0 optimistic.length = 0
optimisticSeeded.length = 0 optimisticSeeded.length = 0
promoted.length = 0
params = {} params = {}
sentShell.length = 0 sentShell.length = 0
syncedDirectories.length = 0 syncedDirectories.length = 0
@ -240,6 +247,11 @@ describe("prompt submit worktree selection", () => {
expect(createdSessions).toEqual(["/repo/worktree-a", "/repo/worktree-b"]) expect(createdSessions).toEqual(["/repo/worktree-a", "/repo/worktree-b"])
expect(sentShell).toEqual(["/repo/worktree-a", "/repo/worktree-b"]) expect(sentShell).toEqual(["/repo/worktree-a", "/repo/worktree-b"])
expect(syncedDirectories).toEqual(["/repo/worktree-a", "/repo/worktree-a", "/repo/worktree-b", "/repo/worktree-b"]) expect(syncedDirectories).toEqual(["/repo/worktree-a", "/repo/worktree-a", "/repo/worktree-b", "/repo/worktree-b"])
expect(promoted).toEqual([
{ directory: "/repo/worktree-a", sessionID: "session-1" },
{ directory: "/repo/worktree-b", sessionID: "session-2" },
])
expect(syncedDirectories).toEqual(["/repo/worktree-a", "/repo/worktree-a", "/repo/worktree-b", "/repo/worktree-b"])
}) })
test("applies auto-accept to newly created sessions", async () => { test("applies auto-accept to newly created sessions", async () => {

View File

@ -296,6 +296,7 @@ export function createPromptSubmit(input: PromptSubmitInput) {
const currentModel = local.model.current() const currentModel = local.model.current()
const currentAgent = local.agent.current() const currentAgent = local.agent.current()
const variant = local.model.variant.current()
if (!currentModel || !currentAgent) { if (!currentModel || !currentAgent) {
showToast({ showToast({
title: language.t("prompt.toast.modelAgentRequired.title"), title: language.t("prompt.toast.modelAgentRequired.title"),
@ -370,6 +371,7 @@ export function createPromptSubmit(input: PromptSubmitInput) {
seed(sessionDirectory, created) seed(sessionDirectory, created)
session = created session = created
if (shouldAutoAccept) permission.enableAutoAccept(session.id, sessionDirectory) if (shouldAutoAccept) permission.enableAutoAccept(session.id, sessionDirectory)
local.session.promote(sessionDirectory, session.id)
layout.handoff.setTabs(base64Encode(sessionDirectory), session.id) layout.handoff.setTabs(base64Encode(sessionDirectory), session.id)
navigate(`/${base64Encode(sessionDirectory)}/session/${session.id}`) navigate(`/${base64Encode(sessionDirectory)}/session/${session.id}`)
} }
@ -387,7 +389,6 @@ export function createPromptSubmit(input: PromptSubmitInput) {
providerID: currentModel.provider.id, providerID: currentModel.provider.id,
} }
const agent = currentAgent.name const agent = currentAgent.name
const variant = local.model.variant.current()
const context = prompt.context.items().slice() const context = prompt.context.items().slice()
const draft: FollowupDraft = { const draft: FollowupDraft = {
sessionID: session.id, sessionID: session.id,

View File

@ -1,252 +1,421 @@
import { createStore } from "solid-js/store"
import { batch, createMemo } from "solid-js"
import { createSimpleContext } from "@opencode-ai/ui/context" import { createSimpleContext } from "@opencode-ai/ui/context"
import { base64Encode } from "@opencode-ai/util/encode"
import { useParams } from "@solidjs/router"
import { batch, createEffect, createMemo, onCleanup } from "solid-js"
import { createStore } from "solid-js/store"
import { useModels } from "@/context/models"
import { useProviders } from "@/hooks/use-providers"
import { modelEnabled, modelProbe } from "@/testing/model-selection"
import { Persist, persisted } from "@/utils/persist"
import { cycleModelVariant, getConfiguredAgentVariant, resolveModelVariant } from "./model-variant"
import { useSDK } from "./sdk" import { useSDK } from "./sdk"
import { useSync } from "./sync" import { useSync } from "./sync"
import { base64Encode } from "@opencode-ai/util/encode"
import { useProviders } from "@/hooks/use-providers"
import { useModels } from "@/context/models"
import { cycleModelVariant, getConfiguredAgentVariant, resolveModelVariant } from "./model-variant"
export type ModelKey = { providerID: string; modelID: string } export type ModelKey = { providerID: string; modelID: string }
type State = {
agent?: string
model?: ModelKey
variant?: string | null
}
type Saved = {
session: Record<string, State | undefined>
}
const WORKSPACE_KEY = "__workspace__"
const handoff = new Map<string, State>()
const handoffKey = (dir: string, id: string) => `${dir}\n${id}`
const migrate = (value: unknown) => {
if (!value || typeof value !== "object") return { session: {} }
const item = value as {
session?: Record<string, State | undefined>
pick?: Record<string, State | undefined>
}
if (item.session && typeof item.session === "object") return { session: item.session }
if (!item.pick || typeof item.pick !== "object") return { session: {} }
return {
session: Object.fromEntries(Object.entries(item.pick).filter(([key]) => key !== WORKSPACE_KEY)),
}
}
const clone = (value: State | undefined) => {
if (!value) return undefined
return {
...value,
model: value.model ? { ...value.model } : undefined,
} satisfies State
}
export const { use: useLocal, provider: LocalProvider } = createSimpleContext({ export const { use: useLocal, provider: LocalProvider } = createSimpleContext({
name: "Local", name: "Local",
init: () => { init: () => {
const params = useParams()
const sdk = useSDK() const sdk = useSDK()
const sync = useSync() const sync = useSync()
const providers = useProviders() const providers = useProviders()
const connected = createMemo(() => new Set(providers.connected().map((provider) => provider.id))) const models = useModels()
function isModelValid(model: ModelKey) { const id = createMemo(() => params.id || undefined)
const provider = providers.all().find((x) => x.id === model.providerID) const list = createMemo(() => sync.data.agent.filter((item) => item.mode !== "subagent" && !item.hidden))
const connected = createMemo(() => new Set(providers.connected().map((item) => item.id)))
const [saved, setSaved] = persisted(
{
...Persist.workspace(sdk.directory, "model-selection", ["model-selection.v1"]),
migrate,
},
createStore<Saved>({
session: {},
}),
)
const [store, setStore] = createStore<{
current?: string
draft?: State
last?: {
type: "agent" | "model" | "variant"
agent?: string
model?: ModelKey | null
variant?: string | null
}
}>({
current: list()[0]?.name,
draft: undefined,
last: undefined,
})
const validModel = (model: ModelKey) => {
const provider = providers.all().find((item) => item.id === model.providerID)
return !!provider?.models[model.modelID] && connected().has(model.providerID) return !!provider?.models[model.modelID] && connected().has(model.providerID)
} }
function getFirstValidModel(...modelFns: (() => ModelKey | undefined)[]) { const firstModel = (...items: Array<() => ModelKey | undefined>) => {
for (const modelFn of modelFns) { for (const item of items) {
const model = modelFn() const model = item()
if (!model) continue if (!model) continue
if (isModelValid(model)) return model if (validModel(model)) return model
} }
} }
let setModel: (model: ModelKey | undefined, options?: { recent?: boolean }) => void = () => undefined const pickAgent = (name: string | undefined) => {
const items = list()
if (items.length === 0) return undefined
return items.find((item) => item.name === name) ?? items[0]
}
const agent = (() => { createEffect(() => {
const list = createMemo(() => sync.data.agent.filter((x) => x.mode !== "subagent" && !x.hidden)) const items = list()
const models = useModels() if (items.length === 0) {
if (store.current !== undefined) setStore("current", undefined)
return
}
if (items.some((item) => item.name === store.current)) return
setStore("current", items[0]?.name)
})
const [store, setStore] = createStore<{ const scope = createMemo<State | undefined>(() => {
current?: string const session = id()
}>({ if (!session) return store.draft
current: list()[0]?.name, return saved.session[session] ?? handoff.get(handoffKey(sdk.directory, session))
})
createEffect(() => {
const session = id()
if (!session) return
const key = handoffKey(sdk.directory, session)
const next = handoff.get(key)
if (!next) return
if (saved.session[session] !== undefined) {
handoff.delete(key)
return
}
setSaved("session", session, clone(next))
handoff.delete(key)
})
const configuredModel = () => {
if (!sync.data.config.model) return
const [providerID, modelID] = sync.data.config.model.split("/")
const model = { providerID, modelID }
if (validModel(model)) return model
}
const recentModel = () => {
for (const item of models.recent.list()) {
if (validModel(item)) return item
}
}
const defaultModel = () => {
const defaults = providers.default()
for (const provider of providers.connected()) {
const configured = defaults[provider.id]
if (configured) {
const model = { providerID: provider.id, modelID: configured }
if (validModel(model)) return model
}
const first = Object.values(provider.models)[0]
if (!first) continue
const model = { providerID: provider.id, modelID: first.id }
if (validModel(model)) return model
}
}
const fallback = createMemo<ModelKey | undefined>(() => configuredModel() ?? recentModel() ?? defaultModel())
const agent = {
list,
current() {
return pickAgent(scope()?.agent ?? store.current)
},
set(name: string | undefined) {
const item = pickAgent(name)
if (!item) {
setStore("current", undefined)
return
}
batch(() => {
setStore("current", item.name)
setStore("last", {
type: "agent",
agent: item.name,
model: item.model,
variant: item.variant ?? null,
})
const next = {
agent: item.name,
model: item.model,
variant: item.variant,
} satisfies State
const session = id()
if (session) {
setSaved("session", session, next)
return
}
setStore("draft", next)
})
},
move(direction: 1 | -1) {
const items = list()
if (items.length === 0) {
setStore("current", undefined)
return
}
let next = items.findIndex((item) => item.name === agent.current()?.name) + direction
if (next < 0) next = items.length - 1
if (next >= items.length) next = 0
const item = items[next]
if (!item) return
agent.set(item.name)
},
}
const current = () => {
const item = firstModel(
() => scope()?.model,
() => agent.current()?.model,
fallback,
)
if (!item) return undefined
return models.find(item)
}
const configured = () => {
const item = agent.current()
const model = current()
if (!item || !model) return undefined
return getConfiguredAgentVariant({
agent: { model: item.model, variant: item.variant },
model: { providerID: model.provider.id, modelID: model.id, variants: model.variants },
}) })
}
const selected = () => scope()?.variant
const snapshot = () => {
const model = current()
return { return {
list, agent: agent.current()?.name,
current() { model: model ? { providerID: model.provider.id, modelID: model.id } : undefined,
const available = list() variant: selected(),
if (available.length === 0) return undefined } satisfies State
return available.find((x) => x.name === store.current) ?? available[0] }
},
set(name: string | undefined) { const write = (next: Partial<State>) => {
const available = list() const state = {
if (available.length === 0) { ...(scope() ?? { agent: agent.current()?.name }),
setStore("current", undefined) ...next,
return } satisfies State
}
const match = name ? available.find((x) => x.name === name) : undefined const session = id()
const value = match ?? available[0] if (session) {
if (!value) return setSaved("session", session, state)
setStore("current", value.name) return
if (!value.model) return
setModel({
providerID: value.model.providerID,
modelID: value.model.modelID,
})
if (value.variant)
models.variant.set({ providerID: value.model.providerID, modelID: value.model.modelID }, value.variant)
},
move(direction: 1 | -1) {
const available = list()
if (available.length === 0) {
setStore("current", undefined)
return
}
let next = available.findIndex((x) => x.name === store.current) + direction
if (next < 0) next = available.length - 1
if (next >= available.length) next = 0
const value = available[next]
if (!value) return
setStore("current", value.name)
if (!value.model) return
setModel({
providerID: value.model.providerID,
modelID: value.model.modelID,
})
if (value.variant)
models.variant.set({ providerID: value.model.providerID, modelID: value.model.modelID }, value.variant)
},
} }
})() setStore("draft", state)
}
const model = (() => { const recent = createMemo(() => models.recent.list().map(models.find).filter(Boolean))
const models = useModels()
const [ephemeral, setEphemeral] = createStore<{ const model = {
model: Record<string, ModelKey | undefined> ready: models.ready,
}>({ current,
model: {}, recent,
}) list: models.list,
cycle(direction: 1 | -1) {
const items = recent()
const item = current()
if (!item) return
const resolveConfigured = () => { const index = items.findIndex((entry) => entry?.provider.id === item.provider.id && entry?.id === item.id)
if (!sync.data.config.model) return
const [providerID, modelID] = sync.data.config.model.split("/")
const key = { providerID, modelID }
if (isModelValid(key)) return key
}
const resolveRecent = () => {
for (const item of models.recent.list()) {
if (isModelValid(item)) return item
}
}
const resolveDefault = () => {
const defaults = providers.default()
for (const provider of providers.connected()) {
const configured = defaults[provider.id]
if (configured) {
const key = { providerID: provider.id, modelID: configured }
if (isModelValid(key)) return key
}
const first = Object.values(provider.models)[0]
if (!first) continue
const key = { providerID: provider.id, modelID: first.id }
if (isModelValid(key)) return key
}
}
const fallbackModel = createMemo<ModelKey | undefined>(() => {
return resolveConfigured() ?? resolveRecent() ?? resolveDefault()
})
const current = createMemo(() => {
const a = agent.current()
if (!a) return undefined
const key = getFirstValidModel(
() => ephemeral.model[a.name],
() => a.model,
fallbackModel,
)
if (!key) return undefined
return models.find(key)
})
const recent = createMemo(() => models.recent.list().map(models.find).filter(Boolean))
const cycle = (direction: 1 | -1) => {
const recentList = recent()
const currentModel = current()
if (!currentModel) return
const index = recentList.findIndex(
(x) => x?.provider.id === currentModel.provider.id && x?.id === currentModel.id,
)
if (index === -1) return if (index === -1) return
let next = index + direction let next = index + direction
if (next < 0) next = recentList.length - 1 if (next < 0) next = items.length - 1
if (next >= recentList.length) next = 0 if (next >= items.length) next = 0
const val = recentList[next] const entry = items[next]
if (!val) return if (!entry) return
model.set({ providerID: entry.provider.id, modelID: entry.id })
model.set({ },
providerID: val.provider.id, set(item: ModelKey | undefined, options?: { recent?: boolean }) {
modelID: val.id,
})
}
const set = (model: ModelKey | undefined, options?: { recent?: boolean }) => {
batch(() => { batch(() => {
const currentAgent = agent.current() setStore("last", {
const next = model ?? fallbackModel() type: "model",
if (currentAgent) setEphemeral("model", currentAgent.name, next) agent: agent.current()?.name,
if (model) models.setVisibility(model, true) model: item ?? null,
if (options?.recent && model) models.recent.push(model) variant: selected(),
})
write({ model: item })
if (!item) return
models.setVisibility(item, true)
if (!options?.recent) return
models.recent.push(item)
}) })
} },
visible(item: ModelKey) {
setModel = set return models.visible(item)
},
return { setVisibility(item: ModelKey, visible: boolean) {
ready: models.ready, models.setVisibility(item, visible)
current, },
recent, variant: {
list: models.list, configured,
cycle, selected,
set, current() {
visible(model: ModelKey) { return resolveModelVariant({
return models.visible(model) variants: this.list(),
selected: this.selected(),
configured: this.configured(),
})
}, },
setVisibility(model: ModelKey, visible: boolean) { list() {
models.setVisibility(model, visible) const item = current()
if (!item?.variants) return []
return Object.keys(item.variants)
}, },
variant: { set(value: string | undefined) {
configured() { batch(() => {
const a = agent.current() const model = current()
const m = current() setStore("last", {
if (!a || !m) return undefined type: "variant",
return getConfiguredAgentVariant({ agent: agent.current()?.name,
agent: { model: a.model, variant: a.variant }, model: model ? { providerID: model.provider.id, modelID: model.id } : null,
model: { providerID: m.provider.id, modelID: m.id, variants: m.variants }, variant: value ?? null,
}) })
}, write({ variant: value ?? null })
selected() { })
const m = current() },
if (!m) return undefined cycle() {
return models.variant.get({ providerID: m.provider.id, modelID: m.id }) const items = this.list()
}, if (items.length === 0) return
current() { this.set(
return resolveModelVariant({ cycleModelVariant({
variants: this.list(), variants: items,
selected: this.selected(), selected: this.selected(),
configured: this.configured(), configured: this.configured(),
}) }),
}, )
list() {
const m = current()
if (!m) return []
if (!m.variants) return []
return Object.keys(m.variants)
},
set(value: string | undefined) {
const m = current()
if (!m) return
models.variant.set({ providerID: m.provider.id, modelID: m.id }, value)
},
cycle() {
const variants = this.list()
if (variants.length === 0) return
this.set(
cycleModelVariant({
variants,
selected: this.selected(),
configured: this.configured(),
}),
)
},
}, },
} },
})() }
const result = { const result = {
slug: createMemo(() => base64Encode(sdk.directory)), slug: createMemo(() => base64Encode(sdk.directory)),
model, model,
agent, agent,
session: {
reset() {
setStore("draft", undefined)
},
promote(dir: string, session: string) {
const next = clone(snapshot())
if (!next) return
if (dir === sdk.directory) {
setSaved("session", session, next)
setStore("draft", undefined)
return
}
handoff.set(handoffKey(dir, session), next)
setStore("draft", undefined)
},
restore(msg: { sessionID: string; agent: string; model: ModelKey; variant?: string }) {
const session = id()
if (!session) return
if (msg.sessionID !== session) return
if (saved.session[session] !== undefined) return
if (handoff.has(handoffKey(sdk.directory, session))) return
setSaved("session", session, {
agent: msg.agent,
model: msg.model,
variant: msg.variant ?? null,
})
},
},
} }
if (modelEnabled()) {
createEffect(() => {
const agent = result.agent.current()
const model = result.model.current()
modelProbe.set({
dir: sdk.directory,
sessionID: id(),
last: store.last,
agent: agent?.name,
model: model
? {
providerID: model.provider.id,
modelID: model.id,
name: model.name,
}
: undefined,
variant: result.model.variant.current() ?? null,
selected: result.model.variant.selected(),
configured: result.model.variant.configured(),
pick: scope(),
base: undefined,
current: store.current,
})
})
onCleanup(() => modelProbe.clear())
}
return result return result
}, },
}) })

View File

@ -44,6 +44,16 @@ describe("model variant", () => {
expect(value).toBe("high") expect(value).toBe("high")
}) })
test("lets an explicit default override the configured variant", () => {
const value = resolveModelVariant({
variants: ["low", "high", "xhigh"],
selected: null,
configured: "xhigh",
})
expect(value).toBeUndefined()
})
test("cycles from configured variant to next", () => { test("cycles from configured variant to next", () => {
const value = cycleModelVariant({ const value = cycleModelVariant({
variants: ["low", "high", "xhigh"], variants: ["low", "high", "xhigh"],
@ -63,4 +73,14 @@ describe("model variant", () => {
expect(value).toBe("low") expect(value).toBe("low")
}) })
test("cycles from an explicit default to the first variant", () => {
const value = cycleModelVariant({
variants: ["low", "high", "xhigh"],
selected: null,
configured: "xhigh",
})
expect(value).toBe("low")
})
}) })

View File

@ -14,7 +14,7 @@ type Model = AgentModel & {
type VariantInput = { type VariantInput = {
variants: string[] variants: string[]
selected: string | undefined selected: string | null | undefined
configured: string | undefined configured: string | undefined
} }
@ -29,6 +29,7 @@ export function getConfiguredAgentVariant(input: { agent: Agent | undefined; mod
} }
export function resolveModelVariant(input: VariantInput) { export function resolveModelVariant(input: VariantInput) {
if (input.selected === null) return undefined
if (input.selected && input.variants.includes(input.selected)) return input.selected if (input.selected && input.variants.includes(input.selected)) return input.selected
if (input.configured && input.variants.includes(input.configured)) return input.configured if (input.configured && input.variants.includes(input.configured)) return input.configured
return undefined return undefined
@ -36,6 +37,7 @@ export function resolveModelVariant(input: VariantInput) {
export function cycleModelVariant(input: VariantInput) { export function cycleModelVariant(input: VariantInput) {
if (input.variants.length === 0) return undefined if (input.variants.length === 0) return undefined
if (input.selected === null) return input.variants[0]
if (input.selected && input.variants.includes(input.selected)) { if (input.selected && input.variants.includes(input.selected)) {
const index = input.variants.indexOf(input.selected) const index = input.variants.indexOf(input.selected)
if (index === input.variants.length - 1) return undefined if (index === input.variants.length - 1) return undefined

View File

@ -80,11 +80,11 @@ export default function Layout(props: ParentProps) {
}) })
return ( return (
<Show when={state.resolved}> <Show when={state.resolved} keyed>
{(resolved) => ( {(resolved) => (
<SDKProvider directory={resolved}> <SDKProvider directory={() => resolved}>
<SyncProvider> <SyncProvider>
<DirectoryDataProvider directory={resolved()}>{props.children}</DirectoryDataProvider> <DirectoryDataProvider directory={resolved}>{props.children}</DirectoryDataProvider>
</SyncProvider> </SyncProvider>
</SDKProvider> </SDKProvider>
)} )}

View File

@ -44,7 +44,7 @@ import { createOpenReviewFile, createSessionTabs, createSizing, focusTerminalByI
import { MessageTimeline } from "@/pages/session/message-timeline" import { MessageTimeline } from "@/pages/session/message-timeline"
import { type DiffStyle, SessionReviewTab, type SessionReviewTabProps } from "@/pages/session/review-tab" import { type DiffStyle, SessionReviewTab, type SessionReviewTabProps } from "@/pages/session/review-tab"
import { useSessionLayout } from "@/pages/session/session-layout" import { useSessionLayout } from "@/pages/session/session-layout"
import { resetSessionModel, syncSessionModel } from "@/pages/session/session-model-helpers" import { syncSessionModel } from "@/pages/session/session-model-helpers"
import { SessionSidePanel } from "@/pages/session/session-side-panel" import { SessionSidePanel } from "@/pages/session/session-side-panel"
import { TerminalPanel } from "@/pages/session/terminal-panel" import { TerminalPanel } from "@/pages/session/terminal-panel"
import { useSessionCommands } from "@/pages/session/use-session-commands" import { useSessionCommands } from "@/pages/session/use-session-commands"
@ -490,7 +490,7 @@ export default function Page() {
(next, prev) => { (next, prev) => {
if (!prev) return if (!prev) return
if (next.dir === prev.dir && next.id === prev.id) return if (next.dir === prev.dir && next.id === prev.id) return
if (!next.id) resetSessionModel(local) if (prev.id && !next.id) local.session.reset()
}, },
{ defer: true }, { defer: true },
), ),

View File

@ -14,145 +14,38 @@ const message = (input?: Partial<Pick<UserMessage, "agent" | "model" | "variant"
}) as UserMessage }) as UserMessage
describe("syncSessionModel", () => { describe("syncSessionModel", () => {
test("restores the last message model and variant", () => { test("restores the last message through session state", () => {
const calls: unknown[] = [] const calls: unknown[] = []
syncSessionModel( syncSessionModel(
{ {
agent: { session: {
current() { restore(value) {
return undefined calls.push(value)
},
set(value) {
calls.push(["agent", value])
},
},
model: {
set(value) {
calls.push(["model", value])
},
current() {
return { id: "claude-sonnet-4", provider: { id: "anthropic" } }
},
variant: {
set(value) {
calls.push(["variant", value])
},
}, },
reset() {},
}, },
}, },
message({ variant: "high" }), message({ variant: "high" }),
) )
expect(calls).toEqual([ expect(calls).toEqual([message({ variant: "high" })])
["agent", "build"],
["model", { providerID: "anthropic", modelID: "claude-sonnet-4" }],
["variant", "high"],
])
})
test("skips variant when the model falls back", () => {
const calls: unknown[] = []
syncSessionModel(
{
agent: {
current() {
return undefined
},
set(value) {
calls.push(["agent", value])
},
},
model: {
set(value) {
calls.push(["model", value])
},
current() {
return { id: "gpt-5", provider: { id: "openai" } }
},
variant: {
set(value) {
calls.push(["variant", value])
},
},
},
},
message({ variant: "high" }),
)
expect(calls).toEqual([
["agent", "build"],
["model", { providerID: "anthropic", modelID: "claude-sonnet-4" }],
])
}) })
}) })
describe("resetSessionModel", () => { describe("resetSessionModel", () => {
test("restores the current agent defaults", () => { test("clears draft session state", () => {
const calls: unknown[] = [] const calls: string[] = []
resetSessionModel({ resetSessionModel({
agent: { session: {
current() { reset() {
return { calls.push("reset")
model: { providerID: "anthropic", modelID: "claude-sonnet-4" },
variant: "high",
}
},
set() {},
},
model: {
set(value) {
calls.push(["model", value])
},
current() {
return undefined
},
variant: {
set(value) {
calls.push(["variant", value])
},
}, },
restore() {},
}, },
}) })
expect(calls).toEqual([ expect(calls).toEqual(["reset"])
["model", { providerID: "anthropic", modelID: "claude-sonnet-4" }],
["variant", "high"],
])
})
test("clears the variant when the agent has none", () => {
const calls: unknown[] = []
resetSessionModel({
agent: {
current() {
return {
model: { providerID: "anthropic", modelID: "claude-sonnet-4" },
}
},
set() {},
},
model: {
set(value) {
calls.push(["model", value])
},
current() {
return undefined
},
variant: {
set(value) {
calls.push(["variant", value])
},
},
},
})
expect(calls).toEqual([
["model", { providerID: "anthropic", modelID: "claude-sonnet-4" }],
["variant", undefined],
])
}) })
}) })

View File

@ -1,48 +1,16 @@
import type { UserMessage } from "@opencode-ai/sdk/v2" import type { UserMessage } from "@opencode-ai/sdk/v2"
import { batch } from "solid-js"
type Local = { type Local = {
agent: { session: {
current(): reset(): void
| { restore(msg: UserMessage): void
model?: UserMessage["model"]
variant?: string
}
| undefined
set(name: string | undefined): void
}
model: {
set(model: UserMessage["model"] | undefined): void
current():
| {
id: string
provider: { id: string }
}
| undefined
variant: {
set(value: string | undefined): void
}
} }
} }
export const resetSessionModel = (local: Local) => { export const resetSessionModel = (local: Local) => {
const agent = local.agent.current() local.session.reset()
if (!agent) return
batch(() => {
local.model.set(agent.model)
local.model.variant.set(agent.variant)
})
} }
export const syncSessionModel = (local: Local, msg: UserMessage) => { export const syncSessionModel = (local: Local, msg: UserMessage) => {
batch(() => { local.session.restore(msg)
local.agent.set(msg.agent)
local.model.set(msg.model)
})
const model = local.model.current()
if (!model) return
if (model.provider.id !== msg.model.providerID) return
if (model.id !== msg.model.modelID) return
local.model.variant.set(msg.variant)
} }

View File

@ -351,7 +351,7 @@ export const useSessionCommands = (actions: SessionCommandContext) => {
description: language.t("command.model.choose.description"), description: language.t("command.model.choose.description"),
keybind: "mod+'", keybind: "mod+'",
slash: "model", slash: "model",
onSelect: () => dialog.show(() => <DialogSelectModel />), onSelect: () => dialog.show(() => <DialogSelectModel model={local.model} />),
}), }),
mcpCommand({ mcpCommand({
id: "mcp.toggle", id: "mcp.toggle",

View File

@ -0,0 +1,80 @@
type ModelKey = {
providerID: string
modelID: string
}
type State = {
agent?: string
model?: ModelKey | null
variant?: string | null
}
export type ModelProbeState = {
dir?: string
sessionID?: string
last?: {
type: "agent" | "model" | "variant"
agent?: string
model?: ModelKey | null
variant?: string | null
}
agent?: string
model?: (ModelKey & { name?: string }) | undefined
variant?: string | null
selected?: string | null
configured?: string
pick?: State
base?: State
current?: string
}
export type ModelWindow = Window & {
__opencode_e2e?: {
model?: {
enabled?: boolean
current?: ModelProbeState
}
}
}
const clone = (state?: State) => {
if (!state) return undefined
return {
...state,
model: state.model ? { ...state.model } : state.model,
}
}
export const modelEnabled = () => {
if (typeof window === "undefined") return false
return (window as ModelWindow).__opencode_e2e?.model?.enabled === true
}
const root = () => {
if (!modelEnabled()) return
return (window as ModelWindow).__opencode_e2e?.model
}
export const modelProbe = {
set(input: ModelProbeState) {
const state = root()
if (!state) return
state.current = {
...input,
model: input.model ? { ...input.model } : undefined,
last: input.last
? {
...input.last,
model: input.last.model ? { ...input.last.model } : input.last.model,
}
: undefined,
pick: clone(input.pick),
base: clone(input.base),
}
},
clear() {
const state = root()
if (!state) return
state.current = undefined
},
}

View File

@ -1,3 +1,5 @@
import type { ModelProbeState } from "./model-selection"
export const terminalAttr = "data-pty-id" export const terminalAttr = "data-pty-id"
export type TerminalProbeState = { export type TerminalProbeState = {
@ -13,6 +15,10 @@ type TerminalProbeControl = {
export type E2EWindow = Window & { export type E2EWindow = Window & {
__opencode_e2e?: { __opencode_e2e?: {
model?: {
enabled?: boolean
current?: ModelProbeState
}
terminal?: { terminal?: {
enabled?: boolean enabled?: boolean
terminals?: Record<string, TerminalProbeState> terminals?: Record<string, TerminalProbeState>

View File

@ -19,6 +19,7 @@ export type SelectProps<T> = Omit<ComponentProps<typeof Kobalte<T>>, "value" | "
children?: (item: T | undefined) => JSX.Element children?: (item: T | undefined) => JSX.Element
triggerStyle?: JSX.CSSProperties triggerStyle?: JSX.CSSProperties
triggerVariant?: "settings" triggerVariant?: "settings"
triggerProps?: Record<string, string | number | boolean | undefined>
} }
export function Select<T>(props: SelectProps<T> & Omit<ButtonProps, "children">) { export function Select<T>(props: SelectProps<T> & Omit<ButtonProps, "children">) {
@ -38,6 +39,7 @@ export function Select<T>(props: SelectProps<T> & Omit<ButtonProps, "children">)
"children", "children",
"triggerStyle", "triggerStyle",
"triggerVariant", "triggerVariant",
"triggerProps",
]) ])
const state = { const state = {
@ -131,6 +133,7 @@ export function Select<T>(props: SelectProps<T> & Omit<ButtonProps, "children">)
}} }}
> >
<Kobalte.Trigger <Kobalte.Trigger
{...local.triggerProps}
disabled={props.disabled} disabled={props.disabled}
data-slot="select-select-trigger" data-slot="select-select-trigger"
as={Button} as={Button}