mirror of
https://gitea.toothfairyai.com/ToothFairyAI/tf_code.git
synced 2026-03-31 22:32:28 +00:00
144 lines
3.7 KiB
TypeScript
144 lines
3.7 KiB
TypeScript
import { Instance } from "@/project/instance"
|
|
import { Plugin } from "../plugin"
|
|
import { map, filter, pipe, fromEntries, mapValues } from "remeda"
|
|
import z from "zod"
|
|
import { fn } from "@/util/fn"
|
|
import type { AuthOuathResult, Hooks } from "@opencode-ai/plugin"
|
|
import { NamedError } from "@opencode-ai/util/error"
|
|
import { Auth } from "@/auth"
|
|
|
|
export namespace ProviderAuth {
|
|
const state = Instance.state(async () => {
|
|
const methods = pipe(
|
|
await Plugin.list(),
|
|
filter((x) => x.auth?.provider !== undefined),
|
|
map((x) => [x.auth!.provider, x.auth!] as const),
|
|
fromEntries(),
|
|
)
|
|
return { methods, pending: {} as Record<string, AuthOuathResult> }
|
|
})
|
|
|
|
export const Method = z
|
|
.object({
|
|
type: z.union([z.literal("oauth"), z.literal("api")]),
|
|
label: z.string(),
|
|
})
|
|
.meta({
|
|
ref: "ProviderAuthMethod",
|
|
})
|
|
export type Method = z.infer<typeof Method>
|
|
|
|
export async function methods() {
|
|
const s = await state().then((x) => x.methods)
|
|
return mapValues(s, (x) =>
|
|
x.methods.map(
|
|
(y): Method => ({
|
|
type: y.type,
|
|
label: y.label,
|
|
}),
|
|
),
|
|
)
|
|
}
|
|
|
|
export const Authorization = z
|
|
.object({
|
|
url: z.string(),
|
|
method: z.union([z.literal("auto"), z.literal("code")]),
|
|
instructions: z.string(),
|
|
})
|
|
.meta({
|
|
ref: "ProviderAuthAuthorization",
|
|
})
|
|
export type Authorization = z.infer<typeof Authorization>
|
|
|
|
export const authorize = fn(
|
|
z.object({
|
|
providerID: z.string(),
|
|
method: z.number(),
|
|
}),
|
|
async (input): Promise<Authorization | undefined> => {
|
|
const auth = await state().then((s) => s.methods[input.providerID])
|
|
const method = auth.methods[input.method]
|
|
if (method.type === "oauth") {
|
|
const result = await method.authorize()
|
|
await state().then((s) => (s.pending[input.providerID] = result))
|
|
return {
|
|
url: result.url,
|
|
method: result.method,
|
|
instructions: result.instructions,
|
|
}
|
|
}
|
|
},
|
|
)
|
|
|
|
export const callback = fn(
|
|
z.object({
|
|
providerID: z.string(),
|
|
method: z.number(),
|
|
code: z.string().optional(),
|
|
}),
|
|
async (input) => {
|
|
const match = await state().then((s) => s.pending[input.providerID])
|
|
if (!match) throw new OauthMissing({ providerID: input.providerID })
|
|
let result
|
|
|
|
if (match.method === "code") {
|
|
if (!input.code) throw new OauthCodeMissing({ providerID: input.providerID })
|
|
result = await match.callback(input.code)
|
|
}
|
|
|
|
if (match.method === "auto") {
|
|
result = await match.callback()
|
|
}
|
|
|
|
if (result?.type === "success") {
|
|
if ("key" in result) {
|
|
await Auth.set(input.providerID, {
|
|
type: "api",
|
|
key: result.key,
|
|
})
|
|
}
|
|
if ("refresh" in result) {
|
|
await Auth.set(input.providerID, {
|
|
type: "oauth",
|
|
access: result.access,
|
|
refresh: result.refresh,
|
|
expires: result.expires,
|
|
})
|
|
}
|
|
return
|
|
}
|
|
|
|
throw new OauthCallbackFailed({})
|
|
},
|
|
)
|
|
|
|
export const api = fn(
|
|
z.object({
|
|
providerID: z.string(),
|
|
key: z.string(),
|
|
}),
|
|
async (input) => {
|
|
await Auth.set(input.providerID, {
|
|
type: "api",
|
|
key: input.key,
|
|
})
|
|
},
|
|
)
|
|
|
|
export const OauthMissing = NamedError.create(
|
|
"ProviderAuthOauthMissing",
|
|
z.object({
|
|
providerID: z.string(),
|
|
}),
|
|
)
|
|
export const OauthCodeMissing = NamedError.create(
|
|
"ProviderAuthOauthCodeMissing",
|
|
z.object({
|
|
providerID: z.string(),
|
|
}),
|
|
)
|
|
|
|
export const OauthCallbackFailed = NamedError.create("ProviderAuthOauthCallbackFailed", z.object({}))
|
|
}
|