Skip to content

Commit

Permalink
Add a generator for @httpMalformedRequestTests
Browse files Browse the repository at this point in the history
@httpMalformedRequestTests allow server implementations to generate tests
that exercise how the implementation deals with requests that are rejected
before they reach the customer's business logic. See
smithy-lang/smithy#871 for more information.
  • Loading branch information
adamthom-amzn committed Jul 28, 2021
1 parent 1027aa8 commit 5e1db33
Show file tree
Hide file tree
Showing 2 changed files with 185 additions and 55 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -57,6 +57,9 @@
import software.amazon.smithy.model.traits.IdempotencyTokenTrait;
import software.amazon.smithy.model.traits.StreamingTrait;
import software.amazon.smithy.protocoltests.traits.AppliesTo;
import software.amazon.smithy.protocoltests.traits.HttpMalformedRequestTestCase;
import software.amazon.smithy.protocoltests.traits.HttpMalformedRequestTestsTrait;
import software.amazon.smithy.protocoltests.traits.HttpMalformedResponseDefinition;
import software.amazon.smithy.protocoltests.traits.HttpMessageTestCase;
import software.amazon.smithy.protocoltests.traits.HttpRequestTestCase;
import software.amazon.smithy.protocoltests.traits.HttpRequestTestsTrait;
Expand All @@ -73,11 +76,12 @@
* Generates HTTP protocol test cases to be run using Jest.
*
* <p>Protocol tests are defined for HTTP protocols using the
* {@code smithy.test#httpRequestTests} and {@code smithy.test#httpResponseTests}
* traits. When found on operations or errors attached to operations, a
* protocol test case will be generated that asserts that the protocol
* serialization and deserialization code creates the correct HTTP requests
* and responses for a specific set of parameters.
* {@code smithy.test#httpRequestTests}, {@code smithy.test#httpResponseTests}
* and {@code smithy.test#httpMalformedRequestTests} traits. When found on
* operations or errors attached to operations, a protocol test case will be
* generated that asserts that the protocol serialization and deserialization
* code creates the correct HTTP requests and responses for a specific set of
* parameters.
*
* TODO: try/catch and if/else are still cumbersome with TypeScriptWriter.
*/
Expand Down Expand Up @@ -186,6 +190,12 @@ private void generateServerOperationTests(OperationShape operation, OperationInd
onlyIfProtocolMatches(testCase, () -> generateServerResponseTest(operation, testCase));
}
});
// 3. Generate malformed request test cases
operation.getTrait(HttpMalformedRequestTestsTrait.class).ifPresent(trait -> {
for (HttpMalformedRequestTestCase testCase : trait.getTestCases()) {
onlyIfProtocolMatches(testCase, () -> generateMalformedRequestTest(operation, testCase));
}
});
// 3. Generate test cases for each error on each operation.
for (StructureShape error : operationIndex.getErrors(operation)) {
if (!error.hasTag("client-only")) {
Expand All @@ -209,6 +219,16 @@ private <T extends HttpMessageTestCase> void onlyIfProtocolMatches(T testCase, R
}
}

// Only generate test cases when its protocol matches the target protocol.
private void onlyIfProtocolMatches(HttpMalformedRequestTestCase testCase, Runnable runnable) {
if (testCase.getProtocol().equals(protocol)) {
LOGGER.fine(() -> format("Generating malformed request test case for %s.%s",
service.getId(), testCase.getId()));
initializeWriterIfNeeded();
runnable.run();
}
}

private void initializeWriterIfNeeded() {
if (writer == null) {
writer = context.getWriter();
Expand Down Expand Up @@ -278,7 +298,7 @@ private void generateServerRequestTest(OperationShape operation, HttpRequestTest
Map<String, String> headers = testCase.getHeaders().entrySet().stream()
.map(entry -> new Pair<>(entry.getKey().toLowerCase(Locale.US), entry.getValue()))
.collect(MapUtils.toUnmodifiableMap(Pair::getLeft, Pair::getRight));
String queryParameters = Node.prettyPrintJson(buildQueryBag(testCase));
String queryParameters = Node.prettyPrintJson(buildQueryBag(testCase.getQueryParams()));
String headerParameters = Node.prettyPrintJson(ObjectNode.fromStringMap(headers));
String body = testCase.getBody().orElse(null);

Expand All @@ -290,26 +310,11 @@ private void generateServerRequestTest(OperationShape operation, HttpRequestTest

// Create a mock function to set in place of the server operation function so we can capture
// input and other information.
writer.write("let testFunction = jest.fn();");
writer.write("const testFunction = jest.fn();");
writer.write("testFunction.mockReturnValue(Promise.resolve({}));");

// We use a partial here so that we don't have to define the entire service, but still get the advantages
// the type checker, including excess property checking. Later on we'll use `as` to cast this to the
// full service so that we can actually use it.
writer.openBlock("const testService: Partial<$T<{}>> = {", "};", serviceSymbol, () -> {
writer.write("$L: testFunction as $T<{}>,", operationSymbol.getName(), operationSymbol);
});

String getHandlerName = "get" + handlerSymbol.getName();
writer.addImport(getHandlerName, null, "./server/");
writer.addImport("ValidationFailure", "__ValidationFailure", "@aws-smithy/server-common");

// Cast the service as any so TS will ignore the fact that the type being passed in is incomplete.
writer.openBlock(
"const handler = $L(testService as $T<{}>, (ctx: {}, failures: __ValidationFailure[]) => {",
"});", getHandlerName, serviceSymbol,
() -> writer.write("if (failures) { throw failures; } return undefined;")
);
boolean usesDefaultValidation = !context.getSettings().isDisableDefaultValidation();
setupStubService(operationSymbol, serviceSymbol, handlerSymbol, usesDefaultValidation);

// Construct a new http request according to the test case definition.
writer.openBlock("const request = new HttpRequest({", "});", () -> {
Expand All @@ -333,10 +338,84 @@ private void generateServerRequestTest(OperationShape operation, HttpRequestTest
});
}

private ObjectNode buildQueryBag(HttpRequestTestCase testCase) {
private void generateMalformedRequestTest(OperationShape operation, HttpMalformedRequestTestCase testCase) {
Symbol operationSymbol = symbolProvider.toSymbol(operation);

Map<String, String> requestHeaders = testCase.getRequest().getHeaders().entrySet().stream()
.map(entry -> new Pair<>(entry.getKey().toLowerCase(Locale.US), entry.getValue()))
.collect(MapUtils.toUnmodifiableMap(Pair::getLeft, Pair::getRight));
String queryParameters = Node.prettyPrintJson(buildQueryBag(testCase.getRequest().getQueryParams()));
String requestHeaderParameters = Node.prettyPrintJson(ObjectNode.fromStringMap(requestHeaders));
String requestBody = testCase.getRequest().getBody().orElse(null);

String testName = testCase.getId() + ":MalformedRequest";
testCase.getDocumentation().ifPresent(writer::writeDocs);
writer.openBlock("it($S, async () => {", "});\n", testName, () -> {
Symbol serviceSymbol = symbolProvider.toSymbol(service);
Symbol handlerSymbol = serviceSymbol.expectProperty("handler", Symbol.class);

// Create a mock function to set in place of the server operation function so we can capture
// input and other information.
writer.write("const testFunction = jest.fn();");
writer.openBlock("testFunction.mockImplementation(() => {", "});", () -> {
writer.write("throw new Error($S);", "This request should have been rejected.");
});

boolean usesDefaultValidation = !context.getSettings().isDisableDefaultValidation();
setupStubService(operationSymbol, serviceSymbol, handlerSymbol, usesDefaultValidation);

// Construct a new http request according to the test case definition.
writer.openBlock("const request = new HttpRequest({", "});", () -> {
writer.write("method: $S,", testCase.getRequest().getMethod());
writer.write("hostname: $S,", testCase.getRequest().getHost().orElse("foo.example.com"));
writer.write("path: $S,", testCase.getRequest().getUri());
writer.write("query: $L,", queryParameters);
writer.write("headers: $L,", requestHeaderParameters);
if (requestBody != null) {
writer.write("body: Readable.from([$S]),", requestBody);
}
});
writer.write("const r = await handler.handle(request, {});").write("");

// Assert that the function has been called exactly once.
writer.write("expect(testFunction.mock.calls.length).toBe(0);");

writeHttpResponseAssertions(testCase.getResponse());
});
}

private void setupStubService(Symbol operationSymbol,
Symbol serviceSymbol,
Symbol handlerSymbol,
boolean usesDefaultValidation) {
// We use a partial here so that we don't have to define the entire service, but still get the advantages
// the type checker, including excess property checking. Later on we'll use `as` to cast this to the
// full service so that we can actually use it.
writer.openBlock("const testService: Partial<$T<{}>> = {", "};", serviceSymbol, () -> {
writer.write("$L: testFunction as $T<{}>,", operationSymbol.getName(), operationSymbol);
});

String getHandlerName = "get" + handlerSymbol.getName();
writer.addImport(getHandlerName, null, "./server/");

if (!usesDefaultValidation) {
writer.addImport("ValidationFailure", "__ValidationFailure", "@aws-smithy/server-common");

// Cast the service as any so TS will ignore the fact that the type being passed in is incomplete.
writer.openBlock(
"const handler = $L(testService as $T<{}>, (ctx: {}, failures: __ValidationFailure[]) => {",
"});", getHandlerName, serviceSymbol,
() -> writer.write("if (failures) { throw failures; } return undefined;")
);
} else {
writer.write("const handler = $L(testService as $T<{}>);", getHandlerName, serviceSymbol);
}
}

private ObjectNode buildQueryBag(List<String> queryParams) {
// The query params in the test definition is a list of strings that looks like
// "Foo=Bar", so we need to split the keys from the values.
Map<String, List<String>> query = testCase.getQueryParams().stream()
Map<String, List<String>> query = queryParams.stream()
.map(pair -> {
String[] split = pair.split("=");
String key;
Expand Down Expand Up @@ -377,13 +456,33 @@ private void writeHttpRequestAssertions(HttpRequestTestCase testCase) {

writeHttpHeaderAssertions(testCase);
writeHttpQueryAssertions(testCase);
writeHttpBodyAssertions(testCase);
testCase.getBody().ifPresent(body -> {
writeHttpBodyAssertions(body, testCase.getBodyMediaType().orElse("UNKNOWN"), true);
});
}

private void writeHttpResponseAssertions(HttpResponseTestCase testCase) {
writer.write("expect(r.statusCode).toBe($L);", testCase.getCode());
writeHttpHeaderAssertions(testCase);
writeHttpBodyAssertions(testCase);
testCase.getBody().ifPresent(body -> {
writeHttpBodyAssertions(body, testCase.getBodyMediaType().orElse("UNKNOWN"), false);
});
}

private void writeHttpResponseAssertions(HttpMalformedResponseDefinition responseDefinition) {
writer.write("expect(r.statusCode).toBe($L);", responseDefinition.getCode());
responseDefinition.getHeaders().forEach((header, value) -> {
header = header.toLowerCase();
writer.write("expect(r.headers[$S]).toBeDefined();", header);
writer.write("expect(r.headers[$S]).toBe($S);", header, value);
});
writer.write("");
responseDefinition.getBody().ifPresent(body -> {
writeHttpBodyAssertions(body, responseDefinition.getBodyMediaType().orElse("UNKNOWN"), false);
});
responseDefinition.getBodyMessageRegex().ifPresent(regex -> {
writeHttpBodyMessageAssertion(regex, responseDefinition.getBodyMediaType().orElse("UNKNOWN"));
});
}

private void writeHttpQueryAssertions(HttpRequestTestCase testCase) {
Expand Down Expand Up @@ -426,37 +525,46 @@ private void writeHttpHeaderAssertions(HttpMessageTestCase testCase) {
writer.write("");
}

private void writeHttpBodyAssertions(HttpMessageTestCase testCase) {
testCase.getBody().ifPresent(body -> {
// If we expect an empty body, expect it to be falsy.
if (body.isEmpty()) {
writer.write("expect(r.body).toBeFalsy();");
return;
}
private void writeHttpBodyAssertions(String body, String mediaType, boolean isClientTest) {
// If we expect an empty body, expect it to be falsy.
if (body.isEmpty()) {
writer.write("expect(r.body).toBeFalsy();");
return;
}

// Fast fail if we don't have a body.
writer.write("expect(r.body).toBeDefined();");
// Fast fail if we don't have a body.
writer.write("expect(r.body).toBeDefined();");

// Otherwise load a media type specific comparator and do a comparison.
String mediaType = testCase.getBodyMediaType().orElse("UNKNOWN");
String comparatorInvoke = registerBodyComparatorStub(mediaType);
// Otherwise load a media type specific comparator and do a comparison.
String comparatorInvoke = registerBodyComparatorStub(mediaType);

// If this is a request case then we know we're generating a client test,
// because a request case for servers would be comparing parsed objects. We
// need to know which is which here to be able to grab the utf8Encoder from
// the right place.
if (testCase instanceof HttpRequestTestCase) {
writer.write("const utf8Encoder = client.config.utf8Encoder;");
} else {
writer.addImport("toUtf8", "__utf8Encoder", "@aws-sdk/util-utf8-node");
writer.write("const utf8Encoder = __utf8Encoder;");
}
// If this is a request case then we know we're generating a client test,
// because a request case for servers would be comparing parsed objects. We
// need to know which is which here to be able to grab the utf8Encoder from
// the right place.
if (isClientTest) {
writer.write("const utf8Encoder = client.config.utf8Encoder;");
} else {
writer.addImport("toUtf8", "__utf8Encoder", "@aws-sdk/util-utf8-node");
writer.write("const utf8Encoder = __utf8Encoder;");
}

// Handle escaping strings with quotes inside them.
writer.write("const bodyString = `$L`;", body.replace("\"", "\\\""));
writer.write("const unequalParts: any = $L;", comparatorInvoke);
writer.write("expect(unequalParts).toBeUndefined();");
});
// Handle escaping strings with quotes inside them.
writer.write("const bodyString = `$L`;", body.replace("\"", "\\\""));
writer.write("const unequalParts: any = $L;", comparatorInvoke);
writer.write("expect(unequalParts).toBeUndefined();");
}

private void writeHttpBodyMessageAssertion(String messageRegex, String mediaType) {
// Fast fail if we don't have a body.
writer.write("expect(r.body).toBeDefined();");

// Otherwise load a media type specific matcher
String comparatorInvoke = registerMessageRegexStub(mediaType);

writer.writeInline("expect(")
.writeInline(comparatorInvoke, messageRegex)
.write(").toEqual(true);");
}

private String registerBodyComparatorStub(String mediaType) {
Expand Down Expand Up @@ -492,6 +600,18 @@ private String registerBodyComparatorStub(String mediaType) {
}
}

private String registerMessageRegexStub(String mediaType) {
// Load an additional stub to handle body comparisons for the
// set of bodyMediaType values we know of.
switch (mediaType) {
case "application/json":
additionalStubs.add("malformed-request-test-regex-json-stub.ts");
return "matchMessageInJsonBody(r.body.toString(), $S)";
default:
throw new IllegalArgumentException("Unsupported media type for message body regex check: " + mediaType);
}
}

public void generateServerResponseTest(OperationShape operation, HttpResponseTestCase testCase) {
Symbol serviceSymbol = symbolProvider.toSymbol(service);
Symbol operationSymbol = symbolProvider.toSymbol(operation);
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,10 @@
/**
* Returns 'true' if the 'message' field in the serialized JSON document matches the given regex.
*/
const matchMessageInJsonBody = (body: string, messageRegex: string): Object => {
const parsedBody = JSON.parse(body);
if (!parsedBody.hasOwnProperty("message")) {
return false;
}
return new RegExp(messageRegex).test(parsedBody["message"]);
}

0 comments on commit 5e1db33

Please sign in to comment.