Skip to content

Commit

Permalink
Fix BigFloat for Julia v1.12 (#307)
Browse files Browse the repository at this point in the history
  • Loading branch information
odow authored Oct 13, 2024
1 parent 381a59d commit 208d0f2
Show file tree
Hide file tree
Showing 2 changed files with 86 additions and 143 deletions.
130 changes: 56 additions & 74 deletions src/implementations/BigFloat.jl
Original file line number Diff line number Diff line change
Expand Up @@ -9,12 +9,20 @@

mutability(::Type{BigFloat}) = IsMutable()

# Copied from `deepcopy_internal` implementation in Julia:
# https://github.com/JuliaLang/julia/blob/7d41d1eb610cad490cbaece8887f9bbd2a775021/base/mpfr.jl#L1041-L1050
function mutable_copy(x::BigFloat)
d = x._d
d′ = GC.@preserve d unsafe_string(pointer(d), sizeof(d)) # creates a definitely-new String
return Base.MPFR._BigFloat(x.prec, x.sign, x.exp, d′)
# These methods are copied from `deepcopy_internal` in `base/mpfr.jl`. We don't
# use `mutable_copy(x) = deepcopy(x)` because this creates an empty `IdDict()`
# which costs some extra allocations. We don't need the IdDict case because we
# never call `mutable_copy` recursively.
@static if VERSION >= v"1.12.0-DEV.1343"
mutable_copy(x::BigFloat) = Base.MPFR._BigFloat(copy(getfield(x, :d)))
else
function mutable_copy(x::BigFloat)
d = x._d
GC.@preserve d begin
d′ = unsafe_string(pointer(d), sizeof(d))
return Base.MPFR._BigFloat(x.prec, x.sign, x.exp, d′)
end
end
end

const _MPFRRoundingMode = Base.MPFR.MPFRRoundingMode
Expand Down Expand Up @@ -297,12 +305,12 @@ function operate_to!(
end

struct DotBuffer{F<:Real}
compensation::F
summation_temp::F
multiplication_temp::F
inner_temp::F
c::F
t::F
input::F
tmp::F

DotBuffer{F}() where {F<:Real} = new{F}(ntuple(i -> F(), Val{4}())...)
DotBuffer{F}() where {F<:Real} = new{F}(zero(F), zero(F), zero(F), zero(F))
end

function buffer_for(
Expand Down Expand Up @@ -366,87 +374,61 @@ end
#
# function KahanBabushkaNeumaierSum(input)
# sum = 0.0
#
# # A running compensation for lost low-order bits.
# c = 0.0
#
# for i ∈ eachindex(input)
# t = sum + input[i]
#
# if abs(input[i]) ≤ abs(sum)
# c += (sum - t) + input[i]
# tmp = (sum - t) + input[i]
# else
# c += (input[i] - t) + sum
# tmp = (input[i] - t) + sum
# end
#
# c += tmp
# sum = t
# end
#
# # The result, with the correction only applied once in the very
# # end.
# # The result, with the correction only applied once in the very end.
# sum + c
# end
function buffered_operate_to!(
buf::DotBuffer{F},
sum::F,
::typeof(LinearAlgebra.dot),
x::AbstractVector{F},
y::AbstractVector{F},
) where {F<:BigFloat}
set! = (o, i) -> operate_to!(o, copy, i)

local swap! = function (x::BigFloat, y::BigFloat)
ccall((:mpfr_swap, :libmpfr), Cvoid, (Ref{BigFloat}, Ref{BigFloat}), x, y)
return nothing
# Returns abs(x) <= abs(y) without allocating.
function _abs_lte_abs(x::BigFloat, y::BigFloat)
x_is_neg, y_is_neg = signbit(x), signbit(y)
if x_is_neg != y_is_neg
operate!(-, x)
end

# Returns abs(x) <= abs(y) without allocating.
local abs_lte_abs = function (x::F, y::F)
local x_is_neg = signbit(x)
local y_is_neg = signbit(y)

local x_neg = x_is_neg != y_is_neg

x_neg && operate!(-, x)

local ret = if y_is_neg
y <= x
else
x <= y
end

x_neg && operate!(-, x)

return ret
ret = y_is_neg ? y <= x : x <= y
if x_is_neg != y_is_neg
operate!(-, x)
end
return ret
end

operate!(zero, sum)
operate!(zero, buf.compensation)

for i in 0:(length(x)-1)
set!(buf.multiplication_temp, x[begin+i])
operate!(*, buf.multiplication_temp, y[begin+i])

operate!(zero, buf.summation_temp)
operate_to!(buf.summation_temp, +, buf.multiplication_temp, sum)

if abs_lte_abs(buf.multiplication_temp, sum)
set!(buf.inner_temp, sum)
operate!(-, buf.inner_temp, buf.summation_temp)
operate!(+, buf.inner_temp, buf.multiplication_temp)
function buffered_operate_to!(
buf::DotBuffer{BigFloat},
sum::BigFloat,
::typeof(LinearAlgebra.dot),
x::AbstractVector{BigFloat},
y::AbstractVector{BigFloat},
) # See pseudocode description
operate!(zero, sum) # sum = 0
operate!(zero, buf.c) # c = 0
for (xi, yi) in zip(x, y) # for i in eachindex(input)
operate_to!(buf.input, copy, xi) # input = x[i]
operate!(*, buf.input, yi) # input = x[i] * y[i]
operate_to!(buf.t, +, sum, buf.input) # t = sum + input
if _abs_lte_abs(buf.input, sum) # if |input| < |sum|
operate_to!(buf.tmp, copy, sum) # tmp = sum
operate!(-, buf.tmp, buf.t) # tmp = sum - t
operate!(+, buf.tmp, buf.input) # tmp = (sum - t) + input
else
set!(buf.inner_temp, buf.multiplication_temp)
operate!(-, buf.inner_temp, buf.summation_temp)
operate!(+, buf.inner_temp, sum)
operate_to!(buf.tmp, copy, buf.input) # tmp = input
operate!(-, buf.tmp, buf.t) # tmp = input - t
operate!(+, buf.tmp, sum) # tmp = (input - t) + sum
end

operate!(+, buf.compensation, buf.inner_temp)

swap!(sum, buf.summation_temp)
operate!(+, buf.c, buf.tmp) # c += tmp
operate_to!(sum, copy, buf.t) # sum = t
end

operate!(+, sum, buf.compensation)

operate!(+, sum, buf.c) # sum += c
return sum
end

Expand Down
99 changes: 30 additions & 69 deletions test/bigfloat_dot.jl
Original file line number Diff line number Diff line change
Expand Up @@ -4,71 +4,6 @@
# v.2.0. If a copy of the MPL was not distributed with this file, You can obtain
# one at http://mozilla.org/MPL/2.0/.

backup_bigfloats(v::AbstractVector{BigFloat}) = map(MA.copy_if_mutable, v)

absolute_error(accurate::Real, approximate::Real) = abs(accurate - approximate)

function relative_error(accurate::Real, approximate::Real)
return absolute_error(accurate, approximate) / abs(accurate)
end

function dotter(x::V, y::V) where {V<:AbstractVector{<:Real}}
let x = x, y = y
() -> LinearAlgebra.dot(x, y)
end
end

function reference_dot(x::V, y::V) where {F<:Real,V<:AbstractVector{F}}
return setprecision(dotter(x, y), F, 8 * precision(F))
end

function dot_test_relative_error(x::V, y::V) where {V<:AbstractVector{BigFloat}}
buf = MA.buffer_for(LinearAlgebra.dot, V, V)

input = (x, y)
backup = map(backup_bigfloats, input)

output = BigFloat()

MA.buffered_operate_to!!(buf, output, LinearAlgebra.dot, input...)

@test input == backup

return relative_error(reference_dot(input...), output)
end

subtracter(s::Real) =
let s = s
x -> x - s
end

our_rand(n::Int, bias::Real) = map(subtracter(bias), rand(BigFloat, n))

function rand_dot_rel_err(size::Int, bias::Real)
x = our_rand(size, bias)
y = our_rand(size, bias)
return dot_test_relative_error(x, y)
end

function max_rand_dot_rel_err(size::Int, bias::Real, iter_cnt::Int)
max_rel_err = zero(BigFloat)
for i in 1:iter_cnt
rel_err = rand_dot_rel_err(size, bias)
<(max_rel_err, rel_err) && (max_rel_err = rel_err)
end
return max_rel_err
end

function max_rand_dot_ulps(size::Int, bias::Real, iter_cnt::Int)
return max_rand_dot_rel_err(size, bias, iter_cnt) / eps(BigFloat)
end

function ulper(size::Int, bias::Real, iter_cnt::Int)
let s = size, b = bias, c = iter_cnt
() -> max_rand_dot_ulps(s, b, c)
end
end

@testset "prec:$prec size:$size bias:$bias" for (prec, size, bias) in
Iterators.product(
# These precisions (in bits) are most probably smaller than what
Expand All @@ -78,11 +13,9 @@ end
# precision (except when vector lengths are really huge with
# respect to the precision).
(32, 64),

# Compensated summation should be accurate even for very large
# input vectors, so test that.
(10000,),

# The zero "bias" signifies that the input will be entirely
# nonnegative (drawn from the interval [0, 1]), while a positive
# bias shifts that interval towards negative infinity. We want to
Expand All @@ -91,8 +24,36 @@ end
# no guarantee on the relative error in that case.
(0.0, 2^-2, 2^-2 + 2^-3 + 2^-4),
)
iter_cnt = 10
err = setprecision(ulper(size, bias, iter_cnt), BigFloat, prec)
err = setprecision(BigFloat, prec) do
maximum_relative_error = mapreduce(max, 1:10) do _
# Generate some random vectors for dot(x, y) input.
x = rand(BigFloat, size) .- bias
y = rand(BigFloat, size) .- bias
# Copy x and y so that we can check we haven't mutated them after
# the fact.
old_x, old_y = MA.copy_if_mutable(x), MA.copy_if_mutable(y)
# Compute output = dot(x, y)
buf = MA.buffer_for(
LinearAlgebra.dot,
Vector{BigFloat},
Vector{BigFloat},
)
output = BigFloat()
MA.buffered_operate_to!!(buf, output, LinearAlgebra.dot, x, y)
# Check that we haven't mutated x or y
@test old_x == x
@test old_y == y
# Compute dot(x, y) in larger precision. This will be used to
# compare with our `dot`.
accurate = setprecision(BigFloat, 8 * precision(BigFloat)) do
return LinearAlgebra.dot(x, y)
end
# Compute the relative error
return abs(accurate - output) / abs(accurate)
end
# Return estimate for ULP
return maximum_relative_error / eps(BigFloat)
end
@test 0 <= err < 1
end

Expand Down

0 comments on commit 208d0f2

Please sign in to comment.