mirror of
https://gitea.toothfairyai.com/ToothFairyAI/tf_code.git
synced 2026-03-30 13:54:01 +00:00
feat: integrate support for multi step auth flows for providers that require additional questions (#18035)
This commit is contained in:
parent
822bb7b336
commit
171e69c2fc
@ -46,9 +46,13 @@ async function handlePluginAuth(plugin: { auth: PluginAuth }, provider: string,
|
||||
const inputs: Record<string, string> = {}
|
||||
if (method.prompts) {
|
||||
for (const prompt of method.prompts) {
|
||||
if (prompt.condition && !prompt.condition(inputs)) {
|
||||
continue
|
||||
if (prompt.when) {
|
||||
const value = inputs[prompt.when.key]
|
||||
if (value === undefined) continue
|
||||
const matches = prompt.when.op === "eq" ? value === prompt.when.value : value !== prompt.when.value
|
||||
if (!matches) continue
|
||||
}
|
||||
if (prompt.condition && !prompt.condition(inputs)) continue
|
||||
if (prompt.type === "select") {
|
||||
const value = await prompts.select({
|
||||
message: prompt.message,
|
||||
|
||||
@ -8,7 +8,7 @@ import { DialogPrompt } from "../ui/dialog-prompt"
|
||||
import { Link } from "../ui/link"
|
||||
import { useTheme } from "../context/theme"
|
||||
import { TextAttributes } from "@opentui/core"
|
||||
import type { ProviderAuthAuthorization } from "@opencode-ai/sdk/v2"
|
||||
import type { ProviderAuthAuthorization, ProviderAuthMethod } from "@opencode-ai/sdk/v2"
|
||||
import { DialogModel } from "./dialog-model"
|
||||
import { useKeyboard } from "@opentui/solid"
|
||||
import { Clipboard } from "@tui/util/clipboard"
|
||||
@ -27,6 +27,7 @@ export function createDialogProviderOptions() {
|
||||
const sync = useSync()
|
||||
const dialog = useDialog()
|
||||
const sdk = useSDK()
|
||||
const toast = useToast()
|
||||
const options = createMemo(() => {
|
||||
return pipe(
|
||||
sync.data.provider_next.all,
|
||||
@ -69,10 +70,29 @@ export function createDialogProviderOptions() {
|
||||
if (index == null) return
|
||||
const method = methods[index]
|
||||
if (method.type === "oauth") {
|
||||
let inputs: Record<string, string> | undefined
|
||||
if (method.prompts?.length) {
|
||||
const value = await PromptsMethod({
|
||||
dialog,
|
||||
prompts: method.prompts,
|
||||
})
|
||||
if (!value) return
|
||||
inputs = value
|
||||
}
|
||||
|
||||
const result = await sdk.client.provider.oauth.authorize({
|
||||
providerID: provider.id,
|
||||
method: index,
|
||||
inputs,
|
||||
})
|
||||
if (result.error) {
|
||||
toast.show({
|
||||
variant: "error",
|
||||
message: JSON.stringify(result.error),
|
||||
})
|
||||
dialog.clear()
|
||||
return
|
||||
}
|
||||
if (result.data?.method === "code") {
|
||||
dialog.replace(() => (
|
||||
<CodeMethod providerID={provider.id} title={method.label} index={index} authorization={result.data!} />
|
||||
@ -257,3 +277,53 @@ function ApiMethod(props: ApiMethodProps) {
|
||||
/>
|
||||
)
|
||||
}
|
||||
|
||||
interface PromptsMethodProps {
|
||||
dialog: ReturnType<typeof useDialog>
|
||||
prompts: NonNullable<ProviderAuthMethod["prompts"]>[number][]
|
||||
}
|
||||
async function PromptsMethod(props: PromptsMethodProps) {
|
||||
const inputs: Record<string, string> = {}
|
||||
for (const prompt of props.prompts) {
|
||||
if (prompt.when) {
|
||||
const value = inputs[prompt.when.key]
|
||||
if (value === undefined) continue
|
||||
const matches = prompt.when.op === "eq" ? value === prompt.when.value : value !== prompt.when.value
|
||||
if (!matches) continue
|
||||
}
|
||||
|
||||
if (prompt.type === "select") {
|
||||
const value = await new Promise<string | null>((resolve) => {
|
||||
props.dialog.replace(
|
||||
() => (
|
||||
<DialogSelect
|
||||
title={prompt.message}
|
||||
options={prompt.options.map((x) => ({
|
||||
title: x.label,
|
||||
value: x.value,
|
||||
description: x.hint,
|
||||
}))}
|
||||
onSelect={(option) => resolve(option.value)}
|
||||
/>
|
||||
),
|
||||
() => resolve(null),
|
||||
)
|
||||
})
|
||||
if (value === null) return null
|
||||
inputs[prompt.key] = value
|
||||
continue
|
||||
}
|
||||
|
||||
const value = await new Promise<string | null>((resolve) => {
|
||||
props.dialog.replace(
|
||||
() => (
|
||||
<DialogPrompt title={prompt.message} placeholder={prompt.placeholder} onConfirm={(value) => resolve(value)} />
|
||||
),
|
||||
() => resolve(null),
|
||||
)
|
||||
})
|
||||
if (value === null) return null
|
||||
inputs[prompt.key] = value
|
||||
}
|
||||
return inputs
|
||||
}
|
||||
|
||||
@ -168,7 +168,7 @@ export async function CopilotAuthPlugin(input: PluginInput): Promise<Hooks> {
|
||||
key: "enterpriseUrl",
|
||||
message: "Enter your GitHub Enterprise URL or domain",
|
||||
placeholder: "company.ghe.com or https://company.ghe.com",
|
||||
condition: (inputs) => inputs.deploymentType === "enterprise",
|
||||
when: { key: "deploymentType", op: "eq", value: "enterprise" },
|
||||
validate: (value) => {
|
||||
if (!value) return "URL or domain is required"
|
||||
try {
|
||||
|
||||
@ -10,6 +10,44 @@ export const Method = z
|
||||
.object({
|
||||
type: z.union([z.literal("oauth"), z.literal("api")]),
|
||||
label: z.string(),
|
||||
prompts: z
|
||||
.array(
|
||||
z.union([
|
||||
z.object({
|
||||
type: z.literal("text"),
|
||||
key: z.string(),
|
||||
message: z.string(),
|
||||
placeholder: z.string().optional(),
|
||||
when: z
|
||||
.object({
|
||||
key: z.string(),
|
||||
op: z.union([z.literal("eq"), z.literal("neq")]),
|
||||
value: z.string(),
|
||||
})
|
||||
.optional(),
|
||||
}),
|
||||
z.object({
|
||||
type: z.literal("select"),
|
||||
key: z.string(),
|
||||
message: z.string(),
|
||||
options: z.array(
|
||||
z.object({
|
||||
label: z.string(),
|
||||
value: z.string(),
|
||||
hint: z.string().optional(),
|
||||
}),
|
||||
),
|
||||
when: z
|
||||
.object({
|
||||
key: z.string(),
|
||||
op: z.union([z.literal("eq"), z.literal("neq")]),
|
||||
value: z.string(),
|
||||
})
|
||||
.optional(),
|
||||
}),
|
||||
]),
|
||||
)
|
||||
.optional(),
|
||||
})
|
||||
.meta({
|
||||
ref: "ProviderAuthMethod",
|
||||
@ -43,16 +81,29 @@ export const OauthCodeMissing = NamedError.create(
|
||||
|
||||
export const OauthCallbackFailed = NamedError.create("ProviderAuthOauthCallbackFailed", z.object({}))
|
||||
|
||||
export const ValidationFailed = NamedError.create(
|
||||
"ProviderAuthValidationFailed",
|
||||
z.object({
|
||||
field: z.string(),
|
||||
message: z.string(),
|
||||
}),
|
||||
)
|
||||
|
||||
export type ProviderAuthError =
|
||||
| Auth.AuthServiceError
|
||||
| InstanceType<typeof OauthMissing>
|
||||
| InstanceType<typeof OauthCodeMissing>
|
||||
| InstanceType<typeof OauthCallbackFailed>
|
||||
| InstanceType<typeof ValidationFailed>
|
||||
|
||||
export namespace ProviderAuthService {
|
||||
export interface Service {
|
||||
readonly methods: () => Effect.Effect<Record<string, Method[]>>
|
||||
readonly authorize: (input: { providerID: ProviderID; method: number }) => Effect.Effect<Authorization | undefined>
|
||||
readonly authorize: (input: {
|
||||
providerID: ProviderID
|
||||
method: number
|
||||
inputs?: Record<string, string>
|
||||
}) => Effect.Effect<Authorization | undefined, ProviderAuthError>
|
||||
readonly callback: (input: {
|
||||
providerID: ProviderID
|
||||
method: number
|
||||
@ -80,16 +131,52 @@ export class ProviderAuthService extends ServiceMap.Service<ProviderAuthService,
|
||||
const pending = new Map<ProviderID, AuthOuathResult>()
|
||||
|
||||
const methods = Effect.fn("ProviderAuthService.methods")(function* () {
|
||||
return Record.map(hooks, (item) => item.methods.map((method): Method => Struct.pick(method, ["type", "label"])))
|
||||
return Record.map(hooks, (item) =>
|
||||
item.methods.map(
|
||||
(method): Method => ({
|
||||
type: method.type,
|
||||
label: method.label,
|
||||
prompts: method.prompts?.map((prompt) => {
|
||||
if (prompt.type === "select") {
|
||||
return {
|
||||
type: "select" as const,
|
||||
key: prompt.key,
|
||||
message: prompt.message,
|
||||
options: prompt.options,
|
||||
when: prompt.when,
|
||||
}
|
||||
}
|
||||
return {
|
||||
type: "text" as const,
|
||||
key: prompt.key,
|
||||
message: prompt.message,
|
||||
placeholder: prompt.placeholder,
|
||||
when: prompt.when,
|
||||
}
|
||||
}),
|
||||
}),
|
||||
),
|
||||
)
|
||||
})
|
||||
|
||||
const authorize = Effect.fn("ProviderAuthService.authorize")(function* (input: {
|
||||
providerID: ProviderID
|
||||
method: number
|
||||
inputs?: Record<string, string>
|
||||
}) {
|
||||
const method = hooks[input.providerID].methods[input.method]
|
||||
if (method.type !== "oauth") return
|
||||
const result = yield* Effect.promise(() => method.authorize())
|
||||
|
||||
if (method.prompts && input.inputs) {
|
||||
for (const prompt of method.prompts) {
|
||||
if (prompt.type === "text" && prompt.validate && input.inputs[prompt.key] !== undefined) {
|
||||
const error = prompt.validate(input.inputs[prompt.key])
|
||||
if (error) return yield* Effect.fail(new ValidationFailed({ field: prompt.key, message: error }))
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
const result = yield* Effect.promise(() => method.authorize(input.inputs))
|
||||
pending.set(input.providerID, result)
|
||||
return {
|
||||
url: result.url,
|
||||
|
||||
@ -20,6 +20,7 @@ export namespace ProviderAuth {
|
||||
z.object({
|
||||
providerID: ProviderID.zod,
|
||||
method: z.number(),
|
||||
inputs: z.record(z.string(), z.string()).optional(),
|
||||
}),
|
||||
async (input): Promise<Authorization | undefined> =>
|
||||
runPromiseInstance(S.ProviderAuthService.use((service) => service.authorize(input))),
|
||||
@ -37,4 +38,5 @@ export namespace ProviderAuth {
|
||||
export import OauthMissing = S.OauthMissing
|
||||
export import OauthCodeMissing = S.OauthCodeMissing
|
||||
export import OauthCallbackFailed = S.OauthCallbackFailed
|
||||
export import ValidationFailed = S.ValidationFailed
|
||||
}
|
||||
|
||||
@ -109,14 +109,16 @@ export const ProviderRoutes = lazy(() =>
|
||||
"json",
|
||||
z.object({
|
||||
method: z.number().meta({ description: "Auth method index" }),
|
||||
inputs: z.record(z.string(), z.string()).optional().meta({ description: "Prompt inputs" }),
|
||||
}),
|
||||
),
|
||||
async (c) => {
|
||||
const providerID = c.req.valid("param").providerID
|
||||
const { method } = c.req.valid("json")
|
||||
const { method, inputs } = c.req.valid("json")
|
||||
const result = await ProviderAuth.authorize({
|
||||
providerID,
|
||||
method,
|
||||
inputs,
|
||||
})
|
||||
return c.json(result)
|
||||
},
|
||||
|
||||
@ -66,6 +66,7 @@ export namespace Server {
|
||||
let status: ContentfulStatusCode
|
||||
if (err instanceof NotFoundError) status = 404
|
||||
else if (err instanceof Provider.ModelNotFoundError) status = 400
|
||||
else if (err.name === "ProviderAuthValidationFailed") status = 400
|
||||
else if (err.name.startsWith("Worktree")) status = 400
|
||||
else status = 500
|
||||
return c.json(err.toObject(), { status })
|
||||
|
||||
@ -34,6 +34,12 @@ export type PluginInput = {
|
||||
|
||||
export type Plugin = (input: PluginInput) => Promise<Hooks>
|
||||
|
||||
type Rule = {
|
||||
key: string
|
||||
op: "eq" | "neq"
|
||||
value: string
|
||||
}
|
||||
|
||||
export type AuthHook = {
|
||||
provider: string
|
||||
loader?: (auth: () => Promise<Auth>, provider: Provider) => Promise<Record<string, any>>
|
||||
@ -48,7 +54,9 @@ export type AuthHook = {
|
||||
message: string
|
||||
placeholder?: string
|
||||
validate?: (value: string) => string | undefined
|
||||
/** @deprecated Use `when` instead */
|
||||
condition?: (inputs: Record<string, string>) => boolean
|
||||
when?: Rule
|
||||
}
|
||||
| {
|
||||
type: "select"
|
||||
@ -59,7 +67,9 @@ export type AuthHook = {
|
||||
value: string
|
||||
hint?: string
|
||||
}>
|
||||
/** @deprecated Use `when` instead */
|
||||
condition?: (inputs: Record<string, string>) => boolean
|
||||
when?: Rule
|
||||
}
|
||||
>
|
||||
authorize(inputs?: Record<string, string>): Promise<AuthOuathResult>
|
||||
@ -74,7 +84,9 @@ export type AuthHook = {
|
||||
message: string
|
||||
placeholder?: string
|
||||
validate?: (value: string) => string | undefined
|
||||
/** @deprecated Use `when` instead */
|
||||
condition?: (inputs: Record<string, string>) => boolean
|
||||
when?: Rule
|
||||
}
|
||||
| {
|
||||
type: "select"
|
||||
@ -85,7 +97,9 @@ export type AuthHook = {
|
||||
value: string
|
||||
hint?: string
|
||||
}>
|
||||
/** @deprecated Use `when` instead */
|
||||
condition?: (inputs: Record<string, string>) => boolean
|
||||
when?: Rule
|
||||
}
|
||||
>
|
||||
authorize?(inputs?: Record<string, string>): Promise<
|
||||
|
||||
@ -2496,6 +2496,9 @@ export class Oauth extends HeyApiClient {
|
||||
directory?: string
|
||||
workspace?: string
|
||||
method?: number
|
||||
inputs?: {
|
||||
[key: string]: string
|
||||
}
|
||||
},
|
||||
options?: Options<never, ThrowOnError>,
|
||||
) {
|
||||
@ -2508,6 +2511,7 @@ export class Oauth extends HeyApiClient {
|
||||
{ in: "query", key: "directory" },
|
||||
{ in: "query", key: "workspace" },
|
||||
{ in: "body", key: "method" },
|
||||
{ in: "body", key: "inputs" },
|
||||
],
|
||||
},
|
||||
],
|
||||
|
||||
@ -1769,6 +1769,34 @@ export type SubtaskPartInput = {
|
||||
export type ProviderAuthMethod = {
|
||||
type: "oauth" | "api"
|
||||
label: string
|
||||
prompts?: Array<
|
||||
| {
|
||||
type: "text"
|
||||
key: string
|
||||
message: string
|
||||
placeholder?: string
|
||||
when?: {
|
||||
key: string
|
||||
op: "eq" | "neq"
|
||||
value: string
|
||||
}
|
||||
}
|
||||
| {
|
||||
type: "select"
|
||||
key: string
|
||||
message: string
|
||||
options: Array<{
|
||||
label: string
|
||||
value: string
|
||||
hint?: string
|
||||
}>
|
||||
when?: {
|
||||
key: string
|
||||
op: "eq" | "neq"
|
||||
value: string
|
||||
}
|
||||
}
|
||||
>
|
||||
}
|
||||
|
||||
export type ProviderAuthAuthorization = {
|
||||
@ -3983,6 +4011,12 @@ export type ProviderOauthAuthorizeData = {
|
||||
* Auth method index
|
||||
*/
|
||||
method: number
|
||||
/**
|
||||
* Prompt inputs
|
||||
*/
|
||||
inputs?: {
|
||||
[key: string]: string
|
||||
}
|
||||
}
|
||||
path: {
|
||||
/**
|
||||
|
||||
@ -4761,6 +4761,16 @@
|
||||
"method": {
|
||||
"description": "Auth method index",
|
||||
"type": "number"
|
||||
},
|
||||
"inputs": {
|
||||
"description": "Prompt inputs",
|
||||
"type": "object",
|
||||
"propertyNames": {
|
||||
"type": "string"
|
||||
},
|
||||
"additionalProperties": {
|
||||
"type": "string"
|
||||
}
|
||||
}
|
||||
},
|
||||
"required": ["method"]
|
||||
@ -11541,6 +11551,114 @@
|
||||
},
|
||||
"label": {
|
||||
"type": "string"
|
||||
},
|
||||
"prompts": {
|
||||
"type": "array",
|
||||
"items": {
|
||||
"anyOf": [
|
||||
{
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"type": {
|
||||
"type": "string",
|
||||
"const": "text"
|
||||
},
|
||||
"key": {
|
||||
"type": "string"
|
||||
},
|
||||
"message": {
|
||||
"type": "string"
|
||||
},
|
||||
"placeholder": {
|
||||
"type": "string"
|
||||
},
|
||||
"when": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"key": {
|
||||
"type": "string"
|
||||
},
|
||||
"op": {
|
||||
"anyOf": [
|
||||
{
|
||||
"type": "string",
|
||||
"const": "eq"
|
||||
},
|
||||
{
|
||||
"type": "string",
|
||||
"const": "neq"
|
||||
}
|
||||
]
|
||||
},
|
||||
"value": {
|
||||
"type": "string"
|
||||
}
|
||||
},
|
||||
"required": ["key", "op", "value"]
|
||||
}
|
||||
},
|
||||
"required": ["type", "key", "message"]
|
||||
},
|
||||
{
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"type": {
|
||||
"type": "string",
|
||||
"const": "select"
|
||||
},
|
||||
"key": {
|
||||
"type": "string"
|
||||
},
|
||||
"message": {
|
||||
"type": "string"
|
||||
},
|
||||
"options": {
|
||||
"type": "array",
|
||||
"items": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"label": {
|
||||
"type": "string"
|
||||
},
|
||||
"value": {
|
||||
"type": "string"
|
||||
},
|
||||
"hint": {
|
||||
"type": "string"
|
||||
}
|
||||
},
|
||||
"required": ["label", "value"]
|
||||
}
|
||||
},
|
||||
"when": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"key": {
|
||||
"type": "string"
|
||||
},
|
||||
"op": {
|
||||
"anyOf": [
|
||||
{
|
||||
"type": "string",
|
||||
"const": "eq"
|
||||
},
|
||||
{
|
||||
"type": "string",
|
||||
"const": "neq"
|
||||
}
|
||||
]
|
||||
},
|
||||
"value": {
|
||||
"type": "string"
|
||||
}
|
||||
},
|
||||
"required": ["key", "op", "value"]
|
||||
}
|
||||
},
|
||||
"required": ["type", "key", "message", "options"]
|
||||
}
|
||||
]
|
||||
}
|
||||
}
|
||||
},
|
||||
"required": ["type", "label"]
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user