From b018f288c0ac336a8297d5c575e062e6bb791902 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Johannes=20M=C3=BCller?= Date: Sun, 28 Feb 2021 14:43:06 +0100 Subject: [PATCH] Implement rounding mode for `Number#round` (#10413) --- spec/std/number_spec.cr | 220 +++++++++++++++++++++++++++++++++++++++- src/float.cr | 32 +++++- src/int.cr | 14 ++- src/math/libm.cr | 6 ++ src/number.cr | 71 +++++++++++-- 5 files changed, 327 insertions(+), 16 deletions(-) diff --git a/spec/std/number_spec.cr b/spec/std/number_spec.cr index 4aa44f65c1e1..e74a45a67909 100644 --- a/spec/std/number_spec.cr +++ b/spec/std/number_spec.cr @@ -53,13 +53,32 @@ describe "Number" do 15.151.round.should eq(15) end - it "infinity" do - Float64::INFINITY.round.infinite?.should eq(1) - Float32::INFINITY.round.infinite?.should eq(1) - (-Float64::INFINITY).round.infinite?.should eq(-1) - (-Float32::INFINITY).round.infinite?.should eq(-1) + it "infinity Float64" do + Float64::INFINITY.round.should eq Float64::INFINITY + Float64::INFINITY.round(digits: 0).should eq Float64::INFINITY + Float64::INFINITY.round(digits: 3).should eq Float64::INFINITY + Float64::INFINITY.round(digits: -3).should eq Float64::INFINITY + (-Float64::INFINITY).round.should eq -Float64::INFINITY + (-Float64::INFINITY).round(digits: 0).should eq -Float64::INFINITY + (-Float64::INFINITY).round(digits: 3).should eq -Float64::INFINITY + (-Float64::INFINITY).round(digits: -3).should eq -Float64::INFINITY end + {% if compare_versions(Crystal::VERSION, "0.36.1") > 0 %} + it "infinity Float32" do + Float32::INFINITY.round.should eq Float32::INFINITY + Float32::INFINITY.round(digits: 0).should eq Float32::INFINITY + Float32::INFINITY.round(digits: 3).should eq Float32::INFINITY + Float32::INFINITY.round(digits: -3).should eq Float32::INFINITY + (-Float32::INFINITY).round.should eq -Float32::INFINITY + (-Float32::INFINITY).round(digits: 0).should eq -Float32::INFINITY + (-Float32::INFINITY).round(digits: 3).should eq -Float32::INFINITY + (-Float32::INFINITY).round(digits: -3).should eq -Float32::INFINITY + end + {% else %} + pending "infinity Float32" + {% end %} + it "nan" do Float64::NAN.round.nan?.should be_true Float32::NAN.round.nan?.should be_true @@ -116,6 +135,155 @@ describe "Number" do 6543210987654321.0.round(-15).should eq(7000000000000000.0) end + describe "rounding modes" do + it "to_zero" do + -1.5.round(:to_zero).should eq -1.0 + -1.0.round(:to_zero).should eq -1.0 + -0.9.round(:to_zero).should eq 0.0 + -0.5.round(:to_zero).should eq 0.0 + -0.1.round(:to_zero).should eq 0.0 + 0.0.round(:to_zero).should eq 0.0 + 0.1.round(:to_zero).should eq 0.0 + 0.5.round(:to_zero).should eq 0.0 + 0.9.round(:to_zero).should eq 0.0 + 1.0.round(:to_zero).should eq 1.0 + 1.5.round(:to_zero).should eq 1.0 + end + + it "to_positive" do + -1.5.round(:to_positive).should eq -1.0 + -1.0.round(:to_positive).should eq -1.0 + -0.9.round(:to_positive).should eq 0.0 + -0.5.round(:to_positive).should eq 0.0 + -0.1.round(:to_positive).should eq 0.0 + 0.0.round(:to_positive).should eq 0.0 + 0.1.round(:to_positive).should eq 1.0 + 0.5.round(:to_positive).should eq 1.0 + 0.9.round(:to_positive).should eq 1.0 + 1.0.round(:to_positive).should eq 1.0 + 1.5.round(:to_positive).should eq 2.0 + end + + it "to_negative" do + -1.5.round(:to_negative).should eq -2.0 + -1.0.round(:to_negative).should eq -1.0 + -0.9.round(:to_negative).should eq -1.0 + -0.5.round(:to_negative).should eq -1.0 + -0.1.round(:to_negative).should eq -1.0 + 0.0.round(:to_negative).should eq 0.0 + 0.1.round(:to_negative).should eq 0.0 + 0.5.round(:to_negative).should eq 0.0 + 0.9.round(:to_negative).should eq 0.0 + 1.0.round(:to_negative).should eq 1.0 + 1.5.round(:to_negative).should eq 1.0 + end + + it "ties_even" do + -2.5.round(:ties_even).should eq -2.0 + -1.5.round(:ties_even).should eq -2.0 + -1.0.round(:ties_even).should eq -1.0 + -0.9.round(:ties_even).should eq -1.0 + -0.5.round(:ties_even).should eq 0.0 + -0.1.round(:ties_even).should eq 0.0 + 0.0.round(:ties_even).should eq 0.0 + 0.1.round(:ties_even).should eq 0.0 + 0.5.round(:ties_even).should eq 0.0 + 0.9.round(:ties_even).should eq 1.0 + 1.0.round(:ties_even).should eq 1.0 + 1.5.round(:ties_even).should eq 2.0 + 2.5.round(:ties_even).should eq 2.0 + end + + it "ties_away" do + -2.5.round(:ties_away).should eq -3.0 + -1.5.round(:ties_away).should eq -2.0 + -1.0.round(:ties_away).should eq -1.0 + -0.9.round(:ties_away).should eq -1.0 + -0.5.round(:ties_away).should eq -1.0 + -0.1.round(:ties_away).should eq 0.0 + 0.0.round(:ties_away).should eq 0.0 + 0.1.round(:ties_away).should eq 0.0 + 0.5.round(:ties_away).should eq 1.0 + 0.9.round(:ties_away).should eq 1.0 + 1.0.round(:ties_away).should eq 1.0 + 1.5.round(:ties_away).should eq 2.0 + 2.5.round(:ties_away).should eq 3.0 + end + + it "default (=ties_away)" do + -2.5.round.should eq -3.0 + -1.5.round.should eq -2.0 + -1.0.round.should eq -1.0 + -0.9.round.should eq -1.0 + -0.5.round.should eq -1.0 + -0.1.round.should eq 0.0 + 0.0.round.should eq 0.0 + 0.1.round.should eq 0.0 + 0.5.round.should eq 1.0 + 0.9.round.should eq 1.0 + 1.0.round.should eq 1.0 + 1.5.round.should eq 2.0 + 2.5.round.should eq 3.0 + end + end + + describe "with digits" do + it "to_zero" do + 12.345.round(-1, mode: :to_zero).should eq 10 + 12.345.round(0, mode: :to_zero).should eq 12 + 12.345.round(1, mode: :to_zero).should eq 12.3 + 12.345.round(2, mode: :to_zero).should eq 12.34 + -12.345.round(-1, mode: :to_zero).should eq -10 + -12.345.round(0, mode: :to_zero).should eq -12 + -12.345.round(1, mode: :to_zero).should eq -12.3 + -12.345.round(2, mode: :to_zero).should eq -12.34 + end + + it "to_positive" do + 12.345.round(-1, mode: :to_positive).should eq 20 + 12.345.round(0, mode: :to_positive).should eq 13 + 12.345.round(1, mode: :to_positive).should eq 12.4 + 12.345.round(2, mode: :to_positive).should eq 12.35 + -12.345.round(-1, mode: :to_positive).should eq -10 + -12.345.round(0, mode: :to_positive).should eq -12 + -12.345.round(1, mode: :to_positive).should eq -12.3 + -12.345.round(2, mode: :to_positive).should eq -12.34 + end + + it "to_negative" do + 12.345.round(-1, mode: :to_negative).should eq 10 + 12.345.round(0, mode: :to_negative).should eq 12 + 12.345.round(1, mode: :to_negative).should eq 12.3 + 12.345.round(2, mode: :to_negative).should eq 12.34 + -12.345.round(-1, mode: :to_negative).should eq -20 + -12.345.round(0, mode: :to_negative).should eq -13 + -12.345.round(1, mode: :to_negative).should eq -12.4 + -12.345.round(2, mode: :to_negative).should eq -12.35 + end + + it "ties_away" do + 13.825.round(-1, mode: :ties_away).should eq 10 + 13.825.round(0, mode: :ties_away).should eq 14 + 13.825.round(1, mode: :ties_away).should eq 13.8 + 13.825.round(2, mode: :ties_away).should eq 13.83 + -13.825.round(-1, mode: :ties_away).should eq -10 + -13.825.round(0, mode: :ties_away).should eq -14 + -13.825.round(1, mode: :ties_away).should eq -13.8 + -13.825.round(2, mode: :ties_away).should eq -13.83 + end + + it "ties_even" do + 15.255.round(-1, mode: :ties_even).should eq 20 + 15.255.round(0, mode: :ties_even).should eq 15 + 15.255.round(1, mode: :ties_even).should eq 15.3 + 15.255.round(2, mode: :ties_even).should eq 15.26 + -15.255.round(-1, mode: :ties_even).should eq -20 + -15.255.round(0, mode: :ties_even).should eq -15 + -15.255.round(1, mode: :ties_even).should eq -15.3 + -15.255.round(2, mode: :ties_even).should eq -15.26 + end + end + describe "base" do it "2" do -1763.116.round(2, base: 2).should eq(-1763.0) @@ -132,6 +300,48 @@ describe "Number" do end end + describe "#round_even" do + -2.5.round_even.should eq -2.0 + -1.5.round_even.should eq -2.0 + -1.0.round_even.should eq -1.0 + -0.9.round_even.should eq -1.0 + -0.5.round_even.should eq -0.0 + -0.1.round_even.should eq 0.0 + 0.0.round_even.should eq 0.0 + 0.1.round_even.should eq 0.0 + 0.5.round_even.should eq 0.0 + 0.9.round_even.should eq 1.0 + 1.0.round_even.should eq 1.0 + 1.5.round_even.should eq 2.0 + 2.5.round_even.should eq 2.0 + + 1.round_even.should eq 1 + 1.round_even.should be_a(Int32) + 1_u8.round_even.should be_a(UInt8) + 1_f32.round_even.should be_a(Float32) + end + + describe "#round_away" do + -2.5.round_away.should eq -3.0 + -1.5.round_away.should eq -2.0 + -1.0.round_away.should eq -1.0 + -0.9.round_away.should eq -1.0 + -0.5.round_away.should eq -1.0 + -0.1.round_away.should eq 0.0 + 0.0.round_away.should eq 0.0 + 0.1.round_away.should eq 0.0 + 0.5.round_away.should eq 1.0 + 0.9.round_away.should eq 1.0 + 1.0.round_away.should eq 1.0 + 1.5.round_away.should eq 2.0 + 2.5.round_away.should eq 3.0 + + 1.round_away.should eq 1 + 1.round_away.should be_a(Int32) + 1_u8.round_away.should be_a(UInt8) + 1_f32.round_away.should be_a(Float32) + end + it "gives the absolute value" do 123.abs.should eq(123) -123.abs.should eq(123) diff --git a/src/float.cr b/src/float.cr index bc4fdf0c2540..cee021be938b 100644 --- a/src/float.cr +++ b/src/float.cr @@ -151,7 +151,21 @@ struct Float32 LibM.floor_f32(self) end - def round + # Rounds towards the nearest integer. If both neighboring integers are equidistant, + # rounds towards the even neighbor (Banker's rounding). + def round_even : self + # LLVM 11 introduced llvm.roundeven.* intrinsics which may replace rint in + # the future. + {% if compare_versions(Crystal::LLVM_VERSION, "11.0.0") >= 0 %} + LibM.roundeven_f32(self) + {% else %} + LibM.rint_f32(self) + {% end %} + end + + # Rounds towards the nearest integer. If both neighboring integers are equidistant, + # rounds away from zero. + def round_away LibM.round_f32(self) end @@ -238,7 +252,21 @@ struct Float64 LibM.floor_f64(self) end - def round + # Rounds towards the nearest integer. If both neighboring integers are equidistant, + # rounds towards the even neighbor (Banker's rounding). + def round_even : self + # LLVM 11 introduced llvm.roundeven.* intrinsics which may replace rint in + # the future. + {% if compare_versions(Crystal::LLVM_VERSION, "11.0.0") >= 0 %} + LibM.roundeven_f64(self) + {% else %} + LibM.rint_f64(self) + {% end %} + end + + # Rounds towards the nearest integer. If both neighboring integers are equidistant, + # rounds away from zero. + def round_away LibM.round_f64(self) end diff --git a/src/int.cr b/src/int.cr index 90f20589e4e6..2dd86031e8a7 100644 --- a/src/int.cr +++ b/src/int.cr @@ -241,6 +241,10 @@ struct Int self >= 0 ? self : -self end + def round(mode : RoundingMode) : self + self + end + def ceil self end @@ -249,11 +253,17 @@ struct Int self end - def round + def trunc + self + end + + # Returns `self`. + def round_even : self self end - def trunc + # Returns `self`. + def round_away self end diff --git a/src/math/libm.cr b/src/math/libm.cr index 0899fc3998ac..c71bd2803c26 100644 --- a/src/math/libm.cr +++ b/src/math/libm.cr @@ -34,6 +34,12 @@ lib LibM {% end %} fun round_f32 = "llvm.round.f32"(value : Float32) : Float32 fun round_f64 = "llvm.round.f64"(value : Float64) : Float64 + {% if compare_versions(Crystal::LLVM_VERSION, "11.0.0") >= 0 %} + fun roundeven_f32 = "llvm.roundeven.f32"(value : Float32) : Float32 + fun roundeven_f64 = "llvm.roundeven.f64"(value : Float64) : Float64 + {% end %} + fun rint_f32 = "llvm.rint.f32"(value : Float32) : Float32 + fun rint_f64 = "llvm.rint.f64"(value : Float64) : Float64 fun sin_f32 = "llvm.sin.f32"(value : Float32) : Float32 fun sin_f64 = "llvm.sin.f64"(value : Float64) : Float64 fun sqrt_f32 = "llvm.sqrt.f32"(value : Float32) : Float32 diff --git a/src/number.cr b/src/number.cr index 9255410534f2..d8e1ef8a9798 100644 --- a/src/number.cr +++ b/src/number.cr @@ -389,19 +389,76 @@ struct Number self.class.new((x / y).round * y) end - # Rounds this number to a given precision in decimal *digits*. + # Rounds this number to a given precision. + # + # Rounds to the specified number of *digits* after the decimal place, + # (or before if negative), in base *base*. + # + # The rounding *mode* controls the direction of the rounding. The default is + # `RoundingMode::TIES_AWAY` which rounds to the nearest integer, with ties + # (fractional value of `0.5`) being rounded away from zero. # # ``` # -1763.116.round(2) # => -1763.12 # ``` - def round(digits = 0, base = 10) - x = self.to_f + def round(digits : Number, base = 10, *, mode : RoundingMode = :ties_away) + if digits < 0 + multiplier = base.to_f ** digits.abs + shifted = self / multiplier + else + multiplier = base.to_f ** digits + shifted = self * multiplier + end + + rounded = shifted.round(mode) + if digits < 0 - y = base.to_f ** digits.abs - self.class.new((x / y).round * y) + result = rounded * multiplier else - y = base.to_f ** digits - self.class.new((x * y).round / y) + result = rounded / multiplier + end + + self.class.new result + end + + # Specifies rounding behaviour for numerical operations capable of discarding + # precision. + enum RoundingMode + # Rounds towards the nearest integer. If both neighboring integers are equidistant, + # rounds towards the even neighbor (Banker's rounding). + TIES_EVEN + + # Rounds towards the nearest integer. If both neighboring integers are equidistant, + # rounds away from zero. + TIES_AWAY + + # Rounds towards zero (truncate). + TO_ZERO + + # Rounds towards positive infinity (ceil). + TO_POSITIVE + + # Rounds towards negative infinity (floor). + TO_NEGATIVE + end + + # Rounds `self` to an integer value using rounding *mode*. + # + # The rounding mode controls the direction of the rounding. The default is + # `RoundingMode::TIES_AWAY` which rounds to the nearest integer, with ties + # (fractional value of `0.5`) being rounded away from zero. + def round(mode : RoundingMode = :ties_away) : self + case mode + in .to_zero? + trunc + in .to_positive? + ceil + in .to_negative? + floor + in .ties_away? + round_away + in .ties_even? + round_even end end