From 171e69c2fc148985af7b9506b47f048d3a34a767 Mon Sep 17 00:00:00 2001 From: Aiden Cline <63023139+rekram1-node@users.noreply.github.com> Date: Wed, 18 Mar 2026 11:36:19 -0500 Subject: [PATCH] feat: integrate support for multi step auth flows for providers that require additional questions (#18035) --- packages/opencode/src/cli/cmd/providers.ts | 8 +- .../cli/cmd/tui/component/dialog-provider.tsx | 72 ++++++++++- packages/opencode/src/plugin/copilot.ts | 2 +- .../opencode/src/provider/auth-service.ts | 93 +++++++++++++- packages/opencode/src/provider/auth.ts | 2 + .../opencode/src/server/routes/provider.ts | 4 +- packages/opencode/src/server/server.ts | 1 + packages/plugin/src/index.ts | 14 +++ packages/sdk/js/src/v2/gen/sdk.gen.ts | 4 + packages/sdk/js/src/v2/gen/types.gen.ts | 34 +++++ packages/sdk/openapi.json | 118 ++++++++++++++++++ 11 files changed, 344 insertions(+), 8 deletions(-) diff --git a/packages/opencode/src/cli/cmd/providers.ts b/packages/opencode/src/cli/cmd/providers.ts index 631ca7811..a2b7c5be1 100644 --- a/packages/opencode/src/cli/cmd/providers.ts +++ b/packages/opencode/src/cli/cmd/providers.ts @@ -46,9 +46,13 @@ async function handlePluginAuth(plugin: { auth: PluginAuth }, provider: string, const inputs: Record = {} if (method.prompts) { for (const prompt of method.prompts) { - if (prompt.condition && !prompt.condition(inputs)) { - continue + if (prompt.when) { + const value = inputs[prompt.when.key] + if (value === undefined) continue + const matches = prompt.when.op === "eq" ? value === prompt.when.value : value !== prompt.when.value + if (!matches) continue } + if (prompt.condition && !prompt.condition(inputs)) continue if (prompt.type === "select") { const value = await prompts.select({ message: prompt.message, diff --git a/packages/opencode/src/cli/cmd/tui/component/dialog-provider.tsx b/packages/opencode/src/cli/cmd/tui/component/dialog-provider.tsx index f77e4727a..635ed71f5 100644 --- a/packages/opencode/src/cli/cmd/tui/component/dialog-provider.tsx +++ b/packages/opencode/src/cli/cmd/tui/component/dialog-provider.tsx @@ -8,7 +8,7 @@ import { DialogPrompt } from "../ui/dialog-prompt" import { Link } from "../ui/link" import { useTheme } from "../context/theme" import { TextAttributes } from "@opentui/core" -import type { ProviderAuthAuthorization } from "@opencode-ai/sdk/v2" +import type { ProviderAuthAuthorization, ProviderAuthMethod } from "@opencode-ai/sdk/v2" import { DialogModel } from "./dialog-model" import { useKeyboard } from "@opentui/solid" import { Clipboard } from "@tui/util/clipboard" @@ -27,6 +27,7 @@ export function createDialogProviderOptions() { const sync = useSync() const dialog = useDialog() const sdk = useSDK() + const toast = useToast() const options = createMemo(() => { return pipe( sync.data.provider_next.all, @@ -69,10 +70,29 @@ export function createDialogProviderOptions() { if (index == null) return const method = methods[index] if (method.type === "oauth") { + let inputs: Record | undefined + if (method.prompts?.length) { + const value = await PromptsMethod({ + dialog, + prompts: method.prompts, + }) + if (!value) return + inputs = value + } + const result = await sdk.client.provider.oauth.authorize({ providerID: provider.id, method: index, + inputs, }) + if (result.error) { + toast.show({ + variant: "error", + message: JSON.stringify(result.error), + }) + dialog.clear() + return + } if (result.data?.method === "code") { dialog.replace(() => ( @@ -257,3 +277,53 @@ function ApiMethod(props: ApiMethodProps) { /> ) } + +interface PromptsMethodProps { + dialog: ReturnType + prompts: NonNullable[number][] +} +async function PromptsMethod(props: PromptsMethodProps) { + const inputs: Record = {} + for (const prompt of props.prompts) { + if (prompt.when) { + const value = inputs[prompt.when.key] + if (value === undefined) continue + const matches = prompt.when.op === "eq" ? value === prompt.when.value : value !== prompt.when.value + if (!matches) continue + } + + if (prompt.type === "select") { + const value = await new Promise((resolve) => { + props.dialog.replace( + () => ( + ({ + title: x.label, + value: x.value, + description: x.hint, + }))} + onSelect={(option) => resolve(option.value)} + /> + ), + () => resolve(null), + ) + }) + if (value === null) return null + inputs[prompt.key] = value + continue + } + + const value = await new Promise((resolve) => { + props.dialog.replace( + () => ( + resolve(value)} /> + ), + () => resolve(null), + ) + }) + if (value === null) return null + inputs[prompt.key] = value + } + return inputs +} diff --git a/packages/opencode/src/plugin/copilot.ts b/packages/opencode/src/plugin/copilot.ts index 31d84532c..44c5289dd 100644 --- a/packages/opencode/src/plugin/copilot.ts +++ b/packages/opencode/src/plugin/copilot.ts @@ -168,7 +168,7 @@ export async function CopilotAuthPlugin(input: PluginInput): Promise { key: "enterpriseUrl", message: "Enter your GitHub Enterprise URL or domain", placeholder: "company.ghe.com or https://company.ghe.com", - condition: (inputs) => inputs.deploymentType === "enterprise", + when: { key: "deploymentType", op: "eq", value: "enterprise" }, validate: (value) => { if (!value) return "URL or domain is required" try { diff --git a/packages/opencode/src/provider/auth-service.ts b/packages/opencode/src/provider/auth-service.ts index 2e9985939..900b31d10 100644 --- a/packages/opencode/src/provider/auth-service.ts +++ b/packages/opencode/src/provider/auth-service.ts @@ -10,6 +10,44 @@ export const Method = z .object({ type: z.union([z.literal("oauth"), z.literal("api")]), label: z.string(), + prompts: z + .array( + z.union([ + z.object({ + type: z.literal("text"), + key: z.string(), + message: z.string(), + placeholder: z.string().optional(), + when: z + .object({ + key: z.string(), + op: z.union([z.literal("eq"), z.literal("neq")]), + value: z.string(), + }) + .optional(), + }), + z.object({ + type: z.literal("select"), + key: z.string(), + message: z.string(), + options: z.array( + z.object({ + label: z.string(), + value: z.string(), + hint: z.string().optional(), + }), + ), + when: z + .object({ + key: z.string(), + op: z.union([z.literal("eq"), z.literal("neq")]), + value: z.string(), + }) + .optional(), + }), + ]), + ) + .optional(), }) .meta({ ref: "ProviderAuthMethod", @@ -43,16 +81,29 @@ export const OauthCodeMissing = NamedError.create( export const OauthCallbackFailed = NamedError.create("ProviderAuthOauthCallbackFailed", z.object({})) +export const ValidationFailed = NamedError.create( + "ProviderAuthValidationFailed", + z.object({ + field: z.string(), + message: z.string(), + }), +) + export type ProviderAuthError = | Auth.AuthServiceError | InstanceType | InstanceType | InstanceType + | InstanceType export namespace ProviderAuthService { export interface Service { readonly methods: () => Effect.Effect> - readonly authorize: (input: { providerID: ProviderID; method: number }) => Effect.Effect + readonly authorize: (input: { + providerID: ProviderID + method: number + inputs?: Record + }) => Effect.Effect readonly callback: (input: { providerID: ProviderID method: number @@ -80,16 +131,52 @@ export class ProviderAuthService extends ServiceMap.Service() const methods = Effect.fn("ProviderAuthService.methods")(function* () { - return Record.map(hooks, (item) => item.methods.map((method): Method => Struct.pick(method, ["type", "label"]))) + return Record.map(hooks, (item) => + item.methods.map( + (method): Method => ({ + type: method.type, + label: method.label, + prompts: method.prompts?.map((prompt) => { + if (prompt.type === "select") { + return { + type: "select" as const, + key: prompt.key, + message: prompt.message, + options: prompt.options, + when: prompt.when, + } + } + return { + type: "text" as const, + key: prompt.key, + message: prompt.message, + placeholder: prompt.placeholder, + when: prompt.when, + } + }), + }), + ), + ) }) const authorize = Effect.fn("ProviderAuthService.authorize")(function* (input: { providerID: ProviderID method: number + inputs?: Record }) { const method = hooks[input.providerID].methods[input.method] if (method.type !== "oauth") return - const result = yield* Effect.promise(() => method.authorize()) + + if (method.prompts && input.inputs) { + for (const prompt of method.prompts) { + if (prompt.type === "text" && prompt.validate && input.inputs[prompt.key] !== undefined) { + const error = prompt.validate(input.inputs[prompt.key]) + if (error) return yield* Effect.fail(new ValidationFailed({ field: prompt.key, message: error })) + } + } + } + + const result = yield* Effect.promise(() => method.authorize(input.inputs)) pending.set(input.providerID, result) return { url: result.url, diff --git a/packages/opencode/src/provider/auth.ts b/packages/opencode/src/provider/auth.ts index 15d23c925..912b12076 100644 --- a/packages/opencode/src/provider/auth.ts +++ b/packages/opencode/src/provider/auth.ts @@ -20,6 +20,7 @@ export namespace ProviderAuth { z.object({ providerID: ProviderID.zod, method: z.number(), + inputs: z.record(z.string(), z.string()).optional(), }), async (input): Promise => runPromiseInstance(S.ProviderAuthService.use((service) => service.authorize(input))), @@ -37,4 +38,5 @@ export namespace ProviderAuth { export import OauthMissing = S.OauthMissing export import OauthCodeMissing = S.OauthCodeMissing export import OauthCallbackFailed = S.OauthCallbackFailed + export import ValidationFailed = S.ValidationFailed } diff --git a/packages/opencode/src/server/routes/provider.ts b/packages/opencode/src/server/routes/provider.ts index fc716d25c..3ac3e7c64 100644 --- a/packages/opencode/src/server/routes/provider.ts +++ b/packages/opencode/src/server/routes/provider.ts @@ -109,14 +109,16 @@ export const ProviderRoutes = lazy(() => "json", z.object({ method: z.number().meta({ description: "Auth method index" }), + inputs: z.record(z.string(), z.string()).optional().meta({ description: "Prompt inputs" }), }), ), async (c) => { const providerID = c.req.valid("param").providerID - const { method } = c.req.valid("json") + const { method, inputs } = c.req.valid("json") const result = await ProviderAuth.authorize({ providerID, method, + inputs, }) return c.json(result) }, diff --git a/packages/opencode/src/server/server.ts b/packages/opencode/src/server/server.ts index 677af4da8..1904706a1 100644 --- a/packages/opencode/src/server/server.ts +++ b/packages/opencode/src/server/server.ts @@ -66,6 +66,7 @@ export namespace Server { let status: ContentfulStatusCode if (err instanceof NotFoundError) status = 404 else if (err instanceof Provider.ModelNotFoundError) status = 400 + else if (err.name === "ProviderAuthValidationFailed") status = 400 else if (err.name.startsWith("Worktree")) status = 400 else status = 500 return c.json(err.toObject(), { status }) diff --git a/packages/plugin/src/index.ts b/packages/plugin/src/index.ts index b78bcae17..7e5ae7a6e 100644 --- a/packages/plugin/src/index.ts +++ b/packages/plugin/src/index.ts @@ -34,6 +34,12 @@ export type PluginInput = { export type Plugin = (input: PluginInput) => Promise +type Rule = { + key: string + op: "eq" | "neq" + value: string +} + export type AuthHook = { provider: string loader?: (auth: () => Promise, provider: Provider) => Promise> @@ -48,7 +54,9 @@ export type AuthHook = { message: string placeholder?: string validate?: (value: string) => string | undefined + /** @deprecated Use `when` instead */ condition?: (inputs: Record) => boolean + when?: Rule } | { type: "select" @@ -59,7 +67,9 @@ export type AuthHook = { value: string hint?: string }> + /** @deprecated Use `when` instead */ condition?: (inputs: Record) => boolean + when?: Rule } > authorize(inputs?: Record): Promise @@ -74,7 +84,9 @@ export type AuthHook = { message: string placeholder?: string validate?: (value: string) => string | undefined + /** @deprecated Use `when` instead */ condition?: (inputs: Record) => boolean + when?: Rule } | { type: "select" @@ -85,7 +97,9 @@ export type AuthHook = { value: string hint?: string }> + /** @deprecated Use `when` instead */ condition?: (inputs: Record) => boolean + when?: Rule } > authorize?(inputs?: Record): Promise< diff --git a/packages/sdk/js/src/v2/gen/sdk.gen.ts b/packages/sdk/js/src/v2/gen/sdk.gen.ts index 27c188838..aa759bb1e 100644 --- a/packages/sdk/js/src/v2/gen/sdk.gen.ts +++ b/packages/sdk/js/src/v2/gen/sdk.gen.ts @@ -2496,6 +2496,9 @@ export class Oauth extends HeyApiClient { directory?: string workspace?: string method?: number + inputs?: { + [key: string]: string + } }, options?: Options, ) { @@ -2508,6 +2511,7 @@ export class Oauth extends HeyApiClient { { in: "query", key: "directory" }, { in: "query", key: "workspace" }, { in: "body", key: "method" }, + { in: "body", key: "inputs" }, ], }, ], diff --git a/packages/sdk/js/src/v2/gen/types.gen.ts b/packages/sdk/js/src/v2/gen/types.gen.ts index 9c5ca274e..41aa24817 100644 --- a/packages/sdk/js/src/v2/gen/types.gen.ts +++ b/packages/sdk/js/src/v2/gen/types.gen.ts @@ -1769,6 +1769,34 @@ export type SubtaskPartInput = { export type ProviderAuthMethod = { type: "oauth" | "api" label: string + prompts?: Array< + | { + type: "text" + key: string + message: string + placeholder?: string + when?: { + key: string + op: "eq" | "neq" + value: string + } + } + | { + type: "select" + key: string + message: string + options: Array<{ + label: string + value: string + hint?: string + }> + when?: { + key: string + op: "eq" | "neq" + value: string + } + } + > } export type ProviderAuthAuthorization = { @@ -3983,6 +4011,12 @@ export type ProviderOauthAuthorizeData = { * Auth method index */ method: number + /** + * Prompt inputs + */ + inputs?: { + [key: string]: string + } } path: { /** diff --git a/packages/sdk/openapi.json b/packages/sdk/openapi.json index c6d79b11e..350395423 100644 --- a/packages/sdk/openapi.json +++ b/packages/sdk/openapi.json @@ -4761,6 +4761,16 @@ "method": { "description": "Auth method index", "type": "number" + }, + "inputs": { + "description": "Prompt inputs", + "type": "object", + "propertyNames": { + "type": "string" + }, + "additionalProperties": { + "type": "string" + } } }, "required": ["method"] @@ -11541,6 +11551,114 @@ }, "label": { "type": "string" + }, + "prompts": { + "type": "array", + "items": { + "anyOf": [ + { + "type": "object", + "properties": { + "type": { + "type": "string", + "const": "text" + }, + "key": { + "type": "string" + }, + "message": { + "type": "string" + }, + "placeholder": { + "type": "string" + }, + "when": { + "type": "object", + "properties": { + "key": { + "type": "string" + }, + "op": { + "anyOf": [ + { + "type": "string", + "const": "eq" + }, + { + "type": "string", + "const": "neq" + } + ] + }, + "value": { + "type": "string" + } + }, + "required": ["key", "op", "value"] + } + }, + "required": ["type", "key", "message"] + }, + { + "type": "object", + "properties": { + "type": { + "type": "string", + "const": "select" + }, + "key": { + "type": "string" + }, + "message": { + "type": "string" + }, + "options": { + "type": "array", + "items": { + "type": "object", + "properties": { + "label": { + "type": "string" + }, + "value": { + "type": "string" + }, + "hint": { + "type": "string" + } + }, + "required": ["label", "value"] + } + }, + "when": { + "type": "object", + "properties": { + "key": { + "type": "string" + }, + "op": { + "anyOf": [ + { + "type": "string", + "const": "eq" + }, + { + "type": "string", + "const": "neq" + } + ] + }, + "value": { + "type": "string" + } + }, + "required": ["key", "op", "value"] + } + }, + "required": ["type", "key", "message", "options"] + } + ] + } } }, "required": ["type", "label"]