Skip to content

Commit

Permalink
feat: RPC: add useAbortSignal option
Browse files Browse the repository at this point in the history
Adds a new option "useAbortSignal" which adds an optional AbortSignal parameter
to RPC functions. AbortController and AbortSignal are built-ins in both Node.JS
and all web browsers, which implement aborting long-lived processes.

For example:

const abortController = new AbortController()
const responsePromise = rpcClient.DoSomething(request, abortController.signal)
// abort the RPC call early
abortController.abort()

Fixes #730

Signed-off-by: Christian Stewart <christian@paral.in>
  • Loading branch information
paralin committed Dec 14, 2022
1 parent 78efa04 commit f51c0d9
Show file tree
Hide file tree
Showing 8 changed files with 218 additions and 7 deletions.
2 changes: 2 additions & 0 deletions README.markdown
Original file line number Diff line number Diff line change
Expand Up @@ -412,6 +412,8 @@ Generated code will be placed in the Gradle build directory.

- With `--ts_proto_opt=outputServices=false`, or `=none`, ts-proto will output NO service definitions.

- With `--ts_proto_opt=useAbortSignal=true`, the generated services will accept an `AbortSignal` to cancel RPC calls.

- With `--ts_proto_opt=useAsyncIterable=true`, the generated services will use `AsyncIterable` instead of `Observable`.

- With `--ts_proto_opt=emitImportedFiles=false`, ts-proto will not emit `google/protobuf/*` files unless you explicit add files to `protoc` like this
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
useAsyncIterable=true,useAbortSignal=true
Binary file not shown.
19 changes: 19 additions & 0 deletions integration/async-iterable-services-abort-signal/simple.proto
Original file line number Diff line number Diff line change
@@ -0,0 +1,19 @@
syntax = "proto3";
package simple;

// Echoer service returns the given message.
service Echoer {
// Echo returns the given message.
rpc Echo(EchoMsg) returns (EchoMsg);
// EchoServerStream is an example of a server -> client one-way stream.
rpc EchoServerStream(EchoMsg) returns (stream EchoMsg);
// EchoClientStream is an example of client->server one-way stream.
rpc EchoClientStream(stream EchoMsg) returns (EchoMsg);
// EchoBidiStream is an example of a two-way stream.
rpc EchoBidiStream(stream EchoMsg) returns (stream EchoMsg);
}

// EchoMsg is the message body for Echo.
message EchoMsg {
string body = 1;
}
178 changes: 178 additions & 0 deletions integration/async-iterable-services-abort-signal/simple.ts
Original file line number Diff line number Diff line change
@@ -0,0 +1,178 @@
/* eslint-disable */
import * as _m0 from "protobufjs/minimal";

export const protobufPackage = "simple";

/** EchoMsg is the message body for Echo. */
export interface EchoMsg {
body: string;
}

function createBaseEchoMsg(): EchoMsg {
return { body: "" };
}

export const EchoMsg = {
encode(message: EchoMsg, writer: _m0.Writer = _m0.Writer.create()): _m0.Writer {
if (message.body !== "") {
writer.uint32(10).string(message.body);
}
return writer;
},

decode(input: _m0.Reader | Uint8Array, length?: number): EchoMsg {
const reader = input instanceof _m0.Reader ? input : new _m0.Reader(input);
let end = length === undefined ? reader.len : reader.pos + length;
const message = createBaseEchoMsg();
while (reader.pos < end) {
const tag = reader.uint32();
switch (tag >>> 3) {
case 1:
message.body = reader.string();
break;
default:
reader.skipType(tag & 7);
break;
}
}
return message;
},

// encodeTransform encodes a source of message objects.
// Transform<EchoMsg, Uint8Array>
async *encodeTransform(
source: AsyncIterable<EchoMsg | EchoMsg[]> | Iterable<EchoMsg | EchoMsg[]>,
): AsyncIterable<Uint8Array> {
for await (const pkt of source) {
if (Array.isArray(pkt)) {
for (const p of pkt) {
yield* [EchoMsg.encode(p).finish()];
}
} else {
yield* [EchoMsg.encode(pkt).finish()];
}
}
},

// decodeTransform decodes a source of encoded messages.
// Transform<Uint8Array, EchoMsg>
async *decodeTransform(
source: AsyncIterable<Uint8Array | Uint8Array[]> | Iterable<Uint8Array | Uint8Array[]>,
): AsyncIterable<EchoMsg> {
for await (const pkt of source) {
if (Array.isArray(pkt)) {
for (const p of pkt) {
yield* [EchoMsg.decode(p)];
}
} else {
yield* [EchoMsg.decode(pkt)];
}
}
},

fromJSON(object: any): EchoMsg {
return { body: isSet(object.body) ? String(object.body) : "" };
},

toJSON(message: EchoMsg): unknown {
const obj: any = {};
message.body !== undefined && (obj.body = message.body);
return obj;
},

fromPartial<I extends Exact<DeepPartial<EchoMsg>, I>>(object: I): EchoMsg {
const message = createBaseEchoMsg();
message.body = object.body ?? "";
return message;
},
};

/** Echoer service returns the given message. */
export interface Echoer {
/** Echo returns the given message. */
Echo(request: EchoMsg, abortSignal?: AbortSignal): Promise<EchoMsg>;
/** EchoServerStream is an example of a server -> client one-way stream. */
EchoServerStream(request: EchoMsg, abortSignal?: AbortSignal): AsyncIterable<EchoMsg>;
/** EchoClientStream is an example of client->server one-way stream. */
EchoClientStream(request: AsyncIterable<EchoMsg>, abortSignal?: AbortSignal): Promise<EchoMsg>;
/** EchoBidiStream is an example of a two-way stream. */
EchoBidiStream(request: AsyncIterable<EchoMsg>, abortSignal?: AbortSignal): AsyncIterable<EchoMsg>;
}

export class EchoerClientImpl implements Echoer {
private readonly rpc: Rpc;
private readonly service: string;
constructor(rpc: Rpc, opts?: { service?: string }) {
this.service = opts?.service || "simple.Echoer";
this.rpc = rpc;
this.Echo = this.Echo.bind(this);
this.EchoServerStream = this.EchoServerStream.bind(this);
this.EchoClientStream = this.EchoClientStream.bind(this);
this.EchoBidiStream = this.EchoBidiStream.bind(this);
}
Echo(request: EchoMsg, abortSignal?: AbortSignal): Promise<EchoMsg> {
const data = EchoMsg.encode(request).finish();
const promise = this.rpc.request(this.service, "Echo", data, abortSignal || undefined);
return promise.then((data) => EchoMsg.decode(new _m0.Reader(data)));
}

EchoServerStream(request: EchoMsg, abortSignal?: AbortSignal): AsyncIterable<EchoMsg> {
const data = EchoMsg.encode(request).finish();
const result = this.rpc.serverStreamingRequest(this.service, "EchoServerStream", data, abortSignal || undefined);
return EchoMsg.decodeTransform(result);
}

EchoClientStream(request: AsyncIterable<EchoMsg>, abortSignal?: AbortSignal): Promise<EchoMsg> {
const data = EchoMsg.encodeTransform(request);
const promise = this.rpc.clientStreamingRequest(this.service, "EchoClientStream", data, abortSignal || undefined);
return promise.then((data) => EchoMsg.decode(new _m0.Reader(data)));
}

EchoBidiStream(request: AsyncIterable<EchoMsg>, abortSignal?: AbortSignal): AsyncIterable<EchoMsg> {
const data = EchoMsg.encodeTransform(request);
const result = this.rpc.bidirectionalStreamingRequest(
this.service,
"EchoBidiStream",
data,
abortSignal || undefined,
);
return EchoMsg.decodeTransform(result);
}
}

interface Rpc {
request(service: string, method: string, data: Uint8Array, abortSignal?: AbortSignal): Promise<Uint8Array>;
clientStreamingRequest(
service: string,
method: string,
data: AsyncIterable<Uint8Array>,
abortSignal?: AbortSignal,
): Promise<Uint8Array>;
serverStreamingRequest(
service: string,
method: string,
data: Uint8Array,
abortSignal?: AbortSignal,
): AsyncIterable<Uint8Array>;
bidirectionalStreamingRequest(
service: string,
method: string,
data: AsyncIterable<Uint8Array>,
abortSignal?: AbortSignal,
): AsyncIterable<Uint8Array>;
}

type Builtin = Date | Function | Uint8Array | string | number | boolean | undefined;

export type DeepPartial<T> = T extends Builtin ? T
: T extends Array<infer U> ? Array<DeepPartial<U>> : T extends ReadonlyArray<infer U> ? ReadonlyArray<DeepPartial<U>>
: T extends {} ? { [K in keyof T]?: DeepPartial<T[K]> }
: Partial<T>;

type KeysOfUnion<T> = T extends T ? keyof T : never;
export type Exact<P, I extends P> = P extends Builtin ? P
: P & { [K in keyof P]: Exact<P[K], I[K]> } & { [K in Exclude<keyof I, KeysOfUnion<P>>]: never };

function isSet(value: any): boolean {
return value !== null && value !== undefined;
}
22 changes: 15 additions & 7 deletions src/generate-services.ts
Original file line number Diff line number Diff line change
Expand Up @@ -62,6 +62,9 @@ export function generateService(
const partialInput = options.outputClientImpl === "grpc-web";
const inputType = requestType(ctx, methodDesc, partialInput);
params.push(code`request: ${inputType}`);
if (options.useAbortSignal) {
params.push(code`abortSignal?: AbortSignal`);
}

// Use metadata as last argument for interface only configuration
if (options.outputClientImpl === "grpc-web") {
Expand Down Expand Up @@ -103,19 +106,21 @@ export function generateService(
return joinCode(chunks, { on: "\n" });
}

function generateRegularRpcMethod(
ctx: Context,
methodDesc: MethodDescriptorProto
): Code {
function generateRegularRpcMethod(ctx: Context, methodDesc: MethodDescriptorProto): Code {
assertInstanceOf(methodDesc, FormattedMethodDescriptor);
const { options } = ctx;
const Reader = impFile(ctx.options, "Reader@protobufjs/minimal");
const rawInputType = rawRequestType(ctx, methodDesc, { keepValueType: true });
const inputType = requestType(ctx, methodDesc);
const rawOutputType = responseType(ctx, methodDesc, { keepValueType: true });

const params = [...(options.context ? [code`ctx: Context`] : []), code`request: ${inputType}`];
const params = [
...(options.context ? [code`ctx: Context`] : []),
code`request: ${inputType}`,
...(options.useAbortSignal ? [code`abortSignal?: AbortSignal`] : []),
];
const maybeCtx = options.context ? "ctx," : "";
const maybeAbortSignal = options.useAbortSignal ? "abortSignal || undefined," : "";

let encode = code`${rawInputType}.encode(request).finish()`;
let decode = code`data => ${rawOutputType}.decode(new ${Reader}(data))`;
Expand Down Expand Up @@ -163,7 +168,8 @@ function generateRegularRpcMethod(
${maybeCtx}
this.service,
"${methodDesc.name}",
data
data,
${maybeAbortSignal}
);
return ${decode};
}
Expand Down Expand Up @@ -329,6 +335,7 @@ export function generateRpcType(ctx: Context, hasStreamingMethods: boolean): Cod
const { options } = ctx;
const maybeContext = options.context ? "<Context>" : "";
const maybeContextParam = options.context ? "ctx: Context," : "";
const maybeAbortSignalParam = options.useAbortSignal ? "abortSignal?: AbortSignal," : "";
const methods = [[code`request`, code`Uint8Array`, code`Promise<Uint8Array>`]];
if (hasStreamingMethods) {
const observable = observableType(ctx);
Expand All @@ -348,7 +355,8 @@ export function generateRpcType(ctx: Context, hasStreamingMethods: boolean): Cod
${maybeContextParam}
service: string,
method: string,
data: ${method[1]}
data: ${method[1]},
${maybeAbortSignalParam}
): ${method[2]};`);
});
chunks.push(code` }`);
Expand Down
2 changes: 2 additions & 0 deletions src/options.ts
Original file line number Diff line number Diff line change
Expand Up @@ -65,6 +65,7 @@ export type Options = {
outputSchema: boolean;
onlyTypes: boolean;
emitImportedFiles: boolean;
useAbortSignal: boolean;
useExactTypes: boolean;
useAsyncIterable: boolean;
unknownFields: boolean;
Expand Down Expand Up @@ -112,6 +113,7 @@ export function defaultOptions(): Options {
onlyTypes: false,
emitImportedFiles: true,
useExactTypes: true,
useAbortSignal: false,
useAsyncIterable: false,
unknownFields: false,
usePrototypeForDefaults: false,
Expand Down
1 change: 1 addition & 0 deletions tests/options-test.ts
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,7 @@ describe("options", () => {
"stringEnums": false,
"unknownFields": false,
"unrecognizedEnum": true,
"useAbortSignal": false,
"useAsyncIterable": false,
"useDate": "timestamp",
"useExactTypes": true,
Expand Down

0 comments on commit f51c0d9

Please sign in to comment.