diff --git a/_test/shift3.go b/_test/shift3.go new file mode 100644 index 000000000..76784d7d0 --- /dev/null +++ b/_test/shift3.go @@ -0,0 +1,10 @@ +package main + +const a = 1.0 + +const b = a + 3 + +func main() { println(b << (1)) } + +// Output: +// 8 diff --git a/interp/cfg.go b/interp/cfg.go index cd027133e..7ff957ba1 100644 --- a/interp/cfg.go +++ b/interp/cfg.go @@ -3,6 +3,7 @@ package interp import ( "fmt" "log" + "math" "path" "reflect" "unicode" @@ -403,7 +404,7 @@ func (interp *Interpreter) cfg(root *node) ([]*node, error) { err = n.cfgErrorf("illegal operand types for '%v' operator", n.action) } case aShlAssign, aShrAssign: - if !(isInt(t0) && isUint(t1)) { + if !(dest.isInteger() && src.isNatural()) { err = n.cfgErrorf("illegal operand types for '%v' operator", n.action) } default: @@ -521,7 +522,7 @@ func (interp *Interpreter) cfg(root *node) ([]*node, error) { err = n.cfgErrorf("illegal operand types for '%v' operator", n.action) } case aShl, aShr: - if !(isInt(t0) && isUint(t1)) { + if !(c0.isInteger() && c1.isNatural()) { err = n.cfgErrorf("illegal operand types for '%v' operator", n.action) } n.typ = c0.typ @@ -1506,7 +1507,57 @@ func wireChild(n *node) { } } -// last returns the last child of a node +// isInteger returns true if node type is integer, false otherwise +func (n *node) isInteger() bool { + if isInt(n.typ.TypeOf()) { + return true + } + if n.typ.untyped && n.rval.IsValid() { + t := n.rval.Type() + if isInt(t) { + return true + } + if isFloat(t) { + // untyped float constant with null decimal part is ok + f := n.rval.Float() + if f == math.Round(f) { + n.rval = reflect.ValueOf(int(f)) + n.typ.rtype = n.rval.Type() + return true + } + } + } + return false +} + +// isNatural returns true if node type is natural, false otherwise +func (n *node) isNatural() bool { + if isUint(n.typ.TypeOf()) { + return true + } + if n.typ.untyped && n.rval.IsValid() { + t := n.rval.Type() + if isUint(t) { + return true + } + if isInt(t) && n.rval.Int() >= 0 { + // positive untyped integer constant is ok + return true + } + if isFloat(t) { + // positive untyped float constant with null decimal part is ok + f := n.rval.Float() + if f == math.Round(f) && f >= 0 { + n.rval = reflect.ValueOf(uint(f)) + n.typ.rtype = n.rval.Type() + return true + } + } + } + return false +} + +// lastChild returns the last child of a node func (n *node) lastChild() *node { return n.child[len(n.child)-1] } func isKey(n *node) bool { diff --git a/interp/interp_eval_test.go b/interp/interp_eval_test.go index 878259784..d41034b08 100644 --- a/interp/interp_eval_test.go +++ b/interp/interp_eval_test.go @@ -44,11 +44,13 @@ func TestEvalArithmetic(t *testing.T) { {desc: "rem_FI", src: "8.0 % 4", err: "1:28: illegal operand types for '%' operator"}, {desc: "shl_II", src: "1 << 8", res: "256"}, {desc: "shl_IN", src: "1 << -1", err: "1:28: illegal operand types for '<<' operator"}, - {desc: "shl_IF", src: "1 << 1.0", err: "1:28: illegal operand types for '<<' operator"}, - {desc: "shl_IF", src: "1.0 << 1", err: "1:28: illegal operand types for '<<' operator"}, + {desc: "shl_IF", src: "1 << 1.0", res: "2"}, + {desc: "shl_IF1", src: "1 << 1.1", err: "1:28: illegal operand types for '<<' operator"}, + {desc: "shl_IF", src: "1.0 << 1", res: "2"}, {desc: "shr_II", src: "1 >> 8", res: "0"}, {desc: "shr_IN", src: "1 >> -1", err: "1:28: illegal operand types for '>>' operator"}, - {desc: "shr_IF", src: "1 >> 1.0", err: "1:28: illegal operand types for '>>' operator"}, + {desc: "shr_IF", src: "1 >> 1.0", res: "0"}, + {desc: "shr_IF1", src: "1 >> 1.1", err: "1:28: illegal operand types for '>>' operator"}, }) } diff --git a/interp/type.go b/interp/type.go index b98bc0c55..05b0ccf5c 100644 --- a/interp/type.go +++ b/interp/type.go @@ -215,14 +215,8 @@ func nodeType(interp *Interpreter, sc *scope, n *node) (*itype, error) { t.name = "float64" t.untyped = true case int: - if isShiftOperand(n) && v >= 0 { - t.cat = uintT - t.name = "uint" - n.rval = reflect.ValueOf(uint(v)) - } else { - t.cat = intT - t.name = "int" - } + t.cat = intT + t.name = "int" t.untyped = true case uint: t.cat = uintT @@ -909,13 +903,6 @@ func isShiftNode(n *node) bool { return false } -func isShiftOperand(n *node) bool { - if isShiftNode(n.anc) { - return n.anc.lastChild() == n - } - return false -} - func isInterface(t *itype) bool { return t.cat == interfaceT || t.TypeOf().Kind() == reflect.Interface } func isStruct(t *itype) bool { return t.TypeOf().Kind() == reflect.Struct }