Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

use serde helper function & serde shortcuts #735

Merged
merged 1 commit into from
Apr 17, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,7 @@
import software.amazon.smithy.model.shapes.UnionShape;
import software.amazon.smithy.model.traits.TimestampFormatTrait.Format;
import software.amazon.smithy.typescript.codegen.integration.ProtocolGenerator.GenerationContext;
import software.amazon.smithy.typescript.codegen.validation.SerdeElision;
import software.amazon.smithy.utils.SmithyUnstableApi;

/**
Expand All @@ -66,6 +67,7 @@
*/
@SmithyUnstableApi
public class DocumentMemberDeserVisitor implements ShapeVisitor<String> {
protected final SerdeElision serdeElision;
private final GenerationContext context;
private final String dataSource;
private final Format defaultTimestampFormat;
Expand All @@ -87,6 +89,8 @@ public DocumentMemberDeserVisitor(
this.context = context;
this.dataSource = dataSource;
this.defaultTimestampFormat = defaultTimestampFormat;
this.serdeElision = SerdeElision.forModel(context.getModel())
.setEnabledForModel(false);
}

/**
Expand Down Expand Up @@ -283,6 +287,11 @@ private String getDelegateDeserializer(Shape shape) {
private String getDelegateDeserializer(Shape shape, String customDataSource) {
// Use the shape for the function name.
Symbol symbol = context.getSymbolProvider().toSymbol(shape);

if (serdeElision.mayElide(shape)) {
return "_json(" + customDataSource + ")";
}

return ProtocolGenerator.getDeserFunctionShortName(symbol)
+ "(" + customDataSource + ", context)";
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,7 @@
import software.amazon.smithy.model.shapes.UnionShape;
import software.amazon.smithy.model.traits.TimestampFormatTrait.Format;
import software.amazon.smithy.typescript.codegen.integration.ProtocolGenerator.GenerationContext;
import software.amazon.smithy.typescript.codegen.validation.SerdeElision;
import software.amazon.smithy.utils.SmithyUnstableApi;

/**
Expand All @@ -65,6 +66,7 @@
*/
@SmithyUnstableApi
public class DocumentMemberSerVisitor implements ShapeVisitor<String> {
protected final SerdeElision serdeElision;
private final GenerationContext context;
private final String dataSource;
private final Format defaultTimestampFormat;
Expand All @@ -86,6 +88,8 @@ public DocumentMemberSerVisitor(
this.context = context;
this.dataSource = dataSource;
this.defaultTimestampFormat = defaultTimestampFormat;
this.serdeElision = SerdeElision.forModel(context.getModel())
.setEnabledForModel(false);
}

/**
Expand Down Expand Up @@ -252,6 +256,11 @@ public final String unionShape(UnionShape shape) {
private String getDelegateSerializer(Shape shape) {
// Use the shape for the function name.
Symbol symbol = context.getSymbolProvider().toSymbol(shape);

if (serdeElision.mayElide(shape)) {
return "_json(" + dataSource + ")";
}

return ProtocolGenerator.getSerFunctionShortName(symbol)
+ "(" + dataSource + ", context)";
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,7 @@
import software.amazon.smithy.model.shapes.UnionShape;
import software.amazon.smithy.typescript.codegen.TypeScriptWriter;
import software.amazon.smithy.typescript.codegen.integration.ProtocolGenerator.GenerationContext;
import software.amazon.smithy.typescript.codegen.validation.SerdeElision;
import software.amazon.smithy.utils.SmithyUnstableApi;

/**
Expand Down Expand Up @@ -305,13 +306,19 @@ protected final void generateDeserFunction(
String methodLongName =
ProtocolGenerator.getDeserFunctionName(symbol, context.getProtocolName());

writer.addImport(symbol, symbol.getName());
writer.writeDocs(methodLongName);
writer.openBlock("const $L = (\n"
+ " output: any,\n"
+ " context: __SerdeContext\n"
+ "): $T => {", "}", methodName, symbol, () -> functionBody.accept(context, shape));
writer.write("");
boolean mayElide = SerdeElision.forModel(context.getModel()).mayElide(shape);
if (mayElide) {
writer.write("// " + methodName + " omitted.");
writer.write("");
} else {
writer.addImport(symbol, symbol.getName());
writer.writeDocs(methodLongName);
writer.openBlock("const $L = (\n"
+ " output: any,\n"
+ " context: __SerdeContext\n"
+ "): $T => {", "}", methodName, symbol, () -> functionBody.accept(context, shape));
writer.write("");
}
}

@Override
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,7 @@
import software.amazon.smithy.model.shapes.UnionShape;
import software.amazon.smithy.typescript.codegen.TypeScriptWriter;
import software.amazon.smithy.typescript.codegen.integration.ProtocolGenerator.GenerationContext;
import software.amazon.smithy.typescript.codegen.validation.SerdeElision;
import software.amazon.smithy.utils.SmithyUnstableApi;

/**
Expand Down Expand Up @@ -302,12 +303,18 @@ private void generateSerFunction(

writer.addImport(symbol, symbol.getName());

writer.writeDocs(methodLongName);
writer.openBlock("const $L = (\n"
+ " input: $T,\n"
+ " context: __SerdeContext\n"
+ "): any => {", "}", methodName, symbol, () -> functionBody.accept(context, shape));
writer.write("");
boolean mayElide = SerdeElision.forModel(context.getModel()).mayElide(shape);
if (mayElide) {
writer.write("// " + methodName + " omitted.");
writer.write("");
} else {
writer.writeDocs(methodLongName);
writer.openBlock("const $L = (\n"
+ " input: $T,\n"
+ " context: __SerdeContext\n"
+ "): any => {", "}", methodName, symbol, () -> functionBody.accept(context, shape));
writer.write("");
}
}

@Override
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,7 @@
import software.amazon.smithy.typescript.codegen.TypeScriptDependency;
import software.amazon.smithy.typescript.codegen.TypeScriptWriter;
import software.amazon.smithy.typescript.codegen.integration.ProtocolGenerator.GenerationContext;
import software.amazon.smithy.typescript.codegen.validation.SerdeElision;
import software.amazon.smithy.utils.SmithyUnstableApi;

/**
Expand Down Expand Up @@ -351,8 +352,13 @@ private void writeEventBody(
} else if (payloadShape instanceof BlobShape || payloadShape instanceof StringShape) {
Symbol symbol = getSymbol(context, payloadShape);
String serFunctionName = ProtocolGenerator.getSerFunctionShortName(symbol);
boolean mayElide = SerdeElision.forModel(context.getModel()).mayElide(payloadShape);
documentShapesToSerialize.add(payloadShape);
writer.write("body = $L(input.$L, context);", payloadMemberName, serFunctionName);
if (mayElide) {
writer.write("body = $L(input.$L);", "_json", payloadMemberName);
} else {
writer.write("body = $L(input.$L, context);", serFunctionName, payloadMemberName);
}
serializeInputEventDocumentPayload.run();
} else {
throw new CodegenException(String.format("Unexpected shape type bound to event payload: `%s`",
Expand All @@ -369,7 +375,12 @@ private void writeEventBody(
Symbol symbol = getSymbol(context, event);
String serFunctionName = ProtocolGenerator.getSerFunctionShortName(symbol);
documentShapesToSerialize.add(event);
writer.write("body = $L(input, context);", serFunctionName);
boolean mayElide = SerdeElision.forModel(context.getModel()).mayElide(event);
if (mayElide) {
writer.write("body = $L(input);", "_json");
} else {
writer.write("body = $L(input, context);", serFunctionName);
}
serializeInputEventDocumentPayload.run();
}
}
Expand Down Expand Up @@ -496,14 +507,26 @@ private void readEventBody(
writer.write("const data: any = await parseBody(output.body, context);");
Symbol symbol = getSymbol(context, payloadShape);
String deserFunctionName = ProtocolGenerator.getDeserFunctionShortName(symbol);
writer.write("contents.$L = $L(data, context);", payloadMemberName, deserFunctionName);
boolean mayElide = SerdeElision.forModel(context.getModel()).mayElide(payloadShape);
if (mayElide) {
writer.addImport("_json", null, "@aws-sdk/smithy-client");
writer.write("contents.$L = $L(data);", payloadMemberName, "_json");
} else {
writer.write("contents.$L = $L(data, context);", payloadMemberName, deserFunctionName);
}
eventShapesToDeserialize.add(payloadShape);
}
} else {
writer.write("const data: any = await parseBody(output.body, context);");
Symbol symbol = getSymbol(context, event);
String deserFunctionName = ProtocolGenerator.getDeserFunctionShortName(symbol);
writer.write("Object.assign(contents, $L(data, context));", deserFunctionName);
boolean mayElide = SerdeElision.forModel(context.getModel()).mayElide(event);
if (mayElide) {
writer.addImport("_json", null, "@aws-sdk/smithy-client");
writer.write("Object.assign(contents, $L(data));", "_json");
} else {
writer.write("Object.assign(contents, $L(data, context));", deserFunctionName);
}
eventShapesToDeserialize.add(event);
}
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -75,6 +75,7 @@
import software.amazon.smithy.typescript.codegen.TypeScriptDependency;
import software.amazon.smithy.typescript.codegen.TypeScriptWriter;
import software.amazon.smithy.typescript.codegen.endpointsV2.RuleSetParameterFinder;
import software.amazon.smithy.typescript.codegen.validation.SerdeElision;
import software.amazon.smithy.utils.ListUtils;
import software.amazon.smithy.utils.OptionalUtils;
import software.amazon.smithy.utils.SetUtils;
Expand Down Expand Up @@ -162,8 +163,13 @@ public final ApplicationProtocol getApplicationProtocol() {
@Override
public void generateSharedComponents(GenerationContext context) {
TypeScriptWriter writer = context.getWriter();
writer.addImport("map", "__map", "@aws-sdk/smithy-client");
writer.write("const map = __map");
writer.addImport("map", null, "@aws-sdk/smithy-client");

if (context.getSettings().generateClient()) {
writer.addImport("withBaseException", null, "@aws-sdk/smithy-client");
SymbolReference exception = HttpProtocolGeneratorUtils.getClientBaseException(context);
writer.write("const throwDefaultError = withBaseException($T);", exception);
}

deserializingErrorShapes.forEach(error -> generateErrorDeserializer(context, error));
serializingErrorShapes.forEach(error -> generateErrorSerializer(context, error));
Expand Down Expand Up @@ -1348,6 +1354,15 @@ private String getNamedMembersInputParam(
switch (bindingType) {
case PAYLOAD:
Symbol symbol = context.getSymbolProvider().toSymbol(target);

boolean mayElideInput = SerdeElision.forModel(context.getModel())
.setEnabledForModel(enableSerdeElision() && !context.getSettings().generateServerSdk())
.mayElide(target);

if (mayElideInput) {
return "_json(" + dataSource + ")";
}

return ProtocolGenerator.getSerFunctionShortName(symbol)
+ "(" + dataSource + ", context)";
default:
Expand Down Expand Up @@ -2088,7 +2103,6 @@ private void generateOperationResponseDeserializer(
});

List<HttpBinding> documentBindings = readResponseBody(context, operation, bindingIndex);

// Track all shapes bound to the document so their deserializers may be generated.
documentBindings.forEach(binding -> {
Shape target = model.expectShape(binding.getMember().getTarget());
Expand Down Expand Up @@ -2674,6 +2688,15 @@ private String getNamedMembersOutputParam(
case PAYLOAD:
// Redirect to a deserialization function.
Symbol symbol = context.getSymbolProvider().toSymbol(target);

boolean mayElideOutput = SerdeElision.forModel(context.getModel())
.setEnabledForModel(enableSerdeElision() && !context.getSettings().generateServerSdk())
.mayElide(target);

if (mayElideOutput) {
return "_json(" + dataSource + ")";
}

return ProtocolGenerator.getDeserFunctionShortName(symbol)
+ "(" + dataSource + ", context)";
default:
Expand Down Expand Up @@ -2858,4 +2881,14 @@ protected abstract void deserializeErrorDocumentBody(
* @return true if this protocol disallows string epoch timestamps in payloads.
*/
protected abstract boolean requiresNumericEpochSecondsInPayload();

/**
* Implement a return true if the protocol allows elision of serde functions
* as defined in {@link SerdeElision}.
*
* @return whether protocol implementation is compatible with serde elision.
*/
protected boolean enableSerdeElision() {
return false;
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -364,7 +364,6 @@ static Set<StructureShape> generateErrorDispatcher(
}

// Error responses must be at least BaseException interface
SymbolReference baseExceptionReference = getClientBaseException(context);
errorCodeGenerator.accept(context);

Runnable defaultErrorHandler = () -> {
Expand All @@ -376,18 +375,15 @@ static Set<StructureShape> generateErrorDispatcher(
writer.write("const parsedBody = await parseBody(output.body, context);");
}

writer.addImport("throwDefaultError", "throwDefaultError", "@aws-sdk/smithy-client");

// Get the protocol specific error location for retrieving contents.
String errorLocation = bodyErrorLocationModifier.apply(context, "parsedBody");
writer.openBlock("throwDefaultError({", "})", () -> {
writer.openBlock("return throwDefaultError({", "})", () -> {
writer.write("output,");
if (errorLocation.equals("parsedBody")) {
writer.write("parsedBody,");
} else {
writer.write("parsedBody: $L,", errorLocation);
}
writer.write("exceptionCtor: $T,", baseExceptionReference);
writer.write("errorCode");
});
};
Expand Down Expand Up @@ -465,7 +461,7 @@ static void writeHostPrefix(GenerationContext context, OperationShape operation)
/**
* Construct a symbol reference of client's base exception class.
*/
private static SymbolReference getClientBaseException(GenerationContext context) {
public static SymbolReference getClientBaseException(GenerationContext context) {
ServiceShape service = context.getService();
SymbolProvider symbolProvider = context.getSymbolProvider();
String serviceExceptionName = symbolProvider.toSymbol(service).getName()
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,7 @@
import software.amazon.smithy.typescript.codegen.CodegenUtils;
import software.amazon.smithy.typescript.codegen.TypeScriptDependency;
import software.amazon.smithy.typescript.codegen.TypeScriptWriter;
import software.amazon.smithy.typescript.codegen.validation.SerdeElision;
import software.amazon.smithy.utils.OptionalUtils;
import software.amazon.smithy.utils.SmithyUnstableApi;

Expand Down Expand Up @@ -129,6 +130,12 @@ public void generateSharedComponents(GenerationContext context) {

TypeScriptWriter writer = context.getWriter();

if (context.getSettings().generateClient()) {
writer.addImport("withBaseException", null, "@aws-sdk/smithy-client");
SymbolReference exception = HttpProtocolGeneratorUtils.getClientBaseException(context);
writer.write("const throwDefaultError = withBaseException($T);", exception);
}

// Write a function to generate HTTP requests since they're so similar.
SymbolReference requestType = getApplicationProtocol().getRequestType();
writer.addUseImports(requestType);
Expand Down Expand Up @@ -449,7 +456,7 @@ private void generateOperationDeserializer(GenerationContext context, OperationS
writer.write("...contents,");
});
});
writer.write("return Promise.resolve(response);");
writer.write("return response;");
});
writer.write("");

Expand Down Expand Up @@ -487,9 +494,15 @@ private void generateErrorDeserializer(GenerationContext context, StructureShape
// the error shape here.
writer.write("const body = parseBody($L.body, context);", outputReference);
}
writer.write("const deserialized: any = $L($L, context);",
ProtocolGenerator.getDeserFunctionShortName(errorSymbol),
getErrorBodyLocation(context, "body"));

if (SerdeElision.forModel(context.getModel()).mayElide(error)) {
writer.write("const deserialized: any = _json($L);",
getErrorBodyLocation(context, "body"));
} else {
writer.write("const deserialized: any = $L($L, context);",
ProtocolGenerator.getDeserFunctionShortName(errorSymbol),
getErrorBodyLocation(context, "body"));
}

// Then load it into the object with additional error and response properties.
writer.openBlock("const exception = new $T({", "});", errorSymbol, () -> {
Expand Down Expand Up @@ -598,4 +611,13 @@ protected abstract void deserializeOutputDocument(
OperationShape operation,
StructureShape outputStructure
);

/**
* See {@link software.amazon.smithy.typescript.codegen.validation.SerdeElision}.
*
* @return whether protocol implementation is compatible with serde elision.
*/
protected boolean enableSerdeElision() {
return false;
}
}
Loading