Skip to content

Commit

Permalink
Convert SerdeElision to KnowledgeIndex
Browse files Browse the repository at this point in the history
  • Loading branch information
srchase committed May 8, 2023
1 parent 59b5ad5 commit 5545a57
Show file tree
Hide file tree
Showing 14 changed files with 585 additions and 382 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -45,7 +45,7 @@
import software.amazon.smithy.model.shapes.UnionShape;
import software.amazon.smithy.model.traits.TimestampFormatTrait.Format;
import software.amazon.smithy.typescript.codegen.integration.ProtocolGenerator.GenerationContext;
import software.amazon.smithy.typescript.codegen.validation.SerdeElision;
import software.amazon.smithy.typescript.codegen.knowledge.SerdeElisionIndex;
import software.amazon.smithy.utils.SmithyUnstableApi;

/**
Expand All @@ -67,10 +67,11 @@
*/
@SmithyUnstableApi
public class DocumentMemberDeserVisitor implements ShapeVisitor<String> {
protected final SerdeElision serdeElision;
protected boolean serdeElisionEnabled;
private final GenerationContext context;
private final String dataSource;
private final Format defaultTimestampFormat;
private final SerdeElisionIndex serdeElisionIndex;

/**
* Constructor.
Expand All @@ -89,8 +90,8 @@ public DocumentMemberDeserVisitor(
this.context = context;
this.dataSource = dataSource;
this.defaultTimestampFormat = defaultTimestampFormat;
this.serdeElision = SerdeElision.forModel(context.getModel())
.setEnabledForModel(false);
this.serdeElisionEnabled = false;
this.serdeElisionIndex = SerdeElisionIndex.of(context.getModel());
}

/**
Expand Down Expand Up @@ -288,7 +289,7 @@ private String getDelegateDeserializer(Shape shape, String customDataSource) {
// Use the shape for the function name.
Symbol symbol = context.getSymbolProvider().toSymbol(shape);

if (serdeElision.mayElide(shape)) {
if (serdeElisionEnabled && serdeElisionIndex.mayElide(shape)) {
return "_json(" + customDataSource + ")";
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -45,7 +45,7 @@
import software.amazon.smithy.model.shapes.UnionShape;
import software.amazon.smithy.model.traits.TimestampFormatTrait.Format;
import software.amazon.smithy.typescript.codegen.integration.ProtocolGenerator.GenerationContext;
import software.amazon.smithy.typescript.codegen.validation.SerdeElision;
import software.amazon.smithy.typescript.codegen.knowledge.SerdeElisionIndex;
import software.amazon.smithy.utils.SmithyUnstableApi;

/**
Expand All @@ -66,10 +66,11 @@
*/
@SmithyUnstableApi
public class DocumentMemberSerVisitor implements ShapeVisitor<String> {
protected final SerdeElision serdeElision;
protected boolean serdeElisionEnabled;
private final GenerationContext context;
private final String dataSource;
private final Format defaultTimestampFormat;
private final SerdeElisionIndex serdeElisionIndex;

/**
* Constructor.
Expand All @@ -88,8 +89,8 @@ public DocumentMemberSerVisitor(
this.context = context;
this.dataSource = dataSource;
this.defaultTimestampFormat = defaultTimestampFormat;
this.serdeElision = SerdeElision.forModel(context.getModel())
.setEnabledForModel(false);
this.serdeElisionEnabled = false;
this.serdeElisionIndex = SerdeElisionIndex.of(context.getModel());
}

/**
Expand Down Expand Up @@ -257,7 +258,7 @@ private String getDelegateSerializer(Shape shape) {
// Use the shape for the function name.
Symbol symbol = context.getSymbolProvider().toSymbol(shape);

if (serdeElision.mayElide(shape)) {
if (serdeElisionEnabled && serdeElisionIndex.mayElide(shape)) {
return "_json(" + dataSource + ")";
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,7 @@
import software.amazon.smithy.model.shapes.UnionShape;
import software.amazon.smithy.typescript.codegen.TypeScriptWriter;
import software.amazon.smithy.typescript.codegen.integration.ProtocolGenerator.GenerationContext;
import software.amazon.smithy.typescript.codegen.validation.SerdeElision;
import software.amazon.smithy.typescript.codegen.knowledge.SerdeElisionIndex;
import software.amazon.smithy.utils.SmithyUnstableApi;

/**
Expand Down Expand Up @@ -66,10 +66,12 @@
*/
@SmithyUnstableApi
public abstract class DocumentShapeDeserVisitor extends ShapeVisitor.Default<Void> {
protected boolean serdeElisionEnabled;
private final GenerationContext context;

public DocumentShapeDeserVisitor(GenerationContext context) {
this.context = context;
this.serdeElisionEnabled = false;
}

/**
Expand Down Expand Up @@ -306,7 +308,7 @@ protected final void generateDeserFunction(
String methodLongName =
ProtocolGenerator.getDeserFunctionName(symbol, context.getProtocolName());

boolean mayElide = SerdeElision.forModel(context.getModel()).mayElide(shape);
boolean mayElide = serdeElisionEnabled && SerdeElisionIndex.of(context.getModel()).mayElide(shape);
if (mayElide) {
writer.write("// " + methodName + " omitted.");
writer.write("");
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,7 @@
import software.amazon.smithy.model.shapes.UnionShape;
import software.amazon.smithy.typescript.codegen.TypeScriptWriter;
import software.amazon.smithy.typescript.codegen.integration.ProtocolGenerator.GenerationContext;
import software.amazon.smithy.typescript.codegen.validation.SerdeElision;
import software.amazon.smithy.typescript.codegen.knowledge.SerdeElisionIndex;
import software.amazon.smithy.utils.SmithyUnstableApi;

/**
Expand Down Expand Up @@ -66,10 +66,12 @@
*/
@SmithyUnstableApi
public abstract class DocumentShapeSerVisitor extends ShapeVisitor.Default<Void> {
protected boolean serdeElisionEnabled;
private final GenerationContext context;

public DocumentShapeSerVisitor(GenerationContext context) {
this.context = context;
this.serdeElisionEnabled = false;
}

/**
Expand Down Expand Up @@ -303,7 +305,7 @@ private void generateSerFunction(

writer.addImport(symbol, symbol.getName());

boolean mayElide = SerdeElision.forModel(context.getModel()).mayElide(shape);
boolean mayElide = serdeElisionEnabled && SerdeElisionIndex.of(context.getModel()).mayElide(shape);
if (mayElide) {
writer.write("// " + methodName + " omitted.");
writer.write("");
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,7 @@
import software.amazon.smithy.typescript.codegen.TypeScriptDependency;
import software.amazon.smithy.typescript.codegen.TypeScriptWriter;
import software.amazon.smithy.typescript.codegen.integration.ProtocolGenerator.GenerationContext;
import software.amazon.smithy.typescript.codegen.validation.SerdeElision;
import software.amazon.smithy.typescript.codegen.knowledge.SerdeElisionIndex;
import software.amazon.smithy.utils.SmithyUnstableApi;

/**
Expand Down Expand Up @@ -117,13 +117,15 @@ public void generateEventStreamSerializers(
eventUnionsToSerialize.forEach(eventsUnion -> {
generateEventStreamSerializer(context, eventsUnion);
});
SerdeElisionIndex serdeElisionIndex = SerdeElisionIndex.of(model);
eventShapesToMarshall.forEach(event -> {
generateEventMarshaller(
context,
event,
documentContentType,
serializeInputEventDocumentPayload,
documentShapesToSerialize);
documentShapesToSerialize,
serdeElisionIndex);
});
}

Expand All @@ -143,7 +145,9 @@ public void generateEventStreamDeserializers(
ServiceShape service,
Set<StructureShape> errorShapesToDeserialize,
Set<Shape> eventShapesToDeserialize,
boolean isErrorCodeInBody
boolean isErrorCodeInBody,
boolean serdeElisionEnabled,
SerdeElisionIndex serdeElisionIndex
) {
Model model = context.getModel();

Expand Down Expand Up @@ -171,7 +175,9 @@ public void generateEventStreamDeserializers(
event,
errorShapesToDeserialize,
eventShapesToDeserialize,
isErrorCodeInBody
isErrorCodeInBody,
serdeElisionEnabled,
serdeElisionIndex
);
});
}
Expand Down Expand Up @@ -234,7 +240,8 @@ public void generateEventMarshaller(
StructureShape event,
String documentContentType,
Runnable serializeInputEventDocumentPayload,
Set<Shape> documentShapesToSerialize
Set<Shape> documentShapesToSerialize,
SerdeElisionIndex serdeElisionIndex
) {
String methodName = getEventSerFunctionName(context, event);
Symbol symbol = getSymbol(context, event);
Expand All @@ -252,7 +259,7 @@ public void generateEventMarshaller(
});
writeEventHeaders(context, event);
writeEventBody(context, event, serializeInputEventDocumentPayload,
documentShapesToSerialize);
documentShapesToSerialize, serdeElisionIndex);
writer.openBlock("return { headers, body };");
});
}
Expand Down Expand Up @@ -336,7 +343,8 @@ private void writeEventBody(
GenerationContext context,
StructureShape event,
Runnable serializeInputEventDocumentPayload,
Set<Shape> documentShapesToSerialize
Set<Shape> documentShapesToSerialize,
SerdeElisionIndex serdeElisionIndex
) {
TypeScriptWriter writer = context.getWriter();
Optional<MemberShape> payloadMemberOptional = getEventPayloadMember(event);
Expand All @@ -352,7 +360,7 @@ private void writeEventBody(
} else if (payloadShape instanceof BlobShape || payloadShape instanceof StringShape) {
Symbol symbol = getSymbol(context, payloadShape);
String serFunctionName = ProtocolGenerator.getSerFunctionShortName(symbol);
boolean mayElide = SerdeElision.forModel(context.getModel()).mayElide(payloadShape);
boolean mayElide = serdeElisionIndex.mayElide(payloadShape);
documentShapesToSerialize.add(payloadShape);
if (mayElide) {
writer.write("body = $L(input.$L);", "_json", payloadMemberName);
Expand All @@ -375,7 +383,7 @@ private void writeEventBody(
Symbol symbol = getSymbol(context, event);
String serFunctionName = ProtocolGenerator.getSerFunctionShortName(symbol);
documentShapesToSerialize.add(event);
boolean mayElide = SerdeElision.forModel(context.getModel()).mayElide(event);
boolean mayElide = serdeElisionIndex.mayElide(event);
if (mayElide) {
writer.write("body = $L(input);", "_json");
} else {
Expand Down Expand Up @@ -431,7 +439,9 @@ public void generateEventUnmarshaller(
StructureShape event,
Set<StructureShape> errorShapesToDeserialize,
Set<Shape> eventShapesToDeserialize,
boolean isErrorCodeInBody
boolean isErrorCodeInBody,
boolean serdeElisionEnabled,
SerdeElisionIndex serdeElisionIndex
) {
String methodName = getEventDeserFunctionName(context, event);
Symbol symbol = getSymbol(context, event);
Expand All @@ -445,7 +455,7 @@ public void generateEventUnmarshaller(
} else {
writer.write("const contents: $L = {} as any;", symbol.getName());
readEventHeaders(context, event);
readEventBody(context, event, eventShapesToDeserialize);
readEventBody(context, event, eventShapesToDeserialize, serdeElisionEnabled, serdeElisionIndex);
writer.write("return contents;");
}
});
Expand Down Expand Up @@ -492,7 +502,9 @@ private void readEventHeaders(GenerationContext context, StructureShape event) {
private void readEventBody(
GenerationContext context,
StructureShape event,
Set<Shape> eventShapesToDeserialize
Set<Shape> eventShapesToDeserialize,
boolean serdeElisionEnabled,
SerdeElisionIndex serdeElisionIndex
) {
TypeScriptWriter writer = context.getWriter();
Optional<MemberShape> payloadmemberOptional = getEventPayloadMember(event);
Expand All @@ -507,7 +519,7 @@ private void readEventBody(
writer.write("const data: any = await parseBody(output.body, context);");
Symbol symbol = getSymbol(context, payloadShape);
String deserFunctionName = ProtocolGenerator.getDeserFunctionShortName(symbol);
boolean mayElide = SerdeElision.forModel(context.getModel()).mayElide(payloadShape);
boolean mayElide = serdeElisionEnabled && serdeElisionIndex.mayElide(payloadShape);
if (mayElide) {
writer.addImport("_json", null, "@aws-sdk/smithy-client");
writer.write("contents.$L = $L(data);", payloadMemberName, "_json");
Expand All @@ -520,7 +532,7 @@ private void readEventBody(
writer.write("const data: any = await parseBody(output.body, context);");
Symbol symbol = getSymbol(context, event);
String deserFunctionName = ProtocolGenerator.getDeserFunctionShortName(symbol);
boolean mayElide = SerdeElision.forModel(context.getModel()).mayElide(event);
boolean mayElide = serdeElisionEnabled && serdeElisionIndex.mayElide(event);
if (mayElide) {
writer.addImport("_json", null, "@aws-sdk/smithy-client");
writer.write("Object.assign(contents, $L(data));", "_json");
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -75,7 +75,7 @@
import software.amazon.smithy.typescript.codegen.TypeScriptDependency;
import software.amazon.smithy.typescript.codegen.TypeScriptWriter;
import software.amazon.smithy.typescript.codegen.endpointsV2.RuleSetParameterFinder;
import software.amazon.smithy.typescript.codegen.validation.SerdeElision;
import software.amazon.smithy.typescript.codegen.knowledge.SerdeElisionIndex;
import software.amazon.smithy.utils.ListUtils;
import software.amazon.smithy.utils.OptionalUtils;
import software.amazon.smithy.utils.SetUtils;
Expand Down Expand Up @@ -183,14 +183,17 @@ public void generateSharedComponents(GenerationContext context) {
},
serializingDocumentShapes
);
SerdeElisionIndex serdeElisionIndex = SerdeElisionIndex.of(context.getModel());
// Error shapes that only referred in the error event of an eventstream
Set<StructureShape> errorEventShapes = new TreeSet<>();
eventStreamGenerator.generateEventStreamDeserializers(
context,
service,
errorEventShapes,
deserializingDocumentShapes,
isErrorCodeInBody
isErrorCodeInBody,
enableSerdeElision(),
serdeElisionIndex
);
errorEventShapes.removeIf(deserializingErrorShapes::contains);
errorEventShapes.forEach(error -> generateErrorDeserializer(context, error));
Expand Down Expand Up @@ -1355,9 +1358,8 @@ private String getNamedMembersInputParam(
case PAYLOAD:
Symbol symbol = context.getSymbolProvider().toSymbol(target);

boolean mayElideInput = SerdeElision.forModel(context.getModel())
.setEnabledForModel(enableSerdeElision() && !context.getSettings().generateServerSdk())
.mayElide(target);
boolean mayElideInput = SerdeElisionIndex.of(context.getModel()).mayElide(target)
&& (enableSerdeElision() && !context.getSettings().generateServerSdk());

if (mayElideInput) {
return "_json(" + dataSource + ")";
Expand Down Expand Up @@ -2689,9 +2691,8 @@ private String getNamedMembersOutputParam(
// Redirect to a deserialization function.
Symbol symbol = context.getSymbolProvider().toSymbol(target);

boolean mayElideOutput = SerdeElision.forModel(context.getModel())
.setEnabledForModel(enableSerdeElision() && !context.getSettings().generateServerSdk())
.mayElide(target);
boolean mayElideOutput = SerdeElisionIndex.of(context.getModel()).mayElide(target)
&& (enableSerdeElision() && !context.getSettings().generateServerSdk());

if (mayElideOutput) {
return "_json(" + dataSource + ")";
Expand Down Expand Up @@ -2883,8 +2884,7 @@ protected abstract void deserializeErrorDocumentBody(
protected abstract boolean requiresNumericEpochSecondsInPayload();

/**
* Implement a return true if the protocol allows elision of serde functions
* as defined in {@link SerdeElision}.
* Implement a return true if the protocol allows elision of serde functions.
*
* @return whether protocol implementation is compatible with serde elision.
*/
Expand Down
Loading

0 comments on commit 5545a57

Please sign in to comment.