From 6c323f3ac3977e4791c1b790c6eb1be226224655 Mon Sep 17 00:00:00 2001 From: Sheehan Olver Date: Thu, 2 May 2024 17:44:45 +0100 Subject: [PATCH 1/2] =?UTF-8?q?Support=20accumulate(=C2=B1,=20::AbstractFi?= =?UTF-8?q?ll)?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- src/FillArrays.jl | 26 ++++++++++++++++++++++---- test/runtests.jl | 4 ++-- 2 files changed, 24 insertions(+), 6 deletions(-) diff --git a/src/FillArrays.jl b/src/FillArrays.jl index 581e8252..b471dfdf 100644 --- a/src/FillArrays.jl +++ b/src/FillArrays.jl @@ -572,14 +572,32 @@ sum(x::AbstractZeros) = getindex_value(x) # needed to support infinite case steprangelen(st...) = StepRangeLen(st...) -cumsum(x::AbstractFill{<:Any,1}) = steprangelen(getindex_value(x), getindex_value(x), length(x)) +function cumsum(x::AbstractFill{T,1}) where T + V = promote_op(add_sum, T, T) + steprangelen(convert(V,getindex_value(x)), convert(V,getindex_value(x)), length(x)) +end -cumsum(x::AbstractZerosVector) = x -cumsum(x::AbstractZerosVector{Bool}) = x -cumsum(x::AbstractOnesVector{II}) where II<:Integer = convert(AbstractVector{II}, oneto(length(x))) +cumsum(x::AbstractZerosVector{T}) where T = convert(AbstractVector{promote_op(add_sum, T, T)}, x) +cumsum(x::AbstractZerosVector{Bool}) = convert(AbstractVector{Int}, x) +cumsum(x::AbstractOnesVector{T}) where T<:Integer = convert(AbstractVector{promote_op(add_sum, T, T)}, oneto(length(x))) cumsum(x::AbstractOnesVector{Bool}) = oneto(length(x)) +for op in (:+, :-) + @eval begin + function accumulate(::typeof($op), x::AbstractFill{T,1}) where T + V = promote_op($op, T, T) + steprangelen(convert(V,getindex_value(x)), convert(V,$op(getindex_value(x))), length(x)) + end + + accumulate(::typeof($op), x::AbstractZerosVector{T}) where T = convert(AbstractVector{promote_op($op, T, T)}, x) + accumulate(::typeof($op), x::AbstractZerosVector{Bool}) = convert(AbstractVector{Int}, x) + end +end + +accumulate(::typeof(+), x::AbstractOnesVector{T}) where T<:Integer = convert(AbstractVector{promote_op(+, T, T)}, oneto(length(x))) +accumulate(::typeof(+), x::AbstractOnesVector{Bool}) = oneto(length(x)) + ######### # Diff ######### diff --git a/test/runtests.jl b/test/runtests.jl index 2fd3c5ca..4369b4e0 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -817,11 +817,11 @@ end @test_throws MethodError sort!(Fill(im, 2)) end -@testset "Cumsum and diff" begin +@testset "Cumsum, accumulate and diff" begin @test sum(Fill(3,10)) ≡ 30 @test reduce(+, Fill(3,10)) ≡ 30 @test sum(x -> x + 1, Fill(3,10)) ≡ 40 - @test cumsum(Fill(3,10)) ≡ StepRangeLen(3,3,10) + @test cumsum(Fill(3,10)) ≡ accumulate(+, Fill(3,10)) ≡ StepRangeLen(3,3,10) @test sum(Ones(10)) ≡ 10.0 @test sum(x -> x + 1, Ones(10)) ≡ 20.0 From e947b1c829ab5a9118d719700acf54d34d38a74c Mon Sep 17 00:00:00 2001 From: Sheehan Olver Date: Thu, 2 May 2024 18:06:39 +0100 Subject: [PATCH 2/2] Overload accumulate and make types of cumsum consistent with Vector --- Project.toml | 2 +- src/FillArrays.jl | 18 ++++++++-------- src/fillbroadcast.jl | 2 ++ test/runtests.jl | 50 +++++++++++++++++++++++++++++++------------- 4 files changed, 48 insertions(+), 24 deletions(-) diff --git a/Project.toml b/Project.toml index caea01ad..dd70d301 100644 --- a/Project.toml +++ b/Project.toml @@ -1,6 +1,6 @@ name = "FillArrays" uuid = "1a297f60-69ca-5386-bcde-b61e274b549b" -version = "1.10.2" +version = "1.11" [deps] LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e" diff --git a/src/FillArrays.jl b/src/FillArrays.jl index b471dfdf..2996b0b2 100644 --- a/src/FillArrays.jl +++ b/src/FillArrays.jl @@ -7,7 +7,7 @@ import Base: size, getindex, setindex!, IndexStyle, checkbounds, convert, any, all, axes, isone, iszero, iterate, unique, allunique, permutedims, inv, copy, vec, setindex!, count, ==, reshape, map, zero, show, view, in, mapreduce, one, reverse, promote_op, promote_rule, repeat, - parent, similar, issorted + parent, similar, issorted, add_sum, accumulate, OneTo import LinearAlgebra: rank, svdvals!, tril, triu, tril!, triu!, diag, transpose, adjoint, fill!, dot, norm2, norm1, normInf, normMinusInf, normp, lmul!, rmul!, diagzero, AdjointAbsVec, TransposeAbsVec, @@ -574,12 +574,12 @@ sum(x::AbstractZeros) = getindex_value(x) steprangelen(st...) = StepRangeLen(st...) function cumsum(x::AbstractFill{T,1}) where T V = promote_op(add_sum, T, T) - steprangelen(convert(V,getindex_value(x)), convert(V,getindex_value(x)), length(x)) + steprangelen(convert(V,getindex_value(x)), getindex_value(x), length(x)) end -cumsum(x::AbstractZerosVector{T}) where T = convert(AbstractVector{promote_op(add_sum, T, T)}, x) -cumsum(x::AbstractZerosVector{Bool}) = convert(AbstractVector{Int}, x) -cumsum(x::AbstractOnesVector{T}) where T<:Integer = convert(AbstractVector{promote_op(add_sum, T, T)}, oneto(length(x))) +cumsum(x::AbstractZerosVector{T}) where T = _range_convert(AbstractVector{promote_op(add_sum, T, T)}, x) +cumsum(x::AbstractZerosVector{Bool}) = _range_convert(AbstractVector{Int}, x) +cumsum(x::AbstractOnesVector{T}) where T<:Integer = _range_convert(AbstractVector{promote_op(add_sum, T, T)}, oneto(length(x))) cumsum(x::AbstractOnesVector{Bool}) = oneto(length(x)) @@ -587,15 +587,15 @@ for op in (:+, :-) @eval begin function accumulate(::typeof($op), x::AbstractFill{T,1}) where T V = promote_op($op, T, T) - steprangelen(convert(V,getindex_value(x)), convert(V,$op(getindex_value(x))), length(x)) + steprangelen(convert(V,getindex_value(x)), $op(getindex_value(x)), length(x)) end - accumulate(::typeof($op), x::AbstractZerosVector{T}) where T = convert(AbstractVector{promote_op($op, T, T)}, x) - accumulate(::typeof($op), x::AbstractZerosVector{Bool}) = convert(AbstractVector{Int}, x) + accumulate(::typeof($op), x::AbstractZerosVector{T}) where T = _range_convert(AbstractVector{promote_op($op, T, T)}, x) + accumulate(::typeof($op), x::AbstractZerosVector{Bool}) = _range_convert(AbstractVector{Int}, x) end end -accumulate(::typeof(+), x::AbstractOnesVector{T}) where T<:Integer = convert(AbstractVector{promote_op(+, T, T)}, oneto(length(x))) +accumulate(::typeof(+), x::AbstractOnesVector{T}) where T<:Integer = _range_convert(AbstractVector{promote_op(+, T, T)}, oneto(length(x))) accumulate(::typeof(+), x::AbstractOnesVector{Bool}) = oneto(length(x)) ######### diff --git a/src/fillbroadcast.jl b/src/fillbroadcast.jl index d286418a..2b5ea59c 100644 --- a/src/fillbroadcast.jl +++ b/src/fillbroadcast.jl @@ -177,7 +177,9 @@ end # special case due to missing converts for ranges _range_convert(::Type{AbstractVector{T}}, a::AbstractRange{T}) where T = a _range_convert(::Type{AbstractVector{T}}, a::AbstractUnitRange) where T = convert(T,first(a)):convert(T,last(a)) +_range_convert(::Type{AbstractVector{T}}, a::OneTo) where T = OneTo(convert(T, a.stop)) _range_convert(::Type{AbstractVector{T}}, a::AbstractRange) where T = convert(T,first(a)):step(a):convert(T,last(a)) +_range_convert(::Type{AbstractVector{T}}, a::ZerosVector) where T = ZerosVector{T}(length(a)) # TODO: replacing with the following will support more general broadcasting. diff --git a/test/runtests.jl b/test/runtests.jl index 4369b4e0..08fd513c 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -818,30 +818,52 @@ end end @testset "Cumsum, accumulate and diff" begin - @test sum(Fill(3,10)) ≡ 30 - @test reduce(+, Fill(3,10)) ≡ 30 - @test sum(x -> x + 1, Fill(3,10)) ≡ 40 - @test cumsum(Fill(3,10)) ≡ accumulate(+, Fill(3,10)) ≡ StepRangeLen(3,3,10) + @test @inferred(sum(Fill(3,10))) ≡ 30 + @test @inferred(reduce(+, Fill(3,10))) ≡ 30 + @test @inferred(sum(x -> x + 1, Fill(3,10))) ≡ 40 + @test @inferred(cumsum(Fill(3,10))) ≡ @inferred(accumulate(+, Fill(3,10))) ≡ StepRangeLen(3,3,10) + @test @inferred(accumulate(-, Fill(3,10))) ≡ StepRangeLen(3,-3,10) - @test sum(Ones(10)) ≡ 10.0 - @test sum(x -> x + 1, Ones(10)) ≡ 20.0 - @test cumsum(Ones(10)) ≡ StepRangeLen(1.0, 1.0, 10) + @test @inferred(sum(Ones(10))) ≡ 10.0 + @test @inferred(sum(x -> x + 1, Ones(10))) ≡ 20.0 + @test @inferred(cumsum(Ones(10))) ≡ @inferred(accumulate(+, Ones(10))) ≡ StepRangeLen(1.0, 1.0, 10) + @test @inferred(accumulate(-, Ones(10))) ≡ StepRangeLen(1.0,-1.0,10) @test sum(Ones{Int}(10)) ≡ 10 @test sum(x -> x + 1, Ones{Int}(10)) ≡ 20 - @test cumsum(Ones{Int}(10)) ≡ Base.OneTo(10) + @test cumsum(Ones{Int}(10)) ≡ accumulate(+,Ones{Int}(10)) ≡ Base.OneTo(10) + @test accumulate(-, Ones{Int}(10)) ≡ StepRangeLen(1,-1,10) @test sum(Zeros(10)) ≡ 0.0 @test sum(x -> x + 1, Zeros(10)) ≡ 10.0 - @test cumsum(Zeros(10)) ≡ Zeros(10) + @test cumsum(Zeros(10)) ≡ accumulate(+,Zeros(10)) ≡ accumulate(-,Zeros(10)) ≡ Zeros(10) @test sum(Zeros{Int}(10)) ≡ 0 @test sum(x -> x + 1, Zeros{Int}(10)) ≡ 10 - @test cumsum(Zeros{Int}(10)) ≡ Zeros{Int}(10) - - @test cumsum(Zeros{Bool}(10)) ≡ Zeros{Bool}(10) - @test cumsum(Ones{Bool}(10)) ≡ Base.OneTo{Int}(10) - @test cumsum(Fill(true,10)) ≡ StepRangeLen(true, true, 10) + @test cumsum(Zeros{Int}(10)) ≡ accumulate(+,Zeros{Int}(10)) ≡ accumulate(-,Zeros{Int}(10)) ≡ Zeros{Int}(10) + + # we want cumsum of fills to match the types of the standard cusum + @test all(cumsum(Zeros{Bool}(10)) .≡ cumsum(zeros(Bool,10))) + @test all(accumulate(+, Zeros{Bool}(10)) .≡ accumulate(+, zeros(Bool,10)) .≡ accumulate(-, zeros(Bool,10))) + @test cumsum(Zeros{Bool}(10)) ≡ accumulate(+, Zeros{Bool}(10)) ≡ accumulate(-, Zeros{Bool}(10)) ≡ Zeros{Int}(10) + @test cumsum(Ones{Bool}(10)) ≡ accumulate(+, Ones{Bool}(10)) ≡ Base.OneTo{Int}(10) + @test all(cumsum(Fill(true,10)) .≡ cumsum(fill(true,10))) + @test cumsum(Fill(true,10)) ≡ StepRangeLen(1, true, 10) + + @test all(cumsum(Zeros{UInt8}(10)) .≡ cumsum(zeros(UInt8,10))) + @test all(accumulate(+, Zeros{UInt8}(10)) .≡ accumulate(+, zeros(UInt8,10))) + @test cumsum(Zeros{UInt8}(10)) ≡ Zeros{UInt64}(10) + @test accumulate(+, Zeros{UInt8}(10)) ≡ accumulate(-, Zeros{UInt8}(10)) ≡ Zeros{UInt8}(10) + + @test all(cumsum(Ones{UInt8}(10)) .≡ cumsum(ones(UInt8,10))) + @test all(accumulate(+, Ones{UInt8}(10)) .≡ accumulate(+, ones(UInt8,10))) + @test cumsum(Ones{UInt8}(10)) ≡ Base.OneTo(UInt64(10)) + @test accumulate(+, Ones{UInt8}(10)) ≡ Base.OneTo(UInt8(10)) + + @test all(cumsum(Fill(UInt8(2),10)) .≡ cumsum(fill(UInt8(2),10))) + @test all(accumulate(+, Fill(UInt8(2))) .≡ accumulate(+, fill(UInt8(2)))) + @test cumsum(Fill(UInt8(2),10)) ≡ StepRangeLen(UInt64(2), UInt8(2), 10) + @test accumulate(+, Fill(UInt8(2),10)) ≡ StepRangeLen(UInt8(2), UInt8(2), 10) @test diff(Fill(1,10)) ≡ Zeros{Int}(9) @test diff(Ones{Float64}(10)) ≡ Zeros{Float64}(9)