diff --git a/ext/ReactantNNlibExt.jl b/ext/ReactantNNlibExt.jl index 940accfbe..fadbdaa2d 100644 --- a/ext/ReactantNNlibExt.jl +++ b/ext/ReactantNNlibExt.jl @@ -336,15 +336,17 @@ end # XXX: For performance to use `stablehlo.dynamic_gather` or atleast use traced loop # instead of unrolling the loop (the case for AbstractArray can just use # `stablehlo.gather`). See above for the special case implementation that is optimized. -function NNlib.gather!( - dst::TracedRArray{T1,N}, src::AnyTracedRArray{T2,N}, idxs::AbstractArray -) where {T1,T2,N} +function NNlib.gather!(dst::TracedRArray, src::AnyTracedRArray, idxs::AbstractArray) @warn "Using fallback implementation of `gather!` for using `stablehlo.dynamic_slice`. \ This case is not optimized and will be slow." maxlog = 1 dims = NNlib.scatter_dims(src, dst, idxs) colons = ntuple(Returns(Colon()), dims) start_sizes = ntuple(i -> size(src, i), dims) - results = [reshape(src[colons..., Tuple(idxs[k])...], start_sizes..., :) for k in idxs] + results = map(CartesianIndices(idxs)) do k + res = src[colons..., Tuple(idxs[k])...] + res isa TracedRNumber && (res = Reactant.broadcast_to_size(res, (1,))) + return reshape(res, start_sizes..., :) + end res = reshape(cat(results...; dims=(dims + 1)), size(dst)) dst.mlir_data = res.mlir_data return dst diff --git a/test/nn/nnlib.jl b/test/nn/nnlib.jl index 9aa024c73..7635c9ffd 100644 --- a/test/nn/nnlib.jl +++ b/test/nn/nnlib.jl @@ -173,3 +173,187 @@ end causal_mask2(x) = NNlib.make_causal_mask(x; dims=1) @test @jit(causal_mask2(x_ra)) ≈ causal_mask2(x) end + +# Adapted from https://github.com/FluxML/NNlib.jl/blob/02138682a4fc5ca019759218be50e59907d4527c/test/testsuite/gather.jl#L5 +@testset "NNlib gather" begin + @testset "gather scalar index" begin + ## 1d src, 2d index of ints -> 2d output + src = Float32[3, 4, 5, 6, 7] + index = [ + 1 2 3 4 + 4 2 1 3 + 3 5 5 3 + ] + output = Float32[ + 3 4 5 6 + 6 4 3 5 + 5 7 7 5 + ] + + y1 = @jit(NNlib.gather(Reactant.to_rarray(src), Reactant.to_rarray(index))) + @test y1 ≈ output + @test y1 isa ConcreteRArray{Float32,2} + @test size(y1) == size(index) + + y2 = @jit(NNlib.gather(Reactant.to_rarray(src), index)) + @test y2 ≈ output + @test y2 isa ConcreteRArray{Float32,2} + @test size(y2) == size(index) + + dst = Float32.(zero.(index)) + @test @jit( + NNlib.gather!( + Reactant.to_rarray(dst), Reactant.to_rarray(src), Reactant.to_rarray(index) + ) + ) ≈ output + + dst = zeros(Float32, 3, 5) + @test_throws ArgumentError @jit( + NNlib.gather!( + Reactant.to_rarray(dst), Reactant.to_rarray(src), Reactant.to_rarray(index) + ) + ) + + ## 1d src, 3d index of ints -> 3d output + src = Float32[3, 4, 5, 6, 7] + index = [ + 1 2 3 4 + 4 2 1 3 + 3 5 5 3 + ][:, :, 1:1] + output = Float32[ + 3 4 5 6 + 6 4 3 5 + 5 7 7 5 + ][:, :, 1:1] + + y = @jit(NNlib.gather(Reactant.to_rarray(src), Reactant.to_rarray(index))) + @test y ≈ output + @test y isa ConcreteRArray{Float32,3} + @test size(y) == size(index) + + y = @jit(NNlib.gather(Reactant.to_rarray(src), index)) + @test y ≈ output + @test y isa ConcreteRArray{Float32,3} + @test size(y) == size(index) + + ## 2d src, 2d index of ints -> 3d output + src = Float32[ + 3 5 7 + 4 6 8 + ] + index = [ + 1 2 3 + 2 2 1 + 3 1 3 + ] + + output = zeros(Float32, 2, 3, 3) + output[:, :, 1] = [ + 3 5 7 + 4 6 8 + ] + output[:, :, 2] = [ + 5 5 3 + 6 6 4 + ] + output[:, :, 3] = [ + 7 3 7 + 8 4 8 + ] + + y = @jit(NNlib.gather(Reactant.to_rarray(src), Reactant.to_rarray(index))) + @test y ≈ output + @test y isa ConcreteRArray{Float32,3} + @test size(y) == (size(src)[1:(end - 1)]..., size(index)...) + + y = @jit(NNlib.gather(Reactant.to_rarray(src), index)) + @test y ≈ output + @test y isa ConcreteRArray{Float32,3} + @test size(y) == (size(src)[1:(end - 1)]..., size(index)...) + end + + + @testset "gather tuple index" begin + ## 2d src, 1d index of 2-tuples -> 1d output + src = Float32[ + 3 5 7 + 4 6 8 + ] + index = [(1, 1), (1, 2), (1, 3), (2, 1), (2, 2), (2, 3)] + output = Float32[3, 5, 7, 4, 6, 8] + + y = @jit(NNlib.gather(Reactant.to_rarray(src), Reactant.to_rarray(index))) + M = NNlib.typelength(eltype(index)) + Nsrc = ndims(src) + @test y isa ConcreteRArray{Float32,1} + @test size(y) == (size(src)[1:Nsrc-M]..., size(index)...) + @test y ≈ output + + y = @jit(NNlib.gather(Reactant.to_rarray(src), index)) + @test y ≈ output + @test y isa ConcreteRArray{Float32,1} + @test size(y) == (size(src)[1:Nsrc-M]..., size(index)...) + @test y ≈ output + + ## 3d src, 2d index of 2-tuples -> 3d output + n1, nsrc, nidx = 2, 3, 6 + src = rand(Float32, n1, nsrc, nsrc) + index = [ + (rand(1:nsrc), rand(1:nsrc)) for i = 1:nidx, j = 1:nidx + ] + + y = @jit(NNlib.gather(Reactant.to_rarray(src), Reactant.to_rarray(index))) + M = NNlib.typelength(eltype(index)) + Nsrc = ndims(src) + @test y isa ConcreteRArray{Float32,3} + @test size(y) == (size(src)[1:Nsrc-M]..., size(index)...) + + y = @jit(NNlib.gather(Reactant.to_rarray(src), index)) + M = NNlib.typelength(eltype(index)) + Nsrc = ndims(src) + @test y isa ConcreteRArray{Float32,3} + @test size(y) == (size(src)[1:Nsrc-M]..., size(index)...) + end + + @testset "gather cartesian index" begin + ## 2d src, 1d index of 2-tuples -> 1d output + src = Float32[ + 3 5 7 + 4 6 8 + ] + index = CartesianIndex.([(1, 1), (1, 2), (1, 3), (2, 1), (2, 2), (2, 3)]) + output = Float32[3, 5, 7, 4, 6, 8] + + y = @jit(NNlib.gather(Reactant.to_rarray(src), Reactant.to_rarray(index))) + M = NNlib.typelength(eltype(index)) + Nsrc = ndims(src) + @test y isa ConcreteRArray{Float32,1} + @test size(y) == (size(src)[1:Nsrc-M]..., size(index)...) + @test y ≈ output + + y = @jit(NNlib.gather(Reactant.to_rarray(src), index)) + @test y ≈ output + @test y isa ConcreteRArray{Float32,1} + @test size(y) == (size(src)[1:Nsrc-M]..., size(index)...) + + ## 3d src, 2d index of 2-tuples -> 3d output + n1, nsrc, nidx = 2, 3, 6 + src = rand(Float32, n1, nsrc, nsrc) + index = [ + CartesianIndex((rand(1:nsrc), rand(1:nsrc))) for i = 1:nidx, j = 1:nidx + ] + + y = @jit(NNlib.gather(Reactant.to_rarray(src), Reactant.to_rarray(index))) + M = NNlib.typelength(eltype(index)) + Nsrc = ndims(src) + @test y isa ConcreteRArray{Float32,3} + @test size(y) == (size(src)[1:Nsrc-M]..., size(index)...) + + y = @jit(NNlib.gather(Reactant.to_rarray(src), index)) + M = NNlib.typelength(eltype(index)) + Nsrc = ndims(src) + @test y isa ConcreteRArray{Float32,3} + @test size(y) == (size(src)[1:Nsrc-M]..., size(index)...) + end +end