From 183b751768460003e0fc56d1736a6b2557050ecb Mon Sep 17 00:00:00 2001 From: Avik Pal Date: Fri, 8 Nov 2024 22:27:05 -0500 Subject: [PATCH 01/12] feat: unbreak NNlib.gather --- src/TracedRArray.jl | 7 ++++++- 1 file changed, 6 insertions(+), 1 deletion(-) diff --git a/src/TracedRArray.jl b/src/TracedRArray.jl index 118316b73..27501d3ac 100644 --- a/src/TracedRArray.jl +++ b/src/TracedRArray.jl @@ -699,7 +699,7 @@ end function broadcast_to_size(arg::T, rsize) where {T<:Number} attr = MLIR.IR.DenseElementsAttribute(Base.fill(arg, Tuple(rsize))) - return arg = TracedRArray{T,length(rsize)}( + return TracedRArray{T,length(rsize)}( (), MLIR.IR.result(MLIR.Dialects.stablehlo.constant(; value=attr), 1), rsize ) end @@ -711,6 +711,11 @@ function broadcast_to_size(arg::TracedRNumber, rsize) ) end +function broadcast_to_size(arg::AnyTracedRArray{T, 0}, rsize) where {T} + arg = materialize_traced_array(arg) + return broadcast_to_size(TracedRNumber{T}((), arg.mlir_data), rsize) +end + function broadcast_to_size(arg::AnyTracedRArray, rsize) arg = materialize_traced_array(arg) size(arg) == rsize && return arg From ccd6a721ac335791a916f30cefcf92042d0c0505 Mon Sep 17 00:00:00 2001 From: Avik Pal Date: Sat, 9 Nov 2024 13:26:46 -0500 Subject: [PATCH 02/12] feat: use dynamic slicing --- ext/ReactantNNlibExt.jl | 15 +++++++++++ src/TracedRArray.jl | 58 ++++++++++++++++++++++++----------------- 2 files changed, 49 insertions(+), 24 deletions(-) diff --git a/ext/ReactantNNlibExt.jl b/ext/ReactantNNlibExt.jl index bccdce05d..a765029b3 100644 --- a/ext/ReactantNNlibExt.jl +++ b/ext/ReactantNNlibExt.jl @@ -296,4 +296,19 @@ function NNlib.make_causal_mask(x::AnyTracedRArray; dims::Int=2) ) end +# XXX: For performance to use `stablehlo.dynamic_gather` or atleast use traced loop +# instead of unrolling the loop (the case for AbstractArray can just use +# `stablehlo.gather`) +function NNlib.gather!( + dst::TracedRArray{T1,N}, src::AnyTracedRArray{T2,N}, idxs::AbstractArray +) where {T1,T2,N} + dims = NNlib.scatter_dims(src, dst, idxs) + colons = ntuple(i -> Colon(), dims) + start_sizes = ntuple(i -> size(src, i), dims) + results = [reshape(src[colons..., Tuple(idxs[k])...], start_sizes..., :) for k in idxs] + res = reshape(cat(results...; dims=(dims + 1)), size(dst)) + dst.mlir_data = res.mlir_data + return dst +end + end # module ReactantNNlibExt diff --git a/src/TracedRArray.jl b/src/TracedRArray.jl index 27501d3ac..f9dbd0fea 100644 --- a/src/TracedRArray.jl +++ b/src/TracedRArray.jl @@ -56,7 +56,9 @@ function get_ancestor_indices(x::WrappedTracedRArray, indices...) return get_ancestor_indices(parent(x), Base.reindex(parentindices(x), indices)...) end -function Base.getindex(a::TracedRArray{T,N}, index::Vararg{Int,N}) where {T,N} +function Base.getindex( + a::TracedRArray{T,N}, index::Vararg{Union{Int,TracedRNumber{Int}},N} +) where {T,N} @warn( """Performing scalar indexing on task $(current_task()). Invocation resulted in scalar indexing of a TracedRArray. @@ -65,21 +67,19 @@ Such implementations *do not* execute on device, but very slowly on the CPU, and require expensive copies and synchronization each time and therefore should be avoided.""" ) + start_indices = [promote_to(TracedRNumber{Int}, i - 1).mlir_data for i in index] + slice_sizes = [1 for _ in index] + res1 = MLIR.IR.result( - MLIR.Dialects.stablehlo.slice( - a.mlir_data; - start_indices=MLIR.IR.DenseArrayAttribute([Int64(i - 1) for i in index]), - limit_indices=MLIR.IR.DenseArrayAttribute([Int64(i) for i in index]), - strides=MLIR.IR.DenseArrayAttribute([Int64(1) for i in index]), - ), - 1, + MLIR.Dialects.stablehlo.dynamic_slice(a.mlir_data, start_indices; slice_sizes), 1 ) res2 = MLIR.IR.result( MLIR.Dialects.stablehlo.reshape( - res1; result_0=MLIR.IR.TensorType(Int64[], eltype(MLIR.IR.type(res1))) + res1; result_0=MLIR.IR.TensorType(Int[], eltype(MLIR.IR.type(res1))) ), 1, ) + return TracedRNumber{T}((), res2) end @@ -87,27 +87,35 @@ function Base.getindex(a::TracedRArray{T,0}) where {T} return TracedRNumber{T}((), a.mlir_data) end +# XXX: We want to support https://github.com/EnzymeAD/Reactant.jl/issues/242 eventually function Base.getindex(a::TracedRArray{T,N}, indices::Vararg{Any,N}) where {T,N} - indices = [i isa Colon ? (1:size(a, idx)) : i for (idx, i) in enumerate(indices)] + indices = map(enumerate(indices)) do (idx, i) + i isa Colon && return 1:size(a, idx) + i isa CartesianIndex && return Tuple(i) + return i + end + + foreach(indices) do idxs + idxs isa Number && return + all(isone, diff(idxs)) || error("non-contiguous indexing is not supported") + end + + start_indices = map(indices) do i + return promote_to(TracedRNumber{Int}, first(i) - 1).mlir_data + end + slice_sizes = [length(i) for i in indices] res = MLIR.IR.result( - MLIR.Dialects.stablehlo.slice( - a.mlir_data; - start_indices=MLIR.IR.DenseArrayAttribute([ - Int64(first(i) - 1) for i in indices - ]), - limit_indices=MLIR.IR.DenseArrayAttribute([Int64(last(i)) for i in indices]), - strides=MLIR.IR.DenseArrayAttribute([Int64(1) for i in indices]), - ), - 1, + MLIR.Dialects.stablehlo.dynamic_slice(a.mlir_data, start_indices; slice_sizes), 1 ) + x = TracedRArray{T,N}((), res, Tuple(length.(indices))) - ddims = findall(x -> x isa Integer, indices) - !isempty(ddims) && return dropdims(x; dims=Tuple(ddims)) + ddims = findall(Base.Fix2(isa, Integer), indices) + isempty(ddims) || return dropdims(x; dims=Tuple(ddims)) return x end # Prevent ambiguity -function Base.getindex(a::WrappedTracedRArray, index::Int...) +function Base.getindex(a::WrappedTracedRArray, index::Union{Int,TracedRNumber{Int}}...) return getindex(ancestor(a), get_ancestor_indices(a, index...)...) end @@ -116,7 +124,9 @@ function Base.getindex(a::WrappedTracedRArray, indices...) end function Base.setindex!( - a::TracedRArray{T,N}, v, indices::Vararg{Union{Base.AbstractUnitRange,Colon,Int},N} + a::TracedRArray{T,N}, + v, + indices::Vararg{Union{Base.AbstractUnitRange,Colon,Int,TracedRNumber{Int}},N}, ) where {T,N} indices = map(enumerate(indices)) do (idx, i) i isa Int ? (i:i) : (i isa Colon ? (1:size(a, idx)) : i) @@ -711,7 +721,7 @@ function broadcast_to_size(arg::TracedRNumber, rsize) ) end -function broadcast_to_size(arg::AnyTracedRArray{T, 0}, rsize) where {T} +function broadcast_to_size(arg::AnyTracedRArray{T,0}, rsize) where {T} arg = materialize_traced_array(arg) return broadcast_to_size(TracedRNumber{T}((), arg.mlir_data), rsize) end From ebfb54362766fb43f31e54cabeb98de094e67d89 Mon Sep 17 00:00:00 2001 From: Avik Pal Date: Sat, 9 Nov 2024 13:28:30 -0500 Subject: [PATCH 03/12] chore: apply suggestions from code review Co-authored-by: github-actions[bot] <41898282+github-actions[bot]@users.noreply.github.com> --- src/TracedRArray.jl | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/TracedRArray.jl b/src/TracedRArray.jl index f9dbd0fea..6133ed05a 100644 --- a/src/TracedRArray.jl +++ b/src/TracedRArray.jl @@ -96,7 +96,7 @@ function Base.getindex(a::TracedRArray{T,N}, indices::Vararg{Any,N}) where {T,N} end foreach(indices) do idxs - idxs isa Number && return + idxs isa Number && return nothing all(isone, diff(idxs)) || error("non-contiguous indexing is not supported") end From d40781a91a5fba94a168c2f8e57ab0009da2b6ca Mon Sep 17 00:00:00 2001 From: Avik Pal Date: Sat, 9 Nov 2024 13:43:24 -0500 Subject: [PATCH 04/12] feat: add an overload of Base.Tuple --- src/TracedRArray.jl | 2 ++ 1 file changed, 2 insertions(+) diff --git a/src/TracedRArray.jl b/src/TracedRArray.jl index 6133ed05a..b273da7da 100644 --- a/src/TracedRArray.jl +++ b/src/TracedRArray.jl @@ -155,6 +155,8 @@ function Base.setindex!( return a end +Base.Tuple(x::TracedRArray) = ntuple(Base.Fix1(Base.getindex, x), length(x)) + Base.size(x::TracedRArray) = x.shape Base.copy(A::TracedRArray{T,N}) where {T,N} = TracedRArray{T,N}((), A.mlir_data, size(A)) From 4a2cef2962169d54573b972403fff8a83a761cd5 Mon Sep 17 00:00:00 2001 From: Avik Pal Date: Sat, 9 Nov 2024 13:51:08 -0500 Subject: [PATCH 05/12] fix: ambiguity error --- src/TracedRArray.jl | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/src/TracedRArray.jl b/src/TracedRArray.jl index b273da7da..fbbd7c022 100644 --- a/src/TracedRArray.jl +++ b/src/TracedRArray.jl @@ -148,7 +148,9 @@ function Base.setindex!( end function Base.setindex!( - a::AnyTracedRArray{T,N}, v, indices::Vararg{Union{Base.AbstractUnitRange,Colon,Int},N} + a::AnyTracedRArray{T,N}, + v, + indices::Vararg{Union{Base.AbstractUnitRange,Colon,Int,TracedRNumber{Int}},N}, ) where {T,N} ancestor_indices = get_ancestor_indices(a, indices...) setindex!(ancestor(a), v, ancestor_indices...) From c360871080c87ba23ae63e7925072f2e11393e37 Mon Sep 17 00:00:00 2001 From: Avik Pal Date: Sat, 9 Nov 2024 14:23:47 -0500 Subject: [PATCH 06/12] feat: special case `gather!` for the most common cases --- ext/ReactantNNlibExt.jl | 32 ++++++++++++++++++++++++++++++-- src/TracedRArray.jl | 2 +- 2 files changed, 31 insertions(+), 3 deletions(-) diff --git a/ext/ReactantNNlibExt.jl b/ext/ReactantNNlibExt.jl index a765029b3..449a1252f 100644 --- a/ext/ReactantNNlibExt.jl +++ b/ext/ReactantNNlibExt.jl @@ -296,14 +296,42 @@ function NNlib.make_causal_mask(x::AnyTracedRArray; dims::Int=2) ) end +function NNlib.gather!( + dst::TracedRArray{T1,2}, src::AnyTracedRArray{T2,2}, idxs::AbstractVector{<:Number} +) where {T1,T2} + dims = NNlib.scatter_dims(src, dst, idxs) + @assert dims == 1 # scatter_dims lets us do some size checks so we call that function + idxs = (Reactant.promote_to(TracedRArray{Int,1}, idxs) .- 1).mlir_data + slice_sizes = Reactant.promote_to(TracedRArray{Int,1}, [size(src, 1), 1]).mlir_data + + dimension_numbers = """ + #stablehlo.gather< + offset_dims = [0], + collapsed_slice_dims = [1], + start_index_map = [1], + index_vector_dim = 1>""" + dimension_numbers = parse(Reactant.MLIR.IR.Attribute, dimension_numbers) + + res = MLIR.IR.result( + Reactant.MLIR.Dialects.stablehlo.dynamic_gather( + src.mlir_data, idxs, slice_sizes; dimension_numbers + ), + 1, + ) + dst.mlir_data = res + return dst +end + # XXX: For performance to use `stablehlo.dynamic_gather` or atleast use traced loop # instead of unrolling the loop (the case for AbstractArray can just use -# `stablehlo.gather`) +# `stablehlo.gather`). See above for the special case implementation that is optimized. function NNlib.gather!( dst::TracedRArray{T1,N}, src::AnyTracedRArray{T2,N}, idxs::AbstractArray ) where {T1,T2,N} + @warn "Using fallback implementation of `gather!` for using `stablehlo.dynamic_slice`. \ + This case is not optimized and will be slow." maxlog = 1 dims = NNlib.scatter_dims(src, dst, idxs) - colons = ntuple(i -> Colon(), dims) + colons = ntuple(Returns(Colon()), dims) start_sizes = ntuple(i -> size(src, i), dims) results = [reshape(src[colons..., Tuple(idxs[k])...], start_sizes..., :) for k in idxs] res = reshape(cat(results...; dims=(dims + 1)), size(dst)) diff --git a/src/TracedRArray.jl b/src/TracedRArray.jl index fbbd7c022..a14b439a0 100644 --- a/src/TracedRArray.jl +++ b/src/TracedRArray.jl @@ -75,7 +75,7 @@ and require expensive copies and synchronization each time and therefore should ) res2 = MLIR.IR.result( MLIR.Dialects.stablehlo.reshape( - res1; result_0=MLIR.IR.TensorType(Int[], eltype(MLIR.IR.type(res1))) + res1; result_0=MLIR.IR.TensorType(Int64[], eltype(MLIR.IR.type(res1))) ), 1, ) From 95e1ba9ead30e2f66a565b11b73d3a33015abdc4 Mon Sep 17 00:00:00 2001 From: Avik Pal Date: Sat, 9 Nov 2024 16:03:32 -0500 Subject: [PATCH 07/12] feat: optimize the special case of indexing with unitranges --- ext/ReactantNNlibExt.jl | 11 +++++++++++ 1 file changed, 11 insertions(+) diff --git a/ext/ReactantNNlibExt.jl b/ext/ReactantNNlibExt.jl index 449a1252f..940accfbe 100644 --- a/ext/ReactantNNlibExt.jl +++ b/ext/ReactantNNlibExt.jl @@ -296,6 +296,17 @@ function NNlib.make_causal_mask(x::AnyTracedRArray; dims::Int=2) ) end +# XXX: reevaluate this manual optimization once +# https://github.com/EnzymeAD/Enzyme-JAX/issues/164 is handled +function NNlib.gather!( + dst::TracedRArray{T1,2}, + src::AnyTracedRArray{T2,2}, + idxs::Union{AbstractUnitRange{<:Number}}, +) where {T1,T2} + dst.mlir_data = src[:, idxs].mlir_data + return dst +end + function NNlib.gather!( dst::TracedRArray{T1,2}, src::AnyTracedRArray{T2,2}, idxs::AbstractVector{<:Number} ) where {T1,T2} From 3cbbb6514ce7d4ea16fbc569d2e09900e5d48c5d Mon Sep 17 00:00:00 2001 From: Avik Pal Date: Sat, 9 Nov 2024 22:51:53 -0500 Subject: [PATCH 08/12] test: dynamic slice test --- src/TracedRArray.jl | 9 ++++++++- src/TracedRNumber.jl | 32 ++++++++++++++++++-------------- test/basic.jl | 11 +++++++++++ 3 files changed, 37 insertions(+), 15 deletions(-) diff --git a/src/TracedRArray.jl b/src/TracedRArray.jl index a14b439a0..3593f141f 100644 --- a/src/TracedRArray.jl +++ b/src/TracedRArray.jl @@ -97,7 +97,11 @@ function Base.getindex(a::TracedRArray{T,N}, indices::Vararg{Any,N}) where {T,N} foreach(indices) do idxs idxs isa Number && return nothing - all(isone, diff(idxs)) || error("non-contiguous indexing is not supported") + contiguous = all(isone, diff(idxs)) + # XXX: We want to throw error even for dynamic indexing + if typeof(a) <: Bool + contiguous || error("non-contiguous indexing is not supported") + end end start_indices = map(indices) do i @@ -875,3 +879,6 @@ for (minT, maxT) in Iterators.product((Number, TracedRNumber), (Number, TracedRN return x end end + +Base.all(f::Function, x::TracedRArray) = mapreduce(f, &, x) +Base.any(f::Function, x::TracedRArray) = mapreduce(f, |, x) diff --git a/src/TracedRNumber.jl b/src/TracedRNumber.jl index 89513a2d9..4ddb02131 100644 --- a/src/TracedRNumber.jl +++ b/src/TracedRNumber.jl @@ -193,20 +193,24 @@ function Base.ifelse( ) end -function Base.:&(x::TracedRNumber{Bool}, y::TracedRNumber{Bool}) - return TracedRNumber{Bool}( - (), MLIR.IR.result(MLIR.Dialects.stablehlo.and(x.mlir_data, y.mlir_data), 1) - ) -end -function Base.:|(x::TracedRNumber{Bool}, y::TracedRNumber{Bool}) - return TracedRNumber{Bool}( - (), MLIR.IR.result(MLIR.Dialects.stablehlo.or(x.mlir_data, y.mlir_data), 1) - ) -end -function Base.:!(x::TracedRNumber{Bool}) - return TracedRNumber{Bool}( - (), MLIR.IR.result(MLIR.Dialects.stablehlo.not(x.mlir_data), 1) - ) +for (T1, T2) in zip((Bool, Integer), (Bool, Integer)) + @eval begin + function Base.:&(x::TracedRNumber{<:$(T1)}, y::TracedRNumber{<:$(T2)}) + return TracedRNumber{promote_type(eltype(x), eltype(y))}( + (), MLIR.IR.result(MLIR.Dialects.stablehlo.and(x.mlir_data, y.mlir_data), 1) + ) + end + function Base.:|(x::TracedRNumber{<:$(T1)}, y::TracedRNumber{<:$(T2)}) + return TracedRNumber{promote_type(eltype(x), eltype(y))}( + (), MLIR.IR.result(MLIR.Dialects.stablehlo.or(x.mlir_data, y.mlir_data), 1) + ) + end + function Base.:!(x::TracedRNumber{<:$(T1)}) + return TracedRNumber{eltype(x)}( + (), MLIR.IR.result(MLIR.Dialects.stablehlo.not(x.mlir_data), 1) + ) + end + end end function Base.literal_pow( diff --git a/test/basic.jl b/test/basic.jl index 5137379bc..e386ae545 100644 --- a/test/basic.jl +++ b/test/basic.jl @@ -557,3 +557,14 @@ end @test minimum(y) ≥ 0.0 @test x_ra ≈ x end + +@testset "dynamic indexing" begin + x = randn(5, 3) + x_ra = Reactant.to_rarray(x) + + idx = [1, 2, 3] + idx_ra = Reactant.to_rarray(idx) + + y = @jit(getindex(x_ra, idx_ra, :)) + @test y ≈ x[idx, :] +end From 8d888cf2941cc5a7f4d3a474bd0ef37161f723e5 Mon Sep 17 00:00:00 2001 From: Avik Pal Date: Sat, 9 Nov 2024 23:20:11 -0500 Subject: [PATCH 09/12] test: port NNlib gather tests over --- ext/ReactantNNlibExt.jl | 10 ++- test/nn/nnlib.jl | 184 ++++++++++++++++++++++++++++++++++++++++ 2 files changed, 190 insertions(+), 4 deletions(-) diff --git a/ext/ReactantNNlibExt.jl b/ext/ReactantNNlibExt.jl index 940accfbe..fadbdaa2d 100644 --- a/ext/ReactantNNlibExt.jl +++ b/ext/ReactantNNlibExt.jl @@ -336,15 +336,17 @@ end # XXX: For performance to use `stablehlo.dynamic_gather` or atleast use traced loop # instead of unrolling the loop (the case for AbstractArray can just use # `stablehlo.gather`). See above for the special case implementation that is optimized. -function NNlib.gather!( - dst::TracedRArray{T1,N}, src::AnyTracedRArray{T2,N}, idxs::AbstractArray -) where {T1,T2,N} +function NNlib.gather!(dst::TracedRArray, src::AnyTracedRArray, idxs::AbstractArray) @warn "Using fallback implementation of `gather!` for using `stablehlo.dynamic_slice`. \ This case is not optimized and will be slow." maxlog = 1 dims = NNlib.scatter_dims(src, dst, idxs) colons = ntuple(Returns(Colon()), dims) start_sizes = ntuple(i -> size(src, i), dims) - results = [reshape(src[colons..., Tuple(idxs[k])...], start_sizes..., :) for k in idxs] + results = map(CartesianIndices(idxs)) do k + res = src[colons..., Tuple(idxs[k])...] + res isa TracedRNumber && (res = Reactant.broadcast_to_size(res, (1,))) + return reshape(res, start_sizes..., :) + end res = reshape(cat(results...; dims=(dims + 1)), size(dst)) dst.mlir_data = res.mlir_data return dst diff --git a/test/nn/nnlib.jl b/test/nn/nnlib.jl index 9aa024c73..7635c9ffd 100644 --- a/test/nn/nnlib.jl +++ b/test/nn/nnlib.jl @@ -173,3 +173,187 @@ end causal_mask2(x) = NNlib.make_causal_mask(x; dims=1) @test @jit(causal_mask2(x_ra)) ≈ causal_mask2(x) end + +# Adapted from https://github.com/FluxML/NNlib.jl/blob/02138682a4fc5ca019759218be50e59907d4527c/test/testsuite/gather.jl#L5 +@testset "NNlib gather" begin + @testset "gather scalar index" begin + ## 1d src, 2d index of ints -> 2d output + src = Float32[3, 4, 5, 6, 7] + index = [ + 1 2 3 4 + 4 2 1 3 + 3 5 5 3 + ] + output = Float32[ + 3 4 5 6 + 6 4 3 5 + 5 7 7 5 + ] + + y1 = @jit(NNlib.gather(Reactant.to_rarray(src), Reactant.to_rarray(index))) + @test y1 ≈ output + @test y1 isa ConcreteRArray{Float32,2} + @test size(y1) == size(index) + + y2 = @jit(NNlib.gather(Reactant.to_rarray(src), index)) + @test y2 ≈ output + @test y2 isa ConcreteRArray{Float32,2} + @test size(y2) == size(index) + + dst = Float32.(zero.(index)) + @test @jit( + NNlib.gather!( + Reactant.to_rarray(dst), Reactant.to_rarray(src), Reactant.to_rarray(index) + ) + ) ≈ output + + dst = zeros(Float32, 3, 5) + @test_throws ArgumentError @jit( + NNlib.gather!( + Reactant.to_rarray(dst), Reactant.to_rarray(src), Reactant.to_rarray(index) + ) + ) + + ## 1d src, 3d index of ints -> 3d output + src = Float32[3, 4, 5, 6, 7] + index = [ + 1 2 3 4 + 4 2 1 3 + 3 5 5 3 + ][:, :, 1:1] + output = Float32[ + 3 4 5 6 + 6 4 3 5 + 5 7 7 5 + ][:, :, 1:1] + + y = @jit(NNlib.gather(Reactant.to_rarray(src), Reactant.to_rarray(index))) + @test y ≈ output + @test y isa ConcreteRArray{Float32,3} + @test size(y) == size(index) + + y = @jit(NNlib.gather(Reactant.to_rarray(src), index)) + @test y ≈ output + @test y isa ConcreteRArray{Float32,3} + @test size(y) == size(index) + + ## 2d src, 2d index of ints -> 3d output + src = Float32[ + 3 5 7 + 4 6 8 + ] + index = [ + 1 2 3 + 2 2 1 + 3 1 3 + ] + + output = zeros(Float32, 2, 3, 3) + output[:, :, 1] = [ + 3 5 7 + 4 6 8 + ] + output[:, :, 2] = [ + 5 5 3 + 6 6 4 + ] + output[:, :, 3] = [ + 7 3 7 + 8 4 8 + ] + + y = @jit(NNlib.gather(Reactant.to_rarray(src), Reactant.to_rarray(index))) + @test y ≈ output + @test y isa ConcreteRArray{Float32,3} + @test size(y) == (size(src)[1:(end - 1)]..., size(index)...) + + y = @jit(NNlib.gather(Reactant.to_rarray(src), index)) + @test y ≈ output + @test y isa ConcreteRArray{Float32,3} + @test size(y) == (size(src)[1:(end - 1)]..., size(index)...) + end + + + @testset "gather tuple index" begin + ## 2d src, 1d index of 2-tuples -> 1d output + src = Float32[ + 3 5 7 + 4 6 8 + ] + index = [(1, 1), (1, 2), (1, 3), (2, 1), (2, 2), (2, 3)] + output = Float32[3, 5, 7, 4, 6, 8] + + y = @jit(NNlib.gather(Reactant.to_rarray(src), Reactant.to_rarray(index))) + M = NNlib.typelength(eltype(index)) + Nsrc = ndims(src) + @test y isa ConcreteRArray{Float32,1} + @test size(y) == (size(src)[1:Nsrc-M]..., size(index)...) + @test y ≈ output + + y = @jit(NNlib.gather(Reactant.to_rarray(src), index)) + @test y ≈ output + @test y isa ConcreteRArray{Float32,1} + @test size(y) == (size(src)[1:Nsrc-M]..., size(index)...) + @test y ≈ output + + ## 3d src, 2d index of 2-tuples -> 3d output + n1, nsrc, nidx = 2, 3, 6 + src = rand(Float32, n1, nsrc, nsrc) + index = [ + (rand(1:nsrc), rand(1:nsrc)) for i = 1:nidx, j = 1:nidx + ] + + y = @jit(NNlib.gather(Reactant.to_rarray(src), Reactant.to_rarray(index))) + M = NNlib.typelength(eltype(index)) + Nsrc = ndims(src) + @test y isa ConcreteRArray{Float32,3} + @test size(y) == (size(src)[1:Nsrc-M]..., size(index)...) + + y = @jit(NNlib.gather(Reactant.to_rarray(src), index)) + M = NNlib.typelength(eltype(index)) + Nsrc = ndims(src) + @test y isa ConcreteRArray{Float32,3} + @test size(y) == (size(src)[1:Nsrc-M]..., size(index)...) + end + + @testset "gather cartesian index" begin + ## 2d src, 1d index of 2-tuples -> 1d output + src = Float32[ + 3 5 7 + 4 6 8 + ] + index = CartesianIndex.([(1, 1), (1, 2), (1, 3), (2, 1), (2, 2), (2, 3)]) + output = Float32[3, 5, 7, 4, 6, 8] + + y = @jit(NNlib.gather(Reactant.to_rarray(src), Reactant.to_rarray(index))) + M = NNlib.typelength(eltype(index)) + Nsrc = ndims(src) + @test y isa ConcreteRArray{Float32,1} + @test size(y) == (size(src)[1:Nsrc-M]..., size(index)...) + @test y ≈ output + + y = @jit(NNlib.gather(Reactant.to_rarray(src), index)) + @test y ≈ output + @test y isa ConcreteRArray{Float32,1} + @test size(y) == (size(src)[1:Nsrc-M]..., size(index)...) + + ## 3d src, 2d index of 2-tuples -> 3d output + n1, nsrc, nidx = 2, 3, 6 + src = rand(Float32, n1, nsrc, nsrc) + index = [ + CartesianIndex((rand(1:nsrc), rand(1:nsrc))) for i = 1:nidx, j = 1:nidx + ] + + y = @jit(NNlib.gather(Reactant.to_rarray(src), Reactant.to_rarray(index))) + M = NNlib.typelength(eltype(index)) + Nsrc = ndims(src) + @test y isa ConcreteRArray{Float32,3} + @test size(y) == (size(src)[1:Nsrc-M]..., size(index)...) + + y = @jit(NNlib.gather(Reactant.to_rarray(src), index)) + M = NNlib.typelength(eltype(index)) + Nsrc = ndims(src) + @test y isa ConcreteRArray{Float32,3} + @test size(y) == (size(src)[1:Nsrc-M]..., size(index)...) + end +end From 2547eebf39787efe2dfc411ef928ca6244133141 Mon Sep 17 00:00:00 2001 From: Avik Pal Date: Sat, 9 Nov 2024 23:22:33 -0500 Subject: [PATCH 10/12] chore: apply suggestions from code review Co-authored-by: github-actions[bot] <41898282+github-actions[bot]@users.noreply.github.com> --- test/nn/nnlib.jl | 26 ++++++++++---------------- 1 file changed, 10 insertions(+), 16 deletions(-) diff --git a/test/nn/nnlib.jl b/test/nn/nnlib.jl index 7635c9ffd..3be3d97fc 100644 --- a/test/nn/nnlib.jl +++ b/test/nn/nnlib.jl @@ -226,7 +226,6 @@ end 6 4 3 5 5 7 7 5 ][:, :, 1:1] - y = @jit(NNlib.gather(Reactant.to_rarray(src), Reactant.to_rarray(index))) @test y ≈ output @test y isa ConcreteRArray{Float32,3} @@ -273,7 +272,6 @@ end @test size(y) == (size(src)[1:(end - 1)]..., size(index)...) end - @testset "gather tuple index" begin ## 2d src, 1d index of 2-tuples -> 1d output src = Float32[ @@ -287,33 +285,31 @@ end M = NNlib.typelength(eltype(index)) Nsrc = ndims(src) @test y isa ConcreteRArray{Float32,1} - @test size(y) == (size(src)[1:Nsrc-M]..., size(index)...) + @test size(y) == (size(src)[1:(Nsrc - M)]..., size(index)...) @test y ≈ output y = @jit(NNlib.gather(Reactant.to_rarray(src), index)) @test y ≈ output @test y isa ConcreteRArray{Float32,1} - @test size(y) == (size(src)[1:Nsrc-M]..., size(index)...) + @test size(y) == (size(src)[1:(Nsrc - M)]..., size(index)...) @test y ≈ output ## 3d src, 2d index of 2-tuples -> 3d output n1, nsrc, nidx = 2, 3, 6 src = rand(Float32, n1, nsrc, nsrc) - index = [ - (rand(1:nsrc), rand(1:nsrc)) for i = 1:nidx, j = 1:nidx - ] + index = [(rand(1:nsrc), rand(1:nsrc)) for i in 1:nidx, j in 1:nidx] y = @jit(NNlib.gather(Reactant.to_rarray(src), Reactant.to_rarray(index))) M = NNlib.typelength(eltype(index)) Nsrc = ndims(src) @test y isa ConcreteRArray{Float32,3} - @test size(y) == (size(src)[1:Nsrc-M]..., size(index)...) + @test size(y) == (size(src)[1:(Nsrc - M)]..., size(index)...) y = @jit(NNlib.gather(Reactant.to_rarray(src), index)) M = NNlib.typelength(eltype(index)) Nsrc = ndims(src) @test y isa ConcreteRArray{Float32,3} - @test size(y) == (size(src)[1:Nsrc-M]..., size(index)...) + @test size(y) == (size(src)[1:(Nsrc - M)]..., size(index)...) end @testset "gather cartesian index" begin @@ -329,31 +325,29 @@ end M = NNlib.typelength(eltype(index)) Nsrc = ndims(src) @test y isa ConcreteRArray{Float32,1} - @test size(y) == (size(src)[1:Nsrc-M]..., size(index)...) + @test size(y) == (size(src)[1:(Nsrc - M)]..., size(index)...) @test y ≈ output y = @jit(NNlib.gather(Reactant.to_rarray(src), index)) @test y ≈ output @test y isa ConcreteRArray{Float32,1} - @test size(y) == (size(src)[1:Nsrc-M]..., size(index)...) + @test size(y) == (size(src)[1:(Nsrc - M)]..., size(index)...) ## 3d src, 2d index of 2-tuples -> 3d output n1, nsrc, nidx = 2, 3, 6 src = rand(Float32, n1, nsrc, nsrc) - index = [ - CartesianIndex((rand(1:nsrc), rand(1:nsrc))) for i = 1:nidx, j = 1:nidx - ] + index = [CartesianIndex((rand(1:nsrc), rand(1:nsrc))) for i in 1:nidx, j in 1:nidx] y = @jit(NNlib.gather(Reactant.to_rarray(src), Reactant.to_rarray(index))) M = NNlib.typelength(eltype(index)) Nsrc = ndims(src) @test y isa ConcreteRArray{Float32,3} - @test size(y) == (size(src)[1:Nsrc-M]..., size(index)...) + @test size(y) == (size(src)[1:(Nsrc - M)]..., size(index)...) y = @jit(NNlib.gather(Reactant.to_rarray(src), index)) M = NNlib.typelength(eltype(index)) Nsrc = ndims(src) @test y isa ConcreteRArray{Float32,3} - @test size(y) == (size(src)[1:Nsrc-M]..., size(index)...) + @test size(y) == (size(src)[1:(Nsrc - M)]..., size(index)...) end end From 4de7be43638c262b2d955cac5612d3d94ce8fdc6 Mon Sep 17 00:00:00 2001 From: Avik Pal Date: Sun, 10 Nov 2024 10:30:31 -0500 Subject: [PATCH 11/12] fix: mark length as Int64 --- src/TracedRArray.jl | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/TracedRArray.jl b/src/TracedRArray.jl index 3593f141f..ed704124b 100644 --- a/src/TracedRArray.jl +++ b/src/TracedRArray.jl @@ -68,7 +68,7 @@ and require expensive copies and synchronization each time and therefore should ) start_indices = [promote_to(TracedRNumber{Int}, i - 1).mlir_data for i in index] - slice_sizes = [1 for _ in index] + slice_sizes = [Int64(1) for _ in index] res1 = MLIR.IR.result( MLIR.Dialects.stablehlo.dynamic_slice(a.mlir_data, start_indices; slice_sizes), 1 @@ -107,7 +107,7 @@ function Base.getindex(a::TracedRArray{T,N}, indices::Vararg{Any,N}) where {T,N} start_indices = map(indices) do i return promote_to(TracedRNumber{Int}, first(i) - 1).mlir_data end - slice_sizes = [length(i) for i in indices] + slice_sizes = [Int64(length(i)) for i in indices] res = MLIR.IR.result( MLIR.Dialects.stablehlo.dynamic_slice(a.mlir_data, start_indices; slice_sizes), 1 ) From cf6bffd62d61dc6588e644e08a7324153de4e9bb Mon Sep 17 00:00:00 2001 From: Avik Pal Date: Sun, 10 Nov 2024 10:35:33 -0500 Subject: [PATCH 12/12] fix: use the C API for dimension numbers --- ext/ReactantNNlibExt.jl | 18 +++++++++++------- 1 file changed, 11 insertions(+), 7 deletions(-) diff --git a/ext/ReactantNNlibExt.jl b/ext/ReactantNNlibExt.jl index fadbdaa2d..bb1df12bb 100644 --- a/ext/ReactantNNlibExt.jl +++ b/ext/ReactantNNlibExt.jl @@ -315,13 +315,17 @@ function NNlib.gather!( idxs = (Reactant.promote_to(TracedRArray{Int,1}, idxs) .- 1).mlir_data slice_sizes = Reactant.promote_to(TracedRArray{Int,1}, [size(src, 1), 1]).mlir_data - dimension_numbers = """ - #stablehlo.gather< - offset_dims = [0], - collapsed_slice_dims = [1], - start_index_map = [1], - index_vector_dim = 1>""" - dimension_numbers = parse(Reactant.MLIR.IR.Attribute, dimension_numbers) + #! format: off + dimension_numbers = MLIR.API.stablehloGatherDimensionNumbersGet( + MLIR.IR.context(), + Int64(1), Int64[0], + Int64(1), Int64[1], + Int64(0), Int64[], + Int64(0), Int64[], + Int64(1), Int64[1], + Int64(1) + ) + #! format: on res = MLIR.IR.result( Reactant.MLIR.Dialects.stablehlo.dynamic_gather(