diff --git a/report/counters.go b/report/counters.go index 4ba60c947c..d7e882c197 100644 --- a/report/counters.go +++ b/report/counters.go @@ -108,12 +108,7 @@ func (c Counters) String() string { } // DeepEqual tests equality with other Counters -func (c Counters) DeepEqual(i interface{}) bool { - d, ok := i.(Counters) - if !ok { - return false - } - +func (c Counters) DeepEqual(d Counters) bool { if (c.psMap == nil) != (d.psMap == nil) { return false } else if c.psMap == nil && d.psMap == nil { diff --git a/report/edge_metadatas.go b/report/edge_metadatas.go index 43fb9ef3a8..ee410ab157 100644 --- a/report/edge_metadatas.go +++ b/report/edge_metadatas.go @@ -128,12 +128,7 @@ func (c EdgeMetadatas) String() string { } // DeepEqual tests equality with other Counters -func (c EdgeMetadatas) DeepEqual(i interface{}) bool { - d, ok := i.(EdgeMetadatas) - if !ok { - return false - } - +func (c EdgeMetadatas) DeepEqual(d EdgeMetadatas) bool { if c.Size() != d.Size() { return false } diff --git a/report/latest_map.go b/report/latest_map.go index 9f5619c28d..f75863deac 100644 --- a/report/latest_map.go +++ b/report/latest_map.go @@ -142,12 +142,7 @@ func (m LatestMap) String() string { } // DeepEqual tests equality with other LatestMap -func (m LatestMap) DeepEqual(i interface{}) bool { - n, ok := i.(LatestMap) - if !ok { - return false - } - +func (m LatestMap) DeepEqual(n LatestMap) bool { if m.Size() != n.Size() { return false } diff --git a/report/sets.go b/report/sets.go index 8ea9e6dcda..77067f9df0 100644 --- a/report/sets.go +++ b/report/sets.go @@ -104,12 +104,7 @@ func (s Sets) String() string { } // DeepEqual tests equality with other Sets -func (s Sets) DeepEqual(i interface{}) bool { - t, ok := i.(Sets) - if !ok { - return false - } - +func (s Sets) DeepEqual(t Sets) bool { if s.psMap.Size() != t.psMap.Size() { return false } diff --git a/test/reflect/deepequal.go b/test/reflect/deepequal.go index 70857196eb..f465f5c61d 100644 --- a/test/reflect/deepequal.go +++ b/test/reflect/deepequal.go @@ -32,13 +32,6 @@ func deepValueEqual(v1, v2 reflect.Value, visited map[visit]bool, depth int) boo } // if depth > 10 { panic("deepValueEqual") } // for debugging - hard := func(k reflect.Kind) bool { - switch k { - case reflect.Array, reflect.Map, reflect.Slice, reflect.Struct: - return true - } - return false - } if v1.CanAddr() && v2.CanAddr() && hard(v1.Kind()) { addr1 := v1.UnsafeAddr() @@ -72,88 +65,133 @@ func deepValueEqual(v1, v2 reflect.Value, visited map[visit]bool, depth int) boo return results[0].Bool() } - switch v1.Kind() { - case reflect.Array: - for i := 0; i < v1.Len(); i++ { - if !deepValueEqual(v1.Index(i), v2.Index(i), visited, depth+1) { - return false - } - } + test, ok := map[reflect.Kind]func(v1, v2 reflect.Value, visited map[visit]bool, depth int) bool{ + reflect.Array: arrayEq, + reflect.Slice: sliceEq, + reflect.Interface: interfaceEq, + reflect.Ptr: pointerEq, + reflect.Struct: structEq, + reflect.Map: mapEq, + reflect.Func: funcEq, + reflect.Bool: boolEq, + reflect.Float32: floatEq, + reflect.Float64: floatEq, + reflect.Int: intEq, + reflect.Int8: intEq, + reflect.Int16: intEq, + reflect.Int32: intEq, + reflect.Int64: intEq, + reflect.Uint: uintEq, + reflect.Uintptr: uintEq, + reflect.Uint8: uintEq, + reflect.Uint16: uintEq, + reflect.Uint32: uintEq, + reflect.Uint64: uintEq, + reflect.String: stringEq, + }[v1.Kind()] + if !ok { + test = normalEq + } + return test(v1, v2, visited, depth) +} + +func hard(k reflect.Kind) bool { + switch k { + case reflect.Array, reflect.Map, reflect.Slice, reflect.Struct: return true - case reflect.Slice: - if v1.IsNil() != v2.IsNil() { - return false - } - if v1.Len() != v2.Len() { + } + return false +} + +func arrayEq(v1, v2 reflect.Value, visited map[visit]bool, depth int) bool { + for i := 0; i < v1.Len(); i++ { + if !deepValueEqual(v1.Index(i), v2.Index(i), visited, depth+1) { return false } - if v1.Pointer() == v2.Pointer() { - return true - } - for i := 0; i < v1.Len(); i++ { - if !deepValueEqual(v1.Index(i), v2.Index(i), visited, depth+1) { - return false - } - } - return true - case reflect.Interface: - if v1.IsNil() || v2.IsNil() { - return v1.IsNil() == v2.IsNil() - } - return deepValueEqual(v1.Elem(), v2.Elem(), visited, depth+1) - case reflect.Ptr: - return deepValueEqual(v1.Elem(), v2.Elem(), visited, depth+1) - case reflect.Struct: - for i, n := 0, v1.NumField(); i < n; i++ { - if v1.Type().Field(i).Tag.Get("deepequal") == "skip" { - continue - } - if !deepValueEqual(v1.Field(i), v2.Field(i), visited, depth+1) { - return false - } - } + } + return true +} + +func sliceEq(v1, v2 reflect.Value, visited map[visit]bool, depth int) bool { + if v1.IsNil() != v2.IsNil() { + return false + } + if v1.Len() != v2.Len() { + return false + } + if v1.Pointer() == v2.Pointer() { return true - case reflect.Map: - if v1.IsNil() != v2.IsNil() { - return false - } - if v1.Len() != v2.Len() { + } + for i := 0; i < v1.Len(); i++ { + if !deepValueEqual(v1.Index(i), v2.Index(i), visited, depth+1) { return false } - if v1.Pointer() == v2.Pointer() { - return true - } - for _, k := range v1.MapKeys() { - if !deepValueEqual(v1.MapIndex(k), v2.MapIndex(k), visited, depth+1) { - return false - } + } + return true +} + +func interfaceEq(v1, v2 reflect.Value, visited map[visit]bool, depth int) bool { + if v1.IsNil() || v2.IsNil() { + return v1.IsNil() == v2.IsNil() + } + return deepValueEqual(v1.Elem(), v2.Elem(), visited, depth+1) +} + +func pointerEq(v1, v2 reflect.Value, visited map[visit]bool, depth int) bool { + return deepValueEqual(v1.Elem(), v2.Elem(), visited, depth+1) +} + +func structEq(v1, v2 reflect.Value, visited map[visit]bool, depth int) bool { + for i, n := 0, v1.NumField(); i < n; i++ { + if v1.Type().Field(i).Tag.Get("deepequal") == "skip" { + continue } - return true - case reflect.Func: - if v1.IsNil() && v2.IsNil() { - return true + if !deepValueEqual(v1.Field(i), v2.Field(i), visited, depth+1) { + return false } - // Can't do better than this: + } + return true +} + +func mapEq(v1, v2 reflect.Value, visited map[visit]bool, depth int) bool { + if v1.IsNil() != v2.IsNil() { + return false + } + if v1.Len() != v2.Len() { return false - case reflect.Bool: - return v1.Bool() == v2.Bool() - case reflect.Float32, reflect.Float64: - return v1.Float() == v2.Float() - case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64: - return v1.Int() == v2.Int() - case reflect.Uint, reflect.Uintptr, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64: - return v1.Uint() == v2.Uint() - case reflect.String: - return v1.String() == v2.String() - default: - // Normal equality suffices - if v1.CanInterface() && v2.CanInterface() { - return v1.Interface() == v1.Interface() - } else if v1.CanInterface() || v2.CanInterface() { + } + if v1.Pointer() == v2.Pointer() { + return true + } + for _, k := range v1.MapKeys() { + if !deepValueEqual(v1.MapIndex(k), v2.MapIndex(k), visited, depth+1) { return false } + } + return true +} + +func funcEq(v1, v2 reflect.Value, visited map[visit]bool, depth int) bool { + if v1.IsNil() && v2.IsNil() { return true } + // Can't do better than this: + return false +} + +func boolEq(v1, v2 reflect.Value, _ map[visit]bool, _ int) bool { return v1.Bool() == v2.Bool() } +func floatEq(v1, v2 reflect.Value, _ map[visit]bool, _ int) bool { return v1.Float() == v2.Float() } +func intEq(v1, v2 reflect.Value, _ map[visit]bool, _ int) bool { return v1.Int() == v2.Int() } +func uintEq(v1, v2 reflect.Value, _ map[visit]bool, _ int) bool { return v1.Uint() == v2.Uint() } +func stringEq(v1, v2 reflect.Value, _ map[visit]bool, _ int) bool { return v1.String() == v2.String() } + +func normalEq(v1, v2 reflect.Value, _ map[visit]bool, _ int) bool { + if v1.CanInterface() && v2.CanInterface() { + return v1.Interface() == v1.Interface() + } else if v1.CanInterface() || v2.CanInterface() { + return false + } + return true } // DeepEqual tests for deep equality. It uses normal == equality where @@ -166,10 +204,10 @@ func DeepEqual(a1, a2 interface{}) bool { if a1 == nil || a2 == nil { return a1 == a2 } - v1 := reflect.ValueOf(a1) - v2 := reflect.ValueOf(a2) - if v1.Type() != v2.Type() { - return false - } - return deepValueEqual(v1, v2, make(map[visit]bool), 0) + return deepValueEqual( + reflect.ValueOf(a1), + reflect.ValueOf(a2), + make(map[visit]bool), + 0, + ) }