Skip to content

Commit

Permalink
and, or, count reimplemented using new paradigm
Browse files Browse the repository at this point in the history
  • Loading branch information
lindahua committed May 31, 2014
1 parent 80e8a28 commit 5336b78
Show file tree
Hide file tree
Showing 2 changed files with 67 additions and 38 deletions.
86 changes: 48 additions & 38 deletions base/reduce.jl
Original file line number Diff line number Diff line change
Expand Up @@ -142,6 +142,8 @@ mr_empty(::Abs2Fun, op::AddFun, T) = r_promote(op, abs2(zero(T)))
mr_empty(::IdFun, op::MulFun, T) = r_promote(op, one(T))
mr_empty(::AbsFun, op::MaxFun, T) = abs(zero(T))
mr_empty(::Abs2Fun, op::MaxFun, T) = abs2(zero(T))
mr_empty(f, op::AndFun, T) = true
mr_empty(f, op::OrFun, T) = false

function _mapreduce{T}(f, op, A::AbstractArray{T})
n = length(A)
Expand Down Expand Up @@ -264,7 +266,7 @@ prod(A::AbstractArray{Bool}) =

## maximum & minimum

function mapreduce_seq_impl(f, op::MaxFun, A::AbstractArray, first::Int, last::Int)
function mapreduce_impl(f, op::MaxFun, A::AbstractArray, first::Int, last::Int)
# locate the first non NaN number
v = evaluate(f, A[first])
i = first + 1
Expand All @@ -282,7 +284,7 @@ function mapreduce_seq_impl(f, op::MaxFun, A::AbstractArray, first::Int, last::I
v
end

function mapreduce_seq_impl(f, op::MinFun, A::AbstractArray, first::Int, last::Int)
function mapreduce_impl(f, op::MinFun, A::AbstractArray, first::Int, last::Int)
# locate the first non NaN number
v = evaluate(f, A[first])
i = first + 1
Expand Down Expand Up @@ -362,53 +364,63 @@ end

## all & any

function all(itr)
function mapfoldl(f, ::AndFun, itr)
for x in itr
if !x
if !evaluate(f, x)
return false
end
end
return true
end

function any(itr)
function mapfoldl(f, ::OrFun, itr)
for x in itr
if x
if evaluate(f, x)
return true
end
end
return false
end

function any(pred::Union(Function,Func{1}), itr)
for x in itr
if evaluate(pred, x)
return true
function mapreduce_impl(f, op::AndFun, A::AbstractArray, ifirst::Int, ilast::Int)
while ifirst <= ilast
@inbounds x = A[ifirst]
if !evaluate(f, x)
return false
end
ifirst += 1
end
return false
return true
end

function all(pred::Union(Function,Func{1}), itr)
for x in itr
if !evaluate(pred, x)
return false
function mapreduce_impl(f, op::OrFun, A::AbstractArray, ifirst::Int, ilast::Int)
while ifirst <= ilast
@inbounds x = A[ifirst]
if evaluate(f, x)
return true
end
ifirst += 1
end
return true
return false
end

all(a) = mapreduce(IdFun(), AndFun(), a)
any(a) = mapreduce(IdFun(), OrFun(), a)

all(pred::Union(Function,Func{1}), a) = mapreduce(pred, AndFun(), a)
any(pred::Union(Function,Func{1}), a) = mapreduce(pred, OrFun(), a)


## in & contains

function in(x, itr)
for y in itr
if y == x
return true
end
end
return false
immutable EqX{T} <: Func{1}
x::T
end
EqX{T}(x::T) = EqX{T}(x)
evaluate(f::EqX, y) = (y == f.x)

in(x, itr) = any(EqX(x), itr)

const = in
(x, itr)=!(x, itr)
(itr, x)= (x, itr)
Expand All @@ -431,33 +443,31 @@ end

## countnz & count

function countnz(itr)
function count(pred::Union(Function,Func{1}), itr)
n = 0
for x in itr
if x != 0
if evaluate(pred, x)
n += 1
end
end
return n
end

function countnz(a::AbstractArray)
function count(pred::Union(Function,Func{1}), a::AbstractArray)
n = 0
for i = 1:length(a)
@inbounds x = a[i]
if x != 0
i = 0
len = length(a)
while i < len
@inbounds x = a[i+=1]
if evaluate(pred, x)
n += 1
end
end
return n
end

function count(pred::Function, itr)
n = 0
for x in itr
if pred(x)
n += 1
end
end
return n
end
type NotEqZero <: Func{1} end
evaluate(NotEqZero, x) = (x != 0)

countnz(a) = count(NotEqZero(), a)

19 changes: 19 additions & 0 deletions test/reduce.jl
Original file line number Diff line number Diff line change
Expand Up @@ -166,6 +166,25 @@ prod2(itr) = invoke(prod, (Any,), itr)
@test all(x->x>0, [4]) == true
@test all(x->x>0, [-3, 4, 5]) == false

# in

@test in(1, Int[]) == false
@test in(1, Int[1]) == true
@test in(1, Int[2]) == false
@test in(0, 1:3) == false
@test in(1, 1:3) == true
@test in(2, 1:3) == true

# count & countnz

@test count(x->x>0, Int[]) == 0
@test count(x->x>0, -3:5) == 5

@test countnz(Int[]) == 0
@test countnz(Int[0]) == 0
@test countnz(Int[1]) == 1
@test countnz([1, 0, 2, 0, 3, 0, 4]) == 4


## cumsum, cummin, cummax

Expand Down

1 comment on commit 5336b78

@StefanKarpinski
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

See #10838.

Please sign in to comment.