diff --git a/base/multidimensional.jl b/base/multidimensional.jl index 95b537e4460ad..e85400ebd355a 100644 --- a/base/multidimensional.jl +++ b/base/multidimensional.jl @@ -574,13 +574,18 @@ function accumulate_pairwise(op, v::AbstractVector{T}) where T end function cumsum!(out, v::AbstractVector, axis::Integer=1) - # for types prone to numerical stability issues, we want - # accumulate_pairwise. - axis == 1 ? accumulate_pairwise!(+, out, v) : copy!(out,v) + # we dispatch on the possibility of numerical stability issues + _cumsum!(out, v, axis, TypeArithmetic(eltype(out))) end -function cumsum!(out, v::AbstractVector{<:Integer}, axis::Integer=1) - axis == 1 ? accumulate!(+, out, v) : copy!(out,v) +function _cumsum!(out, v, axis, ::ArithmeticRounds) + axis == 1 ? accumulate_pairwise!(+, out, v) : copy!(out, v) +end +function _cumsum!(out, v, axis, ::ArithmeticUnknown) + _cumsum!(out, v, axis, ArithmeticRounds()) +end +function _cumsum!(out, v, axis, ::TypeArithmetic) + axis == 1 ? accumulate!(+, out, v) : copy!(out, v) end """ diff --git a/test/arrayops.jl b/test/arrayops.jl index ae92e9ba1401b..c6851c018f04d 100644 --- a/test/arrayops.jl +++ b/test/arrayops.jl @@ -2055,6 +2055,28 @@ end @test accumulate(op, [10 20 30], 2) == [10 op(10, 20) op(op(10, 20), 30)] == [10 40 110] end +struct F21666{T <: Base.TypeArithmetic} + x::Float32 +end + +@testset "Exactness of cumsum # 21666" begin + # test that cumsum uses more stable algorithm + # for types with unknown/rounding arithmetic + Base.TypeArithmetic(::Type{F21666{T}}) where {T} = T + Base.:+(x::F, y::F) where {F <: F21666} = F(x.x + y.x) + Base.convert(::Type{Float64}, x::F21666) = Float64(x.x) + # we make v pretty large, because stable algorithm may have a large base case + v = zeros(300); v[1] = 2; v[200:end] = eps(Float32) + + f_rounds = Float64.(cumsum(F21666{Base.ArithmeticRounds}.(v))) + f_unknown = Float64.(cumsum(F21666{Base.ArithmeticUnknown}.(v))) + f_truth = cumsum(v) + f_inexact = Float64.(accumulate(+, Float32.(v))) + @test f_rounds == f_unknown + @test f_rounds != f_inexact + @test norm(f_truth - f_rounds) < norm(f_truth - f_inexact) +end + @testset "zeros and ones" begin @test ones([1,2], Float64, (2,3)) == ones(2,3) @test ones(2) == ones(Int, 2) == ones([2,3], Float32, 2) == [1,1]