Skip to content

Commit

Permalink
add shape deserializer overrides (#512)
Browse files Browse the repository at this point in the history
  • Loading branch information
lucix-aws authored May 23, 2024
1 parent 7afc826 commit 3f5ffbf
Show file tree
Hide file tree
Showing 2 changed files with 49 additions and 3 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,9 @@

import java.util.Collection;
import java.util.Comparator;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
import java.util.Optional;
import java.util.Set;
import java.util.TreeSet;
Expand Down Expand Up @@ -75,6 +77,7 @@ public abstract class HttpBindingProtocolGenerator implements ProtocolGenerator
private final Set<Shape> serializeDocumentBindingShapes = new TreeSet<>();
private final Set<Shape> deserializeDocumentBindingShapes = new TreeSet<>();
private final Set<StructureShape> deserializingErrorShapes = new TreeSet<>();
private final Map<ShapeId, Symbol> deserializerOverrides = new HashMap<>();

/**
* Creates a Http binding protocol generator.
Expand Down Expand Up @@ -1082,6 +1085,13 @@ private String conditionallyBase64Encode(

@Override
public void generateResponseDeserializers(GenerationContext context) {
deserializerOverrides.putAll(
context.getIntegrations().stream()
.flatMap(it -> it.getClientPlugins(context.getModel(), context.getService()).stream())
.flatMap(it -> it.getShapeDeserializers().entrySet().stream())
.collect(Collectors.toMap(Map.Entry::getKey, Map.Entry::getValue))
);

EventStreamIndex streamIndex = EventStreamIndex.of(context.getModel());

for (OperationShape operation : getHttpBindingOperations(context)) {
Expand Down Expand Up @@ -1347,13 +1357,24 @@ private void writeHeaderDeserializerFunction(
) {
writer.openBlock("if headerValues := response.Header.Values($S); len(headerValues) != 0 {", "}",
binding.getLocationName(), () -> {
Shape targetShape = context.getModel().expectShape(memberShape.getTarget());
var target = memberShape.getTarget();
Shape targetShape = context.getModel().expectShape(target);

String operand = "headerValues";
operand = writeHeaderValueAccessor(context, writer, targetShape, binding, operand);

String value = generateHttpHeaderValue(context, writer, memberShape, binding,
operand);
if (deserializerOverrides.containsKey(target)) {
writer.write("""
deserOverride, err := $T($L)
if err != nil {
return err
}
v.$L = deserOverride
""", deserializerOverrides.get(target), operand, memberName);
return;
}

var value = generateHttpHeaderValue(context, writer, memberShape, binding, operand);
writer.write("v.$L = $L", memberName,
CodegenUtils.getAsPointerIfPointable(context.getModel(), writer,
GoPointableIndex.of(context.getModel()), memberShape, value));
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@
import java.util.Optional;
import java.util.Set;
import java.util.function.BiPredicate;
import software.amazon.smithy.codegen.core.Symbol;
import software.amazon.smithy.go.codegen.GoWriter;
import software.amazon.smithy.go.codegen.auth.AuthParameter;
import software.amazon.smithy.go.codegen.auth.AuthParametersResolver;
Expand Down Expand Up @@ -54,6 +55,7 @@ public final class RuntimeClientPlugin implements ToSmithyBuilder<RuntimeClientP
private final MiddlewareRegistrar registerMiddleware;
private final Map<String, GoWriter.Writable> endpointBuiltinBindings;
private final Map<ShapeId, AuthSchemeDefinition> authSchemeDefinitions;
private final Map<ShapeId, Symbol> shapeDeserializers;

private RuntimeClientPlugin(Builder builder) {
operationPredicate = builder.operationPredicate;
Expand All @@ -67,6 +69,7 @@ private RuntimeClientPlugin(Builder builder) {
configFieldResolvers = builder.configFieldResolvers;
endpointBuiltinBindings = builder.endpointBuiltinBindings;
authSchemeDefinitions = builder.authSchemeDefinitions;
shapeDeserializers = builder.shapeDeserializers;
}

@FunctionalInterface
Expand Down Expand Up @@ -130,6 +133,14 @@ public Map<ShapeId, AuthSchemeDefinition> getAuthSchemeDefinitions() {
return authSchemeDefinitions;
}

/**
* Gets the registered shape deserializers.
* @return the deserializers.
*/
public Map<ShapeId, Symbol> getShapeDeserializers() {
return shapeDeserializers;
}

/**
* Gets the optionally present middleware registrar object that resolves to middleware registering function.
*
Expand Down Expand Up @@ -242,6 +253,7 @@ public static final class Builder implements SmithyBuilder<RuntimeClientPlugin>
private Map<String, GoWriter.Writable> endpointBuiltinBindings = new HashMap<>();
private MiddlewareRegistrar registerMiddleware;
private Map<ShapeId, AuthSchemeDefinition> authSchemeDefinitions = new HashMap<>();
private Map<ShapeId, Symbol> shapeDeserializers = new HashMap<>();

@Override
public RuntimeClientPlugin build() {
Expand Down Expand Up @@ -496,5 +508,18 @@ public Builder addAuthSchemeDefinition(ShapeId schemeId, AuthSchemeDefinition de
this.authSchemeDefinitions.put(schemeId, definition);
return this;
}

/**
* Registers a codegen definition for a custom shape deserializer. This feature is currently only supported for
* overriding deserialization in HTTP bindings.
* @param id The shape id.
* @param deserializer The deserializer symbol. The written code MUST be a function which accepts the
* corresponding type for the shape and returns (*type, error) accordingly.
* @return Returns the builder.
*/
public Builder addShapeDeserializer(ShapeId id, Symbol deserializer) {
this.shapeDeserializers.put(id, deserializer);
return this;
}
}
}

0 comments on commit 3f5ffbf

Please sign in to comment.