Skip to content

Commit

Permalink
fix(middleware-websocket): update eventStreamHandler to use MessageSi…
Browse files Browse the repository at this point in the history
…gner (#4803)

Co-authored-by: AndrewFossAWS <108305217+AndrewFossAWS@users.noreply.github.com>
  • Loading branch information
trivikr and AndrewFossAWS authored Jun 7, 2023
1 parent 7529755 commit d8317fe
Show file tree
Hide file tree
Showing 5 changed files with 61 additions and 31 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,14 @@
* @jest-environment jsdom
*/
import { EventStreamCodec } from "@aws-sdk/eventstream-codec";
import { Decoder, Encoder, EventSigner, FinalizeHandler, FinalizeHandlerArguments, HttpRequest } from "@aws-sdk/types";
import {
Decoder,
Encoder,
FinalizeHandler,
FinalizeHandlerArguments,
HttpRequest,
MessageSigner,
} from "@aws-sdk/types";
import { ReadableStream, TransformStream } from "web-streams-polyfill";

import { EventStreamPayloadHandler } from "./EventStreamPayloadHandler";
Expand All @@ -11,8 +18,9 @@ jest.mock("./get-event-signing-stream");
jest.mock("@aws-sdk/eventstream-codec");

describe(EventStreamPayloadHandler.name, () => {
const mockSigner: EventSigner = {
const mockSigner: MessageSigner = {
sign: jest.fn(),
signMessage: jest.fn(),
};
const mockUtf8Decoder: Decoder = jest.fn();
const mockUtf8encoder: Encoder = jest.fn();
Expand All @@ -32,7 +40,7 @@ describe(EventStreamPayloadHandler.name, () => {

it("should throw if request payload is not a stream", () => {
const handler = new EventStreamPayloadHandler({
eventSigner: () => Promise.resolve(mockSigner),
messageSigner: () => Promise.resolve(mockSigner),
utf8Decoder: mockUtf8Decoder,
utf8Encoder: mockUtf8encoder,
});
Expand All @@ -49,7 +57,7 @@ describe(EventStreamPayloadHandler.name, () => {
(mockNextHandler as any).mockImplementationOnce(() => Promise.reject(mockError));

const handler = new EventStreamPayloadHandler({
eventSigner: () => Promise.resolve(mockSigner),
messageSigner: () => Promise.resolve(mockSigner),
utf8Decoder: mockUtf8Decoder,
utf8Encoder: mockUtf8encoder,
});
Expand Down Expand Up @@ -79,7 +87,7 @@ describe(EventStreamPayloadHandler.name, () => {
} as any;

const handler = new EventStreamPayloadHandler({
eventSigner: () => Promise.resolve(mockSigner),
messageSigner: () => Promise.resolve(mockSigner),
utf8Decoder: mockUtf8Decoder,
utf8Encoder: mockUtf8encoder,
});
Expand All @@ -106,7 +114,7 @@ describe(EventStreamPayloadHandler.name, () => {
} as any;

const handler = new EventStreamPayloadHandler({
eventSigner: () => Promise.resolve(mockSigner),
messageSigner: () => Promise.resolve(mockSigner),
utf8Decoder: mockUtf8Decoder,
utf8Encoder: mockUtf8encoder,
});
Expand All @@ -129,7 +137,7 @@ describe(EventStreamPayloadHandler.name, () => {
headers: { authorization },
} as any;
const handler = new EventStreamPayloadHandler({
eventSigner: () => Promise.resolve(mockSigner),
messageSigner: () => Promise.resolve(mockSigner),
utf8Decoder: mockUtf8Decoder,
utf8Encoder: mockUtf8encoder,
});
Expand Down
10 changes: 5 additions & 5 deletions packages/middleware-websocket/src/EventStreamPayloadHandler.ts
Original file line number Diff line number Diff line change
Expand Up @@ -2,21 +2,21 @@ import { EventStreamCodec } from "@aws-sdk/eventstream-codec";
import {
Decoder,
Encoder,
EventSigner,
EventStreamPayloadHandler as IEventStreamPayloadHandler,
FinalizeHandler,
FinalizeHandlerArguments,
FinalizeHandlerOutput,
HandlerExecutionContext,
HttpRequest,
MessageSigner,
MetadataBearer,
Provider,
} from "@aws-sdk/types";

import { getEventSigningTransformStream } from "./get-event-signing-stream";

export interface EventStreamPayloadHandlerOptions {
eventSigner: Provider<EventSigner>;
messageSigner: Provider<MessageSigner>;
utf8Encoder: Encoder;
utf8Decoder: Decoder;
}
Expand All @@ -29,11 +29,11 @@ export interface EventStreamPayloadHandlerOptions {
* 4. Sign the payload after payload stream starting to flow.
*/
export class EventStreamPayloadHandler implements IEventStreamPayloadHandler {
private readonly eventSigner: Provider<EventSigner>;
private readonly messageSigner: Provider<MessageSigner>;
private readonly eventStreamCodec: EventStreamCodec;

constructor(options: EventStreamPayloadHandlerOptions) {
this.eventSigner = options.eventSigner;
this.messageSigner = options.messageSigner;
this.eventStreamCodec = new EventStreamCodec(options.utf8Encoder, options.utf8Decoder);
}

Expand Down Expand Up @@ -68,7 +68,7 @@ export class EventStreamPayloadHandler implements IEventStreamPayloadHandler {
const priorSignature = (match || [])[1] || (query && (query["X-Amz-Signature"] as string)) || "";
const signingStream = getEventSigningTransformStream(
priorSignature,
await this.eventSigner(),
await this.messageSigner(),
this.eventStreamCodec
);

Expand Down
Original file line number Diff line number Diff line change
@@ -1,10 +1,10 @@
import { Decoder, Encoder, EventSigner, EventStreamPayloadHandlerProvider, Provider } from "@aws-sdk/types";
import { Decoder, Encoder, EventStreamPayloadHandlerProvider, MessageSigner, Provider } from "@aws-sdk/types";

import { EventStreamPayloadHandler } from "./EventStreamPayloadHandler";

/** NodeJS event stream utils provider */
export const eventStreamPayloadHandlerProvider: EventStreamPayloadHandlerProvider = (options: {
utf8Encoder: Encoder;
utf8Decoder: Decoder;
eventSigner: Provider<EventSigner>;
messageSigner: Provider<MessageSigner>;
}) => new EventStreamPayloadHandler(options);
38 changes: 29 additions & 9 deletions packages/middleware-websocket/src/get-event-signing-stream.spec.ts
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
* @jest-environment jsdom
*/
import { EventStreamCodec } from "@aws-sdk/eventstream-codec";
import { Message, MessageHeaders } from "@aws-sdk/types";
import { Message, MessageHeaders, SignedMessage } from "@aws-sdk/types";
import { fromUtf8, toUtf8 } from "@aws-sdk/util-utf8";
import { TransformStream } from "web-streams-polyfill";

Expand Down Expand Up @@ -50,10 +50,19 @@ describe(getEventSigningTransformStream.name, () => {
},
},
];
const mockEventSigner = jest
const message1: Message = {
headers: {},
body: fromUtf8("foo"),
};
const message2: Message = {
headers: {},
body: fromUtf8("bar"),
};
const mockMessageSigner = jest
.fn()
.mockReturnValueOnce("7369676e617475726531") //'signature1'
.mockReturnValueOnce("7369676e617475726532"); //'signature2'
.mockReturnValueOnce({ message: message1, signature: "7369676e617475726531" } as SignedMessage) //'signature1'
.mockReturnValueOnce({ message: message2, signature: "7369676e617475726532" } as SignedMessage); //'signature2'

// mock 'new Date()'
let mockDateCount = 0;
// eslint-disable-next-line @typescript-eslint/no-unused-vars
Expand All @@ -65,7 +74,14 @@ describe(getEventSigningTransformStream.name, () => {
mockDateCount += 1;
return expected[mockDateCount - 1][":date"].value;
});
const signingStream = getEventSigningTransformStream("initial", { sign: mockEventSigner }, eventStreamCodec);
const signingStream = getEventSigningTransformStream(
"initial",
{
sign: mockMessageSigner,
signMessage: mockMessageSigner,
},
eventStreamCodec
);
const output: Array<MessageHeaders> = [];

const reader = signingStream.readable.getReader();
Expand All @@ -87,9 +103,13 @@ describe(getEventSigningTransformStream.name, () => {
await writer.close();
await writer.closed;
expect(output).toEqual(expected);
expect(mockEventSigner.mock.calls[0][1].priorSignature).toBe("initial");
expect(mockEventSigner.mock.calls[0][1].signingDate.getTime()).toBe((expected[0][":date"].value as Date).getTime());
expect(mockEventSigner.mock.calls[1][1].priorSignature).toBe("7369676e617475726531");
expect(mockEventSigner.mock.calls[1][1].signingDate.getTime()).toBe((expected[1][":date"].value as Date).getTime());
expect(mockMessageSigner.mock.calls[0][0].priorSignature).toBe("initial");
expect(mockMessageSigner.mock.calls[0][1].signingDate.getTime()).toBe(
(expected[0][":date"].value as Date).getTime()
);
expect(mockMessageSigner.mock.calls[1][0].priorSignature).toBe("7369676e617475726531");
expect(mockMessageSigner.mock.calls[1][1].signingDate.getTime()).toBe(
(expected[1][":date"].value as Date).getTime()
);
});
});
18 changes: 10 additions & 8 deletions packages/middleware-websocket/src/get-event-signing-stream.ts
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
import { EventStreamCodec } from "@aws-sdk/eventstream-codec";
import { EventSigner, MessageHeaders } from "@aws-sdk/types";
import { MessageHeaders, MessageSigner } from "@aws-sdk/types";
import { fromHex } from "@aws-sdk/util-hex-encoding";

/**
Expand All @@ -9,7 +9,7 @@ import { fromHex } from "@aws-sdk/util-hex-encoding";
*/
export const getEventSigningTransformStream = (
initialSignature: string,
eventSigner: EventSigner,
messageSigner: MessageSigner,
eventStreamCodec: EventStreamCodec
): TransformStream<Uint8Array, Uint8Array> => {
let priorSignature = initialSignature;
Expand All @@ -21,23 +21,25 @@ export const getEventSigningTransformStream = (
const dateHeader: MessageHeaders = {
":date": { type: "timestamp", value: now },
};
const signature = await eventSigner.sign(
const signedMessage = await messageSigner.sign(
{
payload: chunk,
headers: eventStreamCodec.formatHeaders(dateHeader),
message: {
body: chunk,
headers: dateHeader,
},
priorSignature: priorSignature,
},
{
priorSignature,
signingDate: now,
}
);
priorSignature = signature;
priorSignature = signedMessage.signature;
const serializedSigned = eventStreamCodec.encode({
headers: {
...dateHeader,
":chunk-signature": {
type: "binary",
value: fromHex(signature),
value: fromHex(signedMessage.signature),
},
},
body: chunk,
Expand Down

0 comments on commit d8317fe

Please sign in to comment.