Skip to content

Commit

Permalink
Use destination type to determine output of cumsum! and cumprod!
Browse files Browse the repository at this point in the history
  • Loading branch information
simonbyrne committed Jan 31, 2018
1 parent 512fbcd commit c5f5d47
Show file tree
Hide file tree
Showing 2 changed files with 27 additions and 8 deletions.
34 changes: 26 additions & 8 deletions base/accumulate.jl
Original file line number Diff line number Diff line change
Expand Up @@ -416,15 +416,33 @@ function _accumulate_pairwise_small!(op, dest::AbstractArray{T}, itr, accv, w, i
end
end

"""
Base.ConvertOp{T}(op)(x,y)
An operator which converts `x` and `y` to type `T` before performing the `op`.
The main purpose is for use in [`cumsum!`](@ref) and [`cumprod!`](@ref), where `T` is determined by the output array.
"""
struct ConvertOp{T,O} <: Function
op::O
end
ConvertOp{T}(op::O) where {T,O} = ConvertOp{T,O}(op)
(c::ConvertOp{T})(x,y) where {T} = c.op(convert(T,x),convert(T,y))
reduce_first(c::ConvertOp{T},x) where {T} = reduce_first(c.op, convert(T,x))




function cumsum!(out, v::AbstractVector{T}) where T
function cumsum!(out::AbstractVector, v::AbstractVector{T}) where T
# we dispatch on the possibility of numerical accuracy issues
cumsum!(out, v, ArithmeticStyle(T))
end
cumsum!(out, v::AbstractVector, ::ArithmeticRounds) = accumulate_pairwise!(+, out, v)
cumsum!(out, v::AbstractVector, ::ArithmeticUnknown) = accumulate_pairwise!(+, out, v)
cumsum!(out, v::AbstractVector, ::ArithmeticStyle) = accumulate!(+, out, v)
cumsum!(out::AbstractVector{T}, v::AbstractVector, ::ArithmeticRounds) where {T} =
accumulate_pairwise!(ConvertOp{T}(+), out, v)
cumsum!(out::AbstractVector{T}, v::AbstractVector, ::ArithmeticUnknown) where {T} =
accumulate_pairwise!(ConvertOp{T}(+), out, v)
cumsum!(out::AbstractVector{T}, v::AbstractVector, ::ArithmeticStyle) where {T} =
accumulate!(ConvertOp{T}(+), out, v)

"""
cumsum(A, dim::Integer)
Expand Down Expand Up @@ -488,14 +506,14 @@ cumsum(v::AbstractVector, ::ArithmeticStyle) = accumulate(add_sum, v)
Cumulative sum of `A` along the dimension `dim`, storing the result in `B`. See also [`cumsum`](@ref).
"""
cumsum!(dest, A, dim::Integer) = accumulate!(+, dest, A, dim)
cumsum!(dest::AbstractArray{T}, A, dim::Integer) where {T} = accumulate!(ConvertOp{T}(+), dest, A, dim)

"""
cumsum!(y::AbstractVector, x::AbstractVector)
Cumulative sum of a vector `x`, storing the result in `y`. See also [`cumsum`](@ref).
"""
cumsum!(dest, itr) = accumulate!(+, dest, src)
cumsum!(dest::AbstractArray{T}, itr) where {T} = accumulate!(ConvertOp{T}(+), dest, src)

"""
cumprod(A, dim::Integer)
Expand Down Expand Up @@ -555,12 +573,12 @@ cumprod(x::AbstractVector) = accumulate(mul_prod, x)
Cumulative product of `A` along the dimension `dim`, storing the result in `B`.
See also [`cumprod`](@ref).
"""
cumprod!(dest, A, dim::Integer) = accumulate!(*, dest, A, dim)
cumprod!(dest::AbstractArray{T}, A, dim::Integer) where {T} = accumulate!(ConvertOp{T}(*), dest, A, dim)

"""
cumprod!(y::AbstractVector, x::AbstractVector)
Cumulative product of a vector `x`, storing the result in `y`.
See also [`cumprod`](@ref).
"""
cumprod!(dest, itr) = accumulate!(*, dest, itr)
cumprod!(dest::AbstractArray{T}, itr) where {T} = accumulate!(ConvertOp{T}(*), dest, itr)
1 change: 1 addition & 0 deletions base/reduce.jl
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@ add_sum(x::SmallSigned) = Int(x)
add_sum(x::SmallUnsigned) = UInt(x)
add_sum(X::AbstractArray) = broadcast(add_sum, X)


"""
Base.mul_prod(x,y)
Expand Down

0 comments on commit c5f5d47

Please sign in to comment.