diff --git a/packages/auth-foundation/src/utils/TaskBridge.ts b/packages/auth-foundation/src/utils/TaskBridge.ts index ae25432..4755e29 100644 --- a/packages/auth-foundation/src/utils/TaskBridge.ts +++ b/packages/auth-foundation/src/utils/TaskBridge.ts @@ -1,5 +1,6 @@ import type { BroadcastChannelLike } from '../types/index.ts'; import { shortID } from '../crypto/index.ts'; +import { AuthSdkError } from '../errors/AuthSdkError.ts'; /** @useDeclaredType */ type TypeMap = Record; @@ -86,8 +87,9 @@ export abstract class TaskBridge { }); this.#pending.set(request.id, request); + let abortHandler: () => void; const result = (new Promise((resolve, reject) => { - const setTimeoutTimer = () => { + const resetTimeoutTimer = () => { // `options.timeout` set to `null` disables the timeout mechanism if (options.timeout === null) { return; @@ -97,20 +99,21 @@ export abstract class TaskBridge { clearTimeout(timeoutId); } // TODO: error type - timeoutId = setTimeout(() => reject(new Error('timeout')), options.timeout ?? 5000); + timeoutId = setTimeout(() => reject( + new TaskBridge.TimeoutError('timeout') + ), options.timeout ?? 5000); }; // sets timeout timer - setTimeoutTimer(); + resetTimeoutTimer(); // forces the pending promise to reject, so resources clean up if the request is aborted - request.signal.addEventListener('abort', () => { - reject(new DOMException('Aborted', 'AbortError')); - }); + abortHandler = () => reject(new DOMException('Aborted', 'AbortError')); + request.signal.addEventListener('abort', abortHandler); // This channel is meant for the Receiver to send the results (aka `HandlerMessage` messages) // ignore all Requestor events received (aka `RequestorMessage`) responseChannel.onmessage = (event) => { - if ('action' in event.data) { + if (request.signal.aborted || 'action' in event.data) { return; // ignore message } @@ -126,8 +129,11 @@ export abstract class TaskBridge { case 'PENDING': // defer the timeout timer when a heartbeat is received (host is still working) - setTimeoutTimer(); + resetTimeoutTimer(); + break; + case 'ABORTED': + request.abort('Host Aborted'); break; } }; @@ -141,15 +147,16 @@ export abstract class TaskBridge { } requestChannel.close(); responseChannel.close(); + request.signal.removeEventListener('abort', abortHandler); this.#pending.delete(request.id); }); - // TODO: review - const cancel = () => { + const abort = () => { responseChannel.postMessage({ action: 'CANCEL', __v: TaskBridge.BridgeVersion }); + request.controller.abort('cancel'); }; - return { result, cancel }; + return { result, abort }; } subscribe(handler: TaskBridge.TaskHandler) { @@ -182,8 +189,6 @@ export abstract class TaskBridge { // event type is now `RequestorMessage` switch (event.data.action) { case 'CANCEL': - // TODO: probably don't need to reply, just cancel action, if possible - // responseChannel.postMessage({ status: 'CANCELED' }); message.abort('cancel'); break; } @@ -198,8 +203,19 @@ export abstract class TaskBridge { ); } catch (err) { - if (err instanceof DOMException && err.name === 'AbortError') { - return null; + if (err instanceof DOMException) { + if (err.name === 'AbortError') { + // task was aborted, do nothing + return null; + } + + if (err.name === 'InvalidStateError') { + // this is error is thrown if a `.postMessage` is attempted after the channel is closed + // this can happen when the `handler` function attempts to `reply()` after `.close()` + // is called. Ignore the error, the `AbortSignal` is provided to the `handler` for + // if needed + return null; + } } if (err instanceof Error) { @@ -213,11 +229,24 @@ export abstract class TaskBridge { }; } + /** + * Returns the number of pending tasks + */ + get pending (): number { + return this.#pending.size; + } + close () { this.#channel?.close(); for (const message of this.#pending.values()) { message.abort(); message.channel.close(); + this.clearMessage(message.id); + } + this.#pending.clear(); + if (this.#heartbeatInt !== null) { + clearInterval(this.#heartbeatInt); + this.#heartbeatInt = null; } } } @@ -229,7 +258,7 @@ export namespace TaskBridge { /** * Possible `status` values indicating the process of an orchestrated request */ - export type TaskStatus = 'PENDING' | 'SUCCESS' | 'FAILED'; + export type TaskStatus = 'PENDING' | 'SUCCESS' | 'FAILED' | 'ABORTED'; export type BridgeVersions = 1 | 2; @@ -247,6 +276,9 @@ export namespace TaskBridge { } | { status: 'PENDING' __v: BridgeVersions; + } | { + status: 'ABORTED' + __v: BridgeVersions; } /** @@ -309,9 +341,9 @@ export namespace TaskBridge { } reply (data: S, status: TaskBridge.TaskStatus): void; - reply (status: 'PENDING'): void; - reply (data: S | 'PENDING', status: TaskBridge.TaskStatus = 'SUCCESS') { - const fn = this.replyFn ?? this.channel.postMessage; + reply (status: 'PENDING' | 'ABORTED'): void; + reply (data: S | 'PENDING' | 'ABORTED', status: TaskBridge.TaskStatus = 'SUCCESS') { + const fn = this.replyFn ?? this.channel.postMessage.bind(this.channel); if (data === 'PENDING' || status === 'PENDING') { // only send `PENDING` heartbeats when using <= v2 of the TaskBridge payload structure @@ -319,6 +351,12 @@ export namespace TaskBridge { fn({ status: 'PENDING', __v: this.__v } satisfies HandlerMessage); } } + else if (data === 'ABORTED' || status === 'ABORTED') { + // only send `PENDING` heartbeats when using <= v2 of the TaskBridge payload structure + if (this.__v === 2) { + fn({ status: 'ABORTED', __v: this.__v } satisfies HandlerMessage); + } + } else { // TODO: remove this condition - OKTA-1053515 if (this.__v < 2) { @@ -332,6 +370,7 @@ export namespace TaskBridge { } abort (...args: Parameters) { + this.reply('ABORTED'); return this.controller.abort(...args); } @@ -345,6 +384,7 @@ export namespace TaskBridge { */ export type TaskOptions = { timeout?: number | null; + signal?: AbortSignal; }; /** @@ -352,7 +392,7 @@ export namespace TaskBridge { */ export type TaskResponse = { result: Promise; - cancel: () => void; + abort: () => void; }; /** @@ -364,4 +404,20 @@ export namespace TaskBridge { options?: { signal: AbortSignal } ) => any; + /** + * @group Errors + */ + export class TimeoutError extends AuthSdkError { + #timeout: boolean = false; + + constructor (...args: ConstructorParameters) { + const [message, ...rest] = args; + super(message ?? 'timeout', ...rest); + this.#timeout = true; + } + + get timeout (): boolean { + return this.#timeout; + } + } } diff --git a/packages/auth-foundation/test/spec/utils/TaskBridge.spec.ts b/packages/auth-foundation/test/spec/utils/TaskBridge.spec.ts index f05e98e..a16d429 100644 --- a/packages/auth-foundation/test/spec/utils/TaskBridge.spec.ts +++ b/packages/auth-foundation/test/spec/utils/TaskBridge.spec.ts @@ -1,6 +1,6 @@ -import { BroadcastChannelLike, JsonRecord } from 'src/types'; import { TaskBridge } from 'src/utils/TaskBridge.ts'; + type TestRequest = { ADD: { foo: number; @@ -23,58 +23,21 @@ type TestResponse = { } }; -class TestChannel implements BroadcastChannelLike { - channel: BroadcastChannel; - #handler: BroadcastChannelLike['onmessage'] = null; - - constructor (public name: string) { - this.channel = new BroadcastChannel(name); - } - - get onmessage () { - return this.#handler; +class TestBus extends TaskBridge { + protected createBridgeChannel (): TaskBridge.BridgeChannel { + return new BroadcastChannel(this.name) as TaskBridge.BridgeChannel; } - set onmessage (handler) { - if (handler === null) { - this.channel.onmessage = null; - this.#handler = null; - } - - console.log('handler set', handler); - - this.#handler = async (event) => { - console.log('got message', event.data); - // const reply = (response) => this.channel.postMessage(response); - // @ts-ignore - await handler(event.data); - }; - - this.channel.onmessage = this.#handler; - } - - postMessage(message: M): void { - this.channel.postMessage(message); - } - - close () { - this.channel.close(); - } -} - -class TestBus extends TaskBridge { - - protected createBridgeChannel (): TaskBridge.BridgeChannel { - return new TestChannel(this.name); - } - - protected createTaskChannel(name: string): TaskBridge.TaskChannel { - return new TestChannel(name); + protected createTaskChannel(name: string): TaskBridge.TaskChannel { + return new BroadcastChannel(name) as TaskBridge.TaskChannel; } } +const sleep = (ms: number) => new Promise(resolve => { + setTimeout(resolve, ms); +}); -describe.skip('TaskBridge', () => { +describe('TaskBridge', () => { let receiver: TaskBridge; let sender: TaskBridge; @@ -84,34 +47,307 @@ describe.skip('TaskBridge', () => { }); afterEach(() => { + expect(receiver.pending).toEqual(0); + expect(sender.pending).toEqual(0); + receiver.close(); sender.close(); + jest.clearAllTimers(); }); - describe('test', () => { - it('sends and receives messages', async () => { - const channel = new BroadcastChannel('test'); - channel.onmessage = (event) => { - console.log('[monitor]: ', event.data); - }; - - receiver.subscribe(async (message, reply) => { - console.log('handler called'); - reply({ foo: '2', bar: '1' }); - }); - - const result = await sender.send({ foo: 1, bar: 2 }).result; - expect(result).toEqual({ bar: 'baz' }); + it('sends and receives messages between separate instances', async () => { + jest.useFakeTimers(); + + const response = { foo: '2', bar: '1' }; + + receiver.subscribe(async (message, reply) => { + reply(response); + }); + + const { result } = sender.send({ foo: 1, bar: 2 }); + await expect(result).resolves.toEqual(response); + expect(jest.getTimerCount()).toBe(0); + + jest.useRealTimers(); + }); + + it('handles multiple tasks simultaneously', async () => { + jest.useFakeTimers(); + + const response = { foo: '2', bar: '1' }; + + receiver.subscribe(async (message, reply) => { + reply(response); + }); + + const promises = Promise.allSettled(Array.from({ length: 3 }, (_, i) => { + const { result } = sender.send({ foo: 1 + i, bar: 2 + i }); + return result; + })); + + await expect(promises).resolves.toEqual([ + { status: 'fulfilled', value: { ...response} }, + { status: 'fulfilled', value: { ...response } }, + { status: 'fulfilled', value: { ...response } }, + ]); + + expect(jest.getTimerCount()).toBe(0); + + jest.useRealTimers(); + }); + + it('handles a single task throwing gracefully', async () => { + jest.useFakeTimers(); + + const response = { foo: '2', bar: '1' }; + + let taskCount = 0; + receiver.subscribe(async (message, reply) => { + const isEven = taskCount === 0 || taskCount % 2 === 0; + taskCount += 1; + if (isEven) { + reply(response); + } + else { + throw new Error('test error'); + } + }); + + const promises = Promise.allSettled(Array.from({ length: 3 }, (_, i) => { + const { result } = sender.send({ foo: 1 + i, bar: 2 + i }); + return result; + })); + + await expect(promises).resolves.toEqual([ + { status: 'fulfilled', value: { ...response} }, + { status: 'fulfilled', value: { error: 'test error' } }, + { status: 'fulfilled', value: { ...response } }, + ]); + + expect(jest.getTimerCount()).toBe(0); + + jest.useRealTimers(); + }); + + it('gracefully handles an error being thrown by the subscribe handler', async () => { + jest.useFakeTimers(); + + const handler = jest.fn().mockImplementation(async () => { + throw new Error('test'); + }); + receiver.subscribe(handler); + + const { result } = sender.send({ foo: 1, bar: 2 }); + await expect(result).resolves.toEqual({ error: 'test' }); + expect(jest.getTimerCount()).toBe(0); + + jest.useRealTimers(); + }); + + it('can handle aborting pending tasks', async () => { + jest.useFakeTimers(); + + const abortListener = jest.fn(); + const handler = jest.fn().mockImplementation( async (message, reply, { signal }) => { + signal.addEventListener('abort', abortListener, { once: true }); + + await sleep(50); // sleep to delay responding to the message, so the abort fires first + reply({ foo: '1', bar: '2' }); + }); + receiver.subscribe(handler); + + const { result, abort } = sender.send({ foo: 1, bar: 2 }); + + // flush microtasks to ensure subscribe abortHandler is set up + await jest.advanceTimersByTimeAsync(10); + + abort(); - channel.close(); + await expect(result).rejects.toThrow(DOMException); + await expect(result).rejects.toThrow('Aborted'); + + // wait a bit more to ensure abort listener is called + await jest.advanceTimersByTimeAsync(100); + + expect(handler).toHaveBeenCalled(); + expect(abortListener).toHaveBeenCalled(); + expect(jest.getTimerCount()).toBe(0); + + jest.useRealTimers(); + }); + + it('will not timeout a pending request when host is available', async () => { + jest.useFakeTimers(); + + const response = { foo: '2', bar: '1' }; + const largeDelay = 10000; + + // clever way of capturing the requestId + let requestId; + const bc = new BroadcastChannel('test'); + bc.onmessage = (evt => { + if (evt.data.requestId) { + requestId = evt.data.requestId; + } + }); + + receiver.subscribe(async (message, reply) => { + await sleep(largeDelay); // very long delay + reply(response); + }); + + const { result } = sender.send({ foo: 1, bar: 2 }); + // advance timers to send BroadcastChannel messages + await jest.advanceTimersByTimeAsync(100); + + // listen on "response channel" and count number of `PENDING` "pings" + let pendingCount = 0; + const channel = new BroadcastChannel(requestId); + channel.onmessage = (evt) => { + if (evt.data.status === 'PENDING') { + pendingCount++; + } + }; + + // advance the timers to the length of the delay, so response is finally returned + await jest.advanceTimersByTimeAsync(largeDelay); + + await expect(result).resolves.toEqual(response); + // expect a predictable number of 'PENDING' pings given the large delay + expect(pendingCount).toEqual(largeDelay / receiver.heartbeatInterval); + expect(jest.getTimerCount()).toBe(0); + + // cleanup + jest.useRealTimers(); + bc.close(); + channel.close(); + }); + + it('will timeout when host does not response within default timeout window', async () => { + expect.assertions(4); // ensures `result.catch()` is invoked + jest.useFakeTimers(); + + receiver.close(); + + const { result } = sender.send({ foo: 1, bar: 2 }); + + // use `.catch` to bind listener synchronously + const promise = result.catch(err => { + expect(err).toBeInstanceOf(TaskBridge.TimeoutError); + }); + + await jest.advanceTimersByTimeAsync(10000); + await promise; + + expect(jest.getTimerCount()).toBe(0); + + jest.useRealTimers(); + }); + + it('will timeout when host does not response within user defined timeout window', async () => { + expect.assertions(4); // ensures `result.catch()` is invoked + jest.useFakeTimers(); + + const largeTimeout = 10000; + + receiver.close(); + + const { result } = sender.send({ foo: 1, bar: 2 }, { timeout: largeTimeout - 100 }); + + // use `.catch` to bind listener synchronously + const promise = result.catch(err => { + expect(err).toBeInstanceOf(TaskBridge.TimeoutError); + }); + + await jest.advanceTimersByTimeAsync(10000); + await promise; + + expect(jest.getTimerCount()).toBe(0); + + jest.useRealTimers(); + }); + + it('will timeout when no host is avaiable', async () => { + expect.assertions(4); // ensures `result.catch()` is invoked + jest.useFakeTimers(); + + const timeout = 100; + + // NOTE: no `receiver.subscribe` call + + const { result } = sender.send({ foo: 1, bar: 2 }, { timeout }); + + // use `.catch` to bind listener synchronously + const promise = result.catch(err => { + expect(err).toBeInstanceOf(TaskBridge.TimeoutError); }); + + await jest.advanceTimersByTimeAsync(timeout); + await promise; + + expect(jest.getTimerCount()).toBe(0); + + jest.useRealTimers(); }); - // xdescribe('', async () => { + it('will abort pending tasks when closed', async () => { + jest.useFakeTimers(); + + const abortListener = jest.fn(); + const handler = jest.fn().mockImplementation(async (message, reply, { signal }) => { + // confirm the `signal` instance fires an `abort` event + signal.addEventListener('abort', abortListener); + + // returns Promise which rejects when event is fired + function rejectWhenFired (target: EventTarget, event: string) { + return new Promise((_, reject) => { + target.addEventListener(event, reject, { once: true }); + }); + } + + // track the timers set by `sleep()` within this test + let sleepTimeout; + function sleep (delay) { + return new Promise((resolve) => { + sleepTimeout = setTimeout(resolve, delay); + }); + } - // }); + // sleep to delay responding to the message, so the abort fires first + try { + await Promise.race([ + sleep(sender.heartbeatInterval * 10), + rejectWhenFired(signal, 'abort'), + ]); + + reply({ foo: '1', bar: '2' }); + } + finally { + // timeouts set via `sleep()` need to be cleared. Test requires no timers remain + clearTimeout(sleepTimeout); + } + }); + receiver.subscribe(handler); + + const promises = Promise.allSettled(Array.from({ length: 3 }, (_, i) => { + const { result } = sender.send({ foo: 1 + i, bar: 2 + i }); + return result; + })); - // xdescribe('', async () => { + // flush microtasks to ensure subscribe handler is set up + await jest.advanceTimersByTimeAsync(10); + + expect(handler).toHaveBeenCalledTimes(3); + + receiver.close(); + const result = await promises; + await jest.advanceTimersByTimeAsync(10); + + expect(result).toEqual(Array(3).fill({ status: 'rejected', reason: expect.any(DOMException) })); + expect(abortListener).toHaveBeenCalledTimes(3); + expect(jest.getTimerCount()).toBe(0); + + jest.useRealTimers(); + }); - // }); }); diff --git a/tooling/jest-helpers/browser/jest.environment.js b/tooling/jest-helpers/browser/jest.environment.js index f025a39..f06b89e 100644 --- a/tooling/jest-helpers/browser/jest.environment.js +++ b/tooling/jest-helpers/browser/jest.environment.js @@ -15,6 +15,8 @@ class CustomJSDomEnv extends JSDOMEnv { this.global.Request = Request; this.global.Response = Response; this.global.Headers = Headers; + this.global.BroadcastChannel = BroadcastChannel; + this.global.DOMException = DOMException; } } diff --git a/tooling/jest-helpers/browser/jest.setup.ts b/tooling/jest-helpers/browser/jest.setup.ts index 43e699c..a7bcbc8 100644 --- a/tooling/jest-helpers/browser/jest.setup.ts +++ b/tooling/jest-helpers/browser/jest.setup.ts @@ -25,7 +25,8 @@ class MockBroadcastChannel implements BroadcastChannel { global.TextEncoder = TextEncoder; global.TextDecoder = TextDecoder; -global.BroadcastChannel = MockBroadcastChannel; +// global.BroadcastChannel = MockBroadcastChannel; +// global.BroadcastChannel = BroadcastChannel; global.fetch = () => { throw new Error(`