From 8628143b37c5d27385df99ddbedd33e1ae50b885 Mon Sep 17 00:00:00 2001 From: Takafumi Arakaki Date: Thu, 26 Mar 2020 13:16:45 -0700 Subject: [PATCH] Implement `accumulate` and friends for arbitrary iterators (#34656) --- NEWS.md | 2 +- base/accumulate.jl | 27 +++++++++++++++++++++++---- base/iterators.jl | 25 ++++++++++++++++++------- test/iterators.jl | 9 +++++++++ 4 files changed, 51 insertions(+), 12 deletions(-) diff --git a/NEWS.md b/NEWS.md index 936a7c932c7ed..1fda7342e3570 100644 --- a/NEWS.md +++ b/NEWS.md @@ -92,7 +92,7 @@ New library features * `isapprox` (or `≈`) now has a one-argument "curried" method `isapprox(x)` which returns a function, like `isequal` (or `==`)` ([#32305]). * `Ref{NTuple{N,T}}` can be passed to `Ptr{T}`/`Ref{T}` `ccall` signatures ([#34199]) -* `accumulate`, `cumsum`, and `cumprod` now support `Tuple` ([#34654]). +* `accumulate`, `cumsum`, and `cumprod` now support `Tuple` ([#34654]) and arbitrary iterators ([#34656]). * In `splice!` with no replacement, values to be removed can now be specified with an arbitrary iterable (instead of a `UnitRange`) ([#34524]). diff --git a/base/accumulate.jl b/base/accumulate.jl index 8e959df71b28a..25cb859e3937c 100644 --- a/base/accumulate.jl +++ b/base/accumulate.jl @@ -92,14 +92,14 @@ function cumsum(A::AbstractArray{T}; dims::Integer) where T end """ - cumsum(itr::Union{AbstractVector,Tuple}) + cumsum(itr) Cumulative sum an iterator. See also [`cumsum!`](@ref) to use a preallocated output array, both for performance and to control the precision of the output (e.g. to avoid overflow). !!! compat "Julia 1.5" - `cumsum` on a tuple requires at least Julia 1.5. + `cumsum` on a non-array iterator requires at least Julia 1.5. # Examples ```jldoctest @@ -117,6 +117,12 @@ julia> cumsum([fill(1, 2) for i in 1:3]) julia> cumsum((1, 1, 1)) (1, 2, 3) + +julia> cumsum(x^2 for x in 1:3) +3-element Array{Int64,1}: + 1 + 5 + 14 ``` """ cumsum(x::AbstractVector) = cumsum(x, dims=1) @@ -170,14 +176,14 @@ function cumprod(A::AbstractArray; dims::Integer) end """ - cumprod(itr::Union{AbstractVector,Tuple}) + cumprod(itr) Cumulative product of an iterator. See also [`cumprod!`](@ref) to use a preallocated output array, both for performance and to control the precision of the output (e.g. to avoid overflow). !!! compat "Julia 1.5" - `cumprod` on a tuple requires at least Julia 1.5. + `cumprod` on a non-array iterator requires at least Julia 1.5. # Examples ```jldoctest @@ -195,6 +201,12 @@ julia> cumprod([fill(1//3, 2, 2) for i in 1:3]) julia> cumprod((1, 2, 1)) (1, 2, 2) + +julia> cumprod(x^2 for x in 1:3) +3-element Array{Int64,1}: + 1 + 4 + 36 ``` """ cumprod(x::AbstractVector) = cumprod(x, dims=1) @@ -210,6 +222,9 @@ also [`accumulate!`](@ref) to use a preallocated output array, both for performa to control the precision of the output (e.g. to avoid overflow). For common operations there are specialized variants of `accumulate`, see: [`cumsum`](@ref), [`cumprod`](@ref) +!!! compat "Julia 1.5" + `accumulate` on a non-array iterator requires at least Julia 1.5. + # Examples ```jldoctest julia> accumulate(+, [1,2,3]) @@ -250,6 +265,10 @@ julia> accumulate(+, fill(1, 3, 3), dims=2) ``` """ function accumulate(op, A; dims::Union{Nothing,Integer}=nothing, kw...) + if dims === nothing && !(A isa AbstractVector) + # This branch takes care of the cases not handled by `_accumulate!`. + return collect(Iterators.accumulate(op, A; kw...)) + end nt = kw.data if nt isa NamedTuple{()} out = similar(A, promote_op(op, eltype(A), eltype(A))) diff --git a/base/iterators.jl b/base/iterators.jl index 20baf2077e59d..da7fa2656e90e 100644 --- a/base/iterators.jl +++ b/base/iterators.jl @@ -443,13 +443,14 @@ reverse(f::Filter) = Filter(f.flt, reverse(f.itr)) # Accumulate -- partial reductions of a function over an iterator -struct Accumulate{F,I} +struct Accumulate{F,I,T} f::F itr::I + init::T end """ - Iterators.accumulate(f, itr) + Iterators.accumulate(f, itr; [init]) Given a 2-argument function `f` and an iterator `itr`, return a new iterator that successively applies `f` to the previous value and the @@ -457,26 +458,36 @@ next element of `itr`. This is effectively a lazy version of [`Base.accumulate`](@ref). +!!! compat "Julia 1.5" + Keyword argument `init` is added in Julia 1.5. + # Examples ```jldoctest -julia> f = Iterators.accumulate(+, [1,2,3,4]) -Base.Iterators.Accumulate{typeof(+),Array{Int64,1}}(+, [1, 2, 3, 4]) +julia> f = Iterators.accumulate(+, [1,2,3,4]); julia> foreach(println, f) 1 3 6 10 + +julia> f = Iterators.accumulate(+, [1,2,3]; init = 100); + +julia> foreach(println, f) +101 +103 +106 ``` """ -accumulate(f, itr) = Accumulate(f, itr) +accumulate(f, itr; init = Base._InitialValue()) = Accumulate(f, itr, init) function iterate(itr::Accumulate) state = iterate(itr.itr) if state === nothing return nothing end - return (state[1], state) + val = Base.BottomRF(itr.f)(itr.init, state[1]) + return (val, (val, state[2])) end function iterate(itr::Accumulate, state) @@ -491,7 +502,7 @@ end length(itr::Accumulate) = length(itr.itr) size(itr::Accumulate) = size(itr.itr) -IteratorSize(::Type{Accumulate{F,I}}) where {F,I} = IteratorSize(I) +IteratorSize(::Type{<:Accumulate{F,I}}) where {F,I} = IteratorSize(I) IteratorEltype(::Type{<:Accumulate}) = EltypeUnknown() # Rest -- iterate starting at the given state diff --git a/test/iterators.jl b/test/iterators.jl index d456969b83aed..71a24e5e4ca1f 100644 --- a/test/iterators.jl +++ b/test/iterators.jl @@ -793,8 +793,17 @@ end @test collect(Iterators.accumulate(+, [1,2])) == [1,3] @test collect(Iterators.accumulate(+, [1,2,3])) == [1,3,6] @test collect(Iterators.accumulate(=>, [:a,:b,:c])) == [:a, :a => :b, (:a => :b) => :c] + @test collect(Iterators.accumulate(+, (x for x in [true])))::Vector{Int} == [1] + @test collect(Iterators.accumulate(+, (x for x in [true, true, false])))::Vector{Int} == [1, 2, 2] + @test collect(Iterators.accumulate(+, (x for x in [true]), init=10.0))::Vector{Float64} == [11.0] @test length(Iterators.accumulate(+, [10,20,30])) == 3 @test size(Iterators.accumulate(max, rand(2,3))) == (2,3) @test Base.IteratorSize(Iterators.accumulate(max, rand(2,3))) === Base.IteratorSize(rand(2,3)) @test Base.IteratorEltype(Iterators.accumulate(*, ())) isa Base.EltypeUnknown end + +@testset "Base.accumulate" begin + @test cumsum(x^2 for x in 1:3) == [1, 5, 14] + @test cumprod(x + 1 for x in 1:3) == [2, 6, 24] + @test accumulate(+, (x^2 for x in 1:3); init=100) == [101, 105, 114] +end