From f4ff3823b9eebb296a812b27a43ee28b07c5b962 Mon Sep 17 00:00:00 2001 From: Timo Stamm Date: Wed, 8 Mar 2023 00:25:21 +0100 Subject: [PATCH] Add support for Connect-Protocol-Version header --- .../src/badweather/broken-input.spec.ts | 2 + .../unsupported-content-encoding.spec.ts | 2 + .../src/helpers/testserver.ts | 25 +- .../headers.spec.ts => fetch-headers.spec.ts} | 0 .../protocol-connect/handler-factory.spec.ts | 193 ++++++++++++++++ .../src/protocol-connect/handler-factory.ts | 4 + .../connect/src/protocol-connect/version.ts | 23 ++ .../src/protocol-grpc-web/trailer.spec.ts | 2 +- .../connect/src/protocol/limit-io.spec.ts | 56 +++++ .../src/protocol/universal-handler.spec.ts | 215 ++++++++++++++++++ .../connect/src/protocol/universal-handler.ts | 93 +++++--- packages/connect/tsconfig.test.json | 0 12 files changed, 579 insertions(+), 36 deletions(-) rename packages/connect/src/{protocol-connect/headers.spec.ts => fetch-headers.spec.ts} (100%) create mode 100644 packages/connect/src/protocol-connect/handler-factory.spec.ts create mode 100644 packages/connect/src/protocol/limit-io.spec.ts create mode 100644 packages/connect/src/protocol/universal-handler.spec.ts delete mode 100644 packages/connect/tsconfig.test.json diff --git a/packages/connect-node-test/src/badweather/broken-input.spec.ts b/packages/connect-node-test/src/badweather/broken-input.spec.ts index 927c60975..cc1af42d0 100644 --- a/packages/connect-node-test/src/badweather/broken-input.spec.ts +++ b/packages/connect-node-test/src/badweather/broken-input.spec.ts @@ -46,6 +46,7 @@ describe("broken input", () => { ), method: "POST", ctype: "application/json", + headers: { "Connect-Protocol-Version": "1" }, }).then((res) => { return { status: res.status, @@ -85,6 +86,7 @@ describe("broken input", () => { url: createMethodUrl(server.getUrl(), TestService, method), method: "POST", ctype: "application/connect+json", + headers: { "Connect-Protocol-Version": "1" }, }).then((res) => ({ status: res.status, endStream: endStreamFromJson(res.body.subarray(5)), diff --git a/packages/connect-node-test/src/badweather/unsupported-content-encoding.spec.ts b/packages/connect-node-test/src/badweather/unsupported-content-encoding.spec.ts index 4ad4cfb16..b3c43b7b3 100644 --- a/packages/connect-node-test/src/badweather/unsupported-content-encoding.spec.ts +++ b/packages/connect-node-test/src/badweather/unsupported-content-encoding.spec.ts @@ -44,6 +44,7 @@ describe("unsupported content encoding", () => { headers: { "content-type": "application/json", "content-encoding": "banana", + "connect-protocol-version": "1", }, rejectUnauthorized, }); @@ -74,6 +75,7 @@ describe("unsupported content encoding", () => { headers: { "content-type": "application/connect+json", "connect-content-encoding": "banana", + "connect-protocol-version": "1", }, rejectUnauthorized, }); diff --git a/packages/connect-node-test/src/helpers/testserver.ts b/packages/connect-node-test/src/helpers/testserver.ts index 52a505e54..2ff4446af 100644 --- a/packages/connect-node-test/src/helpers/testserver.ts +++ b/packages/connect-node-test/src/helpers/testserver.ts @@ -116,7 +116,10 @@ export function createTestServers() { cert: certLocalhost.cert, key: certLocalhost.key, }, - connectNodeAdapter({ routes: testRoutes }) + connectNodeAdapter({ + routes: testRoutes, + requireConnectProtocolHeader: true, + }) ) .listen(0, resolve); }); @@ -145,7 +148,13 @@ export function createTestServers() { start() { return new Promise((resolve) => { nodeH2cServer = http2 - .createServer({}, connectNodeAdapter({ routes: testRoutes })) + .createServer( + {}, + connectNodeAdapter({ + routes: testRoutes, + requireConnectProtocolHeader: true, + }) + ) .listen(0, resolve); }); }, @@ -201,7 +210,10 @@ export function createTestServers() { "Access-Control-Expose-Headers": corsExposeHeaders.join(", "), "Access-Control-Max-Age": 2 * 3600, }; - const serviceHandler = connectNodeAdapter({ routes: testRoutes }); + const serviceHandler = connectNodeAdapter({ + routes: testRoutes, + requireConnectProtocolHeader: true, + }); nodeHttpServer = http .createServer({}, (req, res) => { if (req.method === "OPTIONS") { @@ -243,7 +255,10 @@ export function createTestServers() { cert: certLocalhost.cert, key: certLocalhost.key, }, - connectNodeAdapter({ routes: testRoutes }) + connectNodeAdapter({ + routes: testRoutes, + requireConnectProtocolHeader: true, + }) ) .listen(0, resolve); }); @@ -280,6 +295,7 @@ export function createTestServers() { }); await fastifyH2cServer.register(fastifyConnectPlugin, { routes: testRoutes, + requireConnectProtocolHeader: true, }); await fastifyH2cServer.listen(); }, @@ -306,6 +322,7 @@ export function createTestServers() { app.use( expressConnectMiddleware({ routes: testRoutes, + requireConnectProtocolHeader: true, }) ); expressServer = http.createServer(app); diff --git a/packages/connect/src/protocol-connect/headers.spec.ts b/packages/connect/src/fetch-headers.spec.ts similarity index 100% rename from packages/connect/src/protocol-connect/headers.spec.ts rename to packages/connect/src/fetch-headers.spec.ts diff --git a/packages/connect/src/protocol-connect/handler-factory.spec.ts b/packages/connect/src/protocol-connect/handler-factory.spec.ts new file mode 100644 index 000000000..13684bfd5 --- /dev/null +++ b/packages/connect/src/protocol-connect/handler-factory.spec.ts @@ -0,0 +1,193 @@ +// Copyright 2021-2023 Buf Technologies, Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +import { + Int32Value, + Message, + MethodInfo, + MethodKind, + ServiceType, + StringValue, +} from "@bufbuild/protobuf"; +import { createHandlerFactory } from "./handler-factory.js"; +import { + createMethodImplSpec, + HandlerContext, + MethodImpl, +} from "../implementation.js"; +import type { UniversalHandlerOptions } from "../protocol/index.js"; +import { errorFromJsonBytes } from "./error-json.js"; +import { ConnectError } from "../connect-error.js"; +import { Code } from "../code.js"; +import { endStreamFromJson } from "./end-stream.js"; +import { + createAsyncIterableBytes, + readAllBytes, +} from "../protocol/async-iterable-helper.spec.js"; + +describe("createHandlerFactory()", function () { + const testService: ServiceType = { + typeName: "TestService", + methods: { + foo: { + name: "Foo", + I: Int32Value, + O: StringValue, + kind: MethodKind.Unary, + }, + bar: { + name: "Bar", + I: Int32Value, + O: StringValue, + kind: MethodKind.ServerStreaming, + }, + }, + } as const; + + function stub( + opt: { + service?: ServiceType; + method?: M; + impl?: MethodImpl; + } & Partial + ) { + const method = opt.method ?? testService.methods.foo; + let implDefault: MethodImpl; + switch (method.kind) { + case MethodKind.Unary: + // eslint-disable-next-line @typescript-eslint/require-await + implDefault = async function (req: Message, ctx: HandlerContext) { + ctx.responseHeader.set("stub-handler", "1"); + return new ctx.method.O(); + } as unknown as MethodImpl; + break; + case MethodKind.ServerStreaming: + // eslint-disable-next-line @typescript-eslint/require-await + implDefault = async function* (req: Message, ctx: HandlerContext) { + ctx.responseHeader.set("stub-handler", "1"); + yield new ctx.method.O(); + } as unknown as MethodImpl; + break; + default: + throw new Error("not implemented"); + } + const spec = createMethodImplSpec( + opt.service ?? testService, + method, + opt.impl ?? implDefault + ); + const f = createHandlerFactory(opt); + return f(spec); + } + + describe("requireConnectProtocolHeader", function () { + describe("with unary RPC", function () { + const h = stub({ requireConnectProtocolHeader: true }); + it("should raise error for missing header", async function () { + const res = await h({ + httpVersion: "1.1", + method: "POST", + url: new URL("https://example.com"), + header: new Headers({ "Content-Type": "application/json" }), + body: 777, + }); + expect(res.status).toBe(400); + expect(res.body).toBeInstanceOf(Uint8Array); + if (res.body instanceof Uint8Array) { + const err = errorFromJsonBytes( + res.body, + undefined, + new ConnectError("failed to parse connect err", Code.Internal) + ); + expect(err.message).toBe( + '[invalid_argument] missing required header: set Connect-Protocol-Version to "1"' + ); + } + }); + it("should raise error for wrong header", async function () { + const res = await h({ + httpVersion: "1.1", + method: "POST", + url: new URL("https://example.com"), + header: new Headers({ + "Content-Type": "application/json", + "Connect-Protocol-Version": "UNEXPECTED", + }), + body: 777, + }); + expect(res.status).toBe(400); + expect(res.body).toBeInstanceOf(Uint8Array); + if (res.body instanceof Uint8Array) { + const err = errorFromJsonBytes( + res.body, + undefined, + new ConnectError("failed to parse connect err", Code.Internal) + ); + expect(err.message).toBe( + '[invalid_argument] Connect-Protocol-Version must be "1": got "UNEXPECTED"' + ); + } + }); + }); + describe("with streaming RPC", function () { + const h = stub({ + requireConnectProtocolHeader: true, + method: testService.methods.bar, + }); + it("should raise error for missing header", async function () { + const res = await h({ + httpVersion: "1.1", + method: "POST", + url: new URL("https://example.com"), + header: new Headers({ "Content-Type": "application/connect+json" }), + body: createAsyncIterableBytes(new Uint8Array()), + }); + expect(res.status).toBe(200); + expect(res.body).not.toBeInstanceOf(Uint8Array); + expect(res.body).not.toBeUndefined(); + if (res.body !== undefined && Symbol.asyncIterator in res.body) { + const end = endStreamFromJson( + (await readAllBytes(res.body)).slice(5) + ); + expect(end.error?.message).toBe( + '[invalid_argument] missing required header: set Connect-Protocol-Version to "1"' + ); + } + }); + it("should raise error for wrong header", async function () { + const res = await h({ + httpVersion: "1.1", + method: "POST", + url: new URL("https://example.com"), + header: new Headers({ + "Content-Type": "application/connect+json", + "Connect-Protocol-Version": "UNEXPECTED", + }), + body: createAsyncIterableBytes(new Uint8Array()), + }); + expect(res.status).toBe(200); + expect(res.body).not.toBeInstanceOf(Uint8Array); + expect(res.body).not.toBeUndefined(); + if (res.body !== undefined && Symbol.asyncIterator in res.body) { + const end = endStreamFromJson( + (await readAllBytes(res.body)).slice(5) + ); + expect(end.error?.message).toBe( + '[invalid_argument] Connect-Protocol-Version must be "1": got "UNEXPECTED"' + ); + } + }); + }); + }); +}); diff --git a/packages/connect/src/protocol-connect/handler-factory.ts b/packages/connect/src/protocol-connect/handler-factory.ts index d98d2c1ff..66e08c627 100644 --- a/packages/connect/src/protocol-connect/handler-factory.ts +++ b/packages/connect/src/protocol-connect/handler-factory.ts @@ -73,6 +73,7 @@ import { import { codeToHttpStatus } from "./http-status.js"; import { errorToJsonBytes } from "./error-json.js"; import { trailerMux } from "./trailer-mux.js"; +import { requireProtocolVersion } from "./version.js"; const protocolName = "connect"; const methodPost = "POST"; @@ -153,6 +154,7 @@ function createUnaryHandler, O extends Message>( let status = uResponseOk.status; let body: Uint8Array; try { + if (opt.requireConnectProtocolHeader) requireProtocolVersion(req.header); // raise compression error to serialize it as a error response if (compression.error) { throw compression.error; @@ -278,6 +280,8 @@ function createStreamHandler, O extends Message>( const outputIt = pipe( req.body, transformPrepend(() => { + if (opt.requireConnectProtocolHeader) + requireProtocolVersion(req.header); // raise compression error to serialize it as the end stream response if (compression.error) throw compression.error; return undefined; diff --git a/packages/connect/src/protocol-connect/version.ts b/packages/connect/src/protocol-connect/version.ts index af8e5133f..295752666 100644 --- a/packages/connect/src/protocol-connect/version.ts +++ b/packages/connect/src/protocol-connect/version.ts @@ -12,7 +12,30 @@ // See the License for the specific language governing permissions and // limitations under the License. +import { headerProtocolVersion } from "./headers.js"; +import { ConnectError } from "../connect-error.js"; +import { Code } from "../code.js"; + /** * The only know value for the header Connect-Protocol-Version. */ export const protocolVersion = "1"; + +/** + * Requires the Connect-Protocol-Version header to be present with the expected + * value. Raises a ConnectError with Code.InvalidArgument otherwise. + */ +export function requireProtocolVersion(requestHeader: Headers) { + const v = requestHeader.get(headerProtocolVersion); + if (v === null) { + throw new ConnectError( + `missing required header: set ${headerProtocolVersion} to "${protocolVersion}"`, + Code.InvalidArgument + ); + } else if (v !== protocolVersion) { + throw new ConnectError( + `${headerProtocolVersion} must be "${protocolVersion}": got "${v}"`, + Code.InvalidArgument + ); + } +} diff --git a/packages/connect/src/protocol-grpc-web/trailer.spec.ts b/packages/connect/src/protocol-grpc-web/trailer.spec.ts index 2e527ea79..4f7c423ef 100644 --- a/packages/connect/src/protocol-grpc-web/trailer.spec.ts +++ b/packages/connect/src/protocol-grpc-web/trailer.spec.ts @@ -60,7 +60,7 @@ describe("trailerSerialize()", () => { }); }); -describe("roundtrip", () => { +describe("trailer roundtrip", () => { it("should work", () => { const a = new Headers({ foo: "a, b", diff --git a/packages/connect/src/protocol/limit-io.spec.ts b/packages/connect/src/protocol/limit-io.spec.ts new file mode 100644 index 000000000..1c99db821 --- /dev/null +++ b/packages/connect/src/protocol/limit-io.spec.ts @@ -0,0 +1,56 @@ +// Copyright 2021-2023 Buf Technologies, Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +import { validateReadWriteMaxBytes } from "./limit-io.js"; + +describe("validateReadWriteMaxBytes()", function () { + it("should set defaults", function () { + const o = validateReadWriteMaxBytes(undefined, undefined, undefined); + expect(o).toEqual({ + readMaxBytes: 0xffffffff, + writeMaxBytes: 0xffffffff, + compressMinBytes: 1024, + }); + }); + it("should accept inputs", function () { + const o = validateReadWriteMaxBytes(666, 777, 888); + expect(o).toEqual({ + readMaxBytes: 666, + writeMaxBytes: 777, + compressMinBytes: 888, + }); + }); + it("should assert sane limits for readMaxBytes", function () { + expect(() => + validateReadWriteMaxBytes(-1, undefined, undefined) + ).toThrowError("[internal] readMaxBytes -1 must be >= 1 and <= 4294967295"); + expect(() => + validateReadWriteMaxBytes(0xffffffff + 1, undefined, undefined) + ).toThrowError( + "[internal] readMaxBytes 4294967296 must be >= 1 and <= 4294967295" + ); + }); + it("should assert sane limits for writeMaxBytes", function () { + expect(() => + validateReadWriteMaxBytes(undefined, -1, undefined) + ).toThrowError( + "[internal] writeMaxBytes -1 must be >= 1 and <= 4294967295" + ); + expect(() => + validateReadWriteMaxBytes(undefined, 0xffffffff + 1, undefined) + ).toThrowError( + "[internal] writeMaxBytes 4294967296 must be >= 1 and <= 4294967295" + ); + }); +}); diff --git a/packages/connect/src/protocol/universal-handler.spec.ts b/packages/connect/src/protocol/universal-handler.spec.ts new file mode 100644 index 000000000..55f32ac0e --- /dev/null +++ b/packages/connect/src/protocol/universal-handler.spec.ts @@ -0,0 +1,215 @@ +// Copyright 2021-2023 Buf Technologies, Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +import type { ServiceType } from "@bufbuild/protobuf"; +import { Int32Value, MethodKind, StringValue } from "@bufbuild/protobuf"; +import { + negotiateProtocol, + UniversalHandler, + UniversalHandlerOptions, + validateUniversalHandlerOptions, +} from "./universal-handler.js"; +import type { Compression } from "./compression.js"; +import { contentTypeMatcher } from "./content-type-matcher.js"; +import { Headers } from "undici"; + +describe("validateUniversalHandlerOptions()", function () { + it("should set defaults", function () { + const o = validateUniversalHandlerOptions({}); + expect(o).toEqual({ + acceptCompression: [], + compressMinBytes: 1024, + readMaxBytes: 0xffffffff, + writeMaxBytes: 0xffffffff, + jsonOptions: undefined, + binaryOptions: undefined, + maxDeadlineDurationMs: Number.MAX_SAFE_INTEGER, + shutdownSignal: new AbortController().signal, + requireConnectProtocolHeader: false, + }); + }); + it("should accept inputs", function () { + const fakeCompression: Compression = { + name: "fake", + compress: (bytes) => Promise.resolve(bytes), + decompress: (bytes) => Promise.resolve(bytes), + }; + const i: UniversalHandlerOptions = { + acceptCompression: [fakeCompression], + compressMinBytes: 444, + readMaxBytes: 777, + writeMaxBytes: 666, + jsonOptions: { + ignoreUnknownFields: true, + emitDefaultValues: true, + }, + binaryOptions: { + readUnknownFields: true, + writeUnknownFields: false, + }, + maxDeadlineDurationMs: 888, + shutdownSignal: new AbortController().signal, + requireConnectProtocolHeader: true, + }; + const o = validateUniversalHandlerOptions(i); + expect(o).toEqual(i); + }); +}); + +describe("negotiateProtocol()", function () { + const testService: ServiceType = { + typeName: "TestService", + methods: { + foo: { + name: "Foo", + I: Int32Value, + O: StringValue, + kind: MethodKind.Unary, + }, + bar: { + name: "Bar", + I: Int32Value, + O: StringValue, + kind: MethodKind.Unary, + }, + }, + } as const; + + function stubHandler(o: Partial): UniversalHandler { + return Object.assign( + function () { + return Promise.resolve({ + status: 200, + header: new Headers({ "stub-handler": "1" }), + }); + }, + { + protocolNames: ["protocol-x"], + service: testService, + method: testService.methods.foo, + requestPath: `/${testService.typeName}/${testService.methods.foo.name}`, + allowedMethods: ["POST"], + supportedContentType: contentTypeMatcher(/application\/x/), + ...o, + } + ); + } + + it("should require at least one handler", function () { + expect(() => negotiateProtocol([])).toThrowError( + "[internal] require at least one protocol" + ); + }); + + it("should require all handlers to be for the same RPC", function () { + const foo = stubHandler({ method: testService.methods.foo }); + const bar = stubHandler({ method: testService.methods.bar }); + expect(() => negotiateProtocol([foo, bar])).toThrowError( + "[internal] cannot negotiate protocol for different RPCs" + ); + }); + + it("should require all handlers to have the same request path", function () { + const a = stubHandler({ requestPath: `/a` }); + const b = stubHandler({ requestPath: `/b` }); + expect(() => negotiateProtocol([a, b])).toThrowError( + "[internal] cannot negotiate protocol for different requestPaths" + ); + }); + + it("should merge protocolNames", function () { + const h = negotiateProtocol([ + stubHandler({ protocolNames: ["x"] }), + stubHandler({ protocolNames: ["y", "z"] }), + ]); + expect(h.protocolNames).toEqual(["x", "y", "z"]); + }); + + it("should merge allowedMethods", function () { + const h = negotiateProtocol([ + stubHandler({ allowedMethods: ["POST", "PUT"] }), + stubHandler({ allowedMethods: ["POST", "GET"] }), + ]); + expect(h.allowedMethods).toEqual(["POST", "PUT", "GET"]); + }); + + describe("negotiating handler", function () { + const h = negotiateProtocol([stubHandler({})]); + it("should return HTTP 415 for unsupported request content-type", async function () { + const r = await h({ + httpVersion: "1.1", + method: "POST", + url: new URL("https://example.com"), + header: new Headers({ "Content-Type": "UNSUPPORTED" }), + body: null, + }); + expect(r.status).toBe(415); + }); + it("should return HTTP 405 for matching request content-type but unsupported method", async function () { + const r = await h({ + httpVersion: "1.1", + method: "UNSUPPORTED", + url: new URL("https://example.com"), + header: new Headers({ "Content-Type": "application/x" }), + body: null, + }); + expect(r.status).toBe(405); + }); + it("should call implementation for matching content-type and method", async function () { + const r = await h({ + httpVersion: "1.1", + method: "POST", + url: new URL("https://example.com"), + header: new Headers({ "Content-Type": "application/x" }), + body: null, + }); + expect(r.status).toBe(200); + expect(r.header?.get("stub-handler")).toBe("1"); + }); + + describe("for bidi stream", function () { + const h = negotiateProtocol([ + stubHandler({ + method: { + ...testService.methods.foo, + kind: MethodKind.BiDiStreaming, + }, + }), + ]); + it("should return HTTP 505 for HTTP 1.1", async function () { + const r = await h({ + httpVersion: "1.1", + method: "POST", + url: new URL("https://example.com"), + header: new Headers({ "Content-Type": "application/x" }), + body: null, + }); + expect(r.status).toBe(505); + expect(r.header?.get("Connection")).toBe("close"); + expect(r.body).toBeUndefined(); + }); + it("should require HTTP/2", async function () { + const r = await h({ + httpVersion: "2", + method: "POST", + url: new URL("https://example.com"), + header: new Headers({ "Content-Type": "application/x" }), + body: null, + }); + expect(r.status).toBe(200); + expect(r.header?.get("stub-handler")).toBe("1"); + }); + }); + }); +}); diff --git a/packages/connect/src/protocol/universal-handler.ts b/packages/connect/src/protocol/universal-handler.ts index 835a4f5ab..d37c43a2d 100644 --- a/packages/connect/src/protocol/universal-handler.ts +++ b/packages/connect/src/protocol/universal-handler.ts @@ -29,7 +29,6 @@ import { uResponseUnsupportedMediaType, uResponseVersionNotSupported, } from "./universal.js"; -import { createMethodUrl } from "./create-method-url.js"; import { ContentTypeMatcher, contentTypeMatcher, @@ -37,6 +36,8 @@ import { import type { Compression } from "./compression.js"; import type { ProtocolHandlerFactory } from "./protocol-handler-factory.js"; import { validateReadWriteMaxBytes } from "./limit-io.js"; +import { ConnectError } from "../connect-error.js"; +import { Code } from "../code.js"; /** * Common options for handlers. @@ -85,7 +86,18 @@ export interface UniversalHandlerOptions { maxDeadlineDurationMs: number; // TODO TCN-785 shutdownSignal: AbortSignal; // TODO TCN-919 - // TODO + + /** + * Require requests using the Connect protocol to include the header + * Connect-Protocol-Version. This ensures that HTTP proxies and other + * code inspecting traffic can easily identify Connect RPC requests, + * even if they use a common Content-Type like application/json. + * + * If a Connect request does not include the Connect-Protocol-Version + * header, an error with code invalid_argument (HTTP 400) is returned. + * This option has no effect if the client uses the gRPC or the gRPC-web + * protocol. + */ requireConnectProtocolHeader: boolean; } @@ -139,8 +151,16 @@ export function validateUniversalHandlerOptions( opt: Partial | undefined ): UniversalHandlerOptions { opt ??= {}; + const acceptCompression = opt.acceptCompression + ? [...opt.acceptCompression] + : []; + const requireConnectProtocolHeader = + opt.requireConnectProtocolHeader ?? false; + const shutdownSignal = opt.shutdownSignal ?? neverSignal; + const maxDeadlineDurationMs = + opt.maxDeadlineDurationMs ?? Number.MAX_SAFE_INTEGER; return { - acceptCompression: opt.acceptCompression ? [...opt.acceptCompression] : [], + acceptCompression, ...validateReadWriteMaxBytes( opt.readMaxBytes, opt.writeMaxBytes, @@ -148,15 +168,17 @@ export function validateUniversalHandlerOptions( ), jsonOptions: opt.jsonOptions, binaryOptions: opt.binaryOptions, - maxDeadlineDurationMs: Number.MAX_SAFE_INTEGER, - shutdownSignal: neverSignal, - requireConnectProtocolHeader: false, + maxDeadlineDurationMs, + shutdownSignal, + requireConnectProtocolHeader, }; } /** * For the given service implementation, return a universal handler for each * RPC. The handler serves the given protocols. + * + * At least one protocol is required. */ export function createUniversalServiceHandlers( spec: ServiceImplSpec, @@ -170,16 +192,14 @@ export function createUniversalServiceHandlers( /** * Return a universal handler for the given RPC implementation. * The handler serves the given protocols. + * + * At least one protocol is required. */ export function createUniversalMethodHandler( spec: MethodImplSpec, protocols: ProtocolHandlerFactory[] ): UniversalHandler { - return negotiateProtocol( - spec.service, - spec.method, - protocols.map((f) => f(spec)) - ); + return negotiateProtocol(protocols.map((f) => f(spec))); } /** @@ -189,12 +209,33 @@ export function createUniversalMethodHandler( * different protocols - and returns a single handler that looks at the * Content-Type header and the HTTP verb of the incoming request to select * the appropriate protocol-specific handler. + * + * Raises an error if no protocol handlers were provided, or if they do not + * handle exactly the same RPC. */ -function negotiateProtocol( - service: ServiceType, - method: MethodInfo, +export function negotiateProtocol( protocolHandlers: UniversalHandler[] ): UniversalHandler { + if (protocolHandlers.length == 0) { + throw new ConnectError("require at least one protocol", Code.Internal); + } + const service = protocolHandlers[0].service; + const method = protocolHandlers[0].method; + const requestPath = protocolHandlers[0].requestPath; + if ( + protocolHandlers.some((h) => h.service !== service || h.method !== method) + ) { + throw new ConnectError( + "cannot negotiate protocol for different RPCs", + Code.Internal + ); + } + if (protocolHandlers.some((h) => h.requestPath !== requestPath)) { + throw new ConnectError( + "cannot negotiate protocol for different requestPaths", + Code.Internal + ); + } async function protocolNegotiatingHandler(request: UniversalServerRequest) { if ( method.kind == MethodKind.BiDiStreaming && @@ -209,36 +250,26 @@ function negotiateProtocol( }; } const contentType = request.header.get("Content-Type") ?? ""; - const firstMatch = protocolHandlers - .filter( - (h) => - h.supportedContentType(contentType) && - h.allowedMethods.includes(request.method) - ) - .shift(); - if (firstMatch) { - return firstMatch(request); - } - const contentTypeMatches = protocolHandlers.some((h) => + const matchingContentTypes = protocolHandlers.filter((h) => h.supportedContentType(contentType) ); - if (!contentTypeMatches) { + if (matchingContentTypes.length == 0) { return uResponseUnsupportedMediaType; } - const methodMatches = protocolHandlers.some((h) => + const matchingMethod = matchingContentTypes.filter((h) => h.allowedMethods.includes(request.method) ); - if (!methodMatches) { + if (matchingMethod.length == 0) { return uResponseMethodNotAllowed; } - return uResponseUnsupportedMediaType; + const firstMatch = matchingMethod[0]; + return firstMatch(request); } return Object.assign(protocolNegotiatingHandler, { service, method, - // we expect all protocols to be served under the same path - requestPath: createMethodUrl("/", service, method), + requestPath, supportedContentType: contentTypeMatcher( ...protocolHandlers.map((h) => h.supportedContentType) ), diff --git a/packages/connect/tsconfig.test.json b/packages/connect/tsconfig.test.json deleted file mode 100644 index e69de29bb..000000000