diff --git a/checker/checker.go b/checker/checker.go index 9b0fec8ab..52491b3f8 100644 --- a/checker/checker.go +++ b/checker/checker.go @@ -10,17 +10,7 @@ import ( "github.com/antonmedv/expr/parser" ) -func Check(tree *parser.Tree, config *conf.Config) (t reflect.Type, err error) { - defer func() { - if r := recover(); r != nil { - if h, ok := r.(file.Error); ok { - err = fmt.Errorf("%v", h.Format(tree.Source)) - } else { - err = fmt.Errorf("%v", r) - } - } - }() - +func Check(tree *parser.Tree, config *conf.Config) (reflect.Type, error) { v := &visitor{ collections: make([]reflect.Type, 0), } @@ -32,7 +22,7 @@ func Check(tree *parser.Tree, config *conf.Config) (t reflect.Type, err error) { v.defaultType = config.DefaultType } - t = v.visit(tree.Node) + t := v.visit(tree.Node) if v.expect != reflect.Invalid { switch v.expect { @@ -47,7 +37,11 @@ func Check(tree *parser.Tree, config *conf.Config) (t reflect.Type, err error) { } } - return + if v.err != nil { + return t, fmt.Errorf("%v", v.err.Format(tree.Source)) + } + + return t, nil } type visitor struct { @@ -57,6 +51,7 @@ type visitor struct { collections []reflect.Type strict bool defaultType reflect.Type + err *file.Error } func (v *visitor) visit(node ast.Node) reflect.Type { @@ -111,14 +106,17 @@ func (v *visitor) visit(node ast.Node) reflect.Type { return t } -func (v *visitor) error(node ast.Node, format string, args ...interface{}) file.Error { - return file.Error{ - Location: node.Location(), - Message: fmt.Sprintf(format, args...), +func (v *visitor) error(node ast.Node, format string, args ...interface{}) reflect.Type { + if v.err == nil { // show first error + v.err = &file.Error{ + Location: node.Location(), + Message: fmt.Sprintf(format, args...), + } } + return interfaceType // interface represent undefined type } -func (v *visitor) NilNode(node *ast.NilNode) reflect.Type { +func (v *visitor) NilNode(*ast.NilNode) reflect.Type { return nilType } @@ -135,22 +133,22 @@ func (v *visitor) IdentifierNode(node *ast.IdentifierNode) reflect.Type { } return interfaceType } - panic(v.error(node, "unknown name %v", node.Value)) + return v.error(node, "unknown name %v", node.Value) } -func (v *visitor) IntegerNode(node *ast.IntegerNode) reflect.Type { +func (v *visitor) IntegerNode(*ast.IntegerNode) reflect.Type { return integerType } -func (v *visitor) FloatNode(node *ast.FloatNode) reflect.Type { +func (v *visitor) FloatNode(*ast.FloatNode) reflect.Type { return floatType } -func (v *visitor) BoolNode(node *ast.BoolNode) reflect.Type { +func (v *visitor) BoolNode(*ast.BoolNode) reflect.Type { return boolType } -func (v *visitor) StringNode(node *ast.StringNode) reflect.Type { +func (v *visitor) StringNode(*ast.StringNode) reflect.Type { return stringType } @@ -170,10 +168,10 @@ func (v *visitor) UnaryNode(node *ast.UnaryNode) reflect.Type { } default: - panic(v.error(node, "unknown operator (%v)", node.Operator)) + return v.error(node, "unknown operator (%v)", node.Operator) } - panic(v.error(node, `invalid operation: %v (mismatched type %v)`, node.Operator, t)) + return v.error(node, `invalid operation: %v (mismatched type %v)`, node.Operator, t) } func (v *visitor) BinaryNode(node *ast.BinaryNode) reflect.Type { @@ -255,11 +253,11 @@ func (v *visitor) BinaryNode(node *ast.BinaryNode) reflect.Type { } default: - panic(v.error(node, "unknown operator (%v)", node.Operator)) + return v.error(node, "unknown operator (%v)", node.Operator) } - panic(v.error(node, `invalid operation: %v (mismatched types %v and %v)`, node.Operator, l, r)) + return v.error(node, `invalid operation: %v (mismatched types %v and %v)`, node.Operator, l, r) } func (v *visitor) MatchesNode(node *ast.MatchesNode) reflect.Type { @@ -270,7 +268,7 @@ func (v *visitor) MatchesNode(node *ast.MatchesNode) reflect.Type { return boolType } - panic(v.error(node, `invalid operation: matches (mismatched types %v and %v)`, l, r)) + return v.error(node, `invalid operation: matches (mismatched types %v and %v)`, l, r) } func (v *visitor) PropertyNode(node *ast.PropertyNode) reflect.Type { @@ -280,7 +278,7 @@ func (v *visitor) PropertyNode(node *ast.PropertyNode) reflect.Type { return t } - panic(v.error(node, "type %v has no field %v", t, node.Property)) + return v.error(node, "type %v has no field %v", t, node.Property) } func (v *visitor) IndexNode(node *ast.IndexNode) reflect.Type { @@ -289,12 +287,12 @@ func (v *visitor) IndexNode(node *ast.IndexNode) reflect.Type { if t, ok := indexType(t); ok { if !isInteger(i) && !isString(i) { - panic(v.error(node, "invalid operation: cannot use %v as index to %v", i, t)) + return v.error(node, "invalid operation: cannot use %v as index to %v", i, t) } return t } - panic(v.error(node, "invalid operation: type %v does not support indexing", t)) + return v.error(node, "invalid operation: type %v does not support indexing", t) } func (v *visitor) SliceNode(node *ast.SliceNode) reflect.Type { @@ -306,19 +304,19 @@ func (v *visitor) SliceNode(node *ast.SliceNode) reflect.Type { if node.From != nil { from := v.visit(node.From) if !isInteger(from) { - panic(v.error(node.From, "invalid operation: non-integer slice index %v", from)) + return v.error(node.From, "invalid operation: non-integer slice index %v", from) } } if node.To != nil { to := v.visit(node.To) if !isInteger(to) { - panic(v.error(node.To, "invalid operation: non-integer slice index %v", to)) + return v.error(node.To, "invalid operation: non-integer slice index %v", to) } } return t } - panic(v.error(node, "invalid operation: cannot slice %v", t)) + return v.error(node, "invalid operation: cannot slice %v", t) } func (v *visitor) FunctionNode(node *ast.FunctionNode) reflect.Type { @@ -349,7 +347,7 @@ func (v *visitor) FunctionNode(node *ast.FunctionNode) reflect.Type { } return interfaceType } - panic(v.error(node, "unknown func %v", node.Name)) + return v.error(node, "unknown func %v", node.Name) } func (v *visitor) MethodNode(node *ast.MethodNode) reflect.Type { @@ -359,7 +357,7 @@ func (v *visitor) MethodNode(node *ast.MethodNode) reflect.Type { return v.checkFunc(fn, method, node, node.Method, node.Arguments) } } - panic(v.error(node, "type %v has no method %v", t, node.Method)) + return v.error(node, "type %v has no method %v", t, node.Method) } // checkFunc checks func arguments and returns "return type" of func or method. @@ -369,10 +367,10 @@ func (v *visitor) checkFunc(fn reflect.Type, method bool, node ast.Node, name st } if fn.NumOut() == 0 { - panic(v.error(node, "func %v doesn't return value", name)) + return v.error(node, "func %v doesn't return value", name) } if fn.NumOut() != 1 { - panic(v.error(node, "func %v returns more then one value", name)) + return v.error(node, "func %v returns more then one value", name) } numIn := fn.NumIn() @@ -385,14 +383,14 @@ func (v *visitor) checkFunc(fn reflect.Type, method bool, node ast.Node, name st if fn.IsVariadic() { if len(arguments) < numIn-1 { - panic(v.error(node, "not enough arguments to call %v", name)) + return v.error(node, "not enough arguments to call %v", name) } } else { if len(arguments) > numIn { - panic(v.error(node, "too many arguments to call %v", name)) + return v.error(node, "too many arguments to call %v", name) } if len(arguments) < numIn { - panic(v.error(node, "not enough arguments to call %v", name)) + return v.error(node, "not enough arguments to call %v", name) } } @@ -426,7 +424,7 @@ func (v *visitor) checkFunc(fn reflect.Type, method bool, node ast.Node, name st } if !t.AssignableTo(in) { - panic(v.error(arg, "cannot use %v as argument (type %v) to call %v ", t, in, name)) + return v.error(arg, "cannot use %v as argument (type %v) to call %v ", t, in, name) } } @@ -441,12 +439,12 @@ func (v *visitor) BuiltinNode(node *ast.BuiltinNode) reflect.Type { if isArray(param) || isMap(param) || isString(param) { return integerType } - panic(v.error(node, "invalid argument for len (type %v)", param)) + return v.error(node, "invalid argument for len (type %v)", param) case "all", "none", "any", "one": collection := v.visit(node.Arguments[0]) if !isArray(collection) { - panic(v.error(node.Arguments[0], "builtin %v takes only array (got %v)", node.Name, collection)) + return v.error(node.Arguments[0], "builtin %v takes only array (got %v)", node.Name, collection) } v.collections = append(v.collections, collection) @@ -458,16 +456,16 @@ func (v *visitor) BuiltinNode(node *ast.BuiltinNode) reflect.Type { closure.NumIn() == 1 && isInterface(closure.In(0)) { if !isBool(closure.Out(0)) { - panic(v.error(node.Arguments[1], "closure should return boolean (got %v)", closure.Out(0).String())) + return v.error(node.Arguments[1], "closure should return boolean (got %v)", closure.Out(0).String()) } return boolType } - panic(v.error(node.Arguments[1], "closure should has one input and one output param")) + return v.error(node.Arguments[1], "closure should has one input and one output param") case "filter": collection := v.visit(node.Arguments[0]) if !isArray(collection) { - panic(v.error(node.Arguments[0], "builtin %v takes only array (got %v)", node.Name, collection)) + return v.error(node.Arguments[0], "builtin %v takes only array (got %v)", node.Name, collection) } v.collections = append(v.collections, collection) @@ -479,16 +477,16 @@ func (v *visitor) BuiltinNode(node *ast.BuiltinNode) reflect.Type { closure.NumIn() == 1 && isInterface(closure.In(0)) { if !isBool(closure.Out(0)) { - panic(v.error(node.Arguments[1], "closure should return boolean (got %v)", closure.Out(0).String())) + return v.error(node.Arguments[1], "closure should return boolean (got %v)", closure.Out(0).String()) } return arrayType } - panic(v.error(node.Arguments[1], "closure should has one input and one output param")) + return v.error(node.Arguments[1], "closure should has one input and one output param") case "map": collection := v.visit(node.Arguments[0]) if !isArray(collection) { - panic(v.error(node.Arguments[0], "builtin %v takes only array (got %v)", node.Name, collection)) + return v.error(node.Arguments[0], "builtin %v takes only array (got %v)", node.Name, collection) } v.collections = append(v.collections, collection) @@ -501,12 +499,12 @@ func (v *visitor) BuiltinNode(node *ast.BuiltinNode) reflect.Type { return reflect.SliceOf(closure.Out(0)) } - panic(v.error(node.Arguments[1], "closure should has one input and one output param")) + return v.error(node.Arguments[1], "closure should has one input and one output param") case "count": collection := v.visit(node.Arguments[0]) if !isArray(collection) { - panic(v.error(node.Arguments[0], "builtin %v takes only array (got %v)", node.Name, collection)) + return v.error(node.Arguments[0], "builtin %v takes only array (got %v)", node.Name, collection) } v.collections = append(v.collections, collection) @@ -517,15 +515,15 @@ func (v *visitor) BuiltinNode(node *ast.BuiltinNode) reflect.Type { closure.NumOut() == 1 && closure.NumIn() == 1 && isInterface(closure.In(0)) { if !isBool(closure.Out(0)) { - panic(v.error(node.Arguments[1], "closure should return boolean (got %v)", closure.Out(0).String())) + return v.error(node.Arguments[1], "closure should return boolean (got %v)", closure.Out(0).String()) } return integerType } - panic(v.error(node.Arguments[1], "closure should has one input and one output param")) + return v.error(node.Arguments[1], "closure should has one input and one output param") default: - panic(v.error(node, "unknown builtin %v", node.Name)) + return v.error(node, "unknown builtin %v", node.Name) } } @@ -536,7 +534,7 @@ func (v *visitor) ClosureNode(node *ast.ClosureNode) reflect.Type { func (v *visitor) PointerNode(node *ast.PointerNode) reflect.Type { if len(v.collections) == 0 { - panic(v.error(node, "cannot use pointer accessor outside closure")) + return v.error(node, "cannot use pointer accessor outside closure") } collection := v.collections[len(v.collections)-1] @@ -544,13 +542,13 @@ func (v *visitor) PointerNode(node *ast.PointerNode) reflect.Type { if t, ok := indexType(collection); ok { return t } - panic(v.error(node, "cannot use %v as array", collection)) + return v.error(node, "cannot use %v as array", collection) } func (v *visitor) ConditionalNode(node *ast.ConditionalNode) reflect.Type { c := v.visit(node.Cond) if !isBool(c) { - panic(v.error(node.Cond, "non-bool expression (type %v) used as condition", c)) + return v.error(node.Cond, "non-bool expression (type %v) used as condition", c) } t1 := v.visit(node.Exp1) diff --git a/expr.go b/expr.go index ed9fa4379..961138ec1 100644 --- a/expr.go +++ b/expr.go @@ -143,15 +143,24 @@ func Compile(input string, ops ...Option) (*vm.Program, error) { } _, err = checker.Check(tree, config) - if err != nil { + + // If we have a patch to apply, it may fix out error and + // second type check is needed. Otherwise it is an error. + if err != nil && len(config.Visitors) == 0 { return nil, err } // Patch operators before Optimize, as we may also mark it as ConstExpr. compiler.PatchOperators(&tree.Node, config) - for _, v := range config.Visitors { - ast.Walk(&tree.Node, v) + if len(config.Visitors) >= 0 { + for _, v := range config.Visitors { + ast.Walk(&tree.Node, v) + } + _, err = checker.Check(tree, config) + if err != nil { + return nil, err + } } if config.Optimize { diff --git a/expr_test.go b/expr_test.go index a1930105f..763ee32a5 100644 --- a/expr_test.go +++ b/expr_test.go @@ -3,6 +3,7 @@ package expr_test import ( "fmt" "github.com/antonmedv/expr/ast" + "reflect" "strings" "testing" "time" @@ -986,6 +987,36 @@ func TestConstExpr_error_no_env(t *testing.T) { require.Equal(t, "no environment for const expression: divide", err.Error()) } +func TestPatch(t *testing.T) { + program, err := expr.Compile( + `Ticket == "$100" and "$90" != Ticket + "0"`, + expr.Env(mockEnv{}), + expr.Patch(&stringerPatcher{}), + ) + require.NoError(t, err) + + env := mockEnv{ + Ticket: &ticket{Price: 100}, + } + output, err := expr.Run(program, env) + require.NoError(t, err) + require.Equal(t, true, output) +} + +func TestPatch_length(t *testing.T) { + program, err := expr.Compile( + `String.length == 5`, + expr.Env(mockEnv{}), + expr.Patch(&lengthPatcher{}), + ) + require.NoError(t, err) + + env := mockEnv{String: "hello"} + output, err := expr.Run(program, env) + require.NoError(t, err) + require.Equal(t, true, output) +} + // // Mock types // @@ -1126,3 +1157,37 @@ func (p *patcher) Exit(node *ast.Node) { }) } } + +var stringer = reflect.TypeOf((*fmt.Stringer)(nil)).Elem() + +type stringerPatcher struct{} + +func (p *stringerPatcher) Enter(_ *ast.Node) {} +func (p *stringerPatcher) Exit(node *ast.Node) { + t := (*node).Type() + if t == nil { + return + } + if t.Implements(stringer) { + ast.Patch(node, &ast.MethodNode{ + Node: *node, + Method: "String", + }) + } + +} + +type lengthPatcher struct{} + +func (p *lengthPatcher) Enter(_ *ast.Node) {} +func (p *lengthPatcher) Exit(node *ast.Node) { + switch n := (*node).(type) { + case *ast.PropertyNode: + if n.Property == "length" { + ast.Patch(node, &ast.BuiltinNode{ + Name: "len", + Arguments: []ast.Node{n.Node}, + }) + } + } +}