From 5d69de52a672511d6eb117ae32b621cbae5d5b8d Mon Sep 17 00:00:00 2001
From: Sokwhan Huh <sokwhan@google.com>
Date: Tue, 5 Sep 2023 10:28:10 -0700
Subject: [PATCH] Add HomogeneousLiteral Validator

PiperOrigin-RevId: 562821010
---
 common/navigation/BUILD.bazel                 |   2 +-
 .../main/java/dev/cel/common/CelOptions.java  |   4 +
 validator/BUILD.bazel                         |   2 +-
 .../dev/cel/validator/validators/BUILD.bazel  |  19 ++
 .../HomogeneousLiteralValidator.java          | 147 ++++++++++
 .../dev/cel/validator/validators/BUILD.bazel  |   2 +
 .../HomogeneousLiteralValidatorTest.java      | 260 ++++++++++++++++++
 validator/validators/BUILD.bazel              |   7 +-
 8 files changed, 440 insertions(+), 3 deletions(-)
 create mode 100644 validator/src/main/java/dev/cel/validator/validators/HomogeneousLiteralValidator.java
 create mode 100644 validator/src/test/java/dev/cel/validator/validators/HomogeneousLiteralValidatorTest.java

diff --git a/common/navigation/BUILD.bazel b/common/navigation/BUILD.bazel
index 1bccf932..f4e757ad 100644
--- a/common/navigation/BUILD.bazel
+++ b/common/navigation/BUILD.bazel
@@ -1,6 +1,6 @@
 package(
     default_applicable_licenses = ["//:license"],
-    default_visibility = ["//visibility:public"],  # TODO: Expose when ready
+    default_visibility = ["//visibility:public"],
 )
 
 java_library(
diff --git a/common/src/main/java/dev/cel/common/CelOptions.java b/common/src/main/java/dev/cel/common/CelOptions.java
index ec13b1bd..b4a944da 100644
--- a/common/src/main/java/dev/cel/common/CelOptions.java
+++ b/common/src/main/java/dev/cel/common/CelOptions.java
@@ -301,7 +301,11 @@ public abstract static class Builder {
      * checker will implicitly coerce them to type dyn.
      *
      * <p>This flag is recommended for all new uses of CEL.
+     *
+     * @deprecated Use standalone {@code dev.cel.validators.validator.HomogeneousLiteralValidator}
+     *     instead.
      */
+    @Deprecated
     public abstract Builder enableHomogeneousLiterals(boolean value);
 
     /**
diff --git a/validator/BUILD.bazel b/validator/BUILD.bazel
index fbf29028..a5afcbd2 100644
--- a/validator/BUILD.bazel
+++ b/validator/BUILD.bazel
@@ -1,6 +1,6 @@
 package(
     default_applicable_licenses = ["//:license"],
-    default_visibility = ["//visibility:public"],  # TODO: Expose when ready
+    default_visibility = ["//visibility:public"],
 )
 
 java_library(
diff --git a/validator/src/main/java/dev/cel/validator/validators/BUILD.bazel b/validator/src/main/java/dev/cel/validator/validators/BUILD.bazel
index 012fd524..c6e06f02 100644
--- a/validator/src/main/java/dev/cel/validator/validators/BUILD.bazel
+++ b/validator/src/main/java/dev/cel/validator/validators/BUILD.bazel
@@ -47,6 +47,25 @@ java_library(
     ],
 )
 
+java_library(
+    name = "homogeneous_literal",
+    srcs = [
+        "HomogeneousLiteralValidator.java",
+    ],
+    tags = [
+    ],
+    deps = [
+        "//bundle:cel",
+        "//common",
+        "//common/ast",
+        "//common/navigation",
+        "//common/types:cel_types",
+        "//common/types:type_providers",
+        "//validator:ast_validator",
+        "@maven//:com_google_guava_guava",
+    ],
+)
+
 java_library(
     name = "literal_validator",
     srcs = [
diff --git a/validator/src/main/java/dev/cel/validator/validators/HomogeneousLiteralValidator.java b/validator/src/main/java/dev/cel/validator/validators/HomogeneousLiteralValidator.java
new file mode 100644
index 00000000..db787d89
--- /dev/null
+++ b/validator/src/main/java/dev/cel/validator/validators/HomogeneousLiteralValidator.java
@@ -0,0 +1,147 @@
+// Copyright 2023 Google LLC
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+//      https://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package dev.cel.validator.validators;
+
+import com.google.common.collect.ImmutableList;
+import com.google.common.collect.ImmutableSet;
+import dev.cel.bundle.Cel;
+import dev.cel.common.CelAbstractSyntaxTree;
+import dev.cel.common.ast.CelExpr;
+import dev.cel.common.ast.CelExpr.CelCreateMap;
+import dev.cel.common.ast.CelExpr.ExprKind.Kind;
+import dev.cel.common.navigation.CelNavigableAst;
+import dev.cel.common.navigation.CelNavigableExpr;
+import dev.cel.common.types.CelType;
+import dev.cel.common.types.CelTypes;
+import dev.cel.validator.CelAstValidator;
+import java.util.Arrays;
+import java.util.HashSet;
+import java.util.Optional;
+
+/**
+ * HomogeneousLiteralValidator checks that all list and map literals entries have the same types,
+ * i.e. no mixed list element types or mixed map key or map value types.
+ */
+public class HomogeneousLiteralValidator implements CelAstValidator {
+  private final ImmutableSet<String> exemptFunctions;
+
+  /**
+   * Construct a new instance of {@link HomogeneousLiteralValidator}. This validator will not for
+   * functions in {@code exemptFunctions}.
+   */
+  public static HomogeneousLiteralValidator newInstance(Iterable<String> exemptFunctions) {
+    return new HomogeneousLiteralValidator(exemptFunctions);
+  }
+
+  /**
+   * Construct a new instance of {@link HomogeneousLiteralValidator}. This validator will not for
+   * functions in {@code exemptFunctions}.
+   */
+  public static HomogeneousLiteralValidator newInstance(String... exemptFunctions) {
+    return newInstance(Arrays.asList(exemptFunctions));
+  }
+
+  @Override
+  public void validate(CelNavigableAst navigableAst, Cel cel, IssuesFactory issuesFactory) {
+    navigableAst
+        .getRoot()
+        .descendants()
+        .filter(
+            node ->
+                node.getKind().equals(Kind.CREATE_LIST) || node.getKind().equals(Kind.CREATE_MAP))
+        .filter(node -> !isExemptFunction(node))
+        .map(CelNavigableExpr::expr)
+        .forEach(
+            expr -> {
+              if (expr.exprKind().getKind().equals(Kind.CREATE_LIST)) {
+                validateList(navigableAst.getAst(), issuesFactory, expr);
+              } else if (expr.exprKind().getKind().equals(Kind.CREATE_MAP)) {
+                validateMap(navigableAst.getAst(), issuesFactory, expr);
+              }
+            });
+  }
+
+  private void validateList(CelAbstractSyntaxTree ast, IssuesFactory issuesFactory, CelExpr expr) {
+    CelType previousType = null;
+    HashSet<Integer> optionalIndices = new HashSet<>(expr.createList().optionalIndices());
+    ImmutableList<CelExpr> elements = expr.createList().elements();
+    for (int i = 0; i < elements.size(); i++) {
+      CelExpr element = elements.get(i);
+      CelType currentType = ast.getType(element.id()).get();
+      if (optionalIndices.contains(i)) {
+        currentType = currentType.parameters().get(0);
+      }
+
+      if (previousType == null) {
+        previousType = currentType;
+        continue;
+      }
+
+      reportErrorIfUnassignable(issuesFactory, element.id(), previousType, currentType);
+    }
+  }
+
+  private void validateMap(CelAbstractSyntaxTree ast, IssuesFactory issuesFactory, CelExpr expr) {
+    CelType previousKeyType = null;
+    CelType previousValueType = null;
+    for (CelCreateMap.Entry entry : expr.createMap().entries()) {
+      CelType currentKeyType = ast.getType(entry.key().id()).get();
+      CelType currentValueType = ast.getType(entry.value().id()).get();
+      if (entry.optionalEntry()) {
+        currentValueType = currentValueType.parameters().get(0);
+      }
+
+      if (previousKeyType == null) {
+        previousKeyType = currentKeyType;
+        previousValueType = currentValueType;
+        continue;
+      }
+
+      reportErrorIfUnassignable(issuesFactory, entry.id(), previousKeyType, currentKeyType);
+      reportErrorIfUnassignable(issuesFactory, entry.id(), previousValueType, currentValueType);
+    }
+  }
+
+  private void reportErrorIfUnassignable(
+      IssuesFactory issuesFactory, long elementId, CelType previousType, CelType currentType) {
+    if (!previousType.isAssignableFrom(currentType)) {
+      issuesFactory.addError(
+          elementId,
+          String.format(
+              "expected type '%s' but found '%s'",
+              CelTypes.format(previousType), CelTypes.format(currentType)));
+    }
+  }
+
+  private boolean isExemptFunction(CelNavigableExpr listExpr) {
+    Optional<CelNavigableExpr> parent = listExpr.parent();
+    while (parent.isPresent()) {
+      CelNavigableExpr node = parent.get();
+      if (node.getKind().equals(Kind.CALL)) {
+        if (exemptFunctions.contains(node.expr().callOrDefault().function())) {
+          return true;
+        }
+      }
+
+      parent = node.parent();
+    }
+
+    return false;
+  }
+
+  private HomogeneousLiteralValidator(Iterable<String> exemptFunctions) {
+    this.exemptFunctions = ImmutableSet.copyOf(exemptFunctions);
+  }
+}
diff --git a/validator/src/test/java/dev/cel/validator/validators/BUILD.bazel b/validator/src/test/java/dev/cel/validator/validators/BUILD.bazel
index 321e2b75..eca252aa 100644
--- a/validator/src/test/java/dev/cel/validator/validators/BUILD.bazel
+++ b/validator/src/test/java/dev/cel/validator/validators/BUILD.bazel
@@ -13,10 +13,12 @@ java_library(
         "//common:compiler_common",
         "//common:options",
         "//common/types",
+        "//extensions:optional_library",
         "//runtime",
         "//validator",
         "//validator:validator_builder",
         "//validator/validators:duration",
+        "//validator/validators:homogeneous_literal",
         "//validator/validators:regex",
         "//validator/validators:timestamp",
         "@maven//:com_google_guava_guava",
diff --git a/validator/src/test/java/dev/cel/validator/validators/HomogeneousLiteralValidatorTest.java b/validator/src/test/java/dev/cel/validator/validators/HomogeneousLiteralValidatorTest.java
new file mode 100644
index 00000000..b6de4a80
--- /dev/null
+++ b/validator/src/test/java/dev/cel/validator/validators/HomogeneousLiteralValidatorTest.java
@@ -0,0 +1,260 @@
+// Copyright 2023 Google LLC
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+//      https://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package dev.cel.validator.validators;
+
+import static com.google.common.truth.Truth.assertThat;
+import static dev.cel.common.CelFunctionDecl.newFunctionDeclaration;
+import static dev.cel.common.CelOverloadDecl.newGlobalOverload;
+import static dev.cel.common.CelOverloadDecl.newMemberOverload;
+
+import com.google.testing.junit.testparameterinjector.TestParameterInjector;
+import com.google.testing.junit.testparameterinjector.TestParameters;
+import dev.cel.bundle.Cel;
+import dev.cel.bundle.CelFactory;
+import dev.cel.common.CelAbstractSyntaxTree;
+import dev.cel.common.CelValidationResult;
+import dev.cel.common.types.SimpleType;
+import dev.cel.extensions.CelOptionalLibrary;
+import dev.cel.runtime.CelRuntime.CelFunctionBinding;
+import dev.cel.validator.CelValidator;
+import dev.cel.validator.CelValidatorFactory;
+import java.util.List;
+import java.util.Map;
+import org.junit.Test;
+import org.junit.runner.RunWith;
+
+@RunWith(TestParameterInjector.class)
+public class HomogeneousLiteralValidatorTest {
+  private static final Cel CEL =
+      CelFactory.standardCelBuilder()
+          .addCompilerLibraries(CelOptionalLibrary.INSTANCE)
+          .addRuntimeLibraries(CelOptionalLibrary.INSTANCE)
+          .build();
+
+  private static final CelValidator CEL_VALIDATOR =
+      CelValidatorFactory.standardCelValidatorBuilder(CEL)
+          .addAstValidators(HomogeneousLiteralValidator.newInstance())
+          .build();
+
+  @Test
+  @TestParameters("{source: '[1, 2, 3]'}")
+  @TestParameters("{source: '[dyn(1), dyn(2), dyn(3)]'}")
+  @TestParameters("{source: '[''hello'', ''world'', ''test'']'}")
+  @TestParameters("{source: '[''hello'', ?optional.ofNonZeroValue(''''), ?optional.of('''')]'}")
+  @TestParameters("{source: '[?optional.ofNonZeroValue(''''), ?optional.of(''''), ''hello'']'}")
+  public void list_containsHomogeneousLiterals(String source) throws Exception {
+    CelAbstractSyntaxTree ast = CEL.compile(source).getAst();
+
+    CelValidationResult result = CEL_VALIDATOR.validate(ast);
+
+    assertThat(result.hasError()).isFalse();
+    assertThat(result.getAllIssues()).isEmpty();
+    assertThat(CEL.createProgram(ast).eval()).isInstanceOf(List.class);
+  }
+
+  @Test
+  @TestParameters("{source: '{1: false, 2: true}'}")
+  @TestParameters("{source: '{''hello'': false, ''world'': true}'}")
+  @TestParameters("{source: '{''hello'': false, ?''world'': optional.ofNonZeroValue(true)}'}")
+  @TestParameters("{source: '{?''hello'': optional.ofNonZeroValue(false), ''world'': true}'}")
+  public void map_containsHomogeneousLiterals(String source) throws Exception {
+    CelAbstractSyntaxTree ast = CEL.compile(source).getAst();
+
+    CelValidationResult result = CEL_VALIDATOR.validate(ast);
+
+    assertThat(result.hasError()).isFalse();
+    assertThat(result.getAllIssues()).isEmpty();
+    assertThat(CEL.createProgram(ast).eval()).isInstanceOf(Map.class);
+  }
+
+  @Test
+  public void list_containsHeterogeneousLiterals() throws Exception {
+    CelAbstractSyntaxTree ast = CEL.compile("[1, 2, 'hello']").getAst();
+
+    CelValidationResult result = CEL_VALIDATOR.validate(ast);
+
+    assertThat(result.hasError()).isTrue();
+    assertThat(result.getAllIssues()).hasSize(1);
+    assertThat(result.getErrorString())
+        .contains(
+            "ERROR: <input>:1:8: expected type 'int' but found 'string'\n"
+                + " | [1, 2, 'hello']\n"
+                + " | .......^");
+  }
+
+  @Test
+  public void list_containsHeterogeneousLiteralsInNestedLists() throws Exception {
+    CelAbstractSyntaxTree ast = CEL.compile("[[1], ['hello']]").getAst();
+
+    CelValidationResult result = CEL_VALIDATOR.validate(ast);
+
+    assertThat(result.hasError()).isTrue();
+    assertThat(result.getAllIssues()).hasSize(1);
+    assertThat(result.getErrorString())
+        .contains(
+            "ERROR: <input>:1:7: expected type 'list(int)' but found 'list(string)'\n"
+                + " | [[1], ['hello']]\n"
+                + " | ......^");
+  }
+
+  @Test
+  public void list_containsHeterogeneousLiteralsInDyn() throws Exception {
+    CelAbstractSyntaxTree ast = CEL.compile("[1, 2, dyn(3)]").getAst();
+
+    CelValidationResult result = CEL_VALIDATOR.validate(ast);
+
+    assertThat(result.hasError()).isTrue();
+    assertThat(result.getAllIssues()).hasSize(1);
+    assertThat(result.getErrorString())
+        .contains(
+            "ERROR: <input>:1:11: expected type 'int' but found 'dyn'\n"
+                + " | [1, 2, dyn(3)]\n"
+                + " | ..........^");
+  }
+
+  @Test
+  public void mapKey_containsHeterogeneousLiterals() throws Exception {
+    CelAbstractSyntaxTree ast = CEL.compile("{1: true, 'hello': false}").getAst();
+
+    CelValidationResult result = CEL_VALIDATOR.validate(ast);
+
+    assertThat(result.hasError()).isTrue();
+    assertThat(result.getAllIssues()).hasSize(1);
+    assertThat(result.getErrorString())
+        .contains(
+            "ERROR: <input>:1:18: expected type 'int' but found 'string'\n"
+                + " | {1: true, 'hello': false}\n"
+                + " | .................^");
+  }
+
+  @Test
+  public void mapKey_containsHeterogeneousLiteralsInNestedMaps() throws Exception {
+    CelAbstractSyntaxTree ast = CEL.compile("{{'a': 1}: true, {'b': 'hello'}: false}").getAst();
+
+    CelValidationResult result = CEL_VALIDATOR.validate(ast);
+
+    assertThat(result.hasError()).isTrue();
+    assertThat(result.getAllIssues()).hasSize(1);
+    assertThat(result.getErrorString())
+        .contains(
+            "ERROR: <input>:1:32: expected type 'map(string, int)' but found 'map(string,"
+                + " string)'\n"
+                + " | {{'a': 1}: true, {'b': 'hello'}: false}\n"
+                + " | ...............................^");
+  }
+
+  @Test
+  public void mapKey_containsHeterogeneousLiteralsInDyn() throws Exception {
+    CelAbstractSyntaxTree ast = CEL.compile("{1: true, dyn(2): false}").getAst();
+
+    CelValidationResult result = CEL_VALIDATOR.validate(ast);
+
+    assertThat(result.hasError()).isTrue();
+    assertThat(result.getAllIssues()).hasSize(1);
+    assertThat(result.getErrorString())
+        .contains(
+            "ERROR: <input>:1:17: expected type 'int' but found 'dyn'\n"
+                + " | {1: true, dyn(2): false}\n"
+                + " | ................^");
+  }
+
+  @Test
+  public void mapValue_containsHeterogeneousLiterals() throws Exception {
+    CelAbstractSyntaxTree ast = CEL.compile("{1: true, 2: 'hello'}").getAst();
+
+    CelValidationResult result = CEL_VALIDATOR.validate(ast);
+
+    assertThat(result.hasError()).isTrue();
+    assertThat(result.getAllIssues()).hasSize(1);
+    assertThat(result.getErrorString())
+        .contains(
+            "ERROR: <input>:1:12: expected type 'bool' but found 'string'\n"
+                + " | {1: true, 2: 'hello'}\n"
+                + " | ...........^");
+  }
+
+  @Test
+  public void mapValue_containsHeterogeneousLiteralsInNestedMaps() throws Exception {
+    CelAbstractSyntaxTree ast = CEL.compile("{1: {'a': true}, 2: {'b': 'hello'}}").getAst();
+
+    CelValidationResult result = CEL_VALIDATOR.validate(ast);
+
+    assertThat(result.hasError()).isTrue();
+    assertThat(result.getAllIssues()).hasSize(1);
+    assertThat(result.getErrorString())
+        .contains(
+            "ERROR: <input>:1:19: expected type 'map(string, bool)' but found 'map(string,"
+                + " string)'\n"
+                + " | {1: {'a': true}, 2: {'b': 'hello'}}\n"
+                + " | ..................^");
+  }
+
+  @Test
+  public void mapValue_containsHeterogeneousLiteralsInDyn() throws Exception {
+    CelAbstractSyntaxTree ast = CEL.compile("{1: true, 2: dyn(false)}").getAst();
+
+    CelValidationResult result = CEL_VALIDATOR.validate(ast);
+
+    assertThat(result.hasError()).isTrue();
+    assertThat(result.getAllIssues()).hasSize(1);
+    assertThat(result.getErrorString())
+        .contains(
+            "ERROR: <input>:1:12: expected type 'bool' but found 'dyn'\n"
+                + " | {1: true, 2: dyn(false)}\n"
+                + " | ...........^");
+  }
+
+  @Test
+  @TestParameters("{source: 'exemptFunction([''a'', 2])'}")
+  @TestParameters("{source: 'exemptFunction({1: true, ''hello'': false})'}")
+  @TestParameters("{source: 'exemptFunction({1: {''a'': true, 2: false}})'}")
+  @TestParameters("{source: 'exemptFunction({{''a'': true, 2: false} : false})'}")
+  @TestParameters("{source: '''%s''.format([[1], [2.0]])'}")
+  @TestParameters("{source: '''%s''.format([[1, 2, [3.0, 4]]])'}")
+  @TestParameters("{source: '''%d''.format([[[1, 2, [3.0, 4]]].size()])'}")
+  @TestParameters("{source: '''%d''.format([[1, 2, size([3.0, 4])]])'}")
+  @TestParameters("{source: '''%s''.format([[[1, 2, [3.0, 4]]][0]])'}")
+  public void heterogeneousLiterals_inExemptFunction(String source) throws Exception {
+    Cel cel =
+        CelFactory.standardCelBuilder()
+            .addFunctionDeclarations(
+                newFunctionDeclaration(
+                    "exemptFunction",
+                    newGlobalOverload("exemptFunctionOverloadId", SimpleType.BOOL, SimpleType.DYN)),
+                newFunctionDeclaration(
+                    "format",
+                    newMemberOverload(
+                        "stringFormatOverloadId",
+                        SimpleType.BOOL,
+                        SimpleType.STRING,
+                        SimpleType.DYN)))
+            .addFunctionBindings(
+                CelFunctionBinding.from("exemptFunctionOverloadId", Object.class, (arg) -> true),
+                CelFunctionBinding.from(
+                    "stringFormatOverloadId", String.class, Object.class, (str, arg) -> true))
+            .build();
+    CelValidator validator =
+        CelValidatorFactory.standardCelValidatorBuilder(cel)
+            .addAstValidators(HomogeneousLiteralValidator.newInstance("exemptFunction", "format"))
+            .build();
+    CelAbstractSyntaxTree ast = cel.compile(source).getAst();
+
+    CelValidationResult result = validator.validate(ast);
+
+    assertThat(result.hasError()).isFalse();
+    assertThat(result.getAllIssues()).isEmpty();
+    assertThat(cel.createProgram(ast).eval()).isInstanceOf(Boolean.class);
+  }
+}
diff --git a/validator/validators/BUILD.bazel b/validator/validators/BUILD.bazel
index 333bbc1e..e4f9ea54 100644
--- a/validator/validators/BUILD.bazel
+++ b/validator/validators/BUILD.bazel
@@ -1,6 +1,6 @@
 package(
     default_applicable_licenses = ["//:license"],
-    default_visibility = ["//visibility:public"],  # TODO: Expose when ready
+    default_visibility = ["//visibility:public"],
 )
 
 java_library(
@@ -17,3 +17,8 @@ java_library(
     name = "regex",
     exports = ["//validator/src/main/java/dev/cel/validator/validators:regex"],
 )
+
+java_library(
+    name = "homogeneous_literal",
+    exports = ["//validator/src/main/java/dev/cel/validator/validators:homogeneous_literal"],
+)