diff --git a/knip.json b/knip.json index fae5b9c1..f11192cc 100644 --- a/knip.json +++ b/knip.json @@ -8,7 +8,9 @@ ], "workspaces": { "packages/appkit": {}, - "packages/appkit-ui": {} + "packages/appkit-ui": { + "ignoreDependencies": ["react-dom", "@types/react-dom"] + } }, "ignore": [ "**/*.generated.ts", diff --git a/packages/appkit/package.json b/packages/appkit/package.json index 9e810b97..784b1c4e 100644 --- a/packages/appkit/package.json +++ b/packages/appkit/package.json @@ -68,8 +68,10 @@ "@opentelemetry/sdk-trace-base": "2.6.0", "@opentelemetry/semantic-conventions": "1.38.0", "@types/semver": "7.7.1", + "cors": "^2.8.6", "dotenv": "16.6.1", "express": "4.22.0", + "helmet": "^8.1.0", "obug": "2.1.1", "pg": "8.18.0", "picocolors": "1.1.1", @@ -79,6 +81,7 @@ "ws": "8.18.3" }, "devDependencies": { + "@types/cors": "^2.8.19", "@types/express": "4.17.25", "@types/json-schema": "7.0.15", "@types/pg": "8.16.0", diff --git a/packages/appkit/src/plugins/server/index.ts b/packages/appkit/src/plugins/server/index.ts index e7b9b31a..a6f0f806 100644 --- a/packages/appkit/src/plugins/server/index.ts +++ b/packages/appkit/src/plugins/server/index.ts @@ -12,6 +12,7 @@ import { instrumentations } from "../../telemetry"; import { sanitizeClientConfig } from "./client-config-sanitizer"; import manifest from "./manifest.json"; import { RemoteTunnelController } from "./remote-tunnel/remote-tunnel-controller"; +import { registerErrorHandler, registerSecurityMiddleware } from "./security"; import { StaticServer } from "./static-server"; import type { ServerConfig } from "./types"; import { getRoutes, type PluginEndpoints, printRoutes } from "./utils"; @@ -93,8 +94,12 @@ export class ServerPlugin extends Plugin { * @returns The express application. */ async start(): Promise { + // Security middleware first — inspects headers only, no body needed + registerSecurityMiddleware(this.serverApplication, this.config.security); + this.serverApplication.use( express.json({ + limit: this.config.bodyLimit ?? "100kb", type: (req) => { // Skip JSON parsing for routes that declared skipBodyParsing // (e.g. file uploads where the raw body must flow through). @@ -122,6 +127,9 @@ export class ServerPlugin extends Plugin { await this.setupFrontend(endpoints, pluginConfigs); + // Error handler last — catches unhandled errors from API routes + registerErrorHandler(this.serverApplication, this.config.security); + const server = this.serverApplication.listen( this.config.port ?? ServerPlugin.DEFAULT_CONFIG.port, this.config.host ?? ServerPlugin.DEFAULT_CONFIG.host, diff --git a/packages/appkit/src/plugins/server/manifest.json b/packages/appkit/src/plugins/server/manifest.json index 11822beb..6c563874 100644 --- a/packages/appkit/src/plugins/server/manifest.json +++ b/packages/appkit/src/plugins/server/manifest.json @@ -29,6 +29,75 @@ "staticPath": { "type": "string", "description": "Path to static files directory (auto-detected if not provided)" + }, + "bodyLimit": { + "type": "string", + "description": "JSON body size limit (e.g. '100kb', '1mb'). Default: '100kb'" + }, + "security": { + "type": "object", + "description": "Security configuration. Secure defaults applied when omitted.", + "properties": { + "csrf": { + "oneOf": [ + { + "type": "object", + "properties": { + "allowedOrigins": { + "type": "array", + "items": { "type": "string" }, + "description": "Additional trusted origins for CSRF validation" + } + } + }, + { "const": false } + ] + }, + "helmet": { + "oneOf": [ + { + "type": "object", + "description": "HelmetOptions — fully replaces defaults" + }, + { "const": false } + ] + }, + "cors": { + "oneOf": [ + { + "type": "object", + "properties": { + "allowedOrigins": { + "type": "array", + "items": { "type": "string" } + }, + "credentials": { "type": "boolean" }, + "maxAge": { "type": "number" }, + "allowedMethods": { + "type": "array", + "items": { "type": "string" } + }, + "allowedHeaders": { + "type": "array", + "items": { "type": "string" } + } + } + }, + { "const": false } + ] + }, + "errorHandler": { + "oneOf": [ + { + "type": "object", + "properties": { + "includeErrorCode": { "type": "boolean" } + } + }, + { "const": false } + ] + } + } } } } diff --git a/packages/appkit/src/plugins/server/security/csrf.ts b/packages/appkit/src/plugins/server/security/csrf.ts new file mode 100644 index 00000000..7b52d1fe --- /dev/null +++ b/packages/appkit/src/plugins/server/security/csrf.ts @@ -0,0 +1,184 @@ +import type { NextFunction, Request, Response } from "express"; +import { createLogger } from "../../../logging/logger"; +import type { CsrfConfig } from "./types"; + +const logger = createLogger("server"); + +const STATE_CHANGING_METHODS = new Set(["POST", "PUT", "DELETE", "PATCH"]); + +/** + * Parse a comma-separated env var into trimmed, non-empty strings. + */ +function parseEnvOrigins(envVar: string | undefined): string[] { + if (!envVar) return []; + return envVar + .split(",") + .map((s) => s.trim()) + .filter(Boolean); +} + +/** + * Build the set of trusted origins from all sources: + * 1. DATABRICKS_APP_URL env var + * 2. Config allowedOrigins + * 3. APPKIT_CSRF_ALLOWED_ORIGINS env var + */ +function buildTrustedOrigins(config?: CsrfConfig): Set { + const origins = new Set(); + + const appUrl = process.env.DATABRICKS_APP_URL; + if (appUrl) { + try { + origins.add(new URL(appUrl).origin.toLowerCase()); + } catch { + logger.warn( + "DATABRICKS_APP_URL is not a valid URL: %s — skipping for CSRF", + appUrl, + ); + } + } + + for (const o of config?.allowedOrigins ?? []) { + origins.add(o.toLowerCase().replace(/\/$/, "")); + } + + for (const o of parseEnvOrigins(process.env.APPKIT_CSRF_ALLOWED_ORIGINS)) { + origins.add(o.toLowerCase().replace(/\/$/, "")); + } + + return origins; +} + +/** + * Check if an origin matches localhost (any port). + */ +function isLocalhostOrigin(origin: string): boolean { + try { + const url = new URL(origin); + return url.hostname === "localhost" || url.hostname === "127.0.0.1"; + } catch { + return false; + } +} + +/** + * Same-origin heuristic: compare Origin against Host header. + * Used as fallback when no trusted origins are configured. + */ +function isSameOrigin(origin: string, req: Request): boolean { + const host = req.headers.host; + if (!host) return false; + + try { + const originUrl = new URL(origin); + const originHost = originUrl.host.toLowerCase(); + return originHost === host.toLowerCase(); + } catch { + return false; + } +} + +/** + * Create CSRF protection middleware using Origin header validation. + * + * - Applies to state-changing methods (POST, PUT, DELETE, PATCH) only + * - Allows absent/empty Origin (same-origin browser or non-browser client) + * - Rejects `Origin: null` (sandboxed iframe attack vector) + * - In dev mode, auto-allows localhost origins + * - Falls back to Host header comparison when no trusted origins are configured + */ +export function createCsrfMiddleware( + config?: CsrfConfig | false, +): (req: Request, res: Response, next: NextFunction) => void { + if (config === false) { + return (_req, _res, next) => next(); + } + + const isDev = process.env.NODE_ENV === "development"; + const trustedOrigins = buildTrustedOrigins( + config === undefined ? undefined : config, + ); + + if (!isDev && trustedOrigins.size === 0) { + logger.warn( + "DATABRICKS_APP_URL not set and no CSRF origins configured — CSRF will use Host header fallback. Set DATABRICKS_APP_URL for full protection.", + ); + } + + return (req: Request, res: Response, next: NextFunction) => { + if (!STATE_CHANGING_METHODS.has(req.method)) { + return next(); + } + + const origin = req.headers.origin; + + // No Origin header — allow (same-origin or non-browser client) + if (!origin || origin === "") { + return next(); + } + + // Reject Origin: null (sandboxed iframe, data: URI) + if (origin === "null") { + logger.debug("CSRF rejected: null Origin on %s %s", req.method, req.path); + return res.status(403).json( + isDev + ? { + error: "CSRF validation failed", + detail: + "Origin: null rejected — possible sandboxed iframe or data: URI", + } + : { error: "CSRF validation failed" }, + ); + } + + const normalizedOrigin = origin.toLowerCase().replace(/\/$/, ""); + + // In dev mode, allow localhost origins + if (isDev && isLocalhostOrigin(normalizedOrigin)) { + return next(); + } + + // In production, reject non-HTTPS origins + if (!isDev && !normalizedOrigin.startsWith("https://")) { + logger.debug( + "CSRF rejected: non-HTTPS Origin %s on %s %s", + origin, + req.method, + req.path, + ); + return res.status(403).json( + isDev + ? { + error: "CSRF validation failed", + detail: `Origin must use HTTPS in production: ${origin}`, + } + : { error: "CSRF validation failed" }, + ); + } + + // Check against trusted origins + if (trustedOrigins.has(normalizedOrigin)) { + return next(); + } + + // Fallback: same-origin heuristic (compare Origin vs Host) + if (trustedOrigins.size === 0 && isSameOrigin(origin, req)) { + return next(); + } + + logger.debug( + "CSRF rejected: Origin %s not trusted on %s %s", + origin, + req.method, + req.path, + ); + return res.status(403).json( + isDev + ? { + error: "CSRF validation failed", + detail: `Origin ${origin} not in trusted set`, + } + : { error: "CSRF validation failed" }, + ); + }; +} diff --git a/packages/appkit/src/plugins/server/security/error-handler.ts b/packages/appkit/src/plugins/server/security/error-handler.ts new file mode 100644 index 00000000..8361c723 --- /dev/null +++ b/packages/appkit/src/plugins/server/security/error-handler.ts @@ -0,0 +1,85 @@ +import type { NextFunction, Request, Response } from "express"; +import { AppKitError } from "../../../errors/base"; +import { createLogger } from "../../../logging/logger"; +import type { ErrorHandlerConfig } from "./types"; + +const logger = createLogger("server"); + +/** + * Create a global error handler middleware that prevents information disclosure. + * + * - Logs full error details server-side (using AppKitError.toJSON() for safe sanitization) + * - Returns generic error messages in production + * - Includes message/stack in dev mode for debugging + * - Handles SyntaxError from JSON body parsing (returns 400) + * - Respects headersSent to avoid double-send + */ +export function createErrorHandler( + config?: ErrorHandlerConfig | false, +): (err: Error, req: Request, res: Response, next: NextFunction) => void { + if (config === false) { + return (_err, _req, _res, next) => next(_err); + } + + const isDev = process.env.NODE_ENV === "development"; + const includeErrorCode = config?.includeErrorCode ?? true; + + return (err: Error, _req: Request, res: Response, next: NextFunction) => { + // If headers already sent, delegate to Express default handler + if (res.headersSent) { + return next(err); + } + + // Log the error server-side + if (err instanceof AppKitError) { + logger.error("Unhandled error: %O", err.toJSON()); + } else { + logger.error("Unhandled error: %s", err.message); + if (err.stack) { + logger.debug("Stack trace: %s", err.stack); + } + } + + // Handle JSON parsing errors from express.json() + if ( + err instanceof SyntaxError && + "status" in err && + (err as { status?: number }).status === 400 + ) { + return res + .status(400) + .json( + isDev + ? { error: "Bad Request", message: err.message } + : { error: "Bad Request" }, + ); + } + + // Handle AppKitError with proper status code + if (err instanceof AppKitError) { + const body: Record = { + error: isDev ? err.message : "Internal Server Error", + }; + + if (includeErrorCode) { + body.code = err.code; + } + + if (isDev && err.stack) { + body.stack = err.stack; + } + + return res.status(err.statusCode).json(body); + } + + // Generic error + return res.status(500).json( + isDev + ? { + error: err.message || "Internal Server Error", + stack: err.stack, + } + : { error: "Internal Server Error" }, + ); + }; +} diff --git a/packages/appkit/src/plugins/server/security/index.ts b/packages/appkit/src/plugins/server/security/index.ts new file mode 100644 index 00000000..5595f43e --- /dev/null +++ b/packages/appkit/src/plugins/server/security/index.ts @@ -0,0 +1,140 @@ +import corsMiddleware from "cors"; +import type { Application } from "express"; +import helmet from "helmet"; +import { createLogger } from "../../../logging/logger"; +import { createCsrfMiddleware } from "./csrf"; +import { createErrorHandler } from "./error-handler"; +import type { CorsConfig, SecurityConfig } from "./types"; + +const logger = createLogger("server"); + +/** + * Build the default Helmet options based on the environment. + */ +function getDefaultHelmetOptions(isDev: boolean) { + if (isDev) { + return { + contentSecurityPolicy: { + directives: { + defaultSrc: ["'self'", "http:", "https:", "ws:", "wss:"], + scriptSrc: ["'self'", "'unsafe-inline'", "http:", "https:"], + styleSrc: ["'self'", "'unsafe-inline'", "http:", "https:"], + imgSrc: ["'self'", "http:", "https:", "data:", "blob:"], + fontSrc: ["'self'", "http:", "https:", "data:"], + objectSrc: ["'none'"], + baseUri: ["'self'"], + connectSrc: ["'self'", "http:", "https:", "ws:", "wss:"], + frameAncestors: ["'self'"], + }, + }, + crossOriginOpenerPolicy: { policy: "same-origin" as const }, + }; + } + + return { + contentSecurityPolicy: { + directives: { + defaultSrc: ["https:", "wss:"], + scriptSrc: ["https:"], + styleSrc: ["'self'", "https:", "'unsafe-inline'"], + imgSrc: ["https:", "data:"], + fontSrc: ["https:", "data:"], + objectSrc: ["'none'"], + baseUri: ["'self'"], + connectSrc: ["https:", "wss:"], + frameAncestors: ["'none'"], + }, + }, + crossOriginOpenerPolicy: { policy: "same-origin" as const }, + }; +} + +/** + * Build CORS options from CorsConfig + env var. + */ +function buildCorsOptions(config: CorsConfig) { + const origins = [ + ...(config.allowedOrigins ?? []), + ...(process.env.APPKIT_CORS_ALLOWED_ORIGINS?.split(",") + .map((s) => s.trim()) + .filter(Boolean) ?? []), + ]; + + return { + origin: origins.length > 0 ? origins : (false as false), + credentials: config.credentials ?? false, + maxAge: config.maxAge ?? 86400, + methods: config.allowedMethods ?? ["GET", "POST", "PUT", "DELETE", "PATCH"], + allowedHeaders: config.allowedHeaders ?? ["Content-Type", "Authorization"], + optionsSuccessStatus: 204, + }; +} + +/** + * Register security middleware on the Express application. + * + * Applied in order: + * 1. Helmet (security headers + CSP) + * 2. CORS (if enabled) + * 3. CSRF (origin validation) + * + * All middleware only inspect headers — no body parsing required. + * Must be registered before route handlers. + */ +export function registerSecurityMiddleware( + app: Application, + config?: SecurityConfig, +): void { + const isDev = process.env.NODE_ENV === "development"; + const features: string[] = []; + + // 1. Helmet (security headers) + if (config?.helmet !== false) { + const helmetOptions = + config?.helmet && typeof config.helmet === "object" + ? config.helmet // User-provided options fully replace defaults + : getDefaultHelmetOptions(isDev); + + app.use(helmet(helmetOptions)); + features.push("Helmet (CSP + security headers)"); + } + + // 2. CORS (opt-in) + if (config?.cors) { + const corsOptions = buildCorsOptions(config.cors); + app.use(corsMiddleware(corsOptions)); + features.push("CORS"); + } + + // 3. CSRF (origin validation) + if (config?.csrf !== false) { + const csrfConfig = + config?.csrf && typeof config.csrf === "object" ? config.csrf : undefined; + app.use(createCsrfMiddleware(csrfConfig)); + features.push("CSRF (origin validation)"); + } + + if (features.length > 0) { + logger.info("Security middleware enabled: %s", features.join(", ")); + } +} + +/** + * Register the global error handler middleware. + * + * Must be registered after all route handlers (Express convention). + * Acts as a safety net for unhandled errors — plugins can handle + * their own errors in route handlers without being affected. + */ +export function registerErrorHandler( + app: Application, + config?: SecurityConfig, +): void { + if (config?.errorHandler !== false) { + const errorConfig = + config?.errorHandler && typeof config.errorHandler === "object" + ? config.errorHandler + : undefined; + app.use(createErrorHandler(errorConfig)); + } +} diff --git a/packages/appkit/src/plugins/server/security/types.ts b/packages/appkit/src/plugins/server/security/types.ts new file mode 100644 index 00000000..ef2188ff --- /dev/null +++ b/packages/appkit/src/plugins/server/security/types.ts @@ -0,0 +1,48 @@ +import type { HelmetOptions } from "helmet"; + +/** Security configuration for the server plugin. Secure defaults applied when omitted. */ +export interface SecurityConfig { + /** CSRF protection via Origin header validation. Enabled by default in production. Set `false` to disable. */ + csrf?: CsrfConfig | false; + + /** Helmet security headers (CSP, X-Content-Type-Options, X-Frame-Options, COOP, etc.). + * Enabled by default with secure presets. Pass custom HelmetOptions to fully replace defaults, or `false` to disable. */ + helmet?: HelmetOptions | false; + + /** CORS configuration. Disabled by default — not registered unless explicitly configured. */ + cors?: CorsConfig | false; + + /** Global error handler preventing info disclosure. Enabled by default. Set `false` to disable (e.g. if you have your own). */ + errorHandler?: ErrorHandlerConfig | false; +} + +export interface CsrfConfig { + /** + * Additional trusted origins for CSRF validation (beyond DATABRICKS_APP_URL). + * Also merged with APPKIT_CSRF_ALLOWED_ORIGINS env var (comma-separated). + * All sources are unioned and deduplicated. + */ + allowedOrigins?: string[]; +} + +export interface CorsConfig { + /** + * Allowed origins for CORS. Also merged with APPKIT_CORS_ALLOWED_ORIGINS env var (comma-separated). + * All sources are unioned and deduplicated. + * If empty after merging, CORS rejects all cross-origin requests (safe default). + */ + allowedOrigins?: string[]; + /** Allow credentials (cookies). Default: false */ + credentials?: boolean; + /** Preflight cache duration in seconds. Default: 86400 (24h) */ + maxAge?: number; + /** Allowed HTTP methods. Default: ["GET","POST","PUT","DELETE","PATCH"] */ + allowedMethods?: string[]; + /** Allowed request headers. Default: ["Content-Type","Authorization"] */ + allowedHeaders?: string[]; +} + +export interface ErrorHandlerConfig { + /** Include AppKitError.code in error responses. Default: true (codes are safe to expose and help clients handle errors). */ + includeErrorCode?: boolean; +} diff --git a/packages/appkit/src/plugins/server/tests/security/csrf.test.ts b/packages/appkit/src/plugins/server/tests/security/csrf.test.ts new file mode 100644 index 00000000..aeaabd63 --- /dev/null +++ b/packages/appkit/src/plugins/server/tests/security/csrf.test.ts @@ -0,0 +1,228 @@ +import express from "express"; +import { afterEach, beforeEach, describe, expect, test, vi } from "vitest"; +import { createCsrfMiddleware } from "../../security/csrf"; + +function createTestApp(config?: Parameters[0]) { + const app = express(); + app.use(createCsrfMiddleware(config)); + app.post("/test", (_req, res) => res.json({ ok: true })); + app.get("/test", (_req, res) => res.json({ ok: true })); + app.put("/test", (_req, res) => res.json({ ok: true })); + app.delete("/test", (_req, res) => res.json({ ok: true })); + app.patch("/test", (_req, res) => res.json({ ok: true })); + return app; +} + +async function request( + app: express.Application, + method: string, + path: string, + headers?: Record, +) { + const server = app.listen(0); + const addr = server.address() as { port: number }; + try { + const res = await fetch(`http://127.0.0.1:${addr.port}${path}`, { + method, + headers, + }); + const body = await res.json(); + return { status: res.status, body }; + } finally { + await new Promise((resolve) => server.close(() => resolve())); + } +} + +describe("CSRF Middleware", () => { + const originalEnv = { ...process.env }; + + beforeEach(() => { + vi.stubEnv("NODE_ENV", "production"); + vi.stubEnv("DATABRICKS_APP_URL", "https://my-app.databricksapps.com"); + delete process.env.APPKIT_CSRF_ALLOWED_ORIGINS; + }); + + afterEach(() => { + vi.unstubAllEnvs(); + }); + + describe("method filtering", () => { + test("GET requests bypass CSRF", async () => { + const app = createTestApp(); + const res = await request(app, "GET", "/test", { + origin: "https://evil.com", + }); + expect(res.status).toBe(200); + }); + + test("POST requests are checked", async () => { + const app = createTestApp(); + const res = await request(app, "POST", "/test", { + origin: "https://evil.com", + }); + expect(res.status).toBe(403); + }); + + test("PUT requests are checked", async () => { + const app = createTestApp(); + const res = await request(app, "PUT", "/test", { + origin: "https://evil.com", + }); + expect(res.status).toBe(403); + }); + + test("DELETE requests are checked", async () => { + const app = createTestApp(); + const res = await request(app, "DELETE", "/test", { + origin: "https://evil.com", + }); + expect(res.status).toBe(403); + }); + + test("PATCH requests are checked", async () => { + const app = createTestApp(); + const res = await request(app, "PATCH", "/test", { + origin: "https://evil.com", + }); + expect(res.status).toBe(403); + }); + }); + + describe("origin validation", () => { + test("allows POST without Origin header (same-origin)", async () => { + const app = createTestApp(); + const res = await request(app, "POST", "/test"); + expect(res.status).toBe(200); + }); + + test("rejects POST with Origin: null", async () => { + const app = createTestApp(); + const res = await request(app, "POST", "/test", { origin: "null" }); + expect(res.status).toBe(403); + expect(res.body.error).toBe("CSRF validation failed"); + }); + + test("allows POST with matching DATABRICKS_APP_URL origin", async () => { + const app = createTestApp(); + const res = await request(app, "POST", "/test", { + origin: "https://my-app.databricksapps.com", + }); + expect(res.status).toBe(200); + }); + + test("rejects POST with non-matching origin", async () => { + const app = createTestApp(); + const res = await request(app, "POST", "/test", { + origin: "https://evil.com", + }); + expect(res.status).toBe(403); + }); + + test("case-insensitive origin comparison", async () => { + const app = createTestApp(); + const res = await request(app, "POST", "/test", { + origin: "https://MY-APP.DATABRICKSAPPS.COM", + }); + expect(res.status).toBe(200); + }); + + test("rejects non-HTTPS origins in production", async () => { + const app = createTestApp(); + const res = await request(app, "POST", "/test", { + origin: "http://my-app.databricksapps.com", + }); + expect(res.status).toBe(403); + }); + }); + + describe("config allowedOrigins", () => { + test("allows additional configured origins", async () => { + const app = createTestApp({ + allowedOrigins: ["https://partner.example.com"], + }); + const res = await request(app, "POST", "/test", { + origin: "https://partner.example.com", + }); + expect(res.status).toBe(200); + }); + }); + + describe("APPKIT_CSRF_ALLOWED_ORIGINS env var", () => { + test("merges origins from env var", async () => { + vi.stubEnv( + "APPKIT_CSRF_ALLOWED_ORIGINS", + "https://extra1.com, https://extra2.com", + ); + const app = createTestApp(); + const res = await request(app, "POST", "/test", { + origin: "https://extra1.com", + }); + expect(res.status).toBe(200); + }); + }); + + describe("Host header fallback", () => { + test("falls back to Host header when no trusted origins configured", async () => { + delete process.env.DATABRICKS_APP_URL; + const app = createTestApp(); + + // fetch sets Host header automatically + const server = app.listen(0); + const addr = server.address() as { port: number }; + try { + const res = await fetch(`http://127.0.0.1:${addr.port}/test`, { + method: "POST", + headers: { origin: `http://127.0.0.1:${addr.port}` }, + }); + // In production, non-HTTPS is rejected before fallback + expect(res.status).toBe(403); + } finally { + await new Promise((resolve) => server.close(() => resolve())); + } + }); + }); + + describe("dev mode", () => { + test("allows localhost origins in dev mode", async () => { + vi.stubEnv("NODE_ENV", "development"); + delete process.env.DATABRICKS_APP_URL; + const app = createTestApp(); + const res = await request(app, "POST", "/test", { + origin: "http://localhost:5173", + }); + expect(res.status).toBe(200); + }); + + test("allows 127.0.0.1 origins in dev mode", async () => { + vi.stubEnv("NODE_ENV", "development"); + delete process.env.DATABRICKS_APP_URL; + const app = createTestApp(); + const res = await request(app, "POST", "/test", { + origin: "http://127.0.0.1:8000", + }); + expect(res.status).toBe(200); + }); + + test("still rejects non-localhost origins in dev mode", async () => { + vi.stubEnv("NODE_ENV", "development"); + delete process.env.DATABRICKS_APP_URL; + const app = createTestApp(); + const res = await request(app, "POST", "/test", { + origin: "https://evil.com", + }); + expect(res.status).toBe(403); + // Dev mode includes detail + expect(res.body.detail).toBeDefined(); + }); + }); + + describe("disabled", () => { + test("csrf: false disables all CSRF checks", async () => { + const app = createTestApp(false); + const res = await request(app, "POST", "/test", { + origin: "https://evil.com", + }); + expect(res.status).toBe(200); + }); + }); +}); diff --git a/packages/appkit/src/plugins/server/tests/security/error-handler.test.ts b/packages/appkit/src/plugins/server/tests/security/error-handler.test.ts new file mode 100644 index 00000000..b261f608 --- /dev/null +++ b/packages/appkit/src/plugins/server/tests/security/error-handler.test.ts @@ -0,0 +1,162 @@ +import express from "express"; +import { afterEach, beforeEach, describe, expect, test, vi } from "vitest"; +import { AppKitError } from "../../../../errors/base"; +import { createErrorHandler } from "../../security/error-handler"; + +class TestAppKitError extends AppKitError { + readonly code = "TEST_ERROR"; + readonly statusCode = 422; + readonly isRetryable = false; +} + +function createTestApp( + config?: Parameters[0], + routeSetup?: (app: express.Application) => void, +) { + const app = express(); + app.use(express.json()); + + if (routeSetup) { + routeSetup(app); + } else { + app.get("/throw-appkit", (_req, res, next) => { + next( + new TestAppKitError("Something went wrong", { + context: { userId: "123" }, + }), + ); + }); + app.get("/throw-generic", (_req, res, next) => { + next(new Error("Unexpected failure with secret-token-abc")); + }); + app.post("/parse-json", (req, res) => { + res.json({ received: req.body }); + }); + } + + app.use(createErrorHandler(config)); + return app; +} + +async function request( + app: express.Application, + method: string, + path: string, + options?: { body?: string; headers?: Record }, +) { + const server = app.listen(0); + const addr = server.address() as { port: number }; + try { + const res = await fetch(`http://127.0.0.1:${addr.port}${path}`, { + method, + headers: options?.headers, + body: options?.body, + }); + const body = await res.json(); + return { status: res.status, body }; + } finally { + await new Promise((resolve) => server.close(() => resolve())); + } +} + +describe("Error Handler Middleware", () => { + beforeEach(() => { + vi.stubEnv("NODE_ENV", "production"); + }); + + afterEach(() => { + vi.unstubAllEnvs(); + }); + + describe("production mode", () => { + test("AppKitError returns statusCode and code, not message", async () => { + const app = createTestApp(); + const res = await request(app, "GET", "/throw-appkit"); + expect(res.status).toBe(422); + expect(res.body.error).toBe("Internal Server Error"); + expect(res.body.code).toBe("TEST_ERROR"); + expect(res.body.stack).toBeUndefined(); + expect(res.body.message).toBeUndefined(); + }); + + test("AppKitError hides code when includeErrorCode is false", async () => { + const app = createTestApp({ includeErrorCode: false }); + const res = await request(app, "GET", "/throw-appkit"); + expect(res.status).toBe(422); + expect(res.body.code).toBeUndefined(); + }); + + test("generic error returns 500 with no details", async () => { + const app = createTestApp(); + const res = await request(app, "GET", "/throw-generic"); + expect(res.status).toBe(500); + expect(res.body.error).toBe("Internal Server Error"); + expect(res.body.stack).toBeUndefined(); + // Should not leak the error message with "secret-token" + expect(JSON.stringify(res.body)).not.toContain("secret-token"); + }); + + test("malformed JSON returns 400 Bad Request", async () => { + const app = createTestApp(); + const res = await request(app, "POST", "/parse-json", { + body: "{invalid json", + headers: { "Content-Type": "application/json" }, + }); + expect(res.status).toBe(400); + expect(res.body.error).toBe("Bad Request"); + expect(res.body.message).toBeUndefined(); + }); + }); + + describe("dev mode", () => { + test("includes message and stack in dev mode", async () => { + vi.stubEnv("NODE_ENV", "development"); + const app = createTestApp(); + const res = await request(app, "GET", "/throw-generic"); + expect(res.status).toBe(500); + expect(res.body.error).toContain("Unexpected failure"); + expect(res.body.stack).toBeDefined(); + }); + + test("AppKitError includes message in dev mode", async () => { + vi.stubEnv("NODE_ENV", "development"); + const app = createTestApp(); + const res = await request(app, "GET", "/throw-appkit"); + expect(res.status).toBe(422); + expect(res.body.error).toBe("Something went wrong"); + expect(res.body.code).toBe("TEST_ERROR"); + expect(res.body.stack).toBeDefined(); + }); + + test("malformed JSON includes message in dev mode", async () => { + vi.stubEnv("NODE_ENV", "development"); + const app = createTestApp(); + const res = await request(app, "POST", "/parse-json", { + body: "{bad", + headers: { "Content-Type": "application/json" }, + }); + expect(res.status).toBe(400); + expect(res.body.message).toBeDefined(); + }); + }); + + describe("disabled", () => { + test("errorHandler: false passes errors through", async () => { + const app = express(); + app.get("/throw", (_req, _res, next) => { + next(new Error("test")); + }); + app.use(createErrorHandler(false)); + // Express default handler will return HTML 500 + const server = app.listen(0); + const addr = server.address() as { port: number }; + try { + const res = await fetch(`http://127.0.0.1:${addr.port}/throw`); + // Express default error handler returns 500 with HTML + expect(res.status).toBe(500); + } finally { + await new Promise((resolve) => server.close(() => resolve())); + } + }); + }); +}); diff --git a/packages/appkit/src/plugins/server/tests/security/security.test.ts b/packages/appkit/src/plugins/server/tests/security/security.test.ts new file mode 100644 index 00000000..3bc57786 --- /dev/null +++ b/packages/appkit/src/plugins/server/tests/security/security.test.ts @@ -0,0 +1,171 @@ +import type { Server } from "node:http"; +import { mockServiceContext, setupDatabricksEnv } from "@tools/test-helpers"; +import type { PluginManifest } from "shared"; +import { + afterAll, + afterEach, + beforeAll, + beforeEach, + describe, + expect, + test, + vi, +} from "vitest"; +import { ServiceContext } from "../../../../context/service-context"; +import { createApp } from "../../../../core"; +import { Plugin, toPlugin } from "../../../../plugin"; +import { server as serverPlugin } from "../../index"; + +describe("Security Integration", () => { + let server: Server; + let baseUrl: string; + let serviceContextMock: Awaited>; + const TEST_PORT = 9890; + + beforeAll(async () => { + vi.stubEnv("NODE_ENV", "production"); + vi.stubEnv("DATABRICKS_APP_URL", "https://my-app.databricksapps.com"); + setupDatabricksEnv(); + ServiceContext.reset(); + serviceContextMock = await mockServiceContext(); + + class TestPlugin extends Plugin { + static manifest = { + name: "test-sec", + displayName: "Test Security Plugin", + description: "Test plugin for security integration tests", + resources: { required: [], optional: [] }, + } satisfies PluginManifest<"test-sec">; + + injectRoutes(router: any) { + router.get("/data", (_req: any, res: any) => { + res.json({ data: "hello" }); + }); + router.post("/data", (req: any, res: any) => { + res.json({ received: req.body }); + }); + } + } + + const testPlugin = toPlugin(TestPlugin); + + const app = await createApp({ + plugins: [ + serverPlugin({ + port: TEST_PORT, + host: "127.0.0.1", + autoStart: false, + }), + testPlugin({}), + ], + }); + + await app.server.start(); + server = app.server.getServer(); + baseUrl = `http://127.0.0.1:${TEST_PORT}`; + await new Promise((resolve) => setTimeout(resolve, 100)); + }); + + afterAll(async () => { + vi.unstubAllEnvs(); + serviceContextMock?.restore(); + if (server) { + await new Promise((resolve, reject) => { + server.close((err) => { + if (err) reject(err); + else resolve(); + }); + }); + } + }); + + describe("security headers (Helmet)", () => { + test("sets Content-Security-Policy on responses", async () => { + const res = await fetch(`${baseUrl}/health`); + const csp = res.headers.get("content-security-policy"); + expect(csp).toBeDefined(); + expect(csp).toContain("default-src"); + expect(csp).toContain("frame-ancestors 'none'"); + }); + + test("sets X-Content-Type-Options: nosniff", async () => { + const res = await fetch(`${baseUrl}/health`); + expect(res.headers.get("x-content-type-options")).toBe("nosniff"); + }); + + test("sets Cross-Origin-Opener-Policy", async () => { + const res = await fetch(`${baseUrl}/health`); + expect(res.headers.get("cross-origin-opener-policy")).toBe("same-origin"); + }); + + test("sets X-Frame-Options", async () => { + const res = await fetch(`${baseUrl}/health`); + expect(res.headers.get("x-frame-options")).toBeDefined(); + }); + }); + + describe("CSRF protection", () => { + test("GET requests pass through CSRF", async () => { + const res = await fetch(`${baseUrl}/api/test-sec/data`, { + headers: { origin: "https://evil.com" }, + }); + expect(res.status).toBe(200); + }); + + test("POST without Origin passes (same-origin)", async () => { + const res = await fetch(`${baseUrl}/api/test-sec/data`, { + method: "POST", + headers: { "Content-Type": "application/json" }, + body: JSON.stringify({ test: true }), + }); + expect(res.status).toBe(200); + }); + + test("POST with matching origin passes", async () => { + const res = await fetch(`${baseUrl}/api/test-sec/data`, { + method: "POST", + headers: { + "Content-Type": "application/json", + Origin: "https://my-app.databricksapps.com", + }, + body: JSON.stringify({ test: true }), + }); + expect(res.status).toBe(200); + }); + + test("POST with evil origin is rejected 403", async () => { + const res = await fetch(`${baseUrl}/api/test-sec/data`, { + method: "POST", + headers: { + "Content-Type": "application/json", + Origin: "https://evil.com", + }, + body: JSON.stringify({ test: true }), + }); + expect(res.status).toBe(403); + const body = await res.json(); + expect(body.error).toBe("CSRF validation failed"); + }); + + test("POST with Origin: null is rejected 403", async () => { + const res = await fetch(`${baseUrl}/api/test-sec/data`, { + method: "POST", + headers: { + "Content-Type": "application/json", + Origin: "null", + }, + body: JSON.stringify({ test: true }), + }); + expect(res.status).toBe(403); + }); + }); + + describe("CORS (disabled by default)", () => { + test("no Access-Control headers when CORS is not configured", async () => { + const res = await fetch(`${baseUrl}/health`, { + headers: { Origin: "https://other.com" }, + }); + expect(res.headers.get("access-control-allow-origin")).toBeNull(); + }); + }); +}); diff --git a/packages/appkit/src/plugins/server/types.ts b/packages/appkit/src/plugins/server/types.ts index e187cacc..b2c36a35 100644 --- a/packages/appkit/src/plugins/server/types.ts +++ b/packages/appkit/src/plugins/server/types.ts @@ -1,5 +1,8 @@ import type { BasePluginConfig } from "shared"; import type { Plugin } from "../../plugin"; +import type { SecurityConfig } from "./security/types"; + +export type { SecurityConfig } from "./security/types"; export interface ServerConfig extends BasePluginConfig { port?: number; @@ -7,4 +10,8 @@ export interface ServerConfig extends BasePluginConfig { staticPath?: string; autoStart?: boolean; host?: string; + /** Request body size limit for JSON parsing. Default: "100kb". */ + bodyLimit?: string; + /** Security configuration. Secure defaults applied when omitted. */ + security?: SecurityConfig; } diff --git a/pnpm-lock.yaml b/pnpm-lock.yaml index 199fcfb8..30b3e529 100644 --- a/pnpm-lock.yaml +++ b/pnpm-lock.yaml @@ -296,12 +296,18 @@ importers: '@types/semver': specifier: 7.7.1 version: 7.7.1 + cors: + specifier: ^2.8.6 + version: 2.8.6 dotenv: specifier: 16.6.1 version: 16.6.1 express: specifier: 4.22.0 version: 4.22.0 + helmet: + specifier: ^8.1.0 + version: 8.1.0 obug: specifier: 2.1.1 version: 2.1.1 @@ -324,6 +330,9 @@ importers: specifier: 8.18.3 version: 8.18.3(bufferutil@4.0.9) devDependencies: + '@types/cors': + specifier: ^2.8.19 + version: 2.8.19 '@types/express': specifier: 4.17.25 version: 4.17.25 @@ -4789,6 +4798,9 @@ packages: '@types/conventional-commits-parser@5.0.1': resolution: {integrity: sha512-7uz5EHdzz2TqoMfV7ee61Egf5y6NkcO4FB/1iCCQnbeiI1F3xzv3vK5dBCXUCLQgGYS+mUeigK1iKQzvED+QnQ==} + '@types/cors@2.8.19': + resolution: {integrity: sha512-mFNylyeyqN93lfe/9CSxOGREz8cpzAhH+E93xJ4xWQf62V8sQ/24reV2nyzUWM6H6Xji+GGHpkbLe7pVoUEskg==} + '@types/d3-array@3.2.2': resolution: {integrity: sha512-hOLWVbm7uRza0BYXpIIW5pxfrKe0W+D5lrFiAEYR+pb6w3N2SwSMaJbXdUfSEv+dT4MfHBLtn5js0LAWaO6otw==} @@ -6077,6 +6089,10 @@ packages: core-util-is@1.0.3: resolution: {integrity: sha512-ZQBvi1DcpJ4GDqanjucZ2Hj3wEO5pZDS89BWbkcrvdxksJorwUDDZamX9ldFkp9aw2lmBDLgkObEA4DWNJ9FYQ==} + cors@2.8.6: + resolution: {integrity: sha512-tJtZBBHA6vjIAaF6EnIaq6laBBP9aq/Y3ouVJjEfoHbRBcHBAHYcMh/w8LDrk2PvIMMq8gmopa5D4V8RmbrxGw==} + engines: {node: '>= 0.10'} + cose-base@1.0.3: resolution: {integrity: sha512-s9whTXInMSgAp/NVXVNuVxVKzGH2qck3aQlVHxDCdAEPgtMKwc4Wq6/QKhgdEdgbLSi9rBTAcPoRa6JpiG4ksg==} @@ -7585,6 +7601,10 @@ packages: resolution: {integrity: sha512-F/1DnUGPopORZi0ni+CvrCgHQ5FyEAHRLSApuYWMmrbSwoN2Mn/7k+Gl38gJnR7yyDZk6WLXwiGod1JOWNDKGw==} hasBin: true + helmet@8.1.0: + resolution: {integrity: sha512-jOiHyAZsmnr8LqoPGmCjYAaiuWwjAPLgY8ZX2XrmHawt99/u1y6RgrZMTeoPfpUbV96HOalYgz1qzkRbw54Pmg==} + engines: {node: '>=18.0.0'} + hermes-estree@0.25.1: resolution: {integrity: sha512-0wUoCcLp+5Ev5pDW2OriHC2MJCbwLwuRx+gAqMTOkGKJJiBCLjtrvy4PWUGn6MIVefecRpzoOZ/UV6iGdOr+Cw==} @@ -17092,6 +17112,10 @@ snapshots: dependencies: '@types/node': 25.2.3 + '@types/cors@2.8.19': + dependencies: + '@types/node': 25.2.3 + '@types/d3-array@3.2.2': {} '@types/d3-axis@3.0.6': @@ -18640,6 +18664,11 @@ snapshots: core-util-is@1.0.3: {} + cors@2.8.6: + dependencies: + object-assign: 4.1.1 + vary: 1.1.2 + cose-base@1.0.3: dependencies: layout-base: 1.0.2 @@ -20374,6 +20403,8 @@ snapshots: he@1.2.0: {} + helmet@8.1.0: {} + hermes-estree@0.25.1: {} hermes-estree@0.33.3: