Skip to content

Commit

Permalink
Validators for common literals and use cases.
Browse files Browse the repository at this point in the history
Introduce validators for duration, timestamp, and regex string literals
as well as support for homogeneous aggregate literals with a carveout
for string format calls.
  • Loading branch information
TristonianJones committed Jul 7, 2023
1 parent 5dc9173 commit 308ea29
Show file tree
Hide file tree
Showing 9 changed files with 594 additions and 212 deletions.
1 change: 1 addition & 0 deletions cel/BUILD.bazel
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@ go_library(
"macro.go",
"options.go",
"program.go",
"validator.go",
],
importpath = "github.com/google/cel-go/cel",
visibility = ["//visibility:public"],
Expand Down
99 changes: 0 additions & 99 deletions cel/cel_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -275,105 +275,6 @@ func TestCustomEnv(t *testing.T) {
})
}

func TestHomogeneousAggregateLiterals(t *testing.T) {
e, err := NewCustomEnv(
Variable("name", StringType),
Function(operators.In,
Overload(overloads.InList, []*Type{StringType, ListType(StringType)}, BoolType,
BinaryBinding(func(lhs, rhs ref.Val) ref.Val {
return rhs.(traits.Container).Contains(lhs)
}),
),
Overload(overloads.InMap, []*Type{StringType, MapType(StringType, BoolType)}, BoolType,
BinaryBinding(func(lhs, rhs ref.Val) ref.Val {
return rhs.(traits.Container).Contains(lhs)
}),
),
),
HomogeneousAggregateLiterals(),
)
if err != nil {
t.Fatalf("NewCustomEnv() failed: %v", err)
}

tests := []struct {
name string
expr string
iss string
vars map[string]any
out ref.Val
}{
{
name: "err_list",
expr: `name in ['hello', 0]`,
iss: `
ERROR: <input>:1:19: expected type 'string' but found 'int'
| name in ['hello', 0]
| ..................^`,
},
{
name: "err_map_key",
expr: `name in {'hello':'world', 1:'!'}`,
iss: `
ERROR: <input>:1:6: found no matching overload for '@in' applied to '(string, map(!error!, string))'
| name in {'hello':'world', 1:'!'}
| .....^
ERROR: <input>:1:27: expected type 'string' but found 'int'
| name in {'hello':'world', 1:'!'}
| ..........................^`,
},
{
name: "err_map_value",
expr: `name in {'hello':'world', 'goodbye':true}`,
iss: `
ERROR: <input>:1:37: expected type 'string' but found 'bool'
| name in {'hello':'world', 'goodbye':true}
| ....................................^`,
},
{
name: "ok_list",
expr: `name in ['hello', 'world']`,
vars: map[string]any{"name": "world"},
out: types.True,
},
{
name: "ok_map",
expr: `name in {'hello': false, 'world': true}`,
vars: map[string]any{"name": "world"},
out: types.True,
},
}
for _, tst := range tests {
tc := tst
t.Run(tc.name, func(t *testing.T) {
ast, iss := e.Compile(tc.expr)
if tc.iss != "" {
if iss.Err() == nil {
t.Fatalf("e.Compile(%v) returned ast, expected error: %v", tc.expr, tc.iss)
}
if !test.Compare(iss.Err().Error(), tc.iss) {
t.Fatalf("e.Compile(%v) returned %v, expected error: %v", tc.expr, iss.Err(), tc.iss)
}
return
}
if iss.Err() != nil {
t.Fatalf("e.Compile(%v) failed: %v", tc.expr, iss.Err())
}
prg, err := e.Program(ast)
if err != nil {
t.Fatalf("e.Program() failed: %v", err)
}
out, _, err := prg.Eval(tc.vars)
if err != nil {
t.Fatalf("prg.Eval(%v) errored: %v", tc.vars, err)
}
if out != tc.out {
t.Errorf("program eval got %v, wanted %v", out, tc.out)
}
})
}
}

func TestCrossTypeNumericComparisons(t *testing.T) {
tests := []struct {
name string
Expand Down
31 changes: 27 additions & 4 deletions cel/env.go
Original file line number Diff line number Diff line change
Expand Up @@ -118,6 +118,7 @@ type Env struct {
features map[int]bool
appliedFeatures map[int]bool
libraries map[string]bool
validators []ASTValidator

// Internal parser representation
prsr *parser.Parser
Expand Down Expand Up @@ -178,6 +179,7 @@ func NewCustomEnv(opts ...EnvOption) (*Env, error) {
features: map[int]bool{},
appliedFeatures: map[int]bool{},
libraries: map[string]bool{},
validators: []ASTValidator{},
progOpts: []ProgramOption{},
}).configure(opts)
}
Expand Down Expand Up @@ -207,12 +209,22 @@ func (e *Env) Check(ast *Ast) (*Ast, *Issues) {
}
// Manually create the Ast to ensure that the Ast source information (which may be more
// detailed than the information provided by Check), is returned to the caller.
return &Ast{
ast = &Ast{
source: ast.Source(),
expr: res.Expr,
info: res.SourceInfo,
refMap: res.ReferenceMap,
typeMap: res.TypeMap}, nil
typeMap: res.TypeMap}

// Apply additional validators on the type-checked result.
iss := NewIssues(errs)
for _, v := range e.validators {
v.Validate(e, ast, iss)
}
if iss.Err() != nil {
return nil, iss
}
return ast, nil
}

// Compile combines the Parse and Check phases CEL program compilation to produce an Ast and
Expand Down Expand Up @@ -331,6 +343,8 @@ func (e *Env) Extend(opts ...EnvOption) (*Env, error) {
for k, v := range e.libraries {
libsCopy[k] = v
}
validatorsCopy := make([]ASTValidator, len(e.validators))
copy(validatorsCopy, e.validators)

ext := &Env{
Container: e.Container,
Expand All @@ -342,6 +356,7 @@ func (e *Env) Extend(opts ...EnvOption) (*Env, error) {
features: featuresCopy,
appliedFeatures: appliedFeaturesCopy,
libraries: libsCopy,
validators: validatorsCopy,
provider: provider,
chkOpts: chkOptsCopy,
prsrOpts: prsrOptsCopy,
Expand All @@ -362,6 +377,16 @@ func (e *Env) HasLibrary(libName string) bool {
return exists && configured
}

// HasValidator returns whether a specific ASTValidator has been configured in the environment.
func (e *Env) HasValidator(name string) bool {
for _, v := range e.validators {
if v.Name() == name {
return true
}
}
return false
}

// Parse parses the input expression value `txt` to a Ast and/or a set of Issues.
//
// This form of Parse creates a Source value for the input `txt` and forwards to the
Expand Down Expand Up @@ -534,8 +559,6 @@ func (e *Env) initChecker() (*checker.Env, error) {
chkOpts := []checker.Option{}
chkOpts = append(chkOpts, e.chkOpts...)
chkOpts = append(chkOpts,
checker.HomogeneousAggregateLiterals(
e.HasFeature(featureDisableDynamicAggregateLiterals)),
checker.CrossTypeNumericComparisons(
e.HasFeature(featureCrossTypeNumericComparisons)))

Expand Down
13 changes: 1 addition & 12 deletions cel/options.go
Original file line number Diff line number Diff line change
Expand Up @@ -41,13 +41,6 @@ import (
const (
_ = iota

// Disallow heterogeneous aggregate (list, map) literals.
// Note, it is still possible to have heterogeneous aggregates when
// provided as variables to the expression, as well as via conversion
// of well-known dynamic types, or with unchecked expressions.
// Affects checking. Provides a subset of standard behavior.
featureDisableDynamicAggregateLiterals

// Enable the tracking of function call expressions replaced by macros.
featureEnableMacroCallTracking

Expand All @@ -63,10 +56,6 @@ const (
// is not already in UTC.
featureDefaultUTCTimeZone

// Enable the use of optional types in the syntax, type-system, type-checking,
// and runtime.
featureOptionalTypes

// Enable the serialization of logical operator ASTs as variadic calls, thus
// compressing the logic graph to a single call when multiple like-operator
// expressions occur: e.g. a && b && c && d -> call(_&&_, [a, b, c, d])
Expand Down Expand Up @@ -157,7 +146,7 @@ func EagerlyValidateDeclarations(enabled bool) EnvOption {
// expression, as well as via conversion of well-known dynamic types, or with unchecked
// expressions.
func HomogeneousAggregateLiterals() EnvOption {
return features(featureDisableDynamicAggregateLiterals, true)
return ASTValidators(ValidateHomogeneousAggregateLiterals())
}

// variadicLogicalOperatorASTs flatten like-operator chained logical expressions into a single
Expand Down
Loading

0 comments on commit 308ea29

Please sign in to comment.