Skip to content

Commit

Permalink
fix: correct automatic type conversion for untyped constants
Browse files Browse the repository at this point in the history
  • Loading branch information
mvertes authored and traefiker committed Sep 25, 2019
1 parent 8db7a81 commit 03596da
Show file tree
Hide file tree
Showing 4 changed files with 71 additions and 21 deletions.
10 changes: 10 additions & 0 deletions _test/shift3.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,10 @@
package main

const a = 1.0

const b = a + 3

func main() { println(b << (1)) }

// Output:
// 8
57 changes: 54 additions & 3 deletions interp/cfg.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@ package interp
import (
"fmt"
"log"
"math"
"path"
"reflect"
"unicode"
Expand Down Expand Up @@ -403,7 +404,7 @@ func (interp *Interpreter) cfg(root *node) ([]*node, error) {
err = n.cfgErrorf("illegal operand types for '%v' operator", n.action)
}
case aShlAssign, aShrAssign:
if !(isInt(t0) && isUint(t1)) {
if !(dest.isInteger() && src.isNatural()) {
err = n.cfgErrorf("illegal operand types for '%v' operator", n.action)
}
default:
Expand Down Expand Up @@ -521,7 +522,7 @@ func (interp *Interpreter) cfg(root *node) ([]*node, error) {
err = n.cfgErrorf("illegal operand types for '%v' operator", n.action)
}
case aShl, aShr:
if !(isInt(t0) && isUint(t1)) {
if !(c0.isInteger() && c1.isNatural()) {
err = n.cfgErrorf("illegal operand types for '%v' operator", n.action)
}
n.typ = c0.typ
Expand Down Expand Up @@ -1506,7 +1507,57 @@ func wireChild(n *node) {
}
}

// last returns the last child of a node
// isInteger returns true if node type is integer, false otherwise
func (n *node) isInteger() bool {
if isInt(n.typ.TypeOf()) {
return true
}
if n.typ.untyped && n.rval.IsValid() {
t := n.rval.Type()
if isInt(t) {
return true
}
if isFloat(t) {
// untyped float constant with null decimal part is ok
f := n.rval.Float()
if f == math.Round(f) {
n.rval = reflect.ValueOf(int(f))
n.typ.rtype = n.rval.Type()
return true
}
}
}
return false
}

// isNatural returns true if node type is natural, false otherwise
func (n *node) isNatural() bool {
if isUint(n.typ.TypeOf()) {
return true
}
if n.typ.untyped && n.rval.IsValid() {
t := n.rval.Type()
if isUint(t) {
return true
}
if isInt(t) && n.rval.Int() >= 0 {
// positive untyped integer constant is ok
return true
}
if isFloat(t) {
// positive untyped float constant with null decimal part is ok
f := n.rval.Float()
if f == math.Round(f) && f >= 0 {
n.rval = reflect.ValueOf(uint(f))
n.typ.rtype = n.rval.Type()
return true
}
}
}
return false
}

// lastChild returns the last child of a node
func (n *node) lastChild() *node { return n.child[len(n.child)-1] }

func isKey(n *node) bool {
Expand Down
8 changes: 5 additions & 3 deletions interp/interp_eval_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -44,11 +44,13 @@ func TestEvalArithmetic(t *testing.T) {
{desc: "rem_FI", src: "8.0 % 4", err: "1:28: illegal operand types for '%' operator"},
{desc: "shl_II", src: "1 << 8", res: "256"},
{desc: "shl_IN", src: "1 << -1", err: "1:28: illegal operand types for '<<' operator"},
{desc: "shl_IF", src: "1 << 1.0", err: "1:28: illegal operand types for '<<' operator"},
{desc: "shl_IF", src: "1.0 << 1", err: "1:28: illegal operand types for '<<' operator"},
{desc: "shl_IF", src: "1 << 1.0", res: "2"},
{desc: "shl_IF1", src: "1 << 1.1", err: "1:28: illegal operand types for '<<' operator"},
{desc: "shl_IF", src: "1.0 << 1", res: "2"},
{desc: "shr_II", src: "1 >> 8", res: "0"},
{desc: "shr_IN", src: "1 >> -1", err: "1:28: illegal operand types for '>>' operator"},
{desc: "shr_IF", src: "1 >> 1.0", err: "1:28: illegal operand types for '>>' operator"},
{desc: "shr_IF", src: "1 >> 1.0", res: "0"},
{desc: "shr_IF1", src: "1 >> 1.1", err: "1:28: illegal operand types for '>>' operator"},
})
}

Expand Down
17 changes: 2 additions & 15 deletions interp/type.go
Original file line number Diff line number Diff line change
Expand Up @@ -215,14 +215,8 @@ func nodeType(interp *Interpreter, sc *scope, n *node) (*itype, error) {
t.name = "float64"
t.untyped = true
case int:
if isShiftOperand(n) && v >= 0 {
t.cat = uintT
t.name = "uint"
n.rval = reflect.ValueOf(uint(v))
} else {
t.cat = intT
t.name = "int"
}
t.cat = intT
t.name = "int"
t.untyped = true
case uint:
t.cat = uintT
Expand Down Expand Up @@ -909,13 +903,6 @@ func isShiftNode(n *node) bool {
return false
}

func isShiftOperand(n *node) bool {
if isShiftNode(n.anc) {
return n.anc.lastChild() == n
}
return false
}

func isInterface(t *itype) bool { return t.cat == interfaceT || t.TypeOf().Kind() == reflect.Interface }

func isStruct(t *itype) bool { return t.TypeOf().Kind() == reflect.Struct }
Expand Down

0 comments on commit 03596da

Please sign in to comment.