diff --git a/integration/async-iterable-services-abort-signal/parameters.txt b/integration/async-iterable-services-abort-signal/parameters.txt new file mode 100644 index 000000000..aab111ad9 --- /dev/null +++ b/integration/async-iterable-services-abort-signal/parameters.txt @@ -0,0 +1 @@ +useAsyncIterable=true,useAbortSignal=true diff --git a/integration/async-iterable-services-abort-signal/simple.bin b/integration/async-iterable-services-abort-signal/simple.bin new file mode 100644 index 000000000..51973c2d8 Binary files /dev/null and b/integration/async-iterable-services-abort-signal/simple.bin differ diff --git a/integration/async-iterable-services-abort-signal/simple.proto b/integration/async-iterable-services-abort-signal/simple.proto new file mode 100644 index 000000000..73d6a5180 --- /dev/null +++ b/integration/async-iterable-services-abort-signal/simple.proto @@ -0,0 +1,19 @@ +syntax = "proto3"; +package simple; + +// Echoer service returns the given message. +service Echoer { + // Echo returns the given message. + rpc Echo(EchoMsg) returns (EchoMsg); + // EchoServerStream is an example of a server -> client one-way stream. + rpc EchoServerStream(EchoMsg) returns (stream EchoMsg); + // EchoClientStream is an example of client->server one-way stream. + rpc EchoClientStream(stream EchoMsg) returns (EchoMsg); + // EchoBidiStream is an example of a two-way stream. + rpc EchoBidiStream(stream EchoMsg) returns (stream EchoMsg); +} + +// EchoMsg is the message body for Echo. +message EchoMsg { + string body = 1; +} diff --git a/integration/async-iterable-services-abort-signal/simple.ts b/integration/async-iterable-services-abort-signal/simple.ts new file mode 100644 index 000000000..c6ec3f2f8 --- /dev/null +++ b/integration/async-iterable-services-abort-signal/simple.ts @@ -0,0 +1,178 @@ +/* eslint-disable */ +import * as _m0 from "protobufjs/minimal"; + +export const protobufPackage = "simple"; + +/** EchoMsg is the message body for Echo. */ +export interface EchoMsg { + body: string; +} + +function createBaseEchoMsg(): EchoMsg { + return { body: "" }; +} + +export const EchoMsg = { + encode(message: EchoMsg, writer: _m0.Writer = _m0.Writer.create()): _m0.Writer { + if (message.body !== "") { + writer.uint32(10).string(message.body); + } + return writer; + }, + + decode(input: _m0.Reader | Uint8Array, length?: number): EchoMsg { + const reader = input instanceof _m0.Reader ? input : new _m0.Reader(input); + let end = length === undefined ? reader.len : reader.pos + length; + const message = createBaseEchoMsg(); + while (reader.pos < end) { + const tag = reader.uint32(); + switch (tag >>> 3) { + case 1: + message.body = reader.string(); + break; + default: + reader.skipType(tag & 7); + break; + } + } + return message; + }, + + // encodeTransform encodes a source of message objects. + // Transform + async *encodeTransform( + source: AsyncIterable | Iterable, + ): AsyncIterable { + for await (const pkt of source) { + if (Array.isArray(pkt)) { + for (const p of pkt) { + yield* [EchoMsg.encode(p).finish()]; + } + } else { + yield* [EchoMsg.encode(pkt).finish()]; + } + } + }, + + // decodeTransform decodes a source of encoded messages. + // Transform + async *decodeTransform( + source: AsyncIterable | Iterable, + ): AsyncIterable { + for await (const pkt of source) { + if (Array.isArray(pkt)) { + for (const p of pkt) { + yield* [EchoMsg.decode(p)]; + } + } else { + yield* [EchoMsg.decode(pkt)]; + } + } + }, + + fromJSON(object: any): EchoMsg { + return { body: isSet(object.body) ? String(object.body) : "" }; + }, + + toJSON(message: EchoMsg): unknown { + const obj: any = {}; + message.body !== undefined && (obj.body = message.body); + return obj; + }, + + fromPartial, I>>(object: I): EchoMsg { + const message = createBaseEchoMsg(); + message.body = object.body ?? ""; + return message; + }, +}; + +/** Echoer service returns the given message. */ +export interface Echoer { + /** Echo returns the given message. */ + Echo(request: EchoMsg, abortSignal?: AbortSignal): Promise; + /** EchoServerStream is an example of a server -> client one-way stream. */ + EchoServerStream(request: EchoMsg, abortSignal?: AbortSignal): AsyncIterable; + /** EchoClientStream is an example of client->server one-way stream. */ + EchoClientStream(request: AsyncIterable, abortSignal?: AbortSignal): Promise; + /** EchoBidiStream is an example of a two-way stream. */ + EchoBidiStream(request: AsyncIterable, abortSignal?: AbortSignal): AsyncIterable; +} + +export class EchoerClientImpl implements Echoer { + private readonly rpc: Rpc; + private readonly service: string; + constructor(rpc: Rpc, opts?: { service?: string }) { + this.service = opts?.service || "simple.Echoer"; + this.rpc = rpc; + this.Echo = this.Echo.bind(this); + this.EchoServerStream = this.EchoServerStream.bind(this); + this.EchoClientStream = this.EchoClientStream.bind(this); + this.EchoBidiStream = this.EchoBidiStream.bind(this); + } + Echo(request: EchoMsg, abortSignal?: AbortSignal): Promise { + const data = EchoMsg.encode(request).finish(); + const promise = this.rpc.request(this.service, "Echo", data, abortSignal || undefined); + return promise.then((data) => EchoMsg.decode(new _m0.Reader(data))); + } + + EchoServerStream(request: EchoMsg, abortSignal?: AbortSignal): AsyncIterable { + const data = EchoMsg.encode(request).finish(); + const result = this.rpc.serverStreamingRequest(this.service, "EchoServerStream", data, abortSignal || undefined); + return EchoMsg.decodeTransform(result); + } + + EchoClientStream(request: AsyncIterable, abortSignal?: AbortSignal): Promise { + const data = EchoMsg.encodeTransform(request); + const promise = this.rpc.clientStreamingRequest(this.service, "EchoClientStream", data, abortSignal || undefined); + return promise.then((data) => EchoMsg.decode(new _m0.Reader(data))); + } + + EchoBidiStream(request: AsyncIterable, abortSignal?: AbortSignal): AsyncIterable { + const data = EchoMsg.encodeTransform(request); + const result = this.rpc.bidirectionalStreamingRequest( + this.service, + "EchoBidiStream", + data, + abortSignal || undefined, + ); + return EchoMsg.decodeTransform(result); + } +} + +interface Rpc { + request(service: string, method: string, data: Uint8Array, abortSignal?: AbortSignal): Promise; + clientStreamingRequest( + service: string, + method: string, + data: AsyncIterable, + abortSignal?: AbortSignal, + ): Promise; + serverStreamingRequest( + service: string, + method: string, + data: Uint8Array, + abortSignal?: AbortSignal, + ): AsyncIterable; + bidirectionalStreamingRequest( + service: string, + method: string, + data: AsyncIterable, + abortSignal?: AbortSignal, + ): AsyncIterable; +} + +type Builtin = Date | Function | Uint8Array | string | number | boolean | undefined; + +export type DeepPartial = T extends Builtin ? T + : T extends Array ? Array> : T extends ReadonlyArray ? ReadonlyArray> + : T extends {} ? { [K in keyof T]?: DeepPartial } + : Partial; + +type KeysOfUnion = T extends T ? keyof T : never; +export type Exact = P extends Builtin ? P + : P & { [K in keyof P]: Exact } & { [K in Exclude>]: never }; + +function isSet(value: any): boolean { + return value !== null && value !== undefined; +} diff --git a/src/generate-services.ts b/src/generate-services.ts index ba8a4f311..2e7283424 100644 --- a/src/generate-services.ts +++ b/src/generate-services.ts @@ -62,6 +62,9 @@ export function generateService( const partialInput = options.outputClientImpl === "grpc-web"; const inputType = requestType(ctx, methodDesc, partialInput); params.push(code`request: ${inputType}`); + if (options.useAbortSignal) { + params.push(code`abortSignal?: AbortSignal`); + } // Use metadata as last argument for interface only configuration if (options.outputClientImpl === "grpc-web") { @@ -114,8 +117,13 @@ function generateRegularRpcMethod( const inputType = requestType(ctx, methodDesc); const rawOutputType = responseType(ctx, methodDesc, { keepValueType: true }); - const params = [...(options.context ? [code`ctx: Context`] : []), code`request: ${inputType}`]; + const params = [ + ...(options.context ? [code`ctx: Context`] : []), + code`request: ${inputType}`, + ...(options.useAbortSignal ? [code`abortSignal?: AbortSignal`] : []) + ]; const maybeCtx = options.context ? "ctx," : ""; + const maybeAbortSignal = options.useAbortSignal ? "abortSignal || undefined," : ""; let encode = code`${rawInputType}.encode(request).finish()`; let decode = code`data => ${rawOutputType}.decode(new ${Reader}(data))`; @@ -163,7 +171,8 @@ function generateRegularRpcMethod( ${maybeCtx} this.service, "${methodDesc.name}", - data + data, + ${maybeAbortSignal} ); return ${decode}; } @@ -329,6 +338,7 @@ export function generateRpcType(ctx: Context, hasStreamingMethods: boolean): Cod const { options } = ctx; const maybeContext = options.context ? "" : ""; const maybeContextParam = options.context ? "ctx: Context," : ""; + const maybeAbortSignalParam = options.useAbortSignal ? "abortSignal?: AbortSignal," : ""; const methods = [[code`request`, code`Uint8Array`, code`Promise`]]; if (hasStreamingMethods) { const observable = observableType(ctx); @@ -348,7 +358,8 @@ export function generateRpcType(ctx: Context, hasStreamingMethods: boolean): Cod ${maybeContextParam} service: string, method: string, - data: ${method[1]} + data: ${method[1]}, + ${maybeAbortSignalParam} ): ${method[2]};`); }); chunks.push(code` }`); diff --git a/src/options.ts b/src/options.ts index 20c48ef01..e84fa5059 100644 --- a/src/options.ts +++ b/src/options.ts @@ -65,6 +65,7 @@ export type Options = { outputSchema: boolean; onlyTypes: boolean; emitImportedFiles: boolean; + useAbortSignal: boolean; useExactTypes: boolean; useAsyncIterable: boolean; unknownFields: boolean; @@ -112,6 +113,7 @@ export function defaultOptions(): Options { onlyTypes: false, emitImportedFiles: true, useExactTypes: true, + useAbortSignal: false, useAsyncIterable: false, unknownFields: false, usePrototypeForDefaults: false, diff --git a/tests/options-test.ts b/tests/options-test.ts index 6625b6c55..2bd64ecf1 100644 --- a/tests/options-test.ts +++ b/tests/options-test.ts @@ -40,6 +40,7 @@ describe("options", () => { "stringEnums": false, "unknownFields": false, "unrecognizedEnum": true, + "useAbortSignal": false, "useAsyncIterable": false, "useDate": "timestamp", "useExactTypes": true,