From 7ba88a217304a46bc0d5250d52748192b7511aa4 Mon Sep 17 00:00:00 2001 From: Pawel Kosiec Date: Fri, 3 Apr 2026 12:05:08 +0200 Subject: [PATCH 1/6] feat: add serving type generator, Vite plugin, and UI hooks Add Vite plugin that auto-generates TypeScript types from serving endpoint OpenAPI schemas. Includes AST-based server file extraction (@ast-grep/napi), schema-to-TypeScript conversion, and caching. Also adds useServingInvoke and useServingStream React hooks in appkit-ui with full type-safe registry support. Signed-off-by: Pawel Kosiec --- .gitignore | 3 + .../Function.appKitServingTypesPlugin.md | 24 ++ .../Function.extractServingEndpoints.md | 24 ++ .../api/appkit/Function.findServerFile.md | 19 ++ docs/docs/api/appkit/index.md | 3 + docs/docs/api/appkit/typedoc-sidebar.ts | 15 + .../__tests__/use-serving-invoke.test.ts | 117 ++++++++ .../__tests__/use-serving-stream.test.ts | 271 +++++++++++++++++ packages/appkit-ui/src/react/hooks/index.ts | 15 + packages/appkit-ui/src/react/hooks/types.ts | 51 ++++ .../src/react/hooks/use-serving-invoke.ts | 103 +++++++ .../src/react/hooks/use-serving-stream.ts | 123 ++++++++ packages/appkit/package.json | 1 + packages/appkit/src/index.ts | 6 +- .../src/plugins/serving/schema-filter.ts | 18 +- .../src/type-generator/serving/cache.ts | 55 ++++ .../src/type-generator/serving/converter.ts | 149 ++++++++++ .../src/type-generator/serving/fetcher.ts | 158 ++++++++++ .../src/type-generator/serving/generator.ts | 266 +++++++++++++++++ .../serving/server-file-extractor.ts | 221 ++++++++++++++ .../serving/tests/cache.test.ts | 107 +++++++ .../serving/tests/converter.test.ts | 278 ++++++++++++++++++ .../serving/tests/fetcher.test.ts | 209 +++++++++++++ .../serving/tests/generator.test.ts | 215 ++++++++++++++ .../tests/server-file-extractor.test.ts | 213 ++++++++++++++ .../serving/tests/vite-plugin.test.ts | 186 ++++++++++++ .../src/type-generator/serving/vite-plugin.ts | 109 +++++++ pnpm-lock.yaml | 3 + 28 files changed, 2947 insertions(+), 15 deletions(-) create mode 100644 docs/docs/api/appkit/Function.appKitServingTypesPlugin.md create mode 100644 docs/docs/api/appkit/Function.extractServingEndpoints.md create mode 100644 docs/docs/api/appkit/Function.findServerFile.md create mode 100644 packages/appkit-ui/src/react/hooks/__tests__/use-serving-invoke.test.ts create mode 100644 packages/appkit-ui/src/react/hooks/__tests__/use-serving-stream.test.ts create mode 100644 packages/appkit-ui/src/react/hooks/use-serving-invoke.ts create mode 100644 packages/appkit-ui/src/react/hooks/use-serving-stream.ts create mode 100644 packages/appkit/src/type-generator/serving/cache.ts create mode 100644 packages/appkit/src/type-generator/serving/converter.ts create mode 100644 packages/appkit/src/type-generator/serving/fetcher.ts create mode 100644 packages/appkit/src/type-generator/serving/generator.ts create mode 100644 packages/appkit/src/type-generator/serving/server-file-extractor.ts create mode 100644 packages/appkit/src/type-generator/serving/tests/cache.test.ts create mode 100644 packages/appkit/src/type-generator/serving/tests/converter.test.ts create mode 100644 packages/appkit/src/type-generator/serving/tests/fetcher.test.ts create mode 100644 packages/appkit/src/type-generator/serving/tests/generator.test.ts create mode 100644 packages/appkit/src/type-generator/serving/tests/server-file-extractor.test.ts create mode 100644 packages/appkit/src/type-generator/serving/tests/vite-plugin.test.ts create mode 100644 packages/appkit/src/type-generator/serving/vite-plugin.ts diff --git a/.gitignore b/.gitignore index 3b6cc969..4c51d5b1 100644 --- a/.gitignore +++ b/.gitignore @@ -9,3 +9,6 @@ coverage *.tsbuildinfo .turbo + +# AppKit type generator caches +.databricks diff --git a/docs/docs/api/appkit/Function.appKitServingTypesPlugin.md b/docs/docs/api/appkit/Function.appKitServingTypesPlugin.md new file mode 100644 index 00000000..bc28660a --- /dev/null +++ b/docs/docs/api/appkit/Function.appKitServingTypesPlugin.md @@ -0,0 +1,24 @@ +# Function: appKitServingTypesPlugin() + +```ts +function appKitServingTypesPlugin(options?: AppKitServingTypesPluginOptions): Plugin$1; +``` + +Vite plugin to generate TypeScript types for AppKit serving endpoints. +Fetches OpenAPI schemas from Databricks and generates a .d.ts with +ServingEndpointRegistry module augmentation. + +Endpoint discovery order: +1. Explicit `endpoints` option (override) +2. AST extraction from server file (server/index.ts or server/server.ts) +3. DATABRICKS_SERVING_ENDPOINT env var (single default endpoint) + +## Parameters + +| Parameter | Type | +| ------ | ------ | +| `options?` | `AppKitServingTypesPluginOptions` | + +## Returns + +`Plugin$1` diff --git a/docs/docs/api/appkit/Function.extractServingEndpoints.md b/docs/docs/api/appkit/Function.extractServingEndpoints.md new file mode 100644 index 00000000..24a5b00d --- /dev/null +++ b/docs/docs/api/appkit/Function.extractServingEndpoints.md @@ -0,0 +1,24 @@ +# Function: extractServingEndpoints() + +```ts +function extractServingEndpoints(serverFilePath: string): + | Record + | null; +``` + +Extract serving endpoint config from a server file by AST-parsing it. +Looks for `serving({ endpoints: { alias: { env: "..." }, ... } })` calls +and extracts the endpoint alias names and their environment variable mappings. + +## Parameters + +| Parameter | Type | Description | +| ------ | ------ | ------ | +| `serverFilePath` | `string` | Absolute path to the server entry file | + +## Returns + + \| `Record`\<`string`, [`EndpointConfig`](Interface.EndpointConfig.md)\> + \| `null` + +Extracted endpoint config, or null if not found or not extractable diff --git a/docs/docs/api/appkit/Function.findServerFile.md b/docs/docs/api/appkit/Function.findServerFile.md new file mode 100644 index 00000000..2ed4e268 --- /dev/null +++ b/docs/docs/api/appkit/Function.findServerFile.md @@ -0,0 +1,19 @@ +# Function: findServerFile() + +```ts +function findServerFile(basePath: string): string | null; +``` + +Find the server entry file by checking candidate paths in order. + +## Parameters + +| Parameter | Type | Description | +| ------ | ------ | ------ | +| `basePath` | `string` | Project root directory to search from | + +## Returns + +`string` \| `null` + +Absolute path to the server file, or null if none found diff --git a/docs/docs/api/appkit/index.md b/docs/docs/api/appkit/index.md index f4685e04..faadf237 100644 --- a/docs/docs/api/appkit/index.md +++ b/docs/docs/api/appkit/index.md @@ -70,9 +70,12 @@ plugin architecture, and React integration. | Function | Description | | ------ | ------ | +| [appKitServingTypesPlugin](Function.appKitServingTypesPlugin.md) | Vite plugin to generate TypeScript types for AppKit serving endpoints. Fetches OpenAPI schemas from Databricks and generates a .d.ts with ServingEndpointRegistry module augmentation. | | [appKitTypesPlugin](Function.appKitTypesPlugin.md) | Vite plugin to generate types for AppKit queries. Calls generateFromEntryPoint under the hood. | | [createApp](Function.createApp.md) | Bootstraps AppKit with the provided configuration. | | [createLakebasePool](Function.createLakebasePool.md) | Create a Lakebase pool with appkit's logger integration. Telemetry automatically uses appkit's OpenTelemetry configuration via global registry. | +| [extractServingEndpoints](Function.extractServingEndpoints.md) | Extract serving endpoint config from a server file by AST-parsing it. Looks for `serving({ endpoints: { alias: { env: "..." }, ... } })` calls and extracts the endpoint alias names and their environment variable mappings. | +| [findServerFile](Function.findServerFile.md) | Find the server entry file by checking candidate paths in order. | | [generateDatabaseCredential](Function.generateDatabaseCredential.md) | Generate OAuth credentials for Postgres database connection using the proper Postgres API. | | [getExecutionContext](Function.getExecutionContext.md) | Get the current execution context. | | [getLakebaseOrmConfig](Function.getLakebaseOrmConfig.md) | Get Lakebase connection configuration for ORMs that don't accept pg.Pool directly. | diff --git a/docs/docs/api/appkit/typedoc-sidebar.ts b/docs/docs/api/appkit/typedoc-sidebar.ts index 91815e3d..1d498d1a 100644 --- a/docs/docs/api/appkit/typedoc-sidebar.ts +++ b/docs/docs/api/appkit/typedoc-sidebar.ts @@ -225,6 +225,11 @@ const typedocSidebar: SidebarsConfig = { type: "category", label: "Functions", items: [ + { + type: "doc", + id: "api/appkit/Function.appKitServingTypesPlugin", + label: "appKitServingTypesPlugin" + }, { type: "doc", id: "api/appkit/Function.appKitTypesPlugin", @@ -240,6 +245,16 @@ const typedocSidebar: SidebarsConfig = { id: "api/appkit/Function.createLakebasePool", label: "createLakebasePool" }, + { + type: "doc", + id: "api/appkit/Function.extractServingEndpoints", + label: "extractServingEndpoints" + }, + { + type: "doc", + id: "api/appkit/Function.findServerFile", + label: "findServerFile" + }, { type: "doc", id: "api/appkit/Function.generateDatabaseCredential", diff --git a/packages/appkit-ui/src/react/hooks/__tests__/use-serving-invoke.test.ts b/packages/appkit-ui/src/react/hooks/__tests__/use-serving-invoke.test.ts new file mode 100644 index 00000000..6d5f159f --- /dev/null +++ b/packages/appkit-ui/src/react/hooks/__tests__/use-serving-invoke.test.ts @@ -0,0 +1,117 @@ +import { act, renderHook, waitFor } from "@testing-library/react"; +import { afterEach, beforeEach, describe, expect, test, vi } from "vitest"; +import { useServingInvoke } from "../use-serving-invoke"; + +describe("useServingInvoke", () => { + beforeEach(() => { + vi.spyOn(globalThis, "fetch").mockResolvedValue( + new Response(JSON.stringify({ choices: [] }), { status: 200 }), + ); + }); + + afterEach(() => { + vi.restoreAllMocks(); + }); + + test("initial state is idle", () => { + const { result } = renderHook(() => useServingInvoke({ messages: [] })); + + expect(result.current.data).toBeNull(); + expect(result.current.loading).toBe(false); + expect(result.current.error).toBeNull(); + expect(typeof result.current.invoke).toBe("function"); + }); + + test("calls fetch to correct URL on invoke", async () => { + const fetchSpy = vi.spyOn(globalThis, "fetch"); + + const { result } = renderHook(() => + useServingInvoke({ messages: [{ role: "user", content: "Hello" }] }), + ); + + act(() => { + result.current.invoke(); + }); + + await waitFor(() => { + expect(fetchSpy).toHaveBeenCalledWith( + "/api/serving/invoke", + expect.objectContaining({ + method: "POST", + body: JSON.stringify({ + messages: [{ role: "user", content: "Hello" }], + }), + }), + ); + }); + }); + + test("uses alias in URL when provided", async () => { + const fetchSpy = vi.spyOn(globalThis, "fetch"); + + const { result } = renderHook(() => + useServingInvoke({ messages: [] }, { alias: "llm" }), + ); + + act(() => { + result.current.invoke(); + }); + + await waitFor(() => { + expect(fetchSpy).toHaveBeenCalledWith( + "/api/serving/llm/invoke", + expect.any(Object), + ); + }); + }); + + test("sets data on successful response", async () => { + const responseData = { + choices: [{ message: { content: "Hi" } }], + }; + + vi.spyOn(globalThis, "fetch").mockResolvedValue( + new Response(JSON.stringify(responseData), { status: 200 }), + ); + + const { result } = renderHook(() => useServingInvoke({ messages: [] })); + + act(() => { + result.current.invoke(); + }); + + await waitFor(() => { + expect(result.current.data).toEqual(responseData); + expect(result.current.loading).toBe(false); + }); + }); + + test("sets error on failed response", async () => { + vi.spyOn(globalThis, "fetch").mockResolvedValue( + new Response(JSON.stringify({ error: "Not found" }), { status: 404 }), + ); + + const { result } = renderHook(() => useServingInvoke({ messages: [] })); + + await act(async () => { + result.current.invoke(); + // Wait for the fetch promise chain to resolve + await new Promise((r) => setTimeout(r, 10)); + }); + + await waitFor(() => { + expect(result.current.error).toBe("Not found"); + expect(result.current.loading).toBe(false); + }); + }); + + test("auto starts when autoStart is true", async () => { + const fetchSpy = vi.spyOn(globalThis, "fetch"); + + renderHook(() => useServingInvoke({ messages: [] }, { autoStart: true })); + + await waitFor(() => { + expect(fetchSpy).toHaveBeenCalled(); + }); + }); +}); diff --git a/packages/appkit-ui/src/react/hooks/__tests__/use-serving-stream.test.ts b/packages/appkit-ui/src/react/hooks/__tests__/use-serving-stream.test.ts new file mode 100644 index 00000000..0a1a736c --- /dev/null +++ b/packages/appkit-ui/src/react/hooks/__tests__/use-serving-stream.test.ts @@ -0,0 +1,271 @@ +import { act, renderHook, waitFor } from "@testing-library/react"; +import { afterEach, describe, expect, test, vi } from "vitest"; + +// Mock connectSSE — capture callbacks so we can simulate SSE events +let capturedCallbacks: { + onMessage?: (msg: { data: string }) => void; + onError?: (err: Error) => void; + signal?: AbortSignal; +} = {}; + +let resolveStream: (() => void) | null = null; + +const mockConnectSSE = vi.fn().mockImplementation((opts: any) => { + capturedCallbacks = { + onMessage: opts.onMessage, + onError: opts.onError, + signal: opts.signal, + }; + return new Promise((resolve) => { + resolveStream = resolve; + // Also resolve after a tick as fallback for tests that don't manually resolve + setTimeout(resolve, 0); + }); +}); + +vi.mock("@/js", () => ({ + connectSSE: (...args: unknown[]) => mockConnectSSE(...args), +})); + +import { useServingStream } from "../use-serving-stream"; + +describe("useServingStream", () => { + afterEach(() => { + capturedCallbacks = {}; + resolveStream = null; + vi.clearAllMocks(); + }); + + test("initial state is idle", () => { + const { result } = renderHook(() => useServingStream({ messages: [] })); + + expect(result.current.chunks).toEqual([]); + expect(result.current.streaming).toBe(false); + expect(result.current.error).toBeNull(); + expect(typeof result.current.stream).toBe("function"); + expect(typeof result.current.reset).toBe("function"); + }); + + test("calls connectSSE with correct URL on stream", () => { + const { result } = renderHook(() => useServingStream({ messages: [] })); + + act(() => { + result.current.stream(); + }); + + expect(mockConnectSSE).toHaveBeenCalledWith( + expect.objectContaining({ + url: "/api/serving/stream", + payload: JSON.stringify({ messages: [] }), + }), + ); + }); + + test("uses alias in URL when provided", () => { + const { result } = renderHook(() => + useServingStream({ messages: [] }, { alias: "embedder" }), + ); + + act(() => { + result.current.stream(); + }); + + expect(mockConnectSSE).toHaveBeenCalledWith( + expect.objectContaining({ + url: "/api/serving/embedder/stream", + }), + ); + }); + + test("sets streaming to true when stream() is called", () => { + const { result } = renderHook(() => useServingStream({ messages: [] })); + + act(() => { + result.current.stream(); + }); + + expect(result.current.streaming).toBe(true); + }); + + test("accumulates chunks from onMessage", async () => { + const { result } = renderHook(() => useServingStream({ messages: [] })); + + act(() => { + result.current.stream(); + }); + + act(() => { + capturedCallbacks.onMessage?.({ data: JSON.stringify({ id: 1 }) }); + }); + + act(() => { + capturedCallbacks.onMessage?.({ data: JSON.stringify({ id: 2 }) }); + }); + + expect(result.current.chunks).toEqual([{ id: 1 }, { id: 2 }]); + }); + + test("accumulates chunks with error field as normal data", async () => { + const { result } = renderHook(() => useServingStream({ messages: [] })); + + act(() => { + result.current.stream(); + }); + + act(() => { + capturedCallbacks.onMessage?.({ + data: JSON.stringify({ error: "Model overloaded" }), + }); + }); + + // Chunks with an `error` field are treated as data, not stream errors. + // Transport-level errors are delivered via onError callback instead. + expect(result.current.chunks).toEqual([{ error: "Model overloaded" }]); + expect(result.current.error).toBeNull(); + expect(result.current.streaming).toBe(true); + }); + + test("sets error from onError callback", async () => { + const { result } = renderHook(() => useServingStream({ messages: [] })); + + act(() => { + result.current.stream(); + }); + + act(() => { + capturedCallbacks.onError?.(new Error("Connection lost")); + }); + + expect(result.current.error).toBe("Connection lost"); + expect(result.current.streaming).toBe(false); + }); + + test("silently skips malformed JSON messages", () => { + const { result } = renderHook(() => useServingStream({ messages: [] })); + + act(() => { + result.current.stream(); + }); + + act(() => { + capturedCallbacks.onMessage?.({ data: "not valid json{" }); + }); + + // No chunks added, no error set + expect(result.current.chunks).toEqual([]); + expect(result.current.error).toBeNull(); + }); + + test("reset() clears state and aborts active stream", () => { + const { result } = renderHook(() => useServingStream({ messages: [] })); + + act(() => { + result.current.stream(); + }); + + act(() => { + capturedCallbacks.onMessage?.({ data: JSON.stringify({ id: 1 }) }); + }); + + expect(result.current.chunks).toHaveLength(1); + expect(result.current.streaming).toBe(true); + + act(() => { + result.current.reset(); + }); + + expect(result.current.chunks).toEqual([]); + expect(result.current.streaming).toBe(false); + expect(result.current.error).toBeNull(); + }); + + test("autoStart triggers stream on mount", async () => { + renderHook(() => useServingStream({ messages: [] }, { autoStart: true })); + + await waitFor(() => { + expect(mockConnectSSE).toHaveBeenCalled(); + }); + }); + + test("passes abort signal to connectSSE", () => { + const { result } = renderHook(() => useServingStream({ messages: [] })); + + act(() => { + result.current.stream(); + }); + + expect(capturedCallbacks.signal).toBeDefined(); + expect(capturedCallbacks.signal?.aborted).toBe(false); + }); + + test("aborts stream on unmount", () => { + const { result, unmount } = renderHook(() => + useServingStream({ messages: [] }), + ); + + act(() => { + result.current.stream(); + }); + + const signal = capturedCallbacks.signal; + expect(signal?.aborted).toBe(false); + + unmount(); + + expect(signal?.aborted).toBe(true); + }); + + test("sets streaming to false when connectSSE resolves", async () => { + const { result } = renderHook(() => useServingStream({ messages: [] })); + + act(() => { + result.current.stream(); + }); + + await waitFor(() => { + expect(result.current.streaming).toBe(false); + }); + }); + + test("calls onComplete with accumulated chunks when stream finishes", async () => { + const onComplete = vi.fn(); + + // Use a controllable mock so stream doesn't auto-resolve + mockConnectSSE.mockImplementationOnce((opts: any) => { + capturedCallbacks = { + onMessage: opts.onMessage, + onError: opts.onError, + signal: opts.signal, + }; + return new Promise((resolve) => { + resolveStream = resolve; + }); + }); + + const { result } = renderHook(() => + useServingStream({ messages: [] }, { onComplete }), + ); + + act(() => { + result.current.stream(); + }); + + // Send two chunks + act(() => { + capturedCallbacks.onMessage?.({ data: JSON.stringify({ id: 1 }) }); + }); + act(() => { + capturedCallbacks.onMessage?.({ data: JSON.stringify({ id: 2 }) }); + }); + + expect(onComplete).not.toHaveBeenCalled(); + + // Complete the stream + await act(async () => { + resolveStream?.(); + await new Promise((r) => setTimeout(r, 0)); + }); + + expect(onComplete).toHaveBeenCalledWith([{ id: 1 }, { id: 2 }]); + }); +}); diff --git a/packages/appkit-ui/src/react/hooks/index.ts b/packages/appkit-ui/src/react/hooks/index.ts index 84d51b53..a425b010 100644 --- a/packages/appkit-ui/src/react/hooks/index.ts +++ b/packages/appkit-ui/src/react/hooks/index.ts @@ -2,8 +2,13 @@ export type { AnalyticsFormat, InferResultByFormat, InferRowType, + InferServingChunk, + InferServingRequest, + InferServingResponse, PluginRegistry, QueryRegistry, + ServingAlias, + ServingEndpointRegistry, TypedArrowTable, UseAnalyticsQueryOptions, UseAnalyticsQueryResult, @@ -15,3 +20,13 @@ export { useChartData, } from "./use-chart-data"; export { usePluginClientConfig } from "./use-plugin-config"; +export { + type UseServingInvokeOptions, + type UseServingInvokeResult, + useServingInvoke, +} from "./use-serving-invoke"; +export { + type UseServingStreamOptions, + type UseServingStreamResult, + useServingStream, +} from "./use-serving-stream"; diff --git a/packages/appkit-ui/src/react/hooks/types.ts b/packages/appkit-ui/src/react/hooks/types.ts index 5db725fc..19ce1fac 100644 --- a/packages/appkit-ui/src/react/hooks/types.ts +++ b/packages/appkit-ui/src/react/hooks/types.ts @@ -134,3 +134,54 @@ export type InferParams = K extends AugmentedRegistry export interface PluginRegistry { [key: string]: Record; } + +// ============================================================================ +// Serving Endpoint Registry +// ============================================================================ + +/** + * Serving endpoint registry for type-safe alias names. + * Extend this interface via module augmentation to get alias autocomplete: + * + * @example + * ```typescript + * // Auto-generated by appKitServingTypesPlugin() + * declare module "@databricks/appkit-ui/react" { + * interface ServingEndpointRegistry { + * llm: { request: {...}; response: {...}; chunk: {...} }; + * } + * } + * ``` + */ +// biome-ignore lint/suspicious/noEmptyInterface: intentionally empty — populated via module augmentation +export interface ServingEndpointRegistry {} + +/** Resolves to registry keys if populated, otherwise string */ +export type ServingAlias = + AugmentedRegistry extends never + ? string + : AugmentedRegistry; + +/** Infers chunk type from registry when alias is a known key */ +export type InferServingChunk = + K extends AugmentedRegistry + ? ServingEndpointRegistry[K] extends { chunk: infer C } + ? C + : unknown + : unknown; + +/** Infers response type from registry when alias is a known key */ +export type InferServingResponse = + K extends AugmentedRegistry + ? ServingEndpointRegistry[K] extends { response: infer R } + ? R + : unknown + : unknown; + +/** Infers request type from registry when alias is a known key */ +export type InferServingRequest = + K extends AugmentedRegistry + ? ServingEndpointRegistry[K] extends { request: infer Req } + ? Req + : Record + : Record; diff --git a/packages/appkit-ui/src/react/hooks/use-serving-invoke.ts b/packages/appkit-ui/src/react/hooks/use-serving-invoke.ts new file mode 100644 index 00000000..343a5e71 --- /dev/null +++ b/packages/appkit-ui/src/react/hooks/use-serving-invoke.ts @@ -0,0 +1,103 @@ +import { useCallback, useEffect, useRef, useState } from "react"; +import type { + InferServingRequest, + InferServingResponse, + ServingAlias, +} from "./types"; + +export interface UseServingInvokeOptions< + K extends ServingAlias = ServingAlias, +> { + /** Endpoint alias for named mode. Omit for default mode. */ + alias?: K; + /** If false, does not invoke automatically on mount. Default: false */ + autoStart?: boolean; +} + +export interface UseServingInvokeResult { + /** Trigger the invocation. Returns the response data, or null on error/abort. */ + invoke: () => Promise; + /** Response data, null until loaded. */ + data: T | null; + /** Whether a request is in progress. */ + loading: boolean; + /** Error message, if any. */ + error: string | null; +} + +/** + * Hook for non-streaming invocation of a serving endpoint. + * Calls `POST /api/serving/invoke` (default) or `POST /api/serving/{alias}/invoke` (named). + * + * When the type generator has populated `ServingEndpointRegistry`, the response type + * is automatically inferred from the endpoint's OpenAPI schema. + */ +export function useServingInvoke( + body: InferServingRequest, + options: UseServingInvokeOptions = {} as UseServingInvokeOptions, +): UseServingInvokeResult> { + type TResponse = InferServingResponse; + const { alias, autoStart = false } = options; + + const [data, setData] = useState(null); + const [loading, setLoading] = useState(false); + const [error, setError] = useState(null); + const abortControllerRef = useRef(null); + + const urlSuffix = alias + ? `/api/serving/${encodeURIComponent(String(alias))}/invoke` + : "/api/serving/invoke"; + + const bodyJson = JSON.stringify(body); + + const invoke = useCallback((): Promise => { + if (abortControllerRef.current) { + abortControllerRef.current.abort(); + } + + setLoading(true); + setError(null); + setData(null); + + const abortController = new AbortController(); + abortControllerRef.current = abortController; + + return fetch(urlSuffix, { + method: "POST", + headers: { "Content-Type": "application/json" }, + body: bodyJson, + signal: abortController.signal, + }) + .then(async (res) => { + if (!res.ok) { + const errorBody = await res.json().catch(() => null); + throw new Error(errorBody?.error || `HTTP ${res.status}`); + } + return res.json(); + }) + .then((result: TResponse) => { + if (abortController.signal.aborted) return null; + setData(result); + setLoading(false); + return result; + }) + .catch((err: Error) => { + if (abortController.signal.aborted) return null; + setError(err.message || "Request failed"); + setLoading(false); + return null; + }); + }, [urlSuffix, bodyJson]); + + useEffect(() => { + if (autoStart) { + invoke(); + } + + return () => { + abortControllerRef.current?.abort(); + }; + }, [invoke, autoStart]); + + return { invoke, data, loading, error }; +} diff --git a/packages/appkit-ui/src/react/hooks/use-serving-stream.ts b/packages/appkit-ui/src/react/hooks/use-serving-stream.ts new file mode 100644 index 00000000..4801d94c --- /dev/null +++ b/packages/appkit-ui/src/react/hooks/use-serving-stream.ts @@ -0,0 +1,123 @@ +import { useCallback, useEffect, useRef, useState } from "react"; +import { connectSSE } from "@/js"; +import type { + InferServingChunk, + InferServingRequest, + ServingAlias, +} from "./types"; + +export interface UseServingStreamOptions< + K extends ServingAlias = ServingAlias, + T = InferServingChunk, +> { + /** Endpoint alias for named mode. Omit for default mode. */ + alias?: K; + /** If true, starts streaming automatically on mount. Default: false */ + autoStart?: boolean; + /** Called with accumulated chunks when the stream completes successfully. */ + onComplete?: (chunks: T[]) => void; +} + +export interface UseServingStreamResult { + /** Trigger the streaming invocation. */ + stream: () => void; + /** Accumulated chunks received so far. */ + chunks: T[]; + /** Whether streaming is in progress. */ + streaming: boolean; + /** Error message, if any. */ + error: string | null; + /** Reset chunks and abort any active stream. */ + reset: () => void; +} + +/** + * Hook for streaming invocation of a serving endpoint via SSE. + * Calls `POST /api/serving/stream` (default) or `POST /api/serving/{alias}/stream` (named). + * Accumulates parsed chunks in state. + * + * When the type generator has populated `ServingEndpointRegistry`, the chunk type + * is automatically inferred from the endpoint's OpenAPI schema. + */ +export function useServingStream( + body: InferServingRequest, + options: UseServingStreamOptions = {} as UseServingStreamOptions, +): UseServingStreamResult> { + type TChunk = InferServingChunk; + const { alias, autoStart = false, onComplete } = options; + + const [chunks, setChunks] = useState([]); + const [streaming, setStreaming] = useState(false); + const [error, setError] = useState(null); + const abortControllerRef = useRef(null); + const chunksRef = useRef([]); + const onCompleteRef = useRef(onComplete); + onCompleteRef.current = onComplete; + + const urlSuffix = alias + ? `/api/serving/${encodeURIComponent(String(alias))}/stream` + : "/api/serving/stream"; + + const reset = useCallback(() => { + abortControllerRef.current?.abort(); + abortControllerRef.current = null; + chunksRef.current = []; + setChunks([]); + setStreaming(false); + setError(null); + }, []); + + const bodyJson = JSON.stringify(body); + + const stream = useCallback(() => { + // Abort any existing stream + abortControllerRef.current?.abort(); + + setStreaming(true); + setError(null); + setChunks([]); + chunksRef.current = []; + + const abortController = new AbortController(); + abortControllerRef.current = abortController; + + connectSSE({ + url: urlSuffix, + payload: bodyJson, + signal: abortController.signal, + onMessage: async (message) => { + if (abortController.signal.aborted) return; + try { + const parsed = JSON.parse(message.data); + + chunksRef.current = [...chunksRef.current, parsed as TChunk]; + setChunks(chunksRef.current); + } catch { + // Skip malformed messages + } + }, + onError: (err) => { + if (abortController.signal.aborted) return; + setStreaming(false); + setError(err instanceof Error ? err.message : "Streaming failed"); + }, + }).then(() => { + if (abortController.signal.aborted) return; + // Stream completed + setStreaming(false); + onCompleteRef.current?.(chunksRef.current); + }); + }, [urlSuffix, bodyJson]); + + useEffect(() => { + if (autoStart) { + stream(); + } + + return () => { + abortControllerRef.current?.abort(); + }; + }, [stream, autoStart]); + + return { stream, chunks, streaming, error, reset }; +} diff --git a/packages/appkit/package.json b/packages/appkit/package.json index 9e810b97..06da3ee1 100644 --- a/packages/appkit/package.json +++ b/packages/appkit/package.json @@ -50,6 +50,7 @@ "typecheck": "tsc --noEmit" }, "dependencies": { + "@ast-grep/napi": "0.37.0", "@databricks/lakebase": "workspace:*", "@databricks/sdk-experimental": "0.16.0", "@opentelemetry/api": "1.9.0", diff --git a/packages/appkit/src/index.ts b/packages/appkit/src/index.ts index 662a9178..3df5572b 100644 --- a/packages/appkit/src/index.ts +++ b/packages/appkit/src/index.ts @@ -81,6 +81,10 @@ export { SpanStatusCode, type TelemetryConfig, } from "./telemetry"; - +export { + extractServingEndpoints, + findServerFile, +} from "./type-generator/serving/server-file-extractor"; +export { appKitServingTypesPlugin } from "./type-generator/serving/vite-plugin"; // Vite plugin and type generation export { appKitTypesPlugin } from "./type-generator/vite-plugin"; diff --git a/packages/appkit/src/plugins/serving/schema-filter.ts b/packages/appkit/src/plugins/serving/schema-filter.ts index 6e52294a..07683ede 100644 --- a/packages/appkit/src/plugins/serving/schema-filter.ts +++ b/packages/appkit/src/plugins/serving/schema-filter.ts @@ -1,19 +1,9 @@ import fs from "node:fs/promises"; import { createLogger } from "../../logging/logger"; - -const CACHE_VERSION = "1"; - -interface ServingCacheEntry { - hash: string; - requestType: string; - responseType: string; - chunkType: string | null; -} - -interface ServingCache { - version: string; - endpoints: Record; -} +import { + CACHE_VERSION, + type ServingCache, +} from "../../type-generator/serving/cache"; const logger = createLogger("serving:schema-filter"); diff --git a/packages/appkit/src/type-generator/serving/cache.ts b/packages/appkit/src/type-generator/serving/cache.ts new file mode 100644 index 00000000..2737f117 --- /dev/null +++ b/packages/appkit/src/type-generator/serving/cache.ts @@ -0,0 +1,55 @@ +import crypto from "node:crypto"; +import fs from "node:fs/promises"; +import path from "node:path"; +import { createLogger } from "../../logging/logger"; + +const logger = createLogger("type-generator:serving:cache"); + +export const CACHE_VERSION = "1"; +const CACHE_FILE = ".appkit-serving-types-cache.json"; +const CACHE_DIR = path.join( + process.cwd(), + "node_modules", + ".databricks", + "appkit", +); + +export interface ServingCacheEntry { + hash: string; + requestType: string; + responseType: string; + chunkType: string | null; +} + +export interface ServingCache { + version: string; + endpoints: Record; +} + +export function hashSchema(schemaJson: string): string { + return crypto.createHash("sha256").update(schemaJson).digest("hex"); +} + +export async function loadServingCache(): Promise { + const cachePath = path.join(CACHE_DIR, CACHE_FILE); + try { + await fs.mkdir(CACHE_DIR, { recursive: true }); + const raw = await fs.readFile(cachePath, "utf8"); + const cache = JSON.parse(raw) as ServingCache; + if (cache.version === CACHE_VERSION) { + return cache; + } + logger.debug("Cache version mismatch, starting fresh"); + } catch (err) { + if ((err as NodeJS.ErrnoException).code !== "ENOENT") { + logger.warn("Cache file is corrupted, flushing cache completely."); + } + } + return { version: CACHE_VERSION, endpoints: {} }; +} + +export async function saveServingCache(cache: ServingCache): Promise { + const cachePath = path.join(CACHE_DIR, CACHE_FILE); + await fs.mkdir(CACHE_DIR, { recursive: true }); + await fs.writeFile(cachePath, JSON.stringify(cache, null, 2), "utf8"); +} diff --git a/packages/appkit/src/type-generator/serving/converter.ts b/packages/appkit/src/type-generator/serving/converter.ts new file mode 100644 index 00000000..1849e720 --- /dev/null +++ b/packages/appkit/src/type-generator/serving/converter.ts @@ -0,0 +1,149 @@ +import type { OpenApiOperation, OpenApiSchema } from "./fetcher"; + +/** + * Converts an OpenAPI schema to a TypeScript type string. + */ +function schemaToTypeString(schema: OpenApiSchema, indent = 0): string { + const pad = " ".repeat(indent); + + if (schema.oneOf) { + return schema.oneOf.map((s) => schemaToTypeString(s, indent)).join(" | "); + } + + if (schema.enum) { + return schema.enum.map((v) => JSON.stringify(v)).join(" | "); + } + + switch (schema.type) { + case "string": + return "string"; + case "integer": + case "number": + return "number"; + case "boolean": + return "boolean"; + case "array": { + if (!schema.items) return "unknown[]"; + const itemType = schemaToTypeString(schema.items, indent); + // Wrap union types in parens for array + if (itemType.includes(" | ") && !itemType.startsWith("{")) { + return `(${itemType})[]`; + } + return `${itemType}[]`; + } + case "object": { + if (!schema.properties) return "Record"; + const required = new Set(schema.required ?? []); + const entries = Object.entries(schema.properties).map(([key, prop]) => { + const optional = !required.has(key) ? "?" : ""; + const nullable = prop.nullable ? " | null" : ""; + const typeStr = schemaToTypeString(prop, indent + 1); + const formatComment = + prop.format && (prop.type === "number" || prop.type === "integer") + ? `/** @openapi ${prop.format}${prop.nullable ? ", nullable" : ""} */\n${pad} ` + : prop.nullable && prop.type === "integer" + ? `/** @openapi integer, nullable */\n${pad} ` + : ""; + return `${pad} ${formatComment}${key}${optional}: ${typeStr}${nullable};`; + }); + return `{\n${entries.join("\n")}\n${pad}}`; + } + default: + return "unknown"; + } +} + +/** + * Extracts and converts the request schema from an OpenAPI path operation. + * Strips the `stream` property from the request type. + */ +export function convertRequestSchema(operation: OpenApiOperation): string { + const schema = operation.requestBody?.content?.["application/json"]?.schema; + if (!schema || !schema.properties) return "Record"; + + // Strip `stream` property — the plugin controls this + const { stream: _stream, ...filteredProps } = schema.properties; + const filteredRequired = (schema.required ?? []).filter( + (r) => r !== "stream", + ); + + const filteredSchema: OpenApiSchema = { + ...schema, + properties: filteredProps, + required: filteredRequired.length > 0 ? filteredRequired : undefined, + }; + + return schemaToTypeString(filteredSchema); +} + +/** + * Extracts and converts the response schema from an OpenAPI path operation. + */ +export function convertResponseSchema(operation: OpenApiOperation): string { + const response = operation.responses?.["200"]; + const schema = response?.content?.["application/json"]?.schema; + if (!schema) return "unknown"; + return schemaToTypeString(schema); +} + +/** + * Derives a streaming chunk type from the response schema. + * Returns null if the response doesn't follow OpenAI-compatible format. + * + * OpenAI-compatible heuristic: response has `choices` array where items + * have a `message` object property. + */ +export function deriveChunkType(operation: OpenApiOperation): string | null { + const response = operation.responses?.["200"]; + const schema = response?.content?.["application/json"]?.schema; + if (!schema?.properties) return null; + + const choicesProp = schema.properties.choices; + if (!choicesProp || choicesProp.type !== "array" || !choicesProp.items) + return null; + + const choiceItemProps = choicesProp.items.properties; + if (!choiceItemProps?.message) return null; + + // It's OpenAI-compatible. Build the chunk type by transforming. + const messageSchema = choiceItemProps.message; + + // Build chunk schema: replace message with delta (Partial), make finish_reason nullable, drop usage + const chunkProperties: Record = {}; + + for (const [key, prop] of Object.entries(schema.properties)) { + if (key === "usage") continue; // Drop usage from chunks + if (key === "choices") { + // Transform choices items + const chunkChoiceProps: Record = {}; + for (const [ck, cp] of Object.entries(choiceItemProps)) { + if (ck === "message") { + // Replace message with delta: Partial + chunkChoiceProps.delta = { ...messageSchema }; + } else if (ck === "finish_reason") { + chunkChoiceProps[ck] = { ...cp, nullable: true }; + } else { + chunkChoiceProps[ck] = cp; + } + } + chunkProperties[key] = { + type: "array", + items: { + type: "object", + properties: chunkChoiceProps, + }, + }; + } else { + chunkProperties[key] = prop; + } + } + + const chunkSchema: OpenApiSchema = { + type: "object", + properties: chunkProperties, + }; + + // Delta properties are already optional (no `required` array in the schema), + // so schemaToTypeString renders them with `?:` — no Partial<> wrapper needed. + return schemaToTypeString(chunkSchema); +} diff --git a/packages/appkit/src/type-generator/serving/fetcher.ts b/packages/appkit/src/type-generator/serving/fetcher.ts new file mode 100644 index 00000000..bf733d7b --- /dev/null +++ b/packages/appkit/src/type-generator/serving/fetcher.ts @@ -0,0 +1,158 @@ +import type { WorkspaceClient } from "@databricks/sdk-experimental"; +import { createLogger } from "../../logging/logger"; + +const logger = createLogger("type-generator:serving:fetcher"); + +interface OpenApiSpec { + openapi: string; + info: { title: string; version: string }; + paths: Record>; +} + +export interface OpenApiOperation { + requestBody?: { + content: { + "application/json": { + schema: OpenApiSchema; + }; + }; + }; + responses?: Record< + string, + { + content?: { + "application/json": { + schema: OpenApiSchema; + }; + }; + } + >; +} + +export interface OpenApiSchema { + type?: string; + properties?: Record; + required?: string[]; + items?: OpenApiSchema; + enum?: string[]; + nullable?: boolean; + oneOf?: OpenApiSchema[]; + format?: string; +} + +/** + * Fetches the OpenAPI schema for a serving endpoint. + * Returns null if the endpoint is not found or access is denied. + */ +export async function fetchOpenApiSchema( + client: WorkspaceClient, + endpointName: string, + servedModel?: string, +): Promise<{ spec: OpenApiSpec; pathKey: string } | null> { + const headers = new Headers({ Accept: "application/json" }); + await client.config.authenticate(headers); + + const host = client.config.host; + if (!host) { + logger.warn("Databricks host not configured, skipping schema fetch"); + return null; + } + + const base = host.startsWith("http") ? host : `https://${host}`; + const url = new URL( + `/api/2.0/serving-endpoints/${encodeURIComponent(endpointName)}/openapi`, + base, + ); + + const controller = new AbortController(); + const timeout = setTimeout(() => controller.abort(), 5000); + + try { + const res = await fetch(url.toString(), { + headers, + signal: controller.signal, + }); + + if (!res.ok) { + const body = await res.text().catch(() => ""); + if (res.status === 404) { + logger.warn( + "Endpoint '%s' not found, skipping type generation%s", + endpointName, + body ? `: ${body}` : "", + ); + } else if (res.status === 403) { + logger.warn( + "Access denied to endpoint '%s' schema, skipping type generation%s", + endpointName, + body ? `: ${body}` : "", + ); + } else { + logger.warn( + "Failed to fetch schema for '%s' (HTTP %d), skipping%s", + endpointName, + res.status, + body ? `: ${body}` : "", + ); + } + return null; + } + + const rawSpec: unknown = await res.json(); + if ( + typeof rawSpec !== "object" || + rawSpec === null || + !("paths" in rawSpec) || + typeof (rawSpec as OpenApiSpec).paths !== "object" + ) { + logger.warn( + "Invalid OpenAPI schema structure for '%s', skipping", + endpointName, + ); + return null; + } + const spec = rawSpec as OpenApiSpec; + + // Find the right path key + const pathKeys = Object.keys(spec.paths ?? {}); + if (pathKeys.length === 0) { + logger.warn("No paths in OpenAPI schema for '%s'", endpointName); + return null; + } + + let pathKey: string; + if (servedModel) { + const match = pathKeys.find((k) => k.includes(`/${servedModel}/`)); + if (!match) { + logger.warn( + "Served model '%s' not found in schema for '%s', using first path", + servedModel, + endpointName, + ); + pathKey = pathKeys[0]; + } else { + pathKey = match; + } + } else { + pathKey = pathKeys[0]; + } + + return { spec, pathKey }; + } catch (err) { + if ((err as Error).name === "AbortError") { + logger.warn( + "Timeout fetching schema for '%s', skipping type generation", + endpointName, + ); + } else { + logger.warn( + "Error fetching schema for '%s': %s", + endpointName, + (err as Error).message, + ); + } + return null; + } finally { + clearTimeout(timeout); + } +} diff --git a/packages/appkit/src/type-generator/serving/generator.ts b/packages/appkit/src/type-generator/serving/generator.ts new file mode 100644 index 00000000..44026f89 --- /dev/null +++ b/packages/appkit/src/type-generator/serving/generator.ts @@ -0,0 +1,266 @@ +import fs from "node:fs/promises"; +import { WorkspaceClient } from "@databricks/sdk-experimental"; +import pc from "picocolors"; +import { createLogger } from "../../logging/logger"; +import type { EndpointConfig } from "../../plugins/serving/types"; +import { + CACHE_VERSION, + hashSchema, + loadServingCache, + type ServingCache, + saveServingCache, +} from "./cache"; +import { + convertRequestSchema, + convertResponseSchema, + deriveChunkType, +} from "./converter"; +import { fetchOpenApiSchema } from "./fetcher"; + +const logger = createLogger("type-generator:serving"); + +const GENERIC_REQUEST = "Record"; +const GENERIC_RESPONSE = "unknown"; +const GENERIC_CHUNK = "unknown"; + +interface GenerateServingTypesOptions { + outFile: string; + endpoints?: Record; + noCache?: boolean; +} + +/** + * Generates TypeScript type declarations for serving endpoints + * by fetching their OpenAPI schemas and converting to TypeScript. + */ +export async function generateServingTypes( + options: GenerateServingTypesOptions, +): Promise { + const { outFile, noCache } = options; + + // Resolve endpoints from config or env + const endpoints = options.endpoints ?? resolveDefaultEndpoints(); + if (Object.keys(endpoints).length === 0) { + logger.debug("No serving endpoints configured, skipping type generation"); + return; + } + + const startTime = performance.now(); + + const cache = noCache + ? { version: CACHE_VERSION, endpoints: {} } + : await loadServingCache(); + + const client = new WorkspaceClient({}); + let updated = false; + + const registryEntries: string[] = []; + const logEntries: Array<{ + alias: string; + status: "HIT" | "MISS"; + error?: string; + }> = []; + + for (const [alias, config] of Object.entries(endpoints)) { + const endpointName = process.env[config.env]; + if (!endpointName) { + registryEntries.push( + buildRegistryEntry( + alias, + GENERIC_REQUEST, + GENERIC_RESPONSE, + GENERIC_CHUNK, + ), + ); + logEntries.push({ + alias, + status: "MISS", + error: `env ${config.env} not set`, + }); + continue; + } + + const result = await fetchOpenApiSchema( + client, + endpointName, + config.servedModel, + ); + if (!result) { + registryEntries.push( + buildRegistryEntry( + alias, + GENERIC_REQUEST, + GENERIC_RESPONSE, + GENERIC_CHUNK, + ), + ); + logEntries.push({ + alias, + status: "MISS", + error: "schema fetch failed", + }); + continue; + } + + const { spec, pathKey } = result; + const schemaJson = JSON.stringify(spec); + const hash = hashSchema(schemaJson); + + // Check cache + const cached = cache.endpoints[alias]; + if (cached && cached.hash === hash) { + registryEntries.push( + buildRegistryEntry( + alias, + cached.requestType, + cached.responseType, + cached.chunkType, + ), + ); + logEntries.push({ alias, status: "HIT" }); + continue; + } + + // Cache miss — convert + const operation = spec.paths[pathKey]?.post; + if (!operation) { + logEntries.push({ + alias, + status: "MISS", + error: "no POST operation", + }); + continue; + } + + let requestType: string; + let responseType: string; + let chunkType: string | null; + try { + requestType = convertRequestSchema(operation); + responseType = convertResponseSchema(operation); + chunkType = deriveChunkType(operation); + } catch (convErr) { + logger.warn( + "Schema conversion failed for '%s': %s", + alias, + (convErr as Error).message, + ); + registryEntries.push( + buildRegistryEntry( + alias, + GENERIC_REQUEST, + GENERIC_RESPONSE, + GENERIC_CHUNK, + ), + ); + logEntries.push({ + alias, + status: "MISS", + error: "schema conversion failed", + }); + continue; + } + + cache.endpoints[alias] = { hash, requestType, responseType, chunkType }; + updated = true; + + registryEntries.push( + buildRegistryEntry(alias, requestType, responseType, chunkType), + ); + logEntries.push({ alias, status: "MISS" }); + } + + // Print formatted table (matching analytics typegen output) + if (logEntries.length > 0) { + const maxNameLen = Math.max(...logEntries.map((e) => e.alias.length)); + const separator = pc.dim("─".repeat(50)); + console.log(""); + console.log( + ` ${pc.bold("Typegen Serving")} ${pc.dim(`(${logEntries.length})`)}`, + ); + console.log(` ${separator}`); + for (const entry of logEntries) { + const tag = + entry.status === "HIT" + ? `cache ${pc.bold(pc.green("HIT "))}` + : `cache ${pc.bold(pc.yellow("MISS "))}`; + const rawName = entry.alias.padEnd(maxNameLen); + const reason = entry.error ? ` ${pc.dim(entry.error)}` : ""; + console.log(` ${tag} ${rawName}${reason}`); + } + const elapsed = ((performance.now() - startTime) / 1000).toFixed(2); + const newCount = logEntries.filter((e) => e.status === "MISS").length; + const cacheCount = logEntries.filter((e) => e.status === "HIT").length; + console.log(` ${separator}`); + console.log( + ` ${newCount} new, ${cacheCount} from cache. ${pc.dim(`${elapsed}s`)}`, + ); + console.log(""); + } + + const output = generateTypeDeclarations(registryEntries); + await fs.writeFile(outFile, output, "utf-8"); + + if (registryEntries.length === 0) { + logger.debug( + "Wrote empty serving types to %s (no endpoints resolved)", + outFile, + ); + } else { + logger.debug("Wrote serving types to %s", outFile); + } + + if (updated) { + await saveServingCache(cache as ServingCache); + } +} + +function resolveDefaultEndpoints(): Record { + if (process.env.DATABRICKS_SERVING_ENDPOINT) { + return { default: { env: "DATABRICKS_SERVING_ENDPOINT" } }; + } + return {}; +} + +function buildRegistryEntry( + alias: string, + requestType: string, + responseType: string, + chunkType: string | null, +): string { + const indent = " "; + const chunkEntry = chunkType ? chunkType : "unknown"; + return ` ${alias}: { +${indent}request: ${indentType(requestType, indent)}; +${indent}response: ${indentType(responseType, indent)}; +${indent}chunk: ${indentType(chunkEntry, indent)}; + };`; +} + +function indentType(typeStr: string, baseIndent: string): string { + if (!typeStr.includes("\n")) return typeStr; + return typeStr + .split("\n") + .map((line, i) => (i === 0 ? line : `${baseIndent}${line}`)) + .join("\n"); +} + +function generateTypeDeclarations(entries: string[]): string { + return `// Auto-generated by AppKit - DO NOT EDIT +// Generated from serving endpoint OpenAPI schemas +import "@databricks/appkit"; +import "@databricks/appkit-ui/react"; + +declare module "@databricks/appkit" { + interface ServingEndpointRegistry { +${entries.join("\n")} + } +} + +declare module "@databricks/appkit-ui/react" { + interface ServingEndpointRegistry { +${entries.join("\n")} + } +} +`; +} diff --git a/packages/appkit/src/type-generator/serving/server-file-extractor.ts b/packages/appkit/src/type-generator/serving/server-file-extractor.ts new file mode 100644 index 00000000..cb1fbe7e --- /dev/null +++ b/packages/appkit/src/type-generator/serving/server-file-extractor.ts @@ -0,0 +1,221 @@ +import fs from "node:fs"; +import path from "node:path"; +import { Lang, parse, type SgNode } from "@ast-grep/napi"; +import { createLogger } from "../../logging/logger"; +import type { EndpointConfig } from "../../plugins/serving/types"; + +const logger = createLogger("type-generator:serving:extractor"); + +/** + * Candidate paths for the server entry file, relative to the project root. + * Checked in order; the first that exists is used. + * Same convention as plugin sync (sync.ts SERVER_FILE_CANDIDATES). + */ +const SERVER_FILE_CANDIDATES = ["server/index.ts", "server/server.ts"]; + +/** + * Find the server entry file by checking candidate paths in order. + * + * @param basePath - Project root directory to search from + * @returns Absolute path to the server file, or null if none found + */ +export function findServerFile(basePath: string): string | null { + for (const candidate of SERVER_FILE_CANDIDATES) { + const fullPath = path.join(basePath, candidate); + if (fs.existsSync(fullPath)) { + return fullPath; + } + } + return null; +} + +/** + * Extract serving endpoint config from a server file by AST-parsing it. + * Looks for `serving({ endpoints: { alias: { env: "..." }, ... } })` calls + * and extracts the endpoint alias names and their environment variable mappings. + * + * @param serverFilePath - Absolute path to the server entry file + * @returns Extracted endpoint config, or null if not found or not extractable + */ +export function extractServingEndpoints( + serverFilePath: string, +): Record | null { + let content: string; + try { + content = fs.readFileSync(serverFilePath, "utf-8"); + } catch { + logger.debug("Could not read server file: %s", serverFilePath); + return null; + } + + const lang = serverFilePath.endsWith(".tsx") ? Lang.Tsx : Lang.TypeScript; + const ast = parse(lang, content); + const root = ast.root(); + + // Find serving(...) call expressions + const servingCall = findServingCall(root); + if (!servingCall) { + logger.debug("No serving() call found in %s", serverFilePath); + return null; + } + + // Get the first argument (the config object) + const args = servingCall.field("arguments"); + if (!args) { + return null; + } + + const configArg = args.children().find((child) => child.kind() === "object"); + if (!configArg) { + // serving() called with no args or non-object arg + return null; + } + + // Find the "endpoints" property in the config object + const endpointsPair = findProperty(configArg, "endpoints"); + if (!endpointsPair) { + // Config object has no "endpoints" property (e.g. serving({ timeout: 5000 })) + return null; + } + + // Get the value of the endpoints property + const endpointsValue = getPropertyValue(endpointsPair); + if (!endpointsValue || endpointsValue.kind() !== "object") { + // endpoints is a variable reference, not an inline object + logger.debug( + "serving() endpoints is not an inline object literal in %s. " + + "Pass endpoints explicitly via appKitServingTypesPlugin({ endpoints }) in vite.config.ts.", + serverFilePath, + ); + return null; + } + + // Extract each endpoint entry + const endpoints: Record = {}; + const pairs = endpointsValue + .children() + .filter((child) => child.kind() === "pair"); + + for (const pair of pairs) { + const entry = extractEndpointEntry(pair); + if (entry) { + endpoints[entry.alias] = entry.config; + } + } + + if (Object.keys(endpoints).length === 0) { + return null; + } + + logger.debug( + "Extracted %d endpoint(s) from %s: %s", + Object.keys(endpoints).length, + serverFilePath, + Object.keys(endpoints).join(", "), + ); + + return endpoints; +} + +/** + * Find the serving() call expression in the AST. + * Looks for call expressions where the callee identifier is "serving". + */ +function findServingCall(root: SgNode): SgNode | null { + const callExpressions = root.findAll({ + rule: { kind: "call_expression" }, + }); + + for (const call of callExpressions) { + const callee = call.children()[0]; + if (callee?.kind() === "identifier" && callee.text() === "serving") { + return call; + } + } + + return null; +} + +/** + * Find a property (pair node) with the given key name in an object expression. + */ +function findProperty(objectNode: SgNode, propertyName: string): SgNode | null { + const pairs = objectNode + .children() + .filter((child) => child.kind() === "pair"); + + for (const pair of pairs) { + const key = pair.children()[0]; + if (!key) continue; + + const keyText = + key.kind() === "property_identifier" + ? key.text() + : key.kind() === "string" + ? key.text().replace(/^['"]|['"]$/g, "") + : null; + + if (keyText === propertyName) { + return pair; + } + } + + return null; +} + +/** + * Get the value node from a pair (property: value). + * The value is typically the last meaningful child after the colon. + */ +function getPropertyValue(pairNode: SgNode): SgNode | null { + const children = pairNode.children(); + // pair children: [key, ":", value] + return children.length >= 3 ? children[children.length - 1] : null; +} + +/** + * Extract a single endpoint entry from a pair node like: + * `demo: { env: "DATABRICKS_SERVING_ENDPOINT", servedModel: "my-model" }` + */ +function extractEndpointEntry( + pair: SgNode, +): { alias: string; config: EndpointConfig } | null { + const children = pair.children(); + if (children.length < 3) return null; + + // Get alias name (the key) + const keyNode = children[0]; + const alias = + keyNode.kind() === "property_identifier" + ? keyNode.text() + : keyNode.kind() === "string" + ? keyNode.text().replace(/^['"]|['"]$/g, "") + : null; + + if (!alias) return null; + + // Get the value (should be an object like { env: "..." }) + const valueNode = children[children.length - 1]; + if (valueNode.kind() !== "object") return null; + + // Extract env field + const envPair = findProperty(valueNode, "env"); + if (!envPair) return null; + + const envValue = getPropertyValue(envPair); + if (!envValue || envValue.kind() !== "string") return null; + + const env = envValue.text().replace(/^['"]|['"]$/g, ""); + + // Extract optional servedModel field + const config: EndpointConfig = { env }; + const servedModelPair = findProperty(valueNode, "servedModel"); + if (servedModelPair) { + const servedModelValue = getPropertyValue(servedModelPair); + if (servedModelValue?.kind() === "string") { + config.servedModel = servedModelValue.text().replace(/^['"]|['"]$/g, ""); + } + } + + return { alias, config }; +} diff --git a/packages/appkit/src/type-generator/serving/tests/cache.test.ts b/packages/appkit/src/type-generator/serving/tests/cache.test.ts new file mode 100644 index 00000000..1c0ab21c --- /dev/null +++ b/packages/appkit/src/type-generator/serving/tests/cache.test.ts @@ -0,0 +1,107 @@ +import fs from "node:fs/promises"; +import { afterEach, beforeEach, describe, expect, test, vi } from "vitest"; +import { + CACHE_VERSION, + hashSchema, + loadServingCache, + type ServingCache, + saveServingCache, +} from "../cache"; + +vi.mock("node:fs/promises"); + +describe("serving cache", () => { + beforeEach(() => { + vi.mocked(fs.mkdir).mockResolvedValue(undefined); + }); + + afterEach(() => { + vi.restoreAllMocks(); + }); + + describe("hashSchema", () => { + test("returns consistent SHA256 hash", () => { + const hash1 = hashSchema('{"openapi": "3.1.0"}'); + const hash2 = hashSchema('{"openapi": "3.1.0"}'); + expect(hash1).toBe(hash2); + expect(hash1).toHaveLength(64); // SHA256 hex + }); + + test("different inputs produce different hashes", () => { + const hash1 = hashSchema('{"a": 1}'); + const hash2 = hashSchema('{"a": 2}'); + expect(hash1).not.toBe(hash2); + }); + }); + + describe("loadServingCache", () => { + test("returns empty cache when file does not exist", async () => { + vi.mocked(fs.readFile).mockRejectedValue( + Object.assign(new Error("ENOENT"), { code: "ENOENT" }), + ); + + const cache = await loadServingCache(); + expect(cache).toEqual({ version: CACHE_VERSION, endpoints: {} }); + }); + + test("returns parsed cache when file exists with correct version", async () => { + const cached: ServingCache = { + version: CACHE_VERSION, + endpoints: { + llm: { + hash: "abc", + requestType: "{ messages: string[] }", + responseType: "{ model: string }", + chunkType: null, + }, + }, + }; + vi.mocked(fs.readFile).mockResolvedValue(JSON.stringify(cached)); + + const cache = await loadServingCache(); + expect(cache).toEqual(cached); + }); + + test("flushes cache when version mismatches", async () => { + vi.mocked(fs.readFile).mockResolvedValue( + JSON.stringify({ version: "0", endpoints: { old: {} } }), + ); + + const cache = await loadServingCache(); + expect(cache).toEqual({ version: CACHE_VERSION, endpoints: {} }); + }); + + test("flushes cache when file is corrupted", async () => { + vi.mocked(fs.readFile).mockResolvedValue("not json"); + + const cache = await loadServingCache(); + expect(cache).toEqual({ version: CACHE_VERSION, endpoints: {} }); + }); + }); + + describe("saveServingCache", () => { + test("writes cache to file", async () => { + vi.mocked(fs.writeFile).mockResolvedValue(); + + const cache: ServingCache = { + version: CACHE_VERSION, + endpoints: { + test: { + hash: "xyz", + requestType: "{}", + responseType: "{}", + chunkType: null, + }, + }, + }; + + await saveServingCache(cache); + + expect(fs.writeFile).toHaveBeenCalledWith( + expect.stringContaining(".appkit-serving-types-cache.json"), + JSON.stringify(cache, null, 2), + "utf8", + ); + }); + }); +}); diff --git a/packages/appkit/src/type-generator/serving/tests/converter.test.ts b/packages/appkit/src/type-generator/serving/tests/converter.test.ts new file mode 100644 index 00000000..ca794fb3 --- /dev/null +++ b/packages/appkit/src/type-generator/serving/tests/converter.test.ts @@ -0,0 +1,278 @@ +import { describe, expect, test } from "vitest"; +import { + convertRequestSchema, + convertResponseSchema, + deriveChunkType, +} from "../converter"; +import type { OpenApiOperation, OpenApiSchema } from "../fetcher"; + +function makeOperation( + requestProps: Record, + responseProps?: Record, + required?: string[], +): OpenApiOperation { + return { + requestBody: { + content: { + "application/json": { + schema: { + type: "object", + properties: requestProps, + required, + }, + }, + }, + }, + responses: responseProps + ? { + "200": { + content: { + "application/json": { + schema: { + type: "object", + properties: responseProps, + }, + }, + }, + }, + } + : undefined, + }; +} + +describe("converter", () => { + describe("convertRequestSchema", () => { + test("converts string type", () => { + const op = makeOperation({ name: { type: "string" } }); + const result = convertRequestSchema(op); + expect(result).toContain("name?: string;"); + }); + + test("converts integer type to number", () => { + const op = makeOperation({ count: { type: "integer" } }); + expect(convertRequestSchema(op)).toContain("count?: number;"); + }); + + test("converts number type", () => { + const op = makeOperation({ + temp: { type: "number", format: "double" }, + }); + expect(convertRequestSchema(op)).toContain("temp?: number;"); + }); + + test("converts boolean type", () => { + const op = makeOperation({ flag: { type: "boolean" } }); + expect(convertRequestSchema(op)).toContain("flag?: boolean;"); + }); + + test("converts enum to string literal union", () => { + const op = makeOperation({ + role: { type: "string", enum: ["user", "assistant"] }, + }); + const result = convertRequestSchema(op); + expect(result).toContain('"user" | "assistant"'); + }); + + test("converts array type", () => { + const op = makeOperation({ + items: { type: "array", items: { type: "string" } }, + }); + expect(convertRequestSchema(op)).toContain("items?: string[];"); + }); + + test("converts nested object", () => { + const op = makeOperation({ + messages: { + type: "array", + items: { + type: "object", + properties: { + role: { type: "string" }, + content: { type: "string" }, + }, + }, + }, + }); + const result = convertRequestSchema(op); + expect(result).toContain("role?: string;"); + expect(result).toContain("content?: string;"); + }); + + test("handles nullable properties", () => { + const op = makeOperation({ + temp: { type: "number", nullable: true }, + }); + expect(convertRequestSchema(op)).toContain("temp?: number | null;"); + }); + + test("handles oneOf union types", () => { + const op = makeOperation({ + stop: { + oneOf: [ + { type: "string" }, + { type: "array", items: { type: "string" } }, + ], + }, + }); + const result = convertRequestSchema(op); + expect(result).toContain("string | string[]"); + }); + + test("strips stream property from request", () => { + const op = makeOperation({ + messages: { type: "array", items: { type: "string" } }, + stream: { type: "boolean", nullable: true }, + temperature: { type: "number" }, + }); + const result = convertRequestSchema(op); + expect(result).not.toContain("stream"); + expect(result).toContain("messages"); + expect(result).toContain("temperature"); + }); + + test("marks required properties without ?", () => { + const op = makeOperation( + { + messages: { type: "array", items: { type: "string" } }, + temperature: { type: "number" }, + }, + undefined, + ["messages"], + ); + const result = convertRequestSchema(op); + expect(result).toContain("messages: string[];"); + expect(result).toContain("temperature?: number;"); + }); + + test("returns Record for missing schema", () => { + const op: OpenApiOperation = {}; + expect(convertRequestSchema(op)).toBe("Record"); + }); + }); + + describe("convertResponseSchema", () => { + test("converts response schema", () => { + const op = makeOperation( + {}, + { + model: { type: "string" }, + id: { type: "string" }, + }, + ); + const result = convertResponseSchema(op); + expect(result).toContain("model?: string;"); + expect(result).toContain("id?: string;"); + }); + + test("returns unknown for missing response", () => { + const op: OpenApiOperation = {}; + expect(convertResponseSchema(op)).toBe("unknown"); + }); + }); + + describe("deriveChunkType", () => { + test("derives chunk type from OpenAI-compatible response", () => { + const op: OpenApiOperation = { + responses: { + "200": { + content: { + "application/json": { + schema: { + type: "object", + properties: { + model: { type: "string" }, + choices: { + type: "array", + items: { + type: "object", + properties: { + index: { type: "integer" }, + message: { + type: "object", + properties: { + role: { + type: "string", + enum: ["user", "assistant"], + }, + content: { type: "string" }, + }, + }, + finish_reason: { type: "string" }, + }, + }, + }, + usage: { + type: "object", + properties: { + prompt_tokens: { type: "integer" }, + }, + nullable: true, + }, + id: { type: "string" }, + }, + }, + }, + }, + }, + }, + }; + + const result = deriveChunkType(op); + expect(result).not.toBeNull(); + // Should have delta instead of message + expect(result).toContain("delta"); + expect(result).not.toContain("message"); + // Should make finish_reason nullable + expect(result).toContain("finish_reason"); + expect(result).toContain("| null"); + // Should drop usage + expect(result).not.toContain("usage"); + // Should keep model and id + expect(result).toContain("model"); + expect(result).toContain("id"); + }); + + test("returns null for non-OpenAI response (no choices)", () => { + const op = makeOperation( + {}, + { + predictions: { type: "array", items: { type: "number" } }, + }, + ); + expect(deriveChunkType(op)).toBeNull(); + }); + + test("returns null for choices without message", () => { + const op: OpenApiOperation = { + responses: { + "200": { + content: { + "application/json": { + schema: { + type: "object", + properties: { + choices: { + type: "array", + items: { + type: "object", + properties: { + score: { type: "number" }, + }, + }, + }, + }, + }, + }, + }, + }, + }, + }; + expect(deriveChunkType(op)).toBeNull(); + }); + + test("returns null for missing response", () => { + const op: OpenApiOperation = {}; + expect(deriveChunkType(op)).toBeNull(); + }); + }); +}); diff --git a/packages/appkit/src/type-generator/serving/tests/fetcher.test.ts b/packages/appkit/src/type-generator/serving/tests/fetcher.test.ts new file mode 100644 index 00000000..802540b0 --- /dev/null +++ b/packages/appkit/src/type-generator/serving/tests/fetcher.test.ts @@ -0,0 +1,209 @@ +import { afterEach, beforeEach, describe, expect, test, vi } from "vitest"; +import { fetchOpenApiSchema } from "../fetcher"; + +const mockAuthenticate = vi.fn(async () => {}); + +function createMockClient(host?: string) { + return { + config: { + host, + authenticate: mockAuthenticate, + }, + } as any; +} + +function makeValidSpec( + paths: Record = { "/invocations": { post: {} } }, +) { + return { + openapi: "3.0.0", + info: { title: "test", version: "1" }, + paths, + }; +} + +describe("fetchOpenApiSchema", () => { + beforeEach(() => { + vi.spyOn(globalThis, "fetch").mockResolvedValue( + new Response(JSON.stringify(makeValidSpec()), { status: 200 }), + ); + }); + + afterEach(() => { + vi.restoreAllMocks(); + }); + + test("returns null when host is not configured", async () => { + const result = await fetchOpenApiSchema(createMockClient(undefined), "ep"); + expect(result).toBeNull(); + }); + + test("returns null on HTTP 404", async () => { + vi.spyOn(globalThis, "fetch").mockResolvedValue( + new Response("Not found", { status: 404 }), + ); + + const result = await fetchOpenApiSchema( + createMockClient("https://host.databricks.com"), + "my-endpoint", + ); + expect(result).toBeNull(); + }); + + test("returns null on HTTP 403", async () => { + vi.spyOn(globalThis, "fetch").mockResolvedValue( + new Response("Forbidden", { status: 403 }), + ); + + const result = await fetchOpenApiSchema( + createMockClient("https://host.databricks.com"), + "my-endpoint", + ); + expect(result).toBeNull(); + }); + + test("returns null on generic error status", async () => { + vi.spyOn(globalThis, "fetch").mockResolvedValue( + new Response("Server error", { status: 500 }), + ); + + const result = await fetchOpenApiSchema( + createMockClient("https://host.databricks.com"), + "my-endpoint", + ); + expect(result).toBeNull(); + }); + + test("returns null on timeout (AbortError)", async () => { + vi.spyOn(globalThis, "fetch").mockRejectedValue( + Object.assign(new Error("The operation was aborted"), { + name: "AbortError", + }), + ); + + const result = await fetchOpenApiSchema( + createMockClient("https://host.databricks.com"), + "my-endpoint", + ); + expect(result).toBeNull(); + }); + + test("returns null on network error", async () => { + vi.spyOn(globalThis, "fetch").mockRejectedValue(new Error("fetch failed")); + + const result = await fetchOpenApiSchema( + createMockClient("https://host.databricks.com"), + "my-endpoint", + ); + expect(result).toBeNull(); + }); + + test("returns spec and pathKey for valid response", async () => { + const spec = makeValidSpec({ + "/serving-endpoints/ep/invocations": { post: { requestBody: {} } }, + }); + vi.spyOn(globalThis, "fetch").mockResolvedValue( + new Response(JSON.stringify(spec), { status: 200 }), + ); + + const result = await fetchOpenApiSchema( + createMockClient("https://host.databricks.com"), + "ep", + ); + expect(result).not.toBeNull(); + expect(result?.pathKey).toBe("/serving-endpoints/ep/invocations"); + expect(result?.spec.openapi).toBe("3.0.0"); + }); + + test("matches servedModel path when provided", async () => { + const spec = makeValidSpec({ + "/serving-endpoints/ep/served-models/gpt4/invocations": { post: {} }, + "/serving-endpoints/ep/invocations": { post: {} }, + }); + vi.spyOn(globalThis, "fetch").mockResolvedValue( + new Response(JSON.stringify(spec), { status: 200 }), + ); + + const result = await fetchOpenApiSchema( + createMockClient("https://host.databricks.com"), + "ep", + "gpt4", + ); + expect(result?.pathKey).toBe( + "/serving-endpoints/ep/served-models/gpt4/invocations", + ); + }); + + test("falls back to first path when servedModel not found", async () => { + const spec = makeValidSpec({ + "/serving-endpoints/ep/invocations": { post: {} }, + }); + vi.spyOn(globalThis, "fetch").mockResolvedValue( + new Response(JSON.stringify(spec), { status: 200 }), + ); + + const result = await fetchOpenApiSchema( + createMockClient("https://host.databricks.com"), + "ep", + "nonexistent-model", + ); + expect(result?.pathKey).toBe("/serving-endpoints/ep/invocations"); + }); + + test("returns null for invalid spec structure (missing paths)", async () => { + vi.spyOn(globalThis, "fetch").mockResolvedValue( + new Response(JSON.stringify({ openapi: "3.0.0", info: {} }), { + status: 200, + }), + ); + + const result = await fetchOpenApiSchema( + createMockClient("https://host.databricks.com"), + "ep", + ); + expect(result).toBeNull(); + }); + + test("returns null when paths object is empty", async () => { + vi.spyOn(globalThis, "fetch").mockResolvedValue( + new Response(JSON.stringify(makeValidSpec({})), { status: 200 }), + ); + + const result = await fetchOpenApiSchema( + createMockClient("https://host.databricks.com"), + "ep", + ); + expect(result).toBeNull(); + }); + + test("authenticates request headers", async () => { + await fetchOpenApiSchema( + createMockClient("https://host.databricks.com"), + "ep", + ); + expect(mockAuthenticate).toHaveBeenCalledWith(expect.any(Headers)); + }); + + test("constructs correct URL with encoded endpoint name", async () => { + const fetchSpy = vi.spyOn(globalThis, "fetch"); + + await fetchOpenApiSchema( + createMockClient("https://host.databricks.com"), + "my endpoint", + ); + + expect(fetchSpy).toHaveBeenCalledWith( + expect.stringContaining("/serving-endpoints/my%20endpoint/openapi"), + expect.any(Object), + ); + }); + + test("prepends https when host lacks protocol", async () => { + const fetchSpy = vi.spyOn(globalThis, "fetch"); + + await fetchOpenApiSchema(createMockClient("host.databricks.com"), "ep"); + + const url = fetchSpy.mock.calls[0][0] as string; + expect(url.startsWith("https://")).toBe(true); + }); +}); diff --git a/packages/appkit/src/type-generator/serving/tests/generator.test.ts b/packages/appkit/src/type-generator/serving/tests/generator.test.ts new file mode 100644 index 00000000..f9d1b378 --- /dev/null +++ b/packages/appkit/src/type-generator/serving/tests/generator.test.ts @@ -0,0 +1,215 @@ +import fs from "node:fs/promises"; +import { afterEach, beforeEach, describe, expect, test, vi } from "vitest"; +import { generateServingTypes } from "../generator"; + +vi.mock("node:fs/promises"); + +// Mock cache module +vi.mock("../cache", () => ({ + CACHE_VERSION: "1", + hashSchema: vi.fn(() => "mock-hash"), + loadServingCache: vi.fn(async () => ({ version: "1", endpoints: {} })), + saveServingCache: vi.fn(async () => {}), +})); + +// Mock fetcher +const mockFetchOpenApiSchema = vi.fn(); +vi.mock("../fetcher", () => ({ + fetchOpenApiSchema: (...args: any[]) => mockFetchOpenApiSchema(...args), +})); + +// Mock WorkspaceClient +vi.mock("@databricks/sdk-experimental", () => ({ + WorkspaceClient: vi.fn(() => ({ config: {} })), +})); + +const CHAT_OPENAPI_SPEC = { + openapi: "3.1.0", + info: { title: "test", version: "1" }, + paths: { + "/served-models/llm/invocations": { + post: { + requestBody: { + content: { + "application/json": { + schema: { + type: "object", + properties: { + messages: { + type: "array", + items: { + type: "object", + properties: { + role: { type: "string" }, + content: { type: "string" }, + }, + }, + }, + temperature: { type: "number", nullable: true }, + stream: { type: "boolean", nullable: true }, + }, + }, + }, + }, + }, + responses: { + "200": { + content: { + "application/json": { + schema: { + type: "object", + properties: { + model: { type: "string" }, + choices: { + type: "array", + items: { + type: "object", + properties: { + message: { + type: "object", + properties: { + role: { type: "string" }, + content: { type: "string" }, + }, + }, + finish_reason: { type: "string" }, + }, + }, + }, + }, + }, + }, + }, + }, + }, + }, + }, + }, +}; + +describe("generateServingTypes", () => { + const outFile = "/tmp/test-serving-types.d.ts"; + + beforeEach(() => { + vi.mocked(fs.writeFile).mockResolvedValue(); + process.env.TEST_SERVING_ENDPOINT = "my-endpoint"; + }); + + afterEach(() => { + delete process.env.TEST_SERVING_ENDPOINT; + delete process.env.DATABRICKS_SERVING_ENDPOINT; + vi.restoreAllMocks(); + }); + + test("generates .d.ts with module augmentation for a chat endpoint", async () => { + mockFetchOpenApiSchema.mockResolvedValue({ + spec: CHAT_OPENAPI_SPEC, + pathKey: "/served-models/llm/invocations", + }); + + await generateServingTypes({ + outFile, + endpoints: { llm: { env: "TEST_SERVING_ENDPOINT" } }, + noCache: true, + }); + + expect(fs.writeFile).toHaveBeenCalledWith( + outFile, + expect.any(String), + "utf-8", + ); + const output = vi.mocked(fs.writeFile).mock.calls[0][1] as string; + + // Verify module augmentation structure + expect(output).toContain("// Auto-generated by AppKit - DO NOT EDIT"); + expect(output).toContain('import "@databricks/appkit"'); + expect(output).toContain('import "@databricks/appkit-ui/react"'); + expect(output).toContain('declare module "@databricks/appkit"'); + expect(output).toContain('declare module "@databricks/appkit-ui/react"'); + expect(output).toContain("interface ServingEndpointRegistry"); + expect(output).toContain("llm:"); + expect(output).toContain("request:"); + expect(output).toContain("response:"); + expect(output).toContain("chunk:"); + }); + + test("strips stream property from generated request type", async () => { + mockFetchOpenApiSchema.mockResolvedValue({ + spec: CHAT_OPENAPI_SPEC, + pathKey: "/served-models/llm/invocations", + }); + + await generateServingTypes({ + outFile, + endpoints: { llm: { env: "TEST_SERVING_ENDPOINT" } }, + noCache: true, + }); + + const output = vi.mocked(fs.writeFile).mock.calls[0][1] as string; + // `stream` should be stripped from request type + expect(output).toContain("messages"); + expect(output).toContain("temperature"); + expect(output).not.toMatch(/\bstream\??\s*:/); + }); + + test("emits generic types when env var is not set", async () => { + delete process.env.TEST_SERVING_ENDPOINT; + + await generateServingTypes({ + outFile, + endpoints: { llm: { env: "TEST_SERVING_ENDPOINT" } }, + noCache: true, + }); + + expect(mockFetchOpenApiSchema).not.toHaveBeenCalled(); + const output = vi.mocked(fs.writeFile).mock.calls[0][1] as string; + expect(output).toContain("llm:"); + expect(output).toContain("Record"); + }); + + test("skips generation when no endpoints configured and no env var", async () => { + await generateServingTypes({ + outFile, + noCache: true, + }); + + expect(mockFetchOpenApiSchema).not.toHaveBeenCalled(); + expect(fs.writeFile).not.toHaveBeenCalled(); + }); + + test("emits generic types when schema fetch returns null", async () => { + mockFetchOpenApiSchema.mockResolvedValue(null); + + await generateServingTypes({ + outFile, + endpoints: { llm: { env: "TEST_SERVING_ENDPOINT" } }, + noCache: true, + }); + + const output = vi.mocked(fs.writeFile).mock.calls[0][1] as string; + expect(output).toContain("llm:"); + expect(output).toContain("Record"); + }); + + test("resolves default endpoint from DATABRICKS_SERVING_ENDPOINT", async () => { + process.env.DATABRICKS_SERVING_ENDPOINT = "my-default-endpoint"; + mockFetchOpenApiSchema.mockResolvedValue({ + spec: CHAT_OPENAPI_SPEC, + pathKey: "/served-models/llm/invocations", + }); + + await generateServingTypes({ + outFile, + noCache: true, + }); + + expect(mockFetchOpenApiSchema).toHaveBeenCalledWith( + expect.anything(), + "my-default-endpoint", + undefined, + ); + + const output = vi.mocked(fs.writeFile).mock.calls[0][1] as string; + expect(output).toContain("default:"); + }); +}); diff --git a/packages/appkit/src/type-generator/serving/tests/server-file-extractor.test.ts b/packages/appkit/src/type-generator/serving/tests/server-file-extractor.test.ts new file mode 100644 index 00000000..f0a94709 --- /dev/null +++ b/packages/appkit/src/type-generator/serving/tests/server-file-extractor.test.ts @@ -0,0 +1,213 @@ +import fs from "node:fs"; +import path from "node:path"; +import { afterEach, describe, expect, test, vi } from "vitest"; +import { + extractServingEndpoints, + findServerFile, +} from "../server-file-extractor"; + +describe("findServerFile", () => { + afterEach(() => { + vi.restoreAllMocks(); + }); + + test("returns server/index.ts when it exists", () => { + vi.spyOn(fs, "existsSync").mockImplementation((p) => + String(p).endsWith(path.join("server", "index.ts")), + ); + expect(findServerFile("/app")).toBe( + path.join("/app", "server", "index.ts"), + ); + }); + + test("returns server/server.ts when index.ts does not exist", () => { + vi.spyOn(fs, "existsSync").mockImplementation((p) => + String(p).endsWith(path.join("server", "server.ts")), + ); + expect(findServerFile("/app")).toBe( + path.join("/app", "server", "server.ts"), + ); + }); + + test("returns null when no server file exists", () => { + vi.spyOn(fs, "existsSync").mockReturnValue(false); + expect(findServerFile("/app")).toBeNull(); + }); +}); + +describe("extractServingEndpoints", () => { + afterEach(() => { + vi.restoreAllMocks(); + }); + + function mockServerFile(content: string) { + vi.spyOn(fs, "readFileSync").mockReturnValue(content); + } + + test("extracts inline endpoints from serving() call", () => { + mockServerFile(` +import { createApp, serving } from '@databricks/appkit'; + +createApp({ + plugins: [ + serving({ + endpoints: { + demo: { env: "DATABRICKS_SERVING_ENDPOINT" }, + second: { env: "DATABRICKS_SERVING_ENDPOINT_SECOND" }, + } + }), + ], +}); +`); + + const result = extractServingEndpoints("/app/server/index.ts"); + expect(result).toEqual({ + demo: { env: "DATABRICKS_SERVING_ENDPOINT" }, + second: { env: "DATABRICKS_SERVING_ENDPOINT_SECOND" }, + }); + }); + + test("extracts servedModel when present", () => { + mockServerFile(` +import { createApp, serving } from '@databricks/appkit'; + +createApp({ + plugins: [ + serving({ + endpoints: { + demo: { env: "DATABRICKS_SERVING_ENDPOINT", servedModel: "my-model" }, + } + }), + ], +}); +`); + + const result = extractServingEndpoints("/app/server/index.ts"); + expect(result).toEqual({ + demo: { env: "DATABRICKS_SERVING_ENDPOINT", servedModel: "my-model" }, + }); + }); + + test("returns null when serving() has no arguments", () => { + mockServerFile(` +import { createApp, serving } from '@databricks/appkit'; + +createApp({ + plugins: [serving()], +}); +`); + + const result = extractServingEndpoints("/app/server/index.ts"); + expect(result).toBeNull(); + }); + + test("returns null when serving() has config but no endpoints", () => { + mockServerFile(` +import { createApp, serving } from '@databricks/appkit'; + +createApp({ + plugins: [ + serving({ timeout: 5000 }), + ], +}); +`); + + const result = extractServingEndpoints("/app/server/index.ts"); + expect(result).toBeNull(); + }); + + test("returns null when serving() has empty config object", () => { + mockServerFile(` +import { createApp, serving } from '@databricks/appkit'; + +createApp({ + plugins: [serving({})], +}); +`); + + const result = extractServingEndpoints("/app/server/index.ts"); + expect(result).toBeNull(); + }); + + test("returns null when endpoints is a variable reference", () => { + mockServerFile(` +import { createApp, serving } from '@databricks/appkit'; + +const myEndpoints = { demo: { env: "DATABRICKS_SERVING_ENDPOINT" } }; +createApp({ + plugins: [ + serving({ endpoints: myEndpoints }), + ], +}); +`); + + const result = extractServingEndpoints("/app/server/index.ts"); + expect(result).toBeNull(); + }); + + test("returns null when no serving() call exists", () => { + mockServerFile(` +import { createApp, analytics } from '@databricks/appkit'; + +createApp({ + plugins: [analytics({})], +}); +`); + + const result = extractServingEndpoints("/app/server/index.ts"); + expect(result).toBeNull(); + }); + + test("returns null when server file cannot be read", () => { + vi.spyOn(fs, "readFileSync").mockImplementation(() => { + throw new Error("ENOENT"); + }); + + const result = extractServingEndpoints("/app/server/nonexistent.ts"); + expect(result).toBeNull(); + }); + + test("handles single-quoted env values", () => { + mockServerFile(` +import { createApp, serving } from '@databricks/appkit'; + +createApp({ + plugins: [ + serving({ + endpoints: { + demo: { env: 'DATABRICKS_SERVING_ENDPOINT' }, + } + }), + ], +}); +`); + + const result = extractServingEndpoints("/app/server/index.ts"); + expect(result).toEqual({ + demo: { env: "DATABRICKS_SERVING_ENDPOINT" }, + }); + }); + + test("handles endpoints with trailing commas", () => { + mockServerFile(` +import { createApp, serving } from '@databricks/appkit'; + +createApp({ + plugins: [ + serving({ + endpoints: { + demo: { env: "DATABRICKS_SERVING_ENDPOINT" }, + second: { env: "DATABRICKS_SERVING_ENDPOINT_SECOND" }, + }, + }), + ], +}); +`); + + const result = extractServingEndpoints("/app/server/index.ts"); + expect(result).toEqual({ + demo: { env: "DATABRICKS_SERVING_ENDPOINT" }, + second: { env: "DATABRICKS_SERVING_ENDPOINT_SECOND" }, + }); + }); +}); diff --git a/packages/appkit/src/type-generator/serving/tests/vite-plugin.test.ts b/packages/appkit/src/type-generator/serving/tests/vite-plugin.test.ts new file mode 100644 index 00000000..bcd10915 --- /dev/null +++ b/packages/appkit/src/type-generator/serving/tests/vite-plugin.test.ts @@ -0,0 +1,186 @@ +import { afterEach, beforeEach, describe, expect, test, vi } from "vitest"; + +const mockGenerateServingTypes = vi.fn(async () => {}); +const mockFindServerFile = vi.fn((): string | null => null); +const mockExtractServingEndpoints = vi.fn( + (): Record | null => null, +); + +vi.mock("../generator", () => ({ + generateServingTypes: (...args: any[]) => mockGenerateServingTypes(...args), +})); + +vi.mock("../server-file-extractor", () => ({ + findServerFile: (...args: any[]) => mockFindServerFile(...args), + extractServingEndpoints: (...args: any[]) => + mockExtractServingEndpoints(...args), +})); + +import { appKitServingTypesPlugin } from "../vite-plugin"; + +describe("appKitServingTypesPlugin", () => { + const originalEnv = { ...process.env }; + + beforeEach(() => { + mockGenerateServingTypes.mockReset(); + mockFindServerFile.mockReset(); + mockExtractServingEndpoints.mockReset(); + }); + + afterEach(() => { + process.env = { ...originalEnv }; + vi.restoreAllMocks(); + }); + + describe("apply()", () => { + test("returns true when explicit endpoints provided", () => { + const plugin = appKitServingTypesPlugin({ + endpoints: { llm: { env: "LLM_ENDPOINT" } }, + }); + expect((plugin as any).apply()).toBe(true); + }); + + test("returns true when DATABRICKS_SERVING_ENDPOINT is set", () => { + process.env.DATABRICKS_SERVING_ENDPOINT = "my-endpoint"; + const plugin = appKitServingTypesPlugin(); + expect((plugin as any).apply()).toBe(true); + }); + + test("returns true when server file found in cwd", () => { + mockFindServerFile.mockReturnValueOnce("/app/server/index.ts"); + const plugin = appKitServingTypesPlugin(); + expect((plugin as any).apply()).toBe(true); + }); + + test("returns true when server file found in parent dir", () => { + mockFindServerFile + .mockReturnValueOnce(null) // cwd check + .mockReturnValueOnce("/app/server/index.ts"); // parent check + const plugin = appKitServingTypesPlugin(); + expect((plugin as any).apply()).toBe(true); + }); + + test("returns false when nothing configured", () => { + delete process.env.DATABRICKS_SERVING_ENDPOINT; + mockFindServerFile.mockReturnValue(null); + const plugin = appKitServingTypesPlugin(); + expect((plugin as any).apply()).toBe(false); + }); + }); + + describe("configResolved()", () => { + test("resolves outFile relative to config.root", async () => { + const plugin = appKitServingTypesPlugin({ + endpoints: { llm: { env: "LLM" } }, + }); + (plugin as any).configResolved({ root: "/app/client" }); + await (plugin as any).buildStart(); + + expect(mockGenerateServingTypes).toHaveBeenCalledWith( + expect.objectContaining({ + outFile: expect.stringContaining( + "/app/client/src/appKitServingTypes.d.ts", + ), + }), + ); + }); + + test("uses custom outFile when provided", async () => { + const plugin = appKitServingTypesPlugin({ + outFile: "types/serving.d.ts", + endpoints: { llm: { env: "LLM" } }, + }); + (plugin as any).configResolved({ root: "/app/client" }); + await (plugin as any).buildStart(); + + expect(mockGenerateServingTypes).toHaveBeenCalledWith( + expect.objectContaining({ + outFile: expect.stringContaining("types/serving.d.ts"), + }), + ); + }); + }); + + describe("buildStart()", () => { + test("calls generateServingTypes with explicit endpoints", async () => { + const endpoints = { llm: { env: "LLM_ENDPOINT" } }; + const plugin = appKitServingTypesPlugin({ endpoints }); + (plugin as any).configResolved({ root: "/app/client" }); + + await (plugin as any).buildStart(); + + expect(mockGenerateServingTypes).toHaveBeenCalledWith( + expect.objectContaining({ + endpoints, + noCache: false, + }), + ); + }); + + test("extracts endpoints from server file when not explicit", async () => { + const extracted = { llm: { env: "LLM_EP" } }; + mockFindServerFile.mockReturnValue("/app/server/index.ts"); + mockExtractServingEndpoints.mockReturnValue(extracted); + + const plugin = appKitServingTypesPlugin(); + (plugin as any).configResolved({ root: "/app/client" }); + await (plugin as any).buildStart(); + + expect(mockGenerateServingTypes).toHaveBeenCalledWith( + expect.objectContaining({ endpoints: extracted }), + ); + }); + + test("passes undefined endpoints when no server file found", async () => { + mockFindServerFile.mockReturnValue(null); + + const plugin = appKitServingTypesPlugin(); + (plugin as any).configResolved({ root: "/app/client" }); + await (plugin as any).buildStart(); + + expect(mockGenerateServingTypes).toHaveBeenCalledWith( + expect.objectContaining({ endpoints: undefined }), + ); + }); + + test("passes undefined when AST extraction returns null", async () => { + mockFindServerFile.mockReturnValue("/app/server/index.ts"); + mockExtractServingEndpoints.mockReturnValue(null); + + const plugin = appKitServingTypesPlugin(); + (plugin as any).configResolved({ root: "/app/client" }); + await (plugin as any).buildStart(); + + expect(mockGenerateServingTypes).toHaveBeenCalledWith( + expect.objectContaining({ endpoints: undefined }), + ); + }); + + test("swallows errors in dev mode", async () => { + process.env.NODE_ENV = "development"; + mockGenerateServingTypes.mockRejectedValue(new Error("fetch failed")); + + const plugin = appKitServingTypesPlugin({ + endpoints: { llm: { env: "LLM" } }, + }); + (plugin as any).configResolved({ root: "/app/client" }); + + // Should not throw + await expect((plugin as any).buildStart()).resolves.toBeUndefined(); + }); + + test("rethrows errors in production mode", async () => { + process.env.NODE_ENV = "production"; + mockGenerateServingTypes.mockRejectedValue(new Error("fetch failed")); + + const plugin = appKitServingTypesPlugin({ + endpoints: { llm: { env: "LLM" } }, + }); + (plugin as any).configResolved({ root: "/app/client" }); + + await expect((plugin as any).buildStart()).rejects.toThrow( + "fetch failed", + ); + }); + }); +}); diff --git a/packages/appkit/src/type-generator/serving/vite-plugin.ts b/packages/appkit/src/type-generator/serving/vite-plugin.ts new file mode 100644 index 00000000..9903a253 --- /dev/null +++ b/packages/appkit/src/type-generator/serving/vite-plugin.ts @@ -0,0 +1,109 @@ +import path from "node:path"; +import type { Plugin } from "vite"; +import { createLogger } from "../../logging/logger"; +import type { EndpointConfig } from "../../plugins/serving/types"; +import { generateServingTypes } from "./generator"; +import { + extractServingEndpoints, + findServerFile, +} from "./server-file-extractor"; + +const logger = createLogger("type-generator:serving:vite-plugin"); + +interface AppKitServingTypesPluginOptions { + /** Path to the output .d.ts file (relative to client root). Default: "src/appKitServingTypes.d.ts" */ + outFile?: string; + /** Endpoint config override. If omitted, auto-discovers from the server file or falls back to DATABRICKS_SERVING_ENDPOINT env var. */ + endpoints?: Record; +} + +/** + * Vite plugin to generate TypeScript types for AppKit serving endpoints. + * Fetches OpenAPI schemas from Databricks and generates a .d.ts with + * ServingEndpointRegistry module augmentation. + * + * Endpoint discovery order: + * 1. Explicit `endpoints` option (override) + * 2. AST extraction from server file (server/index.ts or server/server.ts) + * 3. DATABRICKS_SERVING_ENDPOINT env var (single default endpoint) + */ +export function appKitServingTypesPlugin( + options?: AppKitServingTypesPluginOptions, +): Plugin { + let outFile: string; + let projectRoot: string; + + async function generate() { + try { + // Resolve endpoints: explicit option > server file AST > env var fallback (handled by generator) + let endpoints = options?.endpoints; + if (!endpoints) { + const serverFile = findServerFile(projectRoot); + if (serverFile) { + endpoints = extractServingEndpoints(serverFile) ?? undefined; + } + } + + await generateServingTypes({ + outFile, + endpoints, + noCache: false, + }); + } catch (error) { + if (process.env.NODE_ENV === "production") { + throw error; + } + logger.error("Error generating serving types: %O", error); + } + } + + return { + name: "appkit-serving-types", + + apply() { + // Fast checks — no AST parsing here + if (options?.endpoints && Object.keys(options.endpoints).length > 0) { + return true; + } + + if (process.env.DATABRICKS_SERVING_ENDPOINT) { + return true; + } + + // Check if a server file exists (may contain serving() config) + // Use process.cwd() for apply() since configResolved hasn't run yet + if (findServerFile(process.cwd())) { + return true; + } + + // Also check parent dir (for when cwd is client/) + const parentDir = path.resolve(process.cwd(), ".."); + if (findServerFile(parentDir)) { + return true; + } + + logger.debug( + "No serving endpoints configured. Skipping type generation.", + ); + return false; + }, + + configResolved(config) { + // Resolve project root: go up one level from Vite root (client dir) + // This handles both: + // - pnpm dev: process.cwd() is app root, config.root is client/ + // - pnpm build: process.cwd() is client/ (cd client && vite build), config.root is client/ + projectRoot = path.resolve(config.root, ".."); + outFile = path.resolve( + config.root, + options?.outFile ?? "src/appKitServingTypes.d.ts", + ); + }, + + async buildStart() { + await generate(); + }, + + // No configureServer / watcher — schemas change on endpoint redeploy, not on file edit + }; +} diff --git a/pnpm-lock.yaml b/pnpm-lock.yaml index 199fcfb8..9ca11b81 100644 --- a/pnpm-lock.yaml +++ b/pnpm-lock.yaml @@ -242,6 +242,9 @@ importers: packages/appkit: dependencies: + '@ast-grep/napi': + specifier: 0.37.0 + version: 0.37.0 '@databricks/lakebase': specifier: workspace:* version: link:../lakebase From 72610a7d3de4a289aef4471a061a89c1d78fc9f6 Mon Sep 17 00:00:00 2001 From: Pawel Kosiec Date: Fri, 3 Apr 2026 12:47:22 +0200 Subject: [PATCH 2/6] fix: use structured requestKeys in cache, add stream body override - Store requestKeys[] in serving cache instead of regex-parsing TypeScript type strings in schema-filter (fragile indentation dependency) - Add overrideBody parameter to useServingStream's stream() to allow callers to pass fresh body without waiting for useMemo recomputation - Lazy-init WorkspaceClient in type generator (skip when no endpoints resolve) Signed-off-by: Pawel Kosiec --- .../__tests__/use-serving-stream.test.ts | 20 +++++ .../src/react/hooks/use-serving-stream.ts | 84 ++++++++++--------- .../src/plugins/serving/schema-filter.ts | 26 +----- .../serving/tests/schema-filter.test.ts | 30 +++++-- .../src/type-generator/serving/cache.ts | 1 + .../src/type-generator/serving/converter.ts | 10 +++ .../src/type-generator/serving/generator.ts | 14 +++- .../serving/tests/cache.test.ts | 2 + .../serving/tests/converter.test.ts | 30 +++++++ 9 files changed, 147 insertions(+), 70 deletions(-) diff --git a/packages/appkit-ui/src/react/hooks/__tests__/use-serving-stream.test.ts b/packages/appkit-ui/src/react/hooks/__tests__/use-serving-stream.test.ts index 0a1a736c..1ab0bf44 100644 --- a/packages/appkit-ui/src/react/hooks/__tests__/use-serving-stream.test.ts +++ b/packages/appkit-ui/src/react/hooks/__tests__/use-serving-stream.test.ts @@ -61,6 +61,26 @@ describe("useServingStream", () => { ); }); + test("uses override body when passed to stream()", () => { + const { result } = renderHook(() => + useServingStream({ messages: [{ role: "user", content: "old" }] }), + ); + + const overrideBody = { + messages: [{ role: "user" as const, content: "new" }], + }; + + act(() => { + result.current.stream(overrideBody); + }); + + expect(mockConnectSSE).toHaveBeenCalledWith( + expect.objectContaining({ + payload: JSON.stringify(overrideBody), + }), + ); + }); + test("uses alias in URL when provided", () => { const { result } = renderHook(() => useServingStream({ messages: [] }, { alias: "embedder" }), diff --git a/packages/appkit-ui/src/react/hooks/use-serving-stream.ts b/packages/appkit-ui/src/react/hooks/use-serving-stream.ts index 4801d94c..25cb90a7 100644 --- a/packages/appkit-ui/src/react/hooks/use-serving-stream.ts +++ b/packages/appkit-ui/src/react/hooks/use-serving-stream.ts @@ -18,9 +18,12 @@ export interface UseServingStreamOptions< onComplete?: (chunks: T[]) => void; } -export interface UseServingStreamResult { - /** Trigger the streaming invocation. */ - stream: () => void; +export interface UseServingStreamResult< + T = unknown, + TBody = Record, +> { + /** Trigger the streaming invocation. Pass an optional body override for this invocation. */ + stream: (overrideBody?: TBody) => void; /** Accumulated chunks received so far. */ chunks: T[]; /** Whether streaming is in progress. */ @@ -42,7 +45,7 @@ export interface UseServingStreamResult { export function useServingStream( body: InferServingRequest, options: UseServingStreamOptions = {} as UseServingStreamOptions, -): UseServingStreamResult> { +): UseServingStreamResult, InferServingRequest> { type TChunk = InferServingChunk; const { alias, autoStart = false, onComplete } = options; @@ -69,45 +72,50 @@ export function useServingStream( const bodyJson = JSON.stringify(body); - const stream = useCallback(() => { - // Abort any existing stream - abortControllerRef.current?.abort(); + const stream = useCallback( + (overrideBody?: InferServingRequest) => { + // Abort any existing stream + abortControllerRef.current?.abort(); - setStreaming(true); - setError(null); - setChunks([]); - chunksRef.current = []; + setStreaming(true); + setError(null); + setChunks([]); + chunksRef.current = []; - const abortController = new AbortController(); - abortControllerRef.current = abortController; + const abortController = new AbortController(); + abortControllerRef.current = abortController; - connectSSE({ - url: urlSuffix, - payload: bodyJson, - signal: abortController.signal, - onMessage: async (message) => { - if (abortController.signal.aborted) return; - try { - const parsed = JSON.parse(message.data); - - chunksRef.current = [...chunksRef.current, parsed as TChunk]; - setChunks(chunksRef.current); - } catch { - // Skip malformed messages - } - }, - onError: (err) => { + const payload = overrideBody ? JSON.stringify(overrideBody) : bodyJson; + + connectSSE({ + url: urlSuffix, + payload, + signal: abortController.signal, + onMessage: async (message) => { + if (abortController.signal.aborted) return; + try { + const parsed = JSON.parse(message.data); + + chunksRef.current = [...chunksRef.current, parsed as TChunk]; + setChunks(chunksRef.current); + } catch { + // Skip malformed messages + } + }, + onError: (err) => { + if (abortController.signal.aborted) return; + setStreaming(false); + setError(err instanceof Error ? err.message : "Streaming failed"); + }, + }).then(() => { if (abortController.signal.aborted) return; + // Stream completed setStreaming(false); - setError(err instanceof Error ? err.message : "Streaming failed"); - }, - }).then(() => { - if (abortController.signal.aborted) return; - // Stream completed - setStreaming(false); - onCompleteRef.current?.(chunksRef.current); - }); - }, [urlSuffix, bodyJson]); + onCompleteRef.current?.(chunksRef.current); + }); + }, + [urlSuffix, bodyJson], + ); useEffect(() => { if (autoStart) { diff --git a/packages/appkit/src/plugins/serving/schema-filter.ts b/packages/appkit/src/plugins/serving/schema-filter.ts index 07683ede..92a25c69 100644 --- a/packages/appkit/src/plugins/serving/schema-filter.ts +++ b/packages/appkit/src/plugins/serving/schema-filter.ts @@ -37,11 +37,8 @@ export async function loadEndpointSchemas( const cache = parsed; for (const [alias, entry] of Object.entries(cache.endpoints)) { - // Extract property keys from the requestType string - // The requestType is a TypeScript object type like "{ messages: ...; temperature: ...; }" - const keys = extractPropertyKeys(entry.requestType); - if (keys.size > 0) { - allowlists.set(alias, keys); + if (entry.requestKeys && entry.requestKeys.length > 0) { + allowlists.set(alias, new Set(entry.requestKeys)); } } } catch (err) { @@ -57,25 +54,6 @@ export async function loadEndpointSchemas( return allowlists; } -/** - * Extracts top-level property keys from a TypeScript object type string. - * Matches patterns like `key:` or `key?:` at the first nesting level. - */ -function extractPropertyKeys(typeStr: string): Set { - const keys = new Set(); - // Match property names at the top level of the object type - // Looking for patterns: ` propertyName:` or ` propertyName?:` - const propRegex = /^\s{2}(?:\/\*\*[^*]*\*\/\s*)?(\w+)\??:/gm; - for ( - let match = propRegex.exec(typeStr); - match !== null; - match = propRegex.exec(typeStr) - ) { - keys.add(match[1]); - } - return keys; -} - /** * Filters a request body against the allowed keys for an endpoint alias. * Returns the filtered body and logs a warning for stripped params. diff --git a/packages/appkit/src/plugins/serving/tests/schema-filter.test.ts b/packages/appkit/src/plugins/serving/tests/schema-filter.test.ts index 948b47f9..4fc030d8 100644 --- a/packages/appkit/src/plugins/serving/tests/schema-filter.test.ts +++ b/packages/appkit/src/plugins/serving/tests/schema-filter.test.ts @@ -109,7 +109,7 @@ describe("schema-filter", () => { expect(result.size).toBe(0); }); - test("extracts property keys from cached types", async () => { + test("reads requestKeys from cache entries", async () => { const fs = (await import("node:fs/promises")).default; vi.mocked(fs.readFile).mockResolvedValue( JSON.stringify({ @@ -117,13 +117,10 @@ describe("schema-filter", () => { endpoints: { default: { hash: "abc", - requestType: `{ - messages: string[]; - temperature?: number | null; - max_tokens: number; -}`, + requestType: "{}", responseType: "{}", chunkType: null, + requestKeys: ["messages", "temperature", "max_tokens"], }, }, }), @@ -137,5 +134,26 @@ describe("schema-filter", () => { expect(keys?.has("temperature")).toBe(true); expect(keys?.has("max_tokens")).toBe(true); }); + + test("skips entries without requestKeys (backwards compat)", async () => { + const fs = (await import("node:fs/promises")).default; + vi.mocked(fs.readFile).mockResolvedValue( + JSON.stringify({ + version: "1", + endpoints: { + default: { + hash: "abc", + requestType: "{ messages: string[] }", + responseType: "{}", + chunkType: null, + }, + }, + }), + ); + + const result = await loadEndpointSchemas("/some/path"); + // No requestKeys → passthrough mode (no allowlist) + expect(result.size).toBe(0); + }); }); }); diff --git a/packages/appkit/src/type-generator/serving/cache.ts b/packages/appkit/src/type-generator/serving/cache.ts index 2737f117..dc9bf7e2 100644 --- a/packages/appkit/src/type-generator/serving/cache.ts +++ b/packages/appkit/src/type-generator/serving/cache.ts @@ -19,6 +19,7 @@ export interface ServingCacheEntry { requestType: string; responseType: string; chunkType: string | null; + requestKeys: string[]; } export interface ServingCache { diff --git a/packages/appkit/src/type-generator/serving/converter.ts b/packages/appkit/src/type-generator/serving/converter.ts index 1849e720..b56b0460 100644 --- a/packages/appkit/src/type-generator/serving/converter.ts +++ b/packages/appkit/src/type-generator/serving/converter.ts @@ -53,6 +53,16 @@ function schemaToTypeString(schema: OpenApiSchema, indent = 0): string { } } +/** + * Extracts the top-level property keys from the request schema. + * Strips the `stream` property (plugin-controlled). + */ +export function extractRequestKeys(operation: OpenApiOperation): string[] { + const schema = operation.requestBody?.content?.["application/json"]?.schema; + if (!schema?.properties) return []; + return Object.keys(schema.properties).filter((k) => k !== "stream"); +} + /** * Extracts and converts the request schema from an OpenAPI path operation. * Strips the `stream` property from the request type. diff --git a/packages/appkit/src/type-generator/serving/generator.ts b/packages/appkit/src/type-generator/serving/generator.ts index 44026f89..2cd88619 100644 --- a/packages/appkit/src/type-generator/serving/generator.ts +++ b/packages/appkit/src/type-generator/serving/generator.ts @@ -14,6 +14,7 @@ import { convertRequestSchema, convertResponseSchema, deriveChunkType, + extractRequestKeys, } from "./converter"; import { fetchOpenApiSchema } from "./fetcher"; @@ -51,7 +52,7 @@ export async function generateServingTypes( ? { version: CACHE_VERSION, endpoints: {} } : await loadServingCache(); - const client = new WorkspaceClient({}); + let client: WorkspaceClient | undefined; let updated = false; const registryEntries: string[] = []; @@ -80,6 +81,7 @@ export async function generateServingTypes( continue; } + client ??= new WorkspaceClient({}); const result = await fetchOpenApiSchema( client, endpointName, @@ -135,10 +137,12 @@ export async function generateServingTypes( let requestType: string; let responseType: string; let chunkType: string | null; + let requestKeys: string[]; try { requestType = convertRequestSchema(operation); responseType = convertResponseSchema(operation); chunkType = deriveChunkType(operation); + requestKeys = extractRequestKeys(operation); } catch (convErr) { logger.warn( "Schema conversion failed for '%s': %s", @@ -161,7 +165,13 @@ export async function generateServingTypes( continue; } - cache.endpoints[alias] = { hash, requestType, responseType, chunkType }; + cache.endpoints[alias] = { + hash, + requestType, + responseType, + chunkType, + requestKeys, + }; updated = true; registryEntries.push( diff --git a/packages/appkit/src/type-generator/serving/tests/cache.test.ts b/packages/appkit/src/type-generator/serving/tests/cache.test.ts index 1c0ab21c..0c99c997 100644 --- a/packages/appkit/src/type-generator/serving/tests/cache.test.ts +++ b/packages/appkit/src/type-generator/serving/tests/cache.test.ts @@ -53,6 +53,7 @@ describe("serving cache", () => { requestType: "{ messages: string[] }", responseType: "{ model: string }", chunkType: null, + requestKeys: ["messages"], }, }, }; @@ -91,6 +92,7 @@ describe("serving cache", () => { requestType: "{}", responseType: "{}", chunkType: null, + requestKeys: [], }, }, }; diff --git a/packages/appkit/src/type-generator/serving/tests/converter.test.ts b/packages/appkit/src/type-generator/serving/tests/converter.test.ts index ca794fb3..1be30738 100644 --- a/packages/appkit/src/type-generator/serving/tests/converter.test.ts +++ b/packages/appkit/src/type-generator/serving/tests/converter.test.ts @@ -3,6 +3,7 @@ import { convertRequestSchema, convertResponseSchema, deriveChunkType, + extractRequestKeys, } from "../converter"; import type { OpenApiOperation, OpenApiSchema } from "../fetcher"; @@ -275,4 +276,33 @@ describe("converter", () => { expect(deriveChunkType(op)).toBeNull(); }); }); + + describe("extractRequestKeys", () => { + test("extracts top-level property keys excluding stream", () => { + const op = makeOperation({ + messages: { type: "array", items: { type: "string" } }, + temperature: { type: "number" }, + stream: { type: "boolean", nullable: true }, + }); + expect(extractRequestKeys(op)).toEqual(["messages", "temperature"]); + }); + + test("returns empty array for missing schema", () => { + const op: OpenApiOperation = {}; + expect(extractRequestKeys(op)).toEqual([]); + }); + + test("returns empty array for schema without properties", () => { + const op: OpenApiOperation = { + requestBody: { + content: { + "application/json": { + schema: { type: "object" }, + }, + }, + }, + }; + expect(extractRequestKeys(op)).toEqual([]); + }); + }); }); From 8c4ccfd571c76a830a48944453d7765b05ed837c Mon Sep 17 00:00:00 2001 From: Pawel Kosiec Date: Fri, 3 Apr 2026 13:08:10 +0200 Subject: [PATCH 3/6] fix: clear chunks after stream completion in useServingStream Chunks persisted after onComplete, causing the streaming bubble to remain visible alongside the committed message (duplicate response). Now chunks are cleared atomically with setStreaming(false) so React batches all state updates in one render. Signed-off-by: Pawel Kosiec --- .../__tests__/use-serving-stream.test.ts | 37 +++++++++++++++++++ .../src/react/hooks/use-serving-stream.ts | 4 +- 2 files changed, 40 insertions(+), 1 deletion(-) diff --git a/packages/appkit-ui/src/react/hooks/__tests__/use-serving-stream.test.ts b/packages/appkit-ui/src/react/hooks/__tests__/use-serving-stream.test.ts index 1ab0bf44..ecc00e9f 100644 --- a/packages/appkit-ui/src/react/hooks/__tests__/use-serving-stream.test.ts +++ b/packages/appkit-ui/src/react/hooks/__tests__/use-serving-stream.test.ts @@ -288,4 +288,41 @@ describe("useServingStream", () => { expect(onComplete).toHaveBeenCalledWith([{ id: 1 }, { id: 2 }]); }); + + test("clears chunks after stream completes", async () => { + // Use a controllable mock so stream doesn't auto-resolve + mockConnectSSE.mockImplementationOnce((opts: any) => { + capturedCallbacks = { + onMessage: opts.onMessage, + onError: opts.onError, + signal: opts.signal, + }; + return new Promise((resolve) => { + resolveStream = resolve; + }); + }); + + const { result } = renderHook(() => useServingStream({ messages: [] })); + + act(() => { + result.current.stream(); + }); + + // Send a chunk + act(() => { + capturedCallbacks.onMessage?.({ data: JSON.stringify({ id: 1 }) }); + }); + + expect(result.current.chunks).toEqual([{ id: 1 }]); + + // Complete the stream + await act(async () => { + resolveStream?.(); + await new Promise((r) => setTimeout(r, 0)); + }); + + // Chunks should be cleared after completion + expect(result.current.chunks).toEqual([]); + expect(result.current.streaming).toBe(false); + }); }); diff --git a/packages/appkit-ui/src/react/hooks/use-serving-stream.ts b/packages/appkit-ui/src/react/hooks/use-serving-stream.ts index 25cb90a7..d34b5559 100644 --- a/packages/appkit-ui/src/react/hooks/use-serving-stream.ts +++ b/packages/appkit-ui/src/react/hooks/use-serving-stream.ts @@ -109,9 +109,11 @@ export function useServingStream( }, }).then(() => { if (abortController.signal.aborted) return; - // Stream completed + // Stream completed — let onComplete consume chunks, then clear them setStreaming(false); onCompleteRef.current?.(chunksRef.current); + chunksRef.current = []; + setChunks([]); }); }, [urlSuffix, bodyJson], From 3a4c52d4eb5ab4741671f94ef4f6b4d530f5b607 Mon Sep 17 00:00:00 2001 From: Pawel Kosiec Date: Fri, 3 Apr 2026 13:18:21 +0200 Subject: [PATCH 4/6] fix: revert chunk-clearing in useServingStream completion MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Clearing chunks in the hook's .then() handler caused a race with React batching — chunks were empty before the component could commit them. Let consumers decide when to clear via reset() instead. Signed-off-by: Pawel Kosiec --- .../__tests__/use-serving-stream.test.ts | 37 ------------------- .../src/react/hooks/use-serving-stream.ts | 4 +- 2 files changed, 1 insertion(+), 40 deletions(-) diff --git a/packages/appkit-ui/src/react/hooks/__tests__/use-serving-stream.test.ts b/packages/appkit-ui/src/react/hooks/__tests__/use-serving-stream.test.ts index ecc00e9f..1ab0bf44 100644 --- a/packages/appkit-ui/src/react/hooks/__tests__/use-serving-stream.test.ts +++ b/packages/appkit-ui/src/react/hooks/__tests__/use-serving-stream.test.ts @@ -288,41 +288,4 @@ describe("useServingStream", () => { expect(onComplete).toHaveBeenCalledWith([{ id: 1 }, { id: 2 }]); }); - - test("clears chunks after stream completes", async () => { - // Use a controllable mock so stream doesn't auto-resolve - mockConnectSSE.mockImplementationOnce((opts: any) => { - capturedCallbacks = { - onMessage: opts.onMessage, - onError: opts.onError, - signal: opts.signal, - }; - return new Promise((resolve) => { - resolveStream = resolve; - }); - }); - - const { result } = renderHook(() => useServingStream({ messages: [] })); - - act(() => { - result.current.stream(); - }); - - // Send a chunk - act(() => { - capturedCallbacks.onMessage?.({ data: JSON.stringify({ id: 1 }) }); - }); - - expect(result.current.chunks).toEqual([{ id: 1 }]); - - // Complete the stream - await act(async () => { - resolveStream?.(); - await new Promise((r) => setTimeout(r, 0)); - }); - - // Chunks should be cleared after completion - expect(result.current.chunks).toEqual([]); - expect(result.current.streaming).toBe(false); - }); }); diff --git a/packages/appkit-ui/src/react/hooks/use-serving-stream.ts b/packages/appkit-ui/src/react/hooks/use-serving-stream.ts index d34b5559..25cb90a7 100644 --- a/packages/appkit-ui/src/react/hooks/use-serving-stream.ts +++ b/packages/appkit-ui/src/react/hooks/use-serving-stream.ts @@ -109,11 +109,9 @@ export function useServingStream( }, }).then(() => { if (abortController.signal.aborted) return; - // Stream completed — let onComplete consume chunks, then clear them + // Stream completed setStreaming(false); onCompleteRef.current?.(chunksRef.current); - chunksRef.current = []; - setChunks([]); }); }, [urlSuffix, bodyJson], From 86870670b5c2d03db70dd9efbe917d52eaa70d46 Mon Sep 17 00:00:00 2001 From: Pawel Kosiec Date: Fri, 3 Apr 2026 14:21:02 +0200 Subject: [PATCH 5/6] fix: add catch handler to connectSSE promise in useServingStream Without a .catch(), if connectSSE rejects the promise is unhandled and setStreaming(false) never fires, leaving the hook in a broken state. This matches the pattern used by the genie chat hook. Signed-off-by: Pawel Kosiec --- .../src/react/hooks/use-serving-stream.ts | 18 ++++++++++++------ 1 file changed, 12 insertions(+), 6 deletions(-) diff --git a/packages/appkit-ui/src/react/hooks/use-serving-stream.ts b/packages/appkit-ui/src/react/hooks/use-serving-stream.ts index 25cb90a7..f0bb7bf2 100644 --- a/packages/appkit-ui/src/react/hooks/use-serving-stream.ts +++ b/packages/appkit-ui/src/react/hooks/use-serving-stream.ts @@ -107,12 +107,18 @@ export function useServingStream( setStreaming(false); setError(err instanceof Error ? err.message : "Streaming failed"); }, - }).then(() => { - if (abortController.signal.aborted) return; - // Stream completed - setStreaming(false); - onCompleteRef.current?.(chunksRef.current); - }); + }) + .then(() => { + if (abortController.signal.aborted) return; + // Stream completed + setStreaming(false); + onCompleteRef.current?.(chunksRef.current); + }) + .catch(() => { + if (abortController.signal.aborted) return; + setStreaming(false); + setError("Connection error"); + }); }, [urlSuffix, bodyJson], ); From 1d503def4b6a0235e9d8f50ae64f509bea4b44cc Mon Sep 17 00:00:00 2001 From: Pawel Kosiec Date: Fri, 3 Apr 2026 15:29:35 +0200 Subject: [PATCH 6/6] fix: add overrideBody parameter to useServingInvoke The invoke callback is recreated whenever body changes (via useCallback deps), which triggers the useEffect cleanup that aborts in-flight requests. Adding overrideBody allows callers to use a stable body while passing the real payload per-invocation, matching useServingStream. Signed-off-by: Pawel Kosiec --- .../src/react/hooks/use-serving-invoke.ts | 84 ++++++++++--------- 1 file changed, 46 insertions(+), 38 deletions(-) diff --git a/packages/appkit-ui/src/react/hooks/use-serving-invoke.ts b/packages/appkit-ui/src/react/hooks/use-serving-invoke.ts index 343a5e71..8e80e82e 100644 --- a/packages/appkit-ui/src/react/hooks/use-serving-invoke.ts +++ b/packages/appkit-ui/src/react/hooks/use-serving-invoke.ts @@ -14,9 +14,12 @@ export interface UseServingInvokeOptions< autoStart?: boolean; } -export interface UseServingInvokeResult { - /** Trigger the invocation. Returns the response data, or null on error/abort. */ - invoke: () => Promise; +export interface UseServingInvokeResult< + T = unknown, + TBody = Record, +> { + /** Trigger the invocation. Pass an optional body override for this invocation. */ + invoke: (overrideBody?: TBody) => Promise; /** Response data, null until loaded. */ data: T | null; /** Whether a request is in progress. */ @@ -35,7 +38,7 @@ export interface UseServingInvokeResult { export function useServingInvoke( body: InferServingRequest, options: UseServingInvokeOptions = {} as UseServingInvokeOptions, -): UseServingInvokeResult> { +): UseServingInvokeResult, InferServingRequest> { type TResponse = InferServingResponse; const { alias, autoStart = false } = options; @@ -50,44 +53,49 @@ export function useServingInvoke( const bodyJson = JSON.stringify(body); - const invoke = useCallback((): Promise => { - if (abortControllerRef.current) { - abortControllerRef.current.abort(); - } + const invoke = useCallback( + (overrideBody?: InferServingRequest): Promise => { + if (abortControllerRef.current) { + abortControllerRef.current.abort(); + } - setLoading(true); - setError(null); - setData(null); + setLoading(true); + setError(null); + setData(null); - const abortController = new AbortController(); - abortControllerRef.current = abortController; + const abortController = new AbortController(); + abortControllerRef.current = abortController; - return fetch(urlSuffix, { - method: "POST", - headers: { "Content-Type": "application/json" }, - body: bodyJson, - signal: abortController.signal, - }) - .then(async (res) => { - if (!res.ok) { - const errorBody = await res.json().catch(() => null); - throw new Error(errorBody?.error || `HTTP ${res.status}`); - } - return res.json(); - }) - .then((result: TResponse) => { - if (abortController.signal.aborted) return null; - setData(result); - setLoading(false); - return result; + const payload = overrideBody ? JSON.stringify(overrideBody) : bodyJson; + + return fetch(urlSuffix, { + method: "POST", + headers: { "Content-Type": "application/json" }, + body: payload, + signal: abortController.signal, }) - .catch((err: Error) => { - if (abortController.signal.aborted) return null; - setError(err.message || "Request failed"); - setLoading(false); - return null; - }); - }, [urlSuffix, bodyJson]); + .then(async (res) => { + if (!res.ok) { + const errorBody = await res.json().catch(() => null); + throw new Error(errorBody?.error || `HTTP ${res.status}`); + } + return res.json(); + }) + .then((result: TResponse) => { + if (abortController.signal.aborted) return null; + setData(result); + setLoading(false); + return result; + }) + .catch((err: Error) => { + if (abortController.signal.aborted) return null; + setError(err.message || "Request failed"); + setLoading(false); + return null; + }); + }, + [urlSuffix, bodyJson], + ); useEffect(() => { if (autoStart) {