Skip to content

Commit

Permalink
Add double type-check when a patch is present
Browse files Browse the repository at this point in the history
  • Loading branch information
antonmedv committed Mar 17, 2020
1 parent 8ec0158 commit 546dfc9
Show file tree
Hide file tree
Showing 3 changed files with 133 additions and 61 deletions.
114 changes: 56 additions & 58 deletions checker/checker.go
Original file line number Diff line number Diff line change
Expand Up @@ -10,17 +10,7 @@ import (
"github.com/antonmedv/expr/parser"
)

func Check(tree *parser.Tree, config *conf.Config) (t reflect.Type, err error) {
defer func() {
if r := recover(); r != nil {
if h, ok := r.(file.Error); ok {
err = fmt.Errorf("%v", h.Format(tree.Source))
} else {
err = fmt.Errorf("%v", r)
}
}
}()

func Check(tree *parser.Tree, config *conf.Config) (reflect.Type, error) {
v := &visitor{
collections: make([]reflect.Type, 0),
}
Expand All @@ -32,7 +22,7 @@ func Check(tree *parser.Tree, config *conf.Config) (t reflect.Type, err error) {
v.defaultType = config.DefaultType
}

t = v.visit(tree.Node)
t := v.visit(tree.Node)

if v.expect != reflect.Invalid {
switch v.expect {
Expand All @@ -47,7 +37,11 @@ func Check(tree *parser.Tree, config *conf.Config) (t reflect.Type, err error) {
}
}

return
if v.err != nil {
return t, fmt.Errorf("%v", v.err.Format(tree.Source))
}

return t, nil
}

type visitor struct {
Expand All @@ -57,6 +51,7 @@ type visitor struct {
collections []reflect.Type
strict bool
defaultType reflect.Type
err *file.Error
}

func (v *visitor) visit(node ast.Node) reflect.Type {
Expand Down Expand Up @@ -111,14 +106,17 @@ func (v *visitor) visit(node ast.Node) reflect.Type {
return t
}

func (v *visitor) error(node ast.Node, format string, args ...interface{}) file.Error {
return file.Error{
Location: node.Location(),
Message: fmt.Sprintf(format, args...),
func (v *visitor) error(node ast.Node, format string, args ...interface{}) reflect.Type {
if v.err == nil { // show first error
v.err = &file.Error{
Location: node.Location(),
Message: fmt.Sprintf(format, args...),
}
}
return interfaceType // interface represent undefined type
}

func (v *visitor) NilNode(node *ast.NilNode) reflect.Type {
func (v *visitor) NilNode(*ast.NilNode) reflect.Type {
return nilType
}

Expand All @@ -135,22 +133,22 @@ func (v *visitor) IdentifierNode(node *ast.IdentifierNode) reflect.Type {
}
return interfaceType
}
panic(v.error(node, "unknown name %v", node.Value))
return v.error(node, "unknown name %v", node.Value)
}

func (v *visitor) IntegerNode(node *ast.IntegerNode) reflect.Type {
func (v *visitor) IntegerNode(*ast.IntegerNode) reflect.Type {
return integerType
}

func (v *visitor) FloatNode(node *ast.FloatNode) reflect.Type {
func (v *visitor) FloatNode(*ast.FloatNode) reflect.Type {
return floatType
}

func (v *visitor) BoolNode(node *ast.BoolNode) reflect.Type {
func (v *visitor) BoolNode(*ast.BoolNode) reflect.Type {
return boolType
}

func (v *visitor) StringNode(node *ast.StringNode) reflect.Type {
func (v *visitor) StringNode(*ast.StringNode) reflect.Type {
return stringType
}

Expand All @@ -170,10 +168,10 @@ func (v *visitor) UnaryNode(node *ast.UnaryNode) reflect.Type {
}

default:
panic(v.error(node, "unknown operator (%v)", node.Operator))
return v.error(node, "unknown operator (%v)", node.Operator)
}

panic(v.error(node, `invalid operation: %v (mismatched type %v)`, node.Operator, t))
return v.error(node, `invalid operation: %v (mismatched type %v)`, node.Operator, t)
}

func (v *visitor) BinaryNode(node *ast.BinaryNode) reflect.Type {
Expand Down Expand Up @@ -255,11 +253,11 @@ func (v *visitor) BinaryNode(node *ast.BinaryNode) reflect.Type {
}

default:
panic(v.error(node, "unknown operator (%v)", node.Operator))
return v.error(node, "unknown operator (%v)", node.Operator)

}

panic(v.error(node, `invalid operation: %v (mismatched types %v and %v)`, node.Operator, l, r))
return v.error(node, `invalid operation: %v (mismatched types %v and %v)`, node.Operator, l, r)
}

func (v *visitor) MatchesNode(node *ast.MatchesNode) reflect.Type {
Expand All @@ -270,7 +268,7 @@ func (v *visitor) MatchesNode(node *ast.MatchesNode) reflect.Type {
return boolType
}

panic(v.error(node, `invalid operation: matches (mismatched types %v and %v)`, l, r))
return v.error(node, `invalid operation: matches (mismatched types %v and %v)`, l, r)
}

func (v *visitor) PropertyNode(node *ast.PropertyNode) reflect.Type {
Expand All @@ -280,7 +278,7 @@ func (v *visitor) PropertyNode(node *ast.PropertyNode) reflect.Type {
return t
}

panic(v.error(node, "type %v has no field %v", t, node.Property))
return v.error(node, "type %v has no field %v", t, node.Property)
}

func (v *visitor) IndexNode(node *ast.IndexNode) reflect.Type {
Expand All @@ -289,12 +287,12 @@ func (v *visitor) IndexNode(node *ast.IndexNode) reflect.Type {

if t, ok := indexType(t); ok {
if !isInteger(i) && !isString(i) {
panic(v.error(node, "invalid operation: cannot use %v as index to %v", i, t))
return v.error(node, "invalid operation: cannot use %v as index to %v", i, t)
}
return t
}

panic(v.error(node, "invalid operation: type %v does not support indexing", t))
return v.error(node, "invalid operation: type %v does not support indexing", t)
}

func (v *visitor) SliceNode(node *ast.SliceNode) reflect.Type {
Expand All @@ -306,19 +304,19 @@ func (v *visitor) SliceNode(node *ast.SliceNode) reflect.Type {
if node.From != nil {
from := v.visit(node.From)
if !isInteger(from) {
panic(v.error(node.From, "invalid operation: non-integer slice index %v", from))
return v.error(node.From, "invalid operation: non-integer slice index %v", from)
}
}
if node.To != nil {
to := v.visit(node.To)
if !isInteger(to) {
panic(v.error(node.To, "invalid operation: non-integer slice index %v", to))
return v.error(node.To, "invalid operation: non-integer slice index %v", to)
}
}
return t
}

panic(v.error(node, "invalid operation: cannot slice %v", t))
return v.error(node, "invalid operation: cannot slice %v", t)
}

func (v *visitor) FunctionNode(node *ast.FunctionNode) reflect.Type {
Expand Down Expand Up @@ -349,7 +347,7 @@ func (v *visitor) FunctionNode(node *ast.FunctionNode) reflect.Type {
}
return interfaceType
}
panic(v.error(node, "unknown func %v", node.Name))
return v.error(node, "unknown func %v", node.Name)
}

func (v *visitor) MethodNode(node *ast.MethodNode) reflect.Type {
Expand All @@ -359,7 +357,7 @@ func (v *visitor) MethodNode(node *ast.MethodNode) reflect.Type {
return v.checkFunc(fn, method, node, node.Method, node.Arguments)
}
}
panic(v.error(node, "type %v has no method %v", t, node.Method))
return v.error(node, "type %v has no method %v", t, node.Method)
}

// checkFunc checks func arguments and returns "return type" of func or method.
Expand All @@ -369,10 +367,10 @@ func (v *visitor) checkFunc(fn reflect.Type, method bool, node ast.Node, name st
}

if fn.NumOut() == 0 {
panic(v.error(node, "func %v doesn't return value", name))
return v.error(node, "func %v doesn't return value", name)
}
if fn.NumOut() != 1 {
panic(v.error(node, "func %v returns more then one value", name))
return v.error(node, "func %v returns more then one value", name)
}

numIn := fn.NumIn()
Expand All @@ -385,14 +383,14 @@ func (v *visitor) checkFunc(fn reflect.Type, method bool, node ast.Node, name st

if fn.IsVariadic() {
if len(arguments) < numIn-1 {
panic(v.error(node, "not enough arguments to call %v", name))
return v.error(node, "not enough arguments to call %v", name)
}
} else {
if len(arguments) > numIn {
panic(v.error(node, "too many arguments to call %v", name))
return v.error(node, "too many arguments to call %v", name)
}
if len(arguments) < numIn {
panic(v.error(node, "not enough arguments to call %v", name))
return v.error(node, "not enough arguments to call %v", name)
}
}

Expand Down Expand Up @@ -426,7 +424,7 @@ func (v *visitor) checkFunc(fn reflect.Type, method bool, node ast.Node, name st
}

if !t.AssignableTo(in) {
panic(v.error(arg, "cannot use %v as argument (type %v) to call %v ", t, in, name))
return v.error(arg, "cannot use %v as argument (type %v) to call %v ", t, in, name)
}
}

Expand All @@ -441,12 +439,12 @@ func (v *visitor) BuiltinNode(node *ast.BuiltinNode) reflect.Type {
if isArray(param) || isMap(param) || isString(param) {
return integerType
}
panic(v.error(node, "invalid argument for len (type %v)", param))
return v.error(node, "invalid argument for len (type %v)", param)

case "all", "none", "any", "one":
collection := v.visit(node.Arguments[0])
if !isArray(collection) {
panic(v.error(node.Arguments[0], "builtin %v takes only array (got %v)", node.Name, collection))
return v.error(node.Arguments[0], "builtin %v takes only array (got %v)", node.Name, collection)
}

v.collections = append(v.collections, collection)
Expand All @@ -458,16 +456,16 @@ func (v *visitor) BuiltinNode(node *ast.BuiltinNode) reflect.Type {
closure.NumIn() == 1 && isInterface(closure.In(0)) {

if !isBool(closure.Out(0)) {
panic(v.error(node.Arguments[1], "closure should return boolean (got %v)", closure.Out(0).String()))
return v.error(node.Arguments[1], "closure should return boolean (got %v)", closure.Out(0).String())
}
return boolType
}
panic(v.error(node.Arguments[1], "closure should has one input and one output param"))
return v.error(node.Arguments[1], "closure should has one input and one output param")

case "filter":
collection := v.visit(node.Arguments[0])
if !isArray(collection) {
panic(v.error(node.Arguments[0], "builtin %v takes only array (got %v)", node.Name, collection))
return v.error(node.Arguments[0], "builtin %v takes only array (got %v)", node.Name, collection)
}

v.collections = append(v.collections, collection)
Expand All @@ -479,16 +477,16 @@ func (v *visitor) BuiltinNode(node *ast.BuiltinNode) reflect.Type {
closure.NumIn() == 1 && isInterface(closure.In(0)) {

if !isBool(closure.Out(0)) {
panic(v.error(node.Arguments[1], "closure should return boolean (got %v)", closure.Out(0).String()))
return v.error(node.Arguments[1], "closure should return boolean (got %v)", closure.Out(0).String())
}
return arrayType
}
panic(v.error(node.Arguments[1], "closure should has one input and one output param"))
return v.error(node.Arguments[1], "closure should has one input and one output param")

case "map":
collection := v.visit(node.Arguments[0])
if !isArray(collection) {
panic(v.error(node.Arguments[0], "builtin %v takes only array (got %v)", node.Name, collection))
return v.error(node.Arguments[0], "builtin %v takes only array (got %v)", node.Name, collection)
}

v.collections = append(v.collections, collection)
Expand All @@ -501,12 +499,12 @@ func (v *visitor) BuiltinNode(node *ast.BuiltinNode) reflect.Type {

return reflect.SliceOf(closure.Out(0))
}
panic(v.error(node.Arguments[1], "closure should has one input and one output param"))
return v.error(node.Arguments[1], "closure should has one input and one output param")

case "count":
collection := v.visit(node.Arguments[0])
if !isArray(collection) {
panic(v.error(node.Arguments[0], "builtin %v takes only array (got %v)", node.Name, collection))
return v.error(node.Arguments[0], "builtin %v takes only array (got %v)", node.Name, collection)
}

v.collections = append(v.collections, collection)
Expand All @@ -517,15 +515,15 @@ func (v *visitor) BuiltinNode(node *ast.BuiltinNode) reflect.Type {
closure.NumOut() == 1 &&
closure.NumIn() == 1 && isInterface(closure.In(0)) {
if !isBool(closure.Out(0)) {
panic(v.error(node.Arguments[1], "closure should return boolean (got %v)", closure.Out(0).String()))
return v.error(node.Arguments[1], "closure should return boolean (got %v)", closure.Out(0).String())
}

return integerType
}
panic(v.error(node.Arguments[1], "closure should has one input and one output param"))
return v.error(node.Arguments[1], "closure should has one input and one output param")

default:
panic(v.error(node, "unknown builtin %v", node.Name))
return v.error(node, "unknown builtin %v", node.Name)
}
}

Expand All @@ -536,21 +534,21 @@ func (v *visitor) ClosureNode(node *ast.ClosureNode) reflect.Type {

func (v *visitor) PointerNode(node *ast.PointerNode) reflect.Type {
if len(v.collections) == 0 {
panic(v.error(node, "cannot use pointer accessor outside closure"))
return v.error(node, "cannot use pointer accessor outside closure")
}

collection := v.collections[len(v.collections)-1]

if t, ok := indexType(collection); ok {
return t
}
panic(v.error(node, "cannot use %v as array", collection))
return v.error(node, "cannot use %v as array", collection)
}

func (v *visitor) ConditionalNode(node *ast.ConditionalNode) reflect.Type {
c := v.visit(node.Cond)
if !isBool(c) {
panic(v.error(node.Cond, "non-bool expression (type %v) used as condition", c))
return v.error(node.Cond, "non-bool expression (type %v) used as condition", c)
}

t1 := v.visit(node.Exp1)
Expand Down
15 changes: 12 additions & 3 deletions expr.go
Original file line number Diff line number Diff line change
Expand Up @@ -143,15 +143,24 @@ func Compile(input string, ops ...Option) (*vm.Program, error) {
}

_, err = checker.Check(tree, config)
if err != nil {

// If we have a patch to apply, it may fix out error and
// second type check is needed. Otherwise it is an error.
if err != nil && len(config.Visitors) == 0 {
return nil, err
}

// Patch operators before Optimize, as we may also mark it as ConstExpr.
compiler.PatchOperators(&tree.Node, config)

for _, v := range config.Visitors {
ast.Walk(&tree.Node, v)
if len(config.Visitors) >= 0 {
for _, v := range config.Visitors {
ast.Walk(&tree.Node, v)
}
_, err = checker.Check(tree, config)
if err != nil {
return nil, err
}
}

if config.Optimize {
Expand Down
Loading

0 comments on commit 546dfc9

Please sign in to comment.