diff --git a/cel/BUILD.bazel b/cel/BUILD.bazel
index 9b3445e1..de86e2c2 100644
--- a/cel/BUILD.bazel
+++ b/cel/BUILD.bazel
@@ -15,6 +15,7 @@ go_library(
"macro.go",
"options.go",
"program.go",
+ "validator.go",
],
importpath = "github.com/google/cel-go/cel",
visibility = ["//visibility:public"],
diff --git a/cel/cel_test.go b/cel/cel_test.go
index a4a48f88..d814334c 100644
--- a/cel/cel_test.go
+++ b/cel/cel_test.go
@@ -275,105 +275,6 @@ func TestCustomEnv(t *testing.T) {
})
}
-func TestHomogeneousAggregateLiterals(t *testing.T) {
- e, err := NewCustomEnv(
- Variable("name", StringType),
- Function(operators.In,
- Overload(overloads.InList, []*Type{StringType, ListType(StringType)}, BoolType,
- BinaryBinding(func(lhs, rhs ref.Val) ref.Val {
- return rhs.(traits.Container).Contains(lhs)
- }),
- ),
- Overload(overloads.InMap, []*Type{StringType, MapType(StringType, BoolType)}, BoolType,
- BinaryBinding(func(lhs, rhs ref.Val) ref.Val {
- return rhs.(traits.Container).Contains(lhs)
- }),
- ),
- ),
- HomogeneousAggregateLiterals(),
- )
- if err != nil {
- t.Fatalf("NewCustomEnv() failed: %v", err)
- }
-
- tests := []struct {
- name string
- expr string
- iss string
- vars map[string]any
- out ref.Val
- }{
- {
- name: "err_list",
- expr: `name in ['hello', 0]`,
- iss: `
- ERROR: :1:19: expected type 'string' but found 'int'
- | name in ['hello', 0]
- | ..................^`,
- },
- {
- name: "err_map_key",
- expr: `name in {'hello':'world', 1:'!'}`,
- iss: `
- ERROR: :1:6: found no matching overload for '@in' applied to '(string, map(!error!, string))'
- | name in {'hello':'world', 1:'!'}
- | .....^
- ERROR: :1:27: expected type 'string' but found 'int'
- | name in {'hello':'world', 1:'!'}
- | ..........................^`,
- },
- {
- name: "err_map_value",
- expr: `name in {'hello':'world', 'goodbye':true}`,
- iss: `
- ERROR: :1:37: expected type 'string' but found 'bool'
- | name in {'hello':'world', 'goodbye':true}
- | ....................................^`,
- },
- {
- name: "ok_list",
- expr: `name in ['hello', 'world']`,
- vars: map[string]any{"name": "world"},
- out: types.True,
- },
- {
- name: "ok_map",
- expr: `name in {'hello': false, 'world': true}`,
- vars: map[string]any{"name": "world"},
- out: types.True,
- },
- }
- for _, tst := range tests {
- tc := tst
- t.Run(tc.name, func(t *testing.T) {
- ast, iss := e.Compile(tc.expr)
- if tc.iss != "" {
- if iss.Err() == nil {
- t.Fatalf("e.Compile(%v) returned ast, expected error: %v", tc.expr, tc.iss)
- }
- if !test.Compare(iss.Err().Error(), tc.iss) {
- t.Fatalf("e.Compile(%v) returned %v, expected error: %v", tc.expr, iss.Err(), tc.iss)
- }
- return
- }
- if iss.Err() != nil {
- t.Fatalf("e.Compile(%v) failed: %v", tc.expr, iss.Err())
- }
- prg, err := e.Program(ast)
- if err != nil {
- t.Fatalf("e.Program() failed: %v", err)
- }
- out, _, err := prg.Eval(tc.vars)
- if err != nil {
- t.Fatalf("prg.Eval(%v) errored: %v", tc.vars, err)
- }
- if out != tc.out {
- t.Errorf("program eval got %v, wanted %v", out, tc.out)
- }
- })
- }
-}
-
func TestCrossTypeNumericComparisons(t *testing.T) {
tests := []struct {
name string
diff --git a/cel/env.go b/cel/env.go
index 56e60323..ada293f0 100644
--- a/cel/env.go
+++ b/cel/env.go
@@ -118,6 +118,7 @@ type Env struct {
features map[int]bool
appliedFeatures map[int]bool
libraries map[string]bool
+ validators []ASTValidator
// Internal parser representation
prsr *parser.Parser
@@ -178,14 +179,19 @@ func NewCustomEnv(opts ...EnvOption) (*Env, error) {
features: map[int]bool{},
appliedFeatures: map[int]bool{},
libraries: map[string]bool{},
+ validators: []ASTValidator{},
progOpts: []ProgramOption{},
}).configure(opts)
}
// Check performs type-checking on the input Ast and yields a checked Ast and/or set of Issues.
+// If any `ASTValidators` are configured on the environment, they will be applied after a valid
+// type-check result. If any issues are detected, the validators will provide them on the
+// output Issues object.
//
-// Checking has failed if the returned Issues value and its Issues.Err() value are non-nil.
-// Issues should be inspected if they are non-nil, but may not represent a fatal error.
+// Either checking or validation has failed if the returned Issues value and its Issues.Err()
+// value are non-nil. Issues should be inspected if they are non-nil, but may not represent a
+// fatal error.
//
// It is possible to have both non-nil Ast and Issues values returned from this call: however,
// the mere presence of an Ast does not imply that it is valid for use.
@@ -207,12 +213,22 @@ func (e *Env) Check(ast *Ast) (*Ast, *Issues) {
}
// Manually create the Ast to ensure that the Ast source information (which may be more
// detailed than the information provided by Check), is returned to the caller.
- return &Ast{
+ ast = &Ast{
source: ast.Source(),
expr: res.Expr,
info: res.SourceInfo,
refMap: res.ReferenceMap,
- typeMap: res.TypeMap}, nil
+ typeMap: res.TypeMap}
+
+ // Apply additional validators on the type-checked result.
+ iss := NewIssues(errs)
+ for _, v := range e.validators {
+ v.Validate(e, ast, iss)
+ }
+ if iss.Err() != nil {
+ return nil, iss
+ }
+ return ast, nil
}
// Compile combines the Parse and Check phases CEL program compilation to produce an Ast and
@@ -331,6 +347,8 @@ func (e *Env) Extend(opts ...EnvOption) (*Env, error) {
for k, v := range e.libraries {
libsCopy[k] = v
}
+ validatorsCopy := make([]ASTValidator, len(e.validators))
+ copy(validatorsCopy, e.validators)
ext := &Env{
Container: e.Container,
@@ -342,6 +360,7 @@ func (e *Env) Extend(opts ...EnvOption) (*Env, error) {
features: featuresCopy,
appliedFeatures: appliedFeaturesCopy,
libraries: libsCopy,
+ validators: validatorsCopy,
provider: provider,
chkOpts: chkOptsCopy,
prsrOpts: prsrOptsCopy,
@@ -362,6 +381,16 @@ func (e *Env) HasLibrary(libName string) bool {
return exists && configured
}
+// HasValidator returns whether a specific ASTValidator has been configured in the environment.
+func (e *Env) HasValidator(name string) bool {
+ for _, v := range e.validators {
+ if v.Name() == name {
+ return true
+ }
+ }
+ return false
+}
+
// Parse parses the input expression value `txt` to a Ast and/or a set of Issues.
//
// This form of Parse creates a Source value for the input `txt` and forwards to the
@@ -534,8 +563,6 @@ func (e *Env) initChecker() (*checker.Env, error) {
chkOpts := []checker.Option{}
chkOpts = append(chkOpts, e.chkOpts...)
chkOpts = append(chkOpts,
- checker.HomogeneousAggregateLiterals(
- e.HasFeature(featureDisableDynamicAggregateLiterals)),
checker.CrossTypeNumericComparisons(
e.HasFeature(featureCrossTypeNumericComparisons)))
diff --git a/cel/options.go b/cel/options.go
index 0cde7c8a..3b9ba2c9 100644
--- a/cel/options.go
+++ b/cel/options.go
@@ -41,13 +41,6 @@ import (
const (
_ = iota
- // Disallow heterogeneous aggregate (list, map) literals.
- // Note, it is still possible to have heterogeneous aggregates when
- // provided as variables to the expression, as well as via conversion
- // of well-known dynamic types, or with unchecked expressions.
- // Affects checking. Provides a subset of standard behavior.
- featureDisableDynamicAggregateLiterals
-
// Enable the tracking of function call expressions replaced by macros.
featureEnableMacroCallTracking
@@ -63,10 +56,6 @@ const (
// is not already in UTC.
featureDefaultUTCTimeZone
- // Enable the use of optional types in the syntax, type-system, type-checking,
- // and runtime.
- featureOptionalTypes
-
// Enable the serialization of logical operator ASTs as variadic calls, thus
// compressing the logic graph to a single call when multiple like-operator
// expressions occur: e.g. a && b && c && d -> call(_&&_, [a, b, c, d])
@@ -157,7 +146,7 @@ func EagerlyValidateDeclarations(enabled bool) EnvOption {
// expression, as well as via conversion of well-known dynamic types, or with unchecked
// expressions.
func HomogeneousAggregateLiterals() EnvOption {
- return features(featureDisableDynamicAggregateLiterals, true)
+ return ASTValidators(ValidateHomogeneousAggregateLiterals())
}
// variadicLogicalOperatorASTs flatten like-operator chained logical expressions into a single
diff --git a/cel/validator.go b/cel/validator.go
new file mode 100644
index 00000000..57a696af
--- /dev/null
+++ b/cel/validator.go
@@ -0,0 +1,285 @@
+// Copyright 2023 Google LLC
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package cel
+
+import (
+ "fmt"
+ "regexp"
+
+ "github.com/google/cel-go/common"
+ "github.com/google/cel-go/common/ast"
+ "github.com/google/cel-go/common/overloads"
+
+ exprpb "google.golang.org/genproto/googleapis/api/expr/v1alpha1"
+)
+
+// ASTValidators configures a set of ASTValidator instances into the target environment.
+//
+// Validators are applied in the order in which the are specified and are treated as singletons.
+// The same ASTValidator with a given name will not be applied more than once.
+func ASTValidators(validators ...ASTValidator) EnvOption {
+ return func(e *Env) (*Env, error) {
+ for _, v := range validators {
+ if !e.HasValidator(v.Name()) {
+ e.validators = append(e.validators, v)
+ }
+ }
+ return e, nil
+ }
+}
+
+// ASTValidator defines a singleton interface for validating a type-checked Ast against an environment.
+//
+// Note: the Issues argument is mutable in the sense that it is intended to collect errors which will be
+// reported to the caller.
+type ASTValidator interface {
+ // Name returns the name of the validator. Names must be unique.
+ Name() string
+
+ // Validate validates a given Ast within an Environment and collects a set of potential issues.
+ Validate(*Env, *Ast, *Issues)
+
+ is_validator()
+}
+
+// ExtendedValidations collects a set of common AST validations which reduce the likelihood of runtime errors.
+//
+// - Validate duration and timestamp literals
+// - Ensure regex strings are valid
+// - Disable mixed type list and map literals
+func ExtendedValidations() EnvOption {
+ return ASTValidators(
+ ValidateDurationLiterals(),
+ ValidateTimestampLiterals(),
+ ValidateRegexLiterals(),
+ ValidateHomogeneousAggregateLiterals(),
+ )
+}
+
+// ValidateDurationLiterals ensures that duration literal arguments are valid immediately after type-check.
+func ValidateDurationLiterals() ASTValidator {
+ return newFormatValidator(overloads.TypeConvertDuration, 0, evalCall)
+}
+
+// ValidateTimestampLiterals ensures that timestamp literal arguments are valid immediately after type-check.
+func ValidateTimestampLiterals() ASTValidator {
+ return newFormatValidator(overloads.TypeConvertTimestamp, 0, evalCall)
+}
+
+// ValidateRegexLiterals ensures that regex patterns are validated after type-check.
+func ValidateRegexLiterals() ASTValidator {
+ return newFormatValidator(overloads.Matches, 0, compileRegex)
+}
+
+// ValidateHomogeneousAggregateLiterals checks that all list and map literals entries have the same types, i.e.
+// no mixed list element types or mixed map key or map value types.
+//
+// Note: the string format call relies on a mixed element type list for ease of use, so this check skips all
+// literals which occur within string format calls.
+func ValidateHomogeneousAggregateLiterals() ASTValidator {
+ return homogeneousAggregateLiteralValidator{}
+}
+
+type argChecker func(env *Env, call, arg ast.NavigableExpr) error
+
+func newFormatValidator(funcName string, argNum int, check argChecker) formatValidator {
+ return formatValidator{
+ funcName: funcName,
+ check: check,
+ argNum: argNum,
+ }
+}
+
+type formatValidator struct {
+ funcName string
+ argNum int
+ check argChecker
+}
+
+// Name returns the unique name of this function format validator.
+func (v formatValidator) Name() string {
+ return fmt.Sprintf("cel.lib.std.validate.functions.%s", v.funcName)
+}
+
+// Validate searches the AST for uses of a given function name with a constant argument and performs a check
+// on whether the argument is a valid literal value.
+func (v formatValidator) Validate(e *Env, a *Ast, iss *Issues) {
+ errs := errorReporter{iss: iss, info: a.info}
+ root := ast.NavigateCheckedAST(astToCheckedAST(a))
+ funcCalls := ast.MatchDescendants(root, ast.FunctionMatcher(v.funcName))
+ for _, call := range funcCalls {
+ callArgs := call.AsCall().Args()
+ if len(callArgs) <= v.argNum {
+ continue
+ }
+ litArg := callArgs[v.argNum]
+ if litArg.Kind() != ast.LiteralKind {
+ continue
+ }
+ if err := v.check(e, call, litArg); err != nil {
+ errs.reportErrorAtID(litArg.ID(), "invalid %s argument", v.funcName)
+ }
+ }
+}
+
+func evalCall(env *Env, call, arg ast.NavigableExpr) error {
+ ast := ParsedExprToAst(&exprpb.ParsedExpr{Expr: call.ToExpr()})
+ prg, err := env.Program(ast)
+ if err != nil {
+ return err
+ }
+ _, _, err = prg.Eval(NoVars())
+ return err
+}
+
+func compileRegex(_ *Env, _, arg ast.NavigableExpr) error {
+ pattern := arg.AsLiteral().Value().(string)
+ _, err := regexp.Compile(pattern)
+ return err
+}
+
+func (formatValidator) is_validator() {}
+
+type homogeneousAggregateLiteralValidator struct{}
+
+// Name returns the unique name of the homogeneous type validator.
+func (homogeneousAggregateLiteralValidator) Name() string {
+ return "cel.lib.std.validate.types.homogeneous"
+}
+
+// Validate validates that all lists and map literals have homogeneous types, i.e. don't contain dyn types.
+//
+// This validator makes an exception for list and map literals which occur at any level of nesting within
+// string format calls.
+func (v homogeneousAggregateLiteralValidator) Validate(e *Env, a *Ast, iss *Issues) {
+ errs := errorReporter{iss: iss, info: a.info}
+ root := ast.NavigateCheckedAST(astToCheckedAST(a))
+ listExprs := ast.MatchDescendants(root, ast.KindMatcher(ast.ListKind))
+ for _, listExpr := range listExprs {
+ // TODO: Add a validator config object which allows libraries to influence validation options
+ // for validators that *might* be configured. In this case, a way of skipping certain function
+ // overloads.
+ if hasStringFormatAncestor(listExpr) {
+ continue
+ }
+ l := listExpr.AsList()
+ elements := l.Elements()
+ optIndices := l.OptionalIndices()
+ var elemType *Type
+ for i, e := range elements {
+ et := e.Type()
+ if isOptionalIndex(i, optIndices) {
+ et = et.Parameters()[0]
+ }
+ if elemType == nil {
+ elemType = et
+ continue
+ }
+ if !elemType.IsEquivalentType(et) {
+ v.typeMismatch(errs, e.ID(), elemType, et)
+ break
+ }
+ }
+ }
+ mapExprs := ast.MatchDescendants(root, ast.KindMatcher(ast.MapKind))
+ for _, mapExpr := range mapExprs {
+ if hasStringFormatAncestor(mapExpr) {
+ continue
+ }
+ m := mapExpr.AsMap()
+ entries := m.Entries()
+ var keyType, valType *Type
+ for _, e := range entries {
+ key, val := e.Key(), e.Value()
+ kt, vt := key.Type(), val.Type()
+ if e.IsOptional() {
+ vt = vt.Parameters()[0]
+ }
+ if keyType == nil && valType == nil {
+ keyType, valType = kt, vt
+ continue
+ }
+ if !keyType.IsEquivalentType(kt) {
+ v.typeMismatch(errs, key.ID(), keyType, kt)
+ }
+ if !valType.IsEquivalentType(vt) {
+ v.typeMismatch(errs, val.ID(), valType, vt)
+ }
+ }
+ }
+}
+
+func hasStringFormatAncestor(e ast.NavigableExpr) bool {
+ if parent, found := e.Parent(); found {
+ if parent.Kind() == ast.CallKind && parent.AsCall().FunctionName() == "format" {
+ return true
+ }
+ if parent.Kind() == ast.ListKind || parent.Kind() == ast.MapKind {
+ return hasStringFormatAncestor(parent)
+ }
+ }
+ return false
+}
+
+func isOptionalIndex(i int, optIndices []int32) bool {
+ for _, optInd := range optIndices {
+ if i == int(optInd) {
+ return true
+ }
+ }
+ return false
+}
+
+func (homogeneousAggregateLiteralValidator) typeMismatch(errs errorReporter, id int64, expected, actual *Type) {
+ errs.reportErrorAtID(id, "expected type '%s' but found '%s'", FormatCelType(expected), FormatCelType(actual))
+}
+
+func (homogeneousAggregateLiteralValidator) is_validator() {}
+
+type errorReporter struct {
+ iss *Issues
+ info *exprpb.SourceInfo
+}
+
+func (er *errorReporter) reportErrorAtID(id int64, message string, args ...any) {
+ er.iss.errs.ReportErrorAtID(id, locationByID(id, er.info), message, args...)
+}
+
+func locationByID(id int64, sourceInfo *exprpb.SourceInfo) common.Location {
+ positions := sourceInfo.GetPositions()
+ var line = 1
+ if offset, found := positions[id]; found {
+ col := int(offset)
+ for _, lineOffset := range sourceInfo.GetLineOffsets() {
+ if lineOffset < offset {
+ line++
+ col = int(offset - lineOffset)
+ } else {
+ break
+ }
+ }
+ return common.NewLocation(line, col)
+ }
+ return common.NoLocation
+}
+
+func astToCheckedAST(a *Ast) *ast.CheckedAST {
+ return &ast.CheckedAST{
+ Expr: a.expr,
+ SourceInfo: a.info,
+ TypeMap: a.typeMap,
+ ReferenceMap: a.refMap,
+ }
+}
diff --git a/cel/validator_test.go b/cel/validator_test.go
new file mode 100644
index 00000000..217dc289
--- /dev/null
+++ b/cel/validator_test.go
@@ -0,0 +1,342 @@
+// Copyright 2023 Google LLC
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package cel
+
+import (
+ "testing"
+
+ "github.com/google/cel-go/common/operators"
+ "github.com/google/cel-go/common/overloads"
+ "github.com/google/cel-go/common/types"
+ "github.com/google/cel-go/common/types/ref"
+ "github.com/google/cel-go/common/types/traits"
+ "github.com/google/cel-go/test"
+)
+
+func TestValidateDurationLiterals(t *testing.T) {
+ env, err := NewEnv(
+ Variable("x", types.StringType),
+ ASTValidators(ValidateDurationLiterals()))
+ if err != nil {
+ t.Fatalf("NewEnv(ValidateDurationLiterals()) failed: %v", err)
+ }
+
+ tests := []struct {
+ expr string
+ iss string
+ }{
+ {
+ expr: `duration('1')`,
+ iss: `ERROR: :1:10: invalid duration argument
+ | duration('1')
+ | .........^`,
+ },
+ {
+ expr: `duration('1d')`,
+ iss: `ERROR: :1:10: invalid duration argument
+ | duration('1d')
+ | .........^`,
+ },
+ {
+ expr: "duration('1us')\n < duration('1nns')",
+ iss: `ERROR: :2:13: invalid duration argument
+ | < duration('1nns')
+ | ............^`,
+ },
+ {
+ expr: `duration('2h3m4s5us')`,
+ },
+ {
+ expr: `duration(x)`,
+ },
+ }
+ for _, tst := range tests {
+ tc := tst
+ t.Run(tc.expr, func(t *testing.T) {
+ _, iss := env.Compile(tc.expr)
+ if tc.iss != "" {
+ if iss.Err() == nil {
+ t.Fatalf("e.Compile(%v) returned ast, expected error: %v", tc.expr, tc.iss)
+ }
+ if !test.Compare(iss.Err().Error(), tc.iss) {
+ t.Fatalf("e.Compile(%v) returned %v, expected error: %v", tc.expr, iss.Err(), tc.iss)
+ }
+ return
+ }
+ if iss.Err() != nil {
+ t.Fatalf("e.Compile(%v) failed: %v", tc.expr, iss.Err())
+ }
+ })
+ }
+}
+
+func TestValidateTimestampLiterals(t *testing.T) {
+ env, err := NewEnv(
+ Variable("x", types.StringType),
+ ASTValidators(ValidateTimestampLiterals()))
+ if err != nil {
+ t.Fatalf("NewEnv(ValidateTimestampLiterals()) failed: %v", err)
+ }
+
+ tests := []struct {
+ expr string
+ iss string
+ }{
+ {
+ expr: `timestamp('1000-00-00T00:00:00Z')`,
+ iss: `ERROR: :1:11: invalid timestamp argument
+ | timestamp('1000-00-00T00:00:00Z')
+ | ..........^`,
+ },
+ {
+ expr: `timestamp('1000-01-01T00:00:00ZZ')`,
+ iss: `ERROR: :1:11: invalid timestamp argument
+ | timestamp('1000-01-01T00:00:00ZZ')
+ | ..........^`,
+ },
+ {
+ expr: `timestamp('1000-01-01T00:00:00Z')`,
+ },
+ {
+ expr: `timestamp(-6213559680)`, // min unix epoch time.
+ },
+ {
+ expr: `timestamp(-62135596801)`,
+ iss: `ERROR: :1:11: invalid timestamp argument
+ | timestamp(-62135596801)
+ | ..........^`,
+ },
+ {
+ expr: `timestamp(x)`,
+ },
+ }
+ for _, tst := range tests {
+ tc := tst
+ t.Run(tc.expr, func(t *testing.T) {
+ _, iss := env.Compile(tc.expr)
+ if tc.iss != "" {
+ if iss.Err() == nil {
+ t.Fatalf("e.Compile(%v) returned ast, expected error: %v", tc.expr, tc.iss)
+ }
+ if !test.Compare(iss.Err().Error(), tc.iss) {
+ t.Fatalf("e.Compile(%v) returned %v, expected error: %v", tc.expr, iss.Err(), tc.iss)
+ }
+ return
+ }
+ if iss.Err() != nil {
+ t.Fatalf("e.Compile(%v) failed: %v", tc.expr, iss.Err())
+ }
+ })
+ }
+}
+
+func TestValidateRegexLiterals(t *testing.T) {
+ env, err := NewEnv(
+ Variable("x", types.StringType),
+ ASTValidators(ValidateRegexLiterals()))
+ if err != nil {
+ t.Fatalf("NewEnv(ValidateRegexLiterals()) failed: %v", err)
+ }
+
+ tests := []struct {
+ expr string
+ iss string
+ }{
+ {
+ expr: `'hello'.matches('el*')`,
+ },
+ {
+ expr: `'hello'.matches('x++')`,
+ iss: `
+ ERROR: :1:17: invalid matches argument
+ | 'hello'.matches('x++')
+ | ................^`,
+ },
+ {
+ expr: `'hello'.matches('(?el*)')`,
+ iss: `
+ ERROR: :1:17: invalid matches argument
+ | 'hello'.matches('(?el*)')
+ | ................^`,
+ },
+ {
+ expr: `'hello'.matches('??el*')`,
+ iss: `
+ ERROR: :1:17: invalid matches argument
+ | 'hello'.matches('??el*')
+ | ................^`,
+ },
+ {
+ expr: `'hello'.matches(x)`,
+ },
+ }
+ for _, tst := range tests {
+ tc := tst
+ t.Run(tc.expr, func(t *testing.T) {
+ _, iss := env.Compile(tc.expr)
+ if tc.iss != "" {
+ if iss.Err() == nil {
+ t.Fatalf("e.Compile(%v) returned ast, expected error: %v", tc.expr, tc.iss)
+ }
+ if !test.Compare(iss.Err().Error(), tc.iss) {
+ t.Fatalf("e.Compile(%v) returned %v, expected error: %v", tc.expr, iss.Err(), tc.iss)
+ }
+ return
+ }
+ if iss.Err() != nil {
+ t.Fatalf("e.Compile(%v) failed: %v", tc.expr, iss.Err())
+ }
+ })
+ }
+}
+
+func TestValidateHomogeneousAggregateLiterals(t *testing.T) {
+ env, err := NewCustomEnv(
+ Variable("name", StringType),
+ Function(operators.In,
+ Overload(overloads.InList, []*Type{StringType, ListType(StringType)}, BoolType,
+ BinaryBinding(func(lhs, rhs ref.Val) ref.Val {
+ return rhs.(traits.Container).Contains(lhs)
+ }),
+ ),
+ Overload(overloads.InMap, []*Type{StringType, MapType(StringType, BoolType)}, BoolType,
+ BinaryBinding(func(lhs, rhs ref.Val) ref.Val {
+ return rhs.(traits.Container).Contains(lhs)
+ }),
+ ),
+ ),
+ OptionalTypes(),
+ HomogeneousAggregateLiterals(),
+ ASTValidators(ValidateHomogeneousAggregateLiterals()),
+ )
+ if err != nil {
+ t.Fatalf("NewCustomEnv() failed: %v", err)
+ }
+
+ tests := []struct {
+ expr string
+ iss string
+ }{
+ {
+ expr: `name in ['hello', 0]`,
+ iss: `
+ ERROR: :1:19: expected type 'string' but found 'int'
+ | name in ['hello', 0]
+ | ..................^`,
+ },
+ {
+ expr: `{'hello':'world', 1:'!'}`,
+ iss: `
+ ERROR: :1:19: expected type 'string' but found 'int'
+ | {'hello':'world', 1:'!'}
+ | ..................^`,
+ },
+ {
+ expr: `name in {'hello':'world', 'goodbye':true}`,
+ iss: `
+ ERROR: :1:37: expected type 'string' but found 'bool'
+ | name in {'hello':'world', 'goodbye':true}
+ | ....................................^`,
+ },
+ {
+ expr: `name in ['hello', 'world']`,
+ },
+ {
+ expr: `name in ['hello', ?optional.ofNonZeroValue('')]`,
+ },
+ {
+ expr: `name in [?optional.ofNonZeroValue(''), 'hello', ?optional.of('')]`,
+ },
+ {
+ expr: `name in {'hello': false, 'world': true}`,
+ },
+ {
+ expr: `{'hello': false, ?'world': optional.ofNonZeroValue(true)}`,
+ },
+ {
+ expr: `{?'hello': optional.ofNonZeroValue(false), 'world': true}`,
+ },
+ }
+ for _, tst := range tests {
+ tc := tst
+ t.Run(tc.expr, func(t *testing.T) {
+ _, iss := env.Compile(tc.expr)
+ if tc.iss != "" {
+ if iss.Err() == nil {
+ t.Fatalf("e.Compile(%v) returned ast, expected error: %v", tc.expr, tc.iss)
+ }
+ if !test.Compare(iss.Err().Error(), tc.iss) {
+ t.Fatalf("e.Compile(%v) returned %v, expected error: %v", tc.expr, iss.Err(), tc.iss)
+ }
+ return
+ }
+ if iss.Err() != nil {
+ t.Fatalf("e.Compile(%v) failed: %v", tc.expr, iss.Err())
+ }
+ })
+ }
+}
+
+func TestExtendedValidations(t *testing.T) {
+ env, err := NewEnv(
+ Variable("x", types.StringType),
+ ExtendedValidations(),
+ )
+ if err != nil {
+ t.Fatalf("NewEnv(ExtendedValidations()) failed: %v", err)
+ }
+ tests := []struct {
+ expr string
+ iss string
+ }{
+ {
+ expr: `x in ['hello', 0]
+ && duration(x) < duration('1d')
+ && timestamp(x) != timestamp('1000-01-00T00:00:00Z')
+ && x.matches('x++')`,
+ iss: `
+ ERROR: :1:16: expected type 'string' but found 'int'
+ | x in ['hello', 0]
+ | ...............^
+ ERROR: :2:30: invalid duration argument
+ | && duration(x) < duration('1d')
+ | .............................^
+ ERROR: :3:33: invalid timestamp argument
+ | && timestamp(x) != timestamp('1000-01-00T00:00:00Z')
+ | ................................^
+ ERROR: :4:17: invalid matches argument
+ | && x.matches('x++')
+ | ................^`,
+ },
+ }
+ for _, tst := range tests {
+ tc := tst
+ t.Run(tc.expr, func(t *testing.T) {
+ _, iss := env.Compile(tc.expr)
+ if tc.iss != "" {
+ if iss.Err() == nil {
+ t.Fatalf("e.Compile(%v) returned ast, expected error: %v", tc.expr, tc.iss)
+ }
+ if !test.Compare(iss.Err().Error(), tc.iss) {
+ t.Fatalf("e.Compile(%v) returned %v, expected error: %v", tc.expr, iss.Err(), tc.iss)
+ }
+ return
+ }
+ if iss.Err() != nil {
+ t.Fatalf("e.Compile(%v) failed: %v", tc.expr, iss.Err())
+ }
+ })
+ }
+}
diff --git a/checker/checker_test.go b/checker/checker_test.go
index 551cd2d1..032bb5ea 100644
--- a/checker/checker_test.go
+++ b/checker/checker_test.go
@@ -23,8 +23,6 @@ import (
"github.com/google/cel-go/common/ast"
"github.com/google/cel-go/common/containers"
"github.com/google/cel-go/common/decls"
- "github.com/google/cel-go/common/operators"
- "github.com/google/cel-go/common/overloads"
"github.com/google/cel-go/common/stdlib"
"github.com/google/cel-go/common/types"
"github.com/google/cel-go/parser"
@@ -1295,92 +1293,6 @@ _&&_(_==_(list~type(list(dyn))^list,
)~bool^equals`,
outType: types.BoolType,
},
- // Homogeneous aggregate type restriction tests.
- {
- in: `name in [1, 2u, 'string']`,
- env: testEnv{
- idents: []*decls.VariableDecl{
- decls.NewVariable("name", types.StringType),
- },
- functions: []*decls.FunctionDecl{
- testFunction(t, operators.In,
- decls.Overload(overloads.InList,
- []*types.Type{
- types.StringType,
- types.NewListType(types.StringType),
- }, types.BoolType)),
- },
- },
- opts: []Option{HomogeneousAggregateLiterals(true)},
- disableStdEnv: true,
- out: `@in(
- name~string^name,
- [
- 1~int,
- 2u~uint,
- "string"~string
- ]~list(string)
- )~bool^in_list`,
- err: `ERROR: :1:13: expected type 'int' but found 'uint'
- | name in [1, 2u, 'string']
- | ............^`,
- },
- {
- in: `name in [1, 2, 3]`,
- env: testEnv{
- idents: []*decls.VariableDecl{
- decls.NewVariable("name", types.StringType),
- },
- functions: []*decls.FunctionDecl{
- testFunction(t, operators.In,
- decls.Overload(overloads.InList,
- []*types.Type{
- types.StringType,
- types.NewListType(types.StringType),
- }, types.BoolType)),
- },
- },
- opts: []Option{HomogeneousAggregateLiterals(true)},
- disableStdEnv: true,
- out: `@in(
- name~string^name,
- [
- 1~int,
- 2~int,
- 3~int
- ]~list(int)
- )~!error!`,
- err: `ERROR: :1:6: found no matching overload for '@in' applied to '(string, list(int))'
- | name in [1, 2, 3]
- | .....^`,
- },
- {
- in: `name in ["1", "2", "3"]`,
- env: testEnv{
- idents: []*decls.VariableDecl{
- decls.NewVariable("name", types.StringType),
- },
- functions: []*decls.FunctionDecl{
- testFunction(t, operators.In,
- decls.Overload(overloads.InList,
- []*types.Type{
- types.StringType,
- types.NewListType(types.StringType),
- }, types.BoolType)),
- },
- },
- opts: []Option{HomogeneousAggregateLiterals(true)},
- disableStdEnv: true,
- out: `@in(
- name~string^name,
- [
- "1"~string,
- "2"~string,
- "3"~string
- ]~list(string)
- )~bool^in_list`,
- outType: types.BoolType,
- },
{
in: `([[[1]], [[2]], [[3]]][0][0] + [2, 3, {'four': {'five': 'six'}}])[3]`,
out: `_[_](
diff --git a/checker/options.go b/checker/options.go
index 63b5fbd1..0560c381 100644
--- a/checker/options.go
+++ b/checker/options.go
@@ -32,15 +32,6 @@ func CrossTypeNumericComparisons(enabled bool) Option {
}
}
-// HomogeneousAggregateLiterals toggles support for constructing lists and maps whose elements all
-// have the same type.
-func HomogeneousAggregateLiterals(enabled bool) Option {
- return func(opts *options) error {
- opts.homogeneousAggregateLiterals = enabled
- return nil
- }
-}
-
// ValidatedDeclarations provides a references to validated declarations which will be copied
// into new checker instances.
func ValidatedDeclarations(env *Env) Option {
diff --git a/ext/strings_test.go b/ext/strings_test.go
index e6d5e441..b8fa2ebe 100644
--- a/ext/strings_test.go
+++ b/ext/strings_test.go
@@ -1260,6 +1260,7 @@ func TestStringFormat(t *testing.T) {
cel.Container("ext"),
cel.Abbrevs("google.expr.proto3.test"),
cel.Types(&proto3pb.TestAllTypes{}),
+ cel.ASTValidators(cel.ValidateHomogeneousAggregateLiterals()),
NativeTypes(
reflect.TypeOf(&TestNestedType{}),
reflect.ValueOf(&TestAllTypes{}),