Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

String format validator #775

Merged
merged 3 commits into from
Jul 31, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
40 changes: 0 additions & 40 deletions cel/decls.go
Original file line number Diff line number Diff line change
Expand Up @@ -353,43 +353,3 @@ func ExprDeclToDeclaration(d *exprpb.Decl) (EnvOption, error) {
return nil, fmt.Errorf("unsupported decl: %v", d)
}
}

func typeValueToKind(tv ref.Type) (Kind, error) {
switch tv {
case types.BoolType:
return BoolKind, nil
case types.DoubleType:
return DoubleKind, nil
case types.IntType:
return IntKind, nil
case types.UintType:
return UintKind, nil
case types.ListType:
return ListKind, nil
case types.MapType:
return MapKind, nil
case types.StringType:
return StringKind, nil
case types.BytesType:
return BytesKind, nil
case types.DurationType:
return DurationKind, nil
case types.TimestampType:
return TimestampKind, nil
case types.NullType:
return NullTypeKind, nil
case types.TypeType:
return TypeKind, nil
default:
switch tv.TypeName() {
case "dyn":
return DynKind, nil
case "google.protobuf.Any":
return AnyKind, nil
case "optional":
return OpaqueKind, nil
default:
return 0, fmt.Errorf("no known conversion for type of %s", tv.TypeName())
}
}
}
2 changes: 2 additions & 0 deletions cel/options.go
Original file line number Diff line number Diff line change
Expand Up @@ -447,6 +447,8 @@ const (
OptTrackCost EvalOption = 1 << iota

// OptCheckStringFormat enables compile-time checking of string.format calls for syntax/cardinality.
//
// Deprecated: use ext.ValidateFormatString() as this option is now a no-op.
OptCheckStringFormat EvalOption = 1 << iota
)

Expand Down
60 changes: 17 additions & 43 deletions cel/program.go
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@ import (
"fmt"
"sync"

celast "github.com/google/cel-go/common/ast"
"github.com/google/cel-go/common/ast"
"github.com/google/cel-go/common/types"
"github.com/google/cel-go/common/types/ref"
"github.com/google/cel-go/interpreter"
Expand Down Expand Up @@ -148,7 +148,7 @@ func (p *prog) clone() *prog {
// ProgramOption values.
//
// If the program cannot be configured the prog will be nil, with a non-nil error response.
func newProgram(e *Env, ast *Ast, opts []ProgramOption) (Program, error) {
func newProgram(e *Env, a *Ast, opts []ProgramOption) (Program, error) {
// Build the dispatcher, interpreter, and default program value.
disp := interpreter.NewDispatcher()

Expand Down Expand Up @@ -208,34 +208,6 @@ func newProgram(e *Env, ast *Ast, opts []ProgramOption) (Program, error) {
if len(p.regexOptimizations) > 0 {
decorators = append(decorators, interpreter.CompileRegexConstants(p.regexOptimizations...))
}
// Enable compile-time checking of syntax/cardinality for string.format calls.
if p.evalOpts&OptCheckStringFormat == OptCheckStringFormat {
var isValidType func(id int64, validTypes ...ref.Type) (bool, error)
if ast.IsChecked() {
isValidType = func(id int64, validTypes ...ref.Type) (bool, error) {
t := ast.typeMap[id]
if t.Kind() == DynKind {
return true, nil
}
for _, vt := range validTypes {
k, err := typeValueToKind(vt)
if err != nil {
return false, err
}
if t.Kind() == k {
return true, nil
}
}
return false, nil
}
} else {
// if the AST isn't type-checked, short-circuit validation
isValidType = func(id int64, validTypes ...ref.Type) (bool, error) {
return true, nil
}
}
decorators = append(decorators, interpreter.InterpolateFormattedString(isValidType))
}

// Enable exhaustive eval, state tracking and cost tracking last since they require a factory.
if p.evalOpts&(OptExhaustiveEval|OptTrackState|OptTrackCost) != 0 {
Expand Down Expand Up @@ -263,18 +235,18 @@ func newProgram(e *Env, ast *Ast, opts []ProgramOption) (Program, error) {
decs = append(decs, interpreter.Observe(observers...))
}

return p.clone().initInterpretable(ast, decs)
return p.clone().initInterpretable(a, decs)
}
return newProgGen(factory)
}
return p.initInterpretable(ast, decorators)
return p.initInterpretable(a, decorators)
}

func (p *prog) initInterpretable(ast *Ast, decs []interpreter.InterpretableDecorator) (*prog, error) {
func (p *prog) initInterpretable(a *Ast, decs []interpreter.InterpretableDecorator) (*prog, error) {
// Unchecked programs do not contain type and reference information and may be slower to execute.
if !ast.IsChecked() {
if !a.IsChecked() {
interpretable, err :=
p.interpreter.NewUncheckedInterpretable(ast.Expr(), decs...)
p.interpreter.NewUncheckedInterpretable(a.Expr(), decs...)
if err != nil {
return nil, err
}
Expand All @@ -283,12 +255,7 @@ func (p *prog) initInterpretable(ast *Ast, decs []interpreter.InterpretableDecor
}

// When the AST has been checked it contains metadata that can be used to speed up program execution.
checked := &celast.CheckedAST{
Expr: ast.Expr(),
SourceInfo: ast.SourceInfo(),
TypeMap: ast.typeMap,
ReferenceMap: ast.refMap,
}
checked := astToCheckedAST(a)
interpretable, err := p.interpreter.NewInterpretable(checked, decs...)
if err != nil {
return nil, err
Expand Down Expand Up @@ -558,9 +525,16 @@ func (p *evalActivationPool) Put(value any) {
p.Pool.Put(a)
}

var (
emptyEvalState = interpreter.NewEvalState()
func astToCheckedAST(a *Ast) *ast.CheckedAST {
return &ast.CheckedAST{
Expr: a.Expr(),
SourceInfo: a.SourceInfo(),
TypeMap: a.typeMap,
ReferenceMap: a.refMap,
}
}

var (
// activationPool is an internally managed pool of Activation values that wrap map[string]any inputs
activationPool = newEvalActivationPool()

Expand Down
11 changes: 0 additions & 11 deletions cel/validator.go
Original file line number Diff line number Diff line change
Expand Up @@ -244,17 +244,6 @@ func (homogeneousAggregateLiteralValidator) Name() string {
return homogeneousValidatorName
}

// Configure implements the ASTValidatorConfigurer interface and currently sets the list of standard
// and exempt functions from homogeneous aggregate literal checks.
//
// TODO: Move this call into the string.format() ASTValidator once ported.
func (homogeneousAggregateLiteralValidator) Configure(c MutableValidatorConfig) error {
emptyList := []string{}
exemptFunctions := c.GetOrDefault(HomogeneousAggregateLiteralExemptFunctions, emptyList).([]string)
exemptFunctions = append(exemptFunctions, "format")
return c.Set(HomogeneousAggregateLiteralExemptFunctions, exemptFunctions)
}

// Validate validates that all lists and map literals have homogeneous types, i.e. don't contain dyn types.
//
// This validator makes an exception for list and map literals which occur at any level of nesting within
Expand Down
1 change: 1 addition & 0 deletions common/ast/BUILD.bazel
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@ package(
"//cel:__subpackages__",
"//checker:__subpackages__",
"//common:__subpackages__",
"//ext:__subpackages__",
"//interpreter:__subpackages__",
],
licenses = ["notice"], # Apache 2.0
Expand Down
9 changes: 8 additions & 1 deletion common/ast/expr.go
Original file line number Diff line number Diff line change
Expand Up @@ -78,7 +78,7 @@ func KindMatcher(kind ExprKind) ExprMatcher {
}

// FunctionMatcher returns an ExprMatcher which will match NavigableExpr nodes of CallKind type whose
// function name is equal to `funcName`.
// function name is equal to `funcName` regardless of whether it is a member or a global function.
func FunctionMatcher(funcName string) ExprMatcher {
return func(e NavigableExpr) bool {
if e.Kind() != CallKind {
Expand Down Expand Up @@ -217,6 +217,9 @@ type NavigableCallExpr interface {
// FunctionName returns the name of the function.
FunctionName() string

// IsMemberFunction returns whether the call has a non-nil target indicating it is a member function
IsMemberFunction() bool

// Target returns the target of the expression if one is present.
Target() NavigableExpr

Expand Down Expand Up @@ -444,6 +447,10 @@ func (call navigableCallImpl) FunctionName() string {
return call.ToExpr().GetCallExpr().GetFunction()
}

func (call navigableCallImpl) IsMemberFunction() bool {
return call.ToExpr().GetCallExpr().GetTarget() != nil
}

func (call navigableCallImpl) Target() NavigableExpr {
t := call.ToExpr().GetCallExpr().GetTarget()
if t != nil {
Expand Down
2 changes: 2 additions & 0 deletions ext/BUILD.bazel
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@ go_library(
name = "go_default_library",
srcs = [
"encoders.go",
"formatting.go",
"guards.go",
"lists.go",
"math.go",
Expand All @@ -21,6 +22,7 @@ go_library(
deps = [
"//cel:go_default_library",
"//checker/decls:go_default_library",
"//common/ast:go_default_library",
"//common/overloads:go_default_library",
"//common/types:go_default_library",
"//common/types/pb:go_default_library",
Expand Down
Loading