Skip to content

Commit

Permalink
Handle possibility of eltype change in mapreducedim and diff
Browse files Browse the repository at this point in the history
  • Loading branch information
wsshin committed Aug 7, 2017
1 parent 359e347 commit 5172c39
Show file tree
Hide file tree
Showing 2 changed files with 44 additions and 42 deletions.
27 changes: 12 additions & 15 deletions src/mapreduce.jl
Original file line number Diff line number Diff line change
Expand Up @@ -101,9 +101,11 @@ end
_mapreducedim(f, op, Size(a), a, Val{D}, v0)
end

@generated function _mapreducedim(f, op, ::Size{S}, a::StaticArray, ::Type{Val{D}}) where {S, D}
@generated function _mapreducedim(f, op, ::Size{S}, a::StaticArray, ::Type{Val{D}}) where {S,D}
N = length(S)
Snew = ([n==D ? 1 : S[n] for n = 1:N]...)
T0 = eltype(a)
T = :((T1 = Base.promote_op(f, $T0); Base.promote_op(op, T1, T1)))

exprs = Array{Expr}(Snew)
itr = [1:n for n Snew]
Expand All @@ -118,14 +120,13 @@ end
exprs[i...] = expr
end

# TODO element type might change
return quote
@_inline_meta
@inbounds return similar_type(a, Size($Snew))(tuple($(exprs...)))
@inbounds return similar_type(a, $T, Size($Snew))(tuple($(exprs...)))
end
end

@generated function _mapreducedim(f, op, ::Size{S}, a::StaticArray, ::Type{Val{D}}, v0) where {S, D}
@generated function _mapreducedim(f, op, ::Size{S}, a::StaticArray, ::Type{Val{D}}, v0::T) where {S,D,T}
N = length(S)
Snew = ([n==D ? 1 : S[n] for n = 1:N]...)

Expand All @@ -142,10 +143,9 @@ end
exprs[i...] = expr
end

# TODO element type might change
return quote
@_inline_meta
@inbounds return similar_type(a, Size($Snew))(tuple($(exprs...)))
@inbounds return similar_type(a, T, Size($Snew))(tuple($(exprs...)))
end
end

Expand Down Expand Up @@ -186,9 +186,6 @@ end
# all and any must return Bool, so we know the appropriate v0 is true and false,
# respectively. Therefore, all(f, ...) and any(f, ...) are implemented by mapreduce(f, ...)
# with an initial value v0 = true and false.
#
# 4. Some implementations (e.g., count(a, Val{D})) are commented out because corresponding
# Base functions (e.g., count(a, D)) do not exist yet.
@inline iszero(a::StaticArray{<:Any,T}) where {T} = reduce((x,y) -> x && (y==zero(T)), true, a)

@inline sum(a::StaticArray{<:Any,T}) where {T} = reduce(+, zero(T), a)
Expand All @@ -203,8 +200,8 @@ end

@inline count(a::StaticArray{<:Any,Bool}) = reduce(+, 0, a)
@inline count(f::Function, a::StaticArray) = mapreduce(x->f(x)::Bool, +, 0, a)
# @inline count(a::StaticArray{<:Any,Bool}, ::Type{Val{D}}) where {D} = reducedim(+, a, Val{D}, 0)
# @inline count(f::Function, a::StaticArray, ::Type{Val{D}}) where {D} = mapreducedim(x->f(x)::Bool, +, a, Val{D}, 0)
@inline count(a::StaticArray{<:Any,Bool}, ::Type{Val{D}}) where {D} = reducedim(+, a, Val{D}, 0)
@inline count(f::Function, a::StaticArray, ::Type{Val{D}}) where {D} = mapreducedim(x->f(x)::Bool, +, a, Val{D}, 0)

@inline all(a::StaticArray{<:Any,Bool}) = reduce(&, true, a) # non-branching versions
@inline all(f::Function, a::StaticArray) = mapreduce(x->f(x)::Bool, &, true, a)
Expand All @@ -219,7 +216,7 @@ end
@inline mean(a::StaticArray) = sum(a) / length(a)
@inline mean(f::Function, a::StaticArray) = sum(f, a) / length(a)
@inline mean(a::StaticArray, ::Type{Val{D}}) where {D} = sum(a, Val{D}) / size(a, D)
# @inline mean(f::Function, a::StaticArray, ::Type{Val{D}}) where {D} = sum(f, a, Val{D}) / size(a, D)
@inline mean(f::Function, a::StaticArray, ::Type{Val{D}}) where {D} = sum(f, a, Val{D}) / size(a, D)

@inline minimum(a::StaticArray) = reduce(min, a) # base has mapreduce(idenity, scalarmin, a)
@inline minimum(f::Function, a::StaticArray) = mapreduce(f, min, a)
Expand All @@ -235,9 +232,10 @@ end
@inline diff(a::StaticArray) = diff(a, Val{1})
@inline diff(a::StaticArray, ::Type{Val{D}}) where {D} = _diff(Size(a), a, Val{D})

@generated function _diff(::Size{S}, a::StaticArray, ::Type{Val{D}}) where {S, D}
@generated function _diff(::Size{S}, a::StaticArray, ::Type{Val{D}}) where {S,D}
N = length(S)
Snew = ([n==D ? S[n]-1 : S[n] for n = 1:N]...)
T = Base.promote_op(-, eltype(a), eltype(a))

exprs = Array{Expr}(Snew)
itr = [1:n for n = Snew]
Expand All @@ -248,9 +246,8 @@ end
exprs[i1...] = :(a[$(i2...)] - a[$(i1...)])
end

# TODO element type might change
return quote
@_inline_meta
@inbounds return similar_type(a, Size($Snew))(tuple($(exprs...)))
@inbounds return similar_type(a, $T, Size($Snew))(tuple($(exprs...)))
end
end
59 changes: 32 additions & 27 deletions test/mapreduce.jl
Original file line number Diff line number Diff line change
Expand Up @@ -24,71 +24,76 @@
end

@testset "[map]reduce and [map]reducedim" begin
a = rand(4,3); sa = SMatrix{4,3}(a)
a = rand(4,3); sa = SMatrix{4,3}(a); (I,J) = size(a)
v1 = [2,4,6,8]; sv1 = SVector{4}(v1)
v2 = [4,3,2,1]; sv2 = SVector{4}(v2)
@test reduce(+, sv1) === reduce(+, v1)
@test reduce(+, 0, sv1) === reduce(+, 0, v1)
@test reducedim(max, sa, Val{1}, -1.) == reducedim(max, a, 1, -1.)
@test reducedim(max, sa, Val{2}, -1.) == reducedim(max, a, 2, -1.)
@test reducedim(max, sa, Val{1}, -1.) === SMatrix{1,J}(reducedim(max, a, 1, -1.))
@test reducedim(max, sa, Val{2}, -1.) === SMatrix{I,1}(reducedim(max, a, 2, -1.))
@test mapreduce(-, +, sv1) === mapreduce(-, +, v1)
@test mapreduce(-, +, 0, sv1) === mapreduce(-, +, 0, v1)
@test mapreduce(*, +, sv1, sv2) === 40
@test mapreduce(*, +, 0, sv1, sv2) === 40
@test mapreducedim(x->x^2, max, sa, Val{1}, -1.) == mapreducedim(x->x^2, max, a, 1, -1.)
@test mapreducedim(x->x^2, max, sa, Val{2}, -1.) == mapreducedim(x->x^2, max, a, 2, -1.)
@test mapreducedim(x->x^2, max, sa, Val{1}, -1.) == SMatrix{1,J}(mapreducedim(x->x^2, max, a, 1, -1.))
@test mapreducedim(x->x^2, max, sa, Val{2}, -1.) == SMatrix{I,1}(mapreducedim(x->x^2, max, a, 2, -1.))
end

@testset "implemented by [map]reduce and [map]reducedim" begin
a = randn(2,2,2); sa = SArray{Tuple{2,2,2}}(a)
b = rand(Bool,2,2,2); sb = SArray{Tuple{2,2,2}}(b)
z = zeros(2,2,2); sz = SArray{Tuple{2,2,2}}(z)
I, J, K = 2, 2, 2
OSArray = SArray{Tuple{I,J,K}} # original
RSArray1 = SArray{Tuple{1,J,K}} # reduced in dimension 1
RSArray2 = SArray{Tuple{I,1,K}} # reduced in dimension 2
RSArray3 = SArray{Tuple{I,J,1}} # reduced in dimension 3
a = randn(I,J,K); sa = OSArray(a)
b = rand(Bool,I,J,K); sb = OSArray(b)
z = zeros(I,J,K); sz = OSArray(z)

@test iszero(sz) == iszero(z)

@test sum(sa) === sum(a)
@test sum(abs2, sa) === sum(abs2, a)
@test sum(sa, Val{2}) == sum(a, 2)
@test sum(abs2, sa, Val{2}) == sum(abs2, a, 2)
@test sum(sa, Val{2}) === RSArray2(sum(a, 2))
@test sum(abs2, sa, Val{2}) === RSArray2(sum(abs2, a, 2))

@test prod(sa) === prod(a)
@test prod(abs2, sa) === prod(abs2, a)
@test prod(sa, Val{2}) == prod(a, 2)
@test prod(abs2, sa, Val{2}) == prod(abs2, a, 2)
@test prod(sa, Val{2}) === RSArray2(prod(a, 2))
@test prod(abs2, sa, Val{2}) === RSArray2(prod(abs2, a, 2))

@test count(sb) === count(b)
@test count(x->x>0, sa) === count(x->x>0, a)
# @test count(sb, Val{2}) == count(b, 2)
# @test count(x->x>0, sa, Val{2}) == count(x->x>0, a, 2)
@test count(sb, Val{2}) === RSArray2(reshape([count(b[i,:,k]) for i = 1:I, k = 1:K], (I,1,K)))
@test count(x->x>0, sa, Val{2}) === RSArray2(reshape([count(x->x>0, a[i,:,k]) for i = 1:I, k = 1:K], (I,1,K)))

@test all(sb) === all(b)
@test all(x->x>0, sa) === all(x->x>0, a)
@test all(sb, Val{2}) == all(b, 2)
@test all(x->x>0, sa, Val{2}) == all(x->x>0, a, 2)
@test all(sb, Val{2}) === RSArray2(all(b, 2))
@test all(x->x>0, sa, Val{2}) === RSArray2(all(x->x>0, a, 2))

@test any(sb) === any(b)
@test any(x->x>0, sa) === any(x->x>0, a)
@test any(sb, Val{2}) == any(b, 2)
@test any(x->x>0, sa, Val{2}) == any(x->x>0, a, 2)
@test any(sb, Val{2}) === RSArray2(any(b, 2))
@test any(x->x>0, sa, Val{2}) === RSArray2(any(x->x>0, a, 2))

@test mean(sa) === mean(a)
@test mean(abs2, sa) === mean(abs2, a)
@test mean(sa, Val{2}) == mean(a, 2)
# @test mean(abs2, sa, Val{2}) == mean(abs2, a, 2)
@test mean(sa, Val{2}) === RSArray2(mean(a, 2))
@test mean(abs2, sa, Val{2}) === RSArray2(mean(abs2.(a), 2))

@test minimum(sa) === minimum(a)
@test minimum(abs2, sa) === minimum(abs2, a)
@test minimum(sa, Val{2}) == minimum(a, 2)
@test minimum(abs2, sa, Val{2}) == minimum(abs2, a, 2)
@test minimum(sa, Val{2}) === RSArray2(minimum(a, 2))
@test minimum(abs2, sa, Val{2}) === RSArray2(minimum(abs2, a, 2))

@test maximum(sa) === maximum(a)
@test maximum(abs2, sa) === maximum(abs2, a)
@test maximum(sa, Val{2}) == maximum(a, 2)
@test maximum(abs2, sa, Val{2}) == maximum(abs2, a, 2)
@test maximum(sa, Val{2}) === RSArray2(maximum(a, 2))
@test maximum(abs2, sa, Val{2}) === RSArray2(maximum(abs2, a, 2))

@test diff(sa, Val{1}) == a[2:end,:,:] - a[1:end-1,:,:]
@test diff(sa, Val{2}) == a[:,2:end,:] - a[:,1:end-1,:]
@test diff(sa, Val{3}) == a[:,:,2:end] - a[:,:,1:end-1]
@test diff(sa, Val{1}) === RSArray1(a[2:end,:,:] - a[1:end-1,:,:])
@test diff(sa, Val{2}) === RSArray2(a[:,2:end,:] - a[:,1:end-1,:])
@test diff(sa, Val{3}) === RSArray3(a[:,:,2:end] - a[:,:,1:end-1])

# as of Julia v0.6, diff() for regular Array is defined only for vectors and matrices
m = randn(4,3); sm = SMatrix{4,3}(m)
Expand Down

0 comments on commit 5172c39

Please sign in to comment.