feat: add variants toggle (#6325)

Co-authored-by: Github Action <action@github.com>
This commit is contained in:
Aiden Cline
2025-12-29 19:43:50 -08:00
committed by GitHub
parent e1dd9c4ccb
commit ed0c0d90be
14 changed files with 339 additions and 52 deletions

View File

@@ -167,6 +167,13 @@ export function Prompt(props: PromptProps) {
if (!props.disabled) input.cursorColor = theme.text
})
const lastUserMessage = createMemo(() => {
if (!props.sessionID) return undefined
const messages = sync.data.message[props.sessionID]
if (!messages) return undefined
return messages.findLast((m) => m.role === "user")
})
const [store, setStore] = createStore<{
prompt: PromptInfo
mode: "normal" | "shell"
@@ -184,6 +191,26 @@ export function Prompt(props: PromptProps) {
interrupt: 0,
})
createEffect(() => {
const msg = lastUserMessage()
if (!msg) return
// Set agent from last message
if (msg.agent) {
local.agent.set(msg.agent)
}
// Set model from last message
if (msg.model) {
local.model.set(msg.model)
}
// Set variant from last message
if (msg.variant) {
local.model.variant.set(msg.variant)
}
})
command.register(() => {
return [
{
@@ -562,6 +589,7 @@ export function Prompt(props: PromptProps) {
// Capture mode before it gets reset
const currentMode = store.mode
const variant = local.model.variant.current()
if (store.mode === "shell") {
sdk.client.session.shell({
@@ -590,6 +618,7 @@ export function Prompt(props: PromptProps) {
agent: local.agent.current().name,
model: `${selectedModel.providerID}/${selectedModel.modelID}`,
messageID,
variant,
})
} else {
sdk.client.session.prompt({
@@ -598,6 +627,7 @@ export function Prompt(props: PromptProps) {
messageID,
agent: local.agent.current().name,
model: selectedModel,
variant,
parts: [
{
id: Identifier.ascending("part"),
@@ -718,6 +748,13 @@ export function Prompt(props: PromptProps) {
return local.agent.color(local.agent.current().name)
})
const showVariant = createMemo(() => {
const variants = local.model.variant.list()
if (variants.length === 0) return false
const current = local.model.variant.current()
return !!current
})
const spinnerDef = createMemo(() => {
const color = local.agent.color(local.agent.current().name)
return {
@@ -843,6 +880,12 @@ export function Prompt(props: PromptProps) {
return
}
}
if (keybind.match("variant_cycle", e)) {
e.preventDefault()
if (local.model.variant.list().length === 0) return
local.model.variant.cycle()
return
}
if (store.mode === "normal") autocomplete.onKeyDown(e)
if (!autocomplete.visible) {
if (
@@ -958,6 +1001,12 @@ export function Prompt(props: PromptProps) {
{local.model.parsed().model}
</text>
<text fg={theme.textMuted}>{local.model.parsed().provider}</text>
<Show when={showVariant()}>
<text fg={theme.textMuted}>·</text>
<text>
<span style={{ fg: theme.warning, bold: true }}>{local.model.variant.current()}</span>
</text>
</Show>
</box>
</Show>
</box>

View File

@@ -33,24 +33,6 @@ export const { use: useLocal, provider: LocalProvider } = createSimpleContext({
}
}
// Automatically update model when agent changes
createEffect(() => {
const value = agent.current()
if (value.model) {
if (isModelValid(value.model))
model.set({
providerID: value.model.providerID,
modelID: value.model.modelID,
})
else
toast.show({
variant: "warning",
message: `Agent ${value.name}'s configured model ${value.model.providerID}/${value.model.modelID} is not valid`,
duration: 3000,
})
}
})
const agent = iife(() => {
const agents = createMemo(() => sync.data.agent.filter((x) => x.mode !== "subagent" && !x.hidden))
const [agentStore, setAgentStore] = createStore<{
@@ -120,11 +102,13 @@ export const { use: useLocal, provider: LocalProvider } = createSimpleContext({
providerID: string
modelID: string
}[]
variant: Record<string, string | undefined>
}>({
ready: false,
model: {},
recent: [],
favorite: [],
variant: {},
})
const file = Bun.file(path.join(Global.Path.state, "model.json"))
@@ -135,6 +119,7 @@ export const { use: useLocal, provider: LocalProvider } = createSimpleContext({
JSON.stringify({
recent: modelStore.recent,
favorite: modelStore.favorite,
variant: modelStore.variant,
}),
)
}
@@ -144,6 +129,7 @@ export const { use: useLocal, provider: LocalProvider } = createSimpleContext({
.then((x) => {
if (Array.isArray(x.recent)) setModelStore("recent", x.recent)
if (Array.isArray(x.favorite)) setModelStore("favorite", x.favorite)
if (typeof x.variant === "object" && x.variant !== null) setModelStore("variant", x.variant)
})
.catch(() => {})
.finally(() => {
@@ -218,6 +204,7 @@ export const { use: useLocal, provider: LocalProvider } = createSimpleContext({
return {
provider: "Connect a provider",
model: "No provider selected",
reasoning: false,
}
}
const provider = sync.data.provider.find((x) => x.id === value.providerID)
@@ -225,6 +212,7 @@ export const { use: useLocal, provider: LocalProvider } = createSimpleContext({
return {
provider: provider?.name ?? value.providerID,
model: info?.name ?? value.modelID,
reasoning: info?.capabilities?.reasoning ?? false,
}
}),
cycle(direction: 1 | -1) {
@@ -309,6 +297,46 @@ export const { use: useLocal, provider: LocalProvider } = createSimpleContext({
save()
})
},
variant: {
current() {
const m = currentModel()
if (!m) return undefined
const key = `${m.providerID}/${m.modelID}`
return modelStore.variant[key]
},
list() {
const m = currentModel()
if (!m) return []
const provider = sync.data.provider.find((x) => x.id === m.providerID)
const info = provider?.models[m.modelID]
if (!info?.variants) return []
return Object.entries(info.variants)
.filter(([_, v]) => !v.disabled)
.map(([name]) => name)
},
set(value: string | undefined) {
const m = currentModel()
if (!m) return
const key = `${m.providerID}/${m.modelID}`
setModelStore("variant", key, value)
save()
},
cycle() {
const variants = this.list()
if (variants.length === 0) return
const current = this.current()
if (!current) {
this.set(variants[0])
return
}
const index = variants.indexOf(current)
if (index === -1 || index === variants.length - 1) {
this.set(undefined)
return
}
this.set(variants[index + 1])
},
},
}
})
@@ -329,6 +357,24 @@ export const { use: useLocal, provider: LocalProvider } = createSimpleContext({
},
}
// Automatically update model when agent changes
createEffect(() => {
const value = agent.current()
if (value.model) {
if (isModelValid(value.model))
model.set({
providerID: value.model.providerID,
modelID: value.model.modelID,
})
else
toast.show({
variant: "warning",
message: `Agent ${value.name}'s configured model ${value.model.providerID}/${value.model.modelID} is not valid`,
duration: 3000,
})
}
})
const result = {
model,
agent,