diff --git a/pkg/sql/exec/execgen/cmd/execgen/overloads.go b/pkg/sql/exec/execgen/cmd/execgen/overloads.go index 05395371fb4e..b801cc251093 100644 --- a/pkg/sql/exec/execgen/cmd/execgen/overloads.go +++ b/pkg/sql/exec/execgen/cmd/execgen/overloads.go @@ -336,6 +336,13 @@ func (c floatCustomizer) getHashAssignFunc() assignFunc { } } +func (c floatCustomizer) getCmpOpCompareFunc() compareFunc { + // Float comparisons need special handling for NaN. + return func(l, r string) string { + return fmt.Sprintf("compareFloats(float64(%s), float64(%s))", l, r) + } +} + func (c intCustomizer) getHashAssignFunc() assignFunc { return func(op overload, target, v, _ string) string { return fmt.Sprintf("%[1]s = memhash%[3]d(noescape(unsafe.Pointer(&%[2]s)), %[1]s)", target, v, c.width) diff --git a/pkg/sql/exec/float.go b/pkg/sql/exec/float.go new file mode 100644 index 000000000000..a30cdca8333f --- /dev/null +++ b/pkg/sql/exec/float.go @@ -0,0 +1,34 @@ +// Copyright 2019 The Cockroach Authors. +// +// Use of this software is governed by the Business Source License +// included in the file licenses/BSL.txt. +// +// As of the Change Date specified in that file, in accordance with +// the Business Source License, use of this software will be governed +// by the Apache License, Version 2.0, included in the file +// licenses/APL.txt. + +package exec + +import "math" + +// compareFloats compares two float values. This function is necessary for NaN +// handling. In SQL, NaN is treated as less than all other float values. In Go, +// any comparison with NaN returns false. +func compareFloats(a, b float64) int { + if a < b { + return -1 + } + if a > b { + return 1 + } + // Compare bits so that NaN == NaN. + if math.Float64bits(a) == math.Float64bits(b) { + return 0 + } + // Either a or b is NaN. + if math.IsNaN(a) { + return -1 + } + return 1 +} diff --git a/pkg/sql/exec/sort_test.go b/pkg/sql/exec/sort_test.go index 335e18618f24..22772b4bf481 100644 --- a/pkg/sql/exec/sort_test.go +++ b/pkg/sql/exec/sort_test.go @@ -13,6 +13,7 @@ package exec import ( "context" "fmt" + "math" "math/rand" "sort" "testing" @@ -76,8 +77,8 @@ func TestSort(t *testing.T) { ordCols: []distsqlpb.Ordering_Column{{ColIdx: 0}}, }, { - tuples: tuples{{3.2}, {2.0}, {2.4}}, - expected: tuples{{2.0}, {2.4}, {3.2}}, + tuples: tuples{{3.2}, {2.0}, {2.4}, {math.NaN()}, {math.Inf(-1)}, {math.Inf(1)}}, + expected: tuples{{math.NaN()}, {math.Inf(-1)}, {2.0}, {2.4}, {3.2}, {math.Inf(1)}}, typ: []types.T{types.Float64}, ordCols: []distsqlpb.Ordering_Column{{ColIdx: 0}}, }, diff --git a/pkg/sql/exec/utils_test.go b/pkg/sql/exec/utils_test.go index d89e78787e37..3055319e89aa 100644 --- a/pkg/sql/exec/utils_test.go +++ b/pkg/sql/exec/utils_test.go @@ -13,6 +13,7 @@ package exec import ( "context" "fmt" + "math" "math/rand" "reflect" "sort" @@ -516,6 +517,14 @@ func tupleEquals(expected tuple, actual tuple) bool { return false } } else { + // Special case for NaN, since it does not equal itself. + if f1, ok := expected[i].(float64); ok { + if f2, ok := actual[i].(float64); ok { + if math.IsNaN(f1) && math.IsNaN(f2) { + continue + } + } + } if !reflect.DeepEqual(reflect.ValueOf(actual[i]).Convert(reflect.TypeOf(expected[i])).Interface(), expected[i]) { return false }