Skip to content

Commit

Permalink
interp: fix short-form type assertions
Browse files Browse the repository at this point in the history
The long-form (with comma-ok) ones were already fixed but the short-form
ones were not because they were in a completely different code path.

This PR also refactors the code so that both short-form and long-form
are now merged in the same function.

N.B: even though most (all?) cases seem to now be supported, one of them
still yields a result that does not satisfy reflect's Implements method
yet. It does not prevent the resulting assertion to be usable though.

N.B2: the code path for the third-form (_, ok) hasn't been fixed and/or
refactored yet.

Fixes #919
  • Loading branch information
mpl authored Dec 2, 2020
1 parent 101633c commit 2db4579
Show file tree
Hide file tree
Showing 3 changed files with 150 additions and 96 deletions.
40 changes: 39 additions & 1 deletion _test/assert0.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@ package main

import (
"fmt"
"reflect"
"time"
)

Expand All @@ -16,7 +17,8 @@ func (t TestStruct) Write(p []byte) (n int, err error) {
}

func usesWriter(w MyWriter) {
w.Write(nil)
n, _ := w.Write([]byte("hello world"))
fmt.Println(n)
}

type MyStringer interface {
Expand All @@ -28,6 +30,8 @@ func usesStringer(s MyStringer) {
}

func main() {
aType := reflect.TypeOf((*MyWriter)(nil)).Elem()

var t interface{}
t = TestStruct{}
var tw MyWriter
Expand All @@ -39,6 +43,19 @@ func main() {
fmt.Println("TestStruct implements MyWriter")
usesWriter(tw)
}
n, _ := t.(MyWriter).Write([]byte("hello world"))
fmt.Println(n)
bType := reflect.TypeOf(TestStruct{})
fmt.Println(bType.Implements(aType))

t = 42
foo, ok := t.(MyWriter)
if !ok {
fmt.Println("42 does not implement MyWriter")
} else {
fmt.Println("42 implements MyWriter")
}
_ = foo

var tt interface{}
tt = time.Nanosecond
Expand All @@ -50,9 +67,30 @@ func main() {
fmt.Println("time.Nanosecond implements MyStringer")
usesStringer(myD)
}
fmt.Println(tt.(MyStringer).String())
cType := reflect.TypeOf((*MyStringer)(nil)).Elem()
dType := reflect.TypeOf(time.Nanosecond)
fmt.Println(dType.Implements(cType))

tt = 42
bar, ok := tt.(MyStringer)
if !ok {
fmt.Println("42 does not implement MyStringer")
} else {
fmt.Println("42 implements MyStringer")
}
_ = bar

}

// Output:
// TestStruct implements MyWriter
// 11
// 11
// true
// 42 does not implement MyWriter
// time.Nanosecond implements MyStringer
// 1ns
// 1ns
// true
// 42 does not implement MyStringer
24 changes: 24 additions & 0 deletions _test/assert1.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@ package main

import (
"fmt"
"reflect"
"time"
)

Expand All @@ -12,6 +13,8 @@ func (t TestStruct) String() string {
}

func main() {
aType := reflect.TypeOf((*fmt.Stringer)(nil)).Elem()

var t interface{}
t = time.Nanosecond
s, ok := t.(fmt.Stringer)
Expand All @@ -20,6 +23,19 @@ func main() {
return
}
fmt.Println(s.String())
fmt.Println(t.(fmt.Stringer).String())
bType := reflect.TypeOf(time.Nanosecond)
fmt.Println(bType.Implements(aType))


t = 42
foo, ok := t.(fmt.Stringer)
if !ok {
fmt.Println("42 does not implement fmt.Stringer")
} else {
fmt.Println("42 implements fmt.Stringer")
}
_ = foo

var tt interface{}
tt = TestStruct{}
Expand All @@ -29,8 +45,16 @@ func main() {
return
}
fmt.Println(ss.String())
fmt.Println(tt.(fmt.Stringer).String())
// TODO(mpl): uncomment when fixed
// cType := reflect.TypeOf(TestStruct{})
// fmt.Println(cType.Implements(aType))
}

// Output:
// 1ns
// 1ns
// true
// 42 does not implement fmt.Stringer
// hello world
// hello world
182 changes: 87 additions & 95 deletions interp/run.go
Original file line number Diff line number Diff line change
Expand Up @@ -69,7 +69,7 @@ var builtin = [...]bltnGenerator{
aStar: deref,
aSub: sub,
aSubAssign: subAssign,
aTypeAssert: typeAssert,
aTypeAssert: typeAssert1,
aXor: xor,
aXorAssign: xorAssign,
}
Expand Down Expand Up @@ -230,72 +230,6 @@ func typeAssertStatus(n *node) {
}
}

func typeAssert(n *node) {
c0, c1 := n.child[0], n.child[1]
value := genValue(c0) // input value
value0 := genValue(n) // returned result
next := getExec(n.tnext)

switch {
case isInterfaceSrc(c1.typ):
typ := n.child[1].typ
typID := n.child[1].typ.id()
n.exec = func(f *frame) bltn {
v := value(f)
vi, ok := v.Interface().(valueInterface)
if !ok {
panic(n.cfgErrorf("interface conversion: nil is not %v", typID))
}
if !vi.node.typ.implements(typ) {
panic(n.cfgErrorf("interface conversion: %v is not %v", vi.node.typ.id(), typID))
}
value0(f).Set(v)
return next
}
case isInterface(c1.typ):
n.exec = func(f *frame) bltn {
v := value(f).Elem()
typ := value0(f).Type()
if !v.IsValid() {
panic(fmt.Sprintf("interface conversion: interface {} is nil, not %s", typ.String()))
}
if !canAssertTypes(v.Type(), typ) {
method := firstMissingMethod(v.Type(), typ)
panic(fmt.Sprintf("interface conversion: %s is not %s: missing method %s", v.Type().String(), typ.String(), method))
}
value0(f).Set(v)
return next
}
case c0.typ.cat == valueT || c0.typ.cat == errorT:
n.exec = func(f *frame) bltn {
v := value(f).Elem()
typ := value0(f).Type()
if !v.IsValid() {
panic(fmt.Sprintf("interface conversion: interface {} is nil, not %s", typ.String()))
}
if !canAssertTypes(v.Type(), typ) {
method := firstMissingMethod(v.Type(), typ)
panic(fmt.Sprintf("interface conversion: %s is not %s: missing method %s", v.Type().String(), typ.String(), method))
}
value0(f).Set(v)
return next
}
default:
n.exec = func(f *frame) bltn {
v := value(f).Interface().(valueInterface)
typ := value0(f).Type()
if !v.value.IsValid() {
panic(fmt.Sprintf("interface conversion: interface {} is nil, not %s", typ.String()))
}
if !canAssertTypes(v.value.Type(), typ) {
panic(fmt.Sprintf("interface conversion: interface {} is %s, not %s", v.value.Type().String(), typ.String()))
}
value0(f).Set(v.value)
return next
}
}
}

func stripReceiverFromArgs(signature string) (string, error) {
fields := receiverStripperRxp.FindStringSubmatch(signature)
if len(fields) < 5 {
Expand All @@ -307,13 +241,28 @@ func stripReceiverFromArgs(signature string) (string, error) {
return fmt.Sprintf("func(%s", fields[4]), nil
}

func typeAssert1(n *node) {
typeAssert(n, false)
}

func typeAssert2(n *node) {
typeAssert(n, true)
}

func typeAssert(n *node, withOk bool) {
c0, c1 := n.child[0], n.child[1]
value := genValue(c0) // input value
value0 := genValue(n.anc.child[0]) // returned result
value1 := genValue(n.anc.child[1]) // returned status
setStatus := n.anc.child[1].ident != "_" // do not assign status to "_"
typ := c1.typ // type to assert or convert to
value := genValue(c0) // input value
var value0, value1 func(*frame) reflect.Value
setStatus := false
if withOk {
value0 = genValue(n.anc.child[0]) // returned result
value1 = genValue(n.anc.child[1]) // returned status
setStatus = n.anc.child[1].ident != "_" // do not assign status to "_"
} else {
value0 = genValue(n) // returned result
}

typ := c1.typ // type to assert or convert to
typID := typ.id()
rtype := typ.rtype // type to assert
next := getExec(n.tnext)
Expand All @@ -322,13 +271,15 @@ func typeAssert2(n *node) {
case isInterfaceSrc(typ):
n.exec = func(f *frame) bltn {
v, ok := value(f).Interface().(valueInterface)
defer func() {
assertOk := ok
if setStatus {
value1(f).SetBool(assertOk)
}
}()
if setStatus {
defer func() {
value1(f).SetBool(ok)
}()
}
if !ok {
if !withOk {
panic(n.cfgErrorf("interface conversion: nil is not %v", typID))
}
return next
}
if v.node.typ.id() == typID {
Expand All @@ -339,6 +290,9 @@ func typeAssert2(n *node) {
m1 := typ.methods()
if len(m0) < len(m1) {
ok = false
if !withOk {
panic(n.cfgErrorf("interface conversion: %v is not %v", v.node.typ.id(), typID))
}
return next
}

Expand Down Expand Up @@ -383,12 +337,11 @@ func typeAssert2(n *node) {
var leftType reflect.Type
v := value(f)
val, ok := v.Interface().(valueInterface)
defer func() {
assertOk := ok
if setStatus {
value1(f).SetBool(assertOk)
}
}()
if setStatus {
defer func() {
value1(f).SetBool(ok)
}()
}
if ok && val.node.typ.cat != valueT {
m0 := val.node.typ.methods()
m1 := typ.methods()
Expand All @@ -409,6 +362,7 @@ func typeAssert2(n *node) {
}
}

// TODO(mpl): make this case compliant with reflect's Implements.
v = genInterfaceWrapper(val.node, rtype)(f)
value0(f).Set(v)
ok = true
Expand All @@ -423,8 +377,19 @@ func typeAssert2(n *node) {
leftType = v.Type()
ok = true
}
ok = v.IsValid() && canAssertTypes(leftType, rtype)
ok = v.IsValid()
if !ok {
if !withOk {
panic(fmt.Sprintf("interface conversion: interface {} is nil, not %s", rtype.String()))
}
return next
}
ok = canAssertTypes(leftType, rtype)
if !ok {
if !withOk {
method := firstMissingMethod(leftType, rtype)
panic(fmt.Sprintf("interface conversion: %s is not %s: missing method %s", leftType.String(), rtype.String(), method))
}
return next
}
value0(f).Set(v)
Expand All @@ -433,25 +398,52 @@ func typeAssert2(n *node) {
case n.child[0].typ.cat == valueT || n.child[0].typ.cat == errorT:
n.exec = func(f *frame) bltn {
v := value(f).Elem()
ok := v.IsValid() && canAssertTypes(v.Type(), rtype)
if ok {
value0(f).Set(v)
}
ok := v.IsValid()
if setStatus {
value1(f).SetBool(ok)
defer func() {
value1(f).SetBool(ok)
}()
}
if !ok {
if !withOk {
panic(fmt.Sprintf("interface conversion: interface {} is nil, not %s", rtype.String()))
}
return next
}
ok = canAssertTypes(v.Type(), rtype)
if !ok {
if !withOk {
method := firstMissingMethod(v.Type(), rtype)
panic(fmt.Sprintf("interface conversion: %s is not %s: missing method %s", v.Type().String(), rtype.String(), method))
}
return next
}
value0(f).Set(v)
return next
}
default:
n.exec = func(f *frame) bltn {
v, ok := value(f).Interface().(valueInterface)
ok = ok && v.value.IsValid() && canAssertTypes(v.value.Type(), rtype)
if ok {
value0(f).Set(v.value)
}
if setStatus {
value1(f).SetBool(ok)
defer func() {
value1(f).SetBool(ok)
}()
}
if !ok || !v.value.IsValid() {
ok = false
if !withOk {
panic(fmt.Sprintf("interface conversion: interface {} is nil, not %s", rtype.String()))
}
return next
}
ok = canAssertTypes(v.value.Type(), rtype)
if !ok {
if !withOk {
panic(fmt.Sprintf("interface conversion: interface {} is %s, not %s", v.value.Type().String(), rtype.String()))
}
return next
}
value0(f).Set(v.value)
return next
}
}
Expand Down

0 comments on commit 2db4579

Please sign in to comment.