diff --git a/.gitignore b/.gitignore index 3b6cc969..4c51d5b1 100644 --- a/.gitignore +++ b/.gitignore @@ -9,3 +9,6 @@ coverage *.tsbuildinfo .turbo + +# AppKit type generator caches +.databricks diff --git a/apps/dev-playground/.env.dist b/apps/dev-playground/.env.dist index 23c3265a..80eda94b 100644 --- a/apps/dev-playground/.env.dist +++ b/apps/dev-playground/.env.dist @@ -9,6 +9,7 @@ OTEL_SERVICE_NAME='dev-playground' DATABRICKS_VOLUME_PLAYGROUND= DATABRICKS_VOLUME_OTHER= DATABRICKS_GENIE_SPACE_ID= +DATABRICKS_SERVING_ENDPOINT= LAKEBASE_ENDPOINT='' # Run: databricks postgres list-endpoints projects/{project-id}/branches/{branch-id} — use the `name` field from the output PGHOST= PGUSER= diff --git a/apps/dev-playground/client/.gitignore b/apps/dev-playground/client/.gitignore index a547bf36..267b28f3 100644 --- a/apps/dev-playground/client/.gitignore +++ b/apps/dev-playground/client/.gitignore @@ -12,6 +12,9 @@ dist dist-ssr *.local +# Auto-generated types (endpoint-specific, varies per developer) +src/appKitServingTypes.d.ts + # Editor directories and files .vscode/* !.vscode/extensions.json diff --git a/apps/dev-playground/client/src/routeTree.gen.ts b/apps/dev-playground/client/src/routeTree.gen.ts index c4c38d14..99ac75fc 100644 --- a/apps/dev-playground/client/src/routeTree.gen.ts +++ b/apps/dev-playground/client/src/routeTree.gen.ts @@ -12,6 +12,7 @@ import { Route as rootRouteImport } from './routes/__root' import { Route as TypeSafetyRouteRouteImport } from './routes/type-safety.route' import { Route as TelemetryRouteRouteImport } from './routes/telemetry.route' import { Route as SqlHelpersRouteRouteImport } from './routes/sql-helpers.route' +import { Route as ServingRouteRouteImport } from './routes/serving.route' import { Route as ReconnectRouteRouteImport } from './routes/reconnect.route' import { Route as LakebaseRouteRouteImport } from './routes/lakebase.route' import { Route as GenieRouteRouteImport } from './routes/genie.route' @@ -37,6 +38,11 @@ const SqlHelpersRouteRoute = SqlHelpersRouteRouteImport.update({ path: '/sql-helpers', getParentRoute: () => rootRouteImport, } as any) +const ServingRouteRoute = ServingRouteRouteImport.update({ + id: '/serving', + path: '/serving', + getParentRoute: () => rootRouteImport, +} as any) const ReconnectRouteRoute = ReconnectRouteRouteImport.update({ id: '/reconnect', path: '/reconnect', @@ -93,6 +99,7 @@ export interface FileRoutesByFullPath { '/genie': typeof GenieRouteRoute '/lakebase': typeof LakebaseRouteRoute '/reconnect': typeof ReconnectRouteRoute + '/serving': typeof ServingRouteRoute '/sql-helpers': typeof SqlHelpersRouteRoute '/telemetry': typeof TelemetryRouteRoute '/type-safety': typeof TypeSafetyRouteRoute @@ -107,6 +114,7 @@ export interface FileRoutesByTo { '/genie': typeof GenieRouteRoute '/lakebase': typeof LakebaseRouteRoute '/reconnect': typeof ReconnectRouteRoute + '/serving': typeof ServingRouteRoute '/sql-helpers': typeof SqlHelpersRouteRoute '/telemetry': typeof TelemetryRouteRoute '/type-safety': typeof TypeSafetyRouteRoute @@ -122,6 +130,7 @@ export interface FileRoutesById { '/genie': typeof GenieRouteRoute '/lakebase': typeof LakebaseRouteRoute '/reconnect': typeof ReconnectRouteRoute + '/serving': typeof ServingRouteRoute '/sql-helpers': typeof SqlHelpersRouteRoute '/telemetry': typeof TelemetryRouteRoute '/type-safety': typeof TypeSafetyRouteRoute @@ -138,6 +147,7 @@ export interface FileRouteTypes { | '/genie' | '/lakebase' | '/reconnect' + | '/serving' | '/sql-helpers' | '/telemetry' | '/type-safety' @@ -152,6 +162,7 @@ export interface FileRouteTypes { | '/genie' | '/lakebase' | '/reconnect' + | '/serving' | '/sql-helpers' | '/telemetry' | '/type-safety' @@ -166,6 +177,7 @@ export interface FileRouteTypes { | '/genie' | '/lakebase' | '/reconnect' + | '/serving' | '/sql-helpers' | '/telemetry' | '/type-safety' @@ -181,6 +193,7 @@ export interface RootRouteChildren { GenieRouteRoute: typeof GenieRouteRoute LakebaseRouteRoute: typeof LakebaseRouteRoute ReconnectRouteRoute: typeof ReconnectRouteRoute + ServingRouteRoute: typeof ServingRouteRoute SqlHelpersRouteRoute: typeof SqlHelpersRouteRoute TelemetryRouteRoute: typeof TelemetryRouteRoute TypeSafetyRouteRoute: typeof TypeSafetyRouteRoute @@ -209,6 +222,13 @@ declare module '@tanstack/react-router' { preLoaderRoute: typeof SqlHelpersRouteRouteImport parentRoute: typeof rootRouteImport } + '/serving': { + id: '/serving' + path: '/serving' + fullPath: '/serving' + preLoaderRoute: typeof ServingRouteRouteImport + parentRoute: typeof rootRouteImport + } '/reconnect': { id: '/reconnect' path: '/reconnect' @@ -285,6 +305,7 @@ const rootRouteChildren: RootRouteChildren = { GenieRouteRoute: GenieRouteRoute, LakebaseRouteRoute: LakebaseRouteRoute, ReconnectRouteRoute: ReconnectRouteRoute, + ServingRouteRoute: ServingRouteRoute, SqlHelpersRouteRoute: SqlHelpersRouteRoute, TelemetryRouteRoute: TelemetryRouteRoute, TypeSafetyRouteRoute: TypeSafetyRouteRoute, diff --git a/apps/dev-playground/client/src/routes/__root.tsx b/apps/dev-playground/client/src/routes/__root.tsx index 5cf74ce3..35a2282b 100644 --- a/apps/dev-playground/client/src/routes/__root.tsx +++ b/apps/dev-playground/client/src/routes/__root.tsx @@ -104,6 +104,14 @@ function RootComponent() { Files + + + diff --git a/apps/dev-playground/client/src/routes/index.tsx b/apps/dev-playground/client/src/routes/index.tsx index e331d93c..934b1467 100644 --- a/apps/dev-playground/client/src/routes/index.tsx +++ b/apps/dev-playground/client/src/routes/index.tsx @@ -218,6 +218,24 @@ function IndexRoute() { + + +
+

+ Model Serving +

+

+ Chat with a Databricks Model Serving endpoint using streaming + completions with real-time SSE responses. +

+ +
+
diff --git a/apps/dev-playground/client/src/routes/serving.route.tsx b/apps/dev-playground/client/src/routes/serving.route.tsx new file mode 100644 index 00000000..770d42f4 --- /dev/null +++ b/apps/dev-playground/client/src/routes/serving.route.tsx @@ -0,0 +1,148 @@ +import { useServingStream } from "@databricks/appkit-ui/react"; +import { createFileRoute } from "@tanstack/react-router"; +import { useEffect, useRef, useState } from "react"; + +export const Route = createFileRoute("/serving")({ + component: ServingRoute, +}); + +interface Message { + id: string; + role: "user" | "assistant"; + content: string; +} + +function extractContent(chunk: unknown): string { + return ( + (chunk as { choices?: { delta?: { content?: string } }[] })?.choices?.[0] + ?.delta?.content ?? "" + ); +} + +function ServingRoute() { + const [input, setInput] = useState(""); + const [messages, setMessages] = useState([]); + + const { stream, chunks, streaming, error, reset } = useServingStream({ + messages: [], + }); + + const streamingContent = chunks.map(extractContent).join(""); + + // Commit assistant message when streaming transitions from true → false + const prevStreamingRef = useRef(false); + useEffect(() => { + if (prevStreamingRef.current && !streaming && streamingContent) { + setMessages((prev) => [ + ...prev, + { + id: crypto.randomUUID(), + role: "assistant", + content: streamingContent, + }, + ]); + reset(); + } + prevStreamingRef.current = streaming; + }, [streaming, streamingContent, reset]); + + function handleSubmit(e: React.FormEvent) { + e.preventDefault(); + if (!input.trim() || streaming) return; + + const userMessage: Message = { + id: crypto.randomUUID(), + role: "user", + content: input.trim(), + }; + + const fullMessages = [ + ...messages.map(({ role, content }) => ({ role, content })), + { role: "user" as const, content: userMessage.content }, + ]; + + setMessages((prev) => [...prev, userMessage]); + setInput(""); + reset(); + stream({ messages: fullMessages }); + } + + return ( +
+
+
+
+

+ Model Serving +

+

+ Chat with a Databricks Model Serving endpoint. Set{" "} + + DATABRICKS_SERVING_ENDPOINT + {" "} + to enable. +

+
+ +
+ {/* Messages area */} +
+ {messages.map((msg) => ( +
+
+

{msg.content}

+
+
+ ))} + + {/* Streaming response */} + {streaming && ( +
+
+

+ {streamingContent || "..."} +

+
+
+ )} + + {error && ( +
+ Error: {error} +
+ )} +
+ + {/* Input area */} +
+ setInput(e.target.value)} + placeholder="Send a message..." + className="flex-1 rounded-md border px-3 py-2 text-sm bg-background" + disabled={streaming} + /> + +
+
+
+
+
+ ); +} diff --git a/apps/dev-playground/client/vite.config.ts b/apps/dev-playground/client/vite.config.ts index f892c62f..5f37880b 100644 --- a/apps/dev-playground/client/vite.config.ts +++ b/apps/dev-playground/client/vite.config.ts @@ -1,4 +1,5 @@ import path from "node:path"; +import { appKitServingTypesPlugin } from "@databricks/appkit"; import { tanstackRouter } from "@tanstack/router-plugin/vite"; import react from "@vitejs/plugin-react"; import { defineConfig } from "vite"; @@ -11,6 +12,7 @@ export default defineConfig({ target: "react", autoCodeSplitting: process.env.NODE_ENV !== "development", }), + appKitServingTypesPlugin(), ], server: { hmr: { diff --git a/apps/dev-playground/server/index.ts b/apps/dev-playground/server/index.ts index a4b6a2c6..af05b11f 100644 --- a/apps/dev-playground/server/index.ts +++ b/apps/dev-playground/server/index.ts @@ -1,5 +1,12 @@ import "reflect-metadata"; -import { analytics, createApp, files, genie, server } from "@databricks/appkit"; +import { + analytics, + createApp, + files, + genie, + server, + serving, +} from "@databricks/appkit"; import { WorkspaceClient } from "@databricks/sdk-experimental"; import { lakebaseExamples } from "./lakebase-examples-plugin"; import { reconnect } from "./reconnect-plugin"; @@ -26,6 +33,7 @@ createApp({ }), lakebaseExamples(), files(), + serving(), ], ...(process.env.APPKIT_E2E_TEST && { client: createMockClient() }), }).then((appkit) => { diff --git a/docs/docs/api/appkit/Function.appKitServingTypesPlugin.md b/docs/docs/api/appkit/Function.appKitServingTypesPlugin.md new file mode 100644 index 00000000..bc28660a --- /dev/null +++ b/docs/docs/api/appkit/Function.appKitServingTypesPlugin.md @@ -0,0 +1,24 @@ +# Function: appKitServingTypesPlugin() + +```ts +function appKitServingTypesPlugin(options?: AppKitServingTypesPluginOptions): Plugin$1; +``` + +Vite plugin to generate TypeScript types for AppKit serving endpoints. +Fetches OpenAPI schemas from Databricks and generates a .d.ts with +ServingEndpointRegistry module augmentation. + +Endpoint discovery order: +1. Explicit `endpoints` option (override) +2. AST extraction from server file (server/index.ts or server/server.ts) +3. DATABRICKS_SERVING_ENDPOINT env var (single default endpoint) + +## Parameters + +| Parameter | Type | +| ------ | ------ | +| `options?` | `AppKitServingTypesPluginOptions` | + +## Returns + +`Plugin$1` diff --git a/docs/docs/api/appkit/Function.extractServingEndpoints.md b/docs/docs/api/appkit/Function.extractServingEndpoints.md new file mode 100644 index 00000000..24a5b00d --- /dev/null +++ b/docs/docs/api/appkit/Function.extractServingEndpoints.md @@ -0,0 +1,24 @@ +# Function: extractServingEndpoints() + +```ts +function extractServingEndpoints(serverFilePath: string): + | Record + | null; +``` + +Extract serving endpoint config from a server file by AST-parsing it. +Looks for `serving({ endpoints: { alias: { env: "..." }, ... } })` calls +and extracts the endpoint alias names and their environment variable mappings. + +## Parameters + +| Parameter | Type | Description | +| ------ | ------ | ------ | +| `serverFilePath` | `string` | Absolute path to the server entry file | + +## Returns + + \| `Record`\<`string`, [`EndpointConfig`](Interface.EndpointConfig.md)\> + \| `null` + +Extracted endpoint config, or null if not found or not extractable diff --git a/docs/docs/api/appkit/Function.findServerFile.md b/docs/docs/api/appkit/Function.findServerFile.md new file mode 100644 index 00000000..2ed4e268 --- /dev/null +++ b/docs/docs/api/appkit/Function.findServerFile.md @@ -0,0 +1,19 @@ +# Function: findServerFile() + +```ts +function findServerFile(basePath: string): string | null; +``` + +Find the server entry file by checking candidate paths in order. + +## Parameters + +| Parameter | Type | Description | +| ------ | ------ | ------ | +| `basePath` | `string` | Project root directory to search from | + +## Returns + +`string` \| `null` + +Absolute path to the server file, or null if none found diff --git a/docs/docs/api/appkit/Interface.EndpointConfig.md b/docs/docs/api/appkit/Interface.EndpointConfig.md new file mode 100644 index 00000000..6ee94aa3 --- /dev/null +++ b/docs/docs/api/appkit/Interface.EndpointConfig.md @@ -0,0 +1,21 @@ +# Interface: EndpointConfig + +## Properties + +### env + +```ts +env: string; +``` + +Environment variable holding the endpoint name. + +*** + +### servedModel? + +```ts +optional servedModel: string; +``` + +Target a specific served model (bypasses traffic routing). diff --git a/docs/docs/api/appkit/Interface.ServingEndpointEntry.md b/docs/docs/api/appkit/Interface.ServingEndpointEntry.md new file mode 100644 index 00000000..fa054c3f --- /dev/null +++ b/docs/docs/api/appkit/Interface.ServingEndpointEntry.md @@ -0,0 +1,27 @@ +# Interface: ServingEndpointEntry + +Shape of a single registry entry. + +## Properties + +### chunk + +```ts +chunk: unknown; +``` + +*** + +### request + +```ts +request: Record; +``` + +*** + +### response + +```ts +response: unknown; +``` diff --git a/docs/docs/api/appkit/Interface.ServingEndpointRegistry.md b/docs/docs/api/appkit/Interface.ServingEndpointRegistry.md new file mode 100644 index 00000000..defe5270 --- /dev/null +++ b/docs/docs/api/appkit/Interface.ServingEndpointRegistry.md @@ -0,0 +1,5 @@ +# Interface: ServingEndpointRegistry + +Registry interface for serving endpoint type generation. +Empty by default — augmented by the Vite type generator's `.d.ts` output via module augmentation. +When populated, provides autocomplete for alias names and typed request/response/chunk per endpoint. diff --git a/docs/docs/api/appkit/TypeAlias.ServingFactory.md b/docs/docs/api/appkit/TypeAlias.ServingFactory.md new file mode 100644 index 00000000..9ccafef5 --- /dev/null +++ b/docs/docs/api/appkit/TypeAlias.ServingFactory.md @@ -0,0 +1,19 @@ +# Type Alias: ServingFactory + +```ts +type ServingFactory = keyof ServingEndpointRegistry extends never ? (alias?: string) => ServingEndpointMethods : (alias: K) => ServingEndpointMethods; +``` + +Factory function returned by `AppKit.serving`. + +This is a conditional type that adapts based on whether `ServingEndpointRegistry` +has been populated via module augmentation (generated by `appKitServingTypesPlugin()`): + +- **Registry empty (default):** `(alias?: string) => ServingEndpointMethods` — + accepts any alias string with untyped request/response/chunk. +- **Registry populated:** `(alias: K) => ServingEndpointMethods<...>` — + restricts `alias` to known endpoint keys and infers typed request/response/chunk + from the registry entry. + +Run `appKitServingTypesPlugin()` in your Vite config to generate the registry +augmentation and enable full type safety. diff --git a/docs/docs/api/appkit/index.md b/docs/docs/api/appkit/index.md index b5fb7ce0..faadf237 100644 --- a/docs/docs/api/appkit/index.md +++ b/docs/docs/api/appkit/index.md @@ -33,6 +33,7 @@ plugin architecture, and React integration. | [BasePluginConfig](Interface.BasePluginConfig.md) | Base configuration interface for AppKit plugins | | [CacheConfig](Interface.CacheConfig.md) | Configuration for the CacheInterceptor. Controls TTL, size limits, storage backend, and probabilistic cleanup. | | [DatabaseCredential](Interface.DatabaseCredential.md) | Database credentials with OAuth token for Postgres connection | +| [EndpointConfig](Interface.EndpointConfig.md) | - | | [GenerateDatabaseCredentialRequest](Interface.GenerateDatabaseCredentialRequest.md) | Request parameters for generating database OAuth credentials | | [ITelemetry](Interface.ITelemetry.md) | Plugin-facing interface for OpenTelemetry instrumentation. Provides a thin abstraction over OpenTelemetry APIs for plugins. | | [LakebasePoolConfig](Interface.LakebasePoolConfig.md) | Configuration for creating a Lakebase connection pool | @@ -42,6 +43,8 @@ plugin architecture, and React integration. | [ResourceEntry](Interface.ResourceEntry.md) | Internal representation of a resource in the registry. Extends ResourceRequirement with resolution state and plugin ownership. | | [ResourceFieldEntry](Interface.ResourceFieldEntry.md) | Defines a single field for a resource. Each field has its own environment variable and optional description. Single-value types use one key (e.g. id); multi-value types (database, secret) use multiple (e.g. instance_name, database_name or scope, key). | | [ResourceRequirement](Interface.ResourceRequirement.md) | Declares a resource requirement for a plugin. Can be defined statically in a manifest or dynamically via getResourceRequirements(). Narrows the generated base: type → ResourceType enum, permission → ResourcePermission union. | +| [ServingEndpointEntry](Interface.ServingEndpointEntry.md) | Shape of a single registry entry. | +| [ServingEndpointRegistry](Interface.ServingEndpointRegistry.md) | Registry interface for serving endpoint type generation. Empty by default — augmented by the Vite type generator's `.d.ts` output via module augmentation. When populated, provides autocomplete for alias names and typed request/response/chunk per endpoint. | | [StreamExecutionSettings](Interface.StreamExecutionSettings.md) | Execution settings for streaming endpoints. Extends PluginExecutionSettings with SSE stream configuration. | | [TelemetryConfig](Interface.TelemetryConfig.md) | OpenTelemetry configuration for AppKit applications | | [ValidationResult](Interface.ValidationResult.md) | Result of validating all registered resources against the environment. | @@ -54,6 +57,7 @@ plugin architecture, and React integration. | [IAppRouter](TypeAlias.IAppRouter.md) | Express router type for plugin route registration | | [PluginData](TypeAlias.PluginData.md) | Tuple of plugin class, config, and name. Created by `toPlugin()` and passed to `createApp()`. | | [ResourcePermission](TypeAlias.ResourcePermission.md) | Union of all possible permission levels across all resource types. | +| [ServingFactory](TypeAlias.ServingFactory.md) | Factory function returned by `AppKit.serving`. | | [ToPlugin](TypeAlias.ToPlugin.md) | Factory function type returned by `toPlugin()`. Accepts optional config and returns a PluginData tuple. | ## Variables @@ -66,9 +70,12 @@ plugin architecture, and React integration. | Function | Description | | ------ | ------ | +| [appKitServingTypesPlugin](Function.appKitServingTypesPlugin.md) | Vite plugin to generate TypeScript types for AppKit serving endpoints. Fetches OpenAPI schemas from Databricks and generates a .d.ts with ServingEndpointRegistry module augmentation. | | [appKitTypesPlugin](Function.appKitTypesPlugin.md) | Vite plugin to generate types for AppKit queries. Calls generateFromEntryPoint under the hood. | | [createApp](Function.createApp.md) | Bootstraps AppKit with the provided configuration. | | [createLakebasePool](Function.createLakebasePool.md) | Create a Lakebase pool with appkit's logger integration. Telemetry automatically uses appkit's OpenTelemetry configuration via global registry. | +| [extractServingEndpoints](Function.extractServingEndpoints.md) | Extract serving endpoint config from a server file by AST-parsing it. Looks for `serving({ endpoints: { alias: { env: "..." }, ... } })` calls and extracts the endpoint alias names and their environment variable mappings. | +| [findServerFile](Function.findServerFile.md) | Find the server entry file by checking candidate paths in order. | | [generateDatabaseCredential](Function.generateDatabaseCredential.md) | Generate OAuth credentials for Postgres database connection using the proper Postgres API. | | [getExecutionContext](Function.getExecutionContext.md) | Get the current execution context. | | [getLakebaseOrmConfig](Function.getLakebaseOrmConfig.md) | Get Lakebase connection configuration for ORMs that don't accept pg.Pool directly. | diff --git a/docs/docs/api/appkit/typedoc-sidebar.ts b/docs/docs/api/appkit/typedoc-sidebar.ts index 2f17b1d2..1d498d1a 100644 --- a/docs/docs/api/appkit/typedoc-sidebar.ts +++ b/docs/docs/api/appkit/typedoc-sidebar.ts @@ -97,6 +97,11 @@ const typedocSidebar: SidebarsConfig = { id: "api/appkit/Interface.DatabaseCredential", label: "DatabaseCredential" }, + { + type: "doc", + id: "api/appkit/Interface.EndpointConfig", + label: "EndpointConfig" + }, { type: "doc", id: "api/appkit/Interface.GenerateDatabaseCredentialRequest", @@ -142,6 +147,16 @@ const typedocSidebar: SidebarsConfig = { id: "api/appkit/Interface.ResourceRequirement", label: "ResourceRequirement" }, + { + type: "doc", + id: "api/appkit/Interface.ServingEndpointEntry", + label: "ServingEndpointEntry" + }, + { + type: "doc", + id: "api/appkit/Interface.ServingEndpointRegistry", + label: "ServingEndpointRegistry" + }, { type: "doc", id: "api/appkit/Interface.StreamExecutionSettings", @@ -183,6 +198,11 @@ const typedocSidebar: SidebarsConfig = { id: "api/appkit/TypeAlias.ResourcePermission", label: "ResourcePermission" }, + { + type: "doc", + id: "api/appkit/TypeAlias.ServingFactory", + label: "ServingFactory" + }, { type: "doc", id: "api/appkit/TypeAlias.ToPlugin", @@ -205,6 +225,11 @@ const typedocSidebar: SidebarsConfig = { type: "category", label: "Functions", items: [ + { + type: "doc", + id: "api/appkit/Function.appKitServingTypesPlugin", + label: "appKitServingTypesPlugin" + }, { type: "doc", id: "api/appkit/Function.appKitTypesPlugin", @@ -220,6 +245,16 @@ const typedocSidebar: SidebarsConfig = { id: "api/appkit/Function.createLakebasePool", label: "createLakebasePool" }, + { + type: "doc", + id: "api/appkit/Function.extractServingEndpoints", + label: "extractServingEndpoints" + }, + { + type: "doc", + id: "api/appkit/Function.findServerFile", + label: "findServerFile" + }, { type: "doc", id: "api/appkit/Function.generateDatabaseCredential", diff --git a/docs/docs/plugins/serving.md b/docs/docs/plugins/serving.md new file mode 100644 index 00000000..4b2d7a54 --- /dev/null +++ b/docs/docs/plugins/serving.md @@ -0,0 +1,213 @@ +--- +sidebar_position: 7 +--- + +# Serving plugin + +Provides an authenticated proxy to [Databricks Model Serving](https://docs.databricks.com/aws/en/machine-learning/model-serving) endpoints, with invoke and streaming support. + +**Key features:** +- Named endpoint aliases for multiple serving endpoints +- Non-streaming (`invoke`) and SSE streaming (`stream`) invocation +- Automatic OpenAPI type generation for request/response schemas +- Request body filtering based on endpoint schema +- On-behalf-of (OBO) user execution + +## Basic usage + +```ts +import { createApp, server, serving } from "@databricks/appkit"; + +await createApp({ + plugins: [ + server(), + serving(), + ], +}); +``` + +With no configuration, the plugin reads `DATABRICKS_SERVING_ENDPOINT` from the environment and registers it under the `default` alias. + +## Configuration options + +| Option | Type | Default | Description | +|--------|------|---------|-------------| +| `endpoints` | `Record` | `{ default: { env: "DATABRICKS_SERVING_ENDPOINT" } }` | Map of alias names to endpoint configs | +| `timeout` | `number` | `120000` | Request timeout in ms | + +### Endpoint aliases + +Endpoint aliases let you reference multiple serving endpoints by name: + +```ts +serving({ + endpoints: { + llm: { env: "DATABRICKS_SERVING_ENDPOINT" }, + classifier: { env: "DATABRICKS_SERVING_ENDPOINT_CLASSIFIER" }, + }, +}) +``` + +Each alias maps to an environment variable holding the actual endpoint name. If an endpoint serves multiple models, you can use `servedModel` to bypass traffic routing and target a specific model directly: + +```ts +serving({ + endpoints: { + llm: { env: "DATABRICKS_SERVING_ENDPOINT", servedModel: "llama-v2" }, + }, +}) +``` + +## Type generation + +The `appKitServingTypesPlugin()` Vite plugin generates TypeScript types from your serving endpoints' OpenAPI schemas. Add it to your `vite.config.ts`: + +```ts +import { appKitServingTypesPlugin } from "@databricks/appkit"; + +export default defineConfig({ + plugins: [ + appKitServingTypesPlugin(), + ], +}); +``` + +The plugin auto-discovers endpoint configuration from your server file (`server/index.ts` or `server/server.ts`) — no manual config passing needed. + +Generated types provide: +- **Alias autocomplete** in both backend (`AppKit.serving("alias")`) and frontend hooks (`useServingStream`, `useServingInvoke`) +- **Typed request/response/chunk** per endpoint based on OpenAPI schemas + +If an endpoint's OpenAPI schema is unavailable (not deployed, env var not set), the plugin generates generic fallback types. The endpoint is still usable — just without typed request/response. + +:::note +Endpoints that don't define a streaming response schema in their OpenAPI spec will have `chunk: unknown`. For these endpoints, use `useServingInvoke` instead of `useServingStream` — the `response` type will still be properly typed. +::: + +## Environment variables + +| Variable | Description | +|----------|-------------| +| `DATABRICKS_SERVING_ENDPOINT` | Default endpoint name (used when `endpoints` config is omitted) | + +When using named endpoints, define a custom environment variable per alias (e.g. `DATABRICKS_SERVING_ENDPOINT_CLASSIFIER`). + +## HTTP endpoints + +### Named mode (with `endpoints` config) + +- `POST /api/serving/:alias/invoke` — Non-streaming invocation +- `POST /api/serving/:alias/stream` — SSE streaming invocation + +### Default mode (no `endpoints` config) + +- `POST /api/serving/invoke` — Non-streaming invocation +- `POST /api/serving/stream` — SSE streaming invocation + +### Request format + +``` +POST /api/serving/:alias/invoke +Content-Type: application/json + +{ + "messages": [ + { "role": "user", "content": "Hello" } + ] +} +``` + +## Programmatic access + +The plugin exports `invoke` and `stream` methods for server-side use: + +```ts +const AppKit = await createApp({ + plugins: [ + server(), + serving({ + endpoints: { + llm: { env: "DATABRICKS_SERVING_ENDPOINT" }, + }, + }), + ], +}); + +// Non-streaming +const result = await AppKit.serving("llm").invoke({ + messages: [{ role: "user", content: "Hello" }], +}); + +// Streaming +for await (const chunk of AppKit.serving("llm").stream({ + messages: [{ role: "user", content: "Hello" }], +})) { + console.log(chunk); +} +``` + +## Frontend hooks + +The `@databricks/appkit-ui` package provides React hooks for serving endpoints: + +### useServingStream + +Streaming invocation via SSE: + +```tsx +import { useServingStream } from "@databricks/appkit-ui/react"; + +function ChatStream() { + const { stream, chunks, streaming, error, reset } = useServingStream( + { messages: [{ role: "user", content: "Hello" }] }, + { + alias: "llm", + onComplete: (finalChunks) => { + // Called with all accumulated chunks when the stream finishes + console.log("Stream done, got", finalChunks.length, "chunks"); + }, + }, + ); + + return ( + <> + + + {chunks.map((chunk, i) =>
{JSON.stringify(chunk)}
)} + {error &&

{error}

} + + ); +} +``` + +### useServingInvoke + +Non-streaming invocation. `invoke()` returns a promise with the response data (or `null` on error): + +```tsx +import { useServingInvoke } from "@databricks/appkit-ui/react"; + +function Classify() { + const { invoke, data, loading, error } = useServingInvoke( + { inputs: ["sample text"] }, + { alias: "classifier" }, + ); + + async function handleClick() { + const result = await invoke(); + if (result) { + console.log("Classification result:", result); + } + } + + return ( + <> + + {data &&
{JSON.stringify(data)}
} + {error &&

{error}

} + + ); +} +``` + +Both hooks accept `autoStart: true` to invoke automatically on mount. diff --git a/docs/static/appkit-ui/styles.gen.css b/docs/static/appkit-ui/styles.gen.css index 9a9a38eb..a2192039 100644 --- a/docs/static/appkit-ui/styles.gen.css +++ b/docs/static/appkit-ui/styles.gen.css @@ -831,9 +831,6 @@ .max-w-\[calc\(100\%-2rem\)\] { max-width: calc(100% - 2rem); } - .max-w-full { - max-width: 100%; - } .max-w-max { max-width: max-content; } @@ -4514,6 +4511,11 @@ width: calc(var(--spacing) * 5); } } + .\[\&_\[data-slot\=scroll-area-viewport\]\>div\]\:\!block { + & [data-slot=scroll-area-viewport]>div { + display: block !important; + } + } .\[\&_a\]\:underline { & a { text-decoration-line: underline; @@ -4637,11 +4639,26 @@ color: var(--muted-foreground); } } + .\[\&_table\]\:block { + & table { + display: block; + } + } + .\[\&_table\]\:max-w-full { + & table { + max-width: 100%; + } + } .\[\&_table\]\:border-collapse { & table { border-collapse: collapse; } } + .\[\&_table\]\:overflow-x-auto { + & table { + overflow-x: auto; + } + } .\[\&_table\]\:text-xs { & table { font-size: var(--text-xs); @@ -4851,6 +4868,11 @@ width: 100%; } } + .\[\&\>\*\]\:min-w-0 { + &>* { + min-width: calc(var(--spacing) * 0); + } + } .\[\&\>\*\]\:focus-visible\:relative { &>* { &:focus-visible { diff --git a/packages/appkit-ui/src/react/hooks/__tests__/use-serving-invoke.test.ts b/packages/appkit-ui/src/react/hooks/__tests__/use-serving-invoke.test.ts new file mode 100644 index 00000000..6d5f159f --- /dev/null +++ b/packages/appkit-ui/src/react/hooks/__tests__/use-serving-invoke.test.ts @@ -0,0 +1,117 @@ +import { act, renderHook, waitFor } from "@testing-library/react"; +import { afterEach, beforeEach, describe, expect, test, vi } from "vitest"; +import { useServingInvoke } from "../use-serving-invoke"; + +describe("useServingInvoke", () => { + beforeEach(() => { + vi.spyOn(globalThis, "fetch").mockResolvedValue( + new Response(JSON.stringify({ choices: [] }), { status: 200 }), + ); + }); + + afterEach(() => { + vi.restoreAllMocks(); + }); + + test("initial state is idle", () => { + const { result } = renderHook(() => useServingInvoke({ messages: [] })); + + expect(result.current.data).toBeNull(); + expect(result.current.loading).toBe(false); + expect(result.current.error).toBeNull(); + expect(typeof result.current.invoke).toBe("function"); + }); + + test("calls fetch to correct URL on invoke", async () => { + const fetchSpy = vi.spyOn(globalThis, "fetch"); + + const { result } = renderHook(() => + useServingInvoke({ messages: [{ role: "user", content: "Hello" }] }), + ); + + act(() => { + result.current.invoke(); + }); + + await waitFor(() => { + expect(fetchSpy).toHaveBeenCalledWith( + "/api/serving/invoke", + expect.objectContaining({ + method: "POST", + body: JSON.stringify({ + messages: [{ role: "user", content: "Hello" }], + }), + }), + ); + }); + }); + + test("uses alias in URL when provided", async () => { + const fetchSpy = vi.spyOn(globalThis, "fetch"); + + const { result } = renderHook(() => + useServingInvoke({ messages: [] }, { alias: "llm" }), + ); + + act(() => { + result.current.invoke(); + }); + + await waitFor(() => { + expect(fetchSpy).toHaveBeenCalledWith( + "/api/serving/llm/invoke", + expect.any(Object), + ); + }); + }); + + test("sets data on successful response", async () => { + const responseData = { + choices: [{ message: { content: "Hi" } }], + }; + + vi.spyOn(globalThis, "fetch").mockResolvedValue( + new Response(JSON.stringify(responseData), { status: 200 }), + ); + + const { result } = renderHook(() => useServingInvoke({ messages: [] })); + + act(() => { + result.current.invoke(); + }); + + await waitFor(() => { + expect(result.current.data).toEqual(responseData); + expect(result.current.loading).toBe(false); + }); + }); + + test("sets error on failed response", async () => { + vi.spyOn(globalThis, "fetch").mockResolvedValue( + new Response(JSON.stringify({ error: "Not found" }), { status: 404 }), + ); + + const { result } = renderHook(() => useServingInvoke({ messages: [] })); + + await act(async () => { + result.current.invoke(); + // Wait for the fetch promise chain to resolve + await new Promise((r) => setTimeout(r, 10)); + }); + + await waitFor(() => { + expect(result.current.error).toBe("Not found"); + expect(result.current.loading).toBe(false); + }); + }); + + test("auto starts when autoStart is true", async () => { + const fetchSpy = vi.spyOn(globalThis, "fetch"); + + renderHook(() => useServingInvoke({ messages: [] }, { autoStart: true })); + + await waitFor(() => { + expect(fetchSpy).toHaveBeenCalled(); + }); + }); +}); diff --git a/packages/appkit-ui/src/react/hooks/__tests__/use-serving-stream.test.ts b/packages/appkit-ui/src/react/hooks/__tests__/use-serving-stream.test.ts new file mode 100644 index 00000000..1ab0bf44 --- /dev/null +++ b/packages/appkit-ui/src/react/hooks/__tests__/use-serving-stream.test.ts @@ -0,0 +1,291 @@ +import { act, renderHook, waitFor } from "@testing-library/react"; +import { afterEach, describe, expect, test, vi } from "vitest"; + +// Mock connectSSE — capture callbacks so we can simulate SSE events +let capturedCallbacks: { + onMessage?: (msg: { data: string }) => void; + onError?: (err: Error) => void; + signal?: AbortSignal; +} = {}; + +let resolveStream: (() => void) | null = null; + +const mockConnectSSE = vi.fn().mockImplementation((opts: any) => { + capturedCallbacks = { + onMessage: opts.onMessage, + onError: opts.onError, + signal: opts.signal, + }; + return new Promise((resolve) => { + resolveStream = resolve; + // Also resolve after a tick as fallback for tests that don't manually resolve + setTimeout(resolve, 0); + }); +}); + +vi.mock("@/js", () => ({ + connectSSE: (...args: unknown[]) => mockConnectSSE(...args), +})); + +import { useServingStream } from "../use-serving-stream"; + +describe("useServingStream", () => { + afterEach(() => { + capturedCallbacks = {}; + resolveStream = null; + vi.clearAllMocks(); + }); + + test("initial state is idle", () => { + const { result } = renderHook(() => useServingStream({ messages: [] })); + + expect(result.current.chunks).toEqual([]); + expect(result.current.streaming).toBe(false); + expect(result.current.error).toBeNull(); + expect(typeof result.current.stream).toBe("function"); + expect(typeof result.current.reset).toBe("function"); + }); + + test("calls connectSSE with correct URL on stream", () => { + const { result } = renderHook(() => useServingStream({ messages: [] })); + + act(() => { + result.current.stream(); + }); + + expect(mockConnectSSE).toHaveBeenCalledWith( + expect.objectContaining({ + url: "/api/serving/stream", + payload: JSON.stringify({ messages: [] }), + }), + ); + }); + + test("uses override body when passed to stream()", () => { + const { result } = renderHook(() => + useServingStream({ messages: [{ role: "user", content: "old" }] }), + ); + + const overrideBody = { + messages: [{ role: "user" as const, content: "new" }], + }; + + act(() => { + result.current.stream(overrideBody); + }); + + expect(mockConnectSSE).toHaveBeenCalledWith( + expect.objectContaining({ + payload: JSON.stringify(overrideBody), + }), + ); + }); + + test("uses alias in URL when provided", () => { + const { result } = renderHook(() => + useServingStream({ messages: [] }, { alias: "embedder" }), + ); + + act(() => { + result.current.stream(); + }); + + expect(mockConnectSSE).toHaveBeenCalledWith( + expect.objectContaining({ + url: "/api/serving/embedder/stream", + }), + ); + }); + + test("sets streaming to true when stream() is called", () => { + const { result } = renderHook(() => useServingStream({ messages: [] })); + + act(() => { + result.current.stream(); + }); + + expect(result.current.streaming).toBe(true); + }); + + test("accumulates chunks from onMessage", async () => { + const { result } = renderHook(() => useServingStream({ messages: [] })); + + act(() => { + result.current.stream(); + }); + + act(() => { + capturedCallbacks.onMessage?.({ data: JSON.stringify({ id: 1 }) }); + }); + + act(() => { + capturedCallbacks.onMessage?.({ data: JSON.stringify({ id: 2 }) }); + }); + + expect(result.current.chunks).toEqual([{ id: 1 }, { id: 2 }]); + }); + + test("accumulates chunks with error field as normal data", async () => { + const { result } = renderHook(() => useServingStream({ messages: [] })); + + act(() => { + result.current.stream(); + }); + + act(() => { + capturedCallbacks.onMessage?.({ + data: JSON.stringify({ error: "Model overloaded" }), + }); + }); + + // Chunks with an `error` field are treated as data, not stream errors. + // Transport-level errors are delivered via onError callback instead. + expect(result.current.chunks).toEqual([{ error: "Model overloaded" }]); + expect(result.current.error).toBeNull(); + expect(result.current.streaming).toBe(true); + }); + + test("sets error from onError callback", async () => { + const { result } = renderHook(() => useServingStream({ messages: [] })); + + act(() => { + result.current.stream(); + }); + + act(() => { + capturedCallbacks.onError?.(new Error("Connection lost")); + }); + + expect(result.current.error).toBe("Connection lost"); + expect(result.current.streaming).toBe(false); + }); + + test("silently skips malformed JSON messages", () => { + const { result } = renderHook(() => useServingStream({ messages: [] })); + + act(() => { + result.current.stream(); + }); + + act(() => { + capturedCallbacks.onMessage?.({ data: "not valid json{" }); + }); + + // No chunks added, no error set + expect(result.current.chunks).toEqual([]); + expect(result.current.error).toBeNull(); + }); + + test("reset() clears state and aborts active stream", () => { + const { result } = renderHook(() => useServingStream({ messages: [] })); + + act(() => { + result.current.stream(); + }); + + act(() => { + capturedCallbacks.onMessage?.({ data: JSON.stringify({ id: 1 }) }); + }); + + expect(result.current.chunks).toHaveLength(1); + expect(result.current.streaming).toBe(true); + + act(() => { + result.current.reset(); + }); + + expect(result.current.chunks).toEqual([]); + expect(result.current.streaming).toBe(false); + expect(result.current.error).toBeNull(); + }); + + test("autoStart triggers stream on mount", async () => { + renderHook(() => useServingStream({ messages: [] }, { autoStart: true })); + + await waitFor(() => { + expect(mockConnectSSE).toHaveBeenCalled(); + }); + }); + + test("passes abort signal to connectSSE", () => { + const { result } = renderHook(() => useServingStream({ messages: [] })); + + act(() => { + result.current.stream(); + }); + + expect(capturedCallbacks.signal).toBeDefined(); + expect(capturedCallbacks.signal?.aborted).toBe(false); + }); + + test("aborts stream on unmount", () => { + const { result, unmount } = renderHook(() => + useServingStream({ messages: [] }), + ); + + act(() => { + result.current.stream(); + }); + + const signal = capturedCallbacks.signal; + expect(signal?.aborted).toBe(false); + + unmount(); + + expect(signal?.aborted).toBe(true); + }); + + test("sets streaming to false when connectSSE resolves", async () => { + const { result } = renderHook(() => useServingStream({ messages: [] })); + + act(() => { + result.current.stream(); + }); + + await waitFor(() => { + expect(result.current.streaming).toBe(false); + }); + }); + + test("calls onComplete with accumulated chunks when stream finishes", async () => { + const onComplete = vi.fn(); + + // Use a controllable mock so stream doesn't auto-resolve + mockConnectSSE.mockImplementationOnce((opts: any) => { + capturedCallbacks = { + onMessage: opts.onMessage, + onError: opts.onError, + signal: opts.signal, + }; + return new Promise((resolve) => { + resolveStream = resolve; + }); + }); + + const { result } = renderHook(() => + useServingStream({ messages: [] }, { onComplete }), + ); + + act(() => { + result.current.stream(); + }); + + // Send two chunks + act(() => { + capturedCallbacks.onMessage?.({ data: JSON.stringify({ id: 1 }) }); + }); + act(() => { + capturedCallbacks.onMessage?.({ data: JSON.stringify({ id: 2 }) }); + }); + + expect(onComplete).not.toHaveBeenCalled(); + + // Complete the stream + await act(async () => { + resolveStream?.(); + await new Promise((r) => setTimeout(r, 0)); + }); + + expect(onComplete).toHaveBeenCalledWith([{ id: 1 }, { id: 2 }]); + }); +}); diff --git a/packages/appkit-ui/src/react/hooks/index.ts b/packages/appkit-ui/src/react/hooks/index.ts index 84d51b53..a425b010 100644 --- a/packages/appkit-ui/src/react/hooks/index.ts +++ b/packages/appkit-ui/src/react/hooks/index.ts @@ -2,8 +2,13 @@ export type { AnalyticsFormat, InferResultByFormat, InferRowType, + InferServingChunk, + InferServingRequest, + InferServingResponse, PluginRegistry, QueryRegistry, + ServingAlias, + ServingEndpointRegistry, TypedArrowTable, UseAnalyticsQueryOptions, UseAnalyticsQueryResult, @@ -15,3 +20,13 @@ export { useChartData, } from "./use-chart-data"; export { usePluginClientConfig } from "./use-plugin-config"; +export { + type UseServingInvokeOptions, + type UseServingInvokeResult, + useServingInvoke, +} from "./use-serving-invoke"; +export { + type UseServingStreamOptions, + type UseServingStreamResult, + useServingStream, +} from "./use-serving-stream"; diff --git a/packages/appkit-ui/src/react/hooks/types.ts b/packages/appkit-ui/src/react/hooks/types.ts index 5db725fc..19ce1fac 100644 --- a/packages/appkit-ui/src/react/hooks/types.ts +++ b/packages/appkit-ui/src/react/hooks/types.ts @@ -134,3 +134,54 @@ export type InferParams = K extends AugmentedRegistry export interface PluginRegistry { [key: string]: Record; } + +// ============================================================================ +// Serving Endpoint Registry +// ============================================================================ + +/** + * Serving endpoint registry for type-safe alias names. + * Extend this interface via module augmentation to get alias autocomplete: + * + * @example + * ```typescript + * // Auto-generated by appKitServingTypesPlugin() + * declare module "@databricks/appkit-ui/react" { + * interface ServingEndpointRegistry { + * llm: { request: {...}; response: {...}; chunk: {...} }; + * } + * } + * ``` + */ +// biome-ignore lint/suspicious/noEmptyInterface: intentionally empty — populated via module augmentation +export interface ServingEndpointRegistry {} + +/** Resolves to registry keys if populated, otherwise string */ +export type ServingAlias = + AugmentedRegistry extends never + ? string + : AugmentedRegistry; + +/** Infers chunk type from registry when alias is a known key */ +export type InferServingChunk = + K extends AugmentedRegistry + ? ServingEndpointRegistry[K] extends { chunk: infer C } + ? C + : unknown + : unknown; + +/** Infers response type from registry when alias is a known key */ +export type InferServingResponse = + K extends AugmentedRegistry + ? ServingEndpointRegistry[K] extends { response: infer R } + ? R + : unknown + : unknown; + +/** Infers request type from registry when alias is a known key */ +export type InferServingRequest = + K extends AugmentedRegistry + ? ServingEndpointRegistry[K] extends { request: infer Req } + ? Req + : Record + : Record; diff --git a/packages/appkit-ui/src/react/hooks/use-serving-invoke.ts b/packages/appkit-ui/src/react/hooks/use-serving-invoke.ts new file mode 100644 index 00000000..8e80e82e --- /dev/null +++ b/packages/appkit-ui/src/react/hooks/use-serving-invoke.ts @@ -0,0 +1,111 @@ +import { useCallback, useEffect, useRef, useState } from "react"; +import type { + InferServingRequest, + InferServingResponse, + ServingAlias, +} from "./types"; + +export interface UseServingInvokeOptions< + K extends ServingAlias = ServingAlias, +> { + /** Endpoint alias for named mode. Omit for default mode. */ + alias?: K; + /** If false, does not invoke automatically on mount. Default: false */ + autoStart?: boolean; +} + +export interface UseServingInvokeResult< + T = unknown, + TBody = Record, +> { + /** Trigger the invocation. Pass an optional body override for this invocation. */ + invoke: (overrideBody?: TBody) => Promise; + /** Response data, null until loaded. */ + data: T | null; + /** Whether a request is in progress. */ + loading: boolean; + /** Error message, if any. */ + error: string | null; +} + +/** + * Hook for non-streaming invocation of a serving endpoint. + * Calls `POST /api/serving/invoke` (default) or `POST /api/serving/{alias}/invoke` (named). + * + * When the type generator has populated `ServingEndpointRegistry`, the response type + * is automatically inferred from the endpoint's OpenAPI schema. + */ +export function useServingInvoke( + body: InferServingRequest, + options: UseServingInvokeOptions = {} as UseServingInvokeOptions, +): UseServingInvokeResult, InferServingRequest> { + type TResponse = InferServingResponse; + const { alias, autoStart = false } = options; + + const [data, setData] = useState(null); + const [loading, setLoading] = useState(false); + const [error, setError] = useState(null); + const abortControllerRef = useRef(null); + + const urlSuffix = alias + ? `/api/serving/${encodeURIComponent(String(alias))}/invoke` + : "/api/serving/invoke"; + + const bodyJson = JSON.stringify(body); + + const invoke = useCallback( + (overrideBody?: InferServingRequest): Promise => { + if (abortControllerRef.current) { + abortControllerRef.current.abort(); + } + + setLoading(true); + setError(null); + setData(null); + + const abortController = new AbortController(); + abortControllerRef.current = abortController; + + const payload = overrideBody ? JSON.stringify(overrideBody) : bodyJson; + + return fetch(urlSuffix, { + method: "POST", + headers: { "Content-Type": "application/json" }, + body: payload, + signal: abortController.signal, + }) + .then(async (res) => { + if (!res.ok) { + const errorBody = await res.json().catch(() => null); + throw new Error(errorBody?.error || `HTTP ${res.status}`); + } + return res.json(); + }) + .then((result: TResponse) => { + if (abortController.signal.aborted) return null; + setData(result); + setLoading(false); + return result; + }) + .catch((err: Error) => { + if (abortController.signal.aborted) return null; + setError(err.message || "Request failed"); + setLoading(false); + return null; + }); + }, + [urlSuffix, bodyJson], + ); + + useEffect(() => { + if (autoStart) { + invoke(); + } + + return () => { + abortControllerRef.current?.abort(); + }; + }, [invoke, autoStart]); + + return { invoke, data, loading, error }; +} diff --git a/packages/appkit-ui/src/react/hooks/use-serving-stream.ts b/packages/appkit-ui/src/react/hooks/use-serving-stream.ts new file mode 100644 index 00000000..f0bb7bf2 --- /dev/null +++ b/packages/appkit-ui/src/react/hooks/use-serving-stream.ts @@ -0,0 +1,137 @@ +import { useCallback, useEffect, useRef, useState } from "react"; +import { connectSSE } from "@/js"; +import type { + InferServingChunk, + InferServingRequest, + ServingAlias, +} from "./types"; + +export interface UseServingStreamOptions< + K extends ServingAlias = ServingAlias, + T = InferServingChunk, +> { + /** Endpoint alias for named mode. Omit for default mode. */ + alias?: K; + /** If true, starts streaming automatically on mount. Default: false */ + autoStart?: boolean; + /** Called with accumulated chunks when the stream completes successfully. */ + onComplete?: (chunks: T[]) => void; +} + +export interface UseServingStreamResult< + T = unknown, + TBody = Record, +> { + /** Trigger the streaming invocation. Pass an optional body override for this invocation. */ + stream: (overrideBody?: TBody) => void; + /** Accumulated chunks received so far. */ + chunks: T[]; + /** Whether streaming is in progress. */ + streaming: boolean; + /** Error message, if any. */ + error: string | null; + /** Reset chunks and abort any active stream. */ + reset: () => void; +} + +/** + * Hook for streaming invocation of a serving endpoint via SSE. + * Calls `POST /api/serving/stream` (default) or `POST /api/serving/{alias}/stream` (named). + * Accumulates parsed chunks in state. + * + * When the type generator has populated `ServingEndpointRegistry`, the chunk type + * is automatically inferred from the endpoint's OpenAPI schema. + */ +export function useServingStream( + body: InferServingRequest, + options: UseServingStreamOptions = {} as UseServingStreamOptions, +): UseServingStreamResult, InferServingRequest> { + type TChunk = InferServingChunk; + const { alias, autoStart = false, onComplete } = options; + + const [chunks, setChunks] = useState([]); + const [streaming, setStreaming] = useState(false); + const [error, setError] = useState(null); + const abortControllerRef = useRef(null); + const chunksRef = useRef([]); + const onCompleteRef = useRef(onComplete); + onCompleteRef.current = onComplete; + + const urlSuffix = alias + ? `/api/serving/${encodeURIComponent(String(alias))}/stream` + : "/api/serving/stream"; + + const reset = useCallback(() => { + abortControllerRef.current?.abort(); + abortControllerRef.current = null; + chunksRef.current = []; + setChunks([]); + setStreaming(false); + setError(null); + }, []); + + const bodyJson = JSON.stringify(body); + + const stream = useCallback( + (overrideBody?: InferServingRequest) => { + // Abort any existing stream + abortControllerRef.current?.abort(); + + setStreaming(true); + setError(null); + setChunks([]); + chunksRef.current = []; + + const abortController = new AbortController(); + abortControllerRef.current = abortController; + + const payload = overrideBody ? JSON.stringify(overrideBody) : bodyJson; + + connectSSE({ + url: urlSuffix, + payload, + signal: abortController.signal, + onMessage: async (message) => { + if (abortController.signal.aborted) return; + try { + const parsed = JSON.parse(message.data); + + chunksRef.current = [...chunksRef.current, parsed as TChunk]; + setChunks(chunksRef.current); + } catch { + // Skip malformed messages + } + }, + onError: (err) => { + if (abortController.signal.aborted) return; + setStreaming(false); + setError(err instanceof Error ? err.message : "Streaming failed"); + }, + }) + .then(() => { + if (abortController.signal.aborted) return; + // Stream completed + setStreaming(false); + onCompleteRef.current?.(chunksRef.current); + }) + .catch(() => { + if (abortController.signal.aborted) return; + setStreaming(false); + setError("Connection error"); + }); + }, + [urlSuffix, bodyJson], + ); + + useEffect(() => { + if (autoStart) { + stream(); + } + + return () => { + abortControllerRef.current?.abort(); + }; + }, [stream, autoStart]); + + return { stream, chunks, streaming, error, reset }; +} diff --git a/packages/appkit/package.json b/packages/appkit/package.json index 9e810b97..06da3ee1 100644 --- a/packages/appkit/package.json +++ b/packages/appkit/package.json @@ -50,6 +50,7 @@ "typecheck": "tsc --noEmit" }, "dependencies": { + "@ast-grep/napi": "0.37.0", "@databricks/lakebase": "workspace:*", "@databricks/sdk-experimental": "0.16.0", "@opentelemetry/api": "1.9.0", diff --git a/packages/appkit/src/connectors/serving/client.ts b/packages/appkit/src/connectors/serving/client.ts new file mode 100644 index 00000000..6254426d --- /dev/null +++ b/packages/appkit/src/connectors/serving/client.ts @@ -0,0 +1,223 @@ +import { ApiError, type WorkspaceClient } from "@databricks/sdk-experimental"; +import { createLogger } from "../../logging/logger"; +import type { ServingInvokeOptions } from "./types"; + +const logger = createLogger("connectors:serving"); + +/** + * Builds the invocation URL for a serving endpoint. + * Uses `/served-models/{model}/invocations` when servedModel is specified, + * otherwise `/serving-endpoints/{name}/invocations`. + */ +function buildInvocationUrl( + host: string, + endpointName: string, + servedModel?: string, +): string { + const base = host.startsWith("http") ? host : `https://${host}`; + const encodedName = encodeURIComponent(endpointName); + const path = servedModel + ? `/serving-endpoints/${encodedName}/served-models/${encodeURIComponent(servedModel)}/invocations` + : `/serving-endpoints/${encodedName}/invocations`; + return new URL(path, base).toString(); +} + +/** + * Maps upstream Databricks error status codes to appropriate proxy responses. + */ +function mapUpstreamError( + status: number, + body: string, + headers: Headers, +): ApiError { + const safeMessage = body.length > 500 ? `${body.slice(0, 500)}...` : body; + + let parsed: { message?: string; error?: string } = {}; + try { + parsed = JSON.parse(body); + } catch { + // body is not JSON + } + + const message = parsed.message || parsed.error || safeMessage; + + switch (true) { + case status === 400: + return new ApiError(message, "BAD_REQUEST", 400, undefined, []); + case status === 401 || status === 403: + logger.warn("Authentication failure from serving endpoint: %s", message); + return new ApiError(message, "AUTH_FAILURE", status, undefined, []); + case status === 404: + return new ApiError(message, "NOT_FOUND", 404, undefined, []); + case status === 429: { + const retryAfter = headers.get("retry-after"); + const retryMessage = retryAfter + ? `${message} (retry-after: ${retryAfter})` + : message; + return new ApiError(retryMessage, "RATE_LIMITED", 429, undefined, []); + } + case status === 503: + return new ApiError( + "Endpoint loading, retry shortly", + "SERVICE_UNAVAILABLE", + 503, + undefined, + [], + ); + case status >= 500: + return new ApiError(message, "BAD_GATEWAY", 502, undefined, []); + default: + return new ApiError(message, "UNKNOWN", status, undefined, []); + } +} + +/** + * Invokes a serving endpoint and returns the parsed JSON response. + */ +export async function invoke( + client: WorkspaceClient, + endpointName: string, + body: Record, + options?: ServingInvokeOptions, +): Promise { + const host = client.config.host; + if (!host) { + throw new Error( + "Databricks host is not configured. Set DATABRICKS_HOST or configure client.config.host.", + ); + } + + const url = buildInvocationUrl(host, endpointName, options?.servedModel); + + // Always strip `stream` from the body — the connector controls this + const { stream: _stream, ...cleanBody } = body; + + const headers = new Headers({ + "Content-Type": "application/json", + Accept: "application/json", + }); + await client.config.authenticate(headers); + + logger.debug("Invoking endpoint %s at %s", endpointName, url); + + const res = await fetch(url, { + method: "POST", + headers, + body: JSON.stringify(cleanBody), + signal: options?.signal, + }); + + if (!res.ok) { + const text = await res.text(); + throw mapUpstreamError(res.status, text, res.headers); + } + + return res.json(); +} + +/** + * Invokes a serving endpoint with streaming enabled. + * Yields parsed JSON chunks from the NDJSON SSE response. + */ +export async function* stream( + client: WorkspaceClient, + endpointName: string, + body: Record, + options?: ServingInvokeOptions, +): AsyncGenerator { + const host = client.config.host; + if (!host) { + throw new Error( + "Databricks host is not configured. Set DATABRICKS_HOST or configure client.config.host.", + ); + } + + const url = buildInvocationUrl(host, endpointName, options?.servedModel); + + // Strip any user-provided `stream` and inject `stream: true` + const { stream: _stream, ...cleanBody } = body; + const streamBody = { ...cleanBody, stream: true }; + + const headers = new Headers({ + "Content-Type": "application/json", + Accept: "text/event-stream", + }); + await client.config.authenticate(headers); + + logger.debug("Streaming from endpoint %s at %s", endpointName, url); + + const res = await fetch(url, { + method: "POST", + headers, + body: JSON.stringify(streamBody), + signal: options?.signal, + }); + + if (!res.ok) { + const text = await res.text(); + throw mapUpstreamError(res.status, text, res.headers); + } + + if (!res.body) { + throw new Error("Response body is null — streaming not supported"); + } + + const reader = res.body.getReader(); + const decoder = new TextDecoder(); + let buffer = ""; + const MAX_BUFFER_SIZE = 1024 * 1024; // 1 MB + + try { + while (true) { + if (options?.signal?.aborted) break; + + const { done, value } = await reader.read(); + if (done) break; + + buffer += decoder.decode(value, { stream: true }); + + if (buffer.length > MAX_BUFFER_SIZE) { + logger.warn( + "Stream buffer exceeded %d bytes, discarding incomplete data", + MAX_BUFFER_SIZE, + ); + buffer = ""; + } + + // Process complete lines from the buffer + const lines = buffer.split("\n"); + // Keep the last (potentially incomplete) line in the buffer + buffer = lines.pop() ?? ""; + + for (const line of lines) { + const trimmed = line.trim(); + if (!trimmed || trimmed.startsWith(":")) continue; // skip empty lines and SSE comments + if (trimmed === "data: [DONE]") return; + + if (trimmed.startsWith("data: ")) { + const jsonStr = trimmed.slice(6); + try { + yield JSON.parse(jsonStr); + } catch { + logger.warn("Failed to parse streaming chunk: %s", jsonStr); + } + } + } + } + + // Process any remaining data in the buffer + if (buffer.trim() && !options?.signal?.aborted) { + const trimmed = buffer.trim(); + if (trimmed.startsWith("data: ") && trimmed !== "data: [DONE]") { + try { + yield JSON.parse(trimmed.slice(6)); + } catch { + logger.warn("Failed to parse final streaming chunk: %s", trimmed); + } + } + } + } finally { + reader.cancel().catch(() => {}); + reader.releaseLock(); + } +} diff --git a/packages/appkit/src/connectors/serving/tests/client.test.ts b/packages/appkit/src/connectors/serving/tests/client.test.ts new file mode 100644 index 00000000..6af859ae --- /dev/null +++ b/packages/appkit/src/connectors/serving/tests/client.test.ts @@ -0,0 +1,303 @@ +import { afterEach, beforeEach, describe, expect, test, vi } from "vitest"; +import { invoke, stream } from "../client"; + +const mockAuthenticate = vi.fn(); + +function createMockClient(host = "https://test.databricks.com") { + return { + config: { + host, + authenticate: mockAuthenticate, + }, + } as any; +} + +describe("Serving Connector", () => { + beforeEach(() => { + mockAuthenticate.mockResolvedValue(undefined); + }); + + afterEach(() => { + vi.restoreAllMocks(); + }); + + describe("invoke", () => { + test("constructs correct URL for endpoint invocation", async () => { + const fetchSpy = vi + .spyOn(globalThis, "fetch") + .mockResolvedValue( + new Response(JSON.stringify({ result: "ok" }), { status: 200 }), + ); + + const client = createMockClient(); + await invoke(client, "my-endpoint", { messages: [] }); + + expect(fetchSpy).toHaveBeenCalledWith( + "https://test.databricks.com/serving-endpoints/my-endpoint/invocations", + expect.objectContaining({ method: "POST" }), + ); + }); + + test("constructs correct URL with servedModel override", async () => { + const fetchSpy = vi + .spyOn(globalThis, "fetch") + .mockResolvedValue( + new Response(JSON.stringify({ result: "ok" }), { status: 200 }), + ); + + const client = createMockClient(); + await invoke( + client, + "my-endpoint", + { messages: [] }, + { servedModel: "llama-v2" }, + ); + + expect(fetchSpy).toHaveBeenCalledWith( + "https://test.databricks.com/serving-endpoints/my-endpoint/served-models/llama-v2/invocations", + expect.objectContaining({ method: "POST" }), + ); + }); + + test("authenticates request headers", async () => { + vi.spyOn(globalThis, "fetch").mockResolvedValue( + new Response(JSON.stringify({ result: "ok" }), { status: 200 }), + ); + + const client = createMockClient(); + await invoke(client, "my-endpoint", { messages: [] }); + + expect(mockAuthenticate).toHaveBeenCalledWith(expect.any(Headers)); + }); + + test("strips stream property from body", async () => { + const fetchSpy = vi + .spyOn(globalThis, "fetch") + .mockResolvedValue( + new Response(JSON.stringify({ result: "ok" }), { status: 200 }), + ); + + const client = createMockClient(); + await invoke(client, "my-endpoint", { + messages: [], + stream: true, + temperature: 0.7, + }); + + const body = JSON.parse(fetchSpy.mock.calls[0][1]?.body as string); + expect(body).toEqual({ messages: [], temperature: 0.7 }); + expect(body.stream).toBeUndefined(); + }); + + test("returns parsed JSON response", async () => { + const responseData = { choices: [{ message: { content: "Hello" } }] }; + vi.spyOn(globalThis, "fetch").mockResolvedValue( + new Response(JSON.stringify(responseData), { status: 200 }), + ); + + const client = createMockClient(); + const result = await invoke(client, "my-endpoint", { messages: [] }); + + expect(result).toEqual(responseData); + }); + + test("throws ApiError on 400 response", async () => { + vi.spyOn(globalThis, "fetch").mockResolvedValue( + new Response(JSON.stringify({ message: "Invalid params" }), { + status: 400, + }), + ); + + const client = createMockClient(); + await expect( + invoke(client, "my-endpoint", { messages: [] }), + ).rejects.toThrow("Invalid params"); + }); + + test("throws ApiError on 404 response", async () => { + vi.spyOn(globalThis, "fetch").mockResolvedValue( + new Response(JSON.stringify({ message: "Endpoint not found" }), { + status: 404, + }), + ); + + const client = createMockClient(); + await expect( + invoke(client, "my-endpoint", { messages: [] }), + ).rejects.toThrow("Endpoint not found"); + }); + + test("maps 5xx to 502 bad gateway", async () => { + vi.spyOn(globalThis, "fetch").mockResolvedValue( + new Response(JSON.stringify({ message: "Internal error" }), { + status: 500, + }), + ); + + const client = createMockClient(); + try { + await invoke(client, "my-endpoint", { messages: [] }); + expect.unreachable("Should have thrown"); + } catch (err: any) { + expect(err.statusCode).toBe(502); + } + }); + + test("forwards AbortSignal", async () => { + const controller = new AbortController(); + const fetchSpy = vi + .spyOn(globalThis, "fetch") + .mockResolvedValue( + new Response(JSON.stringify({ result: "ok" }), { status: 200 }), + ); + + const client = createMockClient(); + await invoke( + client, + "my-endpoint", + { messages: [] }, + { signal: controller.signal }, + ); + + expect(fetchSpy.mock.calls[0][1]?.signal).toBe(controller.signal); + }); + + test("throws when host is not configured", async () => { + const client = { + config: { + host: "", + authenticate: mockAuthenticate, + }, + } as any; + await expect( + invoke(client, "my-endpoint", { messages: [] }), + ).rejects.toThrow("Databricks host is not configured"); + }); + + test("prepends https:// to host without protocol", async () => { + const fetchSpy = vi + .spyOn(globalThis, "fetch") + .mockResolvedValue( + new Response(JSON.stringify({ result: "ok" }), { status: 200 }), + ); + + const client = createMockClient("test.databricks.com"); + await invoke(client, "my-endpoint", { messages: [] }); + + expect(fetchSpy.mock.calls[0][0]).toContain( + "https://test.databricks.com", + ); + }); + }); + + describe("stream", () => { + function createSSEResponse(chunks: string[]) { + const body = `${chunks.join("\n")}\n`; + return new Response(body, { + status: 200, + headers: { "Content-Type": "text/event-stream" }, + }); + } + + test("yields parsed NDJSON chunks", async () => { + const chunks = [ + 'data: {"choices":[{"delta":{"content":"Hello"}}]}', + 'data: {"choices":[{"delta":{"content":" world"}}]}', + "data: [DONE]", + ]; + + vi.spyOn(globalThis, "fetch").mockResolvedValue( + createSSEResponse(chunks), + ); + + const client = createMockClient(); + const results: unknown[] = []; + for await (const chunk of stream(client, "my-endpoint", { + messages: [], + })) { + results.push(chunk); + } + + expect(results).toEqual([ + { choices: [{ delta: { content: "Hello" } }] }, + { choices: [{ delta: { content: " world" } }] }, + ]); + }); + + test("injects stream: true into body", async () => { + const fetchSpy = vi + .spyOn(globalThis, "fetch") + .mockResolvedValue(createSSEResponse(["data: [DONE]"])); + + const client = createMockClient(); + // Consume the generator + for await (const _ of stream(client, "my-endpoint", { messages: [] })) { + // noop + } + + const body = JSON.parse(fetchSpy.mock.calls[0][1]?.body as string); + expect(body.stream).toBe(true); + }); + + test("strips user-provided stream and re-injects", async () => { + const fetchSpy = vi + .spyOn(globalThis, "fetch") + .mockResolvedValue(createSSEResponse(["data: [DONE]"])); + + const client = createMockClient(); + for await (const _ of stream(client, "my-endpoint", { + messages: [], + stream: false, + })) { + // noop + } + + const body = JSON.parse(fetchSpy.mock.calls[0][1]?.body as string); + expect(body.stream).toBe(true); + }); + + test("skips SSE comments and empty lines", async () => { + const chunks = [ + ": this is a comment", + "", + 'data: {"choices":[{"delta":{"content":"Hi"}}]}', + "", + "data: [DONE]", + ]; + + vi.spyOn(globalThis, "fetch").mockResolvedValue( + createSSEResponse(chunks), + ); + + const client = createMockClient(); + const results: unknown[] = []; + for await (const chunk of stream(client, "my-endpoint", { + messages: [], + })) { + results.push(chunk); + } + + expect(results).toHaveLength(1); + expect(results[0]).toEqual({ choices: [{ delta: { content: "Hi" } }] }); + }); + + test("throws on non-OK response", async () => { + vi.spyOn(globalThis, "fetch").mockResolvedValue( + new Response(JSON.stringify({ message: "Rate limited" }), { + status: 429, + headers: { "Retry-After": "5" }, + }), + ); + + const client = createMockClient(); + try { + for await (const _ of stream(client, "my-endpoint", { messages: [] })) { + // noop + } + expect.unreachable("Should have thrown"); + } catch (err: any) { + expect(err.statusCode).toBe(429); + } + }); + }); +}); diff --git a/packages/appkit/src/connectors/serving/types.ts b/packages/appkit/src/connectors/serving/types.ts new file mode 100644 index 00000000..6dd1acba --- /dev/null +++ b/packages/appkit/src/connectors/serving/types.ts @@ -0,0 +1,4 @@ +export interface ServingInvokeOptions { + servedModel?: string; + signal?: AbortSignal; +} diff --git a/packages/appkit/src/index.ts b/packages/appkit/src/index.ts index 8db7f1d7..3df5572b 100644 --- a/packages/appkit/src/index.ts +++ b/packages/appkit/src/index.ts @@ -48,7 +48,13 @@ export { } from "./errors"; // Plugin authoring export { Plugin, type ToPlugin, toPlugin } from "./plugin"; -export { analytics, files, genie, lakebase, server } from "./plugins"; +export { analytics, files, genie, lakebase, server, serving } from "./plugins"; +export type { + EndpointConfig, + ServingEndpointEntry, + ServingEndpointRegistry, + ServingFactory, +} from "./plugins/serving/types"; // Registry types and utilities for plugin manifests export type { ConfigSchema, @@ -75,6 +81,10 @@ export { SpanStatusCode, type TelemetryConfig, } from "./telemetry"; - +export { + extractServingEndpoints, + findServerFile, +} from "./type-generator/serving/server-file-extractor"; +export { appKitServingTypesPlugin } from "./type-generator/serving/vite-plugin"; // Vite plugin and type generation export { appKitTypesPlugin } from "./type-generator/vite-plugin"; diff --git a/packages/appkit/src/plugins/index.ts b/packages/appkit/src/plugins/index.ts index 7caa040f..4d58082f 100644 --- a/packages/appkit/src/plugins/index.ts +++ b/packages/appkit/src/plugins/index.ts @@ -3,3 +3,4 @@ export * from "./files"; export * from "./genie"; export * from "./lakebase"; export * from "./server"; +export * from "./serving"; diff --git a/packages/appkit/src/plugins/serving/defaults.ts b/packages/appkit/src/plugins/serving/defaults.ts new file mode 100644 index 00000000..1fea64c2 --- /dev/null +++ b/packages/appkit/src/plugins/serving/defaults.ts @@ -0,0 +1,26 @@ +import type { StreamExecutionSettings } from "shared"; + +export const servingInvokeDefaults = { + cache: { + enabled: false, + }, + retry: { + enabled: false, + }, + timeout: 120_000, +}; + +export const servingStreamDefaults: StreamExecutionSettings = { + default: { + cache: { + enabled: false, + }, + retry: { + enabled: false, + }, + timeout: 120_000, + }, + stream: { + bufferSize: 200, + }, +}; diff --git a/packages/appkit/src/plugins/serving/index.ts b/packages/appkit/src/plugins/serving/index.ts new file mode 100644 index 00000000..85caf33b --- /dev/null +++ b/packages/appkit/src/plugins/serving/index.ts @@ -0,0 +1,2 @@ +export * from "./serving"; +export * from "./types"; diff --git a/packages/appkit/src/plugins/serving/manifest.json b/packages/appkit/src/plugins/serving/manifest.json new file mode 100644 index 00000000..9ac0845f --- /dev/null +++ b/packages/appkit/src/plugins/serving/manifest.json @@ -0,0 +1,54 @@ +{ + "$schema": "https://databricks.github.io/appkit/schemas/plugin-manifest.schema.json", + "name": "serving", + "displayName": "Model Serving Plugin", + "description": "Authenticated proxy to Databricks Model Serving endpoints", + "resources": { + "required": [ + { + "type": "serving_endpoint", + "alias": "Serving Endpoint", + "resourceKey": "serving-endpoint", + "description": "Model Serving endpoint for inference", + "permission": "CAN_QUERY", + "fields": { + "name": { + "env": "DATABRICKS_SERVING_ENDPOINT", + "description": "Serving endpoint name" + } + } + } + ], + "optional": [] + }, + "config": { + "schema": { + "type": "object", + "properties": { + "endpoints": { + "type": "object", + "description": "Map of alias names to endpoint configurations", + "additionalProperties": { + "type": "object", + "properties": { + "env": { + "type": "string", + "description": "Environment variable holding the endpoint name" + }, + "servedModel": { + "type": "string", + "description": "Target a specific served model (bypasses traffic routing)" + } + }, + "required": ["env"] + } + }, + "timeout": { + "type": "number", + "default": 120000, + "description": "Request timeout in ms. Default: 120000 (2 min)" + } + } + } + } +} diff --git a/packages/appkit/src/plugins/serving/schema-filter.ts b/packages/appkit/src/plugins/serving/schema-filter.ts new file mode 100644 index 00000000..92a25c69 --- /dev/null +++ b/packages/appkit/src/plugins/serving/schema-filter.ts @@ -0,0 +1,95 @@ +import fs from "node:fs/promises"; +import { createLogger } from "../../logging/logger"; +import { + CACHE_VERSION, + type ServingCache, +} from "../../type-generator/serving/cache"; + +const logger = createLogger("serving:schema-filter"); + +function isValidCache(data: unknown): data is ServingCache { + return ( + typeof data === "object" && + data !== null && + "version" in data && + (data as ServingCache).version === CACHE_VERSION && + "endpoints" in data && + typeof (data as ServingCache).endpoints === "object" + ); +} + +/** + * Loads endpoint schemas from the type generation cache file. + * Returns a map of alias → allowed parameter keys. + */ +export async function loadEndpointSchemas( + cacheFile: string, +): Promise>> { + const allowlists = new Map>(); + + try { + const raw = await fs.readFile(cacheFile, "utf8"); + const parsed: unknown = JSON.parse(raw); + if (!isValidCache(parsed)) { + logger.warn("Serving types cache has invalid structure, skipping"); + return allowlists; + } + const cache = parsed; + + for (const [alias, entry] of Object.entries(cache.endpoints)) { + if (entry.requestKeys && entry.requestKeys.length > 0) { + allowlists.set(alias, new Set(entry.requestKeys)); + } + } + } catch (err) { + if ((err as NodeJS.ErrnoException).code !== "ENOENT") { + logger.warn( + "Failed to load serving types cache: %s", + (err as Error).message, + ); + } + // No cache → no filtering, passthrough mode + } + + return allowlists; +} + +/** + * Filters a request body against the allowed keys for an endpoint alias. + * Returns the filtered body and logs a warning for stripped params. + * + * If no allowlist exists for the alias, returns the body unchanged (passthrough). + */ +export function filterRequestBody( + body: Record, + allowlists: Map>, + alias: string, + filterMode: "strip" | "reject" = "strip", +): Record { + const allowed = allowlists.get(alias); + if (!allowed) return body; + + const stripped: string[] = []; + const filtered: Record = {}; + + for (const [key, value] of Object.entries(body)) { + if (allowed.has(key)) { + filtered[key] = value; + } else { + stripped.push(key); + } + } + + if (stripped.length > 0) { + if (filterMode === "reject") { + throw new Error(`Unknown request parameters: ${stripped.join(", ")}`); + } + logger.warn( + "Stripped unknown params from '%s': %s", + alias, + stripped.join(", "), + ); + } + + return filtered; +} diff --git a/packages/appkit/src/plugins/serving/serving.ts b/packages/appkit/src/plugins/serving/serving.ts new file mode 100644 index 00000000..e3547bcf --- /dev/null +++ b/packages/appkit/src/plugins/serving/serving.ts @@ -0,0 +1,304 @@ +import { randomUUID } from "node:crypto"; +import path from "node:path"; +import type express from "express"; +import type { IAppRouter, StreamExecutionSettings } from "shared"; +import * as servingConnector from "../../connectors/serving/client"; +import { getWorkspaceClient } from "../../context"; +import { createLogger } from "../../logging"; +import { Plugin, toPlugin } from "../../plugin"; +import type { PluginManifest, ResourceRequirement } from "../../registry"; +import { ResourceType } from "../../registry"; +import { servingInvokeDefaults, servingStreamDefaults } from "./defaults"; +import manifest from "./manifest.json"; +import { filterRequestBody, loadEndpointSchemas } from "./schema-filter"; +import type { EndpointConfig, IServingConfig, ServingFactory } from "./types"; + +const logger = createLogger("serving"); + +class EndpointNotFoundError extends Error { + constructor(alias: string) { + super(`Unknown endpoint alias: ${alias}`); + } +} + +class EndpointNotConfiguredError extends Error { + constructor(alias: string, envVar: string) { + super( + `Endpoint '${alias}' is not configured: env var '${envVar}' is not set`, + ); + } +} + +interface ResolvedEndpoint { + name: string; + servedModel?: string; +} + +export class ServingPlugin extends Plugin { + static manifest = manifest as PluginManifest<"serving">; + + protected static description = + "Authenticated proxy to Databricks Model Serving endpoints"; + protected declare config: IServingConfig; + + private readonly endpoints: Record; + private readonly isNamedMode: boolean; + private schemaAllowlists = new Map>(); + + constructor(config: IServingConfig) { + super(config); + this.config = config; + + if (config.endpoints) { + this.endpoints = config.endpoints; + this.isNamedMode = true; + } else { + this.endpoints = { + default: { env: "DATABRICKS_SERVING_ENDPOINT" }, + }; + this.isNamedMode = false; + } + } + + async setup(): Promise { + const cacheFile = path.join( + process.cwd(), + "node_modules", + ".databricks", + "appkit", + ".appkit-serving-types-cache.json", + ); + this.schemaAllowlists = await loadEndpointSchemas(cacheFile); + if (this.schemaAllowlists.size > 0) { + logger.debug( + "Loaded schema allowlists for %d endpoint(s)", + this.schemaAllowlists.size, + ); + } + } + + static getResourceRequirements( + config: IServingConfig, + ): ResourceRequirement[] { + const endpoints = config.endpoints ?? { + default: { env: "DATABRICKS_SERVING_ENDPOINT" }, + }; + + return Object.entries(endpoints).map(([alias, endpointConfig]) => ({ + type: ResourceType.SERVING_ENDPOINT, + alias: `serving-${alias}`, + resourceKey: `serving-${alias}`, + description: `Model Serving endpoint for "${alias}" inference`, + permission: "CAN_QUERY" as const, + fields: { + name: { + env: endpointConfig.env, + description: `Serving endpoint name for "${alias}"`, + }, + }, + required: true, + })); + } + + private resolveAndFilter( + alias: string, + body: Record, + ): { endpoint: ResolvedEndpoint; filteredBody: Record } { + const config = this.endpoints[alias]; + if (!config) { + throw new EndpointNotFoundError(alias); + } + + const name = process.env[config.env]; + if (!name) { + throw new EndpointNotConfiguredError(alias, config.env); + } + + const endpoint: ResolvedEndpoint = { + name, + servedModel: config.servedModel, + }; + const filteredBody = filterRequestBody( + body, + this.schemaAllowlists, + alias, + this.config.filterMode, + ); + return { endpoint, filteredBody }; + } + + injectRoutes(router: IAppRouter) { + if (this.isNamedMode) { + this.route(router, { + name: "invoke", + method: "post", + path: "/:alias/invoke", + handler: async (req: express.Request, res: express.Response) => { + await this.asUser(req)._handleInvoke(req, res); + }, + }); + + this.route(router, { + name: "stream", + method: "post", + path: "/:alias/stream", + handler: async (req: express.Request, res: express.Response) => { + await this.asUser(req)._handleStream(req, res); + }, + }); + } else { + this.route(router, { + name: "invoke", + method: "post", + path: "/invoke", + handler: async (req: express.Request, res: express.Response) => { + req.params.alias = "default"; + await this.asUser(req)._handleInvoke(req, res); + }, + }); + + this.route(router, { + name: "stream", + method: "post", + path: "/stream", + handler: async (req: express.Request, res: express.Response) => { + req.params.alias = "default"; + await this.asUser(req)._handleStream(req, res); + }, + }); + } + } + + async _handleInvoke( + req: express.Request, + res: express.Response, + ): Promise { + const { alias } = req.params; + const rawBody = req.body as Record; + + try { + const result = await this.invoke(alias, rawBody); + if (result === undefined) { + res.status(502).json({ error: "Invocation returned no result" }); + return; + } + res.json(result); + } catch (err) { + const message = err instanceof Error ? err.message : "Invocation failed"; + if (err instanceof EndpointNotFoundError) { + res.status(404).json({ error: message }); + } else if ( + err instanceof EndpointNotConfiguredError || + message.startsWith("Unknown request parameters:") + ) { + res.status(400).json({ error: message }); + } else { + res.status(502).json({ error: message }); + } + } + } + + async _handleStream( + req: express.Request, + res: express.Response, + ): Promise { + const { alias } = req.params; + const rawBody = req.body as Record; + + let endpoint: ResolvedEndpoint; + let filteredBody: Record; + try { + ({ endpoint, filteredBody } = this.resolveAndFilter(alias, rawBody)); + } catch (err) { + const message = err instanceof Error ? err.message : "Invalid request"; + const status = err instanceof EndpointNotFoundError ? 404 : 400; + res.status(status).json({ error: message }); + return; + } + + const timeout = this.config.timeout ?? 120_000; + const requestId = + (typeof req.query.requestId === "string" && req.query.requestId) || + randomUUID(); + + const streamSettings: StreamExecutionSettings = { + ...servingStreamDefaults, + default: { + ...servingStreamDefaults.default, + timeout, + }, + stream: { + ...servingStreamDefaults.stream, + streamId: requestId, + }, + }; + + const workspaceClient = getWorkspaceClient(); + if (!workspaceClient.config.host) { + res.status(500).json({ error: "Databricks host not configured" }); + return; + } + + await this.executeStream( + res, + (signal) => + servingConnector.stream(workspaceClient, endpoint.name, filteredBody, { + servedModel: endpoint.servedModel, + signal, + }), + streamSettings, + ); + } + + async invoke(alias: string, body: Record): Promise { + const { endpoint, filteredBody } = this.resolveAndFilter(alias, body); + const workspaceClient = getWorkspaceClient(); + const timeout = this.config.timeout ?? 120_000; + + return this.execute( + () => + servingConnector.invoke(workspaceClient, endpoint.name, filteredBody, { + servedModel: endpoint.servedModel, + }), + { + default: { + ...servingInvokeDefaults, + timeout, + }, + }, + ); + } + + async *stream( + alias: string, + body: Record, + ): AsyncGenerator { + const { endpoint, filteredBody } = this.resolveAndFilter(alias, body); + const workspaceClient = getWorkspaceClient(); + + yield* servingConnector.stream( + workspaceClient, + endpoint.name, + filteredBody, + { servedModel: endpoint.servedModel }, + ); + } + + async shutdown(): Promise { + this.streamManager.abortAll(); + } + + exports(): ServingFactory { + return ((alias?: string) => ({ + invoke: (body: Record) => + this.invoke(alias ?? "default", body), + stream: (body: Record) => + this.stream(alias ?? "default", body), + })) as ServingFactory; + } +} + +/** + * @internal + */ +export const serving = toPlugin(ServingPlugin); diff --git a/packages/appkit/src/plugins/serving/tests/schema-filter.test.ts b/packages/appkit/src/plugins/serving/tests/schema-filter.test.ts new file mode 100644 index 00000000..4fc030d8 --- /dev/null +++ b/packages/appkit/src/plugins/serving/tests/schema-filter.test.ts @@ -0,0 +1,159 @@ +import { describe, expect, test, vi } from "vitest"; +import { filterRequestBody, loadEndpointSchemas } from "../schema-filter"; + +vi.mock("node:fs/promises", () => ({ + default: { + readFile: vi.fn(), + }, +})); + +describe("schema-filter", () => { + describe("filterRequestBody", () => { + test("strips unknown keys when allowlist exists", () => { + const allowlists = new Map([ + ["default", new Set(["messages", "temperature"])], + ]); + + const result = filterRequestBody( + { messages: [], temperature: 0.7, unknown_param: true }, + allowlists, + "default", + ); + + expect(result).toEqual({ messages: [], temperature: 0.7 }); + }); + + test("preserves all keys when no allowlist for alias", () => { + const allowlists = new Map>(); + + const body = { messages: [], custom: "value" }; + const result = filterRequestBody(body, allowlists, "default"); + + expect(result).toBe(body); // Same reference, no filtering + }); + + test("returns empty object when all keys are unknown", () => { + const allowlists = new Map([["default", new Set(["messages"])]]); + + const result = filterRequestBody( + { bad1: 1, bad2: 2 }, + allowlists, + "default", + ); + + expect(result).toEqual({}); + }); + + test("returns full body when all keys are allowed", () => { + const allowlists = new Map([["default", new Set(["a", "b", "c"])]]); + + const result = filterRequestBody( + { a: 1, b: 2, c: 3 }, + allowlists, + "default", + ); + + expect(result).toEqual({ a: 1, b: 2, c: 3 }); + }); + + test("throws in reject mode when unknown keys are present", () => { + const allowlists = new Map([["default", new Set(["messages"])]]); + + expect(() => + filterRequestBody( + { messages: [], unknown_param: true }, + allowlists, + "default", + "reject", + ), + ).toThrow("Unknown request parameters: unknown_param"); + }); + + test("does not throw in reject mode when all keys are allowed", () => { + const allowlists = new Map([ + ["default", new Set(["messages", "temperature"])], + ]); + + const result = filterRequestBody( + { messages: [], temperature: 0.7 }, + allowlists, + "default", + "reject", + ); + + expect(result).toEqual({ messages: [], temperature: 0.7 }); + }); + + test("strips in default mode (strip)", () => { + const allowlists = new Map([["default", new Set(["messages"])]]); + + const result = filterRequestBody( + { messages: [], extra: true }, + allowlists, + "default", + "strip", + ); + + expect(result).toEqual({ messages: [] }); + }); + }); + + describe("loadEndpointSchemas", () => { + test("returns empty map when cache file does not exist", async () => { + const fs = (await import("node:fs/promises")).default; + vi.mocked(fs.readFile).mockRejectedValue( + Object.assign(new Error("ENOENT"), { code: "ENOENT" }), + ); + + const result = await loadEndpointSchemas("/nonexistent/path"); + expect(result.size).toBe(0); + }); + + test("reads requestKeys from cache entries", async () => { + const fs = (await import("node:fs/promises")).default; + vi.mocked(fs.readFile).mockResolvedValue( + JSON.stringify({ + version: "1", + endpoints: { + default: { + hash: "abc", + requestType: "{}", + responseType: "{}", + chunkType: null, + requestKeys: ["messages", "temperature", "max_tokens"], + }, + }, + }), + ); + + const result = await loadEndpointSchemas("/some/path"); + expect(result.size).toBe(1); + const keys = result.get("default"); + expect(keys).toBeDefined(); + expect(keys?.has("messages")).toBe(true); + expect(keys?.has("temperature")).toBe(true); + expect(keys?.has("max_tokens")).toBe(true); + }); + + test("skips entries without requestKeys (backwards compat)", async () => { + const fs = (await import("node:fs/promises")).default; + vi.mocked(fs.readFile).mockResolvedValue( + JSON.stringify({ + version: "1", + endpoints: { + default: { + hash: "abc", + requestType: "{ messages: string[] }", + responseType: "{}", + chunkType: null, + }, + }, + }), + ); + + const result = await loadEndpointSchemas("/some/path"); + // No requestKeys → passthrough mode (no allowlist) + expect(result.size).toBe(0); + }); + }); +}); diff --git a/packages/appkit/src/plugins/serving/tests/serving.test.ts b/packages/appkit/src/plugins/serving/tests/serving.test.ts new file mode 100644 index 00000000..1a953b77 --- /dev/null +++ b/packages/appkit/src/plugins/serving/tests/serving.test.ts @@ -0,0 +1,339 @@ +import { + createMockRequest, + createMockResponse, + createMockRouter, + mockServiceContext, + setupDatabricksEnv, +} from "@tools/test-helpers"; +import { afterEach, beforeEach, describe, expect, test, vi } from "vitest"; +import { ServiceContext } from "../../../context/service-context"; +import { ServingPlugin, serving } from "../serving"; +import type { IServingConfig } from "../types"; + +// Mock CacheManager singleton +const { mockCacheInstance } = vi.hoisted(() => { + const instance = { + get: vi.fn(), + set: vi.fn(), + delete: vi.fn(), + getOrExecute: vi + .fn() + .mockImplementation( + async (_key: unknown[], fn: () => Promise) => { + return await fn(); + }, + ), + generateKey: vi.fn((...args: unknown[]) => JSON.stringify(args)), + }; + return { mockCacheInstance: instance }; +}); + +vi.mock("../../../cache", () => ({ + CacheManager: { + getInstanceSync: vi.fn(() => mockCacheInstance), + }, +})); + +// Mock the serving connector +const mockInvoke = vi.fn(); +const mockStream = vi.fn(); + +vi.mock("../../../connectors/serving/client", () => ({ + invoke: (...args: any[]) => mockInvoke(...args), + stream: (...args: any[]) => mockStream(...args), +})); + +describe("Serving Plugin", () => { + let serviceContextMock: Awaited>; + + beforeEach(async () => { + setupDatabricksEnv(); + process.env.DATABRICKS_SERVING_ENDPOINT = "test-endpoint"; + ServiceContext.reset(); + + serviceContextMock = await mockServiceContext(); + }); + + afterEach(() => { + serviceContextMock?.restore(); + delete process.env.DATABRICKS_SERVING_ENDPOINT; + vi.restoreAllMocks(); + }); + + test("serving factory should have correct name", () => { + const pluginData = serving(); + expect(pluginData.name).toBe("serving"); + }); + + test("serving factory with config should have correct name", () => { + const pluginData = serving({ + endpoints: { llm: { env: "DATABRICKS_SERVING_ENDPOINT" } }, + }); + expect(pluginData.name).toBe("serving"); + }); + + describe("default mode", () => { + test("reads DATABRICKS_SERVING_ENDPOINT", () => { + const plugin = new ServingPlugin({}); + const api = (plugin.exports() as any)(); + expect(api.invoke).toBeDefined(); + expect(api.stream).toBeDefined(); + }); + + test("injects /invoke and /stream routes", () => { + const plugin = new ServingPlugin({}); + const { router, handlers } = createMockRouter(); + + plugin.injectRoutes(router); + + expect(handlers["POST:/invoke"]).toBeDefined(); + expect(handlers["POST:/stream"]).toBeDefined(); + }); + + test("exports returns a factory that provides invoke and stream", () => { + const plugin = new ServingPlugin({}); + const factory = plugin.exports() as any; + const api = factory(); + + expect(typeof api.invoke).toBe("function"); + expect(typeof api.stream).toBe("function"); + }); + }); + + describe("named mode", () => { + const namedConfig: IServingConfig = { + endpoints: { + llm: { env: "DATABRICKS_SERVING_ENDPOINT" }, + embedder: { env: "DATABRICKS_SERVING_ENDPOINT_EMBEDDING" }, + }, + }; + + test("injects /:alias/invoke and /:alias/stream routes", () => { + const plugin = new ServingPlugin(namedConfig); + const { router, handlers } = createMockRouter(); + + plugin.injectRoutes(router); + + expect(handlers["POST:/:alias/invoke"]).toBeDefined(); + expect(handlers["POST:/:alias/stream"]).toBeDefined(); + }); + + test("exports factory returns invoke and stream for named aliases", () => { + const plugin = new ServingPlugin(namedConfig); + const factory = plugin.exports() as any; + + expect(typeof factory("llm").invoke).toBe("function"); + expect(typeof factory("llm").stream).toBe("function"); + expect(typeof factory("embedder").invoke).toBe("function"); + expect(typeof factory("embedder").stream).toBe("function"); + }); + }); + + describe("route handlers", () => { + test("_handleInvoke returns 404 for unknown alias", async () => { + const plugin = new ServingPlugin({ + endpoints: { llm: { env: "DATABRICKS_SERVING_ENDPOINT" } }, + }); + + const req = createMockRequest({ + params: { alias: "unknown" }, + body: { messages: [] }, + }); + const res = createMockResponse(); + + await plugin._handleInvoke(req as any, res as any); + + expect(res.status).toHaveBeenCalledWith(404); + expect(res.json).toHaveBeenCalledWith({ + error: "Unknown endpoint alias: unknown", + }); + }); + + test("_handleInvoke calls connector with correct endpoint", async () => { + mockInvoke.mockResolvedValue({ choices: [] }); + + const plugin = new ServingPlugin({}); + const req = createMockRequest({ + params: { alias: "default" }, + body: { messages: [{ role: "user", content: "Hello" }] }, + }); + const res = createMockResponse(); + + await plugin._handleInvoke(req as any, res as any); + + expect(mockInvoke).toHaveBeenCalledWith( + expect.anything(), + "test-endpoint", + { messages: [{ role: "user", content: "Hello" }] }, + { servedModel: undefined }, + ); + expect(res.json).toHaveBeenCalledWith({ choices: [] }); + }); + + test("_handleInvoke returns 400 with descriptive message when env var is not set", async () => { + delete process.env.DATABRICKS_SERVING_ENDPOINT; + + const plugin = new ServingPlugin({}); + const req = createMockRequest({ + params: { alias: "default" }, + body: { messages: [] }, + }); + const res = createMockResponse(); + + await plugin._handleInvoke(req as any, res as any); + + expect(res.status).toHaveBeenCalledWith(400); + expect(res.json).toHaveBeenCalledWith({ + error: + "Endpoint 'default' is not configured: env var 'DATABRICKS_SERVING_ENDPOINT' is not set", + }); + }); + + test("_handleInvoke does not throw when connector fails", async () => { + mockInvoke.mockRejectedValue(new Error("Connection refused")); + + const plugin = new ServingPlugin({}); + const req = createMockRequest({ + params: { alias: "default" }, + body: { messages: [] }, + }); + const res = createMockResponse(); + + // Should not throw — execute() handles the error internally + await expect( + plugin._handleInvoke(req as any, res as any), + ).resolves.not.toThrow(); + }); + + test("_handleStream returns 404 for unknown alias", async () => { + const plugin = new ServingPlugin({ + endpoints: { llm: { env: "DATABRICKS_SERVING_ENDPOINT" } }, + }); + + const req = createMockRequest({ + params: { alias: "unknown" }, + body: { messages: [] }, + query: {}, + }); + const res = createMockResponse(); + + await plugin._handleStream(req as any, res as any); + + expect(res.status).toHaveBeenCalledWith(404); + expect(res.json).toHaveBeenCalledWith({ + error: "Unknown endpoint alias: unknown", + }); + }); + + test("_handleStream returns 400 when env var is not set", async () => { + delete process.env.DATABRICKS_SERVING_ENDPOINT; + + const plugin = new ServingPlugin({}); + const req = createMockRequest({ + params: { alias: "default" }, + body: { messages: [] }, + query: {}, + }); + const res = createMockResponse(); + + await plugin._handleStream(req as any, res as any); + + expect(res.status).toHaveBeenCalledWith(400); + expect(res.json).toHaveBeenCalledWith({ + error: + "Endpoint 'default' is not configured: env var 'DATABRICKS_SERVING_ENDPOINT' is not set", + }); + }); + }); + + describe("getResourceRequirements", () => { + test("generates requirements for default mode", () => { + const reqs = ServingPlugin.getResourceRequirements({}); + expect(reqs).toHaveLength(1); + expect(reqs[0]).toMatchObject({ + type: "serving_endpoint", + alias: "serving-default", + permission: "CAN_QUERY", + fields: { + name: { + env: "DATABRICKS_SERVING_ENDPOINT", + }, + }, + }); + }); + + test("generates requirements for named mode", () => { + const reqs = ServingPlugin.getResourceRequirements({ + endpoints: { + llm: { env: "LLM_ENDPOINT" }, + embedder: { env: "EMBED_ENDPOINT" }, + }, + }); + expect(reqs).toHaveLength(2); + expect(reqs[0].fields.name.env).toBe("LLM_ENDPOINT"); + expect(reqs[1].fields.name.env).toBe("EMBED_ENDPOINT"); + }); + }); + + describe("programmatic API", () => { + test("invoke calls connector correctly", async () => { + mockInvoke.mockResolvedValue({ + choices: [{ message: { content: "Hi" } }], + }); + + const plugin = new ServingPlugin({}); + const result = await plugin.invoke("default", { messages: [] }); + + expect(mockInvoke).toHaveBeenCalledWith( + expect.anything(), + "test-endpoint", + { messages: [] }, + { servedModel: undefined }, + ); + expect(result).toEqual({ choices: [{ message: { content: "Hi" } }] }); + }); + + test("invoke throws for unknown alias", async () => { + const plugin = new ServingPlugin({ + endpoints: { llm: { env: "DATABRICKS_SERVING_ENDPOINT" } }, + }); + + await expect(plugin.invoke("unknown", { messages: [] })).rejects.toThrow( + "Unknown endpoint alias: unknown", + ); + }); + + test("stream yields chunks from connector", async () => { + const chunks = [ + { choices: [{ delta: { content: "Hello" } }] }, + { choices: [{ delta: { content: " world" } }] }, + ]; + + mockStream.mockImplementation(async function* () { + for (const chunk of chunks) { + yield chunk; + } + }); + + const plugin = new ServingPlugin({}); + const results: unknown[] = []; + for await (const chunk of plugin.stream("default", { messages: [] })) { + results.push(chunk); + } + + expect(results).toEqual(chunks); + }); + }); + + describe("shutdown", () => { + test("calls streamManager.abortAll", async () => { + const plugin = new ServingPlugin({}); + // Accessing the protected streamManager through the plugin + const abortSpy = vi.spyOn((plugin as any).streamManager, "abortAll"); + + await plugin.shutdown(); + + expect(abortSpy).toHaveBeenCalled(); + }); + }); +}); diff --git a/packages/appkit/src/plugins/serving/types.ts b/packages/appkit/src/plugins/serving/types.ts new file mode 100644 index 00000000..9a2dd230 --- /dev/null +++ b/packages/appkit/src/plugins/serving/types.ts @@ -0,0 +1,67 @@ +import type { BasePluginConfig } from "shared"; + +export interface EndpointConfig { + /** Environment variable holding the endpoint name. */ + env: string; + /** Target a specific served model (bypasses traffic routing). */ + servedModel?: string; +} + +export interface IServingConfig extends BasePluginConfig { + /** Map of alias → endpoint config. Defaults to { default: { env: "DATABRICKS_SERVING_ENDPOINT" } } if omitted. */ + endpoints?: Record; + /** Request timeout in ms. Default: 120000 (2 min) */ + timeout?: number; + /** How to handle unknown request parameters. 'strip' silently removes them (default). 'reject' returns 400. */ + filterMode?: "strip" | "reject"; +} + +/** + * Registry interface for serving endpoint type generation. + * Empty by default — augmented by the Vite type generator's `.d.ts` output via module augmentation. + * When populated, provides autocomplete for alias names and typed request/response/chunk per endpoint. + */ +// biome-ignore lint/suspicious/noEmptyInterface: intentionally empty — populated via module augmentation +export interface ServingEndpointRegistry {} + +/** Shape of a single registry entry. */ +export interface ServingEndpointEntry { + request: Record; + response: unknown; + chunk: unknown; +} + +/** Typed invoke/stream methods for a serving endpoint. */ +export interface ServingEndpointMethods< + TRequest extends Record = Record, + TResponse = unknown, + TChunk = unknown, +> { + invoke: (body: TRequest) => Promise; + stream: (body: TRequest) => AsyncGenerator; +} + +/** + * Factory function returned by `AppKit.serving`. + * + * This is a conditional type that adapts based on whether `ServingEndpointRegistry` + * has been populated via module augmentation (generated by `appKitServingTypesPlugin()`): + * + * - **Registry empty (default):** `(alias?: string) => ServingEndpointMethods` — + * accepts any alias string with untyped request/response/chunk. + * - **Registry populated:** `(alias: K) => ServingEndpointMethods<...>` — + * restricts `alias` to known endpoint keys and infers typed request/response/chunk + * from the registry entry. + * + * Run `appKitServingTypesPlugin()` in your Vite config to generate the registry + * augmentation and enable full type safety. + */ +export type ServingFactory = keyof ServingEndpointRegistry extends never + ? (alias?: string) => ServingEndpointMethods + : ( + alias: K, + ) => ServingEndpointMethods< + ServingEndpointRegistry[K]["request"], + ServingEndpointRegistry[K]["response"], + ServingEndpointRegistry[K]["chunk"] + >; diff --git a/packages/appkit/src/stream/stream-manager.ts b/packages/appkit/src/stream/stream-manager.ts index 41764772..8b511fac 100644 --- a/packages/appkit/src/stream/stream-manager.ts +++ b/packages/appkit/src/stream/stream-manager.ts @@ -374,6 +374,14 @@ export class StreamManager { if (error.name === "AbortError") { return SSEErrorCode.STREAM_ABORTED; } + + // Detect upstream API errors (e.g., from Databricks SDK ApiError) + if ( + "statusCode" in error && + typeof (error as any).statusCode === "number" + ) { + return SSEErrorCode.UPSTREAM_ERROR; + } } return SSEErrorCode.INTERNAL_ERROR; diff --git a/packages/appkit/src/stream/types.ts b/packages/appkit/src/stream/types.ts index 0fd862ba..3841bfd1 100644 --- a/packages/appkit/src/stream/types.ts +++ b/packages/appkit/src/stream/types.ts @@ -16,6 +16,7 @@ export const SSEErrorCode = { INVALID_REQUEST: "INVALID_REQUEST", STREAM_ABORTED: "STREAM_ABORTED", STREAM_EVICTED: "STREAM_EVICTED", + UPSTREAM_ERROR: "UPSTREAM_ERROR", } as const satisfies Record; export type SSEErrorCode = (typeof SSEErrorCode)[keyof typeof SSEErrorCode]; diff --git a/packages/appkit/src/type-generator/serving/cache.ts b/packages/appkit/src/type-generator/serving/cache.ts new file mode 100644 index 00000000..dc9bf7e2 --- /dev/null +++ b/packages/appkit/src/type-generator/serving/cache.ts @@ -0,0 +1,56 @@ +import crypto from "node:crypto"; +import fs from "node:fs/promises"; +import path from "node:path"; +import { createLogger } from "../../logging/logger"; + +const logger = createLogger("type-generator:serving:cache"); + +export const CACHE_VERSION = "1"; +const CACHE_FILE = ".appkit-serving-types-cache.json"; +const CACHE_DIR = path.join( + process.cwd(), + "node_modules", + ".databricks", + "appkit", +); + +export interface ServingCacheEntry { + hash: string; + requestType: string; + responseType: string; + chunkType: string | null; + requestKeys: string[]; +} + +export interface ServingCache { + version: string; + endpoints: Record; +} + +export function hashSchema(schemaJson: string): string { + return crypto.createHash("sha256").update(schemaJson).digest("hex"); +} + +export async function loadServingCache(): Promise { + const cachePath = path.join(CACHE_DIR, CACHE_FILE); + try { + await fs.mkdir(CACHE_DIR, { recursive: true }); + const raw = await fs.readFile(cachePath, "utf8"); + const cache = JSON.parse(raw) as ServingCache; + if (cache.version === CACHE_VERSION) { + return cache; + } + logger.debug("Cache version mismatch, starting fresh"); + } catch (err) { + if ((err as NodeJS.ErrnoException).code !== "ENOENT") { + logger.warn("Cache file is corrupted, flushing cache completely."); + } + } + return { version: CACHE_VERSION, endpoints: {} }; +} + +export async function saveServingCache(cache: ServingCache): Promise { + const cachePath = path.join(CACHE_DIR, CACHE_FILE); + await fs.mkdir(CACHE_DIR, { recursive: true }); + await fs.writeFile(cachePath, JSON.stringify(cache, null, 2), "utf8"); +} diff --git a/packages/appkit/src/type-generator/serving/converter.ts b/packages/appkit/src/type-generator/serving/converter.ts new file mode 100644 index 00000000..b56b0460 --- /dev/null +++ b/packages/appkit/src/type-generator/serving/converter.ts @@ -0,0 +1,159 @@ +import type { OpenApiOperation, OpenApiSchema } from "./fetcher"; + +/** + * Converts an OpenAPI schema to a TypeScript type string. + */ +function schemaToTypeString(schema: OpenApiSchema, indent = 0): string { + const pad = " ".repeat(indent); + + if (schema.oneOf) { + return schema.oneOf.map((s) => schemaToTypeString(s, indent)).join(" | "); + } + + if (schema.enum) { + return schema.enum.map((v) => JSON.stringify(v)).join(" | "); + } + + switch (schema.type) { + case "string": + return "string"; + case "integer": + case "number": + return "number"; + case "boolean": + return "boolean"; + case "array": { + if (!schema.items) return "unknown[]"; + const itemType = schemaToTypeString(schema.items, indent); + // Wrap union types in parens for array + if (itemType.includes(" | ") && !itemType.startsWith("{")) { + return `(${itemType})[]`; + } + return `${itemType}[]`; + } + case "object": { + if (!schema.properties) return "Record"; + const required = new Set(schema.required ?? []); + const entries = Object.entries(schema.properties).map(([key, prop]) => { + const optional = !required.has(key) ? "?" : ""; + const nullable = prop.nullable ? " | null" : ""; + const typeStr = schemaToTypeString(prop, indent + 1); + const formatComment = + prop.format && (prop.type === "number" || prop.type === "integer") + ? `/** @openapi ${prop.format}${prop.nullable ? ", nullable" : ""} */\n${pad} ` + : prop.nullable && prop.type === "integer" + ? `/** @openapi integer, nullable */\n${pad} ` + : ""; + return `${pad} ${formatComment}${key}${optional}: ${typeStr}${nullable};`; + }); + return `{\n${entries.join("\n")}\n${pad}}`; + } + default: + return "unknown"; + } +} + +/** + * Extracts the top-level property keys from the request schema. + * Strips the `stream` property (plugin-controlled). + */ +export function extractRequestKeys(operation: OpenApiOperation): string[] { + const schema = operation.requestBody?.content?.["application/json"]?.schema; + if (!schema?.properties) return []; + return Object.keys(schema.properties).filter((k) => k !== "stream"); +} + +/** + * Extracts and converts the request schema from an OpenAPI path operation. + * Strips the `stream` property from the request type. + */ +export function convertRequestSchema(operation: OpenApiOperation): string { + const schema = operation.requestBody?.content?.["application/json"]?.schema; + if (!schema || !schema.properties) return "Record"; + + // Strip `stream` property — the plugin controls this + const { stream: _stream, ...filteredProps } = schema.properties; + const filteredRequired = (schema.required ?? []).filter( + (r) => r !== "stream", + ); + + const filteredSchema: OpenApiSchema = { + ...schema, + properties: filteredProps, + required: filteredRequired.length > 0 ? filteredRequired : undefined, + }; + + return schemaToTypeString(filteredSchema); +} + +/** + * Extracts and converts the response schema from an OpenAPI path operation. + */ +export function convertResponseSchema(operation: OpenApiOperation): string { + const response = operation.responses?.["200"]; + const schema = response?.content?.["application/json"]?.schema; + if (!schema) return "unknown"; + return schemaToTypeString(schema); +} + +/** + * Derives a streaming chunk type from the response schema. + * Returns null if the response doesn't follow OpenAI-compatible format. + * + * OpenAI-compatible heuristic: response has `choices` array where items + * have a `message` object property. + */ +export function deriveChunkType(operation: OpenApiOperation): string | null { + const response = operation.responses?.["200"]; + const schema = response?.content?.["application/json"]?.schema; + if (!schema?.properties) return null; + + const choicesProp = schema.properties.choices; + if (!choicesProp || choicesProp.type !== "array" || !choicesProp.items) + return null; + + const choiceItemProps = choicesProp.items.properties; + if (!choiceItemProps?.message) return null; + + // It's OpenAI-compatible. Build the chunk type by transforming. + const messageSchema = choiceItemProps.message; + + // Build chunk schema: replace message with delta (Partial), make finish_reason nullable, drop usage + const chunkProperties: Record = {}; + + for (const [key, prop] of Object.entries(schema.properties)) { + if (key === "usage") continue; // Drop usage from chunks + if (key === "choices") { + // Transform choices items + const chunkChoiceProps: Record = {}; + for (const [ck, cp] of Object.entries(choiceItemProps)) { + if (ck === "message") { + // Replace message with delta: Partial + chunkChoiceProps.delta = { ...messageSchema }; + } else if (ck === "finish_reason") { + chunkChoiceProps[ck] = { ...cp, nullable: true }; + } else { + chunkChoiceProps[ck] = cp; + } + } + chunkProperties[key] = { + type: "array", + items: { + type: "object", + properties: chunkChoiceProps, + }, + }; + } else { + chunkProperties[key] = prop; + } + } + + const chunkSchema: OpenApiSchema = { + type: "object", + properties: chunkProperties, + }; + + // Delta properties are already optional (no `required` array in the schema), + // so schemaToTypeString renders them with `?:` — no Partial<> wrapper needed. + return schemaToTypeString(chunkSchema); +} diff --git a/packages/appkit/src/type-generator/serving/fetcher.ts b/packages/appkit/src/type-generator/serving/fetcher.ts new file mode 100644 index 00000000..bf733d7b --- /dev/null +++ b/packages/appkit/src/type-generator/serving/fetcher.ts @@ -0,0 +1,158 @@ +import type { WorkspaceClient } from "@databricks/sdk-experimental"; +import { createLogger } from "../../logging/logger"; + +const logger = createLogger("type-generator:serving:fetcher"); + +interface OpenApiSpec { + openapi: string; + info: { title: string; version: string }; + paths: Record>; +} + +export interface OpenApiOperation { + requestBody?: { + content: { + "application/json": { + schema: OpenApiSchema; + }; + }; + }; + responses?: Record< + string, + { + content?: { + "application/json": { + schema: OpenApiSchema; + }; + }; + } + >; +} + +export interface OpenApiSchema { + type?: string; + properties?: Record; + required?: string[]; + items?: OpenApiSchema; + enum?: string[]; + nullable?: boolean; + oneOf?: OpenApiSchema[]; + format?: string; +} + +/** + * Fetches the OpenAPI schema for a serving endpoint. + * Returns null if the endpoint is not found or access is denied. + */ +export async function fetchOpenApiSchema( + client: WorkspaceClient, + endpointName: string, + servedModel?: string, +): Promise<{ spec: OpenApiSpec; pathKey: string } | null> { + const headers = new Headers({ Accept: "application/json" }); + await client.config.authenticate(headers); + + const host = client.config.host; + if (!host) { + logger.warn("Databricks host not configured, skipping schema fetch"); + return null; + } + + const base = host.startsWith("http") ? host : `https://${host}`; + const url = new URL( + `/api/2.0/serving-endpoints/${encodeURIComponent(endpointName)}/openapi`, + base, + ); + + const controller = new AbortController(); + const timeout = setTimeout(() => controller.abort(), 5000); + + try { + const res = await fetch(url.toString(), { + headers, + signal: controller.signal, + }); + + if (!res.ok) { + const body = await res.text().catch(() => ""); + if (res.status === 404) { + logger.warn( + "Endpoint '%s' not found, skipping type generation%s", + endpointName, + body ? `: ${body}` : "", + ); + } else if (res.status === 403) { + logger.warn( + "Access denied to endpoint '%s' schema, skipping type generation%s", + endpointName, + body ? `: ${body}` : "", + ); + } else { + logger.warn( + "Failed to fetch schema for '%s' (HTTP %d), skipping%s", + endpointName, + res.status, + body ? `: ${body}` : "", + ); + } + return null; + } + + const rawSpec: unknown = await res.json(); + if ( + typeof rawSpec !== "object" || + rawSpec === null || + !("paths" in rawSpec) || + typeof (rawSpec as OpenApiSpec).paths !== "object" + ) { + logger.warn( + "Invalid OpenAPI schema structure for '%s', skipping", + endpointName, + ); + return null; + } + const spec = rawSpec as OpenApiSpec; + + // Find the right path key + const pathKeys = Object.keys(spec.paths ?? {}); + if (pathKeys.length === 0) { + logger.warn("No paths in OpenAPI schema for '%s'", endpointName); + return null; + } + + let pathKey: string; + if (servedModel) { + const match = pathKeys.find((k) => k.includes(`/${servedModel}/`)); + if (!match) { + logger.warn( + "Served model '%s' not found in schema for '%s', using first path", + servedModel, + endpointName, + ); + pathKey = pathKeys[0]; + } else { + pathKey = match; + } + } else { + pathKey = pathKeys[0]; + } + + return { spec, pathKey }; + } catch (err) { + if ((err as Error).name === "AbortError") { + logger.warn( + "Timeout fetching schema for '%s', skipping type generation", + endpointName, + ); + } else { + logger.warn( + "Error fetching schema for '%s': %s", + endpointName, + (err as Error).message, + ); + } + return null; + } finally { + clearTimeout(timeout); + } +} diff --git a/packages/appkit/src/type-generator/serving/generator.ts b/packages/appkit/src/type-generator/serving/generator.ts new file mode 100644 index 00000000..2cd88619 --- /dev/null +++ b/packages/appkit/src/type-generator/serving/generator.ts @@ -0,0 +1,276 @@ +import fs from "node:fs/promises"; +import { WorkspaceClient } from "@databricks/sdk-experimental"; +import pc from "picocolors"; +import { createLogger } from "../../logging/logger"; +import type { EndpointConfig } from "../../plugins/serving/types"; +import { + CACHE_VERSION, + hashSchema, + loadServingCache, + type ServingCache, + saveServingCache, +} from "./cache"; +import { + convertRequestSchema, + convertResponseSchema, + deriveChunkType, + extractRequestKeys, +} from "./converter"; +import { fetchOpenApiSchema } from "./fetcher"; + +const logger = createLogger("type-generator:serving"); + +const GENERIC_REQUEST = "Record"; +const GENERIC_RESPONSE = "unknown"; +const GENERIC_CHUNK = "unknown"; + +interface GenerateServingTypesOptions { + outFile: string; + endpoints?: Record; + noCache?: boolean; +} + +/** + * Generates TypeScript type declarations for serving endpoints + * by fetching their OpenAPI schemas and converting to TypeScript. + */ +export async function generateServingTypes( + options: GenerateServingTypesOptions, +): Promise { + const { outFile, noCache } = options; + + // Resolve endpoints from config or env + const endpoints = options.endpoints ?? resolveDefaultEndpoints(); + if (Object.keys(endpoints).length === 0) { + logger.debug("No serving endpoints configured, skipping type generation"); + return; + } + + const startTime = performance.now(); + + const cache = noCache + ? { version: CACHE_VERSION, endpoints: {} } + : await loadServingCache(); + + let client: WorkspaceClient | undefined; + let updated = false; + + const registryEntries: string[] = []; + const logEntries: Array<{ + alias: string; + status: "HIT" | "MISS"; + error?: string; + }> = []; + + for (const [alias, config] of Object.entries(endpoints)) { + const endpointName = process.env[config.env]; + if (!endpointName) { + registryEntries.push( + buildRegistryEntry( + alias, + GENERIC_REQUEST, + GENERIC_RESPONSE, + GENERIC_CHUNK, + ), + ); + logEntries.push({ + alias, + status: "MISS", + error: `env ${config.env} not set`, + }); + continue; + } + + client ??= new WorkspaceClient({}); + const result = await fetchOpenApiSchema( + client, + endpointName, + config.servedModel, + ); + if (!result) { + registryEntries.push( + buildRegistryEntry( + alias, + GENERIC_REQUEST, + GENERIC_RESPONSE, + GENERIC_CHUNK, + ), + ); + logEntries.push({ + alias, + status: "MISS", + error: "schema fetch failed", + }); + continue; + } + + const { spec, pathKey } = result; + const schemaJson = JSON.stringify(spec); + const hash = hashSchema(schemaJson); + + // Check cache + const cached = cache.endpoints[alias]; + if (cached && cached.hash === hash) { + registryEntries.push( + buildRegistryEntry( + alias, + cached.requestType, + cached.responseType, + cached.chunkType, + ), + ); + logEntries.push({ alias, status: "HIT" }); + continue; + } + + // Cache miss — convert + const operation = spec.paths[pathKey]?.post; + if (!operation) { + logEntries.push({ + alias, + status: "MISS", + error: "no POST operation", + }); + continue; + } + + let requestType: string; + let responseType: string; + let chunkType: string | null; + let requestKeys: string[]; + try { + requestType = convertRequestSchema(operation); + responseType = convertResponseSchema(operation); + chunkType = deriveChunkType(operation); + requestKeys = extractRequestKeys(operation); + } catch (convErr) { + logger.warn( + "Schema conversion failed for '%s': %s", + alias, + (convErr as Error).message, + ); + registryEntries.push( + buildRegistryEntry( + alias, + GENERIC_REQUEST, + GENERIC_RESPONSE, + GENERIC_CHUNK, + ), + ); + logEntries.push({ + alias, + status: "MISS", + error: "schema conversion failed", + }); + continue; + } + + cache.endpoints[alias] = { + hash, + requestType, + responseType, + chunkType, + requestKeys, + }; + updated = true; + + registryEntries.push( + buildRegistryEntry(alias, requestType, responseType, chunkType), + ); + logEntries.push({ alias, status: "MISS" }); + } + + // Print formatted table (matching analytics typegen output) + if (logEntries.length > 0) { + const maxNameLen = Math.max(...logEntries.map((e) => e.alias.length)); + const separator = pc.dim("─".repeat(50)); + console.log(""); + console.log( + ` ${pc.bold("Typegen Serving")} ${pc.dim(`(${logEntries.length})`)}`, + ); + console.log(` ${separator}`); + for (const entry of logEntries) { + const tag = + entry.status === "HIT" + ? `cache ${pc.bold(pc.green("HIT "))}` + : `cache ${pc.bold(pc.yellow("MISS "))}`; + const rawName = entry.alias.padEnd(maxNameLen); + const reason = entry.error ? ` ${pc.dim(entry.error)}` : ""; + console.log(` ${tag} ${rawName}${reason}`); + } + const elapsed = ((performance.now() - startTime) / 1000).toFixed(2); + const newCount = logEntries.filter((e) => e.status === "MISS").length; + const cacheCount = logEntries.filter((e) => e.status === "HIT").length; + console.log(` ${separator}`); + console.log( + ` ${newCount} new, ${cacheCount} from cache. ${pc.dim(`${elapsed}s`)}`, + ); + console.log(""); + } + + const output = generateTypeDeclarations(registryEntries); + await fs.writeFile(outFile, output, "utf-8"); + + if (registryEntries.length === 0) { + logger.debug( + "Wrote empty serving types to %s (no endpoints resolved)", + outFile, + ); + } else { + logger.debug("Wrote serving types to %s", outFile); + } + + if (updated) { + await saveServingCache(cache as ServingCache); + } +} + +function resolveDefaultEndpoints(): Record { + if (process.env.DATABRICKS_SERVING_ENDPOINT) { + return { default: { env: "DATABRICKS_SERVING_ENDPOINT" } }; + } + return {}; +} + +function buildRegistryEntry( + alias: string, + requestType: string, + responseType: string, + chunkType: string | null, +): string { + const indent = " "; + const chunkEntry = chunkType ? chunkType : "unknown"; + return ` ${alias}: { +${indent}request: ${indentType(requestType, indent)}; +${indent}response: ${indentType(responseType, indent)}; +${indent}chunk: ${indentType(chunkEntry, indent)}; + };`; +} + +function indentType(typeStr: string, baseIndent: string): string { + if (!typeStr.includes("\n")) return typeStr; + return typeStr + .split("\n") + .map((line, i) => (i === 0 ? line : `${baseIndent}${line}`)) + .join("\n"); +} + +function generateTypeDeclarations(entries: string[]): string { + return `// Auto-generated by AppKit - DO NOT EDIT +// Generated from serving endpoint OpenAPI schemas +import "@databricks/appkit"; +import "@databricks/appkit-ui/react"; + +declare module "@databricks/appkit" { + interface ServingEndpointRegistry { +${entries.join("\n")} + } +} + +declare module "@databricks/appkit-ui/react" { + interface ServingEndpointRegistry { +${entries.join("\n")} + } +} +`; +} diff --git a/packages/appkit/src/type-generator/serving/server-file-extractor.ts b/packages/appkit/src/type-generator/serving/server-file-extractor.ts new file mode 100644 index 00000000..cb1fbe7e --- /dev/null +++ b/packages/appkit/src/type-generator/serving/server-file-extractor.ts @@ -0,0 +1,221 @@ +import fs from "node:fs"; +import path from "node:path"; +import { Lang, parse, type SgNode } from "@ast-grep/napi"; +import { createLogger } from "../../logging/logger"; +import type { EndpointConfig } from "../../plugins/serving/types"; + +const logger = createLogger("type-generator:serving:extractor"); + +/** + * Candidate paths for the server entry file, relative to the project root. + * Checked in order; the first that exists is used. + * Same convention as plugin sync (sync.ts SERVER_FILE_CANDIDATES). + */ +const SERVER_FILE_CANDIDATES = ["server/index.ts", "server/server.ts"]; + +/** + * Find the server entry file by checking candidate paths in order. + * + * @param basePath - Project root directory to search from + * @returns Absolute path to the server file, or null if none found + */ +export function findServerFile(basePath: string): string | null { + for (const candidate of SERVER_FILE_CANDIDATES) { + const fullPath = path.join(basePath, candidate); + if (fs.existsSync(fullPath)) { + return fullPath; + } + } + return null; +} + +/** + * Extract serving endpoint config from a server file by AST-parsing it. + * Looks for `serving({ endpoints: { alias: { env: "..." }, ... } })` calls + * and extracts the endpoint alias names and their environment variable mappings. + * + * @param serverFilePath - Absolute path to the server entry file + * @returns Extracted endpoint config, or null if not found or not extractable + */ +export function extractServingEndpoints( + serverFilePath: string, +): Record | null { + let content: string; + try { + content = fs.readFileSync(serverFilePath, "utf-8"); + } catch { + logger.debug("Could not read server file: %s", serverFilePath); + return null; + } + + const lang = serverFilePath.endsWith(".tsx") ? Lang.Tsx : Lang.TypeScript; + const ast = parse(lang, content); + const root = ast.root(); + + // Find serving(...) call expressions + const servingCall = findServingCall(root); + if (!servingCall) { + logger.debug("No serving() call found in %s", serverFilePath); + return null; + } + + // Get the first argument (the config object) + const args = servingCall.field("arguments"); + if (!args) { + return null; + } + + const configArg = args.children().find((child) => child.kind() === "object"); + if (!configArg) { + // serving() called with no args or non-object arg + return null; + } + + // Find the "endpoints" property in the config object + const endpointsPair = findProperty(configArg, "endpoints"); + if (!endpointsPair) { + // Config object has no "endpoints" property (e.g. serving({ timeout: 5000 })) + return null; + } + + // Get the value of the endpoints property + const endpointsValue = getPropertyValue(endpointsPair); + if (!endpointsValue || endpointsValue.kind() !== "object") { + // endpoints is a variable reference, not an inline object + logger.debug( + "serving() endpoints is not an inline object literal in %s. " + + "Pass endpoints explicitly via appKitServingTypesPlugin({ endpoints }) in vite.config.ts.", + serverFilePath, + ); + return null; + } + + // Extract each endpoint entry + const endpoints: Record = {}; + const pairs = endpointsValue + .children() + .filter((child) => child.kind() === "pair"); + + for (const pair of pairs) { + const entry = extractEndpointEntry(pair); + if (entry) { + endpoints[entry.alias] = entry.config; + } + } + + if (Object.keys(endpoints).length === 0) { + return null; + } + + logger.debug( + "Extracted %d endpoint(s) from %s: %s", + Object.keys(endpoints).length, + serverFilePath, + Object.keys(endpoints).join(", "), + ); + + return endpoints; +} + +/** + * Find the serving() call expression in the AST. + * Looks for call expressions where the callee identifier is "serving". + */ +function findServingCall(root: SgNode): SgNode | null { + const callExpressions = root.findAll({ + rule: { kind: "call_expression" }, + }); + + for (const call of callExpressions) { + const callee = call.children()[0]; + if (callee?.kind() === "identifier" && callee.text() === "serving") { + return call; + } + } + + return null; +} + +/** + * Find a property (pair node) with the given key name in an object expression. + */ +function findProperty(objectNode: SgNode, propertyName: string): SgNode | null { + const pairs = objectNode + .children() + .filter((child) => child.kind() === "pair"); + + for (const pair of pairs) { + const key = pair.children()[0]; + if (!key) continue; + + const keyText = + key.kind() === "property_identifier" + ? key.text() + : key.kind() === "string" + ? key.text().replace(/^['"]|['"]$/g, "") + : null; + + if (keyText === propertyName) { + return pair; + } + } + + return null; +} + +/** + * Get the value node from a pair (property: value). + * The value is typically the last meaningful child after the colon. + */ +function getPropertyValue(pairNode: SgNode): SgNode | null { + const children = pairNode.children(); + // pair children: [key, ":", value] + return children.length >= 3 ? children[children.length - 1] : null; +} + +/** + * Extract a single endpoint entry from a pair node like: + * `demo: { env: "DATABRICKS_SERVING_ENDPOINT", servedModel: "my-model" }` + */ +function extractEndpointEntry( + pair: SgNode, +): { alias: string; config: EndpointConfig } | null { + const children = pair.children(); + if (children.length < 3) return null; + + // Get alias name (the key) + const keyNode = children[0]; + const alias = + keyNode.kind() === "property_identifier" + ? keyNode.text() + : keyNode.kind() === "string" + ? keyNode.text().replace(/^['"]|['"]$/g, "") + : null; + + if (!alias) return null; + + // Get the value (should be an object like { env: "..." }) + const valueNode = children[children.length - 1]; + if (valueNode.kind() !== "object") return null; + + // Extract env field + const envPair = findProperty(valueNode, "env"); + if (!envPair) return null; + + const envValue = getPropertyValue(envPair); + if (!envValue || envValue.kind() !== "string") return null; + + const env = envValue.text().replace(/^['"]|['"]$/g, ""); + + // Extract optional servedModel field + const config: EndpointConfig = { env }; + const servedModelPair = findProperty(valueNode, "servedModel"); + if (servedModelPair) { + const servedModelValue = getPropertyValue(servedModelPair); + if (servedModelValue?.kind() === "string") { + config.servedModel = servedModelValue.text().replace(/^['"]|['"]$/g, ""); + } + } + + return { alias, config }; +} diff --git a/packages/appkit/src/type-generator/serving/tests/cache.test.ts b/packages/appkit/src/type-generator/serving/tests/cache.test.ts new file mode 100644 index 00000000..0c99c997 --- /dev/null +++ b/packages/appkit/src/type-generator/serving/tests/cache.test.ts @@ -0,0 +1,109 @@ +import fs from "node:fs/promises"; +import { afterEach, beforeEach, describe, expect, test, vi } from "vitest"; +import { + CACHE_VERSION, + hashSchema, + loadServingCache, + type ServingCache, + saveServingCache, +} from "../cache"; + +vi.mock("node:fs/promises"); + +describe("serving cache", () => { + beforeEach(() => { + vi.mocked(fs.mkdir).mockResolvedValue(undefined); + }); + + afterEach(() => { + vi.restoreAllMocks(); + }); + + describe("hashSchema", () => { + test("returns consistent SHA256 hash", () => { + const hash1 = hashSchema('{"openapi": "3.1.0"}'); + const hash2 = hashSchema('{"openapi": "3.1.0"}'); + expect(hash1).toBe(hash2); + expect(hash1).toHaveLength(64); // SHA256 hex + }); + + test("different inputs produce different hashes", () => { + const hash1 = hashSchema('{"a": 1}'); + const hash2 = hashSchema('{"a": 2}'); + expect(hash1).not.toBe(hash2); + }); + }); + + describe("loadServingCache", () => { + test("returns empty cache when file does not exist", async () => { + vi.mocked(fs.readFile).mockRejectedValue( + Object.assign(new Error("ENOENT"), { code: "ENOENT" }), + ); + + const cache = await loadServingCache(); + expect(cache).toEqual({ version: CACHE_VERSION, endpoints: {} }); + }); + + test("returns parsed cache when file exists with correct version", async () => { + const cached: ServingCache = { + version: CACHE_VERSION, + endpoints: { + llm: { + hash: "abc", + requestType: "{ messages: string[] }", + responseType: "{ model: string }", + chunkType: null, + requestKeys: ["messages"], + }, + }, + }; + vi.mocked(fs.readFile).mockResolvedValue(JSON.stringify(cached)); + + const cache = await loadServingCache(); + expect(cache).toEqual(cached); + }); + + test("flushes cache when version mismatches", async () => { + vi.mocked(fs.readFile).mockResolvedValue( + JSON.stringify({ version: "0", endpoints: { old: {} } }), + ); + + const cache = await loadServingCache(); + expect(cache).toEqual({ version: CACHE_VERSION, endpoints: {} }); + }); + + test("flushes cache when file is corrupted", async () => { + vi.mocked(fs.readFile).mockResolvedValue("not json"); + + const cache = await loadServingCache(); + expect(cache).toEqual({ version: CACHE_VERSION, endpoints: {} }); + }); + }); + + describe("saveServingCache", () => { + test("writes cache to file", async () => { + vi.mocked(fs.writeFile).mockResolvedValue(); + + const cache: ServingCache = { + version: CACHE_VERSION, + endpoints: { + test: { + hash: "xyz", + requestType: "{}", + responseType: "{}", + chunkType: null, + requestKeys: [], + }, + }, + }; + + await saveServingCache(cache); + + expect(fs.writeFile).toHaveBeenCalledWith( + expect.stringContaining(".appkit-serving-types-cache.json"), + JSON.stringify(cache, null, 2), + "utf8", + ); + }); + }); +}); diff --git a/packages/appkit/src/type-generator/serving/tests/converter.test.ts b/packages/appkit/src/type-generator/serving/tests/converter.test.ts new file mode 100644 index 00000000..1be30738 --- /dev/null +++ b/packages/appkit/src/type-generator/serving/tests/converter.test.ts @@ -0,0 +1,308 @@ +import { describe, expect, test } from "vitest"; +import { + convertRequestSchema, + convertResponseSchema, + deriveChunkType, + extractRequestKeys, +} from "../converter"; +import type { OpenApiOperation, OpenApiSchema } from "../fetcher"; + +function makeOperation( + requestProps: Record, + responseProps?: Record, + required?: string[], +): OpenApiOperation { + return { + requestBody: { + content: { + "application/json": { + schema: { + type: "object", + properties: requestProps, + required, + }, + }, + }, + }, + responses: responseProps + ? { + "200": { + content: { + "application/json": { + schema: { + type: "object", + properties: responseProps, + }, + }, + }, + }, + } + : undefined, + }; +} + +describe("converter", () => { + describe("convertRequestSchema", () => { + test("converts string type", () => { + const op = makeOperation({ name: { type: "string" } }); + const result = convertRequestSchema(op); + expect(result).toContain("name?: string;"); + }); + + test("converts integer type to number", () => { + const op = makeOperation({ count: { type: "integer" } }); + expect(convertRequestSchema(op)).toContain("count?: number;"); + }); + + test("converts number type", () => { + const op = makeOperation({ + temp: { type: "number", format: "double" }, + }); + expect(convertRequestSchema(op)).toContain("temp?: number;"); + }); + + test("converts boolean type", () => { + const op = makeOperation({ flag: { type: "boolean" } }); + expect(convertRequestSchema(op)).toContain("flag?: boolean;"); + }); + + test("converts enum to string literal union", () => { + const op = makeOperation({ + role: { type: "string", enum: ["user", "assistant"] }, + }); + const result = convertRequestSchema(op); + expect(result).toContain('"user" | "assistant"'); + }); + + test("converts array type", () => { + const op = makeOperation({ + items: { type: "array", items: { type: "string" } }, + }); + expect(convertRequestSchema(op)).toContain("items?: string[];"); + }); + + test("converts nested object", () => { + const op = makeOperation({ + messages: { + type: "array", + items: { + type: "object", + properties: { + role: { type: "string" }, + content: { type: "string" }, + }, + }, + }, + }); + const result = convertRequestSchema(op); + expect(result).toContain("role?: string;"); + expect(result).toContain("content?: string;"); + }); + + test("handles nullable properties", () => { + const op = makeOperation({ + temp: { type: "number", nullable: true }, + }); + expect(convertRequestSchema(op)).toContain("temp?: number | null;"); + }); + + test("handles oneOf union types", () => { + const op = makeOperation({ + stop: { + oneOf: [ + { type: "string" }, + { type: "array", items: { type: "string" } }, + ], + }, + }); + const result = convertRequestSchema(op); + expect(result).toContain("string | string[]"); + }); + + test("strips stream property from request", () => { + const op = makeOperation({ + messages: { type: "array", items: { type: "string" } }, + stream: { type: "boolean", nullable: true }, + temperature: { type: "number" }, + }); + const result = convertRequestSchema(op); + expect(result).not.toContain("stream"); + expect(result).toContain("messages"); + expect(result).toContain("temperature"); + }); + + test("marks required properties without ?", () => { + const op = makeOperation( + { + messages: { type: "array", items: { type: "string" } }, + temperature: { type: "number" }, + }, + undefined, + ["messages"], + ); + const result = convertRequestSchema(op); + expect(result).toContain("messages: string[];"); + expect(result).toContain("temperature?: number;"); + }); + + test("returns Record for missing schema", () => { + const op: OpenApiOperation = {}; + expect(convertRequestSchema(op)).toBe("Record"); + }); + }); + + describe("convertResponseSchema", () => { + test("converts response schema", () => { + const op = makeOperation( + {}, + { + model: { type: "string" }, + id: { type: "string" }, + }, + ); + const result = convertResponseSchema(op); + expect(result).toContain("model?: string;"); + expect(result).toContain("id?: string;"); + }); + + test("returns unknown for missing response", () => { + const op: OpenApiOperation = {}; + expect(convertResponseSchema(op)).toBe("unknown"); + }); + }); + + describe("deriveChunkType", () => { + test("derives chunk type from OpenAI-compatible response", () => { + const op: OpenApiOperation = { + responses: { + "200": { + content: { + "application/json": { + schema: { + type: "object", + properties: { + model: { type: "string" }, + choices: { + type: "array", + items: { + type: "object", + properties: { + index: { type: "integer" }, + message: { + type: "object", + properties: { + role: { + type: "string", + enum: ["user", "assistant"], + }, + content: { type: "string" }, + }, + }, + finish_reason: { type: "string" }, + }, + }, + }, + usage: { + type: "object", + properties: { + prompt_tokens: { type: "integer" }, + }, + nullable: true, + }, + id: { type: "string" }, + }, + }, + }, + }, + }, + }, + }; + + const result = deriveChunkType(op); + expect(result).not.toBeNull(); + // Should have delta instead of message + expect(result).toContain("delta"); + expect(result).not.toContain("message"); + // Should make finish_reason nullable + expect(result).toContain("finish_reason"); + expect(result).toContain("| null"); + // Should drop usage + expect(result).not.toContain("usage"); + // Should keep model and id + expect(result).toContain("model"); + expect(result).toContain("id"); + }); + + test("returns null for non-OpenAI response (no choices)", () => { + const op = makeOperation( + {}, + { + predictions: { type: "array", items: { type: "number" } }, + }, + ); + expect(deriveChunkType(op)).toBeNull(); + }); + + test("returns null for choices without message", () => { + const op: OpenApiOperation = { + responses: { + "200": { + content: { + "application/json": { + schema: { + type: "object", + properties: { + choices: { + type: "array", + items: { + type: "object", + properties: { + score: { type: "number" }, + }, + }, + }, + }, + }, + }, + }, + }, + }, + }; + expect(deriveChunkType(op)).toBeNull(); + }); + + test("returns null for missing response", () => { + const op: OpenApiOperation = {}; + expect(deriveChunkType(op)).toBeNull(); + }); + }); + + describe("extractRequestKeys", () => { + test("extracts top-level property keys excluding stream", () => { + const op = makeOperation({ + messages: { type: "array", items: { type: "string" } }, + temperature: { type: "number" }, + stream: { type: "boolean", nullable: true }, + }); + expect(extractRequestKeys(op)).toEqual(["messages", "temperature"]); + }); + + test("returns empty array for missing schema", () => { + const op: OpenApiOperation = {}; + expect(extractRequestKeys(op)).toEqual([]); + }); + + test("returns empty array for schema without properties", () => { + const op: OpenApiOperation = { + requestBody: { + content: { + "application/json": { + schema: { type: "object" }, + }, + }, + }, + }; + expect(extractRequestKeys(op)).toEqual([]); + }); + }); +}); diff --git a/packages/appkit/src/type-generator/serving/tests/fetcher.test.ts b/packages/appkit/src/type-generator/serving/tests/fetcher.test.ts new file mode 100644 index 00000000..802540b0 --- /dev/null +++ b/packages/appkit/src/type-generator/serving/tests/fetcher.test.ts @@ -0,0 +1,209 @@ +import { afterEach, beforeEach, describe, expect, test, vi } from "vitest"; +import { fetchOpenApiSchema } from "../fetcher"; + +const mockAuthenticate = vi.fn(async () => {}); + +function createMockClient(host?: string) { + return { + config: { + host, + authenticate: mockAuthenticate, + }, + } as any; +} + +function makeValidSpec( + paths: Record = { "/invocations": { post: {} } }, +) { + return { + openapi: "3.0.0", + info: { title: "test", version: "1" }, + paths, + }; +} + +describe("fetchOpenApiSchema", () => { + beforeEach(() => { + vi.spyOn(globalThis, "fetch").mockResolvedValue( + new Response(JSON.stringify(makeValidSpec()), { status: 200 }), + ); + }); + + afterEach(() => { + vi.restoreAllMocks(); + }); + + test("returns null when host is not configured", async () => { + const result = await fetchOpenApiSchema(createMockClient(undefined), "ep"); + expect(result).toBeNull(); + }); + + test("returns null on HTTP 404", async () => { + vi.spyOn(globalThis, "fetch").mockResolvedValue( + new Response("Not found", { status: 404 }), + ); + + const result = await fetchOpenApiSchema( + createMockClient("https://host.databricks.com"), + "my-endpoint", + ); + expect(result).toBeNull(); + }); + + test("returns null on HTTP 403", async () => { + vi.spyOn(globalThis, "fetch").mockResolvedValue( + new Response("Forbidden", { status: 403 }), + ); + + const result = await fetchOpenApiSchema( + createMockClient("https://host.databricks.com"), + "my-endpoint", + ); + expect(result).toBeNull(); + }); + + test("returns null on generic error status", async () => { + vi.spyOn(globalThis, "fetch").mockResolvedValue( + new Response("Server error", { status: 500 }), + ); + + const result = await fetchOpenApiSchema( + createMockClient("https://host.databricks.com"), + "my-endpoint", + ); + expect(result).toBeNull(); + }); + + test("returns null on timeout (AbortError)", async () => { + vi.spyOn(globalThis, "fetch").mockRejectedValue( + Object.assign(new Error("The operation was aborted"), { + name: "AbortError", + }), + ); + + const result = await fetchOpenApiSchema( + createMockClient("https://host.databricks.com"), + "my-endpoint", + ); + expect(result).toBeNull(); + }); + + test("returns null on network error", async () => { + vi.spyOn(globalThis, "fetch").mockRejectedValue(new Error("fetch failed")); + + const result = await fetchOpenApiSchema( + createMockClient("https://host.databricks.com"), + "my-endpoint", + ); + expect(result).toBeNull(); + }); + + test("returns spec and pathKey for valid response", async () => { + const spec = makeValidSpec({ + "/serving-endpoints/ep/invocations": { post: { requestBody: {} } }, + }); + vi.spyOn(globalThis, "fetch").mockResolvedValue( + new Response(JSON.stringify(spec), { status: 200 }), + ); + + const result = await fetchOpenApiSchema( + createMockClient("https://host.databricks.com"), + "ep", + ); + expect(result).not.toBeNull(); + expect(result?.pathKey).toBe("/serving-endpoints/ep/invocations"); + expect(result?.spec.openapi).toBe("3.0.0"); + }); + + test("matches servedModel path when provided", async () => { + const spec = makeValidSpec({ + "/serving-endpoints/ep/served-models/gpt4/invocations": { post: {} }, + "/serving-endpoints/ep/invocations": { post: {} }, + }); + vi.spyOn(globalThis, "fetch").mockResolvedValue( + new Response(JSON.stringify(spec), { status: 200 }), + ); + + const result = await fetchOpenApiSchema( + createMockClient("https://host.databricks.com"), + "ep", + "gpt4", + ); + expect(result?.pathKey).toBe( + "/serving-endpoints/ep/served-models/gpt4/invocations", + ); + }); + + test("falls back to first path when servedModel not found", async () => { + const spec = makeValidSpec({ + "/serving-endpoints/ep/invocations": { post: {} }, + }); + vi.spyOn(globalThis, "fetch").mockResolvedValue( + new Response(JSON.stringify(spec), { status: 200 }), + ); + + const result = await fetchOpenApiSchema( + createMockClient("https://host.databricks.com"), + "ep", + "nonexistent-model", + ); + expect(result?.pathKey).toBe("/serving-endpoints/ep/invocations"); + }); + + test("returns null for invalid spec structure (missing paths)", async () => { + vi.spyOn(globalThis, "fetch").mockResolvedValue( + new Response(JSON.stringify({ openapi: "3.0.0", info: {} }), { + status: 200, + }), + ); + + const result = await fetchOpenApiSchema( + createMockClient("https://host.databricks.com"), + "ep", + ); + expect(result).toBeNull(); + }); + + test("returns null when paths object is empty", async () => { + vi.spyOn(globalThis, "fetch").mockResolvedValue( + new Response(JSON.stringify(makeValidSpec({})), { status: 200 }), + ); + + const result = await fetchOpenApiSchema( + createMockClient("https://host.databricks.com"), + "ep", + ); + expect(result).toBeNull(); + }); + + test("authenticates request headers", async () => { + await fetchOpenApiSchema( + createMockClient("https://host.databricks.com"), + "ep", + ); + expect(mockAuthenticate).toHaveBeenCalledWith(expect.any(Headers)); + }); + + test("constructs correct URL with encoded endpoint name", async () => { + const fetchSpy = vi.spyOn(globalThis, "fetch"); + + await fetchOpenApiSchema( + createMockClient("https://host.databricks.com"), + "my endpoint", + ); + + expect(fetchSpy).toHaveBeenCalledWith( + expect.stringContaining("/serving-endpoints/my%20endpoint/openapi"), + expect.any(Object), + ); + }); + + test("prepends https when host lacks protocol", async () => { + const fetchSpy = vi.spyOn(globalThis, "fetch"); + + await fetchOpenApiSchema(createMockClient("host.databricks.com"), "ep"); + + const url = fetchSpy.mock.calls[0][0] as string; + expect(url.startsWith("https://")).toBe(true); + }); +}); diff --git a/packages/appkit/src/type-generator/serving/tests/generator.test.ts b/packages/appkit/src/type-generator/serving/tests/generator.test.ts new file mode 100644 index 00000000..f9d1b378 --- /dev/null +++ b/packages/appkit/src/type-generator/serving/tests/generator.test.ts @@ -0,0 +1,215 @@ +import fs from "node:fs/promises"; +import { afterEach, beforeEach, describe, expect, test, vi } from "vitest"; +import { generateServingTypes } from "../generator"; + +vi.mock("node:fs/promises"); + +// Mock cache module +vi.mock("../cache", () => ({ + CACHE_VERSION: "1", + hashSchema: vi.fn(() => "mock-hash"), + loadServingCache: vi.fn(async () => ({ version: "1", endpoints: {} })), + saveServingCache: vi.fn(async () => {}), +})); + +// Mock fetcher +const mockFetchOpenApiSchema = vi.fn(); +vi.mock("../fetcher", () => ({ + fetchOpenApiSchema: (...args: any[]) => mockFetchOpenApiSchema(...args), +})); + +// Mock WorkspaceClient +vi.mock("@databricks/sdk-experimental", () => ({ + WorkspaceClient: vi.fn(() => ({ config: {} })), +})); + +const CHAT_OPENAPI_SPEC = { + openapi: "3.1.0", + info: { title: "test", version: "1" }, + paths: { + "/served-models/llm/invocations": { + post: { + requestBody: { + content: { + "application/json": { + schema: { + type: "object", + properties: { + messages: { + type: "array", + items: { + type: "object", + properties: { + role: { type: "string" }, + content: { type: "string" }, + }, + }, + }, + temperature: { type: "number", nullable: true }, + stream: { type: "boolean", nullable: true }, + }, + }, + }, + }, + }, + responses: { + "200": { + content: { + "application/json": { + schema: { + type: "object", + properties: { + model: { type: "string" }, + choices: { + type: "array", + items: { + type: "object", + properties: { + message: { + type: "object", + properties: { + role: { type: "string" }, + content: { type: "string" }, + }, + }, + finish_reason: { type: "string" }, + }, + }, + }, + }, + }, + }, + }, + }, + }, + }, + }, + }, +}; + +describe("generateServingTypes", () => { + const outFile = "/tmp/test-serving-types.d.ts"; + + beforeEach(() => { + vi.mocked(fs.writeFile).mockResolvedValue(); + process.env.TEST_SERVING_ENDPOINT = "my-endpoint"; + }); + + afterEach(() => { + delete process.env.TEST_SERVING_ENDPOINT; + delete process.env.DATABRICKS_SERVING_ENDPOINT; + vi.restoreAllMocks(); + }); + + test("generates .d.ts with module augmentation for a chat endpoint", async () => { + mockFetchOpenApiSchema.mockResolvedValue({ + spec: CHAT_OPENAPI_SPEC, + pathKey: "/served-models/llm/invocations", + }); + + await generateServingTypes({ + outFile, + endpoints: { llm: { env: "TEST_SERVING_ENDPOINT" } }, + noCache: true, + }); + + expect(fs.writeFile).toHaveBeenCalledWith( + outFile, + expect.any(String), + "utf-8", + ); + const output = vi.mocked(fs.writeFile).mock.calls[0][1] as string; + + // Verify module augmentation structure + expect(output).toContain("// Auto-generated by AppKit - DO NOT EDIT"); + expect(output).toContain('import "@databricks/appkit"'); + expect(output).toContain('import "@databricks/appkit-ui/react"'); + expect(output).toContain('declare module "@databricks/appkit"'); + expect(output).toContain('declare module "@databricks/appkit-ui/react"'); + expect(output).toContain("interface ServingEndpointRegistry"); + expect(output).toContain("llm:"); + expect(output).toContain("request:"); + expect(output).toContain("response:"); + expect(output).toContain("chunk:"); + }); + + test("strips stream property from generated request type", async () => { + mockFetchOpenApiSchema.mockResolvedValue({ + spec: CHAT_OPENAPI_SPEC, + pathKey: "/served-models/llm/invocations", + }); + + await generateServingTypes({ + outFile, + endpoints: { llm: { env: "TEST_SERVING_ENDPOINT" } }, + noCache: true, + }); + + const output = vi.mocked(fs.writeFile).mock.calls[0][1] as string; + // `stream` should be stripped from request type + expect(output).toContain("messages"); + expect(output).toContain("temperature"); + expect(output).not.toMatch(/\bstream\??\s*:/); + }); + + test("emits generic types when env var is not set", async () => { + delete process.env.TEST_SERVING_ENDPOINT; + + await generateServingTypes({ + outFile, + endpoints: { llm: { env: "TEST_SERVING_ENDPOINT" } }, + noCache: true, + }); + + expect(mockFetchOpenApiSchema).not.toHaveBeenCalled(); + const output = vi.mocked(fs.writeFile).mock.calls[0][1] as string; + expect(output).toContain("llm:"); + expect(output).toContain("Record"); + }); + + test("skips generation when no endpoints configured and no env var", async () => { + await generateServingTypes({ + outFile, + noCache: true, + }); + + expect(mockFetchOpenApiSchema).not.toHaveBeenCalled(); + expect(fs.writeFile).not.toHaveBeenCalled(); + }); + + test("emits generic types when schema fetch returns null", async () => { + mockFetchOpenApiSchema.mockResolvedValue(null); + + await generateServingTypes({ + outFile, + endpoints: { llm: { env: "TEST_SERVING_ENDPOINT" } }, + noCache: true, + }); + + const output = vi.mocked(fs.writeFile).mock.calls[0][1] as string; + expect(output).toContain("llm:"); + expect(output).toContain("Record"); + }); + + test("resolves default endpoint from DATABRICKS_SERVING_ENDPOINT", async () => { + process.env.DATABRICKS_SERVING_ENDPOINT = "my-default-endpoint"; + mockFetchOpenApiSchema.mockResolvedValue({ + spec: CHAT_OPENAPI_SPEC, + pathKey: "/served-models/llm/invocations", + }); + + await generateServingTypes({ + outFile, + noCache: true, + }); + + expect(mockFetchOpenApiSchema).toHaveBeenCalledWith( + expect.anything(), + "my-default-endpoint", + undefined, + ); + + const output = vi.mocked(fs.writeFile).mock.calls[0][1] as string; + expect(output).toContain("default:"); + }); +}); diff --git a/packages/appkit/src/type-generator/serving/tests/server-file-extractor.test.ts b/packages/appkit/src/type-generator/serving/tests/server-file-extractor.test.ts new file mode 100644 index 00000000..f0a94709 --- /dev/null +++ b/packages/appkit/src/type-generator/serving/tests/server-file-extractor.test.ts @@ -0,0 +1,213 @@ +import fs from "node:fs"; +import path from "node:path"; +import { afterEach, describe, expect, test, vi } from "vitest"; +import { + extractServingEndpoints, + findServerFile, +} from "../server-file-extractor"; + +describe("findServerFile", () => { + afterEach(() => { + vi.restoreAllMocks(); + }); + + test("returns server/index.ts when it exists", () => { + vi.spyOn(fs, "existsSync").mockImplementation((p) => + String(p).endsWith(path.join("server", "index.ts")), + ); + expect(findServerFile("/app")).toBe( + path.join("/app", "server", "index.ts"), + ); + }); + + test("returns server/server.ts when index.ts does not exist", () => { + vi.spyOn(fs, "existsSync").mockImplementation((p) => + String(p).endsWith(path.join("server", "server.ts")), + ); + expect(findServerFile("/app")).toBe( + path.join("/app", "server", "server.ts"), + ); + }); + + test("returns null when no server file exists", () => { + vi.spyOn(fs, "existsSync").mockReturnValue(false); + expect(findServerFile("/app")).toBeNull(); + }); +}); + +describe("extractServingEndpoints", () => { + afterEach(() => { + vi.restoreAllMocks(); + }); + + function mockServerFile(content: string) { + vi.spyOn(fs, "readFileSync").mockReturnValue(content); + } + + test("extracts inline endpoints from serving() call", () => { + mockServerFile(` +import { createApp, serving } from '@databricks/appkit'; + +createApp({ + plugins: [ + serving({ + endpoints: { + demo: { env: "DATABRICKS_SERVING_ENDPOINT" }, + second: { env: "DATABRICKS_SERVING_ENDPOINT_SECOND" }, + } + }), + ], +}); +`); + + const result = extractServingEndpoints("/app/server/index.ts"); + expect(result).toEqual({ + demo: { env: "DATABRICKS_SERVING_ENDPOINT" }, + second: { env: "DATABRICKS_SERVING_ENDPOINT_SECOND" }, + }); + }); + + test("extracts servedModel when present", () => { + mockServerFile(` +import { createApp, serving } from '@databricks/appkit'; + +createApp({ + plugins: [ + serving({ + endpoints: { + demo: { env: "DATABRICKS_SERVING_ENDPOINT", servedModel: "my-model" }, + } + }), + ], +}); +`); + + const result = extractServingEndpoints("/app/server/index.ts"); + expect(result).toEqual({ + demo: { env: "DATABRICKS_SERVING_ENDPOINT", servedModel: "my-model" }, + }); + }); + + test("returns null when serving() has no arguments", () => { + mockServerFile(` +import { createApp, serving } from '@databricks/appkit'; + +createApp({ + plugins: [serving()], +}); +`); + + const result = extractServingEndpoints("/app/server/index.ts"); + expect(result).toBeNull(); + }); + + test("returns null when serving() has config but no endpoints", () => { + mockServerFile(` +import { createApp, serving } from '@databricks/appkit'; + +createApp({ + plugins: [ + serving({ timeout: 5000 }), + ], +}); +`); + + const result = extractServingEndpoints("/app/server/index.ts"); + expect(result).toBeNull(); + }); + + test("returns null when serving() has empty config object", () => { + mockServerFile(` +import { createApp, serving } from '@databricks/appkit'; + +createApp({ + plugins: [serving({})], +}); +`); + + const result = extractServingEndpoints("/app/server/index.ts"); + expect(result).toBeNull(); + }); + + test("returns null when endpoints is a variable reference", () => { + mockServerFile(` +import { createApp, serving } from '@databricks/appkit'; + +const myEndpoints = { demo: { env: "DATABRICKS_SERVING_ENDPOINT" } }; +createApp({ + plugins: [ + serving({ endpoints: myEndpoints }), + ], +}); +`); + + const result = extractServingEndpoints("/app/server/index.ts"); + expect(result).toBeNull(); + }); + + test("returns null when no serving() call exists", () => { + mockServerFile(` +import { createApp, analytics } from '@databricks/appkit'; + +createApp({ + plugins: [analytics({})], +}); +`); + + const result = extractServingEndpoints("/app/server/index.ts"); + expect(result).toBeNull(); + }); + + test("returns null when server file cannot be read", () => { + vi.spyOn(fs, "readFileSync").mockImplementation(() => { + throw new Error("ENOENT"); + }); + + const result = extractServingEndpoints("/app/server/nonexistent.ts"); + expect(result).toBeNull(); + }); + + test("handles single-quoted env values", () => { + mockServerFile(` +import { createApp, serving } from '@databricks/appkit'; + +createApp({ + plugins: [ + serving({ + endpoints: { + demo: { env: 'DATABRICKS_SERVING_ENDPOINT' }, + } + }), + ], +}); +`); + + const result = extractServingEndpoints("/app/server/index.ts"); + expect(result).toEqual({ + demo: { env: "DATABRICKS_SERVING_ENDPOINT" }, + }); + }); + + test("handles endpoints with trailing commas", () => { + mockServerFile(` +import { createApp, serving } from '@databricks/appkit'; + +createApp({ + plugins: [ + serving({ + endpoints: { + demo: { env: "DATABRICKS_SERVING_ENDPOINT" }, + second: { env: "DATABRICKS_SERVING_ENDPOINT_SECOND" }, + }, + }), + ], +}); +`); + + const result = extractServingEndpoints("/app/server/index.ts"); + expect(result).toEqual({ + demo: { env: "DATABRICKS_SERVING_ENDPOINT" }, + second: { env: "DATABRICKS_SERVING_ENDPOINT_SECOND" }, + }); + }); +}); diff --git a/packages/appkit/src/type-generator/serving/tests/vite-plugin.test.ts b/packages/appkit/src/type-generator/serving/tests/vite-plugin.test.ts new file mode 100644 index 00000000..bcd10915 --- /dev/null +++ b/packages/appkit/src/type-generator/serving/tests/vite-plugin.test.ts @@ -0,0 +1,186 @@ +import { afterEach, beforeEach, describe, expect, test, vi } from "vitest"; + +const mockGenerateServingTypes = vi.fn(async () => {}); +const mockFindServerFile = vi.fn((): string | null => null); +const mockExtractServingEndpoints = vi.fn( + (): Record | null => null, +); + +vi.mock("../generator", () => ({ + generateServingTypes: (...args: any[]) => mockGenerateServingTypes(...args), +})); + +vi.mock("../server-file-extractor", () => ({ + findServerFile: (...args: any[]) => mockFindServerFile(...args), + extractServingEndpoints: (...args: any[]) => + mockExtractServingEndpoints(...args), +})); + +import { appKitServingTypesPlugin } from "../vite-plugin"; + +describe("appKitServingTypesPlugin", () => { + const originalEnv = { ...process.env }; + + beforeEach(() => { + mockGenerateServingTypes.mockReset(); + mockFindServerFile.mockReset(); + mockExtractServingEndpoints.mockReset(); + }); + + afterEach(() => { + process.env = { ...originalEnv }; + vi.restoreAllMocks(); + }); + + describe("apply()", () => { + test("returns true when explicit endpoints provided", () => { + const plugin = appKitServingTypesPlugin({ + endpoints: { llm: { env: "LLM_ENDPOINT" } }, + }); + expect((plugin as any).apply()).toBe(true); + }); + + test("returns true when DATABRICKS_SERVING_ENDPOINT is set", () => { + process.env.DATABRICKS_SERVING_ENDPOINT = "my-endpoint"; + const plugin = appKitServingTypesPlugin(); + expect((plugin as any).apply()).toBe(true); + }); + + test("returns true when server file found in cwd", () => { + mockFindServerFile.mockReturnValueOnce("/app/server/index.ts"); + const plugin = appKitServingTypesPlugin(); + expect((plugin as any).apply()).toBe(true); + }); + + test("returns true when server file found in parent dir", () => { + mockFindServerFile + .mockReturnValueOnce(null) // cwd check + .mockReturnValueOnce("/app/server/index.ts"); // parent check + const plugin = appKitServingTypesPlugin(); + expect((plugin as any).apply()).toBe(true); + }); + + test("returns false when nothing configured", () => { + delete process.env.DATABRICKS_SERVING_ENDPOINT; + mockFindServerFile.mockReturnValue(null); + const plugin = appKitServingTypesPlugin(); + expect((plugin as any).apply()).toBe(false); + }); + }); + + describe("configResolved()", () => { + test("resolves outFile relative to config.root", async () => { + const plugin = appKitServingTypesPlugin({ + endpoints: { llm: { env: "LLM" } }, + }); + (plugin as any).configResolved({ root: "/app/client" }); + await (plugin as any).buildStart(); + + expect(mockGenerateServingTypes).toHaveBeenCalledWith( + expect.objectContaining({ + outFile: expect.stringContaining( + "/app/client/src/appKitServingTypes.d.ts", + ), + }), + ); + }); + + test("uses custom outFile when provided", async () => { + const plugin = appKitServingTypesPlugin({ + outFile: "types/serving.d.ts", + endpoints: { llm: { env: "LLM" } }, + }); + (plugin as any).configResolved({ root: "/app/client" }); + await (plugin as any).buildStart(); + + expect(mockGenerateServingTypes).toHaveBeenCalledWith( + expect.objectContaining({ + outFile: expect.stringContaining("types/serving.d.ts"), + }), + ); + }); + }); + + describe("buildStart()", () => { + test("calls generateServingTypes with explicit endpoints", async () => { + const endpoints = { llm: { env: "LLM_ENDPOINT" } }; + const plugin = appKitServingTypesPlugin({ endpoints }); + (plugin as any).configResolved({ root: "/app/client" }); + + await (plugin as any).buildStart(); + + expect(mockGenerateServingTypes).toHaveBeenCalledWith( + expect.objectContaining({ + endpoints, + noCache: false, + }), + ); + }); + + test("extracts endpoints from server file when not explicit", async () => { + const extracted = { llm: { env: "LLM_EP" } }; + mockFindServerFile.mockReturnValue("/app/server/index.ts"); + mockExtractServingEndpoints.mockReturnValue(extracted); + + const plugin = appKitServingTypesPlugin(); + (plugin as any).configResolved({ root: "/app/client" }); + await (plugin as any).buildStart(); + + expect(mockGenerateServingTypes).toHaveBeenCalledWith( + expect.objectContaining({ endpoints: extracted }), + ); + }); + + test("passes undefined endpoints when no server file found", async () => { + mockFindServerFile.mockReturnValue(null); + + const plugin = appKitServingTypesPlugin(); + (plugin as any).configResolved({ root: "/app/client" }); + await (plugin as any).buildStart(); + + expect(mockGenerateServingTypes).toHaveBeenCalledWith( + expect.objectContaining({ endpoints: undefined }), + ); + }); + + test("passes undefined when AST extraction returns null", async () => { + mockFindServerFile.mockReturnValue("/app/server/index.ts"); + mockExtractServingEndpoints.mockReturnValue(null); + + const plugin = appKitServingTypesPlugin(); + (plugin as any).configResolved({ root: "/app/client" }); + await (plugin as any).buildStart(); + + expect(mockGenerateServingTypes).toHaveBeenCalledWith( + expect.objectContaining({ endpoints: undefined }), + ); + }); + + test("swallows errors in dev mode", async () => { + process.env.NODE_ENV = "development"; + mockGenerateServingTypes.mockRejectedValue(new Error("fetch failed")); + + const plugin = appKitServingTypesPlugin({ + endpoints: { llm: { env: "LLM" } }, + }); + (plugin as any).configResolved({ root: "/app/client" }); + + // Should not throw + await expect((plugin as any).buildStart()).resolves.toBeUndefined(); + }); + + test("rethrows errors in production mode", async () => { + process.env.NODE_ENV = "production"; + mockGenerateServingTypes.mockRejectedValue(new Error("fetch failed")); + + const plugin = appKitServingTypesPlugin({ + endpoints: { llm: { env: "LLM" } }, + }); + (plugin as any).configResolved({ root: "/app/client" }); + + await expect((plugin as any).buildStart()).rejects.toThrow( + "fetch failed", + ); + }); + }); +}); diff --git a/packages/appkit/src/type-generator/serving/vite-plugin.ts b/packages/appkit/src/type-generator/serving/vite-plugin.ts new file mode 100644 index 00000000..9903a253 --- /dev/null +++ b/packages/appkit/src/type-generator/serving/vite-plugin.ts @@ -0,0 +1,109 @@ +import path from "node:path"; +import type { Plugin } from "vite"; +import { createLogger } from "../../logging/logger"; +import type { EndpointConfig } from "../../plugins/serving/types"; +import { generateServingTypes } from "./generator"; +import { + extractServingEndpoints, + findServerFile, +} from "./server-file-extractor"; + +const logger = createLogger("type-generator:serving:vite-plugin"); + +interface AppKitServingTypesPluginOptions { + /** Path to the output .d.ts file (relative to client root). Default: "src/appKitServingTypes.d.ts" */ + outFile?: string; + /** Endpoint config override. If omitted, auto-discovers from the server file or falls back to DATABRICKS_SERVING_ENDPOINT env var. */ + endpoints?: Record; +} + +/** + * Vite plugin to generate TypeScript types for AppKit serving endpoints. + * Fetches OpenAPI schemas from Databricks and generates a .d.ts with + * ServingEndpointRegistry module augmentation. + * + * Endpoint discovery order: + * 1. Explicit `endpoints` option (override) + * 2. AST extraction from server file (server/index.ts or server/server.ts) + * 3. DATABRICKS_SERVING_ENDPOINT env var (single default endpoint) + */ +export function appKitServingTypesPlugin( + options?: AppKitServingTypesPluginOptions, +): Plugin { + let outFile: string; + let projectRoot: string; + + async function generate() { + try { + // Resolve endpoints: explicit option > server file AST > env var fallback (handled by generator) + let endpoints = options?.endpoints; + if (!endpoints) { + const serverFile = findServerFile(projectRoot); + if (serverFile) { + endpoints = extractServingEndpoints(serverFile) ?? undefined; + } + } + + await generateServingTypes({ + outFile, + endpoints, + noCache: false, + }); + } catch (error) { + if (process.env.NODE_ENV === "production") { + throw error; + } + logger.error("Error generating serving types: %O", error); + } + } + + return { + name: "appkit-serving-types", + + apply() { + // Fast checks — no AST parsing here + if (options?.endpoints && Object.keys(options.endpoints).length > 0) { + return true; + } + + if (process.env.DATABRICKS_SERVING_ENDPOINT) { + return true; + } + + // Check if a server file exists (may contain serving() config) + // Use process.cwd() for apply() since configResolved hasn't run yet + if (findServerFile(process.cwd())) { + return true; + } + + // Also check parent dir (for when cwd is client/) + const parentDir = path.resolve(process.cwd(), ".."); + if (findServerFile(parentDir)) { + return true; + } + + logger.debug( + "No serving endpoints configured. Skipping type generation.", + ); + return false; + }, + + configResolved(config) { + // Resolve project root: go up one level from Vite root (client dir) + // This handles both: + // - pnpm dev: process.cwd() is app root, config.root is client/ + // - pnpm build: process.cwd() is client/ (cd client && vite build), config.root is client/ + projectRoot = path.resolve(config.root, ".."); + outFile = path.resolve( + config.root, + options?.outFile ?? "src/appKitServingTypes.d.ts", + ); + }, + + async buildStart() { + await generate(); + }, + + // No configureServer / watcher — schemas change on endpoint redeploy, not on file edit + }; +} diff --git a/pnpm-lock.yaml b/pnpm-lock.yaml index 199fcfb8..9ca11b81 100644 --- a/pnpm-lock.yaml +++ b/pnpm-lock.yaml @@ -242,6 +242,9 @@ importers: packages/appkit: dependencies: + '@ast-grep/napi': + specifier: 0.37.0 + version: 0.37.0 '@databricks/lakebase': specifier: workspace:* version: link:../lakebase diff --git a/template/appkit.plugins.json b/template/appkit.plugins.json index cf60a8af..c21d8e80 100644 --- a/template/appkit.plugins.json +++ b/template/appkit.plugins.json @@ -149,6 +149,30 @@ "optional": [] }, "requiredByTemplate": true + }, + "serving": { + "name": "serving", + "displayName": "Model Serving Plugin", + "description": "Authenticated proxy to Databricks Model Serving endpoints", + "package": "@databricks/appkit", + "resources": { + "required": [ + { + "type": "serving_endpoint", + "alias": "Serving Endpoint", + "resourceKey": "serving-endpoint", + "description": "Model Serving endpoint for inference", + "permission": "CAN_QUERY", + "fields": { + "name": { + "env": "DATABRICKS_SERVING_ENDPOINT", + "description": "Serving endpoint name" + } + } + } + ], + "optional": [] + } } } } diff --git a/template/client/src/App.tsx b/template/client/src/App.tsx index fb4c28e6..a94bb5bc 100644 --- a/template/client/src/App.tsx +++ b/template/client/src/App.tsx @@ -17,6 +17,9 @@ import { GeniePage } from './pages/genie/GeniePage'; {{- if .plugins.files}} import { FilesPage } from './pages/files/FilesPage'; {{- end}} +{{- if .plugins.serving}} +import { ServingPage } from './pages/serving/ServingPage'; +{{- end}} const navLinkClass = ({ isActive }: { isActive: boolean }) => `px-3 py-1.5 rounded-md text-sm font-medium transition-colors ${ @@ -53,6 +56,11 @@ function Layout() { Files +{{- end}} +{{- if .plugins.serving}} + + Serving + {{- end}} @@ -80,6 +88,9 @@ const router = createBrowserRouter([ {{- end}} {{- if .plugins.files}} { path: '/files', element: }, +{{- end}} +{{- if .plugins.serving}} + { path: '/serving', element: }, {{- end}} ], }, diff --git a/template/client/src/pages/serving/ServingPage.tsx b/template/client/src/pages/serving/ServingPage.tsx new file mode 100644 index 00000000..b80934ba --- /dev/null +++ b/template/client/src/pages/serving/ServingPage.tsx @@ -0,0 +1,127 @@ +{{if .plugins.serving -}} +import { useServingInvoke } from '@databricks/appkit-ui/react'; +// For streaming endpoints (e.g. chat models), use useServingStream instead: +// import { useServingStream } from '@databricks/appkit-ui/react'; +import { useState } from 'react'; + +interface ChatChoice { + message?: { content?: string }; +} + +interface ChatResponse { + choices?: ChatChoice[]; +} + +function extractContent(data: unknown): string { + const resp = data as ChatResponse; + return resp?.choices?.[0]?.message?.content ?? JSON.stringify(data); +} + +interface Message { + id: string; + role: 'user' | 'assistant'; + content: string; +} + +export function ServingPage() { + const [input, setInput] = useState(''); + const [messages, setMessages] = useState([]); + + const { invoke, loading, error } = useServingInvoke({ messages: [] }); + // For streaming endpoints (e.g. chat models), use useServingStream instead: + // const { stream, chunks, streaming, error, reset } = useServingStream({ messages: [] }); + // Then accumulate chunks: chunks.map(c => c?.choices?.[0]?.delta?.content ?? '').join('') + + function handleSubmit(e: React.FormEvent) { + e.preventDefault(); + if (!input.trim() || loading) return; + + const userMessage: Message = { + id: crypto.randomUUID(), + role: 'user', + content: input.trim(), + }; + + const fullMessages = [ + ...messages.map(({ role, content }) => ({ role, content })), + { role: 'user' as const, content: userMessage.content }, + ]; + + setMessages((prev) => [...prev, userMessage]); + setInput(''); + + void invoke({ messages: fullMessages }).then((result) => { + if (result) { + setMessages((prev) => [ + ...prev, + { id: crypto.randomUUID(), role: 'assistant', content: extractContent(result) }, + ]); + } + }); + } + + return ( +
+
+

Model Serving

+

+ Chat with a Databricks Model Serving endpoint. +

+
+ +
+
+ {messages.map((msg) => ( +
+
+

{msg.content}

+
+
+ ))} + + {loading && ( +
+
+

...

+
+
+ )} + + {error && ( +
+ Error: {error} +
+ )} +
+ +
+ setInput(e.target.value)} + placeholder="Send a message..." + className="flex-1 rounded-md border px-3 py-2 text-sm bg-background" + disabled={loading} + /> + +
+
+
+ ); +} +{{- end}} diff --git a/template/client/vite.config.ts b/template/client/vite.config.ts index b49d4055..12c1d864 100644 --- a/template/client/vite.config.ts +++ b/template/client/vite.config.ts @@ -2,11 +2,20 @@ import { defineConfig } from 'vite'; import react from '@vitejs/plugin-react'; import tailwindcss from '@tailwindcss/vite'; import path from 'node:path'; +{{- if .plugins.serving}} +import { appKitServingTypesPlugin } from '@databricks/appkit'; +{{- end}} // https://vite.dev/config/ export default defineConfig({ root: __dirname, - plugins: [react(), tailwindcss()], + plugins: [ + react(), + tailwindcss(), +{{- if .plugins.serving}} + appKitServingTypesPlugin(), +{{- end}} + ], server: { middlewareMode: true, }, diff --git a/template/databricks.yml.tmpl b/template/databricks.yml.tmpl index accf7709..77997d31 100644 --- a/template/databricks.yml.tmpl +++ b/template/databricks.yml.tmpl @@ -13,7 +13,7 @@ resources: description: "{{.appDescription}}" source_code_path: ./ -{{- if or .plugins.genie .plugins.files}} +{{- if or .plugins.genie .plugins.files .plugins.serving}} user_api_scopes: {{- if .plugins.genie}} - dashboards.genie @@ -21,8 +21,11 @@ resources: {{- if .plugins.files}} - files.files {{- end}} +{{- if .plugins.serving}} + - serving.serving-endpoints +{{- end}} {{- else}} - # Uncomment to enable on behalf of user API scopes. Available scopes: sql, dashboards.genie, files.files + # Uncomment to enable on behalf of user API scopes. Available scopes: sql, dashboards.genie, files.files, serving.serving-endpoints # user_api_scopes: # - sql {{- end}} diff --git a/tools/generate-app-templates.ts b/tools/generate-app-templates.ts index 4b029121..1eff9357 100644 --- a/tools/generate-app-templates.ts +++ b/tools/generate-app-templates.ts @@ -55,21 +55,23 @@ const FEATURE_DEPENDENCIES: Record = { files: "Volume", genie: "Genie Space", lakebase: "Database", + serving: "Serving Endpoint", }; const APP_TEMPLATES: AppTemplate[] = [ { name: "appkit-all-in-one", - features: ["analytics", "files", "genie", "lakebase"], + features: ["analytics", "files", "genie", "lakebase", "serving"], set: { "analytics.sql-warehouse.id": "placeholder", "files.files.path": "placeholder", "genie.genie-space.id": "placeholder", "lakebase.postgres.branch": "placeholder", "lakebase.postgres.database": "placeholder", + "serving.serving-endpoint.name": "placeholder", }, description: - "Full-stack Node.js app with SQL analytics dashboards, file browser, Genie AI conversations, and Lakebase Autoscaling (Postgres) CRUD", + "Full-stack Node.js app with SQL analytics dashboards, file browser, Genie AI conversations, Lakebase Autoscaling (Postgres) CRUD, and Model Serving", }, { name: "appkit-analytics", @@ -96,6 +98,15 @@ const APP_TEMPLATES: AppTemplate[] = [ }, description: "Node.js app with file browser for Databricks Volumes", }, + { + name: "appkit-serving", + features: ["serving"], + set: { + "serving.serving-endpoint.name": "placeholder", + }, + description: + "Node.js app with Databricks Model Serving endpoint integration", + }, { name: "appkit-lakebase", features: ["lakebase"],