Skip to content

Commit

Permalink
fix: improve interface type checks using method sets
Browse files Browse the repository at this point in the history
  • Loading branch information
mvertes authored and traefiker committed Dec 11, 2019
1 parent 0d2c39d commit 273df8a
Show file tree
Hide file tree
Showing 3 changed files with 51 additions and 16 deletions.
18 changes: 18 additions & 0 deletions _test/interface18.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,18 @@
package main

type T struct{}

func (t *T) Error() string { return "T: error" }
func (*T) Foo() { println("foo") }

var invalidT = &T{}

func main() {
var err error
if err != invalidT {
println("ok")
}
}

// Output:
// ok
2 changes: 1 addition & 1 deletion interp/cfg.go
Original file line number Diff line number Diff line change
Expand Up @@ -549,7 +549,7 @@ func (interp *Interpreter) cfg(root *node) ([]*node, error) {
t0, t1 := c0.typ.TypeOf(), c1.typ.TypeOf()
// Shift operator type is inherited from first parameter only
// All other binary operators require both parameter types to be the same
if !isShiftNode(n) && !c0.typ.untyped && !c1.typ.untyped && !c0.typ.equal(c1.typ) {
if !isShiftNode(n) && !c0.typ.untyped && !c1.typ.untyped && !c0.typ.equals(c1.typ) {
err = n.cfgErrorf("mismatched types %s and %s", c0.typ.id(), c1.typ.id())
break
}
Expand Down
47 changes: 32 additions & 15 deletions interp/type.go
Original file line number Diff line number Diff line change
Expand Up @@ -636,18 +636,41 @@ func (t *itype) finalize() (*itype, error) {
return t, err
}

// equal returns true if the given type is identical to the receiver one
func (t *itype) equal(o *itype) bool {
if isInterface(t) || isInterface(o) {
// Check for identical methods sets
return reflect.DeepEqual(t.methods(), o.methods())
// Equals returns true if the given type is identical to the receiver one.
func (t *itype) equals(o *itype) bool {
switch ti, oi := isInterface(t), isInterface(o); {
case ti && oi:
return t.methods().equals(o.methods())
case ti && !oi:
return o.methods().contains(t.methods())
case oi && !ti:
return t.methods().contains(o.methods())
default:
return t.id() == o.id()
}
return t.id() == o.id()
}

// methods returns a map of method type strings, indexed by method names
func (t *itype) methods() map[string]string {
res := make(map[string]string)
// MethodSet defines the set of methods signatures as strings, indexed per method name.
type methodSet map[string]string

// Contains returns true if the method set m contains the method set n.
func (m methodSet) contains(n methodSet) bool {
for k, v := range n {
if m[k] != v {
return false
}
}
return true
}

// Equal returns true if the method set m is equal to the method set n.
func (m methodSet) equals(n methodSet) bool {
return m.contains(n) && n.contains(m)
}

// Methods returns a map of method type strings, indexed by method names.
func (t *itype) methods() methodSet {
res := make(methodSet)
switch t.cat {
case interfaceT:
// Get methods from recursive analysis of interface fields
Expand Down Expand Up @@ -879,15 +902,9 @@ func (t *itype) refType(defined map[string]bool) reflect.Type {
in := make([]reflect.Type, len(t.arg))
out := make([]reflect.Type, len(t.ret))
for i, v := range t.arg {
if defined[v.name] {
v.rtype = interf
}
in[i] = v.refType(defined)
}
for i, v := range t.ret {
if defined[v.name] {
v.rtype = interf
}
out[i] = v.refType(defined)
}
t.rtype = reflect.FuncOf(in, out, false)
Expand Down

0 comments on commit 273df8a

Please sign in to comment.