Skip to content

Commit

Permalink
Merge pull request #304 from skmcgrail/awsQueryErrorTrait
Browse files Browse the repository at this point in the history
Support the delegation of determining the errors that can occur for an operation
  • Loading branch information
skmcgrail committed Jun 3, 2021
2 parents e18eb0c + 8b541aa commit d9d3400
Show file tree
Hide file tree
Showing 4 changed files with 94 additions and 52 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -357,7 +357,8 @@ private void generateOperationDeserializerMiddleware(GenerationContext context,
goWriter.write("");

Set<StructureShape> errorShapes = HttpProtocolGeneratorUtils.generateErrorDispatcher(
context, operation, responseType, this::writeErrorMessageCodeDeserializer);
context, operation, responseType, this::writeErrorMessageCodeDeserializer,
this::getOperationErrors);
deserializingErrorShapes.addAll(errorShapes);
deserializeDocumentBindingShapes.addAll(errorShapes);
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -16,9 +16,15 @@
package software.amazon.smithy.go.codegen.integration;

import java.util.Collection;
import java.util.Map;
import java.util.Set;
import java.util.TreeMap;
import java.util.TreeSet;
import java.util.function.BiFunction;
import java.util.function.Consumer;
import java.util.function.Function;
import java.util.stream.Collectors;
import software.amazon.smithy.codegen.core.CodegenException;
import software.amazon.smithy.codegen.core.Symbol;
import software.amazon.smithy.go.codegen.GoWriter;
import software.amazon.smithy.go.codegen.SmithyGoDependency;
Expand All @@ -29,33 +35,36 @@
import software.amazon.smithy.model.shapes.OperationShape;
import software.amazon.smithy.model.shapes.ServiceShape;
import software.amazon.smithy.model.shapes.Shape;
import software.amazon.smithy.model.shapes.ShapeId;
import software.amazon.smithy.model.shapes.StructureShape;

public final class HttpProtocolGeneratorUtils {

private HttpProtocolGeneratorUtils() {}
private HttpProtocolGeneratorUtils() {
}

/**
* Generates a function that handles error deserialization by getting the error code then
* dispatching to the error-specific deserializer.
*
* <p>
* If the error code does not map to a known error, a generic error will be returned using
* the error code and error message discovered in the response.
*
* <p>
* The default error message and code are both "UnknownError".
*
* @param context The generation context.
* @param operation The operation to generate for.
* @param responseType The response type for the HTTP protocol.
* @param context The generation context.
* @param operation The operation to generate for.
* @param responseType The response type for the HTTP protocol.
* @param errorMessageCodeGenerator A consumer that generates a snippet that sets the {@code errorCode}
* and {@code errorMessage} variables from the http response.
* @return A set of all error structure shapes for the operation that were dispatched to.
*/
static Set<StructureShape> generateErrorDispatcher(
static Set<StructureShape> generateErrorDispatcher(
GenerationContext context,
OperationShape operation,
Symbol responseType,
Consumer<GenerationContext> errorMessageCodeGenerator
Consumer<GenerationContext> errorMessageCodeGenerator,
BiFunction<GenerationContext, OperationShape, Map<String, ShapeId>> operationErrorsToShapes
) {
GoWriter writer = context.getWriter();
ServiceShape service = context.getService();
Expand All @@ -68,50 +77,49 @@ static Set<StructureShape> generateErrorDispatcher(
writer.addUseImports(SmithyGoDependency.SMITHY_MIDDLEWARE);
writer.openBlock("func $L(response $P, metadata *middleware.Metadata) error {", "}",
errorFunctionName, responseType, () -> {
writer.addUseImports(SmithyGoDependency.BYTES);
writer.addUseImports(SmithyGoDependency.IO);

// Copy the response body into a seekable type
writer.write("var errorBuffer bytes.Buffer");
writer.openBlock("if _, err := io.Copy(&errorBuffer, response.Body); err != nil {", "}", () -> {
writer.write("return &smithy.DeserializationError{Err: fmt.Errorf("
+ "\"failed to copy error response body, %w\", err)}");
});
writer.write("errorBody := bytes.NewReader(errorBuffer.Bytes())");
writer.write("");

// Set the default values for code and message.
writer.write("errorCode := \"UnknownError\"");
writer.write("errorMessage := errorCode");
writer.write("");

// Dispatch to the message/code generator to try to get the specific code and message.
errorMessageCodeGenerator.accept(context);

writer.openBlock("switch {", "}", () -> {
new TreeSet<>(operation.getErrors()).forEach(errorId -> {
StructureShape error = context.getModel().expectShape(errorId).asStructureShape().get();
errorShapes.add(error);
String errorDeserFunctionName = ProtocolGenerator.getErrorDeserFunctionName(
error, service, protocolName);
writer.addUseImports(SmithyGoDependency.STRINGS);
writer.openBlock("case strings.EqualFold($S, errorCode):", "", errorId.getName(service), () -> {
writer.write("return $L(response, errorBody)", errorDeserFunctionName);
writer.addUseImports(SmithyGoDependency.BYTES);
writer.addUseImports(SmithyGoDependency.IO);

// Copy the response body into a seekable type
writer.write("var errorBuffer bytes.Buffer");
writer.openBlock("if _, err := io.Copy(&errorBuffer, response.Body); err != nil {", "}", () -> {
writer.write("return &smithy.DeserializationError{Err: fmt.Errorf("
+ "\"failed to copy error response body, %w\", err)}");
});
});

// Create a generic error
writer.addUseImports(SmithyGoDependency.SMITHY);
writer.openBlock("default:", "", () -> {
writer.openBlock("genericError := &smithy.GenericAPIError{", "}", () -> {
writer.write("Code: errorCode,");
writer.write("Message: errorMessage,");
writer.write("errorBody := bytes.NewReader(errorBuffer.Bytes())");
writer.write("");

// Set the default values for code and message.
writer.write("errorCode := \"UnknownError\"");
writer.write("errorMessage := errorCode");
writer.write("");

// Dispatch to the message/code generator to try to get the specific code and message.
errorMessageCodeGenerator.accept(context);

writer.openBlock("switch {", "}", () -> {
operationErrorsToShapes.apply(context, operation).forEach((name, errorId) -> {
StructureShape error = context.getModel().expectShape(errorId).asStructureShape().get();
errorShapes.add(error);
String errorDeserFunctionName = ProtocolGenerator.getErrorDeserFunctionName(
error, service, protocolName);
writer.addUseImports(SmithyGoDependency.STRINGS);
writer.openBlock("case strings.EqualFold($S, errorCode):", "", name, () -> {
writer.write("return $L(response, errorBody)", errorDeserFunctionName);
});
});

// Create a generic error
writer.addUseImports(SmithyGoDependency.SMITHY);
writer.openBlock("default:", "", () -> {
writer.openBlock("genericError := &smithy.GenericAPIError{", "}", () -> {
writer.write("Code: errorCode,");
writer.write("Message: errorMessage,");
});
writer.write("return genericError");
});
});
writer.write("return genericError");
});
});
});
writer.write("");
}).write("");

return errorShapes;
}
Expand All @@ -136,4 +144,24 @@ public static boolean isShapeWithResponseBindings(Model model, Shape shape, Http
}
return false;
}

/**
* Returns a map of error names to their {@link ShapeId}.
*
* @param context the generation context
* @param operation the operation shape to retrieve errors for
* @return map of error names to {@link ShapeId}
*/
public static Map<String, ShapeId> getOperationErrors(GenerationContext context, OperationShape operation) {
return operation.getErrors().stream()
.collect(Collectors.toMap(
shapeId -> shapeId.getName(context.getService()),
Function.identity(),
(x, y) -> {
if (!x.equals(y)) {
throw new CodegenException(String.format("conflicting error shape ids: %s, %s", x, y));
}
return x;
}, TreeMap::new));
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -289,7 +289,8 @@ private void generateOperationDeserializer(GenerationContext context, OperationS
writer.write("");

Set<StructureShape> errorShapes = HttpProtocolGeneratorUtils.generateErrorDispatcher(
context, operation, responseType, this::writeErrorMessageCodeDeserializer);
context, operation, responseType, this::writeErrorMessageCodeDeserializer,
this::getOperationErrors);
deserializingErrorShapes.addAll(errorShapes);
deserializingDocumentShapes.addAll(errorShapes);
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@

import java.util.Collection;
import java.util.List;
import java.util.Map;
import java.util.stream.Collectors;
import software.amazon.smithy.codegen.core.CodegenException;
import software.amazon.smithy.codegen.core.SymbolProvider;
Expand Down Expand Up @@ -248,6 +249,17 @@ static String getDeserializeMiddlewareName(ShapeId operationShapeId, ServiceShap
+ operationShapeId.getName(service);
}

/**
* Returns a map of error names to their {@link ShapeId}.
*
* @param context the generation context
* @param operation the operation shape to retrieve errors for
* @return map of error names to {@link ShapeId}
*/
default Map<String, ShapeId> getOperationErrors(GenerationContext context, OperationShape operation) {
return HttpProtocolGeneratorUtils.getOperationErrors(context, operation);
}

/**
* Context object used for service serialization and deserialization.
*/
Expand Down

0 comments on commit d9d3400

Please sign in to comment.