From e87cdaf5f36383a779787cf5badfade87aa8befe Mon Sep 17 00:00:00 2001 From: Timo Stamm Date: Tue, 16 May 2023 12:38:53 +0200 Subject: [PATCH 1/2] Support timeouts in clients --- packages/connect-web-bench/README.md | 2 +- packages/connect-web/src/connect-transport.ts | 257 +++++++-------- .../connect-web/src/grpc-web-transport.ts | 166 +++++----- packages/connect/src/index.ts | 2 +- packages/connect/src/interceptor.ts | 70 ---- packages/connect/src/legacy-interceptor.ts | 96 ++++++ .../connect/src/protocol-connect/transport.ts | 237 +++++++------- .../src/protocol-grpc-web/transport.ts | 302 ++++++++---------- .../connect/src/protocol-grpc/transport.ts | 186 +++++------ packages/connect/src/protocol/index.ts | 1 + .../connect/src/protocol/run-call.spec.ts | 223 +++++++++++++ packages/connect/src/protocol/run-call.ts | 186 +++++++++++ packages/connect/src/protocol/signals.spec.ts | 6 + packages/connect/src/protocol/signals.ts | 10 +- 14 files changed, 1050 insertions(+), 694 deletions(-) create mode 100644 packages/connect/src/legacy-interceptor.ts create mode 100644 packages/connect/src/protocol/run-call.spec.ts create mode 100644 packages/connect/src/protocol/run-call.ts diff --git a/packages/connect-web-bench/README.md b/packages/connect-web-bench/README.md index 934ea3998..d806b8d88 100644 --- a/packages/connect-web-bench/README.md +++ b/packages/connect-web-bench/README.md @@ -10,5 +10,5 @@ it like a web server would usually do. | code generator | bundle size | minified | compressed | |----------------|-------------------:|-----------------------:|---------------------:| -| connect | 109,228 b | 47,812 b | 12,800 b | +| connect | 111,637 b | 49,004 b | 13,163 b | | grpc-web | 414,906 b | 301,127 b | 53,279 b | diff --git a/packages/connect-web/src/connect-transport.ts b/packages/connect-web/src/connect-transport.ts index 8ecddd0ed..59bc3e854 100644 --- a/packages/connect-web/src/connect-transport.ts +++ b/packages/connect-web/src/connect-transport.ts @@ -12,7 +12,6 @@ // See the License for the specific language governing permissions and // limitations under the License. -import { Message, MethodIdempotency, MethodKind } from "@bufbuild/protobuf"; import type { AnyMessage, BinaryReadOptions, @@ -24,13 +23,7 @@ import type { PartialMessage, ServiceType, } from "@bufbuild/protobuf"; -import { - appendHeaders, - Code, - connectErrorFromReason, - runStreaming, - runUnary, -} from "@bufbuild/connect"; +import { Message, MethodIdempotency, MethodKind } from "@bufbuild/protobuf"; import type { Interceptor, StreamResponse, @@ -38,20 +31,23 @@ import type { UnaryRequest, UnaryResponse, } from "@bufbuild/connect"; +import { appendHeaders } from "@bufbuild/connect"; import { createClientMethodSerializers, createEnvelopeReadableStream, createMethodUrl, encodeEnvelope, + runStreamingCall, + runUnaryCall, } from "@bufbuild/connect/protocol"; import { - requestHeader, endStreamFlag, endStreamFromJson, errorFromJson, + requestHeader, trailerDemux, - validateResponse, transformConnectPostToGetRequest, + validateResponse, } from "@bufbuild/connect/protocol-connect"; import { assertFetchApi } from "./assert-fetch-api.js"; @@ -145,78 +141,75 @@ export function createConnectTransport( options.jsonOptions, options.binaryOptions ); - try { - return await runUnary( - { - stream: false, - service, - method, - url: createMethodUrl(options.baseUrl, service, method), - init: { - method: "POST", - credentials: options.credentials ?? "same-origin", - redirect: "error", - mode: "cors", - }, - header: requestHeader( - method.kind, - useBinaryFormat, - timeoutMs, - header - ), - message: normalize(message), - signal: signal ?? new AbortController().signal, + return await runUnaryCall({ + interceptors: options.interceptors, + signal, + timeoutMs, + req: { + stream: false, + service, + method, + url: createMethodUrl(options.baseUrl, service, method), + init: { + method: "POST", + credentials: options.credentials ?? "same-origin", + redirect: "error", + mode: "cors", }, - async (req: UnaryRequest): Promise> => { - const useGet = - options.useHttpGet === true && - method.idempotency === MethodIdempotency.NoSideEffects; - let body: BodyInit | null = null; - if (useGet) { - req = transformConnectPostToGetRequest( - req, - serialize(req.message), - useBinaryFormat - ); - } else { - body = serialize(req.message); - } - const response = await fetch(req.url, { - ...req.init, - headers: req.header, - signal: req.signal, - body, - }); - const { isUnaryError, unaryError } = validateResponse( - method.kind, - useBinaryFormat, - response.status, - response.headers + header: requestHeader( + method.kind, + useBinaryFormat, + timeoutMs, + header + ), + message: normalize(message), + }, + next: async (req: UnaryRequest): Promise> => { + const useGet = + options.useHttpGet === true && + method.idempotency === MethodIdempotency.NoSideEffects; + let body: BodyInit | null = null; + if (useGet) { + req = transformConnectPostToGetRequest( + req, + serialize(req.message), + useBinaryFormat ); - if (isUnaryError) { - throw errorFromJson( - (await response.json()) as JsonValue, - appendHeaders(...trailerDemux(response.headers)), - unaryError - ); - } - const [demuxedHeader, demuxedTrailer] = trailerDemux( - response.headers + } else { + body = serialize(req.message); + } + const response = await fetch(req.url, { + ...req.init, + headers: req.header, + signal: req.signal, + body, + }); + const { isUnaryError, unaryError } = validateResponse( + method.kind, + useBinaryFormat, + response.status, + response.headers + ); + if (isUnaryError) { + throw errorFromJson( + (await response.json()) as JsonValue, + appendHeaders(...trailerDemux(response.headers)), + unaryError ); - return >{ - stream: false, - service, - method, - header: demuxedHeader, - message: parse(new Uint8Array(await response.arrayBuffer())), - trailer: demuxedTrailer, - }; - }, - options.interceptors - ); - } catch (e) { - throw connectErrorFromReason(e, Code.Internal); - } + } + const [demuxedHeader, demuxedTrailer] = trailerDemux( + response.headers + ); + return >{ + stream: false, + service, + method, + header: demuxedHeader, + message: parse(new Uint8Array(await response.arrayBuffer())), + trailer: demuxedTrailer, + }; + }, + }); }, async stream< @@ -242,32 +235,28 @@ export function createConnectTransport( trailerTarget: Headers ) { const reader = createEnvelopeReadableStream(body).getReader(); - try { - let endStreamReceived = false; - for (;;) { - const result = await reader.read(); - if (result.done) { - break; - } - const { flags, data } = result.value; - if ((flags & endStreamFlag) === endStreamFlag) { - endStreamReceived = true; - const endStream = endStreamFromJson(data); - if (endStream.error) { - throw endStream.error; - } - endStream.metadata.forEach((value, key) => - trailerTarget.set(key, value) - ); - continue; - } - yield parse(data); + let endStreamReceived = false; + for (;;) { + const result = await reader.read(); + if (result.done) { + break; } - if (!endStreamReceived) { - throw "missing EndStreamResponse"; + const { flags, data } = result.value; + if ((flags & endStreamFlag) === endStreamFlag) { + endStreamReceived = true; + const endStream = endStreamFromJson(data); + if (endStream.error) { + throw endStream.error; + } + endStream.metadata.forEach((value, key) => + trailerTarget.set(key, value) + ); + continue; } - } catch (e) { - throw connectErrorFromReason(e); + yield parse(data); + } + if (!endStreamReceived) { + throw "missing EndStreamResponse"; } } @@ -283,9 +272,11 @@ export function createConnectTransport( } return encodeEnvelope(0, serialize(r.value)); } - - return runStreaming( - { + return await runStreamingCall({ + interceptors: options.interceptors, + timeoutMs, + signal, + req: { stream: true, service, method, @@ -296,7 +287,6 @@ export function createConnectTransport( redirect: "error", mode: "cors", }, - signal: signal ?? new AbortController().signal, header: requestHeader( method.kind, useBinaryFormat, @@ -305,37 +295,32 @@ export function createConnectTransport( ), message: input, }, - async (req) => { - try { - const fRes = await fetch(req.url, { - ...req.init, - headers: req.header, - signal: req.signal, - body: await createRequestBody(req.message), - }); - validateResponse( - method.kind, - useBinaryFormat, - fRes.status, - fRes.headers - ); - if (fRes.body === null) { - throw "missing response body"; - } - const trailer = new Headers(); - const res: StreamResponse = { - ...req, - header: fRes.headers, - trailer, - message: parseResponseBody(fRes.body, trailer), - }; - return res; - } catch (e) { - throw connectErrorFromReason(e, Code.Internal); + next: async (req) => { + const fRes = await fetch(req.url, { + ...req.init, + headers: req.header, + signal: req.signal, + body: await createRequestBody(req.message), + }); + validateResponse( + method.kind, + useBinaryFormat, + fRes.status, + fRes.headers + ); + if (fRes.body === null) { + throw "missing response body"; } + const trailer = new Headers(); + const res: StreamResponse = { + ...req, + header: fRes.headers, + trailer, + message: parseResponseBody(fRes.body, trailer), + }; + return res; }, - options.interceptors - ).catch((e: unknown) => Promise.reject(connectErrorFromReason(e))); + }); }, }; } diff --git a/packages/connect-web/src/grpc-web-transport.ts b/packages/connect-web/src/grpc-web-transport.ts index cb9d896cc..c8f788770 100644 --- a/packages/connect-web/src/grpc-web-transport.ts +++ b/packages/connect-web/src/grpc-web-transport.ts @@ -12,7 +12,6 @@ // See the License for the specific language governing permissions and // limitations under the License. -import { Message, MethodKind } from "@bufbuild/protobuf"; import type { AnyMessage, BinaryReadOptions, @@ -23,24 +22,22 @@ import type { PartialMessage, ServiceType, } from "@bufbuild/protobuf"; -import type { UnaryRequest } from "@bufbuild/connect"; -import { - Code, - connectErrorFromReason, - runStreaming, - runUnary, -} from "@bufbuild/connect"; +import { Message, MethodKind } from "@bufbuild/protobuf"; import type { Interceptor, StreamResponse, Transport, + UnaryRequest, UnaryResponse, } from "@bufbuild/connect"; +import { connectErrorFromReason } from "@bufbuild/connect"; import { createClientMethodSerializers, createEnvelopeReadableStream, createMethodUrl, encodeEnvelope, + runStreamingCall, + runUnaryCall, } from "@bufbuild/connect/protocol"; import { requestHeader, @@ -141,83 +138,76 @@ export function createGrpcWebTransport( options.jsonOptions, options.binaryOptions ); - try { - return await runUnary( - { - stream: false, - service, - method, - url: createMethodUrl(options.baseUrl, service, method), - init: { - method: "POST", - credentials: options.credentials ?? "same-origin", - redirect: "error", - mode: "cors", - }, - header: requestHeader(useBinaryFormat, timeoutMs, header), - message: normalize(message), - signal: signal ?? new AbortController().signal, + return await runUnaryCall({ + interceptors: options.interceptors, + signal, + timeoutMs, + req: { + stream: false, + service, + method, + url: createMethodUrl(options.baseUrl, service, method), + init: { + method: "POST", + credentials: options.credentials ?? "same-origin", + redirect: "error", + mode: "cors", }, - async (req: UnaryRequest): Promise> => { - const response = await fetch(req.url, { - ...req.init, - headers: req.header, - signal: req.signal, - body: encodeEnvelope(0, serialize(req.message)), - }); - validateResponse( - useBinaryFormat, - response.status, - response.headers - ); - if (!response.body) { - throw "missing response body"; + header: requestHeader(useBinaryFormat, timeoutMs, header), + message: normalize(message), + }, + next: async (req: UnaryRequest): Promise> => { + const response = await fetch(req.url, { + ...req.init, + headers: req.header, + signal: req.signal, + body: encodeEnvelope(0, serialize(req.message)), + }); + validateResponse(useBinaryFormat, response.status, response.headers); + if (!response.body) { + throw "missing response body"; + } + const reader = createEnvelopeReadableStream( + response.body + ).getReader(); + let trailer: Headers | undefined; + let message: O | undefined; + for (;;) { + const r = await reader.read(); + if (r.done) { + break; } - const reader = createEnvelopeReadableStream( - response.body - ).getReader(); - let trailer: Headers | undefined; - let message: O | undefined; - for (;;) { - const r = await reader.read(); - if (r.done) { - break; - } - const { flags, data } = r.value; - if (flags === trailerFlag) { - if (trailer !== undefined) { - throw "extra trailer"; - } - // Unary responses require exactly one response message, but in - // case of an error, it is perfectly valid to have a response body - // that only contains error trailers. - trailer = trailerParse(data); - continue; - } - if (message !== undefined) { - throw "extra message"; + const { flags, data } = r.value; + if (flags === trailerFlag) { + if (trailer !== undefined) { + throw "extra trailer"; } - message = parse(data); - } - if (trailer === undefined) { - throw "missing trailer"; + // Unary responses require exactly one response message, but in + // case of an error, it is perfectly valid to have a response body + // that only contains error trailers. + trailer = trailerParse(data); + continue; } - validateTrailer(trailer); - if (message === undefined) { - throw "missing message"; + if (message !== undefined) { + throw "extra message"; } - return >{ - stream: false, - header: response.headers, - message, - trailer, - }; - }, - options.interceptors - ); - } catch (e) { - throw connectErrorFromReason(e, Code.Internal); - } + message = parse(data); + } + if (trailer === undefined) { + throw "missing trailer"; + } + validateTrailer(trailer); + if (message === undefined) { + throw "missing message"; + } + return >{ + stream: false, + header: response.headers, + message, + trailer, + }; + }, + }); }, async stream< @@ -237,6 +227,7 @@ export function createGrpcWebTransport( options.jsonOptions, options.binaryOptions ); + async function* parseResponseBody( body: ReadableStream, foundStatus: boolean, @@ -288,6 +279,7 @@ export function createGrpcWebTransport( throw connectErrorFromReason(e); } } + async function createRequestBody( input: AsyncIterable ): Promise { @@ -300,8 +292,12 @@ export function createGrpcWebTransport( } return encodeEnvelope(0, serialize(r.value)); } - return runStreaming( - { + + return runStreamingCall({ + interceptors: options.interceptors, + signal, + timeoutMs, + req: { stream: true, service, method, @@ -312,11 +308,10 @@ export function createGrpcWebTransport( redirect: "error", mode: "cors", }, - signal: signal ?? new AbortController().signal, header: requestHeader(useBinaryFormat, timeoutMs, header), message: input, }, - async (req) => { + next: async (req) => { const fRes = await fetch(req.url, { ...req.init, headers: req.header, @@ -340,8 +335,7 @@ export function createGrpcWebTransport( }; return res; }, - options.interceptors - ).catch((e: unknown) => Promise.reject(connectErrorFromReason(e))); + }); }, }; } diff --git a/packages/connect/src/index.ts b/packages/connect/src/index.ts index d3abb836d..0805ea8be 100644 --- a/packages/connect/src/index.ts +++ b/packages/connect/src/index.ts @@ -51,7 +51,7 @@ export { cors } from "./cors.js"; // Symbols above should be relevant to end users. // Symbols below should only be relevant for other libraries. -export { runUnary, runStreaming } from "./interceptor.js"; +export { runUnary, runStreaming } from "./legacy-interceptor.js"; export { makeAnyClient } from "./any-client.js"; export type { AnyClient } from "./any-client.js"; diff --git a/packages/connect/src/interceptor.ts b/packages/connect/src/interceptor.ts index 4cabbb90e..644210eb9 100644 --- a/packages/connect/src/interceptor.ts +++ b/packages/connect/src/interceptor.ts @@ -54,30 +54,6 @@ type AnyFn = ( req: UnaryRequest | StreamRequest ) => Promise; -/** - * UnaryFn represents the client-side invocation of a unary RPC - a method - * that takes a single input message, and responds with a single output - * message. - * A Transport implements such a function, and makes it available to - * interceptors. - */ -type UnaryFn< - I extends Message = AnyMessage, - O extends Message = AnyMessage -> = (req: UnaryRequest) => Promise>; - -/** - * StreamingFn represents the client-side invocation of a streaming RPC - a - * method that takes zero or more input messages, and responds with zero or - * more output messages. - * A Transport implements such a function, and makes it available to - * interceptors. - */ -type StreamingFn< - I extends Message = AnyMessage, - O extends Message = AnyMessage -> = (req: StreamRequest) => Promise>; - /** * UnaryRequest is used in interceptors to represent a request with a * single input message. @@ -213,49 +189,3 @@ interface ResponseCommon, O extends Message> { */ readonly trailer: Headers; } - -/** - * applyInterceptors takes the given UnaryFn or ServerStreamingFn, and wraps - * it with each of the given interceptors, returning a new UnaryFn or - * ServerStreamingFn. - */ -function applyInterceptors(next: T, interceptors: Interceptor[]): T { - return interceptors - .concat() - .reverse() - .reduce( - // eslint-disable-next-line @typescript-eslint/no-unsafe-argument - (n, i) => i(n), - next as any // eslint-disable-line @typescript-eslint/no-explicit-any - ) as T; -} - -/** - * Runs a unary method with the given interceptors. Note that this function - * is only used when implementing a Transport. - */ -export function runUnary, O extends Message>( - req: UnaryRequest, - next: UnaryFn, - interceptors: Interceptor[] | undefined -): Promise> { - if (interceptors) { - next = applyInterceptors(next, interceptors); - } - return next(req); -} - -/** - * Runs a server-streaming method with the given interceptors. Note that this - * function is only used when implementing a Transport. - */ -export function runStreaming, O extends Message>( - req: StreamRequest, - next: StreamingFn, - interceptors: Interceptor[] | undefined -): Promise> { - if (interceptors) { - next = applyInterceptors(next, interceptors); - } - return next(req); -} diff --git a/packages/connect/src/legacy-interceptor.ts b/packages/connect/src/legacy-interceptor.ts new file mode 100644 index 000000000..d34a7c358 --- /dev/null +++ b/packages/connect/src/legacy-interceptor.ts @@ -0,0 +1,96 @@ +// Copyright 2021-2023 Buf Technologies, Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +import type { AnyMessage, Message } from "@bufbuild/protobuf"; +import type { + Interceptor, + StreamRequest, + StreamResponse, + UnaryRequest, + UnaryResponse, +} from "./interceptor.js"; + +/** + * Runs a unary method with the given interceptors. Note that this function + * is only used when implementing a Transport. + * + * @deprecated Use runUnaryCall from @bufbuild/connect/protocol instead. + */ +export function runUnary, O extends Message>( + req: UnaryRequest, + next: UnaryFn, + interceptors: Interceptor[] | undefined +): Promise> { + if (interceptors) { + next = applyInterceptors(next, interceptors); + } + return next(req); +} + +/** + * Runs a server-streaming method with the given interceptors. Note that this + * function is only used when implementing a Transport. + * + * @deprecated Use runStreamingCall from @bufbuild/connect/protocol instead. + */ +export function runStreaming, O extends Message>( + req: StreamRequest, + next: StreamingFn, + interceptors: Interceptor[] | undefined +): Promise> { + if (interceptors) { + next = applyInterceptors(next, interceptors); + } + return next(req); +} + +/** + * applyInterceptors takes the given UnaryFn or ServerStreamingFn, and wraps + * it with each of the given interceptors, returning a new UnaryFn or + * ServerStreamingFn. + */ +function applyInterceptors(next: T, interceptors: Interceptor[]): T { + return interceptors + .concat() + .reverse() + .reduce( + // eslint-disable-next-line @typescript-eslint/no-unsafe-argument + (n, i) => i(n), + next as any // eslint-disable-line @typescript-eslint/no-explicit-any + ) as T; +} + +/** + * UnaryFn represents the client-side invocation of a unary RPC - a method + * that takes a single input message, and responds with a single output + * message. + * A Transport implements such a function, and makes it available to + * interceptors. + */ +type UnaryFn< + I extends Message = AnyMessage, + O extends Message = AnyMessage +> = (req: UnaryRequest) => Promise>; + +/** + * StreamingFn represents the client-side invocation of a streaming RPC - a + * method that takes zero or more input messages, and responds with zero or + * more output messages. + * A Transport implements such a function, and makes it available to + * interceptors. + */ +type StreamingFn< + I extends Message = AnyMessage, + O extends Message = AnyMessage +> = (req: StreamRequest) => Promise>; diff --git a/packages/connect/src/protocol-connect/transport.ts b/packages/connect/src/protocol-connect/transport.ts index e034bb90a..84a5ea5a0 100644 --- a/packages/connect/src/protocol-connect/transport.ts +++ b/packages/connect/src/protocol-connect/transport.ts @@ -12,7 +12,6 @@ // See the License for the specific language governing permissions and // limitations under the License. -import { MethodIdempotency } from "@bufbuild/protobuf"; import type { AnyMessage, Message, @@ -20,6 +19,7 @@ import type { PartialMessage, ServiceType, } from "@bufbuild/protobuf"; +import { MethodIdempotency } from "@bufbuild/protobuf"; import type { StreamRequest, StreamResponse, @@ -27,23 +27,17 @@ import type { UnaryRequest, UnaryResponse, } from "../index.js"; -import { - appendHeaders, - Code, - ConnectError, - runStreaming, - runUnary, -} from "../index.js"; +import { appendHeaders, Code, ConnectError } from "../index.js"; import type { CommonTransportOptions } from "../protocol/index.js"; import { createAsyncIterable, - createLinkedAbortController, createMethodSerializationLookup, createMethodUrl, pipe, pipeTo, + runStreamingCall, + runUnaryCall, sinkAllBytes, - transformCatchFinally, transformCompressEnvelope, transformDecompressEnvelope, transformJoinEnvelopes, @@ -82,9 +76,11 @@ export function createTransport(opt: CommonTransportOptions): Transport { opt.jsonOptions, opt ); - const ac = createLinkedAbortController(signal); - return await runUnary( - { + return await runUnaryCall({ + interceptors: opt.interceptors, + signal, + timeoutMs, + req: { stream: false, service, method, @@ -100,9 +96,8 @@ export function createTransport(opt: CommonTransportOptions): Transport { ), message: message instanceof method.I ? message : new method.I(message), - signal: ac.signal, }, - async (req: UnaryRequest): Promise> => { + next: async (req: UnaryRequest): Promise> => { let requestBody = serialization .getI(opt.useBinaryFormat) .serialize(req.message); @@ -136,54 +131,48 @@ export function createTransport(opt: CommonTransportOptions): Transport { signal: req.signal, body, }); - try { - const { compression, isUnaryError, unaryError } = - validateResponseWithCompression( - method.kind, - opt.useBinaryFormat, - opt.acceptCompression, - universalResponse.status, - universalResponse.header - ); - const [header, trailer] = trailerDemux(universalResponse.header); - let responseBody = await pipeTo( - universalResponse.body, - sinkAllBytes( - opt.readMaxBytes, - universalResponse.header.get(headerUnaryContentLength) - ), - { propagateDownStreamError: false } + const { compression, isUnaryError, unaryError } = + validateResponseWithCompression( + method.kind, + opt.useBinaryFormat, + opt.acceptCompression, + universalResponse.status, + universalResponse.header + ); + const [header, trailer] = trailerDemux(universalResponse.header); + let responseBody = await pipeTo( + universalResponse.body, + sinkAllBytes( + opt.readMaxBytes, + universalResponse.header.get(headerUnaryContentLength) + ), + { propagateDownStreamError: false } + ); + if (compression) { + responseBody = await compression.decompress( + responseBody, + opt.readMaxBytes ); - if (compression) { - responseBody = await compression.decompress( - responseBody, - opt.readMaxBytes - ); - } - if (isUnaryError) { - throw errorFromJsonBytes( - responseBody, - appendHeaders(header, trailer), - unaryError - ); - } - return >{ - stream: false, - service, - method, - header, - message: serialization - .getO(opt.useBinaryFormat) - .parse(responseBody), - trailer, - }; - } catch (e) { - ac.abort(e); - throw e; } + if (isUnaryError) { + throw errorFromJsonBytes( + responseBody, + appendHeaders(header, trailer), + unaryError + ); + } + return >{ + stream: false, + service, + method, + header, + message: serialization + .getO(opt.useBinaryFormat) + .parse(responseBody), + trailer, + }; }, - opt.interceptors - ); + }); }, async stream< @@ -206,9 +195,11 @@ export function createTransport(opt: CommonTransportOptions): Transport { const endStreamSerialization = createEndStreamSerialization( opt.jsonOptions ); - const ac = createLinkedAbortController(signal); - return runStreaming( - { + return runStreamingCall({ + interceptors: opt.interceptors, + signal, + timeoutMs, + req: { stream: true, service, method, @@ -218,7 +209,6 @@ export function createTransport(opt: CommonTransportOptions): Transport { redirect: "error", mode: "cors", }, - signal: ac.signal, header: requestHeaderWithCompression( method.kind, opt.useBinaryFormat, @@ -231,8 +221,7 @@ export function createTransport(opt: CommonTransportOptions): Transport { propagateDownStreamError: true, }), }, - // eslint-disable-next-line @typescript-eslint/require-await - async (req: StreamRequest) => { + next: async (req: StreamRequest) => { const uRes = await opt.httpClient({ url: req.url, method: "POST", @@ -252,81 +241,69 @@ export function createTransport(opt: CommonTransportOptions): Transport { { propagateDownStreamError: true } ), }); - try { - const { compression } = validateResponseWithCompression( - method.kind, - opt.useBinaryFormat, - opt.acceptCompression, - uRes.status, - uRes.header - ); - const res: StreamResponse = { - ...req, - header: uRes.header, - trailer: new Headers(), - message: pipe( - uRes.body, - transformSplitEnvelope(opt.readMaxBytes), - transformDecompressEnvelope( - compression ?? null, - opt.readMaxBytes - ), - transformParseEnvelope( - serialization.getO(opt.useBinaryFormat), - endStreamFlag, - endStreamSerialization - ), - async function* (iterable) { - let endStreamReceived = false; - for await (const chunk of iterable) { - if (chunk.end) { - if (endStreamReceived) { - throw new ConnectError( - "protocol error: received extra EndStreamResponse", - Code.InvalidArgument - ); - } - endStreamReceived = true; - if (chunk.value.error) { - throw chunk.value.error; - } - chunk.value.metadata.forEach((value, key) => - res.trailer.set(key, value) - ); - continue; - } + const { compression } = validateResponseWithCompression( + method.kind, + opt.useBinaryFormat, + opt.acceptCompression, + uRes.status, + uRes.header + ); + const res: StreamResponse = { + ...req, + header: uRes.header, + trailer: new Headers(), + message: pipe( + uRes.body, + transformSplitEnvelope(opt.readMaxBytes), + transformDecompressEnvelope( + compression ?? null, + opt.readMaxBytes + ), + transformParseEnvelope( + serialization.getO(opt.useBinaryFormat), + endStreamFlag, + endStreamSerialization + ), + async function* (iterable) { + let endStreamReceived = false; + for await (const chunk of iterable) { + if (chunk.end) { if (endStreamReceived) { throw new ConnectError( - "protocol error: received extra message after EndStreamResponse", + "protocol error: received extra EndStreamResponse", Code.InvalidArgument ); } - yield chunk.value; + endStreamReceived = true; + if (chunk.value.error) { + throw chunk.value.error; + } + chunk.value.metadata.forEach((value, key) => + res.trailer.set(key, value) + ); + continue; } - if (!endStreamReceived) { + if (endStreamReceived) { throw new ConnectError( - "protocol error: missing EndStreamResponse", + "protocol error: received extra message after EndStreamResponse", Code.InvalidArgument ); } - }, - transformCatchFinally((e): void => { - if (e !== undefined) { - ac.abort(e); - throw e; - } - }), - { propagateDownStreamError: true } - ), - }; - return res; - } catch (e) { - ac.abort(e); - throw e; - } + yield chunk.value; + } + if (!endStreamReceived) { + throw new ConnectError( + "protocol error: missing EndStreamResponse", + Code.InvalidArgument + ); + } + }, + { propagateDownStreamError: true } + ), + }; + return res; }, - opt.interceptors - ); + }); }, }; } diff --git a/packages/connect/src/protocol-grpc-web/transport.ts b/packages/connect/src/protocol-grpc-web/transport.ts index c503794bf..057b5fe14 100644 --- a/packages/connect/src/protocol-grpc-web/transport.ts +++ b/packages/connect/src/protocol-grpc-web/transport.ts @@ -26,16 +26,16 @@ import type { UnaryRequest, UnaryResponse, } from "../index.js"; -import { Code, ConnectError, runStreaming, runUnary } from "../index.js"; +import { Code, ConnectError } from "../index.js"; import type { CommonTransportOptions } from "../protocol/index.js"; import { createAsyncIterable, - createLinkedAbortController, createMethodSerializationLookup, createMethodUrl, pipe, pipeTo, - transformCatchFinally, + runStreamingCall, + runUnaryCall, transformCompressEnvelope, transformDecompressEnvelope, transformJoinEnvelopes, @@ -71,9 +71,11 @@ export function createTransport(opt: CommonTransportOptions): Transport { opt.jsonOptions, opt ); - const ac = createLinkedAbortController(signal); - return await runUnary( - { + return await runUnaryCall({ + interceptors: opt.interceptors, + signal, + timeoutMs, + req: { stream: false, service, method, @@ -88,9 +90,8 @@ export function createTransport(opt: CommonTransportOptions): Transport { ), message: message instanceof method.I ? message : new method.I(message), - signal: ac.signal, }, - async (req: UnaryRequest): Promise> => { + next: async (req: UnaryRequest): Promise> => { const uRes = await opt.httpClient({ url: req.url, method: "POST", @@ -111,81 +112,72 @@ export function createTransport(opt: CommonTransportOptions): Transport { } ), }); - try { - const { compression } = validateResponseWithCompression( - opt.useBinaryFormat, - opt.acceptCompression, - uRes.status, - uRes.header - ); - const { trailer, message } = await pipeTo( - uRes.body, - transformSplitEnvelope(opt.readMaxBytes), - transformDecompressEnvelope( - compression ?? null, - opt.readMaxBytes - ), - transformParseEnvelope( - serialization.getO(opt.useBinaryFormat), - trailerFlag, - createTrailerSerialization() - ), - async (iterable) => { - let message: O | undefined; - let trailer: Headers | undefined; - for await (const env of iterable) { - if (env.end) { - if (trailer !== undefined) { - throw new ConnectError( - "protocol error: received extra trailer", - Code.InvalidArgument - ); - } - trailer = env.value; - } else { - if (message !== undefined) { - throw new ConnectError( - "protocol error: received extra output message for unary method", - Code.InvalidArgument - ); - } - message = env.value; + const { compression } = validateResponseWithCompression( + opt.useBinaryFormat, + opt.acceptCompression, + uRes.status, + uRes.header + ); + const { trailer, message } = await pipeTo( + uRes.body, + transformSplitEnvelope(opt.readMaxBytes), + transformDecompressEnvelope(compression ?? null, opt.readMaxBytes), + transformParseEnvelope( + serialization.getO(opt.useBinaryFormat), + trailerFlag, + createTrailerSerialization() + ), + async (iterable) => { + let message: O | undefined; + let trailer: Headers | undefined; + for await (const env of iterable) { + if (env.end) { + if (trailer !== undefined) { + throw new ConnectError( + "protocol error: received extra trailer", + Code.InvalidArgument + ); + } + trailer = env.value; + } else { + if (message !== undefined) { + throw new ConnectError( + "protocol error: received extra output message for unary method", + Code.InvalidArgument + ); } + message = env.value; } - return { trailer, message }; - }, - { - propagateDownStreamError: false, } - ); - if (trailer === undefined) { - throw new ConnectError( - "protocol error: missing trailer", - Code.InvalidArgument - ); - } - validateTrailer(trailer); - if (message === undefined) { - throw new ConnectError( - "protocol error: missing output message for unary method", - Code.InvalidArgument - ); + return { trailer, message }; + }, + { + propagateDownStreamError: false, } - return >{ - stream: false, - service, - method, - header: uRes.header, - message, - trailer, - }; - } catch (e) { - ac.abort(e); - throw e; + ); + if (trailer === undefined) { + throw new ConnectError( + "protocol error: missing trailer", + Code.InvalidArgument + ); } + validateTrailer(trailer); + if (message === undefined) { + throw new ConnectError( + "protocol error: missing output message for unary method", + Code.InvalidArgument + ); + } + return >{ + stream: false, + service, + method, + header: uRes.header, + message, + trailer, + }; }, - opt.interceptors - ); + }); }, async stream< I extends Message = AnyMessage, @@ -204,9 +196,11 @@ export function createTransport(opt: CommonTransportOptions): Transport { opt.jsonOptions, opt ); - const ac = createLinkedAbortController(signal); - return runStreaming( - { + return runStreamingCall({ + interceptors: opt.interceptors, + signal, + timeoutMs, + req: { stream: true, service, method, @@ -216,7 +210,6 @@ export function createTransport(opt: CommonTransportOptions): Transport { redirect: "error", mode: "cors", }, - signal: ac.signal, header: requestHeaderWithCompression( opt.useBinaryFormat, timeoutMs, @@ -228,7 +221,7 @@ export function createTransport(opt: CommonTransportOptions): Transport { propagateDownStreamError: true, }), }, - async (req: StreamRequest) => { + next: async (req: StreamRequest) => { const uRes = await opt.httpClient({ url: req.url, method: "POST", @@ -248,97 +241,84 @@ export function createTransport(opt: CommonTransportOptions): Transport { { propagateDownStreamError: true } ), }); - try { - const { compression, foundStatus } = - validateResponseWithCompression( - opt.useBinaryFormat, - opt.acceptCompression, - uRes.status, - uRes.header - ); - const res: StreamResponse = { - ...req, - header: uRes.header, - trailer: new Headers(), - message: pipe( - uRes.body, - transformSplitEnvelope(opt.readMaxBytes), - transformDecompressEnvelope( - compression ?? null, - opt.readMaxBytes - ), - transformParseEnvelope( - serialization.getO(opt.useBinaryFormat), - trailerFlag, - createTrailerSerialization() - ), - async function* (iterable) { - if (foundStatus) { - // A grpc-status: 0 response header was present. This is a "trailers-only" - // response (a response without a body and no trailers). - // - // The spec seems to disallow a trailers-only response for status 0 - we are - // lenient and only verify that the body is empty. - // - // > [...] Trailers-Only is permitted for calls that produce an immediate error. - // See https://github.com/grpc/grpc/blob/master/doc/PROTOCOL-HTTP2.md - const r = await iterable[Symbol.asyncIterator]().next(); - if (r.done !== true) { - throw new ConnectError( - "protocol error: extra data for trailers-only", - Code.InvalidArgument - ); - } - return; + const { compression, foundStatus } = validateResponseWithCompression( + opt.useBinaryFormat, + opt.acceptCompression, + uRes.status, + uRes.header + ); + const res: StreamResponse = { + ...req, + header: uRes.header, + trailer: new Headers(), + message: pipe( + uRes.body, + transformSplitEnvelope(opt.readMaxBytes), + transformDecompressEnvelope( + compression ?? null, + opt.readMaxBytes + ), + transformParseEnvelope( + serialization.getO(opt.useBinaryFormat), + trailerFlag, + createTrailerSerialization() + ), + async function* (iterable) { + if (foundStatus) { + // A grpc-status: 0 response header was present. This is a "trailers-only" + // response (a response without a body and no trailers). + // + // The spec seems to disallow a trailers-only response for status 0 - we are + // lenient and only verify that the body is empty. + // + // > [...] Trailers-Only is permitted for calls that produce an immediate error. + // See https://github.com/grpc/grpc/blob/master/doc/PROTOCOL-HTTP2.md + const r = await iterable[Symbol.asyncIterator]().next(); + if (r.done !== true) { + throw new ConnectError( + "protocol error: extra data for trailers-only", + Code.InvalidArgument + ); } - let trailerReceived = false; - for await (const chunk of iterable) { - if (chunk.end) { - if (trailerReceived) { - throw new ConnectError( - "protocol error: received extra trailer", - Code.InvalidArgument - ); - } - trailerReceived = true; - validateTrailer(chunk.value); - chunk.value.forEach((value, key) => - res.trailer.set(key, value) - ); - continue; - } + return; + } + let trailerReceived = false; + for await (const chunk of iterable) { + if (chunk.end) { if (trailerReceived) { throw new ConnectError( - "protocol error: received extra message after trailer", + "protocol error: received extra trailer", Code.InvalidArgument ); } - yield chunk.value; + trailerReceived = true; + validateTrailer(chunk.value); + chunk.value.forEach((value, key) => + res.trailer.set(key, value) + ); + continue; } - if (!trailerReceived) { + if (trailerReceived) { throw new ConnectError( - "protocol error: missing trailer", + "protocol error: received extra message after trailer", Code.InvalidArgument ); } - }, - transformCatchFinally((e): void => { - if (e !== undefined) { - ac.abort(e); - throw e; - } - }), - { propagateDownStreamError: true } - ), - }; - return res; - } catch (e) { - ac.abort(e); - throw e; - } + yield chunk.value; + } + if (!trailerReceived) { + throw new ConnectError( + "protocol error: missing trailer", + Code.InvalidArgument + ); + } + }, + { propagateDownStreamError: true } + ), + }; + return res; }, - opt.interceptors - ); + }); }, }; } diff --git a/packages/connect/src/protocol-grpc/transport.ts b/packages/connect/src/protocol-grpc/transport.ts index 85dc30264..aa7368b3c 100644 --- a/packages/connect/src/protocol-grpc/transport.ts +++ b/packages/connect/src/protocol-grpc/transport.ts @@ -26,16 +26,16 @@ import type { UnaryRequest, UnaryResponse, } from "../index.js"; -import { Code, ConnectError, runStreaming, runUnary } from "../index.js"; +import { Code, ConnectError } from "../index.js"; import type { CommonTransportOptions } from "../protocol/index.js"; import { createAsyncIterable, - createLinkedAbortController, createMethodSerializationLookup, createMethodUrl, pipe, pipeTo, - transformCatchFinally, + runStreamingCall, + runUnaryCall, transformCompressEnvelope, transformDecompressEnvelope, transformJoinEnvelopes, @@ -70,9 +70,11 @@ export function createTransport(opt: CommonTransportOptions): Transport { opt.jsonOptions, opt ); - const ac = createLinkedAbortController(signal); - return await runUnary( - { + return await runUnaryCall({ + interceptors: opt.interceptors, + signal, + timeoutMs, + req: { stream: false, service, method, @@ -86,9 +88,8 @@ export function createTransport(opt: CommonTransportOptions): Transport { opt.sendCompression ), message: input instanceof method.I ? input : new method.I(input), - signal: ac.signal, }, - async (req: UnaryRequest): Promise> => { + next: async (req: UnaryRequest): Promise> => { const uRes = await opt.httpClient({ url: req.url, method: "POST", @@ -109,60 +110,49 @@ export function createTransport(opt: CommonTransportOptions): Transport { } ), }); - try { - const { compression } = validateResponseWithCompression( - opt.useBinaryFormat, - opt.acceptCompression, - uRes.status, - uRes.header - ); - const message = await pipeTo( - uRes.body, - transformSplitEnvelope(opt.readMaxBytes), - transformDecompressEnvelope( - compression ?? null, - opt.readMaxBytes - ), - transformParseEnvelope( - serialization.getO(opt.useBinaryFormat) - ), - async (iterable) => { - let message: O | undefined; - for await (const chunk of iterable) { - if (message !== undefined) { - throw new ConnectError( - "protocol error: received extra output message for unary method", - Code.InvalidArgument - ); - } - message = chunk; + const { compression } = validateResponseWithCompression( + opt.useBinaryFormat, + opt.acceptCompression, + uRes.status, + uRes.header + ); + const message = await pipeTo( + uRes.body, + transformSplitEnvelope(opt.readMaxBytes), + transformDecompressEnvelope(compression ?? null, opt.readMaxBytes), + transformParseEnvelope(serialization.getO(opt.useBinaryFormat)), + async (iterable) => { + let message: O | undefined; + for await (const chunk of iterable) { + if (message !== undefined) { + throw new ConnectError( + "protocol error: received extra output message for unary method", + Code.InvalidArgument + ); } - return message; - }, - { propagateDownStreamError: false } + message = chunk; + } + return message; + }, + { propagateDownStreamError: false } + ); + validateTrailer(uRes.trailer); + if (message === undefined) { + throw new ConnectError( + "protocol error: missing output message for unary method", + Code.InvalidArgument ); - validateTrailer(uRes.trailer); - if (message === undefined) { - throw new ConnectError( - "protocol error: missing output message for unary method", - Code.InvalidArgument - ); - } - return >{ - stream: false, - service, - method, - header: uRes.header, - message, - trailer: uRes.trailer, - }; - } catch (e) { - ac.abort(e); - throw e; } + return >{ + stream: false, + service, + method, + header: uRes.header, + message, + trailer: uRes.trailer, + }; }, - opt.interceptors - ); + }); }, async stream< I extends Message = AnyMessage, @@ -181,15 +171,16 @@ export function createTransport(opt: CommonTransportOptions): Transport { opt.jsonOptions, opt ); - const ac = createLinkedAbortController(signal); - return runStreaming( - { + return runStreamingCall({ + interceptors: opt.interceptors, + signal, + timeoutMs, + req: { stream: true, service, method, url: createMethodUrl(opt.baseUrl, service, method), init: {}, - signal: ac.signal, header: requestHeaderWithCompression( opt.useBinaryFormat, timeoutMs, @@ -201,7 +192,7 @@ export function createTransport(opt: CommonTransportOptions): Transport { propagateDownStreamError: true, }), }, - async (req: StreamRequest) => { + next: async (req: StreamRequest) => { const uRes = await opt.httpClient({ url: req.url, method: "POST", @@ -221,49 +212,36 @@ export function createTransport(opt: CommonTransportOptions): Transport { { propagateDownStreamError: true } ), }); - try { - const { compression, foundStatus } = - validateResponseWithCompression( - opt.useBinaryFormat, - opt.acceptCompression, - uRes.status, - uRes.header - ); - const res: StreamResponse = { - ...req, - header: uRes.header, - trailer: uRes.trailer, - message: pipe( - uRes.body, - transformSplitEnvelope(opt.readMaxBytes), - transformDecompressEnvelope( - compression ?? null, - opt.readMaxBytes - ), - transformParseEnvelope(serialization.getO(opt.useBinaryFormat)), - async function* (iterable) { - yield* iterable; - if (!foundStatus) { - validateTrailer(uRes.trailer); - } - }, - transformCatchFinally((e): void => { - if (e !== undefined) { - ac.abort(e); - throw e; - } - }), - { propagateDownStreamError: true } + const { compression, foundStatus } = validateResponseWithCompression( + opt.useBinaryFormat, + opt.acceptCompression, + uRes.status, + uRes.header + ); + const res: StreamResponse = { + ...req, + header: uRes.header, + trailer: uRes.trailer, + message: pipe( + uRes.body, + transformSplitEnvelope(opt.readMaxBytes), + transformDecompressEnvelope( + compression ?? null, + opt.readMaxBytes ), - }; - return res; - } catch (e) { - ac.abort(e); - throw e; - } + transformParseEnvelope(serialization.getO(opt.useBinaryFormat)), + async function* (iterable) { + yield* iterable; + if (!foundStatus) { + validateTrailer(uRes.trailer); + } + }, + { propagateDownStreamError: true } + ), + }; + return res; }, - opt.interceptors - ); + }); }, }; } diff --git a/packages/connect/src/protocol/index.ts b/packages/connect/src/protocol/index.ts index 4633a95c3..9a6cece3d 100644 --- a/packages/connect/src/protocol/index.ts +++ b/packages/connect/src/protocol/index.ts @@ -25,6 +25,7 @@ export type { Compression } from "./compression.js"; export type { UniversalHandler } from "./universal-handler.js"; export { createUniversalHandlerClient } from "./universal-handler-client.js"; export { createFetchClient, createFetchHandler } from "./universal-fetch.js"; +export { runUnaryCall, runStreamingCall } from "./run-call.js"; // All exports below are private — internal code that does not follow semantic // versioning. diff --git a/packages/connect/src/protocol/run-call.spec.ts b/packages/connect/src/protocol/run-call.spec.ts new file mode 100644 index 000000000..d88834740 --- /dev/null +++ b/packages/connect/src/protocol/run-call.spec.ts @@ -0,0 +1,223 @@ +// Copyright 2021-2023 Buf Technologies, Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +import { + Int32Value, + MethodKind, + type ServiceType, + StringValue, +} from "@bufbuild/protobuf"; +import { runStreamingCall, runUnaryCall } from "./run-call.js"; +import type { + StreamRequest, + StreamResponse, + UnaryRequest, + UnaryResponse, +} from "../interceptor.js"; +import { createAsyncIterable } from "./async-iterable.js"; + +const TestService = { + typeName: "TestService", + methods: { + unary: { + name: "Unary", + I: Int32Value, + O: StringValue, + kind: MethodKind.Unary, + }, + serverStreaming: { + name: "ServerStreaming", + I: Int32Value, + O: StringValue, + kind: MethodKind.ServerStreaming, + }, + }, +} satisfies ServiceType; + +describe("runUnaryCall()", function () { + function makeReq() { + return { + stream: false as const, + service: TestService, + method: TestService.methods.unary, + url: `https://example.com/TestService/Unary`, + init: {}, + header: new Headers(), + message: new Int32Value({ value: 123 }), + }; + } + + function makeRes(req: UnaryRequest) { + return >{ + stream: false, + service: TestService, + method: TestService.methods.unary, + header: new Headers(), + message: new StringValue({ value: req.message.value.toString(10) }), + trailer: new Headers(), + }; + } + it("should return the response", async function () { + const res = await runUnaryCall({ + timeoutMs: undefined, + signal: undefined, + interceptors: [], + req: makeReq(), + async next(req) { + await new Promise((resolve) => setTimeout(resolve, 1)); + return makeRes(req); + }, + }); + expect(res.message.value).toBe("123"); + }); + it("should trigger the signal when done", async function () { + let signal: AbortSignal | undefined; + await runUnaryCall({ + req: makeReq(), + async next(req) { + signal = req.signal; + await new Promise((resolve) => setTimeout(resolve, 1)); + return makeRes(req); + }, + }); + expect(signal?.aborted).toBeTrue(); + }); + it("should raise Code.Canceled on user abort", async function () { + const userAbort = new AbortController(); + const resPromise = runUnaryCall({ + signal: userAbort.signal, + req: makeReq(), + async next(req) { + for (;;) { + await new Promise((resolve) => setTimeout(resolve, 1)); + req.signal.throwIfAborted(); + } + }, + }); + userAbort.abort(); + await expectAsync(resPromise).toBeRejectedWithError( + "[canceled] This operation was aborted" + ); + }); + it("should raise Code.DeadlineExceeded on timeout", async function () { + const resPromise = runUnaryCall({ + timeoutMs: 1, + req: makeReq(), + async next(req) { + for (;;) { + await new Promise((resolve) => setTimeout(resolve, 1)); + req.signal.throwIfAborted(); + } + }, + }); + await expectAsync(resPromise).toBeRejectedWithError( + "[deadline_exceeded] the operation timed out" + ); + }); +}); + +describe("runStreamingCall()", function () { + function makeReq() { + return { + stream: true as const, + service: TestService, + method: TestService.methods.serverStreaming, + url: `https://example.com/TestService/ServerStreaming`, + init: {}, + header: new Headers(), + message: createAsyncIterable([new Int32Value({ value: 123 })]), + }; + } + + // eslint-disable-next-line @typescript-eslint/no-unused-vars + function makeRes(req: StreamRequest) { + return >{ + stream: true, + service: TestService, + method: TestService.methods.serverStreaming, + header: new Headers(), + message: createAsyncIterable([ + new StringValue({ value: "1" }), + new StringValue({ value: "2" }), + new StringValue({ value: "3" }), + ]), + trailer: new Headers(), + }; + } + + it("should return the response", async function () { + const res = await runStreamingCall({ + timeoutMs: undefined, + signal: undefined, + interceptors: [], + req: makeReq(), + async next(req) { + await new Promise((resolve) => setTimeout(resolve, 1)); + return makeRes(req); + }, + }); + const values: string[] = []; + for await (const m of res.message) { + values.push(m.value); + } + expect(values).toEqual(["1", "2", "3"]); + }); + it("should trigger the signal when done", async function () { + let signal: AbortSignal | undefined; + const res = await runStreamingCall({ + req: makeReq(), + async next(req) { + signal = req.signal; + await new Promise((resolve) => setTimeout(resolve, 1)); + return makeRes(req); + }, + }); + for await (const m of res.message) { + expect(m).toBeDefined(); + } + expect(signal?.aborted).toBeTrue(); + }); + it("should raise Code.Canceled on user abort", async function () { + const userAbort = new AbortController(); + const resPromise = runStreamingCall({ + signal: userAbort.signal, + req: makeReq(), + async next(req) { + for (;;) { + await new Promise((resolve) => setTimeout(resolve, 1)); + req.signal.throwIfAborted(); + } + }, + }); + userAbort.abort(); + await expectAsync(resPromise).toBeRejectedWithError( + "[canceled] This operation was aborted" + ); + }); + it("should raise Code.DeadlineExceeded on timeout", async function () { + const resPromise = runStreamingCall({ + timeoutMs: 1, + req: makeReq(), + async next(req) { + for (;;) { + await new Promise((resolve) => setTimeout(resolve, 1)); + req.signal.throwIfAborted(); + } + }, + }); + await expectAsync(resPromise).toBeRejectedWithError( + "[deadline_exceeded] the operation timed out" + ); + }); +}); diff --git a/packages/connect/src/protocol/run-call.ts b/packages/connect/src/protocol/run-call.ts new file mode 100644 index 000000000..dccce7fdd --- /dev/null +++ b/packages/connect/src/protocol/run-call.ts @@ -0,0 +1,186 @@ +// Copyright 2021-2023 Buf Technologies, Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +import type { AnyMessage, Message } from "@bufbuild/protobuf"; +import type { + Interceptor, + StreamRequest, + StreamResponse, + UnaryRequest, + UnaryResponse, +} from "../interceptor.js"; +import { connectErrorFromReason } from "../connect-error.js"; +import { + createDeadlineSignal, + createLinkedAbortController, + getAbortSignalReason, +} from "./signals.js"; + +/** + * UnaryFn represents the client-side invocation of a unary RPC - a method + * that takes a single input message, and responds with a single output + * message. + * A Transport implements such a function, and makes it available to + * interceptors. + */ +type UnaryFn< + I extends Message = AnyMessage, + O extends Message = AnyMessage +> = (req: UnaryRequest) => Promise>; + +/** + * Runs a unary method with the given interceptors. Note that this function + * is only used when implementing a Transport. + */ +export function runUnaryCall, O extends Message>(opt: { + req: Omit, "signal">; + next: UnaryFn; + timeoutMs?: number; + signal?: AbortSignal; + interceptors?: Interceptor[]; +}): Promise> { + const next = applyInterceptors(opt.next, opt.interceptors); + const [signal, abort, done] = setupSignal(opt); + const req = { + ...opt.req, + signal, + }; + return next(req).then((res) => { + done(); + return res; + }, abort); +} + +/** + * StreamingFn represents the client-side invocation of a streaming RPC - a + * method that takes zero or more input messages, and responds with zero or + * more output messages. + * A Transport implements such a function, and makes it available to + * interceptors. + */ +type StreamingFn< + I extends Message = AnyMessage, + O extends Message = AnyMessage +> = (req: StreamRequest) => Promise>; + +/** + * Runs a server-streaming method with the given interceptors. Note that this + * function is only used when implementing a Transport. + */ +export function runStreamingCall< + I extends Message, + O extends Message +>(opt: { + req: Omit, "signal">; + next: StreamingFn; + timeoutMs?: number; + signal?: AbortSignal; + interceptors?: Interceptor[]; +}): Promise> { + const next = applyInterceptors(opt.next, opt.interceptors); + const [signal, abort, done] = setupSignal(opt); + const req = { + ...opt.req, + signal, + }; + return next(req).then((res) => { + return { + ...res, + message: { + [Symbol.asyncIterator]() { + const it = res.message[Symbol.asyncIterator](); + const w: AsyncIterator = { + next() { + return it.next().then((r) => { + if (r.done == true) { + done(); + } + return r; + }, abort); + }, + }; + if (it.throw !== undefined) { + // eslint-disable-next-line @typescript-eslint/no-non-null-assertion -- can't handle mutated object sensibly + w.throw = (e: unknown) => it.throw!(e); + } + if (it.return !== undefined) { + // eslint-disable-next-line @typescript-eslint/no-non-null-assertion,@typescript-eslint/no-explicit-any -- can't handle mutated object sensibly + w.return = (value?: any) => it.return!(value); + } + return w; + }, + }, + }; + }, abort); +} + +/** + * Create an AbortSignal for Transport implementations. The signal is available + * in UnaryRequest and StreamingRequest, and is triggered when the call is + * aborted (via a timeout or explicit cancellation), errored (e.g. when reading + * an error from the server from the wire), or finished successfully. + * + * Transport implementations can pass the signal to HTTP clients to ensure that + * there are no unused connections leak. + * + * Returns a tuple: + * [0]: The signal, which is also aborted if the optional deadline is reached. + * [1]: Function to call if the Transport encountered an error. + * [2]: Function to call if the Transport finished without an error. + */ +function setupSignal(opt: { + timeoutMs?: number; + signal?: AbortSignal; +}): [AbortSignal, (reason: unknown) => Promise, () => void] { + const { signal, cleanup } = createDeadlineSignal(opt.timeoutMs); + const controller = createLinkedAbortController(opt.signal, signal); + return [ + controller.signal, + function abort(reason: unknown): Promise { + // We peek at the deadline signal because fetch() will throw an error on + // abort that discards the signal reason. + const e = connectErrorFromReason( + signal.aborted ? getAbortSignalReason(signal) : reason + ); + controller.abort(e); + cleanup(); + return Promise.reject(e); + }, + function done() { + cleanup(); + controller.abort(); + }, + ]; +} + +/** + * applyInterceptors takes the given UnaryFn or ServerStreamingFn, and wraps + * it with each of the given interceptors, returning a new UnaryFn or + * ServerStreamingFn. + */ +function applyInterceptors( + next: T, + interceptors: Interceptor[] | undefined +): T { + return ( + (interceptors + ?.concat() + .reverse() + .reduce( + // eslint-disable-next-line @typescript-eslint/no-unsafe-argument + (n, i) => i(n), + next as any // eslint-disable-line @typescript-eslint/no-explicit-any + ) as T) ?? next + ); +} diff --git a/packages/connect/src/protocol/signals.spec.ts b/packages/connect/src/protocol/signals.spec.ts index e14ad9dec..581e7a514 100644 --- a/packages/connect/src/protocol/signals.spec.ts +++ b/packages/connect/src/protocol/signals.spec.ts @@ -66,6 +66,12 @@ describe("createDeadlineSignal()", function () { expect(d.signal.aborted).toBeTrue(); }); }); + describe("with -1 timeout", function () { + it("should be aborted immediately", function () { + const d = createDeadlineSignal(-1); + expect(d.signal.aborted).toBeTrue(); + }); + }); describe("with undefined timeout", function () { it("should still return a signal", function () { const d = createDeadlineSignal(undefined); diff --git a/packages/connect/src/protocol/signals.ts b/packages/connect/src/protocol/signals.ts index 982976dc1..67bf2c1ba 100644 --- a/packages/connect/src/protocol/signals.ts +++ b/packages/connect/src/protocol/signals.ts @@ -70,15 +70,15 @@ export function createDeadlineSignal(timeoutMs: number | undefined): { cleanup: () => void; } { const controller = new AbortController(); - const listener = () => + const listener = () => { controller.abort( new ConnectError("the operation timed out", Code.DeadlineExceeded) ); + }; let timeoutId: ReturnType | undefined; - if (timeoutMs === 0) { - listener(); - } else if (timeoutMs !== undefined) { - timeoutId = setTimeout(listener, timeoutMs); + if (timeoutMs !== undefined) { + if (timeoutMs <= 0) listener(); + else timeoutId = setTimeout(listener, timeoutMs); } return { signal: controller.signal, From 4ee325ab5fa29414b74c74e79e4ae4591c313c9d Mon Sep 17 00:00:00 2001 From: Timo Stamm Date: Tue, 16 May 2023 15:56:30 +0200 Subject: [PATCH 2/2] Remove a try-catch block that has become unnecessary --- .../connect-web/src/grpc-web-transport.ts | 71 +++++++++---------- 1 file changed, 33 insertions(+), 38 deletions(-) diff --git a/packages/connect-web/src/grpc-web-transport.ts b/packages/connect-web/src/grpc-web-transport.ts index c8f788770..e4eaed88a 100644 --- a/packages/connect-web/src/grpc-web-transport.ts +++ b/packages/connect-web/src/grpc-web-transport.ts @@ -30,7 +30,6 @@ import type { UnaryRequest, UnaryResponse, } from "@bufbuild/connect"; -import { connectErrorFromReason } from "@bufbuild/connect"; import { createClientMethodSerializers, createEnvelopeReadableStream, @@ -234,49 +233,45 @@ export function createGrpcWebTransport( trailerTarget: Headers ) { const reader = createEnvelopeReadableStream(body).getReader(); - try { - if (foundStatus) { - // A grpc-status: 0 response header was present. This is a "trailers-only" - // response (a response without a body and no trailers). - // - // The spec seems to disallow a trailers-only response for status 0 - we are - // lenient and only verify that the body is empty. - // - // > [...] Trailers-Only is permitted for calls that produce an immediate error. - // See https://github.com/grpc/grpc/blob/master/doc/PROTOCOL-HTTP2.md - if (!(await reader.read()).done) { - throw "extra data for trailers-only"; - } - return; + if (foundStatus) { + // A grpc-status: 0 response header was present. This is a "trailers-only" + // response (a response without a body and no trailers). + // + // The spec seems to disallow a trailers-only response for status 0 - we are + // lenient and only verify that the body is empty. + // + // > [...] Trailers-Only is permitted for calls that produce an immediate error. + // See https://github.com/grpc/grpc/blob/master/doc/PROTOCOL-HTTP2.md + if (!(await reader.read()).done) { + throw "extra data for trailers-only"; } - let trailerReceived = false; - for (;;) { - const result = await reader.read(); - if (result.done) { - break; - } - const { flags, data } = result.value; - if ((flags & trailerFlag) === trailerFlag) { - if (trailerReceived) { - throw "extra trailer"; - } - trailerReceived = true; - const trailer = trailerParse(data); - validateTrailer(trailer); - trailer.forEach((value, key) => trailerTarget.set(key, value)); - continue; - } + return; + } + let trailerReceived = false; + for (;;) { + const result = await reader.read(); + if (result.done) { + break; + } + const { flags, data } = result.value; + if ((flags & trailerFlag) === trailerFlag) { if (trailerReceived) { - throw "extra message"; + throw "extra trailer"; } - yield parse(data); + trailerReceived = true; + const trailer = trailerParse(data); + validateTrailer(trailer); + trailer.forEach((value, key) => trailerTarget.set(key, value)); continue; } - if (!trailerReceived) { - throw "missing trailer"; + if (trailerReceived) { + throw "extra message"; } - } catch (e) { - throw connectErrorFromReason(e); + yield parse(data); + continue; + } + if (!trailerReceived) { + throw "missing trailer"; } }