Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix edge cases with AsyncIterable protocol #729

Merged
merged 5 commits into from
Jul 31, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion packages/connect-web-bench/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -10,5 +10,5 @@ it like a web server would usually do.

| code generator | bundle size | minified | compressed |
|----------------|-------------------:|-----------------------:|---------------------:|
| connect | 113,771 b | 49,912 b | 13,369 b |
| connect | 112,346 b | 49,541 b | 13,333 b |
| grpc-web | 414,906 b | 301,127 b | 53,279 b |
7 changes: 4 additions & 3 deletions packages/connect-web/src/connect-transport.ts
Original file line number Diff line number Diff line change
Expand Up @@ -136,7 +136,7 @@ export function createConnectTransport(
header: HeadersInit | undefined,
message: PartialMessage<I>
): Promise<UnaryResponse<I, O>> {
const { normalize, serialize, parse } = createClientMethodSerializers(
const { serialize, parse } = createClientMethodSerializers(
method,
useBinaryFormat,
options.jsonOptions,
Expand All @@ -163,7 +163,7 @@ export function createConnectTransport(
timeoutMs,
header
),
message: normalize(message),
message,
},
next: async (req: UnaryRequest<I, O>): Promise<UnaryResponse<I, O>> => {
const useGet =
Expand Down Expand Up @@ -228,7 +228,7 @@ export function createConnectTransport(
signal: AbortSignal | undefined,
timeoutMs: number | undefined,
header: HeadersInit | undefined,
input: AsyncIterable<I>
input: AsyncIterable<PartialMessage<I>>
): Promise<StreamResponse<I, O>> {
const { serialize, parse } = createClientMethodSerializers(
method,
Expand Down Expand Up @@ -279,6 +279,7 @@ export function createConnectTransport(
}
return encodeEnvelope(0, serialize(r.value));
}

return await runStreamingCall<I, O>({
interceptors: options.interceptors,
timeoutMs,
Expand Down
6 changes: 3 additions & 3 deletions packages/connect-web/src/grpc-web-transport.ts
Original file line number Diff line number Diff line change
Expand Up @@ -131,7 +131,7 @@ export function createGrpcWebTransport(
header: Headers,
message: PartialMessage<I>
): Promise<UnaryResponse<I, O>> {
const { normalize, serialize, parse } = createClientMethodSerializers(
const { serialize, parse } = createClientMethodSerializers(
method,
useBinaryFormat,
options.jsonOptions,
Expand All @@ -153,7 +153,7 @@ export function createGrpcWebTransport(
mode: "cors",
},
header: requestHeader(useBinaryFormat, timeoutMs, header),
message: normalize(message),
message,
},
next: async (req: UnaryRequest<I, O>): Promise<UnaryResponse<I, O>> => {
const fetch = options.fetch ?? globalThis.fetch;
Expand Down Expand Up @@ -219,7 +219,7 @@ export function createGrpcWebTransport(
signal: AbortSignal | undefined,
timeoutMs: number | undefined,
header: HeadersInit | undefined,
input: AsyncIterable<I>
input: AsyncIterable<PartialMessage<I>>
): Promise<StreamResponse<I, O>> {
const { serialize, parse } = createClientMethodSerializers(
method,
Expand Down
4 changes: 1 addition & 3 deletions packages/connect/src/callback-client.ts
Original file line number Diff line number Diff line change
Expand Up @@ -137,8 +137,6 @@ function createServerStreamingFn<I extends Message<I>, O extends Message<O>>(
): ServerStreamingFn<I, O> {
return function (input, onResponse, onClose, options) {
const abort = new AbortController();
const inputMessage =
input instanceof method.I ? input : new method.I(input);
async function run() {
options = wrapSignal(abort, options);
const response = await transport.stream(
Expand All @@ -147,7 +145,7 @@ function createServerStreamingFn<I extends Message<I>, O extends Message<O>>(
options.signal,
options.timeoutMs,
options.headers,
createAsyncIterable([inputMessage])
createAsyncIterable([input])
);
options.onHeader?.(response.header);
for await (const message of response.message) {
Expand Down
176 changes: 176 additions & 0 deletions packages/connect/src/promise-client.spec.ts
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,8 @@ import {
import { createAsyncIterable } from "./protocol/async-iterable.js";
import { createRouterTransport } from "./router-transport.js";
import type { HandlerContext } from "./implementation";
import { ConnectError } from "./connect-error.js";
import { Code } from "./code.js";

const TestService = {
typeName: "handwritten.TestService",
Expand Down Expand Up @@ -76,6 +78,80 @@ describe("createClientStreamingFn()", function () {
expect(res).toBeInstanceOf(StringValue);
expect(res.value).toEqual(output.value);
});
it("closes the request iterable when response is received", async () => {
const output = new StringValue({ value: "yield 1" });
const transport = createRouterTransport(({ service }) => {
service(TestService, {
clientStream: async (input: AsyncIterable<Int32Value>) => {
for await (const next of input) {
expect(next.value).toBe(1);
return output;
}
throw new ConnectError(
"expected at least 1 value",
Code.InvalidArgument
);
},
});
});
const fn = createClientStreamingFn(
transport,
TestService,
TestService.methods.clientStream
);
let reqItrClosed = false;
const res = await fn(
// eslint-disable-next-line @typescript-eslint/require-await
(async function* () {
try {
yield { value: 1 };
fail("expected early return");
} finally {
reqItrClosed = true;
}
})()
);
expect(res).toBeInstanceOf(StringValue);
expect(res.value).toEqual(output.value);
expect(reqItrClosed).toBe(true);
});
it("closes the request iterable when response is received", async () => {
const transport = createRouterTransport(({ service }) => {
service(TestService, {
clientStream: async (input: AsyncIterable<Int32Value>) => {
for await (const next of input) {
expect(next.value).toBe(1);
throw new ConnectError("foo", Code.Internal);
}
throw new ConnectError(
"expected at least 1 value",
Code.InvalidArgument
);
},
});
});
const fn = createClientStreamingFn(
transport,
TestService,
TestService.methods.clientStream
);
let reqItrClosed = false;
const res = fn(
// eslint-disable-next-line @typescript-eslint/require-await
(async function* () {
try {
yield { value: 1 };
fail("expected early return");
} finally {
reqItrClosed = true;
}
})()
);
await expectAsync(res).toBeRejectedWith(
new ConnectError("foo", Code.Internal)
);
expect(reqItrClosed).toBe(true);
});
});

describe("createServerStreamingFn()", function () {
Expand Down Expand Up @@ -106,6 +182,20 @@ describe("createServerStreamingFn()", function () {
}
expect(receivedMessages).toEqual(output);
});
it("doesn't support throw/return on the returned response", function () {
const fn = createServerStreamingFn(
createRouterTransport(({ service }) => {
service(TestService, {
serverStream: () => createAsyncIterable([]),
});
}),
TestService,
TestService.methods.bidiStream
);
const it = fn({})[Symbol.asyncIterator]();
expect(it.throw).not.toBeDefined(); // eslint-disable-line @typescript-eslint/unbound-method
expect(it.return).not.toBeDefined(); // eslint-disable-line @typescript-eslint/unbound-method
});
});

describe("createBiDiStreamingFn()", () => {
Expand Down Expand Up @@ -142,4 +232,90 @@ describe("createBiDiStreamingFn()", () => {
expect(index).toBe(3);
expect(bidiIndex).toBe(3);
});
it("closes the request iterable when response is received", async () => {
const values = [123, 456, 789];

const input = createAsyncIterable(
values.map((value) => new Int32Value({ value }))
);
const transport = createRouterTransport(({ service }) => {
service(TestService, {
bidiStream: async function* (input: AsyncIterable<Int32Value>) {
for await (const next of input) {
yield { value: `yield ${next.value}` };
break;
}
},
});
});
const fn = createBiDiStreamingFn(
transport,
TestService,
TestService.methods.bidiStream
);

let count = 0;
for await (const res of fn(input)) {
expect(res).toEqual(new StringValue({ value: "yield 123" }));
count += 1;
}
expect(count).toBe(1);
expect(await input[Symbol.asyncIterator]().next()).toEqual({
done: true,
value: undefined,
});
});
it("closes the request iterable when an error is thrown", async () => {
const values = [123, 456, 789];

const input = createAsyncIterable(
values.map((value) => new Int32Value({ value }))
);
const transport = createRouterTransport(({ service }) => {
service(TestService, {
bidiStream: async function* (input: AsyncIterable<Int32Value>) {
for await (const next of input) {
yield { value: `yield ${next.value}` };
throw new ConnectError("foo", Code.Internal);
}
},
});
});
const fn = createBiDiStreamingFn(
transport,
TestService,
TestService.methods.bidiStream
);

let count = 0;
try {
for await (const res of fn(input)) {
expect(res).toEqual(new StringValue({ value: "yield 123" }));
count += 1;
}
} catch (e) {
expect(e).toBeInstanceOf(ConnectError);
expect((e as ConnectError).code).toBe(Code.Internal);
expect((e as ConnectError).rawMessage).toBe("foo");
}
expect(count).toBe(1);
expect(await input[Symbol.asyncIterator]().next()).toEqual({
done: true,
value: undefined,
});
});
it("doesn't support throw/return on the returned response", function () {
const fn = createBiDiStreamingFn(
createRouterTransport(({ service }) => {
service(TestService, {
bidiStream: () => createAsyncIterable([]),
});
}),
TestService,
TestService.methods.bidiStream
);
const it = fn(createAsyncIterable([]))[Symbol.asyncIterator]();
expect(it.throw).not.toBeDefined(); // eslint-disable-line @typescript-eslint/unbound-method
expect(it.return).not.toBeDefined(); // eslint-disable-line @typescript-eslint/unbound-method
});
});
Loading