From 5336b78819f506bdca2b68985d45557a615b3aff Mon Sep 17 00:00:00 2001 From: Dahua Lin Date: Sat, 31 May 2014 08:05:26 -0500 Subject: [PATCH] and, or, count reimplemented using new paradigm --- base/reduce.jl | 86 ++++++++++++++++++++++++++++---------------------- test/reduce.jl | 19 +++++++++++ 2 files changed, 67 insertions(+), 38 deletions(-) diff --git a/base/reduce.jl b/base/reduce.jl index 55f23b5d028b3..1e308fdde6edf 100644 --- a/base/reduce.jl +++ b/base/reduce.jl @@ -142,6 +142,8 @@ mr_empty(::Abs2Fun, op::AddFun, T) = r_promote(op, abs2(zero(T))) mr_empty(::IdFun, op::MulFun, T) = r_promote(op, one(T)) mr_empty(::AbsFun, op::MaxFun, T) = abs(zero(T)) mr_empty(::Abs2Fun, op::MaxFun, T) = abs2(zero(T)) +mr_empty(f, op::AndFun, T) = true +mr_empty(f, op::OrFun, T) = false function _mapreduce{T}(f, op, A::AbstractArray{T}) n = length(A) @@ -264,7 +266,7 @@ prod(A::AbstractArray{Bool}) = ## maximum & minimum -function mapreduce_seq_impl(f, op::MaxFun, A::AbstractArray, first::Int, last::Int) +function mapreduce_impl(f, op::MaxFun, A::AbstractArray, first::Int, last::Int) # locate the first non NaN number v = evaluate(f, A[first]) i = first + 1 @@ -282,7 +284,7 @@ function mapreduce_seq_impl(f, op::MaxFun, A::AbstractArray, first::Int, last::I v end -function mapreduce_seq_impl(f, op::MinFun, A::AbstractArray, first::Int, last::Int) +function mapreduce_impl(f, op::MinFun, A::AbstractArray, first::Int, last::Int) # locate the first non NaN number v = evaluate(f, A[first]) i = first + 1 @@ -362,53 +364,63 @@ end ## all & any -function all(itr) +function mapfoldl(f, ::AndFun, itr) for x in itr - if !x + if !evaluate(f, x) return false end end return true end -function any(itr) +function mapfoldl(f, ::OrFun, itr) for x in itr - if x + if evaluate(f, x) return true end end return false end -function any(pred::Union(Function,Func{1}), itr) - for x in itr - if evaluate(pred, x) - return true +function mapreduce_impl(f, op::AndFun, A::AbstractArray, ifirst::Int, ilast::Int) + while ifirst <= ilast + @inbounds x = A[ifirst] + if !evaluate(f, x) + return false end + ifirst += 1 end - return false + return true end -function all(pred::Union(Function,Func{1}), itr) - for x in itr - if !evaluate(pred, x) - return false +function mapreduce_impl(f, op::OrFun, A::AbstractArray, ifirst::Int, ilast::Int) + while ifirst <= ilast + @inbounds x = A[ifirst] + if evaluate(f, x) + return true end + ifirst += 1 end - return true + return false end +all(a) = mapreduce(IdFun(), AndFun(), a) +any(a) = mapreduce(IdFun(), OrFun(), a) + +all(pred::Union(Function,Func{1}), a) = mapreduce(pred, AndFun(), a) +any(pred::Union(Function,Func{1}), a) = mapreduce(pred, OrFun(), a) + ## in & contains -function in(x, itr) - for y in itr - if y == x - return true - end - end - return false +immutable EqX{T} <: Func{1} + x::T end +EqX{T}(x::T) = EqX{T}(x) +evaluate(f::EqX, y) = (y == f.x) + +in(x, itr) = any(EqX(x), itr) + const ∈ = in ∉(x, itr)=!∈(x, itr) ∋(itr, x)= ∈(x, itr) @@ -431,33 +443,31 @@ end ## countnz & count -function countnz(itr) +function count(pred::Union(Function,Func{1}), itr) n = 0 for x in itr - if x != 0 + if evaluate(pred, x) n += 1 end end return n end -function countnz(a::AbstractArray) +function count(pred::Union(Function,Func{1}), a::AbstractArray) n = 0 - for i = 1:length(a) - @inbounds x = a[i] - if x != 0 + i = 0 + len = length(a) + while i < len + @inbounds x = a[i+=1] + if evaluate(pred, x) n += 1 end end return n end -function count(pred::Function, itr) - n = 0 - for x in itr - if pred(x) - n += 1 - end - end - return n -end +type NotEqZero <: Func{1} end +evaluate(NotEqZero, x) = (x != 0) + +countnz(a) = count(NotEqZero(), a) + diff --git a/test/reduce.jl b/test/reduce.jl index fc1e3db7aad2b..9cba9ac9c8d0c 100644 --- a/test/reduce.jl +++ b/test/reduce.jl @@ -166,6 +166,25 @@ prod2(itr) = invoke(prod, (Any,), itr) @test all(x->x>0, [4]) == true @test all(x->x>0, [-3, 4, 5]) == false +# in + +@test in(1, Int[]) == false +@test in(1, Int[1]) == true +@test in(1, Int[2]) == false +@test in(0, 1:3) == false +@test in(1, 1:3) == true +@test in(2, 1:3) == true + +# count & countnz + +@test count(x->x>0, Int[]) == 0 +@test count(x->x>0, -3:5) == 5 + +@test countnz(Int[]) == 0 +@test countnz(Int[0]) == 0 +@test countnz(Int[1]) == 1 +@test countnz([1, 0, 2, 0, 3, 0, 4]) == 4 + ## cumsum, cummin, cummax