diff --git a/__tests__/deferCleanup.test.ts b/__tests__/deferCleanup.test.ts new file mode 100644 index 00000000..60371bbb --- /dev/null +++ b/__tests__/deferCleanup.test.ts @@ -0,0 +1,508 @@ +import { + afterEach, + assert, + beforeEach, + describe, + expect, + test, + vi, +} from 'vitest'; +import { Type } from '@sinclair/typebox'; +import { + createClient, + createServer, + Ok, + Procedure, + createServiceSchema, + Middleware, +} from '../router'; +import { createMockTransportNetwork } from '../testUtil'; +import { waitFor } from '../testUtil/fixtures/cleanup'; + +describe('deferCleanup', () => { + let mockTransportNetwork: ReturnType; + + beforeEach(async () => { + mockTransportNetwork = createMockTransportNetwork(); + }); + + afterEach(async () => { + await mockTransportNetwork.cleanup(); + }); + + test('cleanups run in LIFO order', async () => { + const order: Array = []; + + const services = { + test: createServiceSchema().define({ + myRpc: Procedure.rpc({ + requestInit: Type.Object({}), + responseData: Type.Object({}), + async handler({ ctx }) { + ctx.deferCleanup(() => { + order.push(1); + }); + ctx.deferCleanup(() => { + order.push(2); + }); + ctx.deferCleanup(() => { + order.push(3); + }); + + return Ok({}); + }, + }), + }), + }; + + createServer(mockTransportNetwork.getServerTransport(), services); + const client = createClient( + mockTransportNetwork.getClientTransport('client'), + 'SERVER', + ); + + await client.test.myRpc.rpc({}); + + await waitFor(() => { + expect(order).toEqual([3, 2, 1]); + }); + }); + + test('cleanups run even when handler throws', async () => { + const cleanupRan = vi.fn(); + + const services = { + test: createServiceSchema().define({ + myRpc: Procedure.rpc({ + requestInit: Type.Object({}), + responseData: Type.Object({}), + async handler({ ctx }) { + ctx.deferCleanup(cleanupRan); + throw new Error('handler error'); + }, + }), + }), + }; + + createServer(mockTransportNetwork.getServerTransport(), services); + const client = createClient( + mockTransportNetwork.getClientTransport('client'), + 'SERVER', + ); + + await client.test.myRpc.rpc({}); + + await waitFor(() => { + expect(cleanupRan).toHaveBeenCalledOnce(); + }); + }); + + test('cleanups run when handler is cancelled', async () => { + const cleanupRan = vi.fn(); + + const services = { + test: createServiceSchema().define({ + myRpc: Procedure.rpc({ + requestInit: Type.Object({}), + responseData: Type.Object({}), + async handler({ ctx }) { + ctx.deferCleanup(cleanupRan); + ctx.cancel('test cancel'); + + return Ok({}); + }, + }), + }), + }; + + createServer(mockTransportNetwork.getServerTransport(), services); + const client = createClient( + mockTransportNetwork.getClientTransport('client'), + 'SERVER', + ); + + await client.test.myRpc.rpc({}); + + await waitFor(() => { + expect(cleanupRan).toHaveBeenCalledOnce(); + }); + }); + + test('one cleanup throwing does not stop remaining cleanups', async () => { + const cleanup1 = vi.fn(); + const cleanup3 = vi.fn(); + + const services = { + test: createServiceSchema().define({ + myRpc: Procedure.rpc({ + requestInit: Type.Object({}), + responseData: Type.Object({}), + async handler({ ctx }) { + ctx.deferCleanup(cleanup1); + ctx.deferCleanup(() => { + throw new Error('cleanup error'); + }); + ctx.deferCleanup(cleanup3); + + return Ok({}); + }, + }), + }), + }; + + createServer(mockTransportNetwork.getServerTransport(), services); + const client = createClient( + mockTransportNetwork.getClientTransport('client'), + 'SERVER', + ); + + await client.test.myRpc.rpc({}); + + // LIFO: cleanup3 runs first, then the throwing one, then cleanup1 + await waitFor(() => { + expect(cleanup3).toHaveBeenCalledOnce(); + expect(cleanup1).toHaveBeenCalledOnce(); + }); + }); + + test('async cleanups are awaited', async () => { + const order: Array = []; + + const services = { + test: createServiceSchema().define({ + myRpc: Procedure.rpc({ + requestInit: Type.Object({}), + responseData: Type.Object({}), + async handler({ ctx }) { + ctx.deferCleanup(async () => { + await new Promise((resolve) => setTimeout(resolve, 10)); + order.push(1); + }); + ctx.deferCleanup(async () => { + await new Promise((resolve) => setTimeout(resolve, 10)); + order.push(2); + }); + + return Ok({}); + }, + }), + }), + }; + + createServer(mockTransportNetwork.getServerTransport(), services); + const client = createClient( + mockTransportNetwork.getClientTransport('client'), + 'SERVER', + ); + + await client.test.myRpc.rpc({}); + + await waitFor(() => { + expect(order).toEqual([2, 1]); + }); + }); + + test('deferCleanup works for stream procedures', async () => { + const order: Array = []; + + const services = { + test: createServiceSchema().define({ + myStream: Procedure.stream({ + requestInit: Type.Object({}), + requestData: Type.Object({}), + responseData: Type.Object({}), + async handler({ ctx, resWritable }) { + ctx.deferCleanup(() => { + order.push(1); + }); + ctx.deferCleanup(() => { + order.push(2); + }); + + resWritable.write(Ok({})); + resWritable.close(); + }, + }), + }), + }; + + createServer(mockTransportNetwork.getServerTransport(), services); + const client = createClient( + mockTransportNetwork.getClientTransport('client'), + 'SERVER', + ); + + const { reqWritable, resReadable } = client.test.myStream.stream({}); + reqWritable.close(); + + // drain the readable + for await (const _ of resReadable) { + // consume + } + + await waitFor(() => { + expect(order).toEqual([2, 1]); + }); + }); + + test('deferCleanup works for subscription procedures', async () => { + const cleanupRan = vi.fn(); + + const services = { + test: createServiceSchema().define({ + mySub: Procedure.subscription({ + requestInit: Type.Object({}), + responseData: Type.Object({}), + async handler({ ctx, resWritable }) { + ctx.deferCleanup(cleanupRan); + + resWritable.write(Ok({})); + resWritable.close(); + }, + }), + }), + }; + + createServer(mockTransportNetwork.getServerTransport(), services); + const client = createClient( + mockTransportNetwork.getClientTransport('client'), + 'SERVER', + ); + + const { resReadable } = client.test.mySub.subscribe({}); + + // drain the readable + for await (const _ of resReadable) { + // consume + } + + await waitFor(() => { + expect(cleanupRan).toHaveBeenCalledOnce(); + }); + }); + + test('deferCleanup works for upload procedures', async () => { + const order: Array = []; + + const services = { + test: createServiceSchema().define({ + myUpload: Procedure.upload({ + requestInit: Type.Object({}), + requestData: Type.Object({ n: Type.Number() }), + responseData: Type.Object({ result: Type.Number() }), + async handler({ ctx, reqReadable }) { + ctx.deferCleanup(() => { + order.push(1); + }); + ctx.deferCleanup(() => { + order.push(2); + }); + + let sum = 0; + for await (const msg of reqReadable) { + if (msg.ok) { + sum += msg.payload.n; + } + } + + return Ok({ result: sum }); + }, + }), + }), + }; + + createServer(mockTransportNetwork.getServerTransport(), services); + const client = createClient( + mockTransportNetwork.getClientTransport('client'), + 'SERVER', + ); + + const { reqWritable, finalize } = client.test.myUpload.upload({}); + reqWritable.write({ n: 1 }); + reqWritable.write({ n: 2 }); + reqWritable.close(); + + const result = await finalize(); + expect(result).toStrictEqual({ ok: true, payload: { result: 3 } }); + + await waitFor(() => { + expect(order).toEqual([2, 1]); + }); + }); + + test('middleware deferCleanup runs after handler', async () => { + const order: Array = []; + + const middleware: Middleware = ({ ctx, next }) => { + ctx.deferCleanup(() => { + order.push('middleware-cleanup'); + }); + next(); + }; + + const services = { + test: createServiceSchema().define({ + myRpc: Procedure.rpc({ + requestInit: Type.Object({}), + responseData: Type.Object({}), + async handler({ ctx }) { + ctx.deferCleanup(() => { + order.push('handler-cleanup'); + }); + + return Ok({}); + }, + }), + }), + }; + + createServer(mockTransportNetwork.getServerTransport(), services, { + middlewares: [middleware], + }); + const client = createClient( + mockTransportNetwork.getClientTransport('client'), + 'SERVER', + ); + + await client.test.myRpc.rpc({}); + + // Both middleware and handler share the same cleanup stack. + // Since middleware registers before the handler, handler cleanup (LIFO) + // runs first, then middleware cleanup. + await waitFor(() => { + expect(order).toEqual(['handler-cleanup', 'middleware-cleanup']); + }); + }); + + test('cleanup that registers more cleanups', async () => { + const order: Array = []; + + const services = { + test: createServiceSchema().define({ + myRpc: Procedure.rpc({ + requestInit: Type.Object({}), + responseData: Type.Object({}), + async handler({ ctx }) { + ctx.deferCleanup(() => { + order.push(1); + }); + ctx.deferCleanup(() => { + order.push(2); + // Register another cleanup during execution. + // Since cleanups haven't finished, this gets pushed + // to the array and picked up by the pop() loop. + ctx.deferCleanup(() => { + order.push(3); + }); + }); + + return Ok({}); + }, + }), + }), + }; + + createServer(mockTransportNetwork.getServerTransport(), services); + const client = createClient( + mockTransportNetwork.getClientTransport('client'), + 'SERVER', + ); + + await client.test.myRpc.rpc({}); + + // LIFO: 2 pops first (pushes 3), then 3 pops next, then 1 + await waitFor(() => { + expect(order).toEqual([2, 3, 1]); + }); + }); + + test('stream cleanup runs when streams close, not when handler returns', async () => { + const cleanupRan = vi.fn(); + let handlerReturned = false; + + const services = { + test: createServiceSchema().define({ + myStream: Procedure.stream({ + requestInit: Type.Object({}), + requestData: Type.Object({}), + responseData: Type.Object({}), + async handler({ ctx, resWritable }) { + ctx.deferCleanup(cleanupRan); + + // Write a response but don't close resWritable — the handler + // returns while the procedure is still open. + resWritable.write(Ok({})); + handlerReturned = true; + }, + }), + }), + }; + + createServer(mockTransportNetwork.getServerTransport(), services); + const client = createClient( + mockTransportNetwork.getClientTransport('client'), + 'SERVER', + ); + + const abortController = new AbortController(); + const { resReadable } = client.test.myStream.stream( + {}, + { signal: abortController.signal }, + ); + + // Read the response the handler wrote + const iter = resReadable[Symbol.asyncIterator](); + const result = await iter.next(); + expect(result.done).toBe(false); + + // Handler has returned but streams are still open — cleanup should NOT have run. + await waitFor(() => expect(handlerReturned).toBe(true)); + expect(cleanupRan).not.toHaveBeenCalled(); + + // Abort closes everything, triggering cleanup. + abortController.abort(); + + await waitFor(() => { + expect(cleanupRan).toHaveBeenCalledOnce(); + }); + }); + + test('deferCleanup after handler finished calls fn immediately', async () => { + const laterCleanup = vi.fn(); + let savedCtx: { deferCleanup: (fn: () => void) => void } | undefined; + + const services = { + test: createServiceSchema().define({ + myRpc: Procedure.rpc({ + requestInit: Type.Object({}), + responseData: Type.Object({}), + async handler({ ctx }) { + savedCtx = ctx; + + return Ok({}); + }, + }), + }), + }; + + createServer(mockTransportNetwork.getServerTransport(), services); + const client = createClient( + mockTransportNetwork.getClientTransport('client'), + 'SERVER', + ); + + await client.test.myRpc.rpc({}); + + await waitFor(() => { + expect(savedCtx).toBeDefined(); + }); + + // The handler has finished and cleanups have run. + // Registering a new cleanup should call it immediately. + assert(savedCtx); + savedCtx.deferCleanup(laterCleanup); + await waitFor(() => { + expect(laterCleanup).toHaveBeenCalledOnce(); + }); + }); +}); diff --git a/__tests__/e2e.test.ts b/__tests__/e2e.test.ts index a6c6cd7b..1071f857 100644 --- a/__tests__/e2e.test.ts +++ b/__tests__/e2e.test.ts @@ -253,8 +253,10 @@ describe.each(testMatrix())( ok: true, payload: { response: 'abc' }, }); - abortController.abort(); + // Wait for the server's close to be fully processed before aborting, + // so the abort is genuinely a no-op (testing idempotent close). expect(await isReadableDone(resReadable)).toEqual(true); + abortController.abort(); // Make sure that the handlers have finished. await advanceFakeTimersBySessionGrace(); diff --git a/package-lock.json b/package-lock.json index cfbcafb8..e590b98b 100644 --- a/package-lock.json +++ b/package-lock.json @@ -1,12 +1,12 @@ { "name": "@replit/river", - "version": "0.213.1", + "version": "0.214.0", "lockfileVersion": 2, "requires": true, "packages": { "": { "name": "@replit/river", - "version": "0.213.1", + "version": "0.214.0", "license": "MIT", "dependencies": { "@msgpack/msgpack": "^3.1.2", diff --git a/package.json b/package.json index a3691635..cdca6edd 100644 --- a/package.json +++ b/package.json @@ -1,7 +1,7 @@ { "name": "@replit/river", "description": "It's like tRPC but... with JSON Schema Support, duplex streaming and support for service multiplexing. Transport agnostic!", - "version": "0.213.1", + "version": "0.214.0", "type": "module", "exports": { ".": { diff --git a/router/context.ts b/router/context.ts index 922c6574..540c51e7 100644 --- a/router/context.ts +++ b/router/context.ts @@ -40,6 +40,22 @@ export type ProcedureHandlerContext = * the river documentation to understand the difference between the two concepts. */ cancel: (message?: string) => ErrResult>; + /** + * Register a cleanup function that will run after the procedure handler + * completes (whether it returns normally, throws, or is cancelled). + * Cleanup functions run in reverse registration order (LIFO) and each + * cleanup is awaited before the next one starts. + * + * Prefer this over registering async cleanup work on `signal`'s 'abort' + * event. Abort signal callbacks fire synchronously and do not await async + * work, so multiple async callbacks will interlace their execution + * (coroutine-like behavior) rather than running sequentially to completion. + * `deferCleanup` guarantees each cleanup finishes before the next begins. + * + * If a cleanup function throws, the error is recorded on the cleanup span + * but remaining cleanups continue to run. + */ + deferCleanup: (fn: () => void | Promise) => void; /** * This signal is a standard [AbortSignal](https://developer.mozilla.org/en-US/docs/Web/API/AbortSignal) * triggered when the procedure invocation is done. This signal tracks the invocation/request finishing diff --git a/router/server.ts b/router/server.ts index a4fadd5e..ea50f416 100644 --- a/router/server.ts +++ b/router/server.ts @@ -35,9 +35,10 @@ import { Value } from '@sinclair/typebox/value'; import { Err, Result, Ok, ErrResult } from './result'; import { EventMap } from '../transport/events'; import { coerceErrorString } from '../transport/stringifyError'; -import { Span } from '@opentelemetry/api'; +import { context as otelContext, Span, trace } from '@opentelemetry/api'; import { createHandlerSpan, + getTracer, PropagationContext, recordRiverError, } from '../tracing'; @@ -473,9 +474,57 @@ class RiverServer< cancelStream(streamId, result); }; + const deferredCleanups: Array<() => void | Promise> = []; + let cleanupsHaveRun = false; + const runCleanupSafe = async (fn: () => void | Promise) => { + try { + await fn(); + } catch (err) { + span.recordException( + err instanceof Error ? err : new Error(coerceErrorString(err)), + ); + } + }; + + const deferCleanup = (fn: () => void | Promise) => { + if (cleanupsHaveRun) { + void runCleanupSafe(fn); + + return; + } + + deferredCleanups.push(fn); + }; + + const runDeferredCleanups = async () => { + if (deferredCleanups.length === 0) { + cleanupsHaveRun = true; + span.end(); + + return; + } + + const cleanupSpan = getTracer().startSpan( + 'river.cleanup', + {}, + trace.setSpan(otelContext.active(), span), + ); + + try { + for (let fn = deferredCleanups.pop(); fn; fn = deferredCleanups.pop()) { + await runCleanupSafe(fn); + } + } finally { + cleanupsHaveRun = true; + cleanupSpan.end(); + span.end(); + } + }; + const cleanup = () => { finishedController.abort(); this.streams.delete(streamId); + void runDeferredCleanups(); }; const procClosesWithResponse = @@ -606,6 +655,7 @@ class RiverServer< return Err(errRes); }, + deferCleanup, signal: finishedController.signal, }; @@ -615,6 +665,7 @@ class RiverServer< from, metadata: sessionMetadata, span, + deferCleanup, signal: finishedController.signal, streamId, procedureName, @@ -638,8 +689,6 @@ class RiverServer< resWritable.write(responsePayload); } catch (err) { onHandlerError(err, span); - } finally { - span.end(); } break; case 'stream': @@ -652,8 +701,6 @@ class RiverServer< }); } catch (err) { onHandlerError(err, span); - } finally { - span.end(); } break; case 'subscription': @@ -665,8 +712,6 @@ class RiverServer< }); } catch (err) { onHandlerError(err, span); - } finally { - span.end(); } break; case 'upload': @@ -685,8 +730,6 @@ class RiverServer< resWritable.write(responsePayload); } catch (err) { onHandlerError(err, span); - } finally { - span.end(); } break; }