diff --git a/integration/batching-with-context/batching.ts b/integration/batching-with-context/batching.ts index 00edf4034..6c449c201 100644 --- a/integration/batching-with-context/batching.ts +++ b/integration/batching-with-context/batching.ts @@ -1,5 +1,5 @@ /* eslint-disable */ -import { util, configure, Reader, Writer } from 'protobufjs/minimal'; +import { util, configure, Writer, Reader } from 'protobufjs/minimal'; import * as Long from 'long'; import * as DataLoader from 'dataloader'; import * as hash from 'object-hash'; diff --git a/integration/batching/batching.ts b/integration/batching/batching.ts index 25a025910..ce27831bd 100644 --- a/integration/batching/batching.ts +++ b/integration/batching/batching.ts @@ -1,5 +1,5 @@ /* eslint-disable */ -import { util, configure, Reader, Writer } from 'protobufjs/minimal'; +import { util, configure, Writer, Reader } from 'protobufjs/minimal'; import * as Long from 'long'; export const protobufPackage = 'batching'; diff --git a/integration/grpc-web-go-server/example.ts b/integration/grpc-web-go-server/example.ts index 1135f99ba..c4560d92f 100644 --- a/integration/grpc-web-go-server/example.ts +++ b/integration/grpc-web-go-server/example.ts @@ -1,7 +1,8 @@ /* eslint-disable */ -import { util, configure, Reader, Writer } from 'protobufjs/minimal'; +import { util, configure, Writer, Reader } from 'protobufjs/minimal'; import * as Long from 'long'; import { Observable } from 'rxjs'; +import { map } from 'rxjs/operators'; export const protobufPackage = 'rpx'; @@ -744,10 +745,10 @@ export class DashStateClientImpl implements DashState { return promise.then((data) => DashUserSettingsState.decode(new Reader(data))); } - ActiveUserSettingsStream(request: Empty): Promise { + ActiveUserSettingsStream(request: Empty): Observable { const data = Empty.encode(request).finish(); - const promise = this.rpc.request('rpx.DashState', 'ActiveUserSettingsStream', data); - return promise.then((data) => DashUserSettingsState.decode(new Reader(data))); + const result = this.rpc.serverStreamingRequest('rpx.DashState', 'ActiveUserSettingsStream', data); + return result.pipe(map((data) => DashUserSettingsState.decode(new Reader(data)))); } } @@ -791,6 +792,9 @@ export class DashAPICredsClientImpl implements DashAPICreds { interface Rpc { request(service: string, method: string, data: Uint8Array): Promise; + clientStreamingRequest(service: string, method: string, data: Observable): Promise; + serverStreamingRequest(service: string, method: string, data: Uint8Array): Observable; + bidirectionalStreamingRequest(service: string, method: string, data: Observable): Observable; } type Builtin = Date | Function | Uint8Array | string | number | boolean | undefined; diff --git a/integration/lower-case-svc-methods/math.ts b/integration/lower-case-svc-methods/math.ts index a82e96343..b64b5a085 100644 --- a/integration/lower-case-svc-methods/math.ts +++ b/integration/lower-case-svc-methods/math.ts @@ -1,5 +1,5 @@ /* eslint-disable */ -import { util, configure, Reader, Writer } from 'protobufjs/minimal'; +import { util, configure, Writer, Reader } from 'protobufjs/minimal'; import * as Long from 'long'; import * as DataLoader from 'dataloader'; import * as hash from 'object-hash'; diff --git a/integration/meta-typings/simple.ts b/integration/meta-typings/simple.ts index 448be64c5..fe2b5f42f 100644 --- a/integration/meta-typings/simple.ts +++ b/integration/meta-typings/simple.ts @@ -1,6 +1,6 @@ /* eslint-disable */ import { FileDescriptorProto } from 'ts-proto-descriptors'; -import { util, configure, Reader, Writer } from 'protobufjs/minimal'; +import { util, configure, Writer, Reader } from 'protobufjs/minimal'; import * as Long from 'long'; import { protoMetadata as protoMetadata1, DateMessage } from './google/type/date'; import { protoMetadata as protoMetadata2, StringValue, Int32Value, BoolValue } from './google/protobuf/wrappers'; diff --git a/integration/no-proto-package/no-proto-package.ts b/integration/no-proto-package/no-proto-package.ts index 4b5c6a19a..20055d525 100644 --- a/integration/no-proto-package/no-proto-package.ts +++ b/integration/no-proto-package/no-proto-package.ts @@ -1,7 +1,8 @@ /* eslint-disable */ -import { util, configure, Reader, Writer } from 'protobufjs/minimal'; +import { util, configure, Writer, Reader } from 'protobufjs/minimal'; import * as Long from 'long'; import { Observable } from 'rxjs'; +import { map } from 'rxjs/operators'; export const protobufPackage = ''; @@ -114,15 +115,18 @@ export class UserStateClientImpl implements UserState { this.rpc = rpc; this.GetUsers = this.GetUsers.bind(this); } - GetUsers(request: Empty): Promise { + GetUsers(request: Empty): Observable { const data = Empty.encode(request).finish(); - const promise = this.rpc.request('UserState', 'GetUsers', data); - return promise.then((data) => User.decode(new Reader(data))); + const result = this.rpc.serverStreamingRequest('UserState', 'GetUsers', data); + return result.pipe(map((data) => User.decode(new Reader(data)))); } } interface Rpc { request(service: string, method: string, data: Uint8Array): Promise; + clientStreamingRequest(service: string, method: string, data: Observable): Promise; + serverStreamingRequest(service: string, method: string, data: Uint8Array): Observable; + bidirectionalStreamingRequest(service: string, method: string, data: Observable): Observable; } type Builtin = Date | Function | Uint8Array | string | number | boolean | undefined; diff --git a/integration/simple-optionals/simple.ts b/integration/simple-optionals/simple.ts index edb7bea62..7753bf0a6 100644 --- a/integration/simple-optionals/simple.ts +++ b/integration/simple-optionals/simple.ts @@ -1,5 +1,5 @@ /* eslint-disable */ -import { util, configure, Reader, Writer } from 'protobufjs/minimal'; +import { util, configure, Writer, Reader } from 'protobufjs/minimal'; import * as Long from 'long'; import { ImportedThing } from './import_dir/thing'; import { Timestamp } from './google/protobuf/timestamp'; diff --git a/integration/simple-snake/simple.ts b/integration/simple-snake/simple.ts index a1c7a5cb4..f2850deb7 100644 --- a/integration/simple-snake/simple.ts +++ b/integration/simple-snake/simple.ts @@ -1,5 +1,5 @@ /* eslint-disable */ -import { util, configure, Reader, Writer } from 'protobufjs/minimal'; +import { util, configure, Writer, Reader } from 'protobufjs/minimal'; import * as Long from 'long'; import { Timestamp } from './google/protobuf/timestamp'; import { ImportedThing } from './import_dir/thing'; diff --git a/integration/simple-unrecognized-enum/simple.ts b/integration/simple-unrecognized-enum/simple.ts index 58cf9a554..af24e8467 100644 --- a/integration/simple-unrecognized-enum/simple.ts +++ b/integration/simple-unrecognized-enum/simple.ts @@ -1,5 +1,5 @@ /* eslint-disable */ -import { util, configure, Reader, Writer } from 'protobufjs/minimal'; +import { util, configure, Writer, Reader } from 'protobufjs/minimal'; import * as Long from 'long'; import { Timestamp } from './google/protobuf/timestamp'; import { ImportedThing } from './import_dir/thing'; diff --git a/integration/simple/simple.ts b/integration/simple/simple.ts index a8480f434..63b55dc51 100644 --- a/integration/simple/simple.ts +++ b/integration/simple/simple.ts @@ -1,5 +1,5 @@ /* eslint-disable */ -import { util, configure, Reader, Writer } from 'protobufjs/minimal'; +import { util, configure, Writer, Reader } from 'protobufjs/minimal'; import * as Long from 'long'; import { Timestamp } from './google/protobuf/timestamp'; import { ImportedThing } from './import_dir/thing'; diff --git a/src/generate-grpc-web.ts b/src/generate-grpc-web.ts index fb5bae222..0ffbc2f96 100644 --- a/src/generate-grpc-web.ts +++ b/src/generate-grpc-web.ts @@ -1,5 +1,5 @@ import { MethodDescriptorProto, FileDescriptorProto, ServiceDescriptorProto } from 'ts-proto-descriptors'; -import { requestType, responseObservable, responsePromise, responseType } from './types'; +import { requestType, responsePromiseOrObservable, responseType } from './types'; import { Code, code, imp, joinCode } from 'ts-poet'; import { Context } from './context'; import { assertInstanceOf, FormattedMethodDescriptor, maybePrefixPackage } from './utils'; @@ -52,10 +52,7 @@ function generateRpcMethod(ctx: Context, serviceDesc: ServiceDescriptorProto, me const { options, utils } = ctx; const inputType = requestType(ctx, methodDesc); const partialInputType = code`${utils.DeepPartial}<${inputType}>`; - const returns = - options.returnObservable || methodDesc.serverStreaming - ? responseObservable(ctx, methodDesc) - : responsePromise(ctx, methodDesc); + const returns = responsePromiseOrObservable(ctx, methodDesc); const method = methodDesc.serverStreaming ? 'invoke' : 'unary'; return code` ${methodDesc.formattedName}( diff --git a/src/generate-services.ts b/src/generate-services.ts index 3b015099f..1d3b05cc2 100644 --- a/src/generate-services.ts +++ b/src/generate-services.ts @@ -4,8 +4,8 @@ import { BatchMethod, detectBatchMethod, requestType, - responseObservable, - responsePromise, + rawRequestType, + responsePromiseOrObservable, responseType, } from './types'; import { assertInstanceOf, FormattedMethodDescriptor, maybeAddComment, maybePrefixPackage, singular } from './utils'; @@ -72,15 +72,12 @@ export function generateService( params.push(code`...rest: any`); } - // Return observable for interface only configuration, passing returnObservable=true and methodDesc.serverStreaming=true - let returnType: Code; - if (options.returnObservable || methodDesc.serverStreaming) { - returnType = responseObservable(ctx, methodDesc); - } else { - returnType = responsePromise(ctx, methodDesc); - } - - chunks.push(code`${methodDesc.formattedName}(${joinCode(params, { on: ',' })}): ${returnType};`); + chunks.push( + code`${methodDesc.formattedName}(${joinCode(params, { on: ',' })}): ${responsePromiseOrObservable( + ctx, + methodDesc + )};` + ); // If this is a batch method, auto-generate the singular version of it if (options.context) { @@ -108,24 +105,51 @@ function generateRegularRpcMethod( assertInstanceOf(methodDesc, FormattedMethodDescriptor); const { options } = ctx; const Reader = imp('Reader@protobufjs/minimal'); + const rawInputType = rawRequestType(ctx, methodDesc); const inputType = requestType(ctx, methodDesc); const outputType = responseType(ctx, methodDesc); const params = [...(options.context ? [code`ctx: Context`] : []), code`request: ${inputType}`]; const maybeCtx = options.context ? 'ctx,' : ''; + let encode = code`${rawInputType}.encode(request).finish()`; + let decode = code`data => ${outputType}.decode(new ${Reader}(data))`; + + if (methodDesc.clientStreaming) { + encode = code`request.pipe(${imp('map@rxjs/operators')}(request => ${encode}))`; + } + let returnVariable: string; + if (options.returnObservable || methodDesc.serverStreaming) { + returnVariable = 'result'; + decode = code`result.pipe(${imp('map@rxjs/operators')}(${decode}))`; + } else { + returnVariable = 'promise'; + decode = code`promise.then(${decode})`; + } + + let rpcMethod: string; + if (methodDesc.clientStreaming && methodDesc.serverStreaming) { + rpcMethod = 'bidirectionalStreamingRequest'; + } else if (methodDesc.serverStreaming) { + rpcMethod = 'serverStreamingRequest'; + } else if (methodDesc.clientStreaming) { + rpcMethod = 'clientStreamingRequest'; + } else { + rpcMethod = 'request'; + } + return code` ${methodDesc.formattedName}( ${joinCode(params, { on: ',' })} - ): ${responsePromise(ctx, methodDesc)} { - const data = ${inputType}.encode(request).finish(); - const promise = this.rpc.request( + ): ${responsePromiseOrObservable(ctx, methodDesc)} { + const data = ${encode}; + const ${returnVariable} = this.rpc.${rpcMethod}( ${maybeCtx} "${maybePrefixPackage(fileDesc, serviceDesc.name)}", "${methodDesc.name}", data ); - return promise.then(data => ${outputType}.decode(new ${Reader}(data))); + return ${decode}; } `; } @@ -273,24 +297,41 @@ function generateCachingRpcMethod( * * This lets clients pass in their own request-promise-ish client. * + * This also requires clientStreamingRequest, serverStreamingRequest and + * bidirectionalStreamingRequest methods if any of the RPCs is streaming. + * * We don't export this because if a project uses multiple `*.proto` files, * we don't want our the barrel imports in `index.ts` to have multiple `Rpc` * types. */ -export function generateRpcType(ctx: Context): Code { +export function generateRpcType(ctx: Context, hasStreamingMethods: boolean): Code { const { options } = ctx; const maybeContext = options.context ? '' : ''; const maybeContextParam = options.context ? 'ctx: Context,' : ''; - return code` - interface Rpc${maybeContext} { - request( + const methods = [[code`request`, code`Uint8Array`, code`Promise`]]; + if (hasStreamingMethods) { + const observable = imp('Observable@rxjs'); + methods.push([code`clientStreamingRequest`, code`${observable}`, code`Promise`]); + methods.push([code`serverStreamingRequest`, code`Uint8Array`, code`${observable}`]); + methods.push([ + code`bidirectionalStreamingRequest`, + code`${observable}`, + code`${observable}`, + ]); + } + const chunks: Code[] = []; + chunks.push(code` interface Rpc${maybeContext} {`); + methods.forEach((method) => { + chunks.push(code` + ${method[0]}( ${maybeContextParam} service: string, method: string, - data: Uint8Array - ): Promise; - } - `; + data: ${method[1]} + ): ${method[2]};`); + }); + chunks.push(code` }`); + return joinCode(chunks, { on: '\n' }); } export function generateDataLoadersType(): Code { diff --git a/src/main.ts b/src/main.ts index f99a46259..111f0218f 100644 --- a/src/main.ts +++ b/src/main.ts @@ -164,6 +164,7 @@ export function generateFile(ctx: Context, fileDesc: FileDescriptorProto): [stri ); } + let hasServerStreamingMethods = false; let hasStreamingMethods = false; visitServices(fileDesc, sourceInfo, (serviceDesc, sInfo) => { @@ -200,18 +201,23 @@ export function generateFile(ctx: Context, fileDesc: FileDescriptorProto): [stri serviceDesc.method.forEach((method) => { chunks.push(generateGrpcMethodDesc(ctx, serviceDesc, method)); if (method.serverStreaming) { - hasStreamingMethods = true; + hasServerStreamingMethods = true; } }); } } + serviceDesc.method.forEach((methodDesc, index) => { + if (methodDesc.serverStreaming || methodDesc.clientStreaming) { + hasStreamingMethods = true; + } + }); }); if (options.outputServices === ServiceOption.DEFAULT && options.outputClientImpl && fileDesc.service.length > 0) { if (options.outputClientImpl === true) { - chunks.push(generateRpcType(ctx)); + chunks.push(generateRpcType(ctx, hasStreamingMethods)); } else if (options.outputClientImpl === 'grpc-web') { - chunks.push(addGrpcWebMisc(ctx, hasStreamingMethods)); + chunks.push(addGrpcWebMisc(ctx, hasServerStreamingMethods)); } } diff --git a/src/types.ts b/src/types.ts index 5b882a7ed..6e162b074 100644 --- a/src/types.ts +++ b/src/types.ts @@ -537,8 +537,12 @@ export function detectMapType( return undefined; } +export function rawRequestType(ctx: Context, methodDesc: MethodDescriptorProto): Code { + return messageToTypeName(ctx, methodDesc.inputType); +} + export function requestType(ctx: Context, methodDesc: MethodDescriptorProto): Code { - let typeName = messageToTypeName(ctx, methodDesc.inputType); + let typeName = rawRequestType(ctx, methodDesc); if (methodDesc.clientStreaming) { return code`${imp('Observable@rxjs')}<${typeName}>`; } @@ -557,6 +561,14 @@ export function responseObservable(ctx: Context, methodDesc: MethodDescriptorPro return code`${imp('Observable@rxjs')}<${responseType(ctx, methodDesc)}>`; } +export function responsePromiseOrObservable(ctx: Context, methodDesc: MethodDescriptorProto): Code { + const { options } = ctx; + if (options.returnObservable || methodDesc.serverStreaming) { + return responseObservable(ctx, methodDesc); + } + return responsePromise(ctx, methodDesc); +} + export interface BatchMethod { methodDesc: MethodDescriptorProto; // a ${package + service + method name} key to identify this method in caches