mirror of
https://gitea.toothfairyai.com/ToothFairyAI/tf_code.git
synced 2026-03-31 22:32:28 +00:00
feat: integrate support for multi step auth flows for providers that require additional questions (#18035)
This commit is contained in:
@@ -46,9 +46,13 @@ async function handlePluginAuth(plugin: { auth: PluginAuth }, provider: string,
|
||||
const inputs: Record<string, string> = {}
|
||||
if (method.prompts) {
|
||||
for (const prompt of method.prompts) {
|
||||
if (prompt.condition && !prompt.condition(inputs)) {
|
||||
continue
|
||||
if (prompt.when) {
|
||||
const value = inputs[prompt.when.key]
|
||||
if (value === undefined) continue
|
||||
const matches = prompt.when.op === "eq" ? value === prompt.when.value : value !== prompt.when.value
|
||||
if (!matches) continue
|
||||
}
|
||||
if (prompt.condition && !prompt.condition(inputs)) continue
|
||||
if (prompt.type === "select") {
|
||||
const value = await prompts.select({
|
||||
message: prompt.message,
|
||||
|
||||
@@ -8,7 +8,7 @@ import { DialogPrompt } from "../ui/dialog-prompt"
|
||||
import { Link } from "../ui/link"
|
||||
import { useTheme } from "../context/theme"
|
||||
import { TextAttributes } from "@opentui/core"
|
||||
import type { ProviderAuthAuthorization } from "@opencode-ai/sdk/v2"
|
||||
import type { ProviderAuthAuthorization, ProviderAuthMethod } from "@opencode-ai/sdk/v2"
|
||||
import { DialogModel } from "./dialog-model"
|
||||
import { useKeyboard } from "@opentui/solid"
|
||||
import { Clipboard } from "@tui/util/clipboard"
|
||||
@@ -27,6 +27,7 @@ export function createDialogProviderOptions() {
|
||||
const sync = useSync()
|
||||
const dialog = useDialog()
|
||||
const sdk = useSDK()
|
||||
const toast = useToast()
|
||||
const options = createMemo(() => {
|
||||
return pipe(
|
||||
sync.data.provider_next.all,
|
||||
@@ -69,10 +70,29 @@ export function createDialogProviderOptions() {
|
||||
if (index == null) return
|
||||
const method = methods[index]
|
||||
if (method.type === "oauth") {
|
||||
let inputs: Record<string, string> | undefined
|
||||
if (method.prompts?.length) {
|
||||
const value = await PromptsMethod({
|
||||
dialog,
|
||||
prompts: method.prompts,
|
||||
})
|
||||
if (!value) return
|
||||
inputs = value
|
||||
}
|
||||
|
||||
const result = await sdk.client.provider.oauth.authorize({
|
||||
providerID: provider.id,
|
||||
method: index,
|
||||
inputs,
|
||||
})
|
||||
if (result.error) {
|
||||
toast.show({
|
||||
variant: "error",
|
||||
message: JSON.stringify(result.error),
|
||||
})
|
||||
dialog.clear()
|
||||
return
|
||||
}
|
||||
if (result.data?.method === "code") {
|
||||
dialog.replace(() => (
|
||||
<CodeMethod providerID={provider.id} title={method.label} index={index} authorization={result.data!} />
|
||||
@@ -257,3 +277,53 @@ function ApiMethod(props: ApiMethodProps) {
|
||||
/>
|
||||
)
|
||||
}
|
||||
|
||||
interface PromptsMethodProps {
|
||||
dialog: ReturnType<typeof useDialog>
|
||||
prompts: NonNullable<ProviderAuthMethod["prompts"]>[number][]
|
||||
}
|
||||
async function PromptsMethod(props: PromptsMethodProps) {
|
||||
const inputs: Record<string, string> = {}
|
||||
for (const prompt of props.prompts) {
|
||||
if (prompt.when) {
|
||||
const value = inputs[prompt.when.key]
|
||||
if (value === undefined) continue
|
||||
const matches = prompt.when.op === "eq" ? value === prompt.when.value : value !== prompt.when.value
|
||||
if (!matches) continue
|
||||
}
|
||||
|
||||
if (prompt.type === "select") {
|
||||
const value = await new Promise<string | null>((resolve) => {
|
||||
props.dialog.replace(
|
||||
() => (
|
||||
<DialogSelect
|
||||
title={prompt.message}
|
||||
options={prompt.options.map((x) => ({
|
||||
title: x.label,
|
||||
value: x.value,
|
||||
description: x.hint,
|
||||
}))}
|
||||
onSelect={(option) => resolve(option.value)}
|
||||
/>
|
||||
),
|
||||
() => resolve(null),
|
||||
)
|
||||
})
|
||||
if (value === null) return null
|
||||
inputs[prompt.key] = value
|
||||
continue
|
||||
}
|
||||
|
||||
const value = await new Promise<string | null>((resolve) => {
|
||||
props.dialog.replace(
|
||||
() => (
|
||||
<DialogPrompt title={prompt.message} placeholder={prompt.placeholder} onConfirm={(value) => resolve(value)} />
|
||||
),
|
||||
() => resolve(null),
|
||||
)
|
||||
})
|
||||
if (value === null) return null
|
||||
inputs[prompt.key] = value
|
||||
}
|
||||
return inputs
|
||||
}
|
||||
|
||||
@@ -168,7 +168,7 @@ export async function CopilotAuthPlugin(input: PluginInput): Promise<Hooks> {
|
||||
key: "enterpriseUrl",
|
||||
message: "Enter your GitHub Enterprise URL or domain",
|
||||
placeholder: "company.ghe.com or https://company.ghe.com",
|
||||
condition: (inputs) => inputs.deploymentType === "enterprise",
|
||||
when: { key: "deploymentType", op: "eq", value: "enterprise" },
|
||||
validate: (value) => {
|
||||
if (!value) return "URL or domain is required"
|
||||
try {
|
||||
|
||||
@@ -10,6 +10,44 @@ export const Method = z
|
||||
.object({
|
||||
type: z.union([z.literal("oauth"), z.literal("api")]),
|
||||
label: z.string(),
|
||||
prompts: z
|
||||
.array(
|
||||
z.union([
|
||||
z.object({
|
||||
type: z.literal("text"),
|
||||
key: z.string(),
|
||||
message: z.string(),
|
||||
placeholder: z.string().optional(),
|
||||
when: z
|
||||
.object({
|
||||
key: z.string(),
|
||||
op: z.union([z.literal("eq"), z.literal("neq")]),
|
||||
value: z.string(),
|
||||
})
|
||||
.optional(),
|
||||
}),
|
||||
z.object({
|
||||
type: z.literal("select"),
|
||||
key: z.string(),
|
||||
message: z.string(),
|
||||
options: z.array(
|
||||
z.object({
|
||||
label: z.string(),
|
||||
value: z.string(),
|
||||
hint: z.string().optional(),
|
||||
}),
|
||||
),
|
||||
when: z
|
||||
.object({
|
||||
key: z.string(),
|
||||
op: z.union([z.literal("eq"), z.literal("neq")]),
|
||||
value: z.string(),
|
||||
})
|
||||
.optional(),
|
||||
}),
|
||||
]),
|
||||
)
|
||||
.optional(),
|
||||
})
|
||||
.meta({
|
||||
ref: "ProviderAuthMethod",
|
||||
@@ -43,16 +81,29 @@ export const OauthCodeMissing = NamedError.create(
|
||||
|
||||
export const OauthCallbackFailed = NamedError.create("ProviderAuthOauthCallbackFailed", z.object({}))
|
||||
|
||||
export const ValidationFailed = NamedError.create(
|
||||
"ProviderAuthValidationFailed",
|
||||
z.object({
|
||||
field: z.string(),
|
||||
message: z.string(),
|
||||
}),
|
||||
)
|
||||
|
||||
export type ProviderAuthError =
|
||||
| Auth.AuthServiceError
|
||||
| InstanceType<typeof OauthMissing>
|
||||
| InstanceType<typeof OauthCodeMissing>
|
||||
| InstanceType<typeof OauthCallbackFailed>
|
||||
| InstanceType<typeof ValidationFailed>
|
||||
|
||||
export namespace ProviderAuthService {
|
||||
export interface Service {
|
||||
readonly methods: () => Effect.Effect<Record<string, Method[]>>
|
||||
readonly authorize: (input: { providerID: ProviderID; method: number }) => Effect.Effect<Authorization | undefined>
|
||||
readonly authorize: (input: {
|
||||
providerID: ProviderID
|
||||
method: number
|
||||
inputs?: Record<string, string>
|
||||
}) => Effect.Effect<Authorization | undefined, ProviderAuthError>
|
||||
readonly callback: (input: {
|
||||
providerID: ProviderID
|
||||
method: number
|
||||
@@ -80,16 +131,52 @@ export class ProviderAuthService extends ServiceMap.Service<ProviderAuthService,
|
||||
const pending = new Map<ProviderID, AuthOuathResult>()
|
||||
|
||||
const methods = Effect.fn("ProviderAuthService.methods")(function* () {
|
||||
return Record.map(hooks, (item) => item.methods.map((method): Method => Struct.pick(method, ["type", "label"])))
|
||||
return Record.map(hooks, (item) =>
|
||||
item.methods.map(
|
||||
(method): Method => ({
|
||||
type: method.type,
|
||||
label: method.label,
|
||||
prompts: method.prompts?.map((prompt) => {
|
||||
if (prompt.type === "select") {
|
||||
return {
|
||||
type: "select" as const,
|
||||
key: prompt.key,
|
||||
message: prompt.message,
|
||||
options: prompt.options,
|
||||
when: prompt.when,
|
||||
}
|
||||
}
|
||||
return {
|
||||
type: "text" as const,
|
||||
key: prompt.key,
|
||||
message: prompt.message,
|
||||
placeholder: prompt.placeholder,
|
||||
when: prompt.when,
|
||||
}
|
||||
}),
|
||||
}),
|
||||
),
|
||||
)
|
||||
})
|
||||
|
||||
const authorize = Effect.fn("ProviderAuthService.authorize")(function* (input: {
|
||||
providerID: ProviderID
|
||||
method: number
|
||||
inputs?: Record<string, string>
|
||||
}) {
|
||||
const method = hooks[input.providerID].methods[input.method]
|
||||
if (method.type !== "oauth") return
|
||||
const result = yield* Effect.promise(() => method.authorize())
|
||||
|
||||
if (method.prompts && input.inputs) {
|
||||
for (const prompt of method.prompts) {
|
||||
if (prompt.type === "text" && prompt.validate && input.inputs[prompt.key] !== undefined) {
|
||||
const error = prompt.validate(input.inputs[prompt.key])
|
||||
if (error) return yield* Effect.fail(new ValidationFailed({ field: prompt.key, message: error }))
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
const result = yield* Effect.promise(() => method.authorize(input.inputs))
|
||||
pending.set(input.providerID, result)
|
||||
return {
|
||||
url: result.url,
|
||||
|
||||
@@ -20,6 +20,7 @@ export namespace ProviderAuth {
|
||||
z.object({
|
||||
providerID: ProviderID.zod,
|
||||
method: z.number(),
|
||||
inputs: z.record(z.string(), z.string()).optional(),
|
||||
}),
|
||||
async (input): Promise<Authorization | undefined> =>
|
||||
runPromiseInstance(S.ProviderAuthService.use((service) => service.authorize(input))),
|
||||
@@ -37,4 +38,5 @@ export namespace ProviderAuth {
|
||||
export import OauthMissing = S.OauthMissing
|
||||
export import OauthCodeMissing = S.OauthCodeMissing
|
||||
export import OauthCallbackFailed = S.OauthCallbackFailed
|
||||
export import ValidationFailed = S.ValidationFailed
|
||||
}
|
||||
|
||||
@@ -109,14 +109,16 @@ export const ProviderRoutes = lazy(() =>
|
||||
"json",
|
||||
z.object({
|
||||
method: z.number().meta({ description: "Auth method index" }),
|
||||
inputs: z.record(z.string(), z.string()).optional().meta({ description: "Prompt inputs" }),
|
||||
}),
|
||||
),
|
||||
async (c) => {
|
||||
const providerID = c.req.valid("param").providerID
|
||||
const { method } = c.req.valid("json")
|
||||
const { method, inputs } = c.req.valid("json")
|
||||
const result = await ProviderAuth.authorize({
|
||||
providerID,
|
||||
method,
|
||||
inputs,
|
||||
})
|
||||
return c.json(result)
|
||||
},
|
||||
|
||||
@@ -66,6 +66,7 @@ export namespace Server {
|
||||
let status: ContentfulStatusCode
|
||||
if (err instanceof NotFoundError) status = 404
|
||||
else if (err instanceof Provider.ModelNotFoundError) status = 400
|
||||
else if (err.name === "ProviderAuthValidationFailed") status = 400
|
||||
else if (err.name.startsWith("Worktree")) status = 400
|
||||
else status = 500
|
||||
return c.json(err.toObject(), { status })
|
||||
|
||||
Reference in New Issue
Block a user