Skip to content

Commit

Permalink
test: port NNlib gather tests over
Browse files Browse the repository at this point in the history
  • Loading branch information
avik-pal committed Nov 10, 2024
1 parent 3cbbb65 commit 8d888cf
Show file tree
Hide file tree
Showing 2 changed files with 190 additions and 4 deletions.
10 changes: 6 additions & 4 deletions ext/ReactantNNlibExt.jl
Original file line number Diff line number Diff line change
Expand Up @@ -336,15 +336,17 @@ end
# XXX: For performance to use `stablehlo.dynamic_gather` or atleast use traced loop
# instead of unrolling the loop (the case for AbstractArray can just use
# `stablehlo.gather`). See above for the special case implementation that is optimized.
function NNlib.gather!(
dst::TracedRArray{T1,N}, src::AnyTracedRArray{T2,N}, idxs::AbstractArray
) where {T1,T2,N}
function NNlib.gather!(dst::TracedRArray, src::AnyTracedRArray, idxs::AbstractArray)
@warn "Using fallback implementation of `gather!` for using `stablehlo.dynamic_slice`. \
This case is not optimized and will be slow." maxlog = 1
dims = NNlib.scatter_dims(src, dst, idxs)
colons = ntuple(Returns(Colon()), dims)
start_sizes = ntuple(i -> size(src, i), dims)
results = [reshape(src[colons..., Tuple(idxs[k])...], start_sizes..., :) for k in idxs]
results = map(CartesianIndices(idxs)) do k
res = src[colons..., Tuple(idxs[k])...]
res isa TracedRNumber && (res = Reactant.broadcast_to_size(res, (1,)))
return reshape(res, start_sizes..., :)
end
res = reshape(cat(results...; dims=(dims + 1)), size(dst))
dst.mlir_data = res.mlir_data
return dst
Expand Down
184 changes: 184 additions & 0 deletions test/nn/nnlib.jl
Original file line number Diff line number Diff line change
Expand Up @@ -173,3 +173,187 @@ end
causal_mask2(x) = NNlib.make_causal_mask(x; dims=1)
@test @jit(causal_mask2(x_ra)) causal_mask2(x)
end

# Adapted from https://github.com/FluxML/NNlib.jl/blob/02138682a4fc5ca019759218be50e59907d4527c/test/testsuite/gather.jl#L5
@testset "NNlib gather" begin
@testset "gather scalar index" begin
## 1d src, 2d index of ints -> 2d output
src = Float32[3, 4, 5, 6, 7]
index = [
1 2 3 4
4 2 1 3
3 5 5 3
]
output = Float32[
3 4 5 6
6 4 3 5
5 7 7 5
]

y1 = @jit(NNlib.gather(Reactant.to_rarray(src), Reactant.to_rarray(index)))
@test y1 output
@test y1 isa ConcreteRArray{Float32,2}
@test size(y1) == size(index)

y2 = @jit(NNlib.gather(Reactant.to_rarray(src), index))
@test y2 output
@test y2 isa ConcreteRArray{Float32,2}
@test size(y2) == size(index)

dst = Float32.(zero.(index))
@test @jit(
NNlib.gather!(
Reactant.to_rarray(dst), Reactant.to_rarray(src), Reactant.to_rarray(index)
)
) output

dst = zeros(Float32, 3, 5)
@test_throws ArgumentError @jit(
NNlib.gather!(
Reactant.to_rarray(dst), Reactant.to_rarray(src), Reactant.to_rarray(index)
)
)

## 1d src, 3d index of ints -> 3d output
src = Float32[3, 4, 5, 6, 7]
index = [
1 2 3 4
4 2 1 3
3 5 5 3
][:, :, 1:1]
output = Float32[
3 4 5 6
6 4 3 5
5 7 7 5
][:, :, 1:1]

y = @jit(NNlib.gather(Reactant.to_rarray(src), Reactant.to_rarray(index)))
@test y output
@test y isa ConcreteRArray{Float32,3}
@test size(y) == size(index)

y = @jit(NNlib.gather(Reactant.to_rarray(src), index))
@test y output
@test y isa ConcreteRArray{Float32,3}
@test size(y) == size(index)

## 2d src, 2d index of ints -> 3d output
src = Float32[
3 5 7
4 6 8
]
index = [
1 2 3
2 2 1
3 1 3
]

output = zeros(Float32, 2, 3, 3)
output[:, :, 1] = [
3 5 7
4 6 8
]
output[:, :, 2] = [
5 5 3
6 6 4
]
output[:, :, 3] = [
7 3 7
8 4 8
]

y = @jit(NNlib.gather(Reactant.to_rarray(src), Reactant.to_rarray(index)))
@test y output
@test y isa ConcreteRArray{Float32,3}
@test size(y) == (size(src)[1:(end - 1)]..., size(index)...)

y = @jit(NNlib.gather(Reactant.to_rarray(src), index))
@test y output
@test y isa ConcreteRArray{Float32,3}
@test size(y) == (size(src)[1:(end - 1)]..., size(index)...)
end


@testset "gather tuple index" begin
## 2d src, 1d index of 2-tuples -> 1d output
src = Float32[
3 5 7
4 6 8
]
index = [(1, 1), (1, 2), (1, 3), (2, 1), (2, 2), (2, 3)]
output = Float32[3, 5, 7, 4, 6, 8]

y = @jit(NNlib.gather(Reactant.to_rarray(src), Reactant.to_rarray(index)))
M = NNlib.typelength(eltype(index))
Nsrc = ndims(src)
@test y isa ConcreteRArray{Float32,1}
@test size(y) == (size(src)[1:Nsrc-M]..., size(index)...)
@test y output

y = @jit(NNlib.gather(Reactant.to_rarray(src), index))
@test y output
@test y isa ConcreteRArray{Float32,1}
@test size(y) == (size(src)[1:Nsrc-M]..., size(index)...)
@test y output

## 3d src, 2d index of 2-tuples -> 3d output
n1, nsrc, nidx = 2, 3, 6
src = rand(Float32, n1, nsrc, nsrc)
index = [
(rand(1:nsrc), rand(1:nsrc)) for i = 1:nidx, j = 1:nidx
]

y = @jit(NNlib.gather(Reactant.to_rarray(src), Reactant.to_rarray(index)))
M = NNlib.typelength(eltype(index))
Nsrc = ndims(src)
@test y isa ConcreteRArray{Float32,3}
@test size(y) == (size(src)[1:Nsrc-M]..., size(index)...)

y = @jit(NNlib.gather(Reactant.to_rarray(src), index))
M = NNlib.typelength(eltype(index))
Nsrc = ndims(src)
@test y isa ConcreteRArray{Float32,3}
@test size(y) == (size(src)[1:Nsrc-M]..., size(index)...)
end

@testset "gather cartesian index" begin
## 2d src, 1d index of 2-tuples -> 1d output
src = Float32[
3 5 7
4 6 8
]
index = CartesianIndex.([(1, 1), (1, 2), (1, 3), (2, 1), (2, 2), (2, 3)])
output = Float32[3, 5, 7, 4, 6, 8]

y = @jit(NNlib.gather(Reactant.to_rarray(src), Reactant.to_rarray(index)))
M = NNlib.typelength(eltype(index))
Nsrc = ndims(src)
@test y isa ConcreteRArray{Float32,1}
@test size(y) == (size(src)[1:Nsrc-M]..., size(index)...)
@test y output

y = @jit(NNlib.gather(Reactant.to_rarray(src), index))
@test y output
@test y isa ConcreteRArray{Float32,1}
@test size(y) == (size(src)[1:Nsrc-M]..., size(index)...)

## 3d src, 2d index of 2-tuples -> 3d output
n1, nsrc, nidx = 2, 3, 6
src = rand(Float32, n1, nsrc, nsrc)
index = [
CartesianIndex((rand(1:nsrc), rand(1:nsrc))) for i = 1:nidx, j = 1:nidx
]

y = @jit(NNlib.gather(Reactant.to_rarray(src), Reactant.to_rarray(index)))
M = NNlib.typelength(eltype(index))
Nsrc = ndims(src)
@test y isa ConcreteRArray{Float32,3}
@test size(y) == (size(src)[1:Nsrc-M]..., size(index)...)

y = @jit(NNlib.gather(Reactant.to_rarray(src), index))
M = NNlib.typelength(eltype(index))
Nsrc = ndims(src)
@test y isa ConcreteRArray{Float32,3}
@test size(y) == (size(src)[1:Nsrc-M]..., size(index)...)
end
end

0 comments on commit 8d888cf

Please sign in to comment.