Skip to content

Commit

Permalink
feat: add option to use async iterables
Browse files Browse the repository at this point in the history
Adds option useAsyncIterable which uses AsyncIterable instead of Observable.

For example:

  bidirectionalStreamingRequest(
    service: string,
    method: string,
    data: AsyncIterable<Uint8Array>
  ): AsyncIterable<Uint8Array>

Generates Transform async iterables for encoding and decoding:

  // encodeTransform encodes a source of message objects.
  // Transform<TestMessage, Uint8Array>
  async *encodeTransform(
    source: AsyncIterable<TestMessage | TestMessage[]> | Iterable<TestMessage | TestMessage[]>
  ): AsyncIterable<Uint8Array> {
    for await (const pkt of source) {
      if (Array.isArray(pkt)) {
        for (const p of pkt) {
          yield* [TestMessage.encode(p).finish()];
        }
      } else {
        yield* [TestMessage.encode(pkt).finish()];
      }
    }
  },

  // decodeTransform decodes a source of encoded messages.
  // Transform<Uint8Array, TestMessage>
  async *decodeTransform(
    source: AsyncIterable<Uint8Array | Uint8Array[]> | Iterable<Uint8Array | Uint8Array[]>
  ): AsyncIterable<TestMessage> {
    for await (const pkt of source) {
      if (Array.isArray(pkt)) {
        for (const p of pkt) {
          yield* [TestMessage.decode(p)];
        }
      } else {
        yield* [TestMessage.decode(pkt)];
      }
    }
  },

Generates RPC service implementations which use the Transform iterators:

  BidiStreaming(request: AsyncIterable<TestMessage>): AsyncIterable<TestMessage> {
    const data = TestMessage.encodeTransform(request);
    const result = this.rpc.bidirectionalStreamingRequest('simple.Test', 'BidiStreaming', data);
    return TestMessage.decodeTransform(result);
  }

AsyncIterables indicate a stream has ended by closing with an optional error.

Fixes stephenh#600

Signed-off-by: Christian Stewart <christian@paral.in>
  • Loading branch information
paralin committed Jul 1, 2022
1 parent e2323b6 commit cc81e2c
Show file tree
Hide file tree
Showing 12 changed files with 240 additions and 23 deletions.
2 changes: 2 additions & 0 deletions README.markdown
Original file line number Diff line number Diff line change
Expand Up @@ -381,6 +381,8 @@ Generated code will be placed in the Gradle build directory.

- With `--ts_proto_opt=outputServices=false`, or `=none`, ts-proto will output NO service definitions.

- With `--ts_proto_opt=useAsyncIterable=true`, the generated services will use `AsyncIterable` instead of `Observable`.

- With `--ts_proto_opt=emitImportedFiles=false`, ts-proto will not emit `google/protobuf/*` files unless you explicit add files to `protoc` like this
`protoc --plugin=./node_modules/.bin/protoc-gen-ts_proto my_message.proto google/protobuf/duration.proto`

Expand Down
1 change: 1 addition & 0 deletions integration/async-iterable-services/parameters.txt
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
useAsyncIterable=true
Binary file added integration/async-iterable-services/simple.bin
Binary file not shown.
11 changes: 11 additions & 0 deletions integration/async-iterable-services/simple.proto
Original file line number Diff line number Diff line change
@@ -0,0 +1,11 @@
syntax = "proto3";

package simple;

service Test {
rpc BidiStreaming(stream TestMessage) returns (stream TestMessage) {}
}

message TestMessage {
string value = 1;
}
138 changes: 138 additions & 0 deletions integration/async-iterable-services/simple.ts
Original file line number Diff line number Diff line change
@@ -0,0 +1,138 @@
/* eslint-disable */
import * as _m0 from 'protobufjs/minimal';

export const protobufPackage = 'simple';

export interface TestMessage {
value: string;
}

function createBaseTestMessage(): TestMessage {
return { value: '' };
}

export const TestMessage = {
encode(message: TestMessage, writer: _m0.Writer = _m0.Writer.create()): _m0.Writer {
if (message.value !== '') {
writer.uint32(10).string(message.value);
}
return writer;
},

decode(input: _m0.Reader | Uint8Array, length?: number): TestMessage {
const reader = input instanceof _m0.Reader ? input : new _m0.Reader(input);
let end = length === undefined ? reader.len : reader.pos + length;
const message = createBaseTestMessage();
while (reader.pos < end) {
const tag = reader.uint32();
switch (tag >>> 3) {
case 1:
message.value = reader.string();
break;
default:
reader.skipType(tag & 7);
break;
}
}
return message;
},

// encodeTransform encodes a source of message objects.
// Transform<TestMessage, Uint8Array>
async *encodeTransform(
source: AsyncIterable<TestMessage | TestMessage[]> | Iterable<TestMessage | TestMessage[]>
): AsyncIterable<Uint8Array> {
for await (const pkt of source) {
if (Array.isArray(pkt)) {
for (const p of pkt) {
yield* [TestMessage.encode(p).finish()];
}
} else {
yield* [TestMessage.encode(pkt).finish()];
}
}
},

// decodeTransform decodes a source of encoded messages.
// Transform<Uint8Array, TestMessage>
async *decodeTransform(
source: AsyncIterable<Uint8Array | Uint8Array[]> | Iterable<Uint8Array | Uint8Array[]>
): AsyncIterable<TestMessage> {
for await (const pkt of source) {
if (Array.isArray(pkt)) {
for (const p of pkt) {
yield* [TestMessage.decode(p)];
}
} else {
yield* [TestMessage.decode(pkt)];
}
}
},

fromJSON(object: any): TestMessage {
return {
value: isSet(object.value) ? String(object.value) : '',
};
},

toJSON(message: TestMessage): unknown {
const obj: any = {};
message.value !== undefined && (obj.value = message.value);
return obj;
},

fromPartial<I extends Exact<DeepPartial<TestMessage>, I>>(object: I): TestMessage {
const message = createBaseTestMessage();
message.value = object.value ?? '';
return message;
},
};

export interface Test {
BidiStreaming(request: AsyncIterable<TestMessage>): AsyncIterable<TestMessage>;
}

export class TestClientImpl implements Test {
private readonly rpc: Rpc;
constructor(rpc: Rpc) {
this.rpc = rpc;
this.BidiStreaming = this.BidiStreaming.bind(this);
}
BidiStreaming(request: AsyncIterable<TestMessage>): AsyncIterable<TestMessage> {
const data = TestMessage.encodeTransform(request);
const result = this.rpc.bidirectionalStreamingRequest('simple.Test', 'BidiStreaming', data);
return TestMessage.decodeTransform(result);
}
}

interface Rpc {
request(service: string, method: string, data: Uint8Array): Promise<Uint8Array>;
clientStreamingRequest(service: string, method: string, data: AsyncIterable<Uint8Array>): Promise<Uint8Array>;
serverStreamingRequest(service: string, method: string, data: Uint8Array): AsyncIterable<Uint8Array>;
bidirectionalStreamingRequest(
service: string,
method: string,
data: AsyncIterable<Uint8Array>
): AsyncIterable<Uint8Array>;
}

type Builtin = Date | Function | Uint8Array | string | number | boolean | undefined;

export type DeepPartial<T> = T extends Builtin
? T
: T extends Array<infer U>
? Array<DeepPartial<U>>
: T extends ReadonlyArray<infer U>
? ReadonlyArray<DeepPartial<U>>
: T extends {}
? { [K in keyof T]?: DeepPartial<T[K]> }
: Partial<T>;

type KeysOfUnion<T> = T extends T ? keyof T : never;
export type Exact<P, I extends P> = P extends Builtin
? P
: P & { [K in keyof P]: Exact<P[K], I[K]> } & Record<Exclude<keyof I, KeysOfUnion<P>>, never>;

function isSet(value: any): boolean {
return value !== null && value !== undefined;
}
43 changes: 43 additions & 0 deletions src/generate-async-iterable.ts
Original file line number Diff line number Diff line change
@@ -0,0 +1,43 @@
import { code, Code } from 'ts-poet';

/** Creates a function to transform a message Source to a Uint8Array Source. */
export function generateEncodeTransform(fullName: string): Code {
return code`
// encodeTransform encodes a source of message objects.
// Transform<${fullName}, Uint8Array>
async *encodeTransform(
source: AsyncIterable<${fullName} | ${fullName}[]> | Iterable<${fullName} | ${fullName}[]>
): AsyncIterable<Uint8Array> {
for await (const pkt of source) {
if (Array.isArray(pkt)) {
for (const p of pkt) {
yield* [${fullName}.encode(p).finish()]
}
} else {
yield* [${fullName}.encode(pkt).finish()]
}
}
}
`;
}

/** Creates a function to transform a Uint8Array Source to a message Source. */
export function generateDecodeTransform(fullName: string): Code {
return code`
// decodeTransform decodes a source of encoded messages.
// Transform<Uint8Array, ${fullName}>
async *decodeTransform(
source: AsyncIterable<Uint8Array | Uint8Array[]> | Iterable<Uint8Array | Uint8Array[]>
): AsyncIterable<${fullName}> {
for await (const pkt of source) {
if (Array.isArray(pkt)) {
for (const p of pkt) {
yield* [${fullName}.decode(p)]
}
} else {
yield* [${fullName}.decode(pkt)]
}
}
}
`;
}
29 changes: 14 additions & 15 deletions src/generate-grpc-web.ts
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
import { MethodDescriptorProto, FileDescriptorProto, ServiceDescriptorProto } from 'ts-proto-descriptors';
import { rawRequestType, requestType, responsePromiseOrObservable, responseType } from './types';
import { rawRequestType, requestType, responsePromiseOrObservable, responseType, observableType } from './types';
import { Code, code, imp, joinCode } from 'ts-poet';
import { Context } from './context';
import { assertInstanceOf, FormattedMethodDescriptor, maybePrefixPackage } from './utils';
Expand All @@ -8,12 +8,11 @@ const grpc = imp('grpc@@improbable-eng/grpc-web');
const share = imp('share@rxjs/operators');
const take = imp('take@rxjs/operators');
const BrowserHeaders = imp('BrowserHeaders@browser-headers');
const Observable = imp('Observable@rxjs');

/** Generates a client that uses the `@improbable-web/grpc-web` library. */
export function generateGrpcClientImpl(
ctx: Context,
fileDesc: FileDescriptorProto,
_fileDesc: FileDescriptorProto,
serviceDesc: ServiceDescriptorProto
): Code {
const chunks: Code[] = [];
Expand Down Expand Up @@ -154,18 +153,18 @@ export function addGrpcWebMisc(ctx: Context, hasStreamingMethods: boolean): Code
interface UnaryMethodDefinitionishR extends ${grpc}.UnaryMethodDefinition<any, any> { requestStream: any; responseStream: any; }
`);
chunks.push(code`type UnaryMethodDefinitionish = UnaryMethodDefinitionishR;`);
chunks.push(generateGrpcWebRpcType(options.returnObservable, hasStreamingMethods));
chunks.push(generateGrpcWebImpl(options.returnObservable, hasStreamingMethods));
chunks.push(generateGrpcWebRpcType(ctx, options.returnObservable, hasStreamingMethods));
chunks.push(generateGrpcWebImpl(ctx, options.returnObservable, hasStreamingMethods));
return joinCode(chunks, { on: '\n\n' });
}

/** Makes an `Rpc` interface to decouple from the low-level grpc-web `grpc.invoke and grpc.unary`/etc. methods. */
function generateGrpcWebRpcType(returnObservable: boolean, hasStreamingMethods: boolean): Code {
function generateGrpcWebRpcType(ctx: Context, returnObservable: boolean, hasStreamingMethods: boolean): Code {
const chunks: Code[] = [];

chunks.push(code`interface Rpc {`);

const wrapper = returnObservable ? Observable : 'Promise';
const wrapper = returnObservable ? observableType(ctx) : 'Promise';
chunks.push(code`
unary<T extends UnaryMethodDefinitionish>(
methodDesc: T,
Expand All @@ -180,7 +179,7 @@ function generateGrpcWebRpcType(returnObservable: boolean, hasStreamingMethods:
methodDesc: T,
request: any,
metadata: grpc.Metadata | undefined,
): ${Observable}<any>;
): ${observableType(ctx)}<any>;
`);
}

Expand All @@ -189,7 +188,7 @@ function generateGrpcWebRpcType(returnObservable: boolean, hasStreamingMethods:
}

/** Implements the `Rpc` interface by making calls using the `grpc.unary` method. */
function generateGrpcWebImpl(returnObservable: boolean, hasStreamingMethods: boolean): Code {
function generateGrpcWebImpl(ctx: Context, returnObservable: boolean, hasStreamingMethods: boolean): Code {
const options = code`
{
transport?: grpc.TransportFactory,
Expand All @@ -212,13 +211,13 @@ function generateGrpcWebImpl(returnObservable: boolean, hasStreamingMethods: boo
`);

if (returnObservable) {
chunks.push(createObservableUnaryMethod());
chunks.push(createObservableUnaryMethod(ctx));
} else {
chunks.push(createPromiseUnaryMethod());
}

if (hasStreamingMethods) {
chunks.push(createInvokeMethod());
chunks.push(createInvokeMethod(ctx));
}

chunks.push(code`}`);
Expand Down Expand Up @@ -260,13 +259,13 @@ function createPromiseUnaryMethod(): Code {
`;
}

function createObservableUnaryMethod(): Code {
function createObservableUnaryMethod(ctx: Context): Code {
return code`
unary<T extends UnaryMethodDefinitionish>(
methodDesc: T,
_request: any,
metadata: grpc.Metadata | undefined
): ${Observable}<any> {
): ${observableType(ctx)}<any> {
const request = { ..._request, ...methodDesc.requestType };
const maybeCombinedMetadata =
metadata && this.options.metadata
Expand All @@ -293,13 +292,13 @@ function createObservableUnaryMethod(): Code {
`;
}

function createInvokeMethod() {
function createInvokeMethod(ctx: Context) {
return code`
invoke<T extends UnaryMethodDefinitionish>(
methodDesc: T,
_request: any,
metadata: grpc.Metadata | undefined
): ${Observable}<any> {
): ${observableType(ctx)}<any> {
// Status Response Codes (https://developers.google.com/maps-booking/reference/grpc-api/status_codes)
const upStreamCodes = [2, 4, 8, 9, 10, 13, 14, 15];
const DEFAULT_TIMEOUT_TIME: number = 3_000;
Expand Down
20 changes: 14 additions & 6 deletions src/generate-services.ts
Original file line number Diff line number Diff line change
Expand Up @@ -7,10 +7,10 @@ import {
rawRequestType,
responsePromiseOrObservable,
responseType,
observableType,
} from './types';
import { assertInstanceOf, FormattedMethodDescriptor, maybeAddComment, maybePrefixPackage, singular } from './utils';
import SourceInfo, { Fields } from './sourceInfo';
import { camelCase } from './case';
import { contextTypeVar } from './main';
import { Context } from './context';

Expand All @@ -34,7 +34,7 @@ export function generateService(
sourceInfo: SourceInfo,
serviceDesc: ServiceDescriptorProto
): Code {
const { options, utils } = ctx;
const { options } = ctx;
const chunks: Code[] = [];

maybeAddComment(sourceInfo, chunks, serviceDesc.options?.deprecated);
Expand Down Expand Up @@ -121,12 +121,20 @@ function generateRegularRpcMethod(
decode = code`data => ${utils.fromTimestamp}(${rawOutputType}.decode(new ${Reader}(data)))`;
}
if (methodDesc.clientStreaming) {
encode = code`request.pipe(${imp('map@rxjs/operators')}(request => ${encode}))`;
if (options.useAsyncIterable) {
encode = code`${rawInputType}.encodeTransform(request)`;
} else {
encode = code`request.pipe(${imp('map@rxjs/operators')}(request => ${encode}))`;
}
}
let returnVariable: string;
if (options.returnObservable || methodDesc.serverStreaming) {
returnVariable = 'result';
decode = code`result.pipe(${imp('map@rxjs/operators')}(${decode}))`;
if (options.useAsyncIterable) {
decode = code`${rawOutputType}.decodeTransform(result)`;
} else {
decode = code`result.pipe(${imp('map@rxjs/operators')}(${decode}))`;
}
} else {
returnVariable = 'promise';
decode = code`promise.then(${decode})`;
Expand Down Expand Up @@ -207,7 +215,7 @@ export function generateServiceClientImpl(
}

/** We've found a BatchXxx method, create a synthetic GetXxx method that calls it. */
function generateBatchingRpcMethod(ctx: Context, batchMethod: BatchMethod): Code {
function generateBatchingRpcMethod(_ctx: Context, batchMethod: BatchMethod): Code {
const {
methodDesc,
singleMethodName,
Expand Down Expand Up @@ -315,7 +323,7 @@ export function generateRpcType(ctx: Context, hasStreamingMethods: boolean): Cod
const maybeContextParam = options.context ? 'ctx: Context,' : '';
const methods = [[code`request`, code`Uint8Array`, code`Promise<Uint8Array>`]];
if (hasStreamingMethods) {
const observable = imp('Observable@rxjs');
const observable = observableType(ctx);
methods.push([code`clientStreamingRequest`, code`${observable}<Uint8Array>`, code`Promise<Uint8Array>`]);
methods.push([code`serverStreamingRequest`, code`Uint8Array`, code`${observable}<Uint8Array>`]);
methods.push([
Expand Down
5 changes: 5 additions & 0 deletions src/main.ts
Original file line number Diff line number Diff line change
Expand Up @@ -67,6 +67,7 @@ import {
generateGrpcMethodDesc,
generateGrpcServiceDesc,
} from './generate-grpc-web';
import { generateEncodeTransform, generateDecodeTransform } from './generate-async-iterable';
import { generateEnum } from './enums';
import { visit, visitServices } from './visit';
import { DateOption, EnvOption, LongOption, OneofOption, Options, ServiceOption } from './options';
Expand Down Expand Up @@ -165,6 +166,10 @@ export function generateFile(ctx: Context, fileDesc: FileDescriptorProto): [stri
staticMembers.push(generateEncode(ctx, fullName, message));
staticMembers.push(generateDecode(ctx, fullName, message));
}
if (options.useAsyncIterable) {
staticMembers.push(generateEncodeTransform(fullName));
staticMembers.push(generateDecodeTransform(fullName));
}
if (options.outputJsonMethods) {
staticMembers.push(generateFromJson(ctx, fullName, fullTypeName, message));
staticMembers.push(generateToJson(ctx, fullName, fullTypeName, message));
Expand Down
Loading

0 comments on commit cc81e2c

Please sign in to comment.