diff --git a/checker/checker.go b/checker/checker.go index 3e787fa4f..efa1b7eb9 100644 --- a/checker/checker.go +++ b/checker/checker.go @@ -924,8 +924,13 @@ func (v *checker) checkArguments(name string, fn reflect.Type, method bool, argu } if isFloat(in) { - t = floatType - traverseAndReplaceIntegerNodesWithFloatNodes(&arg) + traverseAndReplaceIntegerNodesWithFloatNodes(&arguments[i], in) + continue + } + + if isInteger(in) && isInteger(t) && kind(t) != kind(in) { + traverseAndReplaceIntegerNodesWithIntegerNodes(&arguments[i], in) + continue } if t == nil { @@ -943,19 +948,37 @@ func (v *checker) checkArguments(name string, fn reflect.Type, method bool, argu return fn.Out(0), nil } -func traverseAndReplaceIntegerNodesWithFloatNodes(node *ast.Node) { +func traverseAndReplaceIntegerNodesWithFloatNodes(node *ast.Node, newType reflect.Type) { switch (*node).(type) { case *ast.IntegerNode: *node = &ast.FloatNode{Value: float64((*node).(*ast.IntegerNode).Value)} + (*node).SetType(newType) + case *ast.UnaryNode: + unaryNode := (*node).(*ast.UnaryNode) + traverseAndReplaceIntegerNodesWithFloatNodes(&unaryNode.Node, newType) + case *ast.BinaryNode: + binaryNode := (*node).(*ast.BinaryNode) + switch binaryNode.Operator { + case "+", "-", "*": + traverseAndReplaceIntegerNodesWithFloatNodes(&binaryNode.Left, newType) + traverseAndReplaceIntegerNodesWithFloatNodes(&binaryNode.Right, newType) + } + } +} + +func traverseAndReplaceIntegerNodesWithIntegerNodes(node *ast.Node, newType reflect.Type) { + switch (*node).(type) { + case *ast.IntegerNode: + (*node).SetType(newType) case *ast.UnaryNode: unaryNode := (*node).(*ast.UnaryNode) - traverseAndReplaceIntegerNodesWithFloatNodes(&unaryNode.Node) + traverseAndReplaceIntegerNodesWithIntegerNodes(&unaryNode.Node, newType) case *ast.BinaryNode: binaryNode := (*node).(*ast.BinaryNode) switch binaryNode.Operator { case "+", "-", "*": - traverseAndReplaceIntegerNodesWithFloatNodes(&binaryNode.Left) - traverseAndReplaceIntegerNodesWithFloatNodes(&binaryNode.Right) + traverseAndReplaceIntegerNodesWithIntegerNodes(&binaryNode.Left, newType) + traverseAndReplaceIntegerNodesWithIntegerNodes(&binaryNode.Right, newType) } } } diff --git a/compiler/compiler.go b/compiler/compiler.go index 66ad4f1f2..8e26d8788 100644 --- a/compiler/compiler.go +++ b/compiler/compiler.go @@ -307,7 +307,17 @@ func (c *compiler) IntegerNode(node *ast.IntegerNode) { } func (c *compiler) FloatNode(node *ast.FloatNode) { - c.emitPush(node.Value) + t := node.Type() + if t == nil { + c.emitPush(node.Value) + return + } + switch t.Kind() { + case reflect.Float32: + c.emitPush(float32(node.Value)) + case reflect.Float64: + c.emitPush(node.Value) + } } func (c *compiler) BoolNode(node *ast.BoolNode) { diff --git a/expr_test.go b/expr_test.go index 1857403ec..86945867f 100644 --- a/expr_test.go +++ b/expr_test.go @@ -1969,3 +1969,32 @@ func TestMemoryBudget(t *testing.T) { }) } } + +func TestIssue432(t *testing.T) { + env := map[string]any{ + "func": func( + paramUint32 uint32, + paramUint16 uint16, + paramUint8 uint8, + paramUint uint, + paramInt32 int32, + paramInt16 int16, + paramInt8 int8, + paramInt int, + paramFloat64 float64, + paramFloat32 float32, + ) float64 { + return float64(paramUint32) + float64(paramUint16) + float64(paramUint8) + float64(paramUint) + + float64(paramInt32) + float64(paramInt16) + float64(paramInt8) + float64(paramInt) + + float64(paramFloat64) + float64(paramFloat32) + }, + } + code := `func(1,1,1,1,1,1,1,1,1,1)` + + program, err := expr.Compile(code, expr.Env(env)) + assert.NoError(t, err) + + out, err := expr.Run(program, env) + assert.NoError(t, err) + assert.Equal(t, float64(10), out) +} \ No newline at end of file