Skip to content

Commit

Permalink
[WASM] Implement Java 14 switch expressions.
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 698004955
  • Loading branch information
rluble authored and copybara-github committed Nov 19, 2024
1 parent acb6e59 commit 1a95e5a
Show file tree
Hide file tree
Showing 8 changed files with 446 additions and 39 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,6 @@
*/
package com.google.j2cl.transpiler.ast;

import com.google.common.collect.Iterables;
import com.google.errorprone.annotations.CanIgnoreReturnValue;
import com.google.j2cl.common.visitor.Processor;
import com.google.j2cl.common.visitor.Visitable;
Expand Down Expand Up @@ -43,11 +42,6 @@ public List<Expression> getCaseExpressions() {
return caseExpressions;
}

// TODO(163151103): Remove pre Java 14 switch "emultation" code once the support is complete.
public Expression getCaseExpression() {
return Iterables.getOnlyElement(caseExpressions, null);
}

public List<Statement> getStatements() {
return statements;
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -51,6 +51,8 @@
import com.google.j2cl.transpiler.ast.NumberLiteral;
import com.google.j2cl.transpiler.ast.PrimitiveTypeDescriptor;
import com.google.j2cl.transpiler.ast.StringLiteral;
import com.google.j2cl.transpiler.ast.SwitchExpression;
import com.google.j2cl.transpiler.ast.SwitchStatement;
import com.google.j2cl.transpiler.ast.ThisOrSuperReference;
import com.google.j2cl.transpiler.ast.TypeDescriptor;
import com.google.j2cl.transpiler.ast.TypeDescriptors;
Expand Down Expand Up @@ -321,6 +323,27 @@ public boolean enterMultiExpression(MultiExpression multiExpression) {
return false;
}

@Override
public boolean enterSwitchExpression(SwitchExpression switchExpression) {
String label = environment.getDeclarationName(switchExpression);
sourceBuilder.newLine();
// Create a block that will be the target of the yield statement, which will leave the
// result in the stack and break here.
sourceBuilder.openParens("block " + label);
sourceBuilder.append(
" (result " + environment.getWasmType(switchExpression.getTypeDescriptor()) + ")");

// Render the switch expression as if it where a switch statement, note that since
// all yields will break out of the switch there will be no fallthrough.
StatementTranspiler.render(
SwitchStatement.Builder.from(switchExpression).build(),
sourceBuilder,
environment,
label);
sourceBuilder.closeParens();
return false;
}

@Override
public boolean enterArrayLiteral(ArrayLiteral arrayLiteral) {
checkArgument(arrayLiteral.getTypeDescriptor().isNativeWasmArray());
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -16,8 +16,8 @@
package com.google.j2cl.transpiler.backend.wasm;

import static com.google.common.base.Preconditions.checkState;
import static com.google.common.base.Predicates.not;
import static java.util.Arrays.stream;
import static java.util.stream.Collectors.joining;

import com.google.common.base.Strings;
import com.google.common.collect.Iterables;
Expand All @@ -29,6 +29,7 @@
import com.google.j2cl.transpiler.ast.BooleanLiteral;
import com.google.j2cl.transpiler.ast.BreakStatement;
import com.google.j2cl.transpiler.ast.CatchClause;
import com.google.j2cl.transpiler.ast.ConditionalExpression;
import com.google.j2cl.transpiler.ast.ContinueStatement;
import com.google.j2cl.transpiler.ast.DoWhileStatement;
import com.google.j2cl.transpiler.ast.Expression;
Expand All @@ -39,6 +40,7 @@
import com.google.j2cl.transpiler.ast.LabeledStatement;
import com.google.j2cl.transpiler.ast.LoopStatement;
import com.google.j2cl.transpiler.ast.NumberLiteral;
import com.google.j2cl.transpiler.ast.PrimitiveTypes;
import com.google.j2cl.transpiler.ast.ReturnStatement;
import com.google.j2cl.transpiler.ast.RuntimeMethods;
import com.google.j2cl.transpiler.ast.Statement;
Expand All @@ -50,6 +52,7 @@
import com.google.j2cl.transpiler.ast.TypeDescriptor;
import com.google.j2cl.transpiler.ast.TypeDescriptors;
import com.google.j2cl.transpiler.ast.WhileStatement;
import com.google.j2cl.transpiler.ast.YieldStatement;
import com.google.j2cl.transpiler.backend.common.SourceBuilder;
import java.util.Arrays;
import java.util.List;
Expand All @@ -61,6 +64,14 @@ public static void render(
Statement statement,
final SourceBuilder builder,
final WasmGenerationEnvironment environment) {
render(statement, builder, environment, null);
}

public static void render(
Statement statement,
final SourceBuilder builder,
final WasmGenerationEnvironment environment,
final String enclosingSwitchStatementLabel) {

class SourceTransformer extends AbstractVisitor {
@Override
Expand Down Expand Up @@ -158,6 +169,23 @@ public boolean enterReturnStatement(ReturnStatement returnStatement) {
return false;
}

@Override
public boolean enterYieldStatement(YieldStatement yieldStatement) {
// Render the yield statement as just leaving the result in the stack and breaking
// out the the switch expression label.
builder.emitWithMapping(
yieldStatement.getSourcePosition(),
() -> {
builder.newLine();
if (yieldStatement.getExpression() != null) {
ExpressionTranspiler.render(yieldStatement.getExpression(), builder, environment);
}
builder.newLine();
builder.append("(br " + enclosingSwitchStatementLabel + ")");
});
return false;
}

@Override
public boolean enterSwitchStatement(SwitchStatement switchStatement) {
// Switch statements are emitted as a series of nested blocks, with the innermost block
Expand Down Expand Up @@ -202,7 +230,11 @@ public boolean enterSwitchStatement(SwitchStatement switchStatement) {
builder.append(
switchCase.isDefault()
? ";; default:"
: ";; case " + switchCase.getCaseExpression() + ":");
: ";; case "
+ switchCase.getCaseExpressions().stream()
.map(Expression::toString)
.collect(joining(","))
+ ":");
renderStatements(switchCase.getStatements());
builder.closeParens();
}
Expand All @@ -214,7 +246,7 @@ private void renderSwitchDispatchTable(SwitchStatement switchStatement) {
Stats stats =
Stats.of(
switchStatement.getCases().stream()
.filter(not(SwitchCase::isDefault))
.flatMap(s -> s.getCaseExpressions().stream())
.mapToInt(StatementTranspiler::getSwitchCaseAsIntValue));
if (isDense(stats)) {
renderDenseSwitchDispatchTable(switchStatement, stats);
Expand Down Expand Up @@ -314,7 +346,9 @@ private void renderDenseSwitchDispatchTable(
if (switchCase.isDefault()) {
continue;
}
slots[getSwitchCaseAsIntValue(switchCase) - offset] = casePosition;
for (Expression caseExpression : switchCase.getCaseExpressions()) {
slots[getSwitchCaseAsIntValue(caseExpression) - offset] = casePosition;
}
}

builder.newLine();
Expand All @@ -338,6 +372,7 @@ private void emitBranchIndexExpression(Expression expression, int offset) {
}
}

// TODO(b/379473636): Move the handling of non-dense switches to a normalization pass.
private void renderNonDenseSwitchDispatchTable(SwitchStatement switchStatement) {
// Evaluate the switch expression and jump to the right case.
builder.newLine();
Expand All @@ -357,7 +392,7 @@ private void renderNonDenseSwitchDispatchTable(SwitchStatement switchStatement)
// If the condition for this case is met, jump to the start of the case, i.e. jump out
// of all of the previous enclosing blocks.
Expression condition =
createCaseCondition(switchCase.getCaseExpression(), switchStatement.getExpression());
createCaseCondition(switchCase.getCaseExpressions(), switchStatement.getExpression());
renderConditionalBranch(switchStatement.getSourcePosition(), condition, casePosition);
}

Expand All @@ -370,14 +405,46 @@ private void renderNonDenseSwitchDispatchTable(SwitchStatement switchStatement)

/** Creates the condition to compare the switch expression with the case expression. */
private Expression createCaseCondition(
Expression switchCaseExpression, Expression expression) {
if (TypeDescriptors.isJavaLangString(switchCaseExpression.getTypeDescriptor())) {
// Strings are compared using equals.
return RuntimeMethods.createStringEqualsMethodCall(switchCaseExpression, expression);
List<Expression> switchCaseExpressions, Expression expression) {
Expression condition = null;
for (Expression switchCaseExpression : switchCaseExpressions) {
Expression caseCondition;
if (TypeDescriptors.isJavaLangString(switchCaseExpression.getTypeDescriptor())) {
// Strings are compared using equals.
caseCondition =
RuntimeMethods.createStringEqualsMethodCall(switchCaseExpression, expression);
} else {
checkState(switchCaseExpression.getTypeDescriptor().isPrimitive());
caseCondition = expression.infixEquals(switchCaseExpression);
}
// Transform cases with more that one label short-circuit explicitly, since the backend
// does not implement it but rather a normalization pass that has already been run.
//
// A case expression of the form
//
// case 1, 2, 3:
//
// will be rewritten as
//
// e == 1 ? true : e == 2 ? true : e == 3
//
// which is the equivalent to
//
// e == 1 || e == 2 || e == 3
//
// (There is no short-circuit "or" operator in Wasm.)
//
condition =
condition == null
? caseCondition
: ConditionalExpression.newBuilder()
.setConditionExpression(condition)
.setTrueExpression(BooleanLiteral.get(true))
.setFalseExpression(caseCondition)
.setTypeDescriptor(PrimitiveTypes.BOOLEAN)
.build();
}

checkState(switchCaseExpression.getTypeDescriptor().isPrimitive());
return expression.infixEquals(switchCaseExpression);
return condition;
}

@Override
Expand Down Expand Up @@ -571,7 +638,7 @@ private void renderExpression(Expression expression) {
}

void render(Statement stmt) {
StatementTranspiler.render(stmt, builder, environment);
StatementTranspiler.render(stmt, builder, environment, enclosingSwitchStatementLabel);
}
}

Expand All @@ -598,9 +665,8 @@ public static void renderSourceMappingComment(
}
}

private static int getSwitchCaseAsIntValue(SwitchCase switchCase) {
NumberLiteral caseExpression = (NumberLiteral) switchCase.getCaseExpression();
return caseExpression.getValue().intValue();
private static int getSwitchCaseAsIntValue(Expression caseExpression) {
return ((NumberLiteral) caseExpression).getValue().intValue();
}

private StatementTranspiler() {}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -29,18 +29,18 @@
import com.google.common.collect.ImmutableSet;
import com.google.common.collect.Multiset;
import com.google.errorprone.annotations.CanIgnoreReturnValue;
import com.google.j2cl.transpiler.ast.AbstractVisitor;
import com.google.j2cl.transpiler.ast.ArrayLiteral;
import com.google.j2cl.transpiler.ast.ArrayTypeDescriptor;
import com.google.j2cl.transpiler.ast.DeclaredTypeDescriptor;
import com.google.j2cl.transpiler.ast.Field;
import com.google.j2cl.transpiler.ast.FieldDescriptor;
import com.google.j2cl.transpiler.ast.HasName;
import com.google.j2cl.transpiler.ast.Library;
import com.google.j2cl.transpiler.ast.Method;
import com.google.j2cl.transpiler.ast.MethodDescriptor;
import com.google.j2cl.transpiler.ast.NameDeclaration;
import com.google.j2cl.transpiler.ast.PrimitiveTypeDescriptor;
import com.google.j2cl.transpiler.ast.PrimitiveTypes;
import com.google.j2cl.transpiler.ast.SwitchExpression;
import com.google.j2cl.transpiler.ast.Type;
import com.google.j2cl.transpiler.ast.TypeDeclaration;
import com.google.j2cl.transpiler.ast.TypeDescriptor;
Expand All @@ -49,8 +49,10 @@
import com.google.j2cl.transpiler.backend.wasm.JsImportsGenerator.Imports;
import java.util.Comparator;
import java.util.HashMap;
import java.util.HashSet;
import java.util.LinkedHashMap;
import java.util.Map;
import java.util.Set;
import java.util.function.Function;
import javax.annotation.Nullable;

Expand Down Expand Up @@ -282,9 +284,9 @@ boolean isWasmArrayElementsField(FieldDescriptor descriptor) {
&& descriptor.getName().equals("elements");
}

private final Map<HasName, String> nameByDeclaration = new HashMap<>();
private final Map<Object, String> nameByDeclaration = new HashMap<>();

String getDeclarationName(NameDeclaration declaration) {
String getDeclarationName(Object declaration) {
return "$" + checkNotNull(nameByDeclaration.get(declaration));
}

Expand Down Expand Up @@ -396,6 +398,8 @@ String getSourceMappingPathPrefix() {
this(library, jsImports, /* sourceMappingPathPrefix= */ null, /* isModular= */ false);
}

private static final String SWITCH_EXPRESSION_LABEL = "SWITCH";

WasmGenerationEnvironment(
Library library, Imports jsImports, String sourceMappingPathPrefix, boolean isModular) {
this.isModular = isModular;
Expand All @@ -410,6 +414,31 @@ String getSourceMappingPathPrefix() {
nameByDeclaration.putAll(
UniqueNamesResolver.computeUniqueNames(ImmutableSet.of(), t)));

Set<String> usedNames = new HashSet<>(nameByDeclaration.values());

// Compute ids for switch expression labels that are unique per member and do not collide
// with any of the already assigned names.
library
.streamTypes()
.flatMap(t -> t.getMembers().stream())
.forEach(
member ->
member.accept(
new AbstractVisitor() {
private int index = 0;

// Use preorder traversal so that the switches are numbered in the order
// that they appear in the source.
@Override
public boolean enterSwitchExpression(SwitchExpression switchExpression) {
String switchExpressionLabel = SWITCH_EXPRESSION_LABEL + "." + index++;
checkState(!usedNames.contains(switchExpressionLabel));

nameByDeclaration.put(switchExpression, switchExpressionLabel);
return true;
}
}));

// Create a representation for Java types that is useful to lay out the structs and
// vtables needed in the wasm output.
wasmTypeLayoutByTypeDeclaration = new LinkedHashMap<>();
Expand Down
Loading

0 comments on commit 1a95e5a

Please sign in to comment.