Skip to content

Commit

Permalink
String format validator (#775)
Browse files Browse the repository at this point in the history
* String format validator
* Remove unused method from prior string validation
  • Loading branch information
TristonianJones authored Jul 31, 2023
1 parent 766076f commit 965e9c8
Show file tree
Hide file tree
Showing 12 changed files with 949 additions and 892 deletions.
40 changes: 0 additions & 40 deletions cel/decls.go
Original file line number Diff line number Diff line change
Expand Up @@ -353,43 +353,3 @@ func ExprDeclToDeclaration(d *exprpb.Decl) (EnvOption, error) {
return nil, fmt.Errorf("unsupported decl: %v", d)
}
}

func typeValueToKind(tv ref.Type) (Kind, error) {
switch tv {
case types.BoolType:
return BoolKind, nil
case types.DoubleType:
return DoubleKind, nil
case types.IntType:
return IntKind, nil
case types.UintType:
return UintKind, nil
case types.ListType:
return ListKind, nil
case types.MapType:
return MapKind, nil
case types.StringType:
return StringKind, nil
case types.BytesType:
return BytesKind, nil
case types.DurationType:
return DurationKind, nil
case types.TimestampType:
return TimestampKind, nil
case types.NullType:
return NullTypeKind, nil
case types.TypeType:
return TypeKind, nil
default:
switch tv.TypeName() {
case "dyn":
return DynKind, nil
case "google.protobuf.Any":
return AnyKind, nil
case "optional":
return OpaqueKind, nil
default:
return 0, fmt.Errorf("no known conversion for type of %s", tv.TypeName())
}
}
}
2 changes: 2 additions & 0 deletions cel/options.go
Original file line number Diff line number Diff line change
Expand Up @@ -447,6 +447,8 @@ const (
OptTrackCost EvalOption = 1 << iota

// OptCheckStringFormat enables compile-time checking of string.format calls for syntax/cardinality.
//
// Deprecated: use ext.ValidateFormatString() as this option is now a no-op.
OptCheckStringFormat EvalOption = 1 << iota
)

Expand Down
60 changes: 17 additions & 43 deletions cel/program.go
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@ import (
"fmt"
"sync"

celast "github.com/google/cel-go/common/ast"
"github.com/google/cel-go/common/ast"
"github.com/google/cel-go/common/types"
"github.com/google/cel-go/common/types/ref"
"github.com/google/cel-go/interpreter"
Expand Down Expand Up @@ -148,7 +148,7 @@ func (p *prog) clone() *prog {
// ProgramOption values.
//
// If the program cannot be configured the prog will be nil, with a non-nil error response.
func newProgram(e *Env, ast *Ast, opts []ProgramOption) (Program, error) {
func newProgram(e *Env, a *Ast, opts []ProgramOption) (Program, error) {
// Build the dispatcher, interpreter, and default program value.
disp := interpreter.NewDispatcher()

Expand Down Expand Up @@ -208,34 +208,6 @@ func newProgram(e *Env, ast *Ast, opts []ProgramOption) (Program, error) {
if len(p.regexOptimizations) > 0 {
decorators = append(decorators, interpreter.CompileRegexConstants(p.regexOptimizations...))
}
// Enable compile-time checking of syntax/cardinality for string.format calls.
if p.evalOpts&OptCheckStringFormat == OptCheckStringFormat {
var isValidType func(id int64, validTypes ...ref.Type) (bool, error)
if ast.IsChecked() {
isValidType = func(id int64, validTypes ...ref.Type) (bool, error) {
t := ast.typeMap[id]
if t.Kind() == DynKind {
return true, nil
}
for _, vt := range validTypes {
k, err := typeValueToKind(vt)
if err != nil {
return false, err
}
if t.Kind() == k {
return true, nil
}
}
return false, nil
}
} else {
// if the AST isn't type-checked, short-circuit validation
isValidType = func(id int64, validTypes ...ref.Type) (bool, error) {
return true, nil
}
}
decorators = append(decorators, interpreter.InterpolateFormattedString(isValidType))
}

// Enable exhaustive eval, state tracking and cost tracking last since they require a factory.
if p.evalOpts&(OptExhaustiveEval|OptTrackState|OptTrackCost) != 0 {
Expand Down Expand Up @@ -263,18 +235,18 @@ func newProgram(e *Env, ast *Ast, opts []ProgramOption) (Program, error) {
decs = append(decs, interpreter.Observe(observers...))
}

return p.clone().initInterpretable(ast, decs)
return p.clone().initInterpretable(a, decs)
}
return newProgGen(factory)
}
return p.initInterpretable(ast, decorators)
return p.initInterpretable(a, decorators)
}

func (p *prog) initInterpretable(ast *Ast, decs []interpreter.InterpretableDecorator) (*prog, error) {
func (p *prog) initInterpretable(a *Ast, decs []interpreter.InterpretableDecorator) (*prog, error) {
// Unchecked programs do not contain type and reference information and may be slower to execute.
if !ast.IsChecked() {
if !a.IsChecked() {
interpretable, err :=
p.interpreter.NewUncheckedInterpretable(ast.Expr(), decs...)
p.interpreter.NewUncheckedInterpretable(a.Expr(), decs...)
if err != nil {
return nil, err
}
Expand All @@ -283,12 +255,7 @@ func (p *prog) initInterpretable(ast *Ast, decs []interpreter.InterpretableDecor
}

// When the AST has been checked it contains metadata that can be used to speed up program execution.
checked := &celast.CheckedAST{
Expr: ast.Expr(),
SourceInfo: ast.SourceInfo(),
TypeMap: ast.typeMap,
ReferenceMap: ast.refMap,
}
checked := astToCheckedAST(a)
interpretable, err := p.interpreter.NewInterpretable(checked, decs...)
if err != nil {
return nil, err
Expand Down Expand Up @@ -558,9 +525,16 @@ func (p *evalActivationPool) Put(value any) {
p.Pool.Put(a)
}

var (
emptyEvalState = interpreter.NewEvalState()
func astToCheckedAST(a *Ast) *ast.CheckedAST {
return &ast.CheckedAST{
Expr: a.Expr(),
SourceInfo: a.SourceInfo(),
TypeMap: a.typeMap,
ReferenceMap: a.refMap,
}
}

var (
// activationPool is an internally managed pool of Activation values that wrap map[string]any inputs
activationPool = newEvalActivationPool()

Expand Down
11 changes: 0 additions & 11 deletions cel/validator.go
Original file line number Diff line number Diff line change
Expand Up @@ -244,17 +244,6 @@ func (homogeneousAggregateLiteralValidator) Name() string {
return homogeneousValidatorName
}

// Configure implements the ASTValidatorConfigurer interface and currently sets the list of standard
// and exempt functions from homogeneous aggregate literal checks.
//
// TODO: Move this call into the string.format() ASTValidator once ported.
func (homogeneousAggregateLiteralValidator) Configure(c MutableValidatorConfig) error {
emptyList := []string{}
exemptFunctions := c.GetOrDefault(HomogeneousAggregateLiteralExemptFunctions, emptyList).([]string)
exemptFunctions = append(exemptFunctions, "format")
return c.Set(HomogeneousAggregateLiteralExemptFunctions, exemptFunctions)
}

// Validate validates that all lists and map literals have homogeneous types, i.e. don't contain dyn types.
//
// This validator makes an exception for list and map literals which occur at any level of nesting within
Expand Down
1 change: 1 addition & 0 deletions common/ast/BUILD.bazel
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@ package(
"//cel:__subpackages__",
"//checker:__subpackages__",
"//common:__subpackages__",
"//ext:__subpackages__",
"//interpreter:__subpackages__",
],
licenses = ["notice"], # Apache 2.0
Expand Down
9 changes: 8 additions & 1 deletion common/ast/expr.go
Original file line number Diff line number Diff line change
Expand Up @@ -78,7 +78,7 @@ func KindMatcher(kind ExprKind) ExprMatcher {
}

// FunctionMatcher returns an ExprMatcher which will match NavigableExpr nodes of CallKind type whose
// function name is equal to `funcName`.
// function name is equal to `funcName` regardless of whether it is a member or a global function.
func FunctionMatcher(funcName string) ExprMatcher {
return func(e NavigableExpr) bool {
if e.Kind() != CallKind {
Expand Down Expand Up @@ -217,6 +217,9 @@ type NavigableCallExpr interface {
// FunctionName returns the name of the function.
FunctionName() string

// IsMemberFunction returns whether the call has a non-nil target indicating it is a member function
IsMemberFunction() bool

// Target returns the target of the expression if one is present.
Target() NavigableExpr

Expand Down Expand Up @@ -444,6 +447,10 @@ func (call navigableCallImpl) FunctionName() string {
return call.ToExpr().GetCallExpr().GetFunction()
}

func (call navigableCallImpl) IsMemberFunction() bool {
return call.ToExpr().GetCallExpr().GetTarget() != nil
}

func (call navigableCallImpl) Target() NavigableExpr {
t := call.ToExpr().GetCallExpr().GetTarget()
if t != nil {
Expand Down
2 changes: 2 additions & 0 deletions ext/BUILD.bazel
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@ go_library(
name = "go_default_library",
srcs = [
"encoders.go",
"formatting.go",
"guards.go",
"lists.go",
"math.go",
Expand All @@ -21,6 +22,7 @@ go_library(
deps = [
"//cel:go_default_library",
"//checker/decls:go_default_library",
"//common/ast:go_default_library",
"//common/overloads:go_default_library",
"//common/types:go_default_library",
"//common/types/pb:go_default_library",
Expand Down
Loading

0 comments on commit 965e9c8

Please sign in to comment.