Skip to content

Commit

Permalink
Add support for handling cycles
Browse files Browse the repository at this point in the history
  • Loading branch information
dsnet committed Mar 27, 2018
1 parent 0627e44 commit fd77732
Show file tree
Hide file tree
Showing 7 changed files with 206 additions and 36 deletions.
105 changes: 93 additions & 12 deletions cmp/compare.go
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,7 @@ import (

"github.com/google/go-cmp/cmp/internal/diff"
"github.com/google/go-cmp/cmp/internal/function"
"github.com/google/go-cmp/cmp/internal/pointer"
"github.com/google/go-cmp/cmp/internal/value"
)

Expand Down Expand Up @@ -81,6 +82,12 @@ var nothing = reflect.Value{}
// To equate empty slices and maps, consider using cmpopts.EquateEmpty.
// Map keys are equal according to the == operator.
// To use custom comparisons for map keys, consider using cmpopts.SortMaps.
//
// When recursing into a pointer, slice, or map, the current path is checked
// to detect whether the address for the given pointer, slice element,
// or map has already been visited. If there is a cycle, then the pointed to
// values are considered equal only if both addresses were previously visited
// in the same path step.
func Equal(x, y interface{}, opts ...Option) bool {
s := newState(opts)
s.compareAny(reflect.ValueOf(x), reflect.ValueOf(y))
Expand Down Expand Up @@ -110,6 +117,7 @@ type state struct {
// Calling statelessCompare must not result in observable changes to these.
result diff.Result // The current result of comparison
curPath Path // The current path in the value tree
pointers pointers // The current set of visited pointers
reporter reporter // Optional reporter used for difference formatting

// dynChecker triggers pseudo-random checks for option correctness.
Expand All @@ -123,6 +131,7 @@ type state struct {

func newState(opts []Option) *state {
s := new(state)
s.pointers = makePointers()
for _, opt := range opts {
s.processOption(opt)
}
Expand Down Expand Up @@ -180,8 +189,6 @@ func (s *state) statelessCompare(vx, vy reflect.Value) diff.Result {
}

func (s *state) compareAny(vx, vy reflect.Value) {
// TODO: Support cyclic data structures.

// Rule 0: Differing types are never equal.
if !vx.IsValid() || !vy.IsValid() {
s.report(vx.IsValid() == vy.IsValid(), vx, vy)
Expand Down Expand Up @@ -241,6 +248,13 @@ func (s *state) compareAny(vx, vy reflect.Value) {
}
s.curPath.push(&indirect{pathStep{t.Elem()}})
defer s.curPath.pop()

if eq, visited := s.pointers.Visit(vx, vy); visited {
s.report(eq, vx, vy)
return
}
defer s.pointers.Leave(vx, vy)

s.compareAny(vx.Elem(), vy.Elem())
return
case reflect.Interface:
Expand All @@ -261,6 +275,11 @@ func (s *state) compareAny(vx, vy reflect.Value) {
s.report(vx.IsNil() && vy.IsNil(), vx, vy)
return
}

// NOTE: A slice is technically a collection of pointers.
// Thus, instead of calling pointers.Visit here on the pointer in the
// slice header, we perform the check on each element in compareArray.

fallthrough
case reflect.Array:
s.compareArray(vx, vy, t)
Expand Down Expand Up @@ -393,21 +412,33 @@ func (s *state) compareArray(vx, vy reflect.Value, t reflect.Type) {
step := &sliceIndex{pathStep{t.Elem()}, 0, 0}
s.curPath.push(step)

// Compute an edit-script for slices vx and vy.
es := diff.Difference(vx.Len(), vy.Len(), func(ix, iy int) diff.Result {
step.xkey, step.ykey = ix, iy
return s.statelessCompare(vx.Index(ix), vy.Index(iy))
})

// Report the entire slice as is if the arrays are of primitive kind,
// and the arrays are different enough.
// Checking the visited map is only necessary for slices of non-primitves.
// The elements of a slice are always addressable.
isPrimitive := false
switch t.Elem().Kind() {
case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64,
reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64, reflect.Uintptr,
reflect.Bool, reflect.Float32, reflect.Float64, reflect.Complex64, reflect.Complex128:
isPrimitive = true
}
checkPointer := t.Kind() == reflect.Slice && !isPrimitive

// Compute an edit-script for slices vx and vy.
es := diff.Difference(vx.Len(), vy.Len(), func(ix, iy int) diff.Result {
step.xkey, step.ykey = ix, iy
vvx, vvy := vx.Index(ix), vy.Index(iy)
if checkPointer {
px, py := vvx.Addr(), vvy.Addr()
if eq, visited := s.pointers.Visit(px, py); visited {
return diff.BoolResult(eq)
}
defer s.pointers.Leave(px, py)
}
return s.statelessCompare(vvx, vvy)
})

// Report the entire slice as is if the arrays are of primitive kind,
// and the arrays are different enough.
if isPrimitive && es.Dist() > (vx.Len()+vy.Len())/4 {
s.curPath.pop() // Pop first since we are reporting the whole slice
s.report(false, vx, vy)
Expand All @@ -428,10 +459,21 @@ func (s *state) compareArray(vx, vy reflect.Value, t reflect.Type) {
iy++
default:
step.xkey, step.ykey = ix, iy
vvx, vvy := vx.Index(ix), vy.Index(iy)
if e == diff.Identity {
s.report(true, vx.Index(ix), vy.Index(iy))
s.report(true, vvx, vvy)
} else {
s.compareAny(vx.Index(ix), vy.Index(iy))
if checkPointer {
px, py := vvx.Addr(), vvy.Addr()
if eq, visited := s.pointers.Visit(px, py); visited {
s.report(eq, vvx, vvy)
} else {
s.compareAny(vvx, vvy)
s.pointers.Leave(px, py)
}
} else {
s.compareAny(vvx, vvy)
}
}
ix++
iy++
Expand All @@ -447,6 +489,12 @@ func (s *state) compareMap(vx, vy reflect.Value, t reflect.Type) {
return
}

if eq, visited := s.pointers.Visit(vx, vy); visited {
s.report(eq, vx, vy)
return
}
defer s.pointers.Leave(vx, vy)

// We combine and sort the two map keys so that we can perform the
// comparisons in a deterministic order.
step := &mapIndex{pathStep: pathStep{t.Elem()}}
Expand Down Expand Up @@ -518,6 +566,39 @@ func (s *state) report(eq bool, vx, vy reflect.Value) {
}
}

type pointers struct {
mx map[pointer.P]pointer.P
my map[pointer.P]pointer.P
}

func makePointers() pointers {
return pointers{
make(map[pointer.P]pointer.P),
make(map[pointer.P]pointer.P),
}
}

// Visit descends into pointers vx and vy if they have never been seen before.
// The comparison is equal if both pointers were encountered together.
func (m pointers) Visit(vx, vy reflect.Value) (equal, visited bool) {
px := pointer.New(vx)
py := pointer.New(vy)
_, ok1 := m.mx[px]
_, ok2 := m.my[py]
if ok1 || ok2 {
equal = m.mx[px] == py && m.my[py] == px // Pointers paired together
return equal, true
}
m.mx[px] = py
m.my[py] = px
return false, false
}

func (m pointers) Leave(vx, vy reflect.Value) {
delete(m.mx, pointer.New(vx))
delete(m.my, pointer.New(vy))
}

// dynChecker tracks the state needed to periodically perform checks that
// user provided functions are symmetric and deterministic.
// The zero value is safe for immediate use.
Expand Down
31 changes: 31 additions & 0 deletions cmp/compare_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,7 @@ func TestDiff(t *testing.T) {
tests = append(tests, transformerTests()...)
tests = append(tests, embeddedTests()...)
tests = append(tests, methodTests()...)
tests = append(tests, cycleTests()...)
tests = append(tests, project1Tests()...)
tests = append(tests, project2Tests()...)
tests = append(tests, project3Tests()...)
Expand Down Expand Up @@ -1538,6 +1539,36 @@ func methodTests() []test {
}}
}

func cycleTests() []test {
const label = "Cycle"

type A *A
graphA := new(A)
*graphA = graphA

type B []B
graphB := B{nil}
graphB[0] = graphB

type C map[int]C
graphC := C{0: nil}
graphC[0] = graphC

return []test{{
label: label,
x: graphA,
y: graphA,
}, {
label: label,
x: graphB,
y: graphB,
}, {
label: label,
x: graphC,
y: graphC,
}}
}

func project1Tests() []test {
const label = "Project1"

Expand Down
8 changes: 8 additions & 0 deletions cmp/internal/diff/diff.go
Original file line number Diff line number Diff line change
Expand Up @@ -89,6 +89,14 @@ type EqualFunc func(ix int, iy int) Result
// NDiff is the number of sub-elements that are not equal.
type Result struct{ NSame, NDiff int }

// BoolResult returns a Result that is either Equal or not Equal; never Similar.
func BoolResult(b bool) Result {
if b {
return Result{NSame: 1} // Equal, Similar
}
return Result{NDiff: 2} // Not Equal, not Similar
}

// Equal indicates whether the symbols are equal. Two symbols are equal
// if and only if NDiff == 0. If Equal, then they are also Similar.
func (r Result) Equal() bool { return r.NDiff == 0 }
Expand Down
25 changes: 25 additions & 0 deletions cmp/internal/pointer/pointer_purego.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,25 @@
// Copyright 2018, The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE.md file.

// +build purego

// Package pointer provides an abstraction over opaque pointers.
package pointer

import (
"reflect"
)

// P is a typed pointer and is guaranteed to be comparable.
type P struct {
p uintptr
t reflect.Type
}

// New returns a typed pointer P from v, which must be a Ptr, Slice, or Map.
func New(v reflect.Value) P {
// NOTE: Storing a pointer as an uintptr is technically incorrect as it
// assumes that the GC implementation does not use a moving collector.
return P{v.Pointer(), v.Type()}
}
26 changes: 26 additions & 0 deletions cmp/internal/pointer/pointer_unsafe.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,26 @@
// Copyright 2018, The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE.md file.

// +build !purego

// Package pointer provides an abstraction over opaque pointers.
package pointer

import (
"reflect"
"unsafe"
)

// P is a typed pointer and is guaranteed to be comparable.
type P struct {
p unsafe.Pointer
t reflect.Type
}

// New returns a typed pointer P from v, which must be a Ptr, Slice, or Map.
func New(v reflect.Value) P {
// The proper representation of a pointer is unsafe.Pointer,
// which is necessary if the GC ever uses a moving collector.
return P{unsafe.Pointer(v.Pointer()), v.Type()}
}
Loading

0 comments on commit fd77732

Please sign in to comment.