From 308ea298d3772626f55ad21a904abec9af086a17 Mon Sep 17 00:00:00 2001 From: TristonianJones Date: Wed, 5 Jul 2023 18:21:02 -0700 Subject: [PATCH] Validators for common literals and use cases. Introduce validators for duration, timestamp, and regex string literals as well as support for homogeneous aggregate literals with a carveout for string format calls. --- cel/BUILD.bazel | 1 + cel/cel_test.go | 99 -------------- cel/env.go | 31 ++++- cel/options.go | 13 +- cel/validator.go | 274 +++++++++++++++++++++++++++++++++++++ cel/validator_test.go | 290 ++++++++++++++++++++++++++++++++++++++++ checker/checker_test.go | 88 ------------ checker/options.go | 9 -- ext/strings_test.go | 1 + 9 files changed, 594 insertions(+), 212 deletions(-) create mode 100644 cel/validator.go create mode 100644 cel/validator_test.go diff --git a/cel/BUILD.bazel b/cel/BUILD.bazel index 9b3445e11..de86e2c24 100644 --- a/cel/BUILD.bazel +++ b/cel/BUILD.bazel @@ -15,6 +15,7 @@ go_library( "macro.go", "options.go", "program.go", + "validator.go", ], importpath = "github.com/google/cel-go/cel", visibility = ["//visibility:public"], diff --git a/cel/cel_test.go b/cel/cel_test.go index a4a48f885..d814334c9 100644 --- a/cel/cel_test.go +++ b/cel/cel_test.go @@ -275,105 +275,6 @@ func TestCustomEnv(t *testing.T) { }) } -func TestHomogeneousAggregateLiterals(t *testing.T) { - e, err := NewCustomEnv( - Variable("name", StringType), - Function(operators.In, - Overload(overloads.InList, []*Type{StringType, ListType(StringType)}, BoolType, - BinaryBinding(func(lhs, rhs ref.Val) ref.Val { - return rhs.(traits.Container).Contains(lhs) - }), - ), - Overload(overloads.InMap, []*Type{StringType, MapType(StringType, BoolType)}, BoolType, - BinaryBinding(func(lhs, rhs ref.Val) ref.Val { - return rhs.(traits.Container).Contains(lhs) - }), - ), - ), - HomogeneousAggregateLiterals(), - ) - if err != nil { - t.Fatalf("NewCustomEnv() failed: %v", err) - } - - tests := []struct { - name string - expr string - iss string - vars map[string]any - out ref.Val - }{ - { - name: "err_list", - expr: `name in ['hello', 0]`, - iss: ` - ERROR: :1:19: expected type 'string' but found 'int' - | name in ['hello', 0] - | ..................^`, - }, - { - name: "err_map_key", - expr: `name in {'hello':'world', 1:'!'}`, - iss: ` - ERROR: :1:6: found no matching overload for '@in' applied to '(string, map(!error!, string))' - | name in {'hello':'world', 1:'!'} - | .....^ - ERROR: :1:27: expected type 'string' but found 'int' - | name in {'hello':'world', 1:'!'} - | ..........................^`, - }, - { - name: "err_map_value", - expr: `name in {'hello':'world', 'goodbye':true}`, - iss: ` - ERROR: :1:37: expected type 'string' but found 'bool' - | name in {'hello':'world', 'goodbye':true} - | ....................................^`, - }, - { - name: "ok_list", - expr: `name in ['hello', 'world']`, - vars: map[string]any{"name": "world"}, - out: types.True, - }, - { - name: "ok_map", - expr: `name in {'hello': false, 'world': true}`, - vars: map[string]any{"name": "world"}, - out: types.True, - }, - } - for _, tst := range tests { - tc := tst - t.Run(tc.name, func(t *testing.T) { - ast, iss := e.Compile(tc.expr) - if tc.iss != "" { - if iss.Err() == nil { - t.Fatalf("e.Compile(%v) returned ast, expected error: %v", tc.expr, tc.iss) - } - if !test.Compare(iss.Err().Error(), tc.iss) { - t.Fatalf("e.Compile(%v) returned %v, expected error: %v", tc.expr, iss.Err(), tc.iss) - } - return - } - if iss.Err() != nil { - t.Fatalf("e.Compile(%v) failed: %v", tc.expr, iss.Err()) - } - prg, err := e.Program(ast) - if err != nil { - t.Fatalf("e.Program() failed: %v", err) - } - out, _, err := prg.Eval(tc.vars) - if err != nil { - t.Fatalf("prg.Eval(%v) errored: %v", tc.vars, err) - } - if out != tc.out { - t.Errorf("program eval got %v, wanted %v", out, tc.out) - } - }) - } -} - func TestCrossTypeNumericComparisons(t *testing.T) { tests := []struct { name string diff --git a/cel/env.go b/cel/env.go index 56e60323c..cbffc2434 100644 --- a/cel/env.go +++ b/cel/env.go @@ -118,6 +118,7 @@ type Env struct { features map[int]bool appliedFeatures map[int]bool libraries map[string]bool + validators []ASTValidator // Internal parser representation prsr *parser.Parser @@ -178,6 +179,7 @@ func NewCustomEnv(opts ...EnvOption) (*Env, error) { features: map[int]bool{}, appliedFeatures: map[int]bool{}, libraries: map[string]bool{}, + validators: []ASTValidator{}, progOpts: []ProgramOption{}, }).configure(opts) } @@ -207,12 +209,22 @@ func (e *Env) Check(ast *Ast) (*Ast, *Issues) { } // Manually create the Ast to ensure that the Ast source information (which may be more // detailed than the information provided by Check), is returned to the caller. - return &Ast{ + ast = &Ast{ source: ast.Source(), expr: res.Expr, info: res.SourceInfo, refMap: res.ReferenceMap, - typeMap: res.TypeMap}, nil + typeMap: res.TypeMap} + + // Apply additional validators on the type-checked result. + iss := NewIssues(errs) + for _, v := range e.validators { + v.Validate(e, ast, iss) + } + if iss.Err() != nil { + return nil, iss + } + return ast, nil } // Compile combines the Parse and Check phases CEL program compilation to produce an Ast and @@ -331,6 +343,8 @@ func (e *Env) Extend(opts ...EnvOption) (*Env, error) { for k, v := range e.libraries { libsCopy[k] = v } + validatorsCopy := make([]ASTValidator, len(e.validators)) + copy(validatorsCopy, e.validators) ext := &Env{ Container: e.Container, @@ -342,6 +356,7 @@ func (e *Env) Extend(opts ...EnvOption) (*Env, error) { features: featuresCopy, appliedFeatures: appliedFeaturesCopy, libraries: libsCopy, + validators: validatorsCopy, provider: provider, chkOpts: chkOptsCopy, prsrOpts: prsrOptsCopy, @@ -362,6 +377,16 @@ func (e *Env) HasLibrary(libName string) bool { return exists && configured } +// HasValidator returns whether a specific ASTValidator has been configured in the environment. +func (e *Env) HasValidator(name string) bool { + for _, v := range e.validators { + if v.Name() == name { + return true + } + } + return false +} + // Parse parses the input expression value `txt` to a Ast and/or a set of Issues. // // This form of Parse creates a Source value for the input `txt` and forwards to the @@ -534,8 +559,6 @@ func (e *Env) initChecker() (*checker.Env, error) { chkOpts := []checker.Option{} chkOpts = append(chkOpts, e.chkOpts...) chkOpts = append(chkOpts, - checker.HomogeneousAggregateLiterals( - e.HasFeature(featureDisableDynamicAggregateLiterals)), checker.CrossTypeNumericComparisons( e.HasFeature(featureCrossTypeNumericComparisons))) diff --git a/cel/options.go b/cel/options.go index 0cde7c8a2..3b9ba2c9e 100644 --- a/cel/options.go +++ b/cel/options.go @@ -41,13 +41,6 @@ import ( const ( _ = iota - // Disallow heterogeneous aggregate (list, map) literals. - // Note, it is still possible to have heterogeneous aggregates when - // provided as variables to the expression, as well as via conversion - // of well-known dynamic types, or with unchecked expressions. - // Affects checking. Provides a subset of standard behavior. - featureDisableDynamicAggregateLiterals - // Enable the tracking of function call expressions replaced by macros. featureEnableMacroCallTracking @@ -63,10 +56,6 @@ const ( // is not already in UTC. featureDefaultUTCTimeZone - // Enable the use of optional types in the syntax, type-system, type-checking, - // and runtime. - featureOptionalTypes - // Enable the serialization of logical operator ASTs as variadic calls, thus // compressing the logic graph to a single call when multiple like-operator // expressions occur: e.g. a && b && c && d -> call(_&&_, [a, b, c, d]) @@ -157,7 +146,7 @@ func EagerlyValidateDeclarations(enabled bool) EnvOption { // expression, as well as via conversion of well-known dynamic types, or with unchecked // expressions. func HomogeneousAggregateLiterals() EnvOption { - return features(featureDisableDynamicAggregateLiterals, true) + return ASTValidators(ValidateHomogeneousAggregateLiterals()) } // variadicLogicalOperatorASTs flatten like-operator chained logical expressions into a single diff --git a/cel/validator.go b/cel/validator.go new file mode 100644 index 000000000..00f5ae769 --- /dev/null +++ b/cel/validator.go @@ -0,0 +1,274 @@ +// Copyright 2023 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package cel + +import ( + "fmt" + "regexp" + + "github.com/google/cel-go/common" + "github.com/google/cel-go/common/ast" + "github.com/google/cel-go/common/overloads" + + exprpb "google.golang.org/genproto/googleapis/api/expr/v1alpha1" +) + +// ASTValidators configures a set of ASTValidator instances into the target environment. +// +// Validators are applied in the order in which the are specified and are treated as singletons. +// The same ASTValidator with a given name will not be applied more than once. +func ASTValidators(validators ...ASTValidator) EnvOption { + return func(e *Env) (*Env, error) { + for _, v := range validators { + if !e.HasValidator(v.Name()) { + e.validators = append(e.validators, v) + } + } + return e, nil + } +} + +// ASTValidator defines a singleton interface for validating a type-checked Ast against an environment. +// +// Note: the Issues argument is mutable in the sense that it is intended to collect errors which will be +// reported to the caller. +type ASTValidator interface { + // Name returns the name of the validator. Names must be unique. + Name() string + + // Validate validates a given Ast within an Environment and collects a set of potential issues. + Validate(*Env, *Ast, *Issues) + + is_validator() +} + +// ExtendedValidations collects a set of common AST validations which reduce the likelihood of runtime errors. +// +// - Validate duration and timestamp literals +// - Ensure regex strings are valid +// - Disable mixed type list and map literals +func ExtendedValidations() EnvOption { + return ASTValidators( + ValidateDurationLiterals(), + ValidateTimestampLiterals(), + ValidateRegexLiterals(), + ValidateHomogeneousAggregateLiterals(), + ) +} + +// ValidateDurationLiterals ensures that duration literal arguments are valid immediately after type-check. +func ValidateDurationLiterals() ASTValidator { + return newFormatValidator(overloads.TypeConvertDuration, 0, evalCall) +} + +// ValidateTimestampLiterals ensures that timestamp literal arguments are valid immediately after type-check. +func ValidateTimestampLiterals() ASTValidator { + return newFormatValidator(overloads.TypeConvertTimestamp, 0, evalCall) +} + +// ValidateRegexLiterals ensures that regex patterns are validated after type-check. +func ValidateRegexLiterals() ASTValidator { + return newFormatValidator(overloads.Matches, 0, compileRegex) +} + +// ValidateHomogeneousAggregateLiterals checks that all list and map literals entries have the same types, i.e. +// no mixed list element types or mixed map key or map value types. +// +// Note: the string format call relies on a mixed element type list for ease of use, so this check skips all +// literals which occure within string format calls. +func ValidateHomogeneousAggregateLiterals() ASTValidator { + return homogeneousAggregateLiteralValidator{} +} + +type argChecker func(env *Env, call, arg ast.NavigableExpr) error + +func newFormatValidator(funcName string, argNum int, check argChecker) formatValidator { + return formatValidator{ + funcName: funcName, + check: check, + argNum: argNum, + } +} + +type formatValidator struct { + funcName string + argNum int + check argChecker +} + +func (v formatValidator) Name() string { + return fmt.Sprintf("cel.lib.std.functions.%s", v.funcName) +} + +func (v formatValidator) Validate(e *Env, a *Ast, iss *Issues) { + errs := errorReporter{iss: iss, info: a.info} + root := ast.NavigateCheckedAST(astToCheckedAST(a)) + funcCalls := ast.MatchDescendants(root, ast.FunctionMatcher(v.funcName)) + for _, call := range funcCalls { + callArgs := call.AsCall().Args() + if len(callArgs) <= v.argNum { + continue + } + litArg := callArgs[v.argNum] + if litArg.Kind() != ast.LiteralKind { + continue + } + if err := v.check(e, call, litArg); err != nil { + errs.reportErrorAtID(litArg.ID(), "invalid %s argument: %v", v.funcName, litArg.AsLiteral()) + } + } +} + +func evalCall(env *Env, call, arg ast.NavigableExpr) error { + ast := ParsedExprToAst(&exprpb.ParsedExpr{Expr: call.ToExpr()}) + prg, err := env.Program(ast) + if err != nil { + return err + } + _, _, err = prg.Eval(NoVars()) + return err +} + +func compileRegex(_ *Env, _, arg ast.NavigableExpr) error { + pattern := arg.AsLiteral().Value().(string) + _, err := regexp.Compile(pattern) + return err +} + +func (formatValidator) is_validator() {} + +type homogeneousAggregateLiteralValidator struct{} + +func (homogeneousAggregateLiteralValidator) Name() string { + return "cel.lib.std.types.homogeneous" +} + +func (v homogeneousAggregateLiteralValidator) Validate(_ *Env, a *Ast, iss *Issues) { + errs := errorReporter{iss: iss, info: a.info} + root := ast.NavigateCheckedAST(astToCheckedAST(a)) + listExprs := ast.MatchDescendants(root, ast.KindMatcher(ast.ListKind)) + for _, listExpr := range listExprs { + if hasStringFormatAncestor(listExpr) { + continue + } + l := listExpr.AsList() + elements := l.Elements() + optIndices := l.OptionalIndices() + var elemType *Type + for i, e := range elements { + et := e.Type() + if isOptionalIndex(i, optIndices) { + et = et.Parameters()[0] + } + if elemType == nil { + elemType = et + continue + } + if !elemType.IsEquivalentType(et) { + v.typeMismatch(errs, e.ID(), elemType, et) + break + } + } + } + mapExprs := ast.MatchDescendants(root, ast.KindMatcher(ast.MapKind)) + for _, mapExpr := range mapExprs { + if hasStringFormatAncestor(mapExpr) { + continue + } + m := mapExpr.AsMap() + entries := m.Entries() + var keyType, valType *Type + for _, e := range entries { + key, val := e.Key(), e.Value() + kt, vt := key.Type(), val.Type() + if e.IsOptional() { + vt = vt.Parameters()[0] + } + if keyType == nil && valType == nil { + keyType, valType = kt, vt + continue + } + if !keyType.IsEquivalentType(kt) { + v.typeMismatch(errs, key.ID(), keyType, kt) + } + if !valType.IsEquivalentType(vt) { + v.typeMismatch(errs, val.ID(), valType, vt) + } + } + } +} + +func hasStringFormatAncestor(e ast.NavigableExpr) bool { + if parent, found := e.Parent(); found { + if parent.Kind() == ast.CallKind && parent.AsCall().FunctionName() == "format" { + return true + } + if parent.Kind() == ast.ListKind || parent.Kind() == ast.MapKind { + return hasStringFormatAncestor(parent) + } + } + return false +} + +func isOptionalIndex(i int, optIndices []int32) bool { + for _, optInd := range optIndices { + if i == int(optInd) { + return true + } + } + return false +} + +func (homogeneousAggregateLiteralValidator) typeMismatch(errs errorReporter, id int64, expected, actual *Type) { + errs.reportErrorAtID(id, "expected type '%s' but found '%s'", FormatCelType(expected), FormatCelType(actual)) +} + +func (homogeneousAggregateLiteralValidator) is_validator() {} + +type errorReporter struct { + iss *Issues + info *exprpb.SourceInfo +} + +func (er *errorReporter) reportErrorAtID(id int64, message string, args ...any) { + er.iss.errs.ReportErrorAtID(id, locationByID(id, er.info), message, args...) +} + +func locationByID(id int64, sourceInfo *exprpb.SourceInfo) common.Location { + positions := sourceInfo.GetPositions() + var line = 1 + if offset, found := positions[id]; found { + col := int(offset) + for _, lineOffset := range sourceInfo.GetLineOffsets() { + if lineOffset < offset { + line++ + col = int(offset - lineOffset) + } else { + break + } + } + return common.NewLocation(line, col) + } + return common.NoLocation +} + +func astToCheckedAST(a *Ast) *ast.CheckedAST { + return &ast.CheckedAST{ + Expr: a.expr, + SourceInfo: a.info, + TypeMap: a.typeMap, + ReferenceMap: a.refMap, + } +} diff --git a/cel/validator_test.go b/cel/validator_test.go new file mode 100644 index 000000000..5bf4df570 --- /dev/null +++ b/cel/validator_test.go @@ -0,0 +1,290 @@ +// Copyright 2023 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package cel + +import ( + "testing" + + "github.com/google/cel-go/common/operators" + "github.com/google/cel-go/common/overloads" + "github.com/google/cel-go/common/types" + "github.com/google/cel-go/common/types/ref" + "github.com/google/cel-go/common/types/traits" + "github.com/google/cel-go/test" +) + +func TestValidateDurationLiterals(t *testing.T) { + e, err := NewEnv( + Variable("x", types.StringType), + ASTValidators(ValidateDurationLiterals())) + if err != nil { + t.Fatalf("NewEnv(ValidateDurationLiterals()) failed: %v", err) + } + + tests := []struct { + expr string + iss string + }{ + { + expr: `duration('1')`, + iss: `ERROR: :1:10: invalid duration argument: 1 + | duration('1') + | .........^`, + }, + { + expr: `duration('1d')`, + iss: `ERROR: :1:10: invalid duration argument: 1d + | duration('1d') + | .........^`, + }, + { + expr: "duration('1us')\n < duration('1nns')", + iss: `ERROR: :2:13: invalid duration argument: 1nns + | < duration('1nns') + | ............^`, + }, + { + expr: `duration('2h3m4s5us')`, + }, + { + expr: `duration(x)`, + }, + } + for _, tst := range tests { + tc := tst + t.Run(tc.expr, func(t *testing.T) { + _, iss := e.Compile(tc.expr) + if tc.iss != "" { + if iss.Err() == nil { + t.Fatalf("e.Compile(%v) returned ast, expected error: %v", tc.expr, tc.iss) + } + if !test.Compare(iss.Err().Error(), tc.iss) { + t.Fatalf("e.Compile(%v) returned %v, expected error: %v", tc.expr, iss.Err(), tc.iss) + } + return + } + if iss.Err() != nil { + t.Fatalf("e.Compile(%v) failed: %v", tc.expr, iss.Err()) + } + }) + } +} + +func TestValidateTimestampLiterals(t *testing.T) { + e, err := NewEnv( + Variable("x", types.StringType), + ASTValidators(ValidateTimestampLiterals())) + if err != nil { + t.Fatalf("NewEnv(ValidateTimestampLiterals()) failed: %v", err) + } + + tests := []struct { + expr string + iss string + }{ + { + expr: `timestamp('1000-00-00T00:00:00Z')`, + iss: `ERROR: :1:11: invalid timestamp argument: 1000-00-00T00:00:00Z + | timestamp('1000-00-00T00:00:00Z') + | ..........^`, + }, + { + expr: `timestamp('1000-01-01T00:00:00ZZ')`, + iss: `ERROR: :1:11: invalid timestamp argument: 1000-01-01T00:00:00ZZ + | timestamp('1000-01-01T00:00:00ZZ') + | ..........^`, + }, + { + expr: `timestamp('1000-01-01T00:00:00Z')`, + }, + { + expr: `timestamp(-6213559680)`, // min unix epoch time. + }, + { + expr: `timestamp(-62135596801)`, + iss: `ERROR: :1:11: invalid timestamp argument: -62135596801 + | timestamp(-62135596801) + | ..........^`, + }, + { + expr: `timestamp(x)`, + }, + } + for _, tst := range tests { + tc := tst + t.Run(tc.expr, func(t *testing.T) { + _, iss := e.Compile(tc.expr) + if tc.iss != "" { + if iss.Err() == nil { + t.Fatalf("e.Compile(%v) returned ast, expected error: %v", tc.expr, tc.iss) + } + if !test.Compare(iss.Err().Error(), tc.iss) { + t.Fatalf("e.Compile(%v) returned %v, expected error: %v", tc.expr, iss.Err(), tc.iss) + } + return + } + if iss.Err() != nil { + t.Fatalf("e.Compile(%v) failed: %v", tc.expr, iss.Err()) + } + }) + } +} + +func TestValidateRegexLiterals(t *testing.T) { + e, err := NewEnv( + Variable("x", types.StringType), + ASTValidators(ValidateRegexLiterals())) + if err != nil { + t.Fatalf("NewEnv(ValidateRegexLiterals()) failed: %v", err) + } + + tests := []struct { + expr string + iss string + }{ + { + expr: `'hello'.matches('el*')`, + }, + { + expr: `'hello'.matches('x++')`, + iss: ` + ERROR: :1:17: invalid matches argument: x++ + | 'hello'.matches('x++') + | ................^`, + }, + { + expr: `'hello'.matches('(?el*)')`, + iss: ` + ERROR: :1:17: invalid matches argument: (?el*) + | 'hello'.matches('(?el*)') + | ................^`, + }, + { + expr: `'hello'.matches('??el*')`, + iss: ` + ERROR: :1:17: invalid matches argument: ??el* + | 'hello'.matches('??el*') + | ................^`, + }, + { + expr: `'hello'.matches(x)`, + }, + } + for _, tst := range tests { + tc := tst + t.Run(tc.expr, func(t *testing.T) { + _, iss := e.Compile(tc.expr) + if tc.iss != "" { + if iss.Err() == nil { + t.Fatalf("e.Compile(%v) returned ast, expected error: %v", tc.expr, tc.iss) + } + if !test.Compare(iss.Err().Error(), tc.iss) { + t.Fatalf("e.Compile(%v) returned %v, expected error: %v", tc.expr, iss.Err(), tc.iss) + } + return + } + if iss.Err() != nil { + t.Fatalf("e.Compile(%v) failed: %v", tc.expr, iss.Err()) + } + }) + } +} + +func TestValidateHomogeneousAggregateLiterals(t *testing.T) { + e, err := NewCustomEnv( + Variable("name", StringType), + Function(operators.In, + Overload(overloads.InList, []*Type{StringType, ListType(StringType)}, BoolType, + BinaryBinding(func(lhs, rhs ref.Val) ref.Val { + return rhs.(traits.Container).Contains(lhs) + }), + ), + Overload(overloads.InMap, []*Type{StringType, MapType(StringType, BoolType)}, BoolType, + BinaryBinding(func(lhs, rhs ref.Val) ref.Val { + return rhs.(traits.Container).Contains(lhs) + }), + ), + ), + OptionalTypes(), + HomogeneousAggregateLiterals(), + ASTValidators(ValidateHomogeneousAggregateLiterals()), + ) + if err != nil { + t.Fatalf("NewCustomEnv() failed: %v", err) + } + + tests := []struct { + expr string + iss string + }{ + { + expr: `name in ['hello', 0]`, + iss: ` + ERROR: :1:19: expected type 'string' but found 'int' + | name in ['hello', 0] + | ..................^`, + }, + { + expr: `{'hello':'world', 1:'!'}`, + iss: ` + ERROR: :1:19: expected type 'string' but found 'int' + | {'hello':'world', 1:'!'} + | ..................^`, + }, + { + expr: `name in {'hello':'world', 'goodbye':true}`, + iss: ` + ERROR: :1:37: expected type 'string' but found 'bool' + | name in {'hello':'world', 'goodbye':true} + | ....................................^`, + }, + { + expr: `name in ['hello', 'world']`, + }, + { + expr: `name in ['hello', ?optional.ofNonZeroValue('')]`, + }, + { + expr: `name in [?optional.ofNonZeroValue(''), 'hello', ?optional.of('')]`, + }, + { + expr: `name in {'hello': false, 'world': true}`, + }, + { + expr: `{'hello': false, ?'world': optional.ofNonZeroValue(true)}`, + }, + { + expr: `{?'hello': optional.ofNonZeroValue(false), 'world': true}`, + }, + } + for _, tst := range tests { + tc := tst + t.Run(tc.expr, func(t *testing.T) { + _, iss := e.Compile(tc.expr) + if tc.iss != "" { + if iss.Err() == nil { + t.Fatalf("e.Compile(%v) returned ast, expected error: %v", tc.expr, tc.iss) + } + if !test.Compare(iss.Err().Error(), tc.iss) { + t.Fatalf("e.Compile(%v) returned %v, expected error: %v", tc.expr, iss.Err(), tc.iss) + } + return + } + if iss.Err() != nil { + t.Fatalf("e.Compile(%v) failed: %v", tc.expr, iss.Err()) + } + }) + } +} diff --git a/checker/checker_test.go b/checker/checker_test.go index 551cd2d1b..032bb5eaf 100644 --- a/checker/checker_test.go +++ b/checker/checker_test.go @@ -23,8 +23,6 @@ import ( "github.com/google/cel-go/common/ast" "github.com/google/cel-go/common/containers" "github.com/google/cel-go/common/decls" - "github.com/google/cel-go/common/operators" - "github.com/google/cel-go/common/overloads" "github.com/google/cel-go/common/stdlib" "github.com/google/cel-go/common/types" "github.com/google/cel-go/parser" @@ -1295,92 +1293,6 @@ _&&_(_==_(list~type(list(dyn))^list, )~bool^equals`, outType: types.BoolType, }, - // Homogeneous aggregate type restriction tests. - { - in: `name in [1, 2u, 'string']`, - env: testEnv{ - idents: []*decls.VariableDecl{ - decls.NewVariable("name", types.StringType), - }, - functions: []*decls.FunctionDecl{ - testFunction(t, operators.In, - decls.Overload(overloads.InList, - []*types.Type{ - types.StringType, - types.NewListType(types.StringType), - }, types.BoolType)), - }, - }, - opts: []Option{HomogeneousAggregateLiterals(true)}, - disableStdEnv: true, - out: `@in( - name~string^name, - [ - 1~int, - 2u~uint, - "string"~string - ]~list(string) - )~bool^in_list`, - err: `ERROR: :1:13: expected type 'int' but found 'uint' - | name in [1, 2u, 'string'] - | ............^`, - }, - { - in: `name in [1, 2, 3]`, - env: testEnv{ - idents: []*decls.VariableDecl{ - decls.NewVariable("name", types.StringType), - }, - functions: []*decls.FunctionDecl{ - testFunction(t, operators.In, - decls.Overload(overloads.InList, - []*types.Type{ - types.StringType, - types.NewListType(types.StringType), - }, types.BoolType)), - }, - }, - opts: []Option{HomogeneousAggregateLiterals(true)}, - disableStdEnv: true, - out: `@in( - name~string^name, - [ - 1~int, - 2~int, - 3~int - ]~list(int) - )~!error!`, - err: `ERROR: :1:6: found no matching overload for '@in' applied to '(string, list(int))' - | name in [1, 2, 3] - | .....^`, - }, - { - in: `name in ["1", "2", "3"]`, - env: testEnv{ - idents: []*decls.VariableDecl{ - decls.NewVariable("name", types.StringType), - }, - functions: []*decls.FunctionDecl{ - testFunction(t, operators.In, - decls.Overload(overloads.InList, - []*types.Type{ - types.StringType, - types.NewListType(types.StringType), - }, types.BoolType)), - }, - }, - opts: []Option{HomogeneousAggregateLiterals(true)}, - disableStdEnv: true, - out: `@in( - name~string^name, - [ - "1"~string, - "2"~string, - "3"~string - ]~list(string) - )~bool^in_list`, - outType: types.BoolType, - }, { in: `([[[1]], [[2]], [[3]]][0][0] + [2, 3, {'four': {'five': 'six'}}])[3]`, out: `_[_]( diff --git a/checker/options.go b/checker/options.go index 63b5fbd1a..0560c3813 100644 --- a/checker/options.go +++ b/checker/options.go @@ -32,15 +32,6 @@ func CrossTypeNumericComparisons(enabled bool) Option { } } -// HomogeneousAggregateLiterals toggles support for constructing lists and maps whose elements all -// have the same type. -func HomogeneousAggregateLiterals(enabled bool) Option { - return func(opts *options) error { - opts.homogeneousAggregateLiterals = enabled - return nil - } -} - // ValidatedDeclarations provides a references to validated declarations which will be copied // into new checker instances. func ValidatedDeclarations(env *Env) Option { diff --git a/ext/strings_test.go b/ext/strings_test.go index e6d5e441a..b8fa2ebeb 100644 --- a/ext/strings_test.go +++ b/ext/strings_test.go @@ -1260,6 +1260,7 @@ func TestStringFormat(t *testing.T) { cel.Container("ext"), cel.Abbrevs("google.expr.proto3.test"), cel.Types(&proto3pb.TestAllTypes{}), + cel.ASTValidators(cel.ValidateHomogeneousAggregateLiterals()), NativeTypes( reflect.TypeOf(&TestNestedType{}), reflect.ValueOf(&TestAllTypes{}),