diff --git a/packages/opencode/src/auth/index.ts b/packages/opencode/src/auth/index.ts deleted file mode 100644 index b9c8a78caf9..00000000000 --- a/packages/opencode/src/auth/index.ts +++ /dev/null @@ -1,70 +0,0 @@ -import path from "path" -import { Global } from "../global" -import fs from "fs/promises" -import z from "zod" - -export namespace Auth { - export const Oauth = z - .object({ - type: z.literal("oauth"), - refresh: z.string(), - access: z.string(), - expires: z.number(), - enterpriseUrl: z.string().optional(), - }) - .meta({ ref: "OAuth" }) - - export const Api = z - .object({ - type: z.literal("api"), - key: z.string(), - }) - .meta({ ref: "ApiAuth" }) - - export const WellKnown = z - .object({ - type: z.literal("wellknown"), - key: z.string(), - token: z.string(), - }) - .meta({ ref: "WellKnownAuth" }) - - export const Info = z.discriminatedUnion("type", [Oauth, Api, WellKnown]).meta({ ref: "Auth" }) - export type Info = z.infer - - const filepath = path.join(Global.Path.data, "auth.json") - - export async function get(providerID: string) { - const auth = await all() - return auth[providerID] - } - - export async function all(): Promise> { - const file = Bun.file(filepath) - const data = await file.json().catch(() => ({}) as Record) - return Object.entries(data).reduce( - (acc, [key, value]) => { - const parsed = Info.safeParse(value) - if (!parsed.success) return acc - acc[key] = parsed.data - return acc - }, - {} as Record, - ) - } - - export async function set(key: string, info: Info) { - const file = Bun.file(filepath) - const data = await all() - await Bun.write(file, JSON.stringify({ ...data, [key]: info }, null, 2)) - await fs.chmod(file.name!, 0o600) - } - - export async function remove(key: string) { - const file = Bun.file(filepath) - const data = await all() - delete data[key] - await Bun.write(file, JSON.stringify(data, null, 2)) - await fs.chmod(file.name!, 0o600) - } -} diff --git a/packages/opencode/src/cli/cmd/auth.ts b/packages/opencode/src/cli/cmd/auth.ts index 658329fb6ef..dd1c1d2c45d 100644 --- a/packages/opencode/src/cli/cmd/auth.ts +++ b/packages/opencode/src/cli/cmd/auth.ts @@ -1,4 +1,3 @@ -import { Auth } from "../../auth" import { cmd } from "./cmd" import * as prompts from "@clack/prompts" import { UI } from "../ui" @@ -8,12 +7,69 @@ import path from "path" import os from "os" import { Config } from "../../config/config" import { Global } from "../../global" -import { Plugin } from "../../plugin" import { Instance } from "../../project/instance" import type { Hooks } from "@opencode-ai/plugin" +import { CredentialStore, CredentialsMigrate } from "../../credentials" +import { ProviderAuthRegistry } from "../../provider-auth/registry" +import { VaultKey } from "../../vault/key" +import { VaultFS } from "../../vault/fs" +import fs from "fs/promises" type PluginAuth = NonNullable +async function storeOAuthCredential(args: { + providerId: string + access: string + refresh?: string + expires?: number + namespace?: string + label?: string + extra?: Record +}) { + await CredentialsMigrate.migrateIfNeeded() + const config = await Config.get() + const namespace = (args.namespace ?? config.provider?.[args.providerId]?.auth?.namespace ?? "default").trim() || "default" + const existingOauth = (await CredentialStore.findByProvider(args.providerId, namespace)).filter((r) => r.meta.kind === "oauth") + const existingLabels = new Set(existingOauth.map((r) => r.meta.label ?? "")) + const labelBase = args.label?.split("\n")[0]?.trim() || undefined + + const label = (() => { + if (labelBase) { + if (!existingLabels.has(labelBase)) return labelBase + let n = 2 + while (existingLabels.has(`${labelBase}-${n}`)) n++ + return `${labelBase}-${n}` + } + + const hasDefault = existingLabels.has("default") + return hasDefault ? `${args.providerId}-${new Date().toISOString()}` : "default" + })() + + await CredentialStore.put({ + providerId: args.providerId, + namespace, + kind: "oauth", + label, + secret: { + accessToken: args.access, + refreshToken: args.refresh || undefined, + expiresAt: args.expires || undefined, + extra: args.extra, + }, + }) +} + +async function storeApiCredential(args: { providerId: string; apiKey: string }) { + await CredentialsMigrate.migrateIfNeeded() + await CredentialStore.upsertSingleton({ + providerId: args.providerId, + namespace: "default", + kind: "api", + label: "default", + secret: { apiKey: args.apiKey }, + }) +} + /** * Handle plugin-based authentication flow. * Returns true if auth was handled, false if it should fall through to default handling. @@ -35,6 +91,26 @@ async function handlePluginAuth(plugin: { auth: PluginAuth }, provider: string): } const method = plugin.auth.methods[index] + let namespace: string | undefined + let label: string | undefined + if (method.type === "oauth") { + const config = await Config.get() + const defaultNs = config.provider?.[provider]?.auth?.namespace ?? "default" + const rawNamespace = await prompts.text({ + message: "Namespace (optional)", + placeholder: defaultNs, + }) + if (prompts.isCancel(rawNamespace)) throw new UI.CancelledError() + namespace = rawNamespace.split("\n")[0]?.trim() || defaultNs + + const rawLabel = await prompts.text({ + message: "Account label (optional)", + placeholder: "default", + }) + if (prompts.isCancel(rawLabel)) throw new UI.CancelledError() + label = rawLabel.split("\n")[0]?.trim() || undefined + } + // Handle prompts for all auth types await new Promise((resolve) => setTimeout(resolve, 10)) const inputs: Record = {} @@ -79,24 +155,23 @@ async function handlePluginAuth(plugin: { auth: PluginAuth }, provider: string): if (result.type === "failed") { spinner.stop("Failed to authorize", 1) } - if (result.type === "success") { - const saveProvider = result.provider ?? provider - if ("refresh" in result) { - const { type: _, provider: __, refresh, access, expires, ...extraFields } = result - await Auth.set(saveProvider, { - type: "oauth", - refresh, - access, - expires, - ...extraFields, - }) - } - if ("key" in result) { - await Auth.set(saveProvider, { - type: "api", - key: result.key, - }) - } + if (result.type === "success") { + const saveProvider = result.provider ?? provider + if ("refresh" in result) { + const { type: _, provider: __, refresh, access, expires, ...extraFields } = result + await storeOAuthCredential({ + providerId: saveProvider, + refresh, + access, + expires, + namespace, + label, + extra: extraFields, + }) + } + if ("key" in result) { + await storeApiCredential({ providerId: saveProvider, apiKey: result.key }) + } spinner.stop("Login successful") } } @@ -111,27 +186,26 @@ async function handlePluginAuth(plugin: { auth: PluginAuth }, provider: string): if (result.type === "failed") { prompts.log.error("Failed to authorize") } - if (result.type === "success") { - const saveProvider = result.provider ?? provider - if ("refresh" in result) { - const { type: _, provider: __, refresh, access, expires, ...extraFields } = result - await Auth.set(saveProvider, { - type: "oauth", - refresh, - access, - expires, - ...extraFields, - }) - } - if ("key" in result) { - await Auth.set(saveProvider, { - type: "api", - key: result.key, - }) - } - prompts.log.success("Login successful") - } - } + if (result.type === "success") { + const saveProvider = result.provider ?? provider + if ("refresh" in result) { + const { type: _, provider: __, refresh, access, expires, ...extraFields } = result + await storeOAuthCredential({ + providerId: saveProvider, + refresh, + access, + expires, + namespace, + label, + extra: extraFields, + }) + } + if ("key" in result) { + await storeApiCredential({ providerId: saveProvider, apiKey: result.key }) + } + prompts.log.success("Login successful") + } + } prompts.outro("Done") return true @@ -143,14 +217,11 @@ async function handlePluginAuth(plugin: { auth: PluginAuth }, provider: string): if (result.type === "failed") { prompts.log.error("Failed to authorize") } - if (result.type === "success") { - const saveProvider = result.provider ?? provider - await Auth.set(saveProvider, { - type: "api", - key: result.key, - }) - prompts.log.success("Login successful") - } + if (result.type === "success") { + const saveProvider = result.provider ?? provider + await storeApiCredential({ providerId: saveProvider, apiKey: result.key }) + prompts.log.success("Login successful") + } prompts.outro("Done") return true } @@ -163,29 +234,169 @@ export const AuthCommand = cmd({ command: "auth", describe: "manage credentials", builder: (yargs) => - yargs.command(AuthLoginCommand).command(AuthLogoutCommand).command(AuthListCommand).demandCommand(), + yargs.command(AuthLoginCommand).command(AuthLogoutCommand).command(AuthListCommand).command(AuthVaultCommand).demandCommand(), + async handler() {}, +}) + +export const AuthVaultCommand = cmd({ + command: "vault", + describe: "manage the local encryption key for credentials", + builder: (yargs) => + yargs + .command(AuthVaultInitCommand) + .command(AuthVaultExportCommand) + .command(AuthVaultImportCommand) + .demandCommand(), async handler() {}, }) +export const AuthVaultInitCommand = cmd({ + command: "init", + describe: "initialize the local vault key (creates Global.Path.config/vault.key)", + builder: (yargs) => + yargs.option("force", { + type: "boolean", + describe: "Overwrite existing vault key file", + default: false, + }), + async handler(args) { + UI.empty() + prompts.intro("Vault init") + + const keyPath = VaultKey.keyPath() + const exists = await VaultFS.exists(keyPath) + if (exists && args.force) { + const confirm = await prompts.confirm({ + message: `Overwrite existing vault key at ${keyPath}? (This will break decryption of existing credentials unless you back up the current key)`, + initialValue: false, + }) + if (prompts.isCancel(confirm)) throw new UI.CancelledError() + if (!confirm) { + prompts.outro("Cancelled") + return + } + } + + const result = await VaultKey.init({ force: Boolean(args.force) }) + prompts.log.info(`Key path: ${result.path}`) + prompts.log.info(`Source: ${result.source}`) + prompts.log.success(result.created ? "Vault key created" : "Vault key already exists") + prompts.outro("Done") + }, +}) + +export const AuthVaultExportCommand = cmd({ + command: "export", + describe: "export the vault key (base64) for backup/migration", + builder: (yargs) => + yargs.option("output", { + type: "string", + describe: "Write the key to a file instead of stdout", + }), + async handler(args) { + const key = await VaultKey.exportBase64() + const output = (args.output as string | undefined)?.trim() + if (output) { + await VaultFS.atomicWriteText(output, key + "\n", 0o600) + await fs.chmod(output, 0o600).catch(() => {}) + UI.empty() + prompts.intro("Vault export") + prompts.log.success(`Wrote key to ${output}`) + prompts.outro("Done") + return + } + + // Default: print only the key to stdout for easy piping. + process.stdout.write(key + "\n") + }, +}) + +export const AuthVaultImportCommand = cmd({ + command: "import", + describe: "import a vault key (base64) to Global.Path.config/vault.key", + builder: (yargs) => + yargs + .option("file", { + type: "string", + describe: "Read the key from a file", + }) + .option("key", { + type: "string", + describe: "Provide the base64 key directly", + }), + async handler(args) { + UI.empty() + prompts.intro("Vault import") + + const keyPath = VaultKey.keyPath() + const exists = await VaultFS.exists(keyPath) + if (exists) { + const confirm = await prompts.confirm({ + message: `Overwrite existing vault key at ${keyPath}?`, + initialValue: false, + }) + if (prompts.isCancel(confirm)) throw new UI.CancelledError() + if (!confirm) { + prompts.outro("Cancelled") + return + } + } + + let key: string | undefined + const filePath = (args.file as string | undefined)?.trim() + const inline = (args.key as string | undefined)?.trim() + + if (filePath) { + key = (await fs.readFile(filePath, "utf8")).trim() + } else if (inline) { + key = inline + } else { + const entered = await prompts.password({ + message: "Paste vault key (base64)", + validate: (x) => (x && x.length > 0 ? undefined : "Required"), + }) + if (prompts.isCancel(entered)) throw new UI.CancelledError() + key = entered.trim() + } + + await VaultKey.importBase64(key) + prompts.log.success(`Imported key to ${keyPath}`) + prompts.outro("Done") + }, +}) + export const AuthListCommand = cmd({ command: "list", aliases: ["ls"], describe: "list providers", async handler() { UI.empty() - const authPath = path.join(Global.Path.data, "auth.json") + const authPath = path.join(Global.Path.data, "credentials") const homedir = os.homedir() const displayPath = authPath.startsWith(homedir) ? authPath.replace(homedir, "~") : authPath prompts.intro(`Credentials ${UI.Style.TEXT_DIM}${displayPath}`) - const results = Object.entries(await Auth.all()) + await CredentialsMigrate.migrateIfNeeded() + const { records, errors } = await CredentialStore.listAll() const database = await ModelsDev.get() - for (const [providerID, result] of results) { - const name = database[providerID]?.name || providerID - prompts.log.info(`${name} ${UI.Style.TEXT_DIM}${result.type}`) + const sorted = [...records].sort((a, b) => { + if (a.meta.providerId !== b.meta.providerId) return a.meta.providerId.localeCompare(b.meta.providerId) + if (a.meta.namespace !== b.meta.namespace) return a.meta.namespace.localeCompare(b.meta.namespace) + return a.meta.createdAt - b.meta.createdAt + }) + for (const record of sorted) { + const name = database[record.meta.providerId]?.name || record.meta.providerId + const label = record.meta.label + ? `${record.meta.namespace}/${record.meta.label}` + : `${record.meta.namespace}/${record.meta.id}` + prompts.log.info(`${name} ${UI.Style.TEXT_DIM}${record.meta.kind} ${label}`) + } + + if (errors.length > 0) { + prompts.log.warn(`${errors.length} credential file(s) could not be read/validated.`) } - prompts.outro(`${results.length} credentials`) + prompts.outro(`${records.length} credential record` + (records.length === 1 ? "" : "s")) // Environment variables section const activeEnvVars: Array<{ provider: string; envVar: string }> = [] @@ -240,16 +451,19 @@ export const AuthLoginCommand = cmd({ prompts.log.error("Failed") prompts.outro("Done") return - } - const token = await new Response(proc.stdout).text() - await Auth.set(args.url, { - type: "wellknown", - key: wellknown.auth.env, - token: token.trim(), - }) - prompts.log.success("Logged into " + args.url) - prompts.outro("Done") - return + } + const token = await new Response(proc.stdout).text() + await CredentialsMigrate.migrateIfNeeded() + await CredentialStore.upsertSingleton({ + providerId: args.url, + namespace: "default", + kind: "wellknown", + label: "default", + secret: { envKey: wellknown.auth.env, token: token.trim() }, + }) + prompts.log.success("Logged into " + args.url) + prompts.outro("Done") + return } await ModelsDev.refresh().catch(() => {}) @@ -302,15 +516,15 @@ export const AuthLoginCommand = cmd({ label: "Other", }, ], - }) + }) - if (prompts.isCancel(provider)) throw new UI.CancelledError() + if (prompts.isCancel(provider)) throw new UI.CancelledError() - const plugin = await Plugin.list().then((x) => x.find((x) => x.auth?.provider === provider)) - if (plugin && plugin.auth) { - const handled = await handlePluginAuth({ auth: plugin.auth }, provider) - if (handled) return - } + const core = ProviderAuthRegistry.getAuthHook(provider) + if (core) { + const handled = await handlePluginAuth({ auth: core as any }, provider) + if (handled) return + } if (provider === "other") { provider = await prompts.text({ @@ -318,15 +532,14 @@ export const AuthLoginCommand = cmd({ validate: (x) => (x && x.match(/^[0-9a-z-]+$/) ? undefined : "a-z, 0-9 and hyphens only"), }) if (prompts.isCancel(provider)) throw new UI.CancelledError() - provider = provider.replace(/^@ai-sdk\//, "") - if (prompts.isCancel(provider)) throw new UI.CancelledError() + provider = provider.replace(/^@ai-sdk\//, "") + if (prompts.isCancel(provider)) throw new UI.CancelledError() - // Check if a plugin provides auth for this custom provider - const customPlugin = await Plugin.list().then((x) => x.find((x) => x.auth?.provider === provider)) - if (customPlugin && customPlugin.auth) { - const handled = await handlePluginAuth({ auth: customPlugin.auth }, provider) - if (handled) return - } + const core = ProviderAuthRegistry.getAuthHook(provider) + if (core) { + const handled = await handlePluginAuth({ auth: core as any }, provider) + if (handled) return + } prompts.log.warn( `This only stores a credential for ${provider} - you will need configure it in opencode.json, check the docs for examples.`, @@ -349,17 +562,14 @@ export const AuthLoginCommand = cmd({ prompts.log.info("You can create an api key at https://vercel.link/ai-gateway-token") } - const key = await prompts.password({ - message: "Enter your API key", - validate: (x) => (x && x.length > 0 ? undefined : "Required"), - }) - if (prompts.isCancel(key)) throw new UI.CancelledError() - await Auth.set(provider, { - type: "api", - key, - }) + const key = await prompts.password({ + message: "Enter your API key", + validate: (x) => (x && x.length > 0 ? undefined : "Required"), + }) + if (prompts.isCancel(key)) throw new UI.CancelledError() + await storeApiCredential({ providerId: provider, apiKey: key }) - prompts.outro("Done") + prompts.outro("Done") }, }) }, @@ -370,22 +580,55 @@ export const AuthLogoutCommand = cmd({ describe: "log out from a configured provider", async handler() { UI.empty() - const credentials = await Auth.all().then((x) => Object.entries(x)) + await CredentialsMigrate.migrateIfNeeded() + const { records } = await CredentialStore.listAll() + const providers = Array.from(new Set(records.map((r) => r.meta.providerId))) prompts.intro("Remove credential") - if (credentials.length === 0) { + if (providers.length === 0) { prompts.log.error("No credentials found") return } const database = await ModelsDev.get() const providerID = await prompts.select({ message: "Select provider", - options: credentials.map(([key, value]) => ({ - label: (database[key]?.name || key) + UI.Style.TEXT_DIM + " (" + value.type + ")", - value: key, - })), + options: providers.map((key) => { + const name = database[key]?.name || key + const count = records.filter((r) => r.meta.providerId === key).length + return { + label: `${name} ${UI.Style.TEXT_DIM}(${count})`, + value: key, + } + }), }) - if (prompts.isCancel(providerID)) throw new UI.CancelledError() - await Auth.remove(providerID) - prompts.outro("Logout successful") - }, -}) + if (prompts.isCancel(providerID)) throw new UI.CancelledError() + + const matches = records + .filter((r) => r.meta.providerId === providerID) + .sort((a, b) => { + if (a.meta.namespace !== b.meta.namespace) return a.meta.namespace.localeCompare(b.meta.namespace) + if ((a.meta.label ?? "") !== (b.meta.label ?? "")) return (a.meta.label ?? "").localeCompare(b.meta.label ?? "") + return a.meta.createdAt - b.meta.createdAt + }) + + if (matches.length === 0) { + prompts.log.error("No credentials found for provider") + prompts.outro("Done") + return + } + + const selected = await prompts.multiselect({ + message: "Select credential(s) to remove", + options: matches.map((r) => { + const label = r.meta.label ? `${r.meta.namespace}/${r.meta.label}` : `${r.meta.namespace}/${r.meta.id}` + return { + label, + value: r.meta.id, + hint: r.meta.kind, + } + }), + }) + if (prompts.isCancel(selected)) throw new UI.CancelledError() + await Promise.all(selected.map((id) => CredentialStore.remove(id))) + prompts.outro("Logout successful") + }, + }) diff --git a/packages/opencode/src/cli/cmd/mcp.ts b/packages/opencode/src/cli/cmd/mcp.ts index 9ca4b3bff8b..fda823f8b48 100644 --- a/packages/opencode/src/cli/cmd/mcp.ts +++ b/packages/opencode/src/cli/cmd/mcp.ts @@ -4,7 +4,7 @@ import { StreamableHTTPClientTransport } from "@modelcontextprotocol/sdk/client/ import * as prompts from "@clack/prompts" import { UI } from "../ui" import { MCP } from "../../mcp" -import { McpAuth } from "../../mcp/auth" +import { McpCredentials } from "../../mcp/credentials" import { Config } from "../../config/config" import { Instance } from "../../project/instance" import path from "path" @@ -222,8 +222,7 @@ export const McpLogoutCommand = cmd({ UI.empty() prompts.intro("MCP OAuth Logout") - const authPath = path.join(Global.Path.data, "mcp-auth.json") - const credentials = await McpAuth.all() + const credentials = await McpCredentials.all() const serverNames = Object.keys(credentials) if (serverNames.length === 0) { diff --git a/packages/opencode/src/cli/cmd/tui/component/dialog-credentials.tsx b/packages/opencode/src/cli/cmd/tui/component/dialog-credentials.tsx new file mode 100644 index 00000000000..d753c55b118 --- /dev/null +++ b/packages/opencode/src/cli/cmd/tui/component/dialog-credentials.tsx @@ -0,0 +1,133 @@ +import { createMemo } from "solid-js" +import { useSync } from "@tui/context/sync" +import { DialogSelect, type DialogSelectOption, type DialogSelectRef } from "@tui/ui/dialog-select" +import { useTheme } from "../context/theme" +import { Keybind } from "@/util/keybind" +import { useSDK } from "@tui/context/sdk" +import { useDialog } from "@tui/ui/dialog" +import { DialogPrompt } from "../ui/dialog-prompt" +import { TextAttributes } from "@opentui/core" + +function formatCredentialLabel(meta: { namespace?: string; label?: string; id: string }) { + const namespace = meta.namespace ?? "default" + const label = meta.label ? meta.label : meta.id + return `${namespace}/${label}` +} + +export function DialogCredentials() { + const sync = useSync() + const sdk = useSDK() + const dialog = useDialog() + const { theme } = useTheme() + + const providerNames = createMemo(() => { + const map = new Map() + for (const p of sync.data.provider_next.all) map.set(p.id, p.name) + return map + }) + + const countsByProvider = createMemo(() => { + const counts = new Map() + for (const c of sync.data.credential) { + counts.set(c.providerId, (counts.get(c.providerId) ?? 0) + 1) + } + return counts + }) + + const options = createMemo(() => { + const now = Date.now() + const counts = countsByProvider() + const names = providerNames() + + return [...sync.data.credential] + .toSorted((a, b) => { + if (a.providerId !== b.providerId) return a.providerId.localeCompare(b.providerId) + const aNs = a.namespace ?? "default" + const bNs = b.namespace ?? "default" + if (aNs !== bNs) return aNs.localeCompare(bNs) + if ((a.label ?? "") !== (b.label ?? "")) return (a.label ?? "").localeCompare(b.label ?? "") + return a.createdAt - b.createdAt + }) + .map((c) => { + const cooldownUntil = c.health?.cooldownUntil + const lastStatusCode = c.health?.lastStatusCode + const providerTitle = names.get(c.providerId) ?? c.providerId + const category = `${providerTitle} (${counts.get(c.providerId) ?? 0})` + const cooldown = + cooldownUntil && cooldownUntil > now ? cooldownUntil - now : 0 + const footer = cooldown + ? `${c.kind ?? "unknown"} • cooldown ${(cooldown / 1000).toFixed(0)}s` + : `${c.kind ?? "unknown"}${lastStatusCode ? ` • last ${lastStatusCode}` : ""}` + + return { + value: c.id, + title: providerTitle, + description: formatCredentialLabel(c), + category, + footer: ( + + {footer} + + ), + } satisfies DialogSelectOption + }) + }) + + const keybinds = createMemo(() => [ + { + keybind: Keybind.parse("r")[0], + title: "rename", + onTrigger: async (option: DialogSelectOption) => { + const current = sync.data.credential.find((c) => c.id === option.value) + const initial = current?.label ?? "" + const next = await DialogPrompt.show(dialog, "New label", { + placeholder: "default", + value: initial, + description: () => ( + + Renames this credential label (namespace stays the same). + + ), + }) + dialog.replace(() => ) + if (next === null) return + const label = next.split("\n")[0]?.trim() + if (!label) return + await sdk.client.credential.update({ credentialID: option.value, label }) + await sync.bootstrap() + }, + }, + { + keybind: Keybind.parse("d")[0], + title: "delete", + onTrigger: async (option: DialogSelectOption) => { + const confirm = await DialogPrompt.show(dialog, "Type DELETE to remove", { + placeholder: "DELETE", + description: () => ( + + + This cannot be undone. + + + ), + }) + dialog.replace(() => ) + if (confirm !== "DELETE") return + await sdk.client.credential.remove({ credentialID: option.value }) + await sync.bootstrap() + }, + }, + ]) + + return ( + ) => {}} + title="Credentials" + options={options()} + keybind={keybinds()} + onSelect={() => { + // no-op: actions are via keybinds + }} + /> + ) +} diff --git a/packages/opencode/src/cli/cmd/tui/component/dialog-provider.tsx b/packages/opencode/src/cli/cmd/tui/component/dialog-provider.tsx index 5cc114f92f0..2bd99b3d124 100644 --- a/packages/opencode/src/cli/cmd/tui/component/dialog-provider.tsx +++ b/packages/opencode/src/cli/cmd/tui/component/dialog-provider.tsx @@ -9,6 +9,7 @@ import { useTheme } from "../context/theme" import { TextAttributes } from "@opentui/core" import type { ProviderAuthAuthorization } from "@opencode-ai/sdk/v2" import { DialogModel } from "./dialog-model" +import { DialogCredentials } from "./dialog-credentials" const PROVIDER_PRIORITY: Record = { opencode: 0, @@ -23,17 +24,35 @@ export function createDialogProviderOptions() { const sync = useSync() const dialog = useDialog() const sdk = useSDK() + const { theme } = useTheme() + + const oauthCounts = createMemo(() => { + const counts = new Map() + for (const c of sync.data.credential) { + if (c.kind !== "oauth") continue + counts.set(c.providerId, (counts.get(c.providerId) ?? 0) + 1) + } + return counts + }) + const options = createMemo(() => { - return pipe( + const counts = oauthCounts() + const providerOptions = pipe( sync.data.provider_next.all, sortBy((x) => PROVIDER_PRIORITY[x.id] ?? 99), map((provider) => ({ title: provider.name, value: provider.id, - description: { - opencode: "(Recommended)", - anthropic: "(Claude Max or API key)", - }[provider.id], + description: (() => { + const base = + { + opencode: "(Recommended)", + anthropic: "(Claude Max or API key)", + }[provider.id] ?? "" + const n = counts.get(provider.id) ?? 0 + const suffix = n > 0 ? ` (${n} account${n === 1 ? "" : "s"})` : "" + return `${base}${suffix}`.trim() || undefined + })(), category: provider.id in PROVIDER_PRIORITY ? "Popular" : "Other", async onSelect() { const methods = sync.data.provider_auth[provider.id] ?? [ @@ -63,27 +82,78 @@ export function createDialogProviderOptions() { if (index == null) return const method = methods[index] if (method.type === "oauth") { + const rawNamespace = await DialogPrompt.show(dialog, "Namespace (optional)", { + placeholder: "default", + description: () => ( + + Leave blank to use the default namespace. + + ), + }) + if (rawNamespace === null) return + const namespace = rawNamespace.split("\n")[0]?.trim() || undefined + + const rawLabel = await DialogPrompt.show(dialog, "Account label (optional)", { + placeholder: "default", + description: () => ( + + Leave blank to auto-generate. + + ), + }) + if (rawLabel === null) return + const label = rawLabel.split("\n")[0]?.trim() || undefined + const result = await sdk.client.provider.oauth.authorize({ providerID: provider.id, method: index, + namespace, + label, }) if (result.data?.method === "code") { dialog.replace(() => ( - + )) } if (result.data?.method === "auto") { dialog.replace(() => ( - + )) } } if (method.type === "api") { - return dialog.replace(() => ) + dialog.replace(() => ) } }, })), ) + + return [ + { + title: "Manage connected accounts", + value: "__manage__", + description: "View, rename, or remove stored credentials", + category: "Manage", + async onSelect() { + dialog.replace(() => ) + }, + }, + ...providerOptions, + ] }) return options } @@ -98,6 +168,8 @@ interface AutoMethodProps { providerID: string title: string authorization: ProviderAuthAuthorization + namespace?: string + label?: string } function AutoMethod(props: AutoMethodProps) { const { theme } = useTheme() @@ -109,6 +181,8 @@ function AutoMethod(props: AutoMethodProps) { const result = await sdk.client.provider.oauth.callback({ providerID: props.providerID, method: props.index, + namespace: props.namespace, + label: props.label, }) if (result.error) { dialog.clear() @@ -141,6 +215,8 @@ interface CodeMethodProps { title: string providerID: string authorization: ProviderAuthAuthorization + namespace?: string + label?: string } function CodeMethod(props: CodeMethodProps) { const { theme } = useTheme() @@ -158,6 +234,8 @@ function CodeMethod(props: CodeMethodProps) { providerID: props.providerID, method: props.index, code: value, + namespace: props.namespace, + label: props.label, }) if (!error) { await sdk.client.instance.dispose() diff --git a/packages/opencode/src/cli/cmd/tui/component/dialog-status.tsx b/packages/opencode/src/cli/cmd/tui/component/dialog-status.tsx index b85cd5c6542..bf53797a057 100644 --- a/packages/opencode/src/cli/cmd/tui/component/dialog-status.tsx +++ b/packages/opencode/src/cli/cmd/tui/component/dialog-status.tsx @@ -10,6 +10,14 @@ export function DialogStatus() { const { theme } = useTheme() const enabledFormatters = createMemo(() => sync.data.formatter.filter((f) => f.enabled)) + const rotationStats = createMemo(() => sync.data.rotation_stats) + const topRotations = createMemo(() => { + const stats = rotationStats() + if (!stats) return [] + return Object.entries(stats.byProvider) + .toSorted((a, b) => (b[1].rotations ?? 0) - (a[1].rotations ?? 0)) + .slice(0, 6) + }) const plugins = createMemo(() => { const list = sync.data.config.plugin ?? [] @@ -134,6 +142,29 @@ export function DialogStatus() { + No Rotation Stats}> + {(stats) => { + const ageMinutes = Math.max(0, Math.floor((Date.now() - stats().since) / 60_000)) + return ( + + OAuth Rotation (since {ageMinutes}m) + + {stats().totals.requests} req • {stats().totals.attempts} attempts • {stats().totals.rotations} rotations •{" "} + {stats().totals.refreshSuccess}/{stats().totals.refreshAttempts} refresh + + 0}> + + {([providerId, counts]) => ( + + {providerId}: {counts.rotations} rot • {counts.rateLimited} 429 • {counts.authExpired} auth + + )} + + + + ) + }} + 0} fallback={No Plugins}> {plugins().length} Plugins diff --git a/packages/opencode/src/cli/cmd/tui/context/sync.tsx b/packages/opencode/src/cli/cmd/tui/context/sync.tsx index f74f787db8c..4428677fe21 100644 --- a/packages/opencode/src/cli/cmd/tui/context/sync.tsx +++ b/packages/opencode/src/cli/cmd/tui/context/sync.tsx @@ -14,6 +14,8 @@ import type { SessionStatus, ProviderListResponse, ProviderAuthMethod, + CredentialRecordMeta, + RotationStatsSnapshot, VcsInfo, } from "@opencode-ai/sdk/v2" import { createStore, produce, reconcile } from "solid-js/store" @@ -35,6 +37,8 @@ export const { use: useSync, provider: SyncProvider } = createSimpleContext({ provider_default: Record provider_next: ProviderListResponse provider_auth: Record + credential: CredentialRecordMeta[] + rotation_stats: RotationStatsSnapshot | undefined agent: Agent[] command: Command[] permission: { @@ -71,6 +75,8 @@ export const { use: useSync, provider: SyncProvider } = createSimpleContext({ connected: [], }, provider_auth: {}, + credential: [], + rotation_stats: undefined, config: {}, status: "loading", agent: [], @@ -288,6 +294,8 @@ export const { use: useSync, provider: SyncProvider } = createSimpleContext({ sdk.client.formatter.status().then((x) => setStore("formatter", x.data!)), sdk.client.session.status().then((x) => setStore("session_status", x.data!)), sdk.client.provider.auth().then((x) => setStore("provider_auth", x.data ?? {})), + sdk.client.credential.list().then((x) => setStore("credential", x.data ?? [])), + sdk.client.debug.rotation().then((x) => setStore("rotation_stats", x.data)), sdk.client.vcs.get().then((x) => setStore("vcs", x.data)), sdk.client.path.get().then((x) => setStore("path", x.data!)), ]).then(() => { diff --git a/packages/opencode/src/config/config.ts b/packages/opencode/src/config/config.ts index 73fc1c77352..f9d8fff0bdd 100644 --- a/packages/opencode/src/config/config.ts +++ b/packages/opencode/src/config/config.ts @@ -11,13 +11,13 @@ import fs from "fs/promises" import { lazy } from "../util/lazy" import { NamedError } from "@opencode-ai/util/error" import { Flag } from "../flag/flag" -import { Auth } from "../auth" import { type ParseError as JsoncParseError, parse as parseJsonc, printParseErrorCode } from "jsonc-parser" import { Instance } from "../project/instance" import { LSPServer } from "../lsp/server" import { BunProc } from "@/bun" import { Installation } from "@/installation" import { ConfigMarkdown } from "./markdown" +import { CredentialStore, CredentialsMigrate } from "@/credentials" export namespace Config { const log = Log.create({ service: "config" }) @@ -34,7 +34,8 @@ export namespace Config { } export const state = Instance.state(async () => { - const auth = await Auth.all() + await CredentialsMigrate.migrateIfNeeded() + const { records: credentialRecords } = await CredentialStore.listAll() let result = await global() // Override with custom config if provided @@ -55,12 +56,19 @@ export namespace Config { log.debug("loaded custom config from OPENCODE_CONFIG_CONTENT") } - for (const [key, value] of Object.entries(auth)) { - if (value.type === "wellknown") { - process.env[value.key] = value.token - const wellknown = (await fetch(`${key}/.well-known/opencode`).then((x) => x.json())) as any - result = mergeConfigWithPlugins(result, await load(JSON.stringify(wellknown.config ?? {}), process.cwd())) - } + for (const record of credentialRecords) { + if (record.meta.kind !== "wellknown") continue + const secret = await CredentialStore.decryptSecret(record) + if (!secret || typeof secret !== "object") continue + if (!("envKey" in secret) || !("token" in secret)) continue + + const envKey = String((secret as any).envKey) + const token = String((secret as any).token) + process.env[envKey] = token + + const providerUrl = record.meta.providerId + const wellknown = (await fetch(`${providerUrl}/.well-known/opencode`).then((x) => x.json())) as any + result = mergeConfigWithPlugins(result, await load(JSON.stringify(wellknown.config ?? {}), process.cwd())) } result.agent = result.agent || {} @@ -584,16 +592,44 @@ export namespace Config { }) export type Layout = z.infer - export const Provider = ModelsDev.Provider.partial() - .extend({ - whitelist: z.array(z.string()).optional(), - blacklist: z.array(z.string()).optional(), - models: z.record(z.string(), ModelsDev.Model.partial()).optional(), - options: z - .object({ - apiKey: z.string().optional(), - baseURL: z.string().optional(), - enterpriseUrl: z.string().optional().describe("GitHub Enterprise URL for copilot authentication"), + export const Provider = ModelsDev.Provider.partial() + .extend({ + whitelist: z.array(z.string()).optional(), + blacklist: z.array(z.string()).optional(), + models: z.record(z.string(), ModelsDev.Model.partial()).optional(), + auth: z + .object({ + mode: z + .enum(["auto", "api", "subscription"]) + .optional() + .describe( + "Auth mode for this provider. 'auto' uses subscription OAuth when available, otherwise API key/env. 'api' forces API key/env. 'subscription' forces OAuth rotation.", + ), + namespace: z + .string() + .optional() + .describe("Credential namespace to use for this provider (default: 'default')."), + maxAttempts: z + .number() + .int() + .positive() + .optional() + .describe("Max credentials to try per request when rotating (default: try all eligible)."), + }) + .strict() + .optional() + .describe("Authentication settings for subscription OAuth rotation and API keys."), + options: z + .object({ + apiKey: z.string().optional(), + baseURL: z.string().optional(), + enterpriseUrl: z.string().optional().describe("GitHub Enterprise URL for copilot authentication"), + discoverModels: z + .boolean() + .optional() + .describe( + "Enable dynamic model discovery for OpenAI/OpenAI-compatible providers by calling the upstream /models endpoint and caching results under Global.Path.cache.", + ), setCacheKey: z.boolean().optional().describe("Enable promptCacheKey for this provider (default false)"), timeout: z .union([ diff --git a/packages/opencode/src/credentials/index.ts b/packages/opencode/src/credentials/index.ts new file mode 100644 index 00000000000..266ae66bb50 --- /dev/null +++ b/packages/opencode/src/credentials/index.ts @@ -0,0 +1,4 @@ +export { Credentials } from "./types" +export { CredentialStore } from "./store" +export { CredentialsMigrate } from "./migrate" +export { CredentialPool } from "./pool" diff --git a/packages/opencode/src/credentials/migrate.ts b/packages/opencode/src/credentials/migrate.ts new file mode 100644 index 00000000000..97c9c90f459 --- /dev/null +++ b/packages/opencode/src/credentials/migrate.ts @@ -0,0 +1,176 @@ +import path from "path" +import crypto from "crypto" +import z from "zod" +import { Global } from "@/global" +import { VaultFS } from "@/vault/fs" +import { CredentialStore } from "./store" + +const LegacyAuthOauth = z.object({ + type: z.literal("oauth"), + refresh: z.string(), + access: z.string(), + expires: z.number(), + enterpriseUrl: z.string().optional(), +}) + +const LegacyAuthApi = z.object({ + type: z.literal("api"), + key: z.string(), +}) + +const LegacyAuthWellKnown = z.object({ + type: z.literal("wellknown"), + key: z.string(), + token: z.string(), +}) + +const LegacyAuthInfo = z.discriminatedUnion("type", [LegacyAuthOauth, LegacyAuthApi, LegacyAuthWellKnown]) + +const LegacyMcpAuthEntry = z + .object({ + tokens: z + .object({ + accessToken: z.string(), + refreshToken: z.string().optional(), + expiresAt: z.number().optional(), + scope: z.string().optional(), + }) + .optional(), + clientInfo: z + .object({ + clientId: z.string(), + clientSecret: z.string().optional(), + clientIdIssuedAt: z.number().optional(), + clientSecretExpiresAt: z.number().optional(), + }) + .optional(), + codeVerifier: z.string().optional(), + }) + .strict() + +export namespace CredentialsMigrate { + const LEGACY_AUTH_PATH = path.join(Global.Path.data, "auth.json") + const LEGACY_MCP_AUTH_PATH = path.join(Global.Path.data, "mcp-auth.json") + const MIGRATIONS_PATH = path.join(Global.Path.data, "credentials", "migrations.json") + let didRun = false + let inFlight: Promise | undefined + + const MigrationState = z + .object({ + version: z.literal(1), + legacyAuth: z.record(z.string(), z.string()).default({}), + legacyMcpAuth: z.record(z.string(), z.string()).default({}), + }) + .strict() + type MigrationState = z.infer + + function fingerprint(raw: unknown): string { + const json = JSON.stringify(raw) ?? "" + return crypto.createHash("sha256").update(json).digest("hex") + } + + async function loadState(): Promise { + const json = await VaultFS.readJson(MIGRATIONS_PATH) + const parsed = MigrationState.safeParse(json) + if (parsed.success) return parsed.data + return { version: 1, legacyAuth: {}, legacyMcpAuth: {} } + } + + async function saveState(state: MigrationState): Promise { + await VaultFS.atomicWriteJson(MIGRATIONS_PATH, state, 0o600) + } + + export async function migrateIfNeeded(): Promise { + if (didRun) return + if (inFlight) return inFlight + inFlight = (async () => { + const state = await loadState() + const changed = (await migrateLegacyAuth(state)) || (await migrateLegacyMcpAuth(state)) + if (changed) { + await saveState(state) + } + })() + try { + await inFlight + didRun = true + } finally { + inFlight = undefined + } + } + + async function migrateLegacyAuth(state: MigrationState): Promise { + const legacy = await VaultFS.readJson>(LEGACY_AUTH_PATH) + if (!legacy) return false + let changed = false + + for (const [providerId, raw] of Object.entries(legacy)) { + const parsed = LegacyAuthInfo.safeParse(raw) + if (!parsed.success) continue + + const info = parsed.data + const sig = fingerprint(raw) + if (state.legacyAuth[providerId] === sig) continue + + if (info.type === "api") { + await CredentialStore.upsertSingleton({ + providerId, + namespace: "default", + kind: "api", + label: "default", + secret: { apiKey: info.key }, + }) + } else if (info.type === "oauth") { + await CredentialStore.upsertSingleton({ + providerId, + namespace: "default", + kind: "oauth", + label: "migrated", + secret: { + accessToken: info.access, + refreshToken: info.refresh || undefined, + expiresAt: info.expires, + extra: info.enterpriseUrl ? { enterpriseUrl: info.enterpriseUrl } : undefined, + }, + }) + } else if (info.type === "wellknown") { + await CredentialStore.upsertSingleton({ + providerId, + namespace: "default", + kind: "wellknown", + label: "default", + secret: { envKey: info.key, token: info.token }, + }) + } + + state.legacyAuth[providerId] = sig + changed = true + } + return changed + } + + async function migrateLegacyMcpAuth(state: MigrationState): Promise { + const legacy = await VaultFS.readJson>(LEGACY_MCP_AUTH_PATH) + if (!legacy) return false + let changed = false + + for (const [mcpName, raw] of Object.entries(legacy)) { + const parsed = LegacyMcpAuthEntry.safeParse(raw) + if (!parsed.success) continue + + const sig = fingerprint(raw) + if (state.legacyMcpAuth[mcpName] === sig) continue + + await CredentialStore.upsertSingleton({ + providerId: `mcp:${mcpName}`, + namespace: "default", + kind: "mcp", + label: mcpName, + secret: { entry: parsed.data }, + }) + + state.legacyMcpAuth[mcpName] = sig + changed = true + } + return changed + } +} diff --git a/packages/opencode/src/credentials/pool.ts b/packages/opencode/src/credentials/pool.ts new file mode 100644 index 00000000000..6381532f15d --- /dev/null +++ b/packages/opencode/src/credentials/pool.ts @@ -0,0 +1,55 @@ +import path from "path" +import { Global } from "@/global" +import { VaultFS } from "@/vault/fs" +import { VaultLock } from "@/vault/lock" + +export namespace CredentialPool { + const ROOT = path.join(Global.Path.data, "credentials") + const POOLS_DIR = path.join(ROOT, "pools") + const LOCK_PATH = path.join(ROOT, ".lock") + + function poolPath(providerId: string, namespace: string) { + const safeProvider = encodeURIComponent(providerId) + const safeNamespace = encodeURIComponent(namespace) + return path.join(POOLS_DIR, safeProvider, `${safeNamespace}.json`) + } + + async function loadIds(filePath: string): Promise { + const json = await VaultFS.readJson(filePath) + if (!Array.isArray(json)) return undefined + return json.filter((x) => typeof x === "string") as string[] + } + + export async function getOrderedIds(providerId: string, namespace: string, eligibleIds: string[]): Promise { + const filePath = poolPath(providerId, namespace) + const eligible = new Set(eligibleIds) + + return VaultLock.withLock(LOCK_PATH, async () => { + await VaultFS.ensureDir(path.dirname(filePath)) + const current = (await loadIds(filePath)) ?? [] + + const next: string[] = [] + for (const id of current) { + if (eligible.has(id)) next.push(id) + } + for (const id of eligibleIds) { + if (!next.includes(id)) next.push(id) + } + + await VaultFS.atomicWriteJson(filePath, next, 0o600) + return next + }) + } + + export async function moveToBack(providerId: string, namespace: string, id: string): Promise { + const filePath = poolPath(providerId, namespace) + await VaultLock.withLock(LOCK_PATH, async () => { + await VaultFS.ensureDir(path.dirname(filePath)) + const current = (await loadIds(filePath)) ?? [] + const filtered = current.filter((x) => x !== id) + filtered.push(id) + await VaultFS.atomicWriteJson(filePath, filtered, 0o600) + }) + } +} + diff --git a/packages/opencode/src/credentials/store.ts b/packages/opencode/src/credentials/store.ts new file mode 100644 index 00000000000..5086f0c6a17 --- /dev/null +++ b/packages/opencode/src/credentials/store.ts @@ -0,0 +1,383 @@ +import path from "path" +import fs from "fs/promises" +import { ulid } from "ulid" +import { Global } from "@/global" +import { VaultKey } from "@/vault/key" +import { VaultCrypto } from "@/vault/crypto" +import { VaultFS } from "@/vault/fs" +import { VaultLock } from "@/vault/lock" +import { Credentials } from "./types" +import { Log } from "@/util/log" +import z from "zod" + +type PutInput = { + id?: string + providerId: string + namespace?: string + label?: string + kind: Credentials.Kind + secret: Credentials.Secret +} + +type UpsertSingletonInput = { + providerId: string + namespace: string + label: string + kind: Credentials.Kind + secret: Credentials.Secret +} + +export namespace CredentialStore { + const log = Log.create({ service: "credentials.store" }) + + const ROOT = path.join(Global.Path.data, "credentials") + const RECORDS_DIR = path.join(ROOT, "records") + const LOCK_PATH = path.join(ROOT, ".lock") + const INDEX_PATH = path.join(ROOT, "index.json") + + const IndexFile = z + .object({ + version: z.literal(1), + byProvider: z.record(z.string(), z.record(z.string(), z.array(z.string()))), + }) + .strict() + type IndexFile = z.infer + + const DEFAULT_INDEX: IndexFile = { version: 1, byProvider: {} } + const INDEX_CACHE_TTL_MS = 2_000 + let indexCache: { loadedAt: number; value: IndexFile } | undefined + + async function ensureDirs() { + await VaultFS.ensureDir(RECORDS_DIR) + } + + function recordPath(id: string) { + return path.join(RECORDS_DIR, `${id}.json`) + } + + export async function hasAnyRecords(): Promise { + await ensureDirs() + const entries = await fs.readdir(RECORDS_DIR).catch(() => []) + return entries.some((x) => x.endsWith(".json")) + } + + function cacheIndex(next: IndexFile): IndexFile { + indexCache = { loadedAt: Date.now(), value: next } + return next + } + + async function readIndexFromDisk(): Promise { + const json = await VaultFS.readJson(INDEX_PATH) + const parsed = IndexFile.safeParse(json) + return parsed.success ? parsed.data : undefined + } + + async function rebuildIndex(): Promise { + const { records } = await listAll() + const byProvider: IndexFile["byProvider"] = {} + for (const record of records) { + const provider = record.meta.providerId + const ns = record.meta.namespace + byProvider[provider] ??= {} + byProvider[provider][ns] ??= [] + byProvider[provider][ns].push(record.meta.id) + } + + const index: IndexFile = { version: 1, byProvider } + await VaultLock.withLock(LOCK_PATH, async () => { + await VaultFS.atomicWriteJson(INDEX_PATH, index, 0o600) + }) + return cacheIndex(index) + } + + async function rebuildIndexLocked(): Promise { + const { records } = await listAll() + const byProvider: IndexFile["byProvider"] = {} + for (const record of records) { + const provider = record.meta.providerId + const ns = record.meta.namespace + byProvider[provider] ??= {} + byProvider[provider][ns] ??= [] + byProvider[provider][ns].push(record.meta.id) + } + const index: IndexFile = { version: 1, byProvider } + await VaultFS.atomicWriteJson(INDEX_PATH, index, 0o600) + return cacheIndex(index) + } + + async function loadIndex(opts?: { force?: boolean }): Promise { + const now = Date.now() + if (!opts?.force && indexCache && now - indexCache.loadedAt < INDEX_CACHE_TTL_MS) { + return indexCache.value + } + + const onDisk = await readIndexFromDisk() + if (onDisk) return cacheIndex(onDisk) + return rebuildIndex() + } + + async function loadIndexLocked(): Promise { + const onDisk = await readIndexFromDisk() + if (onDisk) return cacheIndex(onDisk) + return rebuildIndexLocked() + } + + function indexAdd(index: IndexFile, input: { providerId: string; namespace: string; id: string }) { + index.byProvider[input.providerId] ??= {} + index.byProvider[input.providerId][input.namespace] ??= [] + const ids = index.byProvider[input.providerId][input.namespace] + if (!ids.includes(input.id)) ids.push(input.id) + } + + function indexRemove(index: IndexFile, input: { providerId: string; namespace: string; id: string }) { + const ns = index.byProvider[input.providerId]?.[input.namespace] + if (!ns) return + index.byProvider[input.providerId][input.namespace] = ns.filter((x) => x !== input.id) + } + + export async function listAll(): Promise<{ + records: Credentials.RecordFile[] + errors: Array<{ file: string; error: unknown }> + }> { + await ensureDirs() + const glob = new Bun.Glob("*.json") + const records: Credentials.RecordFile[] = [] + const errors: Array<{ file: string; error: unknown }> = [] + + for await (const rel of glob.scan({ cwd: RECORDS_DIR, dot: false, onlyFiles: true })) { + const file = path.join(RECORDS_DIR, rel) + const json = await VaultFS.readJson(file) + const parsed = Credentials.RecordFile.safeParse(json) + if (!parsed.success) { + errors.push({ file, error: parsed.error }) + continue + } + records.push(parsed.data) + } + + if (errors.length > 0) { + log.error("credential record parse errors", { count: errors.length }) + } + + return { records, errors } + } + + export async function getRecordFile(id: string): Promise { + await ensureDirs() + const json = await VaultFS.readJson(recordPath(id)) + const parsed = Credentials.RecordFile.safeParse(json) + if (!parsed.success) return undefined + return parsed.data + } + + export async function decryptSecret(record: Credentials.RecordFile): Promise { + const key = await VaultKey.load() + return VaultCrypto.decryptJson(key, record.secret) as Credentials.Secret + } + + export async function put(input: PutInput): Promise { + await ensureDirs() + const now = Date.now() + const id = input.id ?? ulid() + const namespace = input.namespace ?? "default" + const key = await VaultKey.load() + + const record: Credentials.RecordFile = { + meta: { + id, + providerId: input.providerId, + namespace, + label: input.label, + kind: input.kind, + createdAt: now, + updatedAt: now, + health: { + successCount: 0, + failureCount: 0, + }, + }, + secret: VaultCrypto.encryptJson(key, input.secret), + } + + await VaultLock.withLock(LOCK_PATH, async () => { + await VaultFS.atomicWriteJson(recordPath(id), record, 0o600) + const index = await loadIndexLocked() + indexAdd(index, { providerId: input.providerId, namespace, id }) + await VaultFS.atomicWriteJson(INDEX_PATH, index, 0o600) + cacheIndex(index) + }) + + return record + } + + export async function update( + id: string, + patch: Partial> & { meta?: Partial }, + ) { + return updateWith(id, (existing) => ({ + ...existing, + ...patch, + meta: { + ...existing.meta, + ...(patch.meta ?? {}), + }, + })) + } + + export async function remove(id: string): Promise { + await ensureDirs() + await VaultLock.withLock(LOCK_PATH, async () => { + const before = await getRecordFile(id) + await fs.rm(recordPath(id), { force: true }) + if (!before) { + await rebuildIndexLocked() + return + } + + const index = await loadIndexLocked() + indexRemove(index, { providerId: before.meta.providerId, namespace: before.meta.namespace, id }) + await VaultFS.atomicWriteJson(INDEX_PATH, index, 0o600) + cacheIndex(index) + }) + } + + export async function findByProvider(providerId: string, namespace?: string): Promise { + const index = await loadIndex() + const namespaces = index.byProvider[providerId] ?? {} + const ids = namespace ? namespaces[namespace] ?? [] : Object.values(namespaces).flat() + if (ids.length === 0) return [] + + const out: Credentials.RecordFile[] = [] + for (const id of ids) { + const record = await getRecordFile(id) + if (record) out.push(record) + } + return out + } + + export async function upsertSingleton(input: UpsertSingletonInput): Promise { + await ensureDirs() + const now = Date.now() + const key = await VaultKey.load() + + return VaultLock.withLock(LOCK_PATH, async () => { + let existing: Credentials.RecordFile | undefined + + const index = await loadIndexLocked() + const ids = index.byProvider[input.providerId]?.[input.namespace] ?? [] + for (const id of ids) { + const record = await getRecordFile(id) + if (!record) continue + if ( + record.meta.providerId === input.providerId && + record.meta.namespace === input.namespace && + record.meta.kind === input.kind && + (record.meta.label ?? "") === input.label + ) { + existing = record + break + } + } + + const id = existing?.meta.id ?? ulid() + const record: Credentials.RecordFile = { + meta: { + id, + providerId: input.providerId, + namespace: input.namespace, + label: input.label, + kind: input.kind, + createdAt: existing?.meta.createdAt ?? now, + updatedAt: now, + health: existing?.meta.health ?? { successCount: 0, failureCount: 0 }, + }, + secret: VaultCrypto.encryptJson(key, input.secret), + } + + await VaultFS.atomicWriteJson(recordPath(id), record, 0o600) + indexAdd(index, { providerId: input.providerId, namespace: input.namespace, id }) + await VaultFS.atomicWriteJson(INDEX_PATH, index, 0o600) + cacheIndex(index) + return record + }) + } + + export async function updateSecret(id: string, secret: Credentials.Secret) { + const key = await VaultKey.load() + return updateWith(id, (existing) => ({ ...existing, secret: VaultCrypto.encryptJson(key, secret) })) + } + + export async function updateHealth(id: string, patch: Partial) { + return updateWith(id, (existing) => ({ + ...existing, + meta: { + ...existing.meta, + health: { ...existing.meta.health, ...patch }, + }, + })) + } + + export async function recordOutcome(input: { + id: string + statusCode: number + ok: boolean + cooldownUntil?: number + }) { + return updateWith(input.id, (existing) => { + const now = Date.now() + const prevCooldown = + existing.meta.health.cooldownUntil && existing.meta.health.cooldownUntil > now + ? existing.meta.health.cooldownUntil + : undefined + const cooldownUntil = input.ok ? undefined : input.cooldownUntil ?? prevCooldown + + return { + ...existing, + meta: { + ...existing.meta, + health: { + ...existing.meta.health, + cooldownUntil, + lastStatusCode: input.statusCode, + lastErrorAt: input.ok ? undefined : now, + successCount: existing.meta.health.successCount + (input.ok ? 1 : 0), + failureCount: existing.meta.health.failureCount + (input.ok ? 0 : 1), + }, + }, + } + }) + } + + export async function updateWith( + id: string, + updater: (existing: Credentials.RecordFile) => Credentials.RecordFile | Promise, + ): Promise { + await ensureDirs() + return VaultLock.withLock(LOCK_PATH, async () => { + const json = await VaultFS.readJson(recordPath(id)) + const parsed = Credentials.RecordFile.safeParse(json) + if (!parsed.success) { + if (!json) return undefined + throw new Error(`Invalid credential record: ${recordPath(id)}`) + } + + const existing = parsed.data + const nextRaw = await updater(existing) + const next: Credentials.RecordFile = { + ...nextRaw, + meta: { + ...nextRaw.meta, + id: existing.meta.id, + providerId: existing.meta.providerId, + namespace: existing.meta.namespace, + kind: existing.meta.kind, + createdAt: existing.meta.createdAt, + updatedAt: Date.now(), + }, + } + + await VaultFS.atomicWriteJson(recordPath(id), next, 0o600) + return next + }) + } +} diff --git a/packages/opencode/src/credentials/types.ts b/packages/opencode/src/credentials/types.ts new file mode 100644 index 00000000000..1bf724044f6 --- /dev/null +++ b/packages/opencode/src/credentials/types.ts @@ -0,0 +1,78 @@ +import z from "zod" +import { type VaultEncryptedBlobV1 } from "@/vault/crypto" + +export namespace Credentials { + export const Kind = z.enum(["oauth", "api", "wellknown", "mcp"]).meta({ ref: "CredentialKind" }) + export type Kind = z.infer + + export const Health = z + .object({ + cooldownUntil: z.number().optional(), + lastStatusCode: z.number().optional(), + lastErrorAt: z.number().optional(), + successCount: z.number().default(0), + failureCount: z.number().default(0), + }) + .strict() + .default(() => ({ successCount: 0, failureCount: 0 })) + .meta({ ref: "CredentialHealth" }) + export type Health = z.infer + + export const RecordMeta = z + .object({ + id: z.string(), + providerId: z.string(), + namespace: z.string().default("default"), + label: z.string().optional(), + kind: Kind, + createdAt: z.number(), + updatedAt: z.number(), + health: Health, + }) + .strict() + .meta({ ref: "CredentialRecordMeta" }) + export type RecordMeta = z.infer + + export const EncryptedBlob = z + .object({ + v: z.literal(1), + alg: z.literal("AES-256-GCM"), + nonce_b64: z.string(), + tag_b64: z.string(), + data_b64: z.string(), + }) + .strict() + .meta({ ref: "VaultEncryptedBlobV1" }) + export type EncryptedBlob = VaultEncryptedBlobV1 + + export const RecordFile = z + .object({ + meta: RecordMeta, + secret: EncryptedBlob, + }) + .strict() + .meta({ ref: "CredentialRecordFile" }) + export type RecordFile = z.infer + + export type OAuthSecret = { + accessToken: string + refreshToken?: string + expiresAt?: number + extra?: Record + } + + export type ApiSecret = { + apiKey: string + } + + export type WellknownSecret = { + envKey: string + token: string + } + + export type McpSecret = { + entry: unknown + } + + export type Secret = OAuthSecret | ApiSecret | WellknownSecret | McpSecret +} diff --git a/packages/opencode/src/inference/rotating-fetch.ts b/packages/opencode/src/inference/rotating-fetch.ts new file mode 100644 index 00000000000..e9dd29cdc92 --- /dev/null +++ b/packages/opencode/src/inference/rotating-fetch.ts @@ -0,0 +1,262 @@ +import { CredentialPool, CredentialStore, CredentialsMigrate } from "@/credentials" +import type { RotateDecision } from "@/provider-auth/adapter" +import { ProviderAuthRegistry } from "@/provider-auth/registry" +import { RotationStats } from "@/inference/rotation-stats" +import { Log } from "@/util/log" + +const log = Log.create({ service: "inference.rotating-fetch" }) + +function buildRequestWithUrlAndHeaders(original: Request, url: URL, headers: Headers): Request { + const init: RequestInit = { + method: original.method, + headers, + signal: original.signal, + } + + if (original.body) { + init.body = original.body as any + // Node fetch requires duplex when sending a stream body. + ;(init as any).duplex = (original as any).duplex ?? "half" + } + + return new Request(url, init as any) +} + +function parseRetryAfterMs(resp: Response): number | undefined { + const msHeader = resp.headers.get("retry-after-ms") ?? resp.headers.get("Retry-After-Ms") + if (msHeader) { + const ms = Number(msHeader.trim()) + if (Number.isFinite(ms) && ms >= 0) return Math.floor(ms) + } + const raw = resp.headers.get("retry-after") ?? resp.headers.get("Retry-After") + if (!raw) return undefined + const trimmed = raw.trim() + if (!trimmed) return undefined + const seconds = Number(trimmed) + if (Number.isFinite(seconds) && seconds >= 0) return Math.floor(seconds * 1000) + const date = Date.parse(trimmed) + if (!Number.isNaN(date)) return Math.max(0, date - Date.now()) + return undefined +} + +function stripApiKeyHeaders(headers: Headers): void { + for (const name of ["authorization", "x-api-key", "api-key", "x-goog-api-key"]) { + headers.delete(name) + } +} + +function defaultClassifyResponse(resp: Response): RotateDecision { + if (resp.status === 401 || resp.status === 403) { + return { rotatable: true, isAuthExpired: true, reason: `auth_expired:${resp.status}` } + } + if (resp.status === 429) { + return { rotatable: true, cooldownMs: parseRetryAfterMs(resp) ?? 30_000, reason: "rate_limited" } + } + if (resp.status === 503 || resp.status === 529) { + return { rotatable: true, cooldownMs: parseRetryAfterMs(resp) ?? 30_000, reason: `overloaded:${resp.status}` } + } + return { rotatable: false, reason: `status:${resp.status}` } +} + +async function classifyResponse( + adapter: { classifyResponse?: (r: Response) => Promise | RotateDecision }, + resp: Response, +): Promise { + if (!adapter.classifyResponse) return defaultClassifyResponse(resp) + try { + return await adapter.classifyResponse(resp.clone()) + } catch { + return defaultClassifyResponse(resp) + } +} + +export namespace RotatingFetch { + export type FetchFn = (input: RequestInfo | URL, init?: RequestInit) => Promise + + export type Options = { + providerId: string + namespace: string + maxAttempts?: number + } + + export function create( + baseFetch: FetchFn, + opts: Options, + ): FetchFn { + return async (input, init) => { + await CredentialsMigrate.migrateIfNeeded() + + const canonicalProviderId = ProviderAuthRegistry.resolveProviderId(opts.providerId) + const adapter = ProviderAuthRegistry.getAdapter(canonicalProviderId) + if (!adapter) return baseFetch(input, init) + + const providerIds = ProviderAuthRegistry.equivalentProviderIds(opts.providerId) + const records = ( + await Promise.all(providerIds.map((id) => CredentialStore.findByProvider(id, opts.namespace))) + ) + .flat() + .filter((r) => r.meta.kind === "oauth") + if (records.length === 0) return baseFetch(input, init) + + RotationStats.recordRequest(canonicalProviderId) + + records.sort((a, b) => { + const aDefault = (a.meta.label ?? "") === "default" + const bDefault = (b.meta.label ?? "") === "default" + if (aDefault !== bDefault) return aDefault ? -1 : 1 + if (a.meta.createdAt !== b.meta.createdAt) return a.meta.createdAt - b.meta.createdAt + return a.meta.id.localeCompare(b.meta.id) + }) + + const eligibleIds = records.map((r) => r.meta.id) + const orderedIds = await CredentialPool.getOrderedIds(canonicalProviderId, opts.namespace, eligibleIds) + + const recordById = new Map(records.map((r) => [r.meta.id, r] as const)) + const attempted = new Set() + const refreshed = new Set() + const maxAttempts = Math.min(opts.maxAttempts ?? orderedIds.length, orderedIds.length) + + const original = new Request(input, init) + let lastResponse: Response | undefined + + for (let attempt = 0; attempt < maxAttempts; attempt++) { + RotationStats.recordAttempt(canonicalProviderId) + + const now = Date.now() + const nextId = orderedIds.find((id) => { + if (attempted.has(id)) return false + const r = recordById.get(id) + if (!r) return false + if (r.meta.health.cooldownUntil && r.meta.health.cooldownUntil > now) return false + return true + }) + + const chosenId = + nextId ?? + orderedIds.find((id) => { + if (attempted.has(id)) return false + return recordById.has(id) + }) + + if (!chosenId) break + attempted.add(chosenId) + + const record = recordById.get(chosenId)! + const secret = await CredentialStore.decryptSecret(record) + + const request = original.clone() + const url = new URL(request.url) + const headers = new Headers(request.headers) + // Avoid leaking API keys if present; subscription auth should win for this attempt. + stripApiKeyHeaders(headers) + adapter.prepareRequest?.({ url, headers, request, secret }) + adapter.applyAuth(headers, secret) + + const attemptRequest = buildRequestWithUrlAndHeaders(request, url, headers) + const resp = await baseFetch(attemptRequest) + lastResponse = resp + + let activeResp = resp + let activeDecision = await classifyResponse(adapter, resp) + + if (activeDecision.rotatable && activeDecision.isAuthExpired) { + if (adapter.refresh && (secret as any)?.refreshToken && !refreshed.has(chosenId)) { + refreshed.add(chosenId) + RotationStats.recordRefreshAttempt(canonicalProviderId) + log.info("refreshing oauth credential", { + providerId: opts.providerId, + canonicalProviderId, + namespace: opts.namespace, + credentialId: chosenId, + label: record.meta.label, + }) + try { + const nextSecret = await adapter.refresh(secret) + await CredentialStore.updateSecret(chosenId, nextSecret) + RotationStats.recordRefreshSuccess(canonicalProviderId) + log.info("refreshed oauth credential", { + providerId: opts.providerId, + canonicalProviderId, + namespace: opts.namespace, + credentialId: chosenId, + label: record.meta.label, + }) + try { + resp.body?.cancel() + } catch {} + + const retryReq = original.clone() + const retryUrl = new URL(retryReq.url) + const retryHeaders = new Headers(retryReq.headers) + stripApiKeyHeaders(retryHeaders) + adapter.prepareRequest?.({ url: retryUrl, headers: retryHeaders, request: retryReq, secret: nextSecret }) + adapter.applyAuth(retryHeaders, nextSecret) + RotationStats.recordAttempt(canonicalProviderId) + const retryResp = await baseFetch(buildRequestWithUrlAndHeaders(retryReq, retryUrl, retryHeaders)) + lastResponse = retryResp + + const retryDecision = await classifyResponse(adapter, retryResp) + activeResp = retryResp + activeDecision = retryDecision + + if (!activeDecision.rotatable) { + await CredentialStore.recordOutcome({ id: chosenId, statusCode: activeResp.status, ok: activeResp.ok }) + return activeResp + } + } catch { + RotationStats.recordRefreshFailure(canonicalProviderId) + log.warn("oauth credential refresh failed; will rotate", { + providerId: opts.providerId, + canonicalProviderId, + namespace: opts.namespace, + credentialId: chosenId, + label: record.meta.label, + }) + // fall through to rotate on auth failure + } + } + } + + if (activeDecision.rotatable) { + RotationStats.recordRotation(canonicalProviderId, activeDecision.reason) + const cooldownMs = + activeDecision.cooldownMs ?? + (activeDecision.isAuthExpired ? 5 * 60_000 : parseRetryAfterMs(activeResp) ?? 30_000) + const cooldownUntil = Date.now() + cooldownMs + try { + activeResp.body?.cancel() + } catch {} + log.warn("rotating oauth credential", { + providerId: opts.providerId, + canonicalProviderId, + namespace: opts.namespace, + credentialId: chosenId, + label: record.meta.label, + status: activeResp.status, + reason: activeDecision.reason, + cooldownMs, + attempt: attempt + 1, + maxAttempts, + path: new URL(original.url).pathname, + }) + await CredentialStore.recordOutcome({ id: chosenId, statusCode: activeResp.status, ok: false, cooldownUntil }) + await CredentialPool.moveToBack(canonicalProviderId, opts.namespace, chosenId) + continue + } + + await CredentialStore.recordOutcome({ id: chosenId, statusCode: activeResp.status, ok: activeResp.ok }) + return activeResp + } + + log.error("oauth rotation exhausted credentials", { + providerId: opts.providerId, + canonicalProviderId, + namespace: opts.namespace, + attempted: Array.from(attempted), + maxAttempts, + }) + RotationStats.recordExhausted(canonicalProviderId) + return lastResponse ?? baseFetch(input, init) + } + } +} diff --git a/packages/opencode/src/inference/rotation-stats.ts b/packages/opencode/src/inference/rotation-stats.ts new file mode 100644 index 00000000000..7adb1070977 --- /dev/null +++ b/packages/opencode/src/inference/rotation-stats.ts @@ -0,0 +1,126 @@ +import z from "zod" + +type Counts = { + requests: number + attempts: number + rotations: number + exhausted: number + refreshAttempts: number + refreshSuccess: number + refreshFailure: number + rateLimited: number + authExpired: number +} + +function emptyCounts(): Counts { + return { + requests: 0, + attempts: 0, + rotations: 0, + exhausted: 0, + refreshAttempts: 0, + refreshSuccess: 0, + refreshFailure: 0, + rateLimited: 0, + authExpired: 0, + } +} + +function cloneCounts(c: Counts): Counts { + return { ...c } +} + +export namespace RotationStats { + export const Counts = z + .object({ + requests: z.number().int(), + attempts: z.number().int(), + rotations: z.number().int(), + exhausted: z.number().int(), + refreshAttempts: z.number().int(), + refreshSuccess: z.number().int(), + refreshFailure: z.number().int(), + rateLimited: z.number().int(), + authExpired: z.number().int(), + }) + .strict() + .meta({ ref: "RotationStatsCounts" }) + + export const Snapshot = z + .object({ + since: z.number().int(), + totals: Counts, + byProvider: z.record(z.string(), Counts), + }) + .strict() + .meta({ ref: "RotationStatsSnapshot" }) + export type Snapshot = z.infer + + const since = Date.now() + const totals: Counts = emptyCounts() + const byProvider = new Map() + + function getProvider(providerId: string): Counts { + const existing = byProvider.get(providerId) + if (existing) return existing + const next = emptyCounts() + byProvider.set(providerId, next) + return next + } + + export function recordRequest(providerId: string) { + totals.requests++ + getProvider(providerId).requests++ + } + + export function recordAttempt(providerId: string) { + totals.attempts++ + getProvider(providerId).attempts++ + } + + export function recordRotation(providerId: string, reason: string) { + totals.rotations++ + const p = getProvider(providerId) + p.rotations++ + + if (reason.includes("rate")) { + totals.rateLimited++ + p.rateLimited++ + } + if (reason.includes("auth_expired")) { + totals.authExpired++ + p.authExpired++ + } + } + + export function recordExhausted(providerId: string) { + totals.exhausted++ + getProvider(providerId).exhausted++ + } + + export function recordRefreshAttempt(providerId: string) { + totals.refreshAttempts++ + getProvider(providerId).refreshAttempts++ + } + + export function recordRefreshSuccess(providerId: string) { + totals.refreshSuccess++ + getProvider(providerId).refreshSuccess++ + } + + export function recordRefreshFailure(providerId: string) { + totals.refreshFailure++ + getProvider(providerId).refreshFailure++ + } + + export function snapshot(): Snapshot { + const out: Record> = {} + for (const [k, v] of byProvider.entries()) out[k] = cloneCounts(v) + return { + since, + totals: cloneCounts(totals), + byProvider: out, + } + } +} + diff --git a/packages/opencode/src/mcp/auth.ts b/packages/opencode/src/mcp/credentials.ts similarity index 64% rename from packages/opencode/src/mcp/auth.ts rename to packages/opencode/src/mcp/credentials.ts index 6ebb95698d7..d4ded4152fe 100644 --- a/packages/opencode/src/mcp/auth.ts +++ b/packages/opencode/src/mcp/credentials.ts @@ -1,9 +1,7 @@ -import path from "path" -import fs from "fs/promises" import z from "zod" -import { Global } from "../global" +import { CredentialStore, CredentialsMigrate } from "@/credentials" -export namespace McpAuth { +export namespace McpCredentials { export const Tokens = z.object({ accessToken: z.string(), refreshToken: z.string().optional(), @@ -29,11 +27,25 @@ export namespace McpAuth { }) export type Entry = z.infer - const filepath = path.join(Global.Path.data, "mcp-auth.json") + const KIND = "mcp" as const + + async function ensureMigrated() { + await CredentialsMigrate.migrateIfNeeded() + } + + function providerId(mcpName: string) { + return `mcp:${mcpName}` + } export async function get(mcpName: string): Promise { - const data = await all() - return data[mcpName] + await ensureMigrated() + const matches = await CredentialStore.findByProvider(providerId(mcpName), "default") + const record = matches.find((r) => r.meta.kind === KIND) + if (!record) return undefined + const secret = await CredentialStore.decryptSecret(record) + if (!secret || typeof secret !== "object" || !("entry" in secret)) return undefined + const parsed = Entry.safeParse((secret as any).entry) + return parsed.success ? parsed.data : undefined } /** @@ -54,27 +66,37 @@ export namespace McpAuth { } export async function all(): Promise> { - const file = Bun.file(filepath) - return file.json().catch(() => ({})) + await ensureMigrated() + const { records } = await CredentialStore.listAll() + const result: Record = {} + for (const record of records) { + if (record.meta.kind !== KIND) continue + if (!record.meta.providerId.startsWith("mcp:")) continue + const mcpName = record.meta.providerId.slice("mcp:".length) + const entry = await get(mcpName) + if (entry) result[mcpName] = entry + } + return result } export async function set(mcpName: string, entry: Entry, serverUrl?: string): Promise { - const file = Bun.file(filepath) - const data = await all() - // Always update serverUrl if provided + await ensureMigrated() if (serverUrl) { entry.serverUrl = serverUrl } - await Bun.write(file, JSON.stringify({ ...data, [mcpName]: entry }, null, 2)) - await fs.chmod(file.name!, 0o600) + await CredentialStore.upsertSingleton({ + providerId: providerId(mcpName), + namespace: "default", + kind: KIND, + label: mcpName, + secret: { entry }, + }) } export async function remove(mcpName: string): Promise { - const file = Bun.file(filepath) - const data = await all() - delete data[mcpName] - await Bun.write(file, JSON.stringify(data, null, 2)) - await fs.chmod(file.name!, 0o600) + await ensureMigrated() + const matches = await CredentialStore.findByProvider(providerId(mcpName), "default") + await Promise.all(matches.filter((r) => r.meta.kind === KIND).map((r) => CredentialStore.remove(r.meta.id))) } export async function updateTokens(mcpName: string, tokens: Tokens, serverUrl?: string): Promise { @@ -97,10 +119,9 @@ export namespace McpAuth { export async function clearCodeVerifier(mcpName: string): Promise { const entry = await get(mcpName) - if (entry) { - delete entry.codeVerifier - await set(mcpName, entry) - } + if (!entry) return + delete entry.codeVerifier + await set(mcpName, entry) } export async function updateOAuthState(mcpName: string, oauthState: string): Promise { @@ -122,3 +143,4 @@ export namespace McpAuth { } } } + diff --git a/packages/opencode/src/mcp/index.ts b/packages/opencode/src/mcp/index.ts index 625809af9a8..24d3f5464a2 100644 --- a/packages/opencode/src/mcp/index.ts +++ b/packages/opencode/src/mcp/index.ts @@ -12,7 +12,7 @@ import { Instance } from "../project/instance" import { withTimeout } from "@/util/timeout" import { McpOAuthProvider } from "./oauth-provider" import { McpOAuthCallback } from "./oauth-callback" -import { McpAuth } from "./auth" +import { McpCredentials } from "./credentials" import open from "open" export namespace MCP { @@ -441,7 +441,7 @@ export namespace MCP { const oauthState = Array.from(crypto.getRandomValues(new Uint8Array(32))) .map((b) => b.toString(16).padStart(2, "0")) .join("") - await McpAuth.updateOAuthState(mcpName, oauthState) + await McpCredentials.updateOAuthState(mcpName, oauthState) // Create a new auth provider for this flow // OAuth config is optional - if not provided, we'll use auto-discovery @@ -499,7 +499,7 @@ export namespace MCP { } // Get the state that was already generated and stored in startAuth() - const oauthState = await McpAuth.getOAuthState(mcpName) + const oauthState = await McpCredentials.getOAuthState(mcpName) if (!oauthState) { throw new Error("OAuth state not found - this should not happen") } @@ -513,13 +513,13 @@ export namespace MCP { const code = await McpOAuthCallback.waitForCallback(oauthState) // Validate and clear the state - const storedState = await McpAuth.getOAuthState(mcpName) + const storedState = await McpCredentials.getOAuthState(mcpName) if (storedState !== oauthState) { - await McpAuth.clearOAuthState(mcpName) + await McpCredentials.clearOAuthState(mcpName) throw new Error("OAuth state mismatch - potential CSRF attack") } - await McpAuth.clearOAuthState(mcpName) + await McpCredentials.clearOAuthState(mcpName) // Finish auth return finishAuth(mcpName, code) @@ -540,7 +540,7 @@ export namespace MCP { await transport.finishAuth(authorizationCode) // Clear the code verifier after successful auth - await McpAuth.clearCodeVerifier(mcpName) + await McpCredentials.clearCodeVerifier(mcpName) // Now try to reconnect const cfg = await Config.get() @@ -569,10 +569,10 @@ export namespace MCP { * Remove OAuth credentials for an MCP server. */ export async function removeAuth(mcpName: string): Promise { - await McpAuth.remove(mcpName) + await McpCredentials.remove(mcpName) McpOAuthCallback.cancelPending(mcpName) pendingOAuthTransports.delete(mcpName) - await McpAuth.clearOAuthState(mcpName) + await McpCredentials.clearOAuthState(mcpName) log.info("removed oauth credentials", { mcpName }) } @@ -589,7 +589,7 @@ export namespace MCP { * Check if an MCP server has stored OAuth tokens. */ export async function hasStoredTokens(mcpName: string): Promise { - const entry = await McpAuth.get(mcpName) + const entry = await McpCredentials.get(mcpName) return !!entry?.tokens } } diff --git a/packages/opencode/src/mcp/oauth-callback.ts b/packages/opencode/src/mcp/oauth-callback.ts index bb3b56f2e95..288f402aea1 100644 --- a/packages/opencode/src/mcp/oauth-callback.ts +++ b/packages/opencode/src/mcp/oauth-callback.ts @@ -1,200 +1,20 @@ -import { Log } from "../util/log" -import { OAUTH_CALLBACK_PORT, OAUTH_CALLBACK_PATH } from "./oauth-provider" - -const log = Log.create({ service: "mcp.oauth-callback" }) - -const HTML_SUCCESS = ` - - - OpenCode - Authorization Successful - - - -
-

Authorization Successful

-

You can close this window and return to OpenCode.

-
- - -` - -const HTML_ERROR = (error: string) => ` - - - OpenCode - Authorization Failed - - - -
-

Authorization Failed

-

An error occurred during authorization.

-
${error}
-
- -` - -interface PendingAuth { - resolve: (code: string) => void - reject: (error: Error) => void - timeout: ReturnType -} +import { OAuthCallback } from "../oauth/callback" +import { OAUTH_CALLBACK_PATH, OAUTH_CALLBACK_PORT } from "./oauth-provider" export namespace McpOAuthCallback { - let server: ReturnType | undefined - const pendingAuths = new Map() - - const CALLBACK_TIMEOUT_MS = 5 * 60 * 1000 // 5 minutes - export async function ensureRunning(): Promise { - if (server) return - - const running = await isPortInUse() - if (running) { - log.info("oauth callback server already running on another instance", { port: OAUTH_CALLBACK_PORT }) - return - } - - server = Bun.serve({ - port: OAUTH_CALLBACK_PORT, - fetch(req) { - const url = new URL(req.url) - - if (url.pathname !== OAUTH_CALLBACK_PATH) { - return new Response("Not found", { status: 404 }) - } - - const code = url.searchParams.get("code") - const state = url.searchParams.get("state") - const error = url.searchParams.get("error") - const errorDescription = url.searchParams.get("error_description") - - log.info("received oauth callback", { hasCode: !!code, state, error }) - - // Enforce state parameter presence - if (!state) { - const errorMsg = "Missing required state parameter - potential CSRF attack" - log.error("oauth callback missing state parameter", { url: url.toString() }) - return new Response(HTML_ERROR(errorMsg), { - status: 400, - headers: { "Content-Type": "text/html" }, - }) - } - - if (error) { - const errorMsg = errorDescription || error - if (pendingAuths.has(state)) { - const pending = pendingAuths.get(state)! - clearTimeout(pending.timeout) - pendingAuths.delete(state) - pending.reject(new Error(errorMsg)) - } - return new Response(HTML_ERROR(errorMsg), { - headers: { "Content-Type": "text/html" }, - }) - } - - if (!code) { - return new Response(HTML_ERROR("No authorization code provided"), { - status: 400, - headers: { "Content-Type": "text/html" }, - }) - } - - // Validate state parameter - if (!pendingAuths.has(state)) { - const errorMsg = "Invalid or expired state parameter - potential CSRF attack" - log.error("oauth callback with invalid state", { state, pendingStates: Array.from(pendingAuths.keys()) }) - return new Response(HTML_ERROR(errorMsg), { - status: 400, - headers: { "Content-Type": "text/html" }, - }) - } - - const pending = pendingAuths.get(state)! - - clearTimeout(pending.timeout) - pendingAuths.delete(state) - pending.resolve(code) - - return new Response(HTML_SUCCESS, { - headers: { "Content-Type": "text/html" }, - }) - }, - }) - - log.info("oauth callback server started", { port: OAUTH_CALLBACK_PORT }) + await OAuthCallback.ensureRunning({ port: OAUTH_CALLBACK_PORT, pathname: OAUTH_CALLBACK_PATH }) } - export function waitForCallback(oauthState: string): Promise { - return new Promise((resolve, reject) => { - const timeout = setTimeout(() => { - if (pendingAuths.has(oauthState)) { - pendingAuths.delete(oauthState) - reject(new Error("OAuth callback timeout - authorization took too long")) - } - }, CALLBACK_TIMEOUT_MS) - - pendingAuths.set(oauthState, { resolve, reject, timeout }) - }) - } - - export function cancelPending(mcpName: string): void { - const pending = pendingAuths.get(mcpName) - if (pending) { - clearTimeout(pending.timeout) - pendingAuths.delete(mcpName) - pending.reject(new Error("Authorization cancelled")) - } + export function waitForCallback(key: string): Promise { + return OAuthCallback.waitForCallback({ port: OAUTH_CALLBACK_PORT, pathname: OAUTH_CALLBACK_PATH, key }) } - export async function isPortInUse(): Promise { - return new Promise((resolve) => { - Bun.connect({ - hostname: "127.0.0.1", - port: OAUTH_CALLBACK_PORT, - socket: { - open(socket) { - socket.end() - resolve(true) - }, - error() { - resolve(false) - }, - data() {}, - close() {}, - }, - }).catch(() => { - resolve(false) - }) - }) + export function cancelPending(key: string): void { + OAuthCallback.cancelPending({ port: OAUTH_CALLBACK_PORT, pathname: OAUTH_CALLBACK_PATH, key }) } export async function stop(): Promise { - if (server) { - server.stop() - server = undefined - log.info("oauth callback server stopped") - } - - for (const [name, pending] of pendingAuths) { - clearTimeout(pending.timeout) - pending.reject(new Error("OAuth callback server stopped")) - } - pendingAuths.clear() - } - - export function isRunning(): boolean { - return server !== undefined + await OAuthCallback.stop({ port: OAUTH_CALLBACK_PORT, pathname: OAUTH_CALLBACK_PATH }) } } diff --git a/packages/opencode/src/mcp/oauth-provider.ts b/packages/opencode/src/mcp/oauth-provider.ts index 35ead25e8be..b39fb6efa2b 100644 --- a/packages/opencode/src/mcp/oauth-provider.ts +++ b/packages/opencode/src/mcp/oauth-provider.ts @@ -5,7 +5,7 @@ import type { OAuthClientInformation, OAuthClientInformationFull, } from "@modelcontextprotocol/sdk/shared/auth.js" -import { McpAuth } from "./auth" +import { McpCredentials } from "./credentials" import { Log } from "../util/log" const log = Log.create({ service: "mcp.oauth" }) @@ -57,7 +57,7 @@ export class McpOAuthProvider implements OAuthClientProvider { // Check stored client info (from dynamic registration) // Use getForUrl to validate credentials are for the current server URL - const entry = await McpAuth.getForUrl(this.mcpName, this.serverUrl) + const entry = await McpCredentials.getForUrl(this.mcpName, this.serverUrl) if (entry?.clientInfo) { // Check if client secret has expired if (entry.clientInfo.clientSecretExpiresAt && entry.clientInfo.clientSecretExpiresAt < Date.now() / 1000) { @@ -75,7 +75,7 @@ export class McpOAuthProvider implements OAuthClientProvider { } async saveClientInformation(info: OAuthClientInformationFull): Promise { - await McpAuth.updateClientInfo( + await McpCredentials.updateClientInfo( this.mcpName, { clientId: info.client_id, @@ -93,7 +93,7 @@ export class McpOAuthProvider implements OAuthClientProvider { async tokens(): Promise { // Use getForUrl to validate tokens are for the current server URL - const entry = await McpAuth.getForUrl(this.mcpName, this.serverUrl) + const entry = await McpCredentials.getForUrl(this.mcpName, this.serverUrl) if (!entry?.tokens) return undefined return { @@ -108,7 +108,7 @@ export class McpOAuthProvider implements OAuthClientProvider { } async saveTokens(tokens: OAuthTokens): Promise { - await McpAuth.updateTokens( + await McpCredentials.updateTokens( this.mcpName, { accessToken: tokens.access_token, @@ -127,11 +127,11 @@ export class McpOAuthProvider implements OAuthClientProvider { } async saveCodeVerifier(codeVerifier: string): Promise { - await McpAuth.updateCodeVerifier(this.mcpName, codeVerifier) + await McpCredentials.updateCodeVerifier(this.mcpName, codeVerifier) } async codeVerifier(): Promise { - const entry = await McpAuth.get(this.mcpName) + const entry = await McpCredentials.get(this.mcpName) if (!entry?.codeVerifier) { throw new Error(`No code verifier saved for MCP server: ${this.mcpName}`) } @@ -139,11 +139,11 @@ export class McpOAuthProvider implements OAuthClientProvider { } async saveState(state: string): Promise { - await McpAuth.updateOAuthState(this.mcpName, state) + await McpCredentials.updateOAuthState(this.mcpName, state) } async state(): Promise { - const entry = await McpAuth.get(this.mcpName) + const entry = await McpCredentials.get(this.mcpName) if (!entry?.oauthState) { throw new Error(`No OAuth state saved for MCP server: ${this.mcpName}`) } diff --git a/packages/opencode/src/oauth/callback.ts b/packages/opencode/src/oauth/callback.ts new file mode 100644 index 00000000000..37ad9f884ca --- /dev/null +++ b/packages/opencode/src/oauth/callback.ts @@ -0,0 +1,214 @@ +import { Log } from "@/util/log" + +const log = Log.create({ service: "oauth.callback" }) + +const HTML_SUCCESS = ` + + + OpenCode - Authorization Successful + + + +
+

Authorization Successful

+

You can close this window and return to OpenCode.

+
+ + +` + +const HTML_ERROR = (error: string) => ` + + + OpenCode - Authorization Failed + + + +
+

Authorization Failed

+

An error occurred during authorization.

+
${error}
+
+ +` + +type PendingAuth = { + resolve: (code: string) => void + reject: (error: Error) => void + timeout: ReturnType +} + +type ServerKey = `${number}:${string}` + +type ServerState = { + server?: ReturnType + pending: Map +} + +function serverKey(port: number, pathname: string): ServerKey { + return `${port}:${pathname}` as const +} + +async function sleep(ms: number) { + await new Promise((r) => setTimeout(r, ms)) +} + +async function isPortInUse(port: number): Promise { + return new Promise((resolve) => { + Bun.connect({ + hostname: "127.0.0.1", + port, + socket: { + open(socket) { + socket.end() + resolve(true) + }, + error() { + resolve(false) + }, + data() {}, + close() {}, + }, + }).catch(() => resolve(false)) + }) +} + +export namespace OAuthCallback { + const servers = new Map() + const DEFAULT_TIMEOUT_MS = 5 * 60 * 1000 + + export async function ensureRunning(opts: { port: number; pathname: string }): Promise { + const key = serverKey(opts.port, opts.pathname) + const existing = servers.get(key) + if (existing?.server) return + + const running = await isPortInUse(opts.port) + if (running) { + log.info("oauth callback port already in use", { port: opts.port, pathname: opts.pathname }) + servers.set(key, { server: undefined, pending: new Map() }) + return + } + + const state: ServerState = { pending: new Map() } + state.server = Bun.serve({ + port: opts.port, + fetch(req) { + const url = new URL(req.url) + if (url.pathname !== opts.pathname) return new Response("Not found", { status: 404 }) + + const code = url.searchParams.get("code") + const stateParam = url.searchParams.get("state") + const error = url.searchParams.get("error") + const errorDescription = url.searchParams.get("error_description") + + log.info("oauth callback received", { port: opts.port, pathname: opts.pathname, hasCode: !!code, stateParam, error }) + + if (error) { + const errorMsg = errorDescription || error + if (stateParam && state.pending.has(stateParam)) { + const pending = state.pending.get(stateParam)! + clearTimeout(pending.timeout) + state.pending.delete(stateParam) + pending.reject(new Error(errorMsg)) + } + return new Response(HTML_ERROR(errorMsg), { headers: { "Content-Type": "text/html" } }) + } + + if (!code) { + return new Response(HTML_ERROR("No authorization code provided"), { + status: 400, + headers: { "Content-Type": "text/html" }, + }) + } + + // Find pending auth by state, or fallback to single pending auth. + let pending: PendingAuth | undefined + let pendingKey: string | undefined + if (stateParam && state.pending.has(stateParam)) { + pending = state.pending.get(stateParam)! + pendingKey = stateParam + } else if (!stateParam && state.pending.size === 1) { + const [k, v] = state.pending.entries().next().value as [string, PendingAuth] + pending = v + pendingKey = k + log.info("oauth callback missing state; using single pending auth", { key: k }) + } + + if (!pending || !pendingKey) { + const errorMsg = !stateParam + ? "No state parameter provided and multiple pending authorizations" + : "Unknown or expired authorization request" + return new Response(HTML_ERROR(errorMsg), { status: 400, headers: { "Content-Type": "text/html" } }) + } + + clearTimeout(pending.timeout) + state.pending.delete(pendingKey) + pending.resolve(code) + + return new Response(HTML_SUCCESS, { headers: { "Content-Type": "text/html" } }) + }, + }) + + servers.set(key, state) + log.info("oauth callback server started", { port: opts.port, pathname: opts.pathname }) + } + + export async function waitForCallback(opts: { port: number; pathname: string; key: string; timeoutMs?: number }): Promise { + const key = serverKey(opts.port, opts.pathname) + const state = servers.get(key) + if (!state) throw new Error("OAuth callback server not initialized. Call ensureRunning() first.") + + const timeoutMs = opts.timeoutMs ?? DEFAULT_TIMEOUT_MS + return new Promise((resolve, reject) => { + const timeout = setTimeout(() => { + if (state.pending.has(opts.key)) { + state.pending.delete(opts.key) + reject(new Error("OAuth callback timeout - authorization took too long")) + } + }, timeoutMs) + + state.pending.set(opts.key, { resolve, reject, timeout }) + }) + } + + export function cancelPending(opts: { port: number; pathname: string; key: string }): void { + const key = serverKey(opts.port, opts.pathname) + const state = servers.get(key) + if (!state) return + const pending = state.pending.get(opts.key) + if (!pending) return + clearTimeout(pending.timeout) + state.pending.delete(opts.key) + pending.reject(new Error("Authorization cancelled")) + } + + export async function stop(opts: { port: number; pathname: string }): Promise { + const key = serverKey(opts.port, opts.pathname) + const state = servers.get(key) + if (!state) return + state.server?.stop() + state.server = undefined + + for (const [, pending] of state.pending) { + clearTimeout(pending.timeout) + pending.reject(new Error("OAuth callback server stopped")) + } + state.pending.clear() + servers.delete(key) + + // give Bun time to release the port + await sleep(10) + } +} + diff --git a/packages/opencode/src/oauth/pkce.ts b/packages/opencode/src/oauth/pkce.ts new file mode 100644 index 00000000000..db570af99af --- /dev/null +++ b/packages/opencode/src/oauth/pkce.ts @@ -0,0 +1,27 @@ +import crypto from "crypto" + +function base64UrlEncode(buf: Buffer): string { + return buf + .toString("base64") + .replace(/\+/g, "-") + .replace(/\//g, "_") + .replace(/=+$/g, "") +} + +export namespace PKCE { + export function generateVerifier(byteLength: number = 32): string { + return base64UrlEncode(crypto.randomBytes(byteLength)) + } + + export function challengeFromVerifier(verifier: string): string { + const hash = crypto.createHash("sha256").update(verifier).digest() + return base64UrlEncode(hash) + } +} + +export namespace OAuthState { + export function generate(byteLength: number = 16): string { + return base64UrlEncode(crypto.randomBytes(byteLength)) + } +} + diff --git a/packages/opencode/src/plugin/index.ts b/packages/opencode/src/plugin/index.ts index b492c7179e6..04104bd3bba 100644 --- a/packages/opencode/src/plugin/index.ts +++ b/packages/opencode/src/plugin/index.ts @@ -6,7 +6,6 @@ import { createOpencodeClient } from "@opencode-ai/sdk" import { Server } from "../server/server" import { BunProc } from "../bun" import { Instance } from "../project/instance" -import { Flag } from "../flag/flag" export namespace Plugin { const log = Log.create({ service: "plugin" }) @@ -27,10 +26,6 @@ export namespace Plugin { $: Bun.$, } const plugins = [...(config.plugin ?? [])] - if (!Flag.OPENCODE_DISABLE_DEFAULT_PLUGINS) { - plugins.push("opencode-copilot-auth@0.0.9") - plugins.push("opencode-anthropic-auth@0.0.5") - } for (let plugin of plugins) { log.info("loading plugin", { path: plugin }) if (!plugin.startsWith("file://")) { diff --git a/packages/opencode/src/provider-auth/adapter.ts b/packages/opencode/src/provider-auth/adapter.ts new file mode 100644 index 00000000000..f0cdd44e56b --- /dev/null +++ b/packages/opencode/src/provider-auth/adapter.ts @@ -0,0 +1,53 @@ +import type { AuthOuathResult, Hooks } from "@opencode-ai/plugin" +import type { Credentials } from "@/credentials" + +export type ProviderAuthMethod = NonNullable["methods"][number] + +export type RotateDecision = { + rotatable: boolean + isAuthExpired?: boolean + cooldownMs?: number + reason: string +} + +export type PrepareRequestArgs = { + url: URL + headers: Headers + request: Request + secret: Credentials.Secret +} + +export interface ProviderAuthAdapter { + providerId: string + + authMethods(): ProviderAuthMethod[] + + /** + * Optional request transformation hook for subscription auth. + * Implementations may mutate `url` and `headers` in-place. + */ + prepareRequest?(args: PrepareRequestArgs): void + + /** + * Apply authentication for inference calls. + * Implementations should mutate `headers` in-place (e.g. set Authorization). + */ + applyAuth(headers: Headers, secret: Credentials.Secret): void + + /** + * If supported, refresh an OAuth credential. + * Returns the updated secret fields to persist. + */ + refresh?(secret: Credentials.Secret): Promise + + /** + * Classify a response as rotatable (rate limit, quota exhausted, etc). + */ + classifyResponse?(response: Response): Promise | RotateDecision +} + +export function isOAuthSuccessResult( + result: Awaited>, +): result is { type: "success"; access: string; refresh: string; expires: number; provider?: string } { + return result.type === "success" && "access" in result && "refresh" in result && "expires" in result +} diff --git a/packages/opencode/src/provider-auth/providers/anthropic.ts b/packages/opencode/src/provider-auth/providers/anthropic.ts new file mode 100644 index 00000000000..5a55814bd92 --- /dev/null +++ b/packages/opencode/src/provider-auth/providers/anthropic.ts @@ -0,0 +1,225 @@ +import type { AuthOuathResult, Hooks } from "@opencode-ai/plugin" +import { OAuthCallback } from "@/oauth/callback" +import { PKCE, OAuthState } from "@/oauth/pkce" +import type { ProviderAuthAdapter, ProviderAuthMethod, RotateDecision } from "../adapter" + +const AUTH_URL = process.env["ANTHROPIC_AUTH_URL"] ?? "https://claude.ai/oauth/authorize" +const TOKEN_URL = process.env["ANTHROPIC_TOKEN_URL"] ?? "https://console.anthropic.com/v1/oauth/token" +const CLIENT_ID = process.env["ANTHROPIC_CLIENT_ID"] ?? "9d1c250a-e61b-44d9-88ed-5944d1962f5e" +const REDIRECT_URI = process.env["ANTHROPIC_REDIRECT_URI"] ?? "http://localhost:54545/callback" +const SCOPES = process.env["ANTHROPIC_SCOPES"] ?? "org:create_api_key user:profile user:inference" + +const REDIRECT = new URL(REDIRECT_URI) +const CALLBACK_PORT = Number(REDIRECT.port || "54545") +const CALLBACK_PATH = REDIRECT.pathname || "/callback" + +function parseRetryAfterMs(resp: Response): number | undefined { + const msHeader = resp.headers.get("retry-after-ms") ?? resp.headers.get("Retry-After-Ms") + if (msHeader) { + const ms = Number(msHeader.trim()) + if (Number.isFinite(ms) && ms >= 0) return Math.floor(ms) + } + const raw = resp.headers.get("retry-after") ?? resp.headers.get("Retry-After") + if (!raw) return undefined + const trimmed = raw.trim() + if (!trimmed) return undefined + const seconds = Number(trimmed) + if (Number.isFinite(seconds) && seconds >= 0) return Math.floor(seconds * 1000) + const date = Date.parse(trimmed) + if (!Number.isNaN(date)) return Math.max(0, date - Date.now()) + return undefined +} + +function buildAuthUrl(state: string, codeChallenge: string): string { + const url = new URL(AUTH_URL) + url.searchParams.set("code", "true") + url.searchParams.set("client_id", CLIENT_ID) + url.searchParams.set("response_type", "code") + url.searchParams.set("redirect_uri", REDIRECT_URI) + url.searchParams.set("scope", SCOPES) + url.searchParams.set("code_challenge", codeChallenge) + url.searchParams.set("code_challenge_method", "S256") + url.searchParams.set("state", state) + return url.toString() +} + +async function exchangeCode(args: { code: string; codeVerifier: string; state: string }) { + const body = { + code: args.code.includes("#") ? args.code.split("#")[0] : args.code, + grant_type: "authorization_code", + client_id: CLIENT_ID, + redirect_uri: REDIRECT_URI, + code_verifier: args.codeVerifier, + state: args.state, + } + + const resp = await fetch(TOKEN_URL, { + method: "POST", + headers: { + "Content-Type": "application/json", + Accept: "application/json", + }, + body: JSON.stringify(body), + }) + + const json = await resp.json().catch(() => ({})) + if (!resp.ok) { + const err = (json as any)?.error ?? "token_exchange_failed" + const desc = (json as any)?.error_description ?? resp.statusText + throw new Error(`${err}: ${desc}`) + } + + const access = (json as any)?.access_token + const refresh = (json as any)?.refresh_token + const expiresIn = (json as any)?.expires_in + if (!access) throw new Error("token_exchange_failed: missing access_token") + + const expires = typeof expiresIn === "number" ? Date.now() + expiresIn * 1000 : Date.now() + 60 * 60 * 1000 + return { access, refresh: refresh ?? "", expires } +} + +async function refreshAccessToken(refresh_token: string) { + const body = { + grant_type: "refresh_token", + client_id: CLIENT_ID, + redirect_uri: REDIRECT_URI, + refresh_token, + } + + const resp = await fetch(TOKEN_URL, { + method: "POST", + headers: { + "Content-Type": "application/json", + Accept: "application/json", + }, + body: JSON.stringify(body), + }) + + const json = await resp.json().catch(() => ({})) + if (!resp.ok) { + const err = (json as any)?.error ?? "refresh_failed" + const desc = (json as any)?.error_description ?? resp.statusText + throw new Error(`${err}: ${desc}`) + } + + const access = (json as any)?.access_token + const refresh = (json as any)?.refresh_token + const expiresIn = (json as any)?.expires_in + if (!access) throw new Error("refresh_failed: missing access_token") + + const expires = typeof expiresIn === "number" ? Date.now() + expiresIn * 1000 : Date.now() + 60 * 60 * 1000 + return { access, refresh: refresh ?? refresh_token, expires } +} + +export const AnthropicSubscriptionAdapter: ProviderAuthAdapter = { + providerId: "anthropic", + + authMethods(): ProviderAuthMethod[] { + const makeOAuth = (mode: "auto" | "code"): ProviderAuthMethod => { + return { + type: "oauth", + label: mode === "auto" ? "Claude Max (OAuth)" : "Claude Max (OAuth - paste code)", + async authorize(): Promise { + const codeVerifier = PKCE.generateVerifier() + const codeChallenge = PKCE.challengeFromVerifier(codeVerifier) + const state = OAuthState.generate() + const url = buildAuthUrl(state, codeChallenge) + + if (mode === "auto") { + await OAuthCallback.ensureRunning({ port: CALLBACK_PORT, pathname: CALLBACK_PATH }) + } + + return { + url, + instructions: "Complete login in your browser. Return here when finished.", + method: mode, + async callback(code?: string) { + try { + const resolvedCode = + mode === "code" + ? (code ?? "") + : await OAuthCallback.waitForCallback({ + port: CALLBACK_PORT, + pathname: CALLBACK_PATH, + key: state, + }) + + if (!resolvedCode) return { type: "failed" } + const tokens = await exchangeCode({ code: resolvedCode, codeVerifier, state }) + return { type: "success", access: tokens.access, refresh: tokens.refresh, expires: tokens.expires } + } catch { + return { type: "failed" } + } + }, + } as AuthOuathResult + }, + } + } + + return [makeOAuth("auto"), makeOAuth("code")] + }, + + applyAuth(headers: Headers, secret: any) { + if (secret && typeof secret === "object" && "accessToken" in secret) { + headers.set("Authorization", `Bearer ${String((secret as any).accessToken)}`) + } + }, + + async refresh(secret: any) { + const refresh = secret?.refreshToken + if (!refresh) return secret + const t = await refreshAccessToken(String(refresh)) + return { ...secret, accessToken: t.access, refreshToken: t.refresh, expiresAt: t.expires } + }, + + async classifyResponse(resp: Response): Promise { + if (resp.ok) return { rotatable: false, reason: "ok" } + + if (resp.status === 401 || resp.status === 403) { + return { rotatable: true, isAuthExpired: true, reason: `auth_expired:${resp.status}` } + } + + const retryAfterMs = parseRetryAfterMs(resp) + if (resp.status === 429) { + return { rotatable: true, cooldownMs: retryAfterMs ?? 30_000, reason: "rate_limited" } + } + if (resp.status === 503 || resp.status === 529) { + return { rotatable: true, cooldownMs: retryAfterMs ?? 30_000, reason: `overloaded:${resp.status}` } + } + + const json = await resp.json().catch(() => undefined) + const errorType = String((json as any)?.error?.type ?? (json as any)?.error_type ?? (json as any)?.type ?? "") + .trim() + .toLowerCase() + const message = String((json as any)?.error?.message ?? (json as any)?.message ?? "").trim().toLowerCase() + + const isRateLimited = + errorType.includes("rate_limit") || + errorType.includes("throttle") || + message.includes("rate limit") || + message.includes("too many requests") + + const isOverloaded = + errorType.includes("overloaded") || + message.includes("overloaded") || + message.includes("capacity") || + message.includes("try again") + + const isQuota = + errorType.includes("quota") || + errorType.includes("billing") || + errorType.includes("credit") || + message.includes("quota") || + message.includes("billing") || + message.includes("credits") || + message.includes("usage limit") + + if (isRateLimited || isOverloaded || isQuota) { + const cooldownMs = retryAfterMs ?? (isQuota ? 5 * 60_000 : 30_000) + const reason = `anthropic:${errorType || "throttled"}` + return { rotatable: true, cooldownMs, reason } + } + + return { rotatable: false, reason: `status:${resp.status}` } + }, +} diff --git a/packages/opencode/src/provider-auth/providers/cursor.ts b/packages/opencode/src/provider-auth/providers/cursor.ts new file mode 100644 index 00000000000..de0dc731427 --- /dev/null +++ b/packages/opencode/src/provider-auth/providers/cursor.ts @@ -0,0 +1,111 @@ +import type { AuthOuathResult } from "@opencode-ai/plugin" +import crypto from "crypto" +import type { ProviderAuthAdapter, ProviderAuthMethod } from "../adapter" + +const AUTHENTICATOR_URL = process.env["CURSOR_AUTHENTICATOR_URL"] ?? "https://authenticator.cursor.sh" +const API_URL = process.env["CURSOR_API_URL"] ?? "https://api2.cursor.sh" +const LOGIN_URL = process.env["CURSOR_LOGIN_URL"] ?? `${AUTHENTICATOR_URL}/login` +const POLL_URL = process.env["CURSOR_POLL_URL"] ?? `${API_URL}/auth/poll` +const REFRESH_URL = process.env["CURSOR_REFRESH_URL"] ?? `${API_URL}/auth/refresh` + +type CursorTokenPayload = { + accessToken?: string + refreshToken?: string + access_token?: string + refresh_token?: string + expiresIn?: number + expires_in?: number +} + +function startSession() { + const uuid = crypto.randomUUID() + const verifier = crypto.randomBytes(32).toString("hex") + const url = new URL(LOGIN_URL) + url.searchParams.set("uuid", uuid) + url.searchParams.set("verifier", verifier) + return { uuid, verifier, loginUrl: url.toString() } +} + +async function pollForToken(session: { uuid: string; verifier: string }, maxWaitMs: number = 15 * 60 * 1000) { + const deadline = Date.now() + maxWaitMs + while (Date.now() < deadline) { + const url = new URL(POLL_URL) + url.searchParams.set("uuid", session.uuid) + url.searchParams.set("verifier", session.verifier) + const resp = await fetch(url.toString(), { headers: { Accept: "application/json" } }) + if (resp.status === 200) { + const json = (await resp.json().catch(() => ({}))) as CursorTokenPayload + const access = json.accessToken ?? json.access_token + const refresh = json.refreshToken ?? json.refresh_token + if (!access) throw new Error("poll_failed: missing access token") + const expiresIn = json.expiresIn ?? json.expires_in + const expires = typeof expiresIn === "number" ? Date.now() + expiresIn * 1000 : Date.now() + 60 * 60 * 1000 + return { access, refresh: refresh ?? "", expires } + } + if (resp.status === 401 || resp.status === 404 || resp.status >= 500) { + await new Promise((r) => setTimeout(r, 1500)) + continue + } + throw new Error(`poll_failed: HTTP ${resp.status}`) + } + throw new Error("timeout: polling timed out") +} + +async function refreshToken(refresh_token: string) { + const resp = await fetch(REFRESH_URL, { + method: "POST", + headers: { "Content-Type": "application/json", Accept: "application/json" }, + body: JSON.stringify({ refresh_token }), + }) + if (!resp.ok) throw new Error(`refresh_failed: HTTP ${resp.status}`) + const json = (await resp.json().catch(() => ({}))) as CursorTokenPayload + const access = json.accessToken ?? json.access_token + const refresh = json.refreshToken ?? json.refresh_token + if (!access) throw new Error("refresh_failed: missing access token") + const expiresIn = json.expiresIn ?? json.expires_in + const expires = typeof expiresIn === "number" ? Date.now() + expiresIn * 1000 : Date.now() + 60 * 60 * 1000 + return { access, refresh: refresh ?? refresh_token, expires } +} + +export const CursorSubscriptionAdapter: ProviderAuthAdapter = { + providerId: "cursor", + + authMethods(): ProviderAuthMethod[] { + return [ + { + type: "oauth", + label: "Cursor (Device Link Login)", + async authorize(): Promise { + const session = startSession() + return { + url: session.loginUrl, + instructions: "Complete login in your browser. Return here when finished.", + method: "auto", + async callback() { + try { + const token = await pollForToken(session) + return { type: "success", access: token.access, refresh: token.refresh, expires: token.expires } + } catch { + return { type: "failed" } + } + }, + } as AuthOuathResult + }, + }, + ] + }, + + applyAuth(headers: Headers, secret: any) { + if (secret && typeof secret === "object" && "accessToken" in secret) { + headers.set("Authorization", `Bearer ${String((secret as any).accessToken)}`) + } + }, + + async refresh(secret: any) { + const refresh = secret?.refreshToken + if (!refresh) return secret + const t = await refreshToken(String(refresh)) + return { ...secret, accessToken: t.access, refreshToken: t.refresh, expiresAt: t.expires } + }, +} + diff --git a/packages/opencode/src/provider-auth/providers/github-copilot.ts b/packages/opencode/src/provider-auth/providers/github-copilot.ts new file mode 100644 index 00000000000..d3132092067 --- /dev/null +++ b/packages/opencode/src/provider-auth/providers/github-copilot.ts @@ -0,0 +1,208 @@ +import type { AuthOuathResult } from "@opencode-ai/plugin" +import type { ProviderAuthAdapter, ProviderAuthMethod } from "../adapter" + +const CLIENT_ID = process.env["COPILOT_CLIENT_ID"] ?? "Iv1.b507a08c87ecfe98" +const DEVICE_CODE_URL = process.env["COPILOT_DEVICE_CODE_URL"] ?? "https://github.com/login/device/code" +const TOKEN_URL = process.env["COPILOT_TOKEN_URL"] ?? "https://github.com/login/oauth/access_token" +const COPILOT_API_TOKEN_URL = process.env["COPILOT_API_TOKEN_URL"] ?? "https://api.github.com/copilot_internal/v2/token" +const SCOPE = process.env["COPILOT_SCOPE"] ?? "user:email" + +const COPILOT_EDITOR_VERSION = process.env["COPILOT_EDITOR_VERSION"] ?? "vscode/1.85.1" +const COPILOT_EDITOR_PLUGIN_VERSION = process.env["COPILOT_EDITOR_PLUGIN_VERSION"] ?? "copilot/1.155.0" +const COPILOT_USER_AGENT = process.env["COPILOT_USER_AGENT"] ?? "GithubCopilot/1.155.0" + +type DeviceCodeResponse = { + device_code: string + user_code: string + verification_uri: string + expires_in: number + interval: number +} + +type CopilotApiTokenResponse = { + token?: string + expires_at?: number + endpoints?: { + api?: string + [k: string]: unknown + } + [k: string]: unknown +} + +async function requestDeviceCode(): Promise { + const body = new URLSearchParams() + body.set("client_id", CLIENT_ID) + body.set("scope", SCOPE) + + const resp = await fetch(DEVICE_CODE_URL, { + method: "POST", + headers: { "Content-Type": "application/x-www-form-urlencoded", Accept: "application/json" }, + body: body.toString(), + }) + if (!resp.ok) throw new Error(`device_code_failed: HTTP ${resp.status}`) + return (await resp.json()) as DeviceCodeResponse +} + +async function pollForToken(device: DeviceCodeResponse, timeoutMs: number = 15 * 60 * 1000): Promise<{ access: string }> { + const deadline = Date.now() + Math.min(timeoutMs, device.expires_in * 1000) + let intervalMs = Math.max(5_000, device.interval * 1000) + + while (Date.now() < deadline) { + const body = new URLSearchParams() + body.set("client_id", CLIENT_ID) + body.set("device_code", device.device_code) + body.set("grant_type", "urn:ietf:params:oauth:grant-type:device_code") + + const resp = await fetch(TOKEN_URL, { + method: "POST", + headers: { "Content-Type": "application/x-www-form-urlencoded", Accept: "application/json" }, + body: body.toString(), + }) + + const json = await resp.json().catch(() => ({})) + const err = (json as any)?.error + if (err === "authorization_pending") { + await new Promise((r) => setTimeout(r, intervalMs)) + continue + } + if (err === "slow_down") { + intervalMs += 5_000 + await new Promise((r) => setTimeout(r, intervalMs)) + continue + } + if (err) throw new Error(`${err}: ${(json as any)?.error_description ?? ""}`) + + const access = (json as any)?.access_token + if (access) return { access } + + await new Promise((r) => setTimeout(r, intervalMs)) + } + + throw new Error("timeout: device authorization timed out") +} + +function defaultCopilotHeaders(githubAccessToken?: string): Record { + const headers: Record = { + accept: "application/json", + "editor-version": COPILOT_EDITOR_VERSION, + "editor-plugin-version": COPILOT_EDITOR_PLUGIN_VERSION, + "user-agent": COPILOT_USER_AGENT, + "accept-encoding": "gzip,deflate,br", + } + if (githubAccessToken) headers["authorization"] = `token ${githubAccessToken}` + return headers +} + +async function exchangeForCopilotToken(githubAccessToken: string): Promise<{ + copilotToken: string + expiresAtMs: number + endpoints?: CopilotApiTokenResponse["endpoints"] +}> { + const resp = await fetch(COPILOT_API_TOKEN_URL, { + method: "GET", + headers: defaultCopilotHeaders(githubAccessToken), + }) + const json = (await resp.json().catch(() => ({}))) as CopilotApiTokenResponse + + if (!resp.ok) { + throw new Error(`copilot_token_failed: HTTP ${resp.status}`) + } + + const token = json.token + if (!token) throw new Error("copilot_token_failed: missing token") + + const expiresAtMs = + typeof json.expires_at === "number" && Number.isFinite(json.expires_at) + ? Math.floor(json.expires_at * 1000) + : Date.now() + 60 * 60 * 1000 + + return { + copilotToken: token, + expiresAtMs, + endpoints: json.endpoints, + } +} + +export const GitHubCopilotSubscriptionAdapter: ProviderAuthAdapter = { + providerId: "github-copilot", + + authMethods(): ProviderAuthMethod[] { + return [ + { + type: "oauth", + label: "GitHub Copilot (Device Login)", + async authorize(): Promise { + const device = await requestDeviceCode() + const url = device.verification_uri + const instructions = `Enter code: ${device.user_code}` + + return { + url, + instructions, + method: "auto", + async callback() { + try { + const github = await pollForToken(device) + // GitHub Copilot uses a derived token for inference. + const copilot = await exchangeForCopilotToken(github.access) + return { + type: "success", + access: copilot.copilotToken, + // Store GitHub token as the refresh token so we can derive new copilot tokens on-demand. + refresh: github.access, + expires: copilot.expiresAtMs, + endpoints: copilot.endpoints, + } + } catch { + return { type: "failed" } + } + }, + } as AuthOuathResult + }, + }, + ] + }, + + prepareRequest({ url, secret }) { + const api = (secret as any)?.extra?.endpoints?.api + if (typeof api !== "string" || !api) return + try { + const base = new URL(api) + url.protocol = base.protocol + url.host = base.host + + const prefix = base.pathname && base.pathname !== "/" ? base.pathname.replace(/\/$/, "") : "" + if (prefix) { + const existing = url.pathname.startsWith("/") ? url.pathname : `/${url.pathname}` + url.pathname = `${prefix}${existing}` + } + } catch { + // ignore invalid endpoint + } + }, + + applyAuth(headers: Headers, secret: any) { + if (secret && typeof secret === "object" && "accessToken" in secret) { + headers.set("Authorization", `Bearer ${String((secret as any).accessToken)}`) + if (!headers.has("accept")) headers.set("accept", "application/json") + if (!headers.has("editor-version")) headers.set("editor-version", COPILOT_EDITOR_VERSION) + if (!headers.has("editor-plugin-version")) headers.set("editor-plugin-version", COPILOT_EDITOR_PLUGIN_VERSION) + if (!headers.has("user-agent")) headers.set("user-agent", COPILOT_USER_AGENT) + } + }, + + async refresh(secret: any) { + const githubToken = secret?.refreshToken + if (!githubToken) return secret + const copilot = await exchangeForCopilotToken(String(githubToken)) + return { + ...secret, + accessToken: copilot.copilotToken, + expiresAt: copilot.expiresAtMs, + extra: { + ...(secret?.extra ?? {}), + endpoints: copilot.endpoints ?? (secret?.extra?.endpoints ?? undefined), + }, + } + }, +} diff --git a/packages/opencode/src/provider-auth/providers/google.ts b/packages/opencode/src/provider-auth/providers/google.ts new file mode 100644 index 00000000000..d0674c8df5b --- /dev/null +++ b/packages/opencode/src/provider-auth/providers/google.ts @@ -0,0 +1,159 @@ +import type { AuthOuathResult } from "@opencode-ai/plugin" +import { OAuthCallback } from "@/oauth/callback" +import { PKCE, OAuthState } from "@/oauth/pkce" +import type { ProviderAuthAdapter, ProviderAuthMethod } from "../adapter" + +const CLIENT_ID = + process.env["GEMINI_CLIENT_ID"] ?? + "681255809395-oo8ft2oprdrnp9e3aqf6av3hmdib135j.apps.googleusercontent.com" +const CLIENT_SECRET = process.env["GEMINI_CLIENT_SECRET"] +const AUTHORIZE_URL = process.env["GEMINI_AUTHORIZE_URL"] ?? "https://accounts.google.com/o/oauth2/v2/auth" +const TOKEN_URL = process.env["GEMINI_TOKEN_URL"] ?? "https://oauth2.googleapis.com/token" +const REDIRECT_URI = process.env["GEMINI_REDIRECT_URI"] ?? "http://localhost:8085/oauth2callback" +const SCOPES = + process.env["GEMINI_SCOPES"] ?? + "https://www.googleapis.com/auth/cloud-platform https://www.googleapis.com/auth/userinfo.email https://www.googleapis.com/auth/userinfo.profile" + +const REDIRECT = new URL(REDIRECT_URI) +const CALLBACK_PORT = Number(REDIRECT.port || "8085") +const CALLBACK_PATH = REDIRECT.pathname || "/oauth2callback" + +function buildAuthUrl(state: string, codeChallenge: string): string { + const url = new URL(AUTHORIZE_URL) + url.searchParams.set("client_id", CLIENT_ID) + url.searchParams.set("redirect_uri", REDIRECT_URI) + url.searchParams.set("response_type", "code") + url.searchParams.set("scope", SCOPES) + url.searchParams.set("state", state) + url.searchParams.set("access_type", "offline") + url.searchParams.set("prompt", "consent") + url.searchParams.set("code_challenge", codeChallenge) + url.searchParams.set("code_challenge_method", "S256") + return url.toString() +} + +async function exchangeCode(args: { code: string; codeVerifier: string; state: string }) { + const data = new URLSearchParams() + data.set("client_id", CLIENT_ID) + data.set("code", args.code) + data.set("grant_type", "authorization_code") + data.set("redirect_uri", REDIRECT_URI) + data.set("state", args.state) + data.set("code_verifier", args.codeVerifier) + if (CLIENT_SECRET) data.set("client_secret", CLIENT_SECRET) + + const resp = await fetch(TOKEN_URL, { + method: "POST", + headers: { "Content-Type": "application/x-www-form-urlencoded", Accept: "application/json" }, + body: data.toString(), + }) + + const json = await resp.json().catch(() => ({})) + if (!resp.ok) { + const err = (json as any)?.error ?? "token_exchange_failed" + const desc = (json as any)?.error_description ?? resp.statusText + throw new Error(`${err}: ${desc}`) + } + + const access = (json as any)?.access_token + const refresh = (json as any)?.refresh_token + const expiresIn = (json as any)?.expires_in + if (!access) throw new Error("token_exchange_failed: missing access_token") + const expires = typeof expiresIn === "number" ? Date.now() + expiresIn * 1000 : Date.now() + 60 * 60 * 1000 + return { access, refresh: refresh ?? "", expires } +} + +async function refreshAccessToken(refresh_token: string) { + const data = new URLSearchParams() + data.set("client_id", CLIENT_ID) + data.set("grant_type", "refresh_token") + data.set("refresh_token", refresh_token) + if (CLIENT_SECRET) data.set("client_secret", CLIENT_SECRET) + + const resp = await fetch(TOKEN_URL, { + method: "POST", + headers: { "Content-Type": "application/x-www-form-urlencoded", Accept: "application/json" }, + body: data.toString(), + }) + + const json = await resp.json().catch(() => ({})) + if (!resp.ok) { + const err = (json as any)?.error ?? "refresh_failed" + const desc = (json as any)?.error_description ?? resp.statusText + throw new Error(`${err}: ${desc}`) + } + + const access = (json as any)?.access_token + const expiresIn = (json as any)?.expires_in + if (!access) throw new Error("refresh_failed: missing access_token") + const expires = typeof expiresIn === "number" ? Date.now() + expiresIn * 1000 : Date.now() + 60 * 60 * 1000 + return { access, expires } +} + +export const GoogleGeminiSubscriptionAdapter: ProviderAuthAdapter = { + providerId: "google", + + authMethods(): ProviderAuthMethod[] { + const makeOAuth = (mode: "auto" | "code"): ProviderAuthMethod => { + return { + type: "oauth", + label: mode === "auto" ? "Google (Gemini Code Assist OAuth)" : "Google (OAuth - paste code)", + async authorize(): Promise { + const codeVerifier = PKCE.generateVerifier() + const codeChallenge = PKCE.challengeFromVerifier(codeVerifier) + const state = OAuthState.generate() + const url = buildAuthUrl(state, codeChallenge) + + if (mode === "auto") { + await OAuthCallback.ensureRunning({ port: CALLBACK_PORT, pathname: CALLBACK_PATH }) + } + + return { + url, + instructions: "Complete login in your browser. Return here when finished.", + method: mode, + async callback(code?: string) { + try { + const resolvedCode = + mode === "code" + ? (code ?? "") + : await OAuthCallback.waitForCallback({ + port: CALLBACK_PORT, + pathname: CALLBACK_PATH, + key: state, + }) + if (!resolvedCode) return { type: "failed" } + const tokens = await exchangeCode({ code: resolvedCode, codeVerifier, state }) + return { type: "success", access: tokens.access, refresh: tokens.refresh, expires: tokens.expires } + } catch { + return { type: "failed" } + } + }, + } as AuthOuathResult + }, + } + } + + return [makeOAuth("auto"), makeOAuth("code")] + }, + + prepareRequest({ url }) { + // Avoid leaking API keys when the underlying SDK uses ?key= for API auth. + url.searchParams.delete("key") + url.searchParams.delete("api_key") + url.searchParams.delete("api-key") + }, + + applyAuth(headers: Headers, secret: any) { + if (secret && typeof secret === "object" && "accessToken" in secret) { + headers.set("Authorization", `Bearer ${String((secret as any).accessToken)}`) + } + }, + + async refresh(secret: any) { + const refresh = secret?.refreshToken + if (!refresh) return secret + const t = await refreshAccessToken(String(refresh)) + return { ...secret, accessToken: t.access, expiresAt: t.expires } + }, +} diff --git a/packages/opencode/src/provider-auth/providers/openai.ts b/packages/opencode/src/provider-auth/providers/openai.ts new file mode 100644 index 00000000000..79890a0a82f --- /dev/null +++ b/packages/opencode/src/provider-auth/providers/openai.ts @@ -0,0 +1,157 @@ +import type { AuthOuathResult } from "@opencode-ai/plugin" +import { OAuthCallback } from "@/oauth/callback" +import { PKCE, OAuthState } from "@/oauth/pkce" +import type { ProviderAuthAdapter, ProviderAuthMethod } from "../adapter" + +const AUTH_URL = process.env["OPENAI_AUTH_URL"] ?? "https://auth.openai.com/oauth/authorize" +const TOKEN_URL = process.env["OPENAI_TOKEN_URL"] ?? "https://auth.openai.com/oauth/token" +const CLIENT_ID = process.env["OPENAI_CLIENT_ID"] ?? "app_EMoamEEZ73f0CkXaXp7hrann" +const REDIRECT_URI = process.env["OPENAI_REDIRECT_URI"] ?? "http://localhost:1455/auth/callback" +const SCOPES = process.env["OPENAI_SCOPES"] ?? "openid email profile offline_access" + +const REDIRECT = new URL(REDIRECT_URI) +const CALLBACK_PORT = Number(REDIRECT.port || "1455") +const CALLBACK_PATH = REDIRECT.pathname || "/auth/callback" + +function buildAuthUrl(state: string, codeChallenge: string): string { + const url = new URL(AUTH_URL) + url.searchParams.set("client_id", CLIENT_ID) + url.searchParams.set("response_type", "code") + url.searchParams.set("redirect_uri", REDIRECT_URI) + url.searchParams.set("scope", SCOPES) + url.searchParams.set("state", state) + url.searchParams.set("code_challenge", codeChallenge) + url.searchParams.set("code_challenge_method", "S256") + url.searchParams.set("prompt", "login") + url.searchParams.set("id_token_add_organizations", "true") + url.searchParams.set("codex_cli_simplified_flow", "true") + return url.toString() +} + +async function exchangeCode(args: { code: string; codeVerifier: string; state: string }) { + const data = new URLSearchParams() + data.set("grant_type", "authorization_code") + data.set("client_id", CLIENT_ID) + data.set("code", args.code) + data.set("redirect_uri", REDIRECT_URI) + data.set("code_verifier", args.codeVerifier) + data.set("state", args.state) + + const resp = await fetch(TOKEN_URL, { + method: "POST", + headers: { + "Content-Type": "application/x-www-form-urlencoded", + Accept: "application/json", + }, + body: data.toString(), + }) + + const json = await resp.json().catch(() => ({})) + if (!resp.ok) { + const err = (json as any)?.error ?? "token_exchange_failed" + const desc = (json as any)?.error_description ?? resp.statusText + throw new Error(`${err}: ${desc}`) + } + + const access = (json as any)?.access_token + const refresh = (json as any)?.refresh_token + const expiresIn = (json as any)?.expires_in + if (!access) throw new Error("token_exchange_failed: missing access_token") + + const expires = typeof expiresIn === "number" ? Date.now() + expiresIn * 1000 : Date.now() + 60 * 60 * 1000 + return { access, refresh: refresh ?? "", expires } +} + +async function refreshAccessToken(refresh_token: string) { + const data = new URLSearchParams() + data.set("grant_type", "refresh_token") + data.set("client_id", CLIENT_ID) + data.set("refresh_token", refresh_token) + data.set("redirect_uri", REDIRECT_URI) + + const resp = await fetch(TOKEN_URL, { + method: "POST", + headers: { + "Content-Type": "application/x-www-form-urlencoded", + Accept: "application/json", + }, + body: data.toString(), + }) + + const json = await resp.json().catch(() => ({})) + if (!resp.ok) { + const err = (json as any)?.error ?? "refresh_failed" + const desc = (json as any)?.error_description ?? resp.statusText + throw new Error(`${err}: ${desc}`) + } + + const access = (json as any)?.access_token + const refresh = (json as any)?.refresh_token + const expiresIn = (json as any)?.expires_in + if (!access) throw new Error("refresh_failed: missing access_token") + + const expires = typeof expiresIn === "number" ? Date.now() + expiresIn * 1000 : Date.now() + 60 * 60 * 1000 + return { access, refresh: refresh ?? refresh_token, expires } +} + +export const OpenAISubscriptionAdapter: ProviderAuthAdapter = { + providerId: "openai", + + authMethods(): ProviderAuthMethod[] { + const makeOAuth = (mode: "auto" | "code"): ProviderAuthMethod => { + return { + type: "oauth", + label: mode === "auto" ? "ChatGPT (OAuth)" : "ChatGPT (OAuth - paste code)", + async authorize(): Promise { + const codeVerifier = PKCE.generateVerifier() + const codeChallenge = PKCE.challengeFromVerifier(codeVerifier) + const state = OAuthState.generate() + const url = buildAuthUrl(state, codeChallenge) + + if (mode === "auto") { + await OAuthCallback.ensureRunning({ port: CALLBACK_PORT, pathname: CALLBACK_PATH }) + } + + return { + url, + instructions: "Complete login in your browser. Return here when finished.", + method: mode, + async callback(code?: string) { + try { + const resolvedCode = + mode === "code" + ? (code ?? "") + : await OAuthCallback.waitForCallback({ + port: CALLBACK_PORT, + pathname: CALLBACK_PATH, + key: state, + }) + + if (!resolvedCode) return { type: "failed" } + const tokens = await exchangeCode({ code: resolvedCode, codeVerifier, state }) + return { type: "success", access: tokens.access, refresh: tokens.refresh, expires: tokens.expires } + } catch { + return { type: "failed" } + } + }, + } as AuthOuathResult + }, + } + } + + return [makeOAuth("auto"), makeOAuth("code")] + }, + + applyAuth(headers: Headers, secret: any) { + if (secret && typeof secret === "object" && "accessToken" in secret) { + headers.set("Authorization", `Bearer ${String((secret as any).accessToken)}`) + } + }, + + async refresh(secret: any) { + const refresh = secret?.refreshToken + if (!refresh) return secret + const t = await refreshAccessToken(String(refresh)) + return { ...secret, accessToken: t.access, refreshToken: t.refresh, expiresAt: t.expires } + }, +} diff --git a/packages/opencode/src/provider-auth/providers/qwen.ts b/packages/opencode/src/provider-auth/providers/qwen.ts new file mode 100644 index 00000000000..86c4694d2af --- /dev/null +++ b/packages/opencode/src/provider-auth/providers/qwen.ts @@ -0,0 +1,155 @@ +import type { AuthOuathResult } from "@opencode-ai/plugin" +import { PKCE } from "@/oauth/pkce" +import type { ProviderAuthAdapter, ProviderAuthMethod } from "../adapter" + +const DEVICE_CODE_URL = process.env["QWEN_DEVICE_CODE_ENDPOINT"] ?? "https://chat.qwen.ai/api/v1/oauth2/device/code" +const TOKEN_URL = process.env["QWEN_TOKEN_ENDPOINT"] ?? "https://chat.qwen.ai/api/v1/oauth2/token" +const CLIENT_ID = process.env["QWEN_CLIENT_ID"] ?? "f0304373b74a44d2b584a3fb70ca9e56" +const SCOPES = process.env["QWEN_SCOPES"] ?? "openid profile email model.completion" +const GRANT_TYPE_DEVICE = "urn:ietf:params:oauth:grant-type:device_code" + +type DeviceCodeResponse = { + device_code: string + user_code: string + verification_uri: string + verification_uri_complete?: string + expires_in: number + interval: number +} + +async function initiateDeviceFlow() { + const verifier = PKCE.generateVerifier() + const challenge = PKCE.challengeFromVerifier(verifier) + + const body = new URLSearchParams() + body.set("client_id", CLIENT_ID) + body.set("scope", SCOPES) + body.set("code_challenge", challenge) + body.set("code_challenge_method", "S256") + + const resp = await fetch(DEVICE_CODE_URL, { + method: "POST", + headers: { "Content-Type": "application/x-www-form-urlencoded", Accept: "application/json" }, + body: body.toString(), + }) + if (!resp.ok) throw new Error(`device_code_failed: HTTP ${resp.status}`) + const json = (await resp.json()) as DeviceCodeResponse + if (!json.device_code) throw new Error("device_code_failed: missing device_code") + return { device: json, verifier } +} + +async function pollForToken(args: { device: DeviceCodeResponse; verifier: string; maxWaitMs: number }) { + let intervalMs = Math.max(1_000, args.device.interval * 1000) + const deadline = Date.now() + Math.min(args.maxWaitMs, args.device.expires_in * 1000) + + while (Date.now() < deadline) { + const body = new URLSearchParams() + body.set("grant_type", GRANT_TYPE_DEVICE) + body.set("client_id", CLIENT_ID) + body.set("device_code", args.device.device_code) + body.set("code_verifier", args.verifier) + + const resp = await fetch(TOKEN_URL, { + method: "POST", + headers: { "Content-Type": "application/x-www-form-urlencoded", Accept: "application/json" }, + body: body.toString(), + }) + + const json = await resp.json().catch(() => ({})) + if (resp.ok && (json as any)?.access_token) { + const access = (json as any).access_token as string + const refresh = ((json as any).refresh_token as string | undefined) ?? "" + const expiresIn = (json as any)?.expires_in + const expires = typeof expiresIn === "number" ? Date.now() + expiresIn * 1000 : Date.now() + 60 * 60 * 1000 + return { access, refresh, expires } + } + + const err = (json as any)?.error + if (err === "authorization_pending") { + await new Promise((r) => setTimeout(r, intervalMs)) + continue + } + if (err === "slow_down") { + intervalMs = Math.min(Math.floor(intervalMs * 1.5), 10_000) + await new Promise((r) => setTimeout(r, intervalMs)) + continue + } + if (err) throw new Error(`${err}: ${(json as any)?.error_description ?? ""}`) + + await new Promise((r) => setTimeout(r, intervalMs)) + } + + throw new Error("timeout: device authorization timed out") +} + +async function refreshAccessToken(refresh_token: string) { + const body = new URLSearchParams() + body.set("grant_type", "refresh_token") + body.set("client_id", CLIENT_ID) + body.set("refresh_token", refresh_token) + + const resp = await fetch(TOKEN_URL, { + method: "POST", + headers: { "Content-Type": "application/x-www-form-urlencoded", Accept: "application/json" }, + body: body.toString(), + }) + + const json = await resp.json().catch(() => ({})) + if (!resp.ok) { + const err = (json as any)?.error ?? "refresh_failed" + const desc = (json as any)?.error_description ?? resp.statusText + throw new Error(`${err}: ${desc}`) + } + + const access = (json as any)?.access_token + const refresh = (json as any)?.refresh_token + const expiresIn = (json as any)?.expires_in + if (!access) throw new Error("refresh_failed: missing access_token") + const expires = typeof expiresIn === "number" ? Date.now() + expiresIn * 1000 : Date.now() + 60 * 60 * 1000 + return { access, refresh: refresh ?? refresh_token, expires } +} + +export const QwenSubscriptionAdapter: ProviderAuthAdapter = { + providerId: "qwen", + + authMethods(): ProviderAuthMethod[] { + return [ + { + type: "oauth", + label: "Qwen (Device Login)", + async authorize(): Promise { + const { device, verifier } = await initiateDeviceFlow() + const url = device.verification_uri_complete || device.verification_uri + const instructions = `Enter code: ${device.user_code}` + + return { + url, + instructions, + method: "auto", + async callback() { + try { + const token = await pollForToken({ device, verifier, maxWaitMs: 15 * 60 * 1000 }) + return { type: "success", access: token.access, refresh: token.refresh, expires: token.expires } + } catch { + return { type: "failed" } + } + }, + } as AuthOuathResult + }, + }, + ] + }, + + applyAuth(headers: Headers, secret: any) { + if (secret && typeof secret === "object" && "accessToken" in secret) { + headers.set("Authorization", `Bearer ${String((secret as any).accessToken)}`) + } + }, + + async refresh(secret: any) { + const refresh = secret?.refreshToken + if (!refresh) return secret + const t = await refreshAccessToken(String(refresh)) + return { ...secret, accessToken: t.access, refreshToken: t.refresh, expiresAt: t.expires } + }, +} diff --git a/packages/opencode/src/provider-auth/registry.ts b/packages/opencode/src/provider-auth/registry.ts new file mode 100644 index 00000000000..70c0daa989e --- /dev/null +++ b/packages/opencode/src/provider-auth/registry.ts @@ -0,0 +1,70 @@ +import type { Hooks } from "@opencode-ai/plugin" +import type { ProviderAuthAdapter } from "./adapter" +import { AnthropicSubscriptionAdapter } from "./providers/anthropic" +import { OpenAISubscriptionAdapter } from "./providers/openai" +import { GoogleGeminiSubscriptionAdapter } from "./providers/google" +import { GitHubCopilotSubscriptionAdapter } from "./providers/github-copilot" +import { QwenSubscriptionAdapter } from "./providers/qwen" +import { CursorSubscriptionAdapter } from "./providers/cursor" + +const ADAPTERS: ProviderAuthAdapter[] = [ + AnthropicSubscriptionAdapter, + OpenAISubscriptionAdapter, + GoogleGeminiSubscriptionAdapter, + GitHubCopilotSubscriptionAdapter, + QwenSubscriptionAdapter, + CursorSubscriptionAdapter, +] + +const ALIASES: Record = { + "github-copilot-enterprise": "github-copilot", +} + +export namespace ProviderAuthRegistry { + export function equivalentProviderIds(providerId: string): string[] { + const canonical = resolveProviderId(providerId) + const ids = new Set([canonical, providerId]) + for (const [alias, target] of Object.entries(ALIASES)) { + if (target === canonical) ids.add(alias) + } + return Array.from(ids) + .filter(Boolean) + .sort((a, b) => { + if (a === canonical) return -1 + if (b === canonical) return 1 + return a.localeCompare(b) + }) + } + + export function listProviderIds(): string[] { + const ids = new Set() + for (const adapter of ADAPTERS) ids.add(adapter.providerId) + for (const alias of Object.keys(ALIASES)) ids.add(alias) + return Array.from(ids).sort() + } + + export function resolveProviderId(providerId: string): string { + return ALIASES[providerId] ?? providerId + } + + export function getAdapter(providerId: string): ProviderAuthAdapter | undefined { + const resolved = resolveProviderId(providerId) + return ADAPTERS.find((a) => a.providerId === resolved) + } + + export function getAuthHook(providerId: string): Hooks["auth"] | undefined { + const adapter = getAdapter(providerId) + if (!adapter) return undefined + const methods = [...adapter.authMethods()] + if (!methods.some((m) => m.type === "api")) { + methods.push({ + type: "api", + label: "API key", + } as any) + } + return { + provider: providerId, + methods, + } + } +} diff --git a/packages/opencode/src/provider/auth.ts b/packages/opencode/src/provider/auth.ts index d06253ab4ad..032e88150d4 100644 --- a/packages/opencode/src/provider/auth.ts +++ b/packages/opencode/src/provider/auth.ts @@ -1,21 +1,33 @@ import { Instance } from "@/project/instance" -import { Plugin } from "../plugin" -import { map, filter, pipe, fromEntries, mapValues } from "remeda" +import { mapValues } from "remeda" import z from "zod" import { fn } from "@/util/fn" import type { AuthOuathResult, Hooks } from "@opencode-ai/plugin" import { NamedError } from "@opencode-ai/util/error" -import { Auth } from "@/auth" +import { ProviderAuthRegistry } from "@/provider-auth/registry" +import { CredentialStore, CredentialsMigrate } from "@/credentials" +import { Config } from "@/config/config" export namespace ProviderAuth { const state = Instance.state(async () => { - const methods = pipe( - await Plugin.list(), - filter((x) => x.auth?.provider !== undefined), - map((x) => [x.auth!.provider, x.auth!] as const), - fromEntries(), - ) - return { methods, pending: {} as Record } + const methods: Record> = {} + for (const providerId of ProviderAuthRegistry.listProviderIds()) { + const core = ProviderAuthRegistry.getAuthHook(providerId) + if (!core) continue + methods[providerId] = core as NonNullable + } + + return { + methods, + pending: {} as Record< + string, + { + oauth: AuthOuathResult + namespace?: string + label?: string + } + >, + } }) export const Method = z @@ -55,13 +67,22 @@ export namespace ProviderAuth { z.object({ providerID: z.string(), method: z.number(), + namespace: z.string().optional(), + label: z.string().optional(), }), async (input): Promise => { const auth = await state().then((s) => s.methods[input.providerID]) + if (!auth) return undefined const method = auth.methods[input.method] if (method.type === "oauth") { const result = await method.authorize() - await state().then((s) => (s.pending[input.providerID] = result)) + await state().then((s) => { + s.pending[input.providerID] = { + oauth: result, + namespace: input.namespace, + label: input.label, + } + }) return { url: result.url, method: result.method, @@ -76,9 +97,12 @@ export namespace ProviderAuth { providerID: z.string(), method: z.number(), code: z.string().optional(), + namespace: z.string().optional(), + label: z.string().optional(), }), async (input) => { - const match = await state().then((s) => s.pending[input.providerID]) + const pending = await state().then((s) => s.pending[input.providerID]) + const match = pending?.oauth if (!match) throw new OauthMissing({ providerID: input.providerID }) let result @@ -93,19 +117,55 @@ export namespace ProviderAuth { if (result?.type === "success") { if ("key" in result) { - await Auth.set(input.providerID, { - type: "api", - key: result.key, + await CredentialsMigrate.migrateIfNeeded() + await CredentialStore.upsertSingleton({ + providerId: input.providerID, + namespace: "default", + kind: "api", + label: "default", + secret: { apiKey: result.key }, }) } if ("refresh" in result) { - await Auth.set(input.providerID, { - type: "oauth", - access: result.access, - refresh: result.refresh, - expires: result.expires, + await CredentialsMigrate.migrateIfNeeded() + const config = await Config.get() + const namespace = (input.namespace ?? pending?.namespace ?? config.provider?.[input.providerID]?.auth?.namespace ?? "default") + .trim() || "default" + const desiredLabel = (input.label ?? pending?.label)?.trim() + const existingOauth = (await CredentialStore.findByProvider(input.providerID, namespace)).filter( + (r) => r.meta.kind === "oauth", + ) + const existingLabels = new Set(existingOauth.map((r) => r.meta.label ?? "")) + const labelBase = desiredLabel?.split("\n")[0]?.trim() || undefined + + const label = (() => { + if (labelBase) { + if (!existingLabels.has(labelBase)) return labelBase + let n = 2 + while (existingLabels.has(`${labelBase}-${n}`)) n++ + return `${labelBase}-${n}` + } + + const hasDefault = existingLabels.has("default") + return hasDefault ? `${input.providerID}-${new Date().toISOString()}` : "default" + })() + + const { type: _, provider: __, access, refresh, expires, ...extraFields } = result as any + + await CredentialStore.put({ + providerId: input.providerID, + namespace, + kind: "oauth", + label, + secret: { + accessToken: access, + refreshToken: refresh || undefined, + expiresAt: expires || undefined, + extra: Object.keys(extraFields).length > 0 ? extraFields : undefined, + }, }) } + await state().then((s) => delete s.pending[input.providerID]) return } @@ -119,9 +179,13 @@ export namespace ProviderAuth { key: z.string(), }), async (input) => { - await Auth.set(input.providerID, { - type: "api", - key: input.key, + await CredentialsMigrate.migrateIfNeeded() + await CredentialStore.upsertSingleton({ + providerId: input.providerID, + namespace: "default", + kind: "api", + label: "default", + secret: { apiKey: input.key }, }) }, ) diff --git a/packages/opencode/src/provider/model-discovery.ts b/packages/opencode/src/provider/model-discovery.ts new file mode 100644 index 00000000000..f4622dac70c --- /dev/null +++ b/packages/opencode/src/provider/model-discovery.ts @@ -0,0 +1,142 @@ +import path from "path" +import z from "zod" +import { Global } from "@/global" +import { VaultFS } from "@/vault/fs" +import { Log } from "@/util/log" +import { RotatingFetch } from "@/inference/rotating-fetch" +import { CredentialStore } from "@/credentials" +import { ProviderAuthRegistry } from "@/provider-auth/registry" + +const log = Log.create({ service: "provider.model-discovery" }) + +const CacheFile = z + .object({ + version: z.literal(1), + baseURL: z.string(), + fetchedAt: z.number(), + modelIds: z.array(z.string()), + }) + .strict() +type CacheFile = z.infer + +function cachePath(providerId: string, namespace: string): string { + const safeProvider = encodeURIComponent(providerId) + const safeNamespace = encodeURIComponent(namespace) + return path.join(Global.Path.cache, "model-discovery", `${safeProvider}__${safeNamespace}.json`) +} + +function withoutTrailingSlash(url: string): string { + return url.endsWith("/") ? url.replace(/\/+$/, "") : url +} + +function baseHasV1(baseURL: string): boolean { + return /\/v1($|\/)/.test(baseURL) +} + +async function tryFetchModels(args: { + fetchFn: (input: RequestInfo | URL, init?: RequestInit) => Promise + url: string + headers: Headers + timeoutMs: number +}): Promise { + const resp = await args.fetchFn(args.url, { + method: "GET", + headers: args.headers, + signal: AbortSignal.timeout(args.timeoutMs), + }) + + if (!resp.ok) return undefined + const json = await resp.json().catch(() => undefined) + const data = (json as any)?.data + if (!Array.isArray(data)) return undefined + + const ids = data + .map((item: any) => (item && typeof item === "object" ? item.id : undefined)) + .filter((id: unknown): id is string => typeof id === "string" && id.trim().length > 0) + + return Array.from(new Set(ids)).sort() +} + +export namespace ModelDiscovery { + const DEFAULT_TTL_MS = 12 * 60 * 60 * 1000 // 12h + const DEFAULT_TIMEOUT_MS = 10_000 + + export type Options = { + providerId: string + namespace: string + baseURL: string + authMode: "auto" | "api" | "subscription" + apiKey?: string + headers?: Record + ttlMs?: number + timeoutMs?: number + maxAttempts?: number + } + + export async function discover(opts: Options): Promise { + const namespace = opts.namespace.trim() || "default" + const baseURL = withoutTrailingSlash(opts.baseURL.trim()) + if (!baseURL) return [] + + const ttlMs = opts.ttlMs ?? DEFAULT_TTL_MS + const timeoutMs = opts.timeoutMs ?? DEFAULT_TIMEOUT_MS + + const filePath = cachePath(opts.providerId, namespace) + const cachedRaw = await VaultFS.readJson(filePath) + const cached = CacheFile.safeParse(cachedRaw) + if (cached.success) { + const fresh = Date.now() - cached.data.fetchedAt < ttlMs + const sameBase = cached.data.baseURL === baseURL + if (fresh && sameBase) return cached.data.modelIds + } + + const canonicalProviderId = ProviderAuthRegistry.resolveProviderId(opts.providerId) + const adapter = ProviderAuthRegistry.getAdapter(canonicalProviderId) + + const headers = new Headers({ + Accept: "application/json", + ...(opts.headers ?? {}), + }) + + let fetchFn: (input: RequestInfo | URL, init?: RequestInit) => Promise = fetch + + const providerIds = ProviderAuthRegistry.equivalentProviderIds(opts.providerId) + const oauthRecords = ( + await Promise.all(providerIds.map((id) => CredentialStore.findByProvider(id, namespace))) + ) + .flat() + .filter((r) => r.meta.kind === "oauth") + + const shouldUseSubscription = opts.authMode !== "api" && adapter && oauthRecords.length > 0 + + if (shouldUseSubscription) { + fetchFn = RotatingFetch.create(fetchFn, { + providerId: opts.providerId, + namespace, + maxAttempts: opts.maxAttempts, + }) + } else if (opts.apiKey) { + headers.set("Authorization", `Bearer ${opts.apiKey}`) + } + + const candidates: string[] = [] + candidates.push(`${baseURL}/models`) + if (!baseHasV1(baseURL)) { + candidates.push(`${baseURL}/v1/models`) + } + + for (const url of candidates) { + const ids = await tryFetchModels({ fetchFn, url, headers, timeoutMs }) + if (ids && ids.length > 0) { + const next: CacheFile = { version: 1, baseURL, fetchedAt: Date.now(), modelIds: ids } + await VaultFS.atomicWriteJson(filePath, next, 0o600).catch((e) => { + log.debug("failed to write model discovery cache", { providerId: opts.providerId, error: String(e) }) + }) + return ids + } + } + + return [] + } +} + diff --git a/packages/opencode/src/provider/provider.ts b/packages/opencode/src/provider/provider.ts index b8d4dadbd65..2ce9f7ea0e2 100644 --- a/packages/opencode/src/provider/provider.ts +++ b/packages/opencode/src/provider/provider.ts @@ -5,14 +5,16 @@ import { mapValues, mergeDeep, sortBy } from "remeda" import { NoSuchModelError, type Provider as SDK } from "ai" import { Log } from "../util/log" import { BunProc } from "../bun" -import { Plugin } from "../plugin" import { ModelsDev } from "./models" import { NamedError } from "@opencode-ai/util/error" -import { Auth } from "../auth" import { Env } from "../env" import { Instance } from "../project/instance" import { Flag } from "../flag/flag" import { iife } from "@/util/iife" +import { RotatingFetch } from "@/inference/rotating-fetch" +import { CredentialStore, CredentialsMigrate } from "@/credentials" +import { ProviderAuthRegistry } from "@/provider-auth/registry" +import { ModelDiscovery } from "./model-discovery" // Direct imports for bundled providers import { createAmazonBedrock } from "@ai-sdk/amazon-bedrock" @@ -66,7 +68,9 @@ export namespace Provider { const hasKey = await (async () => { const env = Env.all() if (input.env.some((item) => env[item])) return true - if (await Auth.get(input.id)) return true + await CredentialsMigrate.migrateIfNeeded() + const records = await CredentialStore.findByProvider(input.id, "default") + if (records.some((r) => r.meta.kind === "api")) return true return false })() @@ -286,16 +290,23 @@ export namespace Provider { } }, "sap-ai-core": async () => { - const auth = await Auth.get("sap-ai-core") - const envServiceKey = iife(() => { - const envAICoreServiceKey = Env.get("AICORE_SERVICE_KEY") - if (envAICoreServiceKey) return envAICoreServiceKey - if (auth?.type === "api") { - Env.set("AICORE_SERVICE_KEY", auth.key) - return auth.key + const envAICoreServiceKey = Env.get("AICORE_SERVICE_KEY") + let envServiceKey = envAICoreServiceKey + if (!envServiceKey) { + await CredentialsMigrate.migrateIfNeeded() + const records = await CredentialStore.findByProvider("sap-ai-core", "default") + const apiRecord = + records.find((r) => r.meta.kind === "api" && (r.meta.label ?? "") === "default") ?? + records.find((r) => r.meta.kind === "api") + if (apiRecord) { + const secret = await CredentialStore.decryptSecret(apiRecord) + const apiKey = secret && typeof secret === "object" && "apiKey" in secret ? String((secret as any).apiKey) : undefined + if (apiKey) { + Env.set("AICORE_SERVICE_KEY", apiKey) + envServiceKey = apiKey + } } - return undefined - }) + } const deploymentId = Env.get("AICORE_DEPLOYMENT_ID") const resourceGroup = Env.get("AICORE_RESOURCE_GROUP") @@ -625,61 +636,18 @@ export namespace Provider { } // load apikeys - for (const [providerID, provider] of Object.entries(await Auth.all())) { + await CredentialsMigrate.migrateIfNeeded() + const { records: credentialRecords } = await CredentialStore.listAll() + for (const record of credentialRecords) { + if (record.meta.kind !== "api") continue + const providerID = record.meta.providerId if (disabled.has(providerID)) continue - if (provider.type === "api") { - mergeProvider(providerID, { - source: "api", - key: provider.key, - }) - } - } - - for (const plugin of await Plugin.list()) { - if (!plugin.auth) continue - const providerID = plugin.auth.provider - if (disabled.has(providerID)) continue - - // For github-copilot plugin, check if auth exists for either github-copilot or github-copilot-enterprise - let hasAuth = false - const auth = await Auth.get(providerID) - if (auth) hasAuth = true - - // Special handling for github-copilot: also check for enterprise auth - if (providerID === "github-copilot" && !hasAuth) { - const enterpriseAuth = await Auth.get("github-copilot-enterprise") - if (enterpriseAuth) hasAuth = true - } - - if (!hasAuth) continue - if (!plugin.auth.loader) continue - - // Load for the main provider if auth exists - if (auth) { - const options = await plugin.auth.loader(() => Auth.get(providerID) as any, database[plugin.auth.provider]) - mergeProvider(plugin.auth.provider, { - source: "custom", - options: options, - }) - } - - // If this is github-copilot plugin, also register for github-copilot-enterprise if auth exists - if (providerID === "github-copilot") { - const enterpriseProviderID = "github-copilot-enterprise" - if (!disabled.has(enterpriseProviderID)) { - const enterpriseAuth = await Auth.get(enterpriseProviderID) - if (enterpriseAuth) { - const enterpriseOptions = await plugin.auth.loader( - () => Auth.get(enterpriseProviderID) as any, - database[enterpriseProviderID], - ) - mergeProvider(enterpriseProviderID, { - source: "custom", - options: enterpriseOptions, - }) - } - } - } + const secret = await CredentialStore.decryptSecret(record) + if (!secret || typeof secret !== "object" || !("apiKey" in secret)) continue + mergeProvider(providerID, { + source: "api", + key: String((secret as any).apiKey), + }) } for (const [providerID, fn] of Object.entries(CUSTOM_LOADERS)) { @@ -721,6 +689,79 @@ export namespace Provider { const configProvider = config.provider?.[providerID] + if (configProvider?.options && configProvider.options["discoverModels"] === true) { + const template = Object.values(provider.models)[0] + const baseURL = String((provider.options as any)?.baseURL ?? template?.api?.url ?? "") + const apiKey = typeof (provider.options as any)?.apiKey === "string" ? String((provider.options as any).apiKey) : provider.key + const discovered = await ModelDiscovery.discover({ + providerId: providerID, + namespace: configProvider?.auth?.namespace ?? "default", + baseURL, + authMode: (configProvider?.auth?.mode ?? "auto") as any, + apiKey, + headers: (provider.options as any)?.headers, + timeoutMs: typeof (provider.options as any)?.timeout === "number" ? (provider.options as any).timeout : undefined, + maxAttempts: configProvider?.auth?.maxAttempts, + }) + + if (template && discovered.length > 0) { + const npm = template.api.npm + const url = template.api.url ?? baseURL + const release_date = new Date().toISOString().slice(0, 10) + for (const id of discovered) { + if (provider.models[id]) continue + provider.models[id] = { + id, + providerID, + name: id, + family: "", + api: { + id, + npm, + url, + }, + status: "beta", + capabilities: { + temperature: true, + reasoning: true, + attachment: false, + toolcall: true, + input: { + text: true, + audio: false, + image: false, + video: false, + pdf: false, + }, + output: { + text: true, + audio: false, + image: false, + video: false, + pdf: false, + }, + interleaved: false, + }, + cost: { + input: 0, + output: 0, + cache: { + read: 0, + write: 0, + }, + }, + limit: { + context: 200_000, + output: 8_192, + }, + options: {}, + headers: {}, + release_date, + } + } + } + } + for (const [modelID, model] of Object.entries(provider.models)) { model.api.id = model.api.id ?? model.id ?? modelID if (modelID === "gpt-5-chat-latest" || (providerID === "openrouter" && modelID === "openai/gpt-5-chat")) @@ -753,58 +794,105 @@ export namespace Provider { return state().then((state) => state.providers) } - async function getSDK(model: Model) { - try { - using _ = log.time("getSDK", { - providerID: model.providerID, - }) - const s = await state() - const provider = s.providers[model.providerID] - const options = { ...provider.options } - - if (model.api.npm.includes("@ai-sdk/openai-compatible") && options["includeUsage"] !== false) { - options["includeUsage"] = true - } - - if (!options["baseURL"]) options["baseURL"] = model.api.url - if (options["apiKey"] === undefined && provider.key) options["apiKey"] = provider.key - if (model.headers) - options["headers"] = { - ...options["headers"], - ...model.headers, - } + async function getSDK(model: Model) { + try { + using _ = log.time("getSDK", { + providerID: model.providerID, + }) + const s = await state() + const config = await Config.get() + const provider = s.providers[model.providerID] + const options = { ...provider.options } + const configProvider = config.provider?.[model.providerID] + const authMode = configProvider?.auth?.mode ?? "auto" + const authNamespace = configProvider?.auth?.namespace ?? "default" + const authMaxAttempts = configProvider?.auth?.maxAttempts + + if (model.api.npm.includes("@ai-sdk/openai-compatible") && options["includeUsage"] !== false) { + options["includeUsage"] = true + } + + if (!options["baseURL"]) options["baseURL"] = model.api.url + if (authMode === "subscription") { + options["apiKey"] = "unused" + } else { + if (options["apiKey"] === undefined && provider.key) options["apiKey"] = provider.key + } + if (model.headers) + options["headers"] = { + ...options["headers"], + ...model.headers, + } const key = Bun.hash.xxHash32(JSON.stringify({ npm: model.api.npm, options })) const existing = s.sdk.get(key) if (existing) return existing - const customFetch = options["fetch"] - - options["fetch"] = async (input: any, init?: BunFetchRequestInit) => { - // Preserve custom fetch if it exists, wrap it with timeout logic - const fetchFn = customFetch ?? fetch - const opts = init ?? {} - - if (options["timeout"] !== undefined && options["timeout"] !== null) { - const signals: AbortSignal[] = [] - if (opts.signal) signals.push(opts.signal) - if (options["timeout"] !== false) signals.push(AbortSignal.timeout(options["timeout"])) - - const combined = signals.length > 1 ? AbortSignal.any(signals) : signals[0] - - opts.signal = combined - } - - return fetchFn(input, { - ...opts, - // @ts-ignore see here: https://github.com/oven-sh/bun/issues/16682 - timeout: false, - }) - } - - // Special case: google-vertex-anthropic uses a subpath import - const bundledKey = - model.providerID === "google-vertex-anthropic" ? "@ai-sdk/google-vertex/anthropic" : model.api.npm + const customFetch = options["fetch"] + const fetchFn = customFetch ?? fetch + + const fetchWithTimeout = async (input: any, init?: BunFetchRequestInit) => { + const opts: BunFetchRequestInit = { + ...(init ?? {}), + } + + if (!opts.signal && input instanceof Request && input.signal) { + opts.signal = input.signal + } + + if (options["timeout"] !== undefined && options["timeout"] !== null) { + const signals: AbortSignal[] = [] + if (opts.signal) signals.push(opts.signal) + if (options["timeout"] !== false) signals.push(AbortSignal.timeout(options["timeout"])) + + const combined = signals.length > 1 ? AbortSignal.any(signals) : signals[0] + + opts.signal = combined + } + + return fetchFn(input, { + ...opts, + // @ts-ignore see here: https://github.com/oven-sh/bun/issues/16682 + timeout: false, + }) + } + + if (authMode === "subscription") { + await CredentialsMigrate.migrateIfNeeded() + const canonicalProviderId = ProviderAuthRegistry.resolveProviderId(model.providerID) + const adapter = ProviderAuthRegistry.getAdapter(canonicalProviderId) + if (!adapter) { + throw new Error( + `Provider '${model.providerID}' does not support subscription OAuth in this build (no adapter).`, + ) + } + const providerIds = ProviderAuthRegistry.equivalentProviderIds(model.providerID) + const oauthRecords = ( + await Promise.all(providerIds.map((id) => CredentialStore.findByProvider(id, authNamespace))) + ) + .flat() + .filter((r) => r.meta.kind === "oauth") + if (oauthRecords.length === 0) { + throw new Error( + `No OAuth credentials found for provider '${model.providerID}' in namespace '${authNamespace}'. Run opencode connect and choose an OAuth method.`, + ) + } + } + + if (authMode === "api") { + options["fetch"] = fetchWithTimeout as any + } else { + const fetchWithRotation = RotatingFetch.create(fetchWithTimeout as any, { + providerId: model.providerID, + namespace: authNamespace, + maxAttempts: authMaxAttempts, + }) + options["fetch"] = fetchWithRotation as any + } + + // Special case: google-vertex-anthropic uses a subpath import + const bundledKey = + model.providerID === "google-vertex-anthropic" ? "@ai-sdk/google-vertex/anthropic" : model.api.npm const bundledFn = BUNDLED_PROVIDERS[bundledKey] if (bundledFn) { log.info("using bundled provider", { providerID: model.providerID, pkg: bundledKey }) diff --git a/packages/opencode/src/server/server.ts b/packages/opencode/src/server/server.ts index 8957228ad41..6fd4378fec2 100644 --- a/packages/opencode/src/server/server.ts +++ b/packages/opencode/src/server/server.ts @@ -24,7 +24,8 @@ import { Permission } from "../permission" import { Instance } from "../project/instance" import { Vcs } from "../project/vcs" import { Agent } from "../agent/agent" -import { Auth } from "../auth" +import { CredentialStore, Credentials, CredentialsMigrate } from "@/credentials" +import { RotationStats } from "@/inference/rotation-stats" import { Command } from "../command" import { ProviderAuth } from "../provider/auth" import { Global } from "../global" @@ -1582,14 +1583,18 @@ export namespace Server { "json", z.object({ method: z.number().meta({ description: "Auth method index" }), + namespace: z.string().optional().meta({ description: "Credential namespace (optional)" }), + label: z.string().optional().meta({ description: "Credential label (optional)" }), }), ), async (c) => { const providerID = c.req.valid("param").providerID - const { method } = c.req.valid("json") + const { method, namespace, label } = c.req.valid("json") const result = await ProviderAuth.authorize({ providerID, method, + namespace, + label, }) return c.json(result) }, @@ -1623,15 +1628,19 @@ export namespace Server { z.object({ method: z.number().meta({ description: "Auth method index" }), code: z.string().optional().meta({ description: "OAuth authorization code" }), + namespace: z.string().optional().meta({ description: "Credential namespace (optional)" }), + label: z.string().optional().meta({ description: "Credential label (optional)" }), }), ), async (c) => { const providerID = c.req.valid("param").providerID - const { method, code } = c.req.valid("json") + const { method, code, namespace, label } = c.req.valid("json") await ProviderAuth.callback({ providerID, method, code, + namespace, + label, }) return c.json(true) }, @@ -2439,14 +2448,161 @@ export namespace Server { providerID: z.string(), }), ), - validator("json", Auth.Info), + validator( + "json", + z + .discriminatedUnion("type", [ + z + .object({ + type: z.literal("api"), + key: z.string(), + }) + .meta({ ref: "ApiAuth" }), + z + .object({ + type: z.literal("wellknown"), + key: z.string(), + token: z.string(), + }) + .meta({ ref: "WellKnownAuth" }), + ]) + .meta({ ref: "Auth" }), + ), async (c) => { const providerID = c.req.valid("param").providerID const info = c.req.valid("json") - await Auth.set(providerID, info) + await CredentialsMigrate.migrateIfNeeded() + if (info.type === "api") { + await CredentialStore.upsertSingleton({ + providerId: providerID, + namespace: "default", + kind: "api", + label: "default", + secret: { apiKey: info.key }, + }) + } else { + await CredentialStore.upsertSingleton({ + providerId: providerID, + namespace: "default", + kind: "wellknown", + label: "default", + secret: { envKey: info.key, token: info.token }, + }) + } + return c.json(true) + }, + ) + .get( + "/credential", + describeRoute({ + summary: "List credentials", + description: "List credential records (metadata only).", + operationId: "credential.list", + responses: { + 200: { + description: "Credential records", + content: { + "application/json": { + schema: resolver(Credentials.RecordMeta.array()), + }, + }, + }, + }, + }), + async (c) => { + await CredentialsMigrate.migrateIfNeeded() + const { records } = await CredentialStore.listAll() + return c.json(records.map((r) => r.meta)) + }, + ) + .patch( + "/credential/:credentialID", + describeRoute({ + summary: "Update credential", + description: "Update a credential record's metadata (e.g. label).", + operationId: "credential.update", + responses: { + 200: { + description: "Updated credential meta", + content: { + "application/json": { + schema: resolver(Credentials.RecordMeta), + }, + }, + }, + ...errors(404), + ...errors(400), + }, + }), + validator("param", z.object({ credentialID: z.string() })), + validator( + "json", + z.object({ + label: z.string().min(1).meta({ description: "New label" }), + }), + ), + async (c) => { + await CredentialsMigrate.migrateIfNeeded() + const id = c.req.valid("param").credentialID + const { label } = c.req.valid("json") + const updated = await CredentialStore.update(id, { meta: { label } }) + if (!updated) { + throw new Storage.NotFoundError({ message: "Credential not found" }) + } + return c.json(updated.meta) + }, + ) + .delete( + "/credential/:credentialID", + describeRoute({ + summary: "Remove credential", + description: "Remove a credential record.", + operationId: "credential.remove", + responses: { + 200: { + description: "Removed", + content: { + "application/json": { + schema: resolver(z.boolean()), + }, + }, + }, + ...errors(404), + }, + }), + validator("param", z.object({ credentialID: z.string() })), + async (c) => { + await CredentialsMigrate.migrateIfNeeded() + const id = c.req.valid("param").credentialID + const record = await CredentialStore.getRecordFile(id) + if (!record) { + throw new Storage.NotFoundError({ message: "Credential not found" }) + } + await CredentialStore.remove(id) return c.json(true) }, ) + .get( + "/debug/rotation", + describeRoute({ + summary: "Rotation stats", + description: "In-memory counters for same-request OAuth rotation events.", + operationId: "debug.rotation", + responses: { + 200: { + description: "Rotation stats snapshot", + content: { + "application/json": { + schema: resolver(RotationStats.Snapshot), + }, + }, + }, + }, + }), + async (c) => { + return c.json(RotationStats.snapshot()) + }, + ) .get( "/event", describeRoute({ diff --git a/packages/opencode/src/vault/crypto.ts b/packages/opencode/src/vault/crypto.ts new file mode 100644 index 00000000000..af45af0a2f6 --- /dev/null +++ b/packages/opencode/src/vault/crypto.ts @@ -0,0 +1,53 @@ +import crypto from "crypto" + +export type VaultEncryptedBlobV1 = { + v: 1 + alg: "AES-256-GCM" + nonce_b64: string + tag_b64: string + data_b64: string +} + +export namespace VaultCrypto { + const NONCE_BYTES = 12 + const KEY_BYTES = 32 + + export function encryptJson(key: Buffer, payload: unknown): VaultEncryptedBlobV1 { + if (key.length !== KEY_BYTES) { + throw new Error(`Invalid vault key length. Expected ${KEY_BYTES} bytes, got ${key.length}.`) + } + + const nonce = crypto.randomBytes(NONCE_BYTES) + const cipher = crypto.createCipheriv("aes-256-gcm", key, nonce) + const plaintext = Buffer.from(JSON.stringify(payload), "utf8") + const ciphertext = Buffer.concat([cipher.update(plaintext), cipher.final()]) + const tag = cipher.getAuthTag() + + return { + v: 1, + alg: "AES-256-GCM", + nonce_b64: nonce.toString("base64"), + tag_b64: tag.toString("base64"), + data_b64: ciphertext.toString("base64"), + } + } + + export function decryptJson(key: Buffer, blob: VaultEncryptedBlobV1): unknown { + if (key.length !== KEY_BYTES) { + throw new Error(`Invalid vault key length. Expected ${KEY_BYTES} bytes, got ${key.length}.`) + } + if (blob.v !== 1 || blob.alg !== "AES-256-GCM") { + throw new Error("Unsupported vault blob format.") + } + + const nonce = Buffer.from(blob.nonce_b64, "base64") + const tag = Buffer.from(blob.tag_b64, "base64") + const data = Buffer.from(blob.data_b64, "base64") + + const decipher = crypto.createDecipheriv("aes-256-gcm", key, nonce) + decipher.setAuthTag(tag) + const plaintext = Buffer.concat([decipher.update(data), decipher.final()]) + return JSON.parse(plaintext.toString("utf8")) + } +} + diff --git a/packages/opencode/src/vault/fs.ts b/packages/opencode/src/vault/fs.ts new file mode 100644 index 00000000000..b2452bca7aa --- /dev/null +++ b/packages/opencode/src/vault/fs.ts @@ -0,0 +1,49 @@ +import fs from "fs/promises" +import path from "path" + +export namespace VaultFS { + export async function ensureDir(dir: string) { + await fs.mkdir(dir, { recursive: true }) + } + + export async function exists(filePath: string): Promise { + try { + await fs.stat(filePath) + return true + } catch { + return false + } + } + + export async function readJson(filePath: string): Promise { + try { + const text = await fs.readFile(filePath, "utf8") + return JSON.parse(text) as T + } catch { + return undefined + } + } + + export async function atomicWriteJson(filePath: string, value: unknown, mode: number = 0o600): Promise { + await atomicWriteText(filePath, JSON.stringify(value, null, 2), mode) + } + + export async function atomicWriteText(filePath: string, text: string, mode: number = 0o600): Promise { + const dir = path.dirname(filePath) + await ensureDir(dir) + + const tmpPath = `${filePath}.tmp.${process.pid}.${Math.random().toString(16).slice(2)}` + + const handle = await fs.open(tmpPath, "w", mode) + try { + await handle.writeFile(text, "utf8") + await handle.sync() + } finally { + await handle.close() + } + + await fs.rename(tmpPath, filePath) + await fs.chmod(filePath, mode) + } +} + diff --git a/packages/opencode/src/vault/key.ts b/packages/opencode/src/vault/key.ts new file mode 100644 index 00000000000..900a07870a8 --- /dev/null +++ b/packages/opencode/src/vault/key.ts @@ -0,0 +1,84 @@ +import path from "path" +import fs from "fs/promises" +import crypto from "crypto" +import { Global } from "@/global" +import { VaultFS } from "./fs" + +export namespace VaultKey { + const KEY_ENV = "OPENCODE_VAULT_KEY" + const KEY_PATH = path.join(Global.Path.config, "vault.key") + const KEY_BYTES = 32 + let cached: Buffer | undefined + + export function envVarName(): string { + return KEY_ENV + } + + export function keyPath(): string { + return KEY_PATH + } + + function decodeBase64Key(input: string): Buffer { + const buf = Buffer.from(input.trim(), "base64") + if (buf.length !== KEY_BYTES) { + throw new Error(`Invalid vault key length. Expected ${KEY_BYTES} bytes, got ${buf.length}.`) + } + return buf + } + + async function loadFromFile(): Promise { + if (!(await VaultFS.exists(KEY_PATH))) return undefined + const raw = await fs.readFile(KEY_PATH, "utf8") + return decodeBase64Key(raw) + } + + async function writeKeyToFile(key: Buffer): Promise { + await VaultFS.atomicWriteText(KEY_PATH, key.toString("base64"), 0o600) + // Bun can ignore mode on write; harden after. + await fs.chmod(KEY_PATH, 0o600).catch(() => {}) + } + + export async function load(): Promise { + const env = process.env[KEY_ENV] + if (env) return decodeBase64Key(env) + + if (cached) return cached + const fromFile = await loadFromFile() + if (fromFile) { + cached = fromFile + return cached + } + + cached = crypto.randomBytes(KEY_BYTES) + await writeKeyToFile(cached) + return cached + } + + export async function init(opts?: { force?: boolean }): Promise<{ path: string; created: boolean; source: "env" | "generated" | "existing" }> { + const env = process.env[KEY_ENV] + const envKey = env ? decodeBase64Key(env) : undefined + + const existing = await loadFromFile() + if (existing && !opts?.force) { + cached = existing + return { path: KEY_PATH, created: false, source: "existing" } + } + + const next = envKey ?? crypto.randomBytes(KEY_BYTES) + await writeKeyToFile(next) + cached = next + return { path: KEY_PATH, created: true, source: envKey ? "env" : "generated" } + } + + export async function exportBase64(): Promise { + const key = await load() + return key.toString("base64") + } + + export async function importBase64(input: string): Promise<{ path: string }> { + const key = decodeBase64Key(input) + await writeKeyToFile(key) + cached = key + return { path: KEY_PATH } + } +} diff --git a/packages/opencode/src/vault/lock.ts b/packages/opencode/src/vault/lock.ts new file mode 100644 index 00000000000..50e53adbeda --- /dev/null +++ b/packages/opencode/src/vault/lock.ts @@ -0,0 +1,70 @@ +import fs from "fs/promises" + +type LockContents = { + pid: number + expiresAt: number +} + +export namespace VaultLock { + const DEFAULT_TTL_MS = 30_000 + const DEFAULT_WAIT_MS = 5_000 + const SPIN_MS = 50 + + async function sleep(ms: number) { + await new Promise((r) => setTimeout(r, ms)) + } + + async function tryAcquire(lockPath: string, ttlMs: number): Promise { + try { + const handle = await fs.open(lockPath, "wx", 0o600) + try { + const contents: LockContents = { pid: process.pid, expiresAt: Date.now() + ttlMs } + await handle.writeFile(JSON.stringify(contents), "utf8") + } finally { + await handle.close() + } + return true + } catch (e: any) { + if (e?.code !== "EEXIST") throw e + return false + } + } + + async function breakIfExpired(lockPath: string): Promise { + try { + const raw = await fs.readFile(lockPath, "utf8") + const parsed = JSON.parse(raw) as Partial + if (!parsed.expiresAt || typeof parsed.expiresAt !== "number") return + if (Date.now() > parsed.expiresAt) { + await fs.rm(lockPath, { force: true }) + } + } catch { + // ignore + } + } + + export async function withLock( + lockPath: string, + fn: () => Promise, + opts?: { ttlMs?: number; waitMs?: number }, + ): Promise { + const ttlMs = opts?.ttlMs ?? DEFAULT_TTL_MS + const waitMs = opts?.waitMs ?? DEFAULT_WAIT_MS + const deadline = Date.now() + waitMs + + while (Date.now() < deadline) { + await breakIfExpired(lockPath) + if (await tryAcquire(lockPath, ttlMs)) { + try { + return await fn() + } finally { + await fs.rm(lockPath, { force: true }).catch(() => {}) + } + } + await sleep(SPIN_MS) + } + + throw new Error(`Timed out waiting for lock: ${lockPath}`) + } +} + diff --git a/packages/opencode/test/config/config.test.ts b/packages/opencode/test/config/config.test.ts index 2ff8c01cdb0..07d52732e77 100644 --- a/packages/opencode/test/config/config.test.ts +++ b/packages/opencode/test/config/config.test.ts @@ -40,6 +40,37 @@ test("loads JSON config file", async () => { }) }) +test("parses provider auth settings", async () => { + await using tmp = await tmpdir({ + init: async (dir) => { + await Bun.write( + path.join(dir, "opencode.json"), + JSON.stringify({ + $schema: "https://opencode.ai/config.json", + provider: { + anthropic: { + auth: { + mode: "subscription", + namespace: "work", + maxAttempts: 2, + }, + }, + }, + }), + ) + }, + }) + await Instance.provide({ + directory: tmp.path, + fn: async () => { + const config = await Config.get() + expect(config.provider?.anthropic?.auth?.mode).toBe("subscription") + expect(config.provider?.anthropic?.auth?.namespace).toBe("work") + expect(config.provider?.anthropic?.auth?.maxAttempts).toBe(2) + }, + }) +}) + test("loads JSONC config file", async () => { await using tmp = await tmpdir({ init: async (dir) => { diff --git a/packages/opencode/test/credentials/store.test.ts b/packages/opencode/test/credentials/store.test.ts new file mode 100644 index 00000000000..678212c347c --- /dev/null +++ b/packages/opencode/test/credentials/store.test.ts @@ -0,0 +1,91 @@ +import { describe, expect, test } from "bun:test" +import fs from "fs/promises" +import path from "path" +import { Global } from "../../src/global" +import { CredentialStore } from "../../src/credentials/store" + +async function resetCredentialDir() { + await fs.rm(path.join(Global.Path.data, "credentials"), { recursive: true, force: true }) +} + +describe("CredentialStore", () => { + test("stores secrets encrypted at rest", async () => { + await resetCredentialDir() + const accessToken = "token-plaintext" + const refreshToken = "refresh-plaintext" + + const record = await CredentialStore.put({ + providerId: "openai", + namespace: "default", + kind: "oauth", + label: "default", + secret: { accessToken, refreshToken }, + }) + + const onDisk = await Bun.file(path.join(Global.Path.data, "credentials", "records", `${record.meta.id}.json`)).text() + expect(onDisk).not.toContain(accessToken) + expect(onDisk).not.toContain(refreshToken) + + const decrypted = await CredentialStore.decryptSecret(record) + expect(decrypted).toEqual({ accessToken, refreshToken }) + }) + + test("upsertSingleton updates existing record", async () => { + await resetCredentialDir() + + const first = await CredentialStore.upsertSingleton({ + providerId: "anthropic", + namespace: "default", + kind: "api", + label: "default", + secret: { apiKey: "key-1" }, + }) + const second = await CredentialStore.upsertSingleton({ + providerId: "anthropic", + namespace: "default", + kind: "api", + label: "default", + secret: { apiKey: "key-2" }, + }) + + expect(second.meta.id).toEqual(first.meta.id) + const decrypted = await CredentialStore.decryptSecret(second) + expect(decrypted).toEqual({ apiKey: "key-2" }) + }) + + test("maintains on-disk index for provider lookups", async () => { + await resetCredentialDir() + + const id = "cred-idx-1" + await CredentialStore.put({ + id, + providerId: "openai", + namespace: "default", + kind: "oauth", + label: "default", + secret: { accessToken: "t", refreshToken: "r" }, + }) + + const indexPath = path.join(Global.Path.data, "credentials", "index.json") + const raw = await Bun.file(indexPath).json() + expect(raw).toMatchObject({ + version: 1, + byProvider: { + openai: { + default: [id], + }, + }, + }) + + await CredentialStore.remove(id) + const after = await Bun.file(indexPath).json() + expect(after).toMatchObject({ + version: 1, + byProvider: { + openai: { + default: [], + }, + }, + }) + }) +}) diff --git a/packages/opencode/test/credentials/vault.test.ts b/packages/opencode/test/credentials/vault.test.ts new file mode 100644 index 00000000000..62ec51407c9 --- /dev/null +++ b/packages/opencode/test/credentials/vault.test.ts @@ -0,0 +1,26 @@ +import { describe, expect, test } from "bun:test" +import crypto from "crypto" +import { VaultCrypto } from "../../src/vault/crypto" + +describe("VaultCrypto", () => { + test("encryptJson/decryptJson roundtrip", () => { + const key = crypto.randomBytes(32) + const payload = { a: 1, nested: { b: "x" }, arr: [1, 2, 3] } + + const encrypted = VaultCrypto.encryptJson(key, payload) + const decrypted = VaultCrypto.decryptJson(key, encrypted) + + expect(decrypted).toEqual(payload) + }) + + test("encryptJson uses random nonce", () => { + const key = crypto.randomBytes(32) + const payload = { same: true } + + const a = VaultCrypto.encryptJson(key, payload) + const b = VaultCrypto.encryptJson(key, payload) + + expect(a.nonce_b64).not.toEqual(b.nonce_b64) + }) +}) + diff --git a/packages/opencode/test/inference/rotating-fetch.test.ts b/packages/opencode/test/inference/rotating-fetch.test.ts new file mode 100644 index 00000000000..966e6caf399 --- /dev/null +++ b/packages/opencode/test/inference/rotating-fetch.test.ts @@ -0,0 +1,229 @@ +import { describe, expect, test } from "bun:test" +import fs from "fs/promises" +import path from "path" +import { Global } from "../../src/global" +import { CredentialPool } from "../../src/credentials/pool" +import { CredentialStore } from "../../src/credentials/store" +import { RotatingFetch } from "../../src/inference/rotating-fetch" + +async function resetCredentials() { + await fs.rm(path.join(Global.Path.data, "credentials"), { recursive: true, force: true }) +} + +describe("RotatingFetch", () => { + test("rotates credentials on 429 within the same request", async () => { + await resetCredentials() + + const id1 = "cred-1" + const id2 = "cred-2" + + await CredentialStore.put({ + id: id1, + providerId: "openai", + namespace: "default", + kind: "oauth", + label: "default", + secret: { accessToken: "t1", refreshToken: "r1" }, + }) + await CredentialStore.put({ + id: id2, + providerId: "openai", + namespace: "default", + kind: "oauth", + label: "second", + secret: { accessToken: "t2", refreshToken: "r2" }, + }) + + const seenAuth: Array = [] + const baseFetch = async (input: RequestInfo | URL, init?: RequestInit): Promise => { + const req = new Request(input, init) + const auth = req.headers.get("Authorization") + seenAuth.push(auth) + + if (auth === "Bearer t1") { + return new Response("rate_limited", { status: 429, headers: { "Retry-After": "0" } }) + } + if (auth === "Bearer t2") { + return new Response("ok", { status: 200 }) + } + return new Response("missing_auth", { status: 401 }) + } + + const rotating = RotatingFetch.create(baseFetch, { providerId: "openai", namespace: "default" }) + const resp = await rotating("https://example.com/v1/chat/completions", { + method: "POST", + headers: { "Content-Type": "application/json" }, + body: "{}", + }) + + expect(resp.status).toBe(200) + expect(seenAuth).toEqual(["Bearer t1", "Bearer t2"]) + + const ordered = await CredentialPool.getOrderedIds("openai", "default", [id1, id2]) + expect(ordered).toEqual([id2, id1]) + + const updated1 = await CredentialStore.getRecordFile(id1) + const updated2 = await CredentialStore.getRecordFile(id2) + expect(updated1?.meta.health.lastStatusCode).toBe(429) + expect((updated1?.meta.health.failureCount ?? 0) >= 1).toBe(true) + expect(updated1?.meta.health.cooldownUntil).toBeDefined() + expect(updated2?.meta.health.lastStatusCode).toBe(200) + expect((updated2?.meta.health.successCount ?? 0) >= 1).toBe(true) + }) + + test("rotates credentials on 401 within the same request", async () => { + await resetCredentials() + + const id1 = "cred-1" + const id2 = "cred-2" + + await CredentialStore.put({ + id: id1, + providerId: "github-copilot", + namespace: "default", + kind: "oauth", + label: "default", + secret: { accessToken: "t1", refreshToken: "r1" }, + }) + await CredentialStore.put({ + id: id2, + providerId: "github-copilot", + namespace: "default", + kind: "oauth", + label: "second", + secret: { accessToken: "t2", refreshToken: "r2" }, + }) + + const seenAuth: Array = [] + const baseFetch = async (input: RequestInfo | URL, init?: RequestInit): Promise => { + const req = new Request(input, init) + const auth = req.headers.get("Authorization") + seenAuth.push(auth) + + if (auth === "Bearer t1") return new Response("expired", { status: 401 }) + if (auth === "Bearer t2") return new Response("ok", { status: 200 }) + return new Response("missing_auth", { status: 401 }) + } + + const rotating = RotatingFetch.create(baseFetch, { providerId: "github-copilot", namespace: "default" }) + const resp = await rotating("https://example.com/v1/chat/completions", { + method: "POST", + headers: { "Content-Type": "application/json" }, + body: "{}", + }) + + expect(resp.status).toBe(200) + expect(seenAuth).toEqual(["Bearer t1", "Bearer t2"]) + + const ordered = await CredentialPool.getOrderedIds("github-copilot", "default", [id1, id2]) + expect(ordered).toEqual([id2, id1]) + + const updated1 = await CredentialStore.getRecordFile(id1) + expect(updated1?.meta.health.lastStatusCode).toBe(401) + expect((updated1?.meta.health.failureCount ?? 0) >= 1).toBe(true) + expect(updated1?.meta.health.cooldownUntil).toBeDefined() + }) + + test("uses canonical provider pool + credentials across alias ids", async () => { + await resetCredentials() + + await CredentialStore.put({ + id: "cred-1", + providerId: "github-copilot", + namespace: "default", + kind: "oauth", + label: "default", + secret: { accessToken: "t1", refreshToken: "r1" }, + }) + + const seenAuth: Array = [] + const baseFetch = async (input: RequestInfo | URL, init?: RequestInit): Promise => { + const req = new Request(input, init) + const auth = req.headers.get("Authorization") + seenAuth.push(auth) + if (auth === "Bearer t1") return new Response("ok", { status: 200 }) + return new Response("missing_auth", { status: 401 }) + } + + const rotating = RotatingFetch.create(baseFetch, { providerId: "github-copilot-enterprise", namespace: "default" }) + const resp = await rotating("https://example.com/v1/chat/completions", { + method: "POST", + headers: { "Content-Type": "application/json" }, + body: "{}", + }) + + expect(resp.status).toBe(200) + expect(seenAuth).toEqual(["Bearer t1"]) + }) + + test("adapter prepareRequest can strip API key query params", async () => { + await resetCredentials() + + await CredentialStore.put({ + id: "cred-1", + providerId: "google", + namespace: "default", + kind: "oauth", + label: "default", + secret: { accessToken: "t1", refreshToken: "r1" }, + }) + + const baseFetch = async (input: RequestInfo | URL, init?: RequestInit): Promise => { + const req = new Request(input, init) + const url = new URL(req.url) + expect(url.searchParams.has("key")).toBe(false) + expect(url.searchParams.has("api_key")).toBe(false) + expect(req.headers.get("Authorization")).toBe("Bearer t1") + return new Response("ok", { status: 200 }) + } + + const rotating = RotatingFetch.create(baseFetch, { providerId: "google", namespace: "default" }) + const resp = await rotating("https://example.com/v1/chat/completions?key=leak&api_key=leak2", { + method: "POST", + headers: { "Content-Type": "application/json" }, + body: "{}", + }) + + expect(resp.status).toBe(200) + }) + + test("adapter prepareRequest can rewrite base URL from credential metadata", async () => { + await resetCredentials() + + await CredentialStore.put({ + id: "cred-1", + providerId: "github-copilot", + namespace: "default", + kind: "oauth", + label: "default", + secret: { + accessToken: "copilot-token", + refreshToken: "github-token", + extra: { + endpoints: { + api: "https://alt.example.com/api", + }, + }, + }, + }) + + const baseFetch = async (input: RequestInfo | URL, init?: RequestInit): Promise => { + const req = new Request(input, init) + expect(req.url).toBe("https://alt.example.com/api/v1/chat/completions") + expect(req.headers.get("Authorization")).toBe("Bearer copilot-token") + expect(req.headers.get("editor-version")).toBeTruthy() + expect(req.headers.get("editor-plugin-version")).toBeTruthy() + expect(req.headers.get("user-agent")).toBeTruthy() + return new Response("ok", { status: 200 }) + } + + const rotating = RotatingFetch.create(baseFetch, { providerId: "github-copilot", namespace: "default" }) + const resp = await rotating("https://api.githubcopilot.com/v1/chat/completions", { + method: "POST", + headers: { "Content-Type": "application/json" }, + body: "{}", + }) + + expect(resp.status).toBe(200) + }) +}) diff --git a/packages/opencode/test/mcp/headers.test.ts b/packages/opencode/test/mcp/headers.test.ts index 69998aaaa8a..380cfdcfe39 100644 --- a/packages/opencode/test/mcp/headers.test.ts +++ b/packages/opencode/test/mcp/headers.test.ts @@ -58,6 +58,7 @@ test("headers are passed to transports when oauth is enabled (default)", async ( "test-server": { type: "remote", url: "https://example.com/mcp", + timeout: 250, headers: { Authorization: "Bearer test-token", "X-Custom-Header": "custom-value", @@ -76,6 +77,7 @@ test("headers are passed to transports when oauth is enabled (default)", async ( await MCP.add("test-server", { type: "remote", url: "https://example.com/mcp", + timeout: 250, headers: { Authorization: "Bearer test-token", "X-Custom-Header": "custom-value", @@ -110,6 +112,7 @@ test("headers are passed to transports when oauth is explicitly disabled", async type: "remote", url: "https://example.com/mcp", oauth: false, + timeout: 250, headers: { Authorization: "Bearer test-token", }, @@ -140,6 +143,7 @@ test("no requestInit when headers are not provided", async () => { await MCP.add("test-server-no-headers", { type: "remote", url: "https://example.com/mcp", + timeout: 250, }).catch(() => {}) expect(transportCalls.length).toBeGreaterThanOrEqual(1) diff --git a/packages/opencode/test/provider-auth/registry.test.ts b/packages/opencode/test/provider-auth/registry.test.ts new file mode 100644 index 00000000000..93d74ae506d --- /dev/null +++ b/packages/opencode/test/provider-auth/registry.test.ts @@ -0,0 +1,15 @@ +import { describe, expect, test } from "bun:test" +import { ProviderAuthRegistry } from "../../src/provider-auth/registry" + +describe("ProviderAuthRegistry", () => { + test("resolves alias provider ids to the correct adapter", () => { + const adapter = ProviderAuthRegistry.getAdapter("github-copilot-enterprise") + expect(adapter?.providerId).toBe("github-copilot") + }) + + test("lists alias provider ids", () => { + const ids = ProviderAuthRegistry.listProviderIds() + expect(ids).toContain("github-copilot-enterprise") + }) +}) + diff --git a/packages/sdk/js/src/v2/gen/sdk.gen.ts b/packages/sdk/js/src/v2/gen/sdk.gen.ts index 16fe07ae4a8..727c6fe2c41 100644 --- a/packages/sdk/js/src/v2/gen/sdk.gen.ts +++ b/packages/sdk/js/src/v2/gen/sdk.gen.ts @@ -16,6 +16,12 @@ import type { ConfigProvidersResponses, ConfigUpdateErrors, ConfigUpdateResponses, + CredentialListResponses, + CredentialRemoveErrors, + CredentialRemoveResponses, + CredentialUpdateErrors, + CredentialUpdateResponses, + DebugRotationResponses, EventSubscribeResponses, EventTuiCommandExecute, EventTuiPromptAppend, @@ -1559,6 +1565,8 @@ export class Oauth extends HeyApiClient { providerID: string directory?: string method?: number + namespace?: string + label?: string }, options?: Options, ) { @@ -1570,6 +1578,8 @@ export class Oauth extends HeyApiClient { { in: "path", key: "providerID" }, { in: "query", key: "directory" }, { in: "body", key: "method" }, + { in: "body", key: "namespace" }, + { in: "body", key: "label" }, ], }, ], @@ -1601,6 +1611,8 @@ export class Oauth extends HeyApiClient { directory?: string method?: number code?: string + namespace?: string + label?: string }, options?: Options, ) { @@ -1613,6 +1625,8 @@ export class Oauth extends HeyApiClient { { in: "query", key: "directory" }, { in: "body", key: "method" }, { in: "body", key: "code" }, + { in: "body", key: "namespace" }, + { in: "body", key: "label" }, ], }, ], @@ -2541,6 +2555,115 @@ export class Tui extends HeyApiClient { control = new Control({ client: this.client }) } +export class Credential extends HeyApiClient { + /** + * List credentials + * + * List credential records (metadata only). + */ + public list( + parameters?: { + directory?: string + }, + options?: Options, + ) { + const params = buildClientParams([parameters], [{ args: [{ in: "query", key: "directory" }] }]) + return (options?.client ?? this.client).get({ + url: "/credential", + ...options, + ...params, + }) + } + + /** + * Remove credential + * + * Remove a credential record. + */ + public remove( + parameters: { + credentialID: string + directory?: string + }, + options?: Options, + ) { + const params = buildClientParams( + [parameters], + [ + { + args: [ + { in: "path", key: "credentialID" }, + { in: "query", key: "directory" }, + ], + }, + ], + ) + return (options?.client ?? this.client).delete({ + url: "/credential/{credentialID}", + ...options, + ...params, + }) + } + + /** + * Update credential + * + * Update a credential record's metadata (e.g. label). + */ + public update( + parameters: { + credentialID: string + directory?: string + label?: string + }, + options?: Options, + ) { + const params = buildClientParams( + [parameters], + [ + { + args: [ + { in: "path", key: "credentialID" }, + { in: "query", key: "directory" }, + { in: "body", key: "label" }, + ], + }, + ], + ) + return (options?.client ?? this.client).patch({ + url: "/credential/{credentialID}", + ...options, + ...params, + headers: { + "Content-Type": "application/json", + ...options?.headers, + ...params.headers, + }, + }) + } +} + +export class Debug extends HeyApiClient { + /** + * Rotation stats + * + * In-memory counters for same-request OAuth rotation events. + */ + public rotation( + parameters?: { + directory?: string + }, + options?: Options, + ) { + const params = buildClientParams([parameters], [{ args: [{ in: "query", key: "directory" }] }]) + return (options?.client ?? this.client).get({ + url: "/debug/rotation", + ...options, + ...params, + }) + } +} + export class Event extends HeyApiClient { /** * Subscribe to events @@ -2610,5 +2733,9 @@ export class OpencodeClient extends HeyApiClient { auth = new Auth({ client: this.client }) + credential = new Credential({ client: this.client }) + + debug = new Debug({ client: this.client }) + event = new Event({ client: this.client }) } diff --git a/packages/sdk/js/src/v2/gen/types.gen.ts b/packages/sdk/js/src/v2/gen/types.gen.ts index 00f209c6d88..490a3a0518e 100644 --- a/packages/sdk/js/src/v2/gen/types.gen.ts +++ b/packages/sdk/js/src/v2/gen/types.gen.ts @@ -1232,6 +1232,23 @@ export type ProviderConfig = { } whitelist?: Array blacklist?: Array + /** + * Authentication settings for subscription OAuth rotation and API keys. + */ + auth?: { + /** + * Auth mode for this provider. 'auto' uses subscription OAuth when available, otherwise API key/env. 'api' forces API key/env. 'subscription' forces OAuth rotation. + */ + mode?: "auto" | "api" | "subscription" + /** + * Credential namespace to use for this provider (default: 'default'). + */ + namespace?: string + /** + * Max credentials to try per request when rotating (default: try all eligible). + */ + maxAttempts?: number + } options?: { apiKey?: string baseURL?: string @@ -1239,6 +1256,10 @@ export type ProviderConfig = { * GitHub Enterprise URL for copilot authentication */ enterpriseUrl?: string + /** + * Enable dynamic model discovery for OpenAI/OpenAI-compatible providers by calling the upstream /models endpoint and caching results under Global.Path.cache. + */ + discoverModels?: boolean /** * Enable promptCacheKey for this provider (default false) */ @@ -1823,14 +1844,6 @@ export type FormatterStatus = { enabled: boolean } -export type OAuth = { - type: "oauth" - refresh: string - access: string - expires: number - enterpriseUrl?: string -} - export type ApiAuth = { type: "api" key: string @@ -1842,7 +1855,48 @@ export type WellKnownAuth = { token: string } -export type Auth = OAuth | ApiAuth | WellKnownAuth +export type Auth = ApiAuth | WellKnownAuth + +export type CredentialKind = "oauth" | "api" | "wellknown" | "mcp" + +export type CredentialHealth = { + cooldownUntil?: number + lastStatusCode?: number + lastErrorAt?: number + successCount?: number + failureCount?: number +} + +export type CredentialRecordMeta = { + id: string + providerId: string + namespace?: string + label?: string + kind: CredentialKind + createdAt: number + updatedAt: number + health?: CredentialHealth +} + +export type RotationStatsCounts = { + requests: number + attempts: number + rotations: number + exhausted: number + refreshAttempts: number + refreshSuccess: number + refreshFailure: number + rateLimited: number + authExpired: number +} + +export type RotationStatsSnapshot = { + since: number + totals: RotationStatsCounts + byProvider: { + [key: string]: RotationStatsCounts + } +} export type GlobalEventData = { body?: never @@ -3290,6 +3344,14 @@ export type ProviderOauthAuthorizeData = { * Auth method index */ method: number + /** + * Credential namespace (optional) + */ + namespace?: string + /** + * Credential label (optional) + */ + label?: string } path: { /** @@ -3331,6 +3393,14 @@ export type ProviderOauthCallbackData = { * OAuth authorization code */ code?: string + /** + * Credential namespace (optional) + */ + namespace?: string + /** + * Credential label (optional) + */ + label?: string } path: { /** @@ -4111,6 +4181,109 @@ export type AuthSetResponses = { export type AuthSetResponse = AuthSetResponses[keyof AuthSetResponses] +export type CredentialListData = { + body?: never + path?: never + query?: { + directory?: string + } + url: "/credential" +} + +export type CredentialListResponses = { + /** + * Credential records + */ + 200: Array +} + +export type CredentialListResponse = CredentialListResponses[keyof CredentialListResponses] + +export type CredentialRemoveData = { + body?: never + path: { + credentialID: string + } + query?: { + directory?: string + } + url: "/credential/{credentialID}" +} + +export type CredentialRemoveErrors = { + /** + * Not found + */ + 404: NotFoundError +} + +export type CredentialRemoveError = CredentialRemoveErrors[keyof CredentialRemoveErrors] + +export type CredentialRemoveResponses = { + /** + * Removed + */ + 200: boolean +} + +export type CredentialRemoveResponse = CredentialRemoveResponses[keyof CredentialRemoveResponses] + +export type CredentialUpdateData = { + body?: { + /** + * New label + */ + label: string + } + path: { + credentialID: string + } + query?: { + directory?: string + } + url: "/credential/{credentialID}" +} + +export type CredentialUpdateErrors = { + /** + * Bad request + */ + 400: BadRequestError + /** + * Not found + */ + 404: NotFoundError +} + +export type CredentialUpdateError = CredentialUpdateErrors[keyof CredentialUpdateErrors] + +export type CredentialUpdateResponses = { + /** + * Updated credential meta + */ + 200: CredentialRecordMeta +} + +export type CredentialUpdateResponse = CredentialUpdateResponses[keyof CredentialUpdateResponses] + +export type DebugRotationData = { + body?: never + path?: never + query?: { + directory?: string + } + url: "/debug/rotation" +} + +export type DebugRotationResponses = { + /** + * Rotation stats snapshot + */ + 200: RotationStatsSnapshot +} + +export type DebugRotationResponse = DebugRotationResponses[keyof DebugRotationResponses] + export type EventSubscribeData = { body?: never path?: never diff --git a/specs/provider-auth-v2.md b/specs/provider-auth-v2.md new file mode 100644 index 00000000000..fce83aea075 --- /dev/null +++ b/specs/provider-auth-v2.md @@ -0,0 +1,216 @@ +# Provider Auth V2 (RFC) + +## Summary + +This RFC proposes and documents a **core** OpenCode refactor that adds: + +- An **encrypted credential vault** (AES-256-GCM) with atomic writes + lockfile coordination. +- A **multi-credential store** for OAuth subscriptions (and other credential kinds), enabling multiple accounts per provider. +- **Same-request credential rotation** on `HTTP 429` (rate limiting) and refresh-on-`401/403` where supported. +- A single, composable integration point: **fetch-level middleware** (no proxy/sidecar required). +- Optional **model discovery** for OpenAI/OpenAI-compatible providers (`/models`), cached locally. +- A user-facing way to **manage connected accounts** (TUI + server endpoints) and to **manage the vault key** (CLI). + +This enables “subscription pools” (user-context OAuth sessions) and “API key mode” (developer credits) to coexist cleanly. + +## Goals + +- Support **multiple OAuth subscription credentials per provider** (Anthropic, OpenAI, Google, Copilot, Qwen, Cursor). +- On throttling (`429`): **move the credential to the back** of the provider pool, apply cooldown (Retry-After aware), and retry **within the same user request**. +- Keep architecture **DRY and composable** (one auth store, one rotation engine, provider adapters for differences). +- Preserve the ability to use OpenCode via **API keys** (no subscription required). + +## Non-goals + +- Mid-stream rotation (if the upstream stream fails after tokens are emitted, we surface error; users retry). +- Universal model discovery across all providers (only where the provider exposes it). + +## Design + +### Credential Vault + +- Vault key: + - Loaded from `OPENCODE_VAULT_KEY` (base64 32 bytes) **or** + - Generated and persisted to `Global.Path.config/vault.key` (mode `0600`). +- Vault key UX (explicit, no OS keychain): + - `opencode auth vault init` (create `vault.key`) + - `opencode auth vault export [--output ]` (print base64 key for backup/migration) + - `opencode auth vault import [--file | --key ]` (restore key) +- Encryption: + - `AES-256-GCM` with random nonce per record. +- IO semantics: + - Atomic writes: write temp + `fsync` + `rename`. + - Lockfile: prevents concurrent writers from clobbering pool/order state. + +Code: +- `packages/opencode/src/vault/crypto.ts` +- `packages/opencode/src/vault/fs.ts` +- `packages/opencode/src/vault/lock.ts` +- `packages/opencode/src/vault/key.ts` + +### Credential Store (multi-account) + +Records are stored as: + +- `Global.Path.data/credentials/records/.json` +- Each record contains: + - `meta` (providerId, namespace, label, kind, timestamps, health) + - `secret` (encrypted blob) + +Kinds: +- `oauth` (subscription tokens) +- `api` (API keys) +- `wellknown` (token + env key) +- `mcp` (MCP OAuth entries) + +Code: +- `packages/opencode/src/credentials/types.ts` +- `packages/opencode/src/credentials/store.ts` +- `packages/opencode/src/credentials/migrate.ts` + +### Provider Adapters (OAuth + header injection) + +Adapters provide a single source of truth for: + +- OAuth login flows (PKCE browser redirect, device flow, polling flows) +- Applying credentials to outgoing requests (`applyAuth(headers, secret)`) +- Optional refresh (`refresh(secret)`), when the provider supports it + +Code: +- `packages/opencode/src/provider-auth/adapter.ts` +- `packages/opencode/src/provider-auth/registry.ts` +- `packages/opencode/src/provider-auth/providers/*` + +### Rotation Engine (fetch middleware) + +Rotation is implemented as a **fetch wrapper**: + +1. Load eligible OAuth credentials for `providerId` + `namespace` +2. Order via a persistent pool +3. Attempt request with selected credential +4. On `429`: + - update cooldown based on `Retry-After` + - move credential to back + - retry in the same request +5. On `401/403`: + - refresh if supported and refresh token exists + - retry once + +Code: +- `packages/opencode/src/inference/rotating-fetch.ts` +- `packages/opencode/src/credentials/pool.ts` + +### Integration Point: Provider.getSDK fetch + +Rotation is injected in one place: the AI SDK `fetch` option, in: + +- `packages/opencode/src/provider/provider.ts` (`getSDK()`) + +This keeps the system composable and provider-agnostic. + +### Config + +Per provider config controls the auth mode: + +```jsonc +{ + "provider": { + "anthropic": { + "auth": { + "mode": "auto", // auto | api | subscription + "namespace": "default", // credential namespace + "maxAttempts": 3 // optional + } + } + } +} +``` + +- `auto`: use OAuth rotation if credentials exist; otherwise use API key/env. +- `api`: disable OAuth rotation for this provider. +- `subscription`: require OAuth credentials; error early if missing. + +Schema: +- `packages/opencode/src/config/config.ts` + +### UX (happy path) + +- TUI: `Connect a provider` dialog + - Choose provider → choose auth method (OAuth or API key) + - OAuth adds a new encrypted record (multi-account) +- TUI: `Manage connected accounts` + - List credential records grouped by provider + - Rename labels and remove credentials +- CLI: + - `opencode auth login` + - `opencode auth list` shows credential records (provider/kind/namespace/label) + - `opencode auth logout` removes all records for a provider + - `opencode auth vault init|export|import` manages the local vault key + +### Credential management endpoints (for UI) + +The server exposes credential metadata (never decrypted secrets) for management UI: + +- `GET /credential` (list `Credentials.RecordMeta[]`) +- `PATCH /credential/:credentialID` (update `label`) +- `DELETE /credential/:credentialID` + +### Rotation stats + +OpenCode tracks in-memory counters for same-request OAuth rotation: + +- requests / attempts / rotations / refresh successes / exhausted + +Exposed via: + +- `GET /debug/rotation` (for UI) +- Shown in TUI `Status` dialog + +## Migration + +Migration is **idempotent** and can run even if the v2 store already contains records. + +- `auth.json` → migrated into v2 records (OAuth records become multi-account entries) +- `mcp-auth.json` → migrated into v2 “mcp:” provider records + +Each legacy provider entry is fingerprinted and tracked in: + +- `Global.Path.data/credentials/migrations.json` + +Code: +- `packages/opencode/src/credentials/migrate.ts` +- MCP credential access: + - `packages/opencode/src/mcp/credentials.ts` + +## Testing + +Unit tests cover: +- Vault crypto roundtrip +- Store encryption at rest +- Same-request rotation on `429` + +Code: +- `packages/opencode/test/credentials/*` +- `packages/opencode/test/inference/rotating-fetch.test.ts` + +## Model discovery (opt-in) + +OpenCode can optionally discover model IDs from OpenAI/OpenAI-compatible providers by calling upstream `/models` and caching results under `Global.Path.cache/model-discovery/*`. + +Config: + +```jsonc +{ + "provider": { + "openai": { + "options": { + "discoverModels": true + } + } + } +} +``` + +Notes: +- Discovery is additive (never removes or overwrites curated IDs). +- Discovery can use subscription rotation when `auth.mode` is `auto|subscription` and OAuth credentials exist.