diff --git a/types/sort.go b/types/sort.go index 1704343c74f..8df4acb5d63 100644 --- a/types/sort.go +++ b/types/sort.go @@ -178,15 +178,25 @@ func equal(a, b Val) bool { } switch a.Tid { case DateTimeID: - return a.Value.(time.Time).Equal((b.Value.(time.Time))) + aVal, aOk := a.Value.(time.Time) + bVal, bOk := b.Value.(time.Time) + return aOk && bOk && aVal.Equal(bVal) case IntID: - return (a.Value.(int64)) == (b.Value.(int64)) + aVal, aOk := a.Value.(int64) + bVal, bOk := b.Value.(int64) + return aOk && bOk && aVal == bVal case FloatID: - return (a.Value.(float64)) == (b.Value.(float64)) + aVal, aOk := a.Value.(float64) + bVal, bOk := b.Value.(float64) + return aOk && bOk && aVal == bVal case StringID, DefaultID: - return (a.Value.(string)) == (b.Value.(string)) + aVal, aOk := a.Value.(string) + bVal, bOk := b.Value.(string) + return aOk && bOk && aVal == bVal case BoolID: - return a.Value.(bool) == (b.Value.(bool)) + aVal, aOk := a.Value.(bool) + bVal, bOk := b.Value.(bool) + return aOk && bOk && aVal == bVal } return false } diff --git a/types/sort_test.go b/types/sort_test.go index 605dd012583..3f1ff912ede 100644 --- a/types/sort_test.go +++ b/types/sort_test.go @@ -122,6 +122,32 @@ func TestSortIntAndFloat(t *testing.T) { } +func TestEqual(t *testing.T) { + require.True(t, equal(Val{Tid: IntID, Value: int64(3)}, Val{Tid: IntID, Value: int64(3)}), + "equal should return true for two equal values") + + require.False(t, equal(Val{Tid: IntID, Value: int64(3)}, Val{Tid: IntID, Value: int64(4)}), + "equal should return false for two different values") + + // not equal when the types are different + require.False(t, equal(Val{Tid: IntID, Value: int64(3)}, Val{Tid: FloatID, Value: float64(3.0)}), + "equal should return false for two values with different types") + + // not equal when either parameter has the Value field being nil + require.False(t, equal(Val{Tid: IntID, Value: int64(3)}, Val{Tid: IntID}), + "equal should return false when either parameter cannot have its value converted") + require.False(t, equal(Val{Tid: IntID}, Val{Tid: IntID, Value: int64(3)}), + "equal should return false when either parameter cannot have its value converted") + require.False(t, equal(Val{Tid: IntID}, Val{Tid: IntID}), "equal should return false when either parameter cannot have its value converted") + + // not equal when there is a type mismatch between value and tid for either parameter + require.False(t, equal(Val{Tid: IntID, Value: float64(3.0)}, Val{Tid: FloatID, Value: float64(3.0)}), + "equal should return false when either parameter's value has a type mismatch with its Tid") + require.False(t, equal(Val{Tid: FloatID, Value: float64(3.0)}, Val{Tid: IntID, Value: float64(3.0)}), + "equal should return false when either parameter's value has a type mismatch with its Tid") + +} + func findIndex(t *testing.T, uids []uint64, uid uint64) int { for i := range uids { if uids[i] == uid {