Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

scatter and gather support element type of idx to be CartesianIndex #308

Merged
merged 9 commits into from
Apr 21, 2021
19 changes: 5 additions & 14 deletions src/gather.jl
Original file line number Diff line number Diff line change
Expand Up @@ -20,20 +20,11 @@ or multiple `dst` columns.

See [`gather`](@ref) for an allocating version.
"""
function gather!(dst::AbstractArray{Tdst,Ndst},
src::AbstractArray{Tsrc,Nsrc},
idx::AbstractArray{Tidx, Nidx}) where
{Tdst, Tsrc, Ndst, Nsrc, Nidx, Tidx <: IntOrIntTuple}

M = typelength(Tidx)
d = Ndst - Nidx
d == Nsrc - M || throw(ArgumentError("Incompatible input shapes."))
size(dst)[1:d] == size(src)[1:d] || throw(ArgumentError("Incompatible input shapes."))
size(dst)[d+1:end] == size(idx) || throw(ArgumentError("Incompatible input shapes."))

colons = ntuple(i -> Colon(), d)
function gather!(dst::AbstractArray, src::AbstractArray, idx::AbstractArray)
dims = _check_dims(src, dst, idx)
colons = ntuple(i -> Colon(), dims)
for k in CartesianIndices(idx)
view(dst, colons..., k) .= view(src, colons..., idx[k]...)
_view(dst, colons, k) .= _view(src, colons, idx[k])
end
return dst
end
Expand Down Expand Up @@ -64,7 +55,7 @@ See [`gather!`](@ref) for an in-place version.
"""
function gather(src::AbstractArray{Tsrc, Nsrc},
idx::AbstractArray{Tidx, Nidx}) where
{Tsrc, Nsrc, Nidx, Tidx<:IntOrIntTuple}
{Tsrc, Nsrc, Nidx, Tidx}

M = typelength(Tidx)
dstsize = (size(src)[1:Nsrc-M]..., size(idx)...)
Expand Down
100 changes: 55 additions & 45 deletions src/scatter.jl
Original file line number Diff line number Diff line change
Expand Up @@ -9,17 +9,39 @@
# - ∇scatter_src!
#

function _check_dims(Ndst, Nsrc, N, Nidx)
@assert Ndst - N == Nsrc - Nidx "Incompatible input shapes of (dst, src, idx) = ($Ndst, $Nsrc, $Nidx)."
dims = Ndst - N
if dims < 0
throw(ArgumentError("dims must be non-negative but got dims=$dims."))
end
typelength(::Type{<:Number}) = 1
typelength(::Type{<:NTuple{M}}) where M = M
typelength(::Type{CartesianIndex{M}}) where M = M

function _check_dims(X::AbstractArray{Tx,Nx},
Y::AbstractArray{Ty,Ny},
idx::AbstractArray{Tidx,Nidx}) where
{Tx,Ty,Tidx<:IntOrIntTuple,Nx,Ny,Nidx}
M = typelength(Tidx)
dims = _check_dims(Nx, Ny, M, Nidx)
size(X)[1:dims] == size(Y)[1:dims] || throw(ArgumentError("Incompatible input shapes."))
size(Y)[dims+1:end] == size(idx) || throw(ArgumentError("Incompatible input shapes."))
return dims
end

typelength(::Type{<:Number}) = 1
typelength(::Type{<:NTuple{M}}) where M = M
function _check_dims(X::AbstractArray{Tx,Nx},
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Seems like we can unify the two methods with just converting the Cartesian indices once?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

these are very low-level methods, we should not do any allocation

Y::AbstractArray{Ty,Ny},
idx::AbstractArray{CartesianIndex{M},Nidx}) where {Tx,Ty,Nx,Ny,M,Nidx}
dims = _check_dims(Nx, Ny, M, Nidx)
size(X)[1:dims] == size(Y)[1:dims] || throw(ArgumentError("Incompatible input shapes."))
size(Y)[dims+1:end] == size(idx) || throw(ArgumentError("Incompatible input shapes."))
return dims
end

function _check_dims(Nx, Ny, M, Nidx)
@assert Nx - M == Ny - Nidx "Incompatible input shapes of (dst, src, idx) = ($Nx, $Ny, $Nidx)."
dims = Nx - M
dims < 0 && throw(ArgumentError("dims must be non-negative but got dims=$dims."))
return dims
end

_view(X, colons, k) = view(X, colons..., k...)
_view(X, colons, k::Union{Integer, CartesianIndex}) = view(X, colons..., k)

"""
scatter!(op, dst, src, idx)
Expand All @@ -42,30 +64,18 @@ index of `dst` and the value of `idx` must indicate the last few dimensions of `
Once the dimensions match, arrays are aligned automatically. The value of `idx` can be
`Int` or `Tuple` type.
"""
function scatter!(op,
dst::AbstractArray{Tdst,Ndst},
src::AbstractArray{Tsrc,Nsrc},
idx::AbstractArray{Tidx,Nidx}) where {Tdst,Tsrc,Tidx<:IntOrIntTuple,Ndst,Nsrc,Nidx}
M = typelength(Tidx)
dims = _check_dims(Ndst, Nsrc, M, Nidx)
scatter!(op, dst, src, idx, Val(dims))
end

function scatter!(op, dst::AbstractArray{Tdst}, src::AbstractArray{Tsrc}, idx::AbstractArray{<:IntOrIntTuple},
dims::Val{N}) where {Tdst,Tsrc,N}
function scatter!(op, dst::AbstractArray, src::AbstractArray, idx::AbstractArray)
dims = _check_dims(dst, src, idx)
colons = Base.ntuple(_->Colon(), dims)
for k in CartesianIndices(idx)
dst_v = view(dst, colons..., idx[k]...)
src_v = view(src, colons..., k)
dst_v = _view(dst, colons, idx[k])
src_v = _view(src, colons, k)
dst_v .= (op).(dst_v, src_v)
end
dst
end

function scatter!(op::typeof(mean),
dst::AbstractArray{Tdst,Ndst},
src::AbstractArray{Tsrc,Nsrc},
idx::AbstractArray{<:IntOrIntTuple,Nidx}) where {Tdst,Tsrc,Ndst,Nsrc,Nidx}
function scatter!(op::typeof(mean), dst::AbstractArray, src::AbstractArray, idx::AbstractArray)
Ns = scatter!(+, zero(dst), one.(src), idx)
dst_ = scatter!(+, zero(dst), src, idx)
dst .+= safe_div.(dst_, Ns)
Expand Down Expand Up @@ -93,55 +103,55 @@ function scatter end

for op in [+, -]
@eval function scatter(op::typeof($op),
src::AbstractArray{T,Nsrc},
idx::AbstractArray{<:IntOrIntTuple,Nidx}) where {T,Nsrc,Nidx}
src::AbstractArray{Tsrc,Nsrc},
idx::AbstractArray{Tidx,Nidx}) where {Tsrc,Tidx,Nsrc,Nidx}
dims = Nsrc - Nidx
dstsize = (size(src)[1:dims]..., maximum_dims(idx)...)
dst = similar(src, T, dstsize)
fill!(dst, Base.reduce_empty(+, T))
dst = similar(src, Tsrc, dstsize)
fill!(dst, Base.reduce_empty(+, Tsrc))
scatter!(op, dst, src, idx)
end
end

for op in [*, /]
@eval function scatter(op::typeof($op),
src::AbstractArray{T,Nsrc},
idx::AbstractArray{<:IntOrIntTuple,Nidx}) where {T,Nsrc,Nidx}
src::AbstractArray{Tsrc,Nsrc},
idx::AbstractArray{Tidx,Nidx}) where {Tsrc,Tidx,Nsrc,Nidx}
dims = Nsrc - Nidx
dstsize = (size(src)[1:dims]..., maximum_dims(idx)...)
dst = similar(src, T, dstsize)
fill!(dst, Base.reduce_empty(*, T))
dst = similar(src, Tsrc, dstsize)
fill!(dst, Base.reduce_empty(*, Tsrc))
scatter!(op, dst, src, idx)
end
end

function scatter(op::typeof(max),
src::AbstractArray{T,Nsrc},
idx::AbstractArray{<:IntOrIntTuple,Nidx}) where {T,Nsrc,Nidx}
src::AbstractArray{Tsrc,Nsrc},
idx::AbstractArray{Tidx,Nidx}) where {Tsrc,Tidx,Nsrc,Nidx}
dims = Nsrc - Nidx
dstsize = (size(src)[1:dims]..., maximum_dims(idx)...)
dst = similar(src, T, dstsize)
fill!(dst, typemin(T))
dst = similar(src, Tsrc, dstsize)
fill!(dst, typemin(Tsrc))
scatter!(op, dst, src, idx)
end

function scatter(op::typeof(min),
src::AbstractArray{T,Nsrc},
idx::AbstractArray{<:IntOrIntTuple,Nidx}) where {T,Nsrc,Nidx}
src::AbstractArray{Tsrc,Nsrc},
idx::AbstractArray{Tidx,Nidx}) where {Tsrc,Tidx,Nsrc,Nidx}
dims = Nsrc - Nidx
dstsize = (size(src)[1:dims]..., maximum_dims(idx)...)
dst = similar(src, T, dstsize)
fill!(dst, typemax(T))
dst = similar(src, Tsrc, dstsize)
fill!(dst, typemax(Tsrc))
scatter!(op, dst, src, idx)
end

function scatter(op::typeof(mean),
src::AbstractArray{T,Nsrc},
idx::AbstractArray{<:IntOrIntTuple,Nidx}) where {T,Nsrc,Nidx}
FT = float(T)
src::AbstractArray{Tsrc,Nsrc},
idx::AbstractArray{Tidx,Nidx}) where {Tsrc,Tidx,Nsrc,Nidx}
FT = float(Tsrc)
dims = Nsrc - Nidx
dstsize = (size(src)[1:dims]..., maximum_dims(idx)...)
dst = similar(src, T, dstsize)
dst = similar(src, Tsrc, dstsize)
fill!(dst, Base.reduce_empty(+, FT))
scatter!(op, dst, src, idx)
end
1 change: 1 addition & 0 deletions src/utils.jl
Original file line number Diff line number Diff line change
Expand Up @@ -13,3 +13,4 @@ The maximum of each dimension in the element is computed.
"""
maximum_dims(dims::AbstractArray{<:Integer}) = (maximum(dims), )
maximum_dims(dims::AbstractArray{NTuple{N, T}}) where {N,T} = ntuple(i -> maximum(x->x[i], dims), N)
maximum_dims(dims::AbstractArray{CartesianIndex{N}}) where {N} = ntuple(i -> maximum(x->x[i], dims), N)
30 changes: 30 additions & 0 deletions test/gather.jl
Original file line number Diff line number Diff line change
Expand Up @@ -94,3 +94,33 @@ end
@test y isa Array{T,3}
@test size(y) == (size(src)[1:Nsrc-M]..., size(index)...)
end

@testset "gather cartesian index" begin
T = Float32

## 2d src, 1d index of 2-tuples -> 1d output
src = T[3 5 7
4 6 8]

index = CartesianIndex.([(1,1), (1,2), (1,3), (2,1), (2,2), (2,3)])

output = T[3, 5, 7, 4, 6, 8]

y = gather(src, index)
M = NNlib.typelength(eltype(index))
Nsrc = ndims(src)
@test y isa Array{T,1}
@test size(y) == (size(src)[1:Nsrc-M]..., size(index)...)
@test y == output

## 3d src, 2d index of 2-tuples -> 3d output
n1, nsrc, nidx = 2, 3, 6
src = rand(Float32, n1, nsrc, nsrc)
index = [CartesianIndex((rand(1:nsrc), rand(1:nsrc))) for i=1:nidx, j=1:nidx]

y = gather(src, index)
M = NNlib.typelength(eltype(index))
Nsrc = ndims(src)
@test y isa Array{T,3}
@test size(y) == (size(src)[1:Nsrc-M]..., size(index)...)
end
4 changes: 4 additions & 0 deletions test/scatter.jl
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,10 @@ idxs = Dict(
:tup => [(1,) (2,) (3,) (4,);
(4,) (2,) (1,) (3,);
(3,) (5,) (5,) (3,)],
:car => CartesianIndex.(
[(1,) (2,) (3,) (4,);
(4,) (2,) (1,) (3,);
(3,) (5,) (5,) (3,)]),
)
res = Dict(
(+, 0, true) => [5, 6, 9, 8, 9],
Expand Down
4 changes: 4 additions & 0 deletions test/utils.jl
Original file line number Diff line number Diff line change
Expand Up @@ -6,4 +6,8 @@
ind3 = [(3,4,5) (1,2,3) (2,3,9);
(4,6,2) (5,3,2) (4,4,4)]
@test NNlib.maximum_dims(ind3) == (5,6,9)
ind4 = CartesianIndex.(
[(3,4,5) (1,2,3) (2,3,9);
(4,6,2) (5,3,2) (4,4,4)])
@test NNlib.maximum_dims(ind4) == (5,6,9)
end