From 8521e10f07eeafae5710072fa9e8621e180fd52e Mon Sep 17 00:00:00 2001 From: Felicitas Pojtinger Date: Fri, 9 Aug 2024 14:54:18 -0700 Subject: [PATCH] feat: Allow cancelling in-flight RPC operations with per-`link*()` function context in TypeScript registry implementation Signed-off-by: Felicitas Pojtinger --- ts/src/rpc/registry.ts | 25 +++++++++++++++++++++---- 1 file changed, 21 insertions(+), 4 deletions(-) diff --git a/ts/src/rpc/registry.ts b/ts/src/rpc/registry.ts index 52976e0..ee007c5 100644 --- a/ts/src/rpc/registry.ts +++ b/ts/src/rpc/registry.ts @@ -8,7 +8,7 @@ import { import { ILocalContext, IRemoteContext } from "./context"; import { ClosureManager, registerClosure } from "./manager"; -export const ErrorCallCancelled = "call timed out"; +export const ErrorCallAborted = "call aborted"; export const ErrorCannotCallNonFunction = "can not call non function"; const constructorFunctionName = "constructor"; @@ -50,6 +50,10 @@ export const remoteClosure = ( const makeRPC = ( + // This is separate from the AbortSignal that is the first argument to each RPC because we also + // want to be able to cancel all in-flight RPCs if the signal passed to a `link*()` function is cancelled + linkSignal: AbortSignal | undefined, + name: string, responseResolver: EventTarget, @@ -61,8 +65,8 @@ const makeRPC = ) => async (ctx: IRemoteContext, ...rest: any[]) => new Promise((res, rej) => { - if (ctx?.signal?.aborted) { - rej(new Error(ErrorCallCancelled)); + if (ctx?.signal?.aborted || linkSignal?.aborted) { + rej(new Error(ErrorCallAborted)); return; } @@ -91,7 +95,7 @@ const makeRPC = closureFreers.map((free) => free()); const callResponse: ICallResponse = { - err: ErrorCallCancelled, + err: ErrorCallAborted, }; responseResolver.dispatchEvent( @@ -99,6 +103,7 @@ const makeRPC = ); }; ctx?.signal?.addEventListener("abort", abortListener); + linkSignal?.addEventListener("abort", abortListener); const returnListener = (event: Event) => { const { value, err } = (event as CustomEvent).detail; @@ -147,6 +152,7 @@ export class Registry { /** * Expose local RPCs and implement remote RPCs via a message-based transport + * @param signal AbortSignal for in-flight RPC operations * @param requestWriter Stream to write requests to * @param responseWriter Stream to write responses to * @param requestReader Stream to read requests from @@ -156,6 +162,8 @@ export class Registry { * @param hooks Link hooks */ linkMessage = ( + signal: AbortSignal | undefined, + requestWriter: WritableStreamDefaultWriter, responseWriter: WritableStreamDefaultWriter, @@ -180,6 +188,8 @@ export class Registry { } (r as any)[functionName] = makeRPC( + signal, + functionName, responseResolver, @@ -255,6 +265,8 @@ export class Registry { remoteClosureParameterIndexes?.includes(index + 1) ? (closureCtx: IRemoteContext, ...closureArgs: any[]) => { const rpc = makeRPC( + signal, + "CallClosure", responseResolver, @@ -338,6 +350,7 @@ export class Registry { /** * Expose local RPCs and implement remote RPCs via a stream-based transport + * @param signal AbortSignal for in-flight RPC operations * @param encoder Stream to write messages to * @param decoder Stream to read messages from * @param marshal Function to marshal nested values with @@ -345,6 +358,8 @@ export class Registry { * @param hooks Link hooks */ linkStream = ( + signal: AbortSignal | undefined, + encoder: WritableStream, decoder: ReadableStream, @@ -454,6 +469,8 @@ export class Registry { }); this.linkMessage( + signal, + requestWriter.getWriter(), responseWriter.getWriter(),