From 596f54f26256d389a380968bef0d6ba67b0e004f Mon Sep 17 00:00:00 2001 From: Ganesan Karuppasamy Date: Tue, 21 May 2024 14:40:48 +0530 Subject: [PATCH] Invoke the Deref function as needed for the function arguments. (#651) --- checker/checker.go | 2 +- compiler/compiler.go | 48 +++++++++++++++++++++++----- test/deref/deref_test.go | 67 ++++++++++++++++++++++++++++++++++++++++ 3 files changed, 109 insertions(+), 8 deletions(-) diff --git a/checker/checker.go b/checker/checker.go index a2c86c20..c71a98f0 100644 --- a/checker/checker.go +++ b/checker/checker.go @@ -1044,7 +1044,7 @@ func (v *checker) checkArguments( continue } - if !t.AssignableTo(in) && kind(t) != reflect.Interface { + if !(t.AssignableTo(in) || deref.Type(t).AssignableTo(in)) && kind(t) != reflect.Interface { return anyType, &file.Error{ Location: arg.Location(), Message: fmt.Sprintf("cannot use %v as argument (type %v) to call %v ", t, in, name), diff --git a/compiler/compiler.go b/compiler/compiler.go index 457088d3..205d6023 100644 --- a/compiler/compiler.go +++ b/compiler/compiler.go @@ -592,8 +592,8 @@ func (c *compiler) BinaryNode(node *ast.BinaryNode) { } func (c *compiler) equalBinaryNode(node *ast.BinaryNode) { - l := kind(node.Left) - r := kind(node.Right) + l := kind(node.Left.Type()) + r := kind(node.Right.Type()) leftIsSimple := isSimpleType(node.Left) rightIsSimple := isSimpleType(node.Right) @@ -727,9 +727,44 @@ func (c *compiler) SliceNode(node *ast.SliceNode) { } func (c *compiler) CallNode(node *ast.CallNode) { - for _, arg := range node.Arguments { - c.compile(arg) + fn := node.Callee.Type() + if kind(fn) == reflect.Func { + fnInOffset := 0 + fnNumIn := fn.NumIn() + switch callee := node.Callee.(type) { + case *ast.MemberNode: + if prop, ok := callee.Property.(*ast.StringNode); ok { + if _, ok = callee.Node.Type().MethodByName(prop.Value); ok && callee.Node.Type().Kind() != reflect.Interface { + fnInOffset = 1 + fnNumIn-- + } + } + case *ast.IdentifierNode: + if t, ok := c.config.Types[callee.Value]; ok && t.Method { + fnInOffset = 1 + fnNumIn-- + } + } + for i, arg := range node.Arguments { + c.compile(arg) + if k := kind(arg.Type()); k == reflect.Ptr || k == reflect.Interface { + var in reflect.Type + if fn.IsVariadic() && i >= fnNumIn-1 { + in = fn.In(fn.NumIn() - 1).Elem() + } else { + in = fn.In(i + fnInOffset) + } + if k = kind(in); k != reflect.Ptr && k != reflect.Interface { + c.emit(OpDeref) + } + } + } + } else { + for _, arg := range node.Arguments { + c.compile(arg) + } } + if ident, ok := node.Callee.(*ast.IdentifierNode); ok { if c.config != nil { if fn, ok := c.config.Functions[ident.Value]; ok { @@ -1162,7 +1197,7 @@ func (c *compiler) PairNode(node *ast.PairNode) { } func (c *compiler) derefInNeeded(node ast.Node) { - switch kind(node) { + switch kind(node.Type()) { case reflect.Ptr, reflect.Interface: c.emit(OpDeref) } @@ -1181,8 +1216,7 @@ func (c *compiler) optimize() { } } -func kind(node ast.Node) reflect.Kind { - t := node.Type() +func kind(t reflect.Type) reflect.Kind { if t == nil { return reflect.Invalid } diff --git a/test/deref/deref_test.go b/test/deref/deref_test.go index 0b228ca1..4bfb7616 100644 --- a/test/deref/deref_test.go +++ b/test/deref/deref_test.go @@ -3,6 +3,7 @@ package deref_test import ( "context" "testing" + "time" "github.com/expr-lang/expr/internal/testify/assert" "github.com/expr-lang/expr/internal/testify/require" @@ -253,3 +254,69 @@ func TestDeref_fetch_from_interface_mix_pointer(t *testing.T) { assert.NoError(t, err) assert.Equal(t, "waldo", res) } + +func TestDeref_func_args(t *testing.T) { + i := 20 + env := map[string]any{ + "var": &i, + "fn": func(p int) int { + return p + 1 + }, + } + + program, err := expr.Compile(`fn(var) + fn(var + 0)`, expr.Env(env)) + require.NoError(t, err) + + out, err := expr.Run(program, env) + require.NoError(t, err) + require.Equal(t, 42, out) +} + +func TestDeref_struct_func_args(t *testing.T) { + n, _ := time.Parse(time.RFC3339, "2024-05-12T18:30:00+00:00") + duration := 30 * time.Minute + env := map[string]any{ + "time": n, + "duration": &duration, + } + + program, err := expr.Compile(`time.Add(duration).Format('2006-01-02T15:04:05Z07:00')`, expr.Env(env)) + require.NoError(t, err) + + out, err := expr.Run(program, env) + require.NoError(t, err) + require.Equal(t, "2024-05-12T19:00:00Z", out) +} + +func TestDeref_ignore_func_args(t *testing.T) { + f := foo(1) + env := map[string]any{ + "foo": &f, + "fn": func(f *foo) int { + return f.Bar() + }, + } + + program, err := expr.Compile(`fn(foo)`, expr.Env(env)) + require.NoError(t, err) + + out, err := expr.Run(program, env) + require.NoError(t, err) + require.Equal(t, 42, out) +} + +func TestDeref_ignore_struct_func_args(t *testing.T) { + n := time.Now() + location, _ := time.LoadLocation("UTC") + env := map[string]any{ + "time": n, + "location": location, + } + + program, err := expr.Compile(`time.In(location).Location().String()`, expr.Env(env)) + require.NoError(t, err) + + out, err := expr.Run(program, env) + require.NoError(t, err) + require.Equal(t, "UTC", out) +}