diff --git a/ext/ReactantNNlibExt.jl b/ext/ReactantNNlibExt.jl index bccdce05d..bb1df12bb 100644 --- a/ext/ReactantNNlibExt.jl +++ b/ext/ReactantNNlibExt.jl @@ -296,4 +296,64 @@ function NNlib.make_causal_mask(x::AnyTracedRArray; dims::Int=2) ) end +# XXX: reevaluate this manual optimization once +# https://github.com/EnzymeAD/Enzyme-JAX/issues/164 is handled +function NNlib.gather!( + dst::TracedRArray{T1,2}, + src::AnyTracedRArray{T2,2}, + idxs::Union{AbstractUnitRange{<:Number}}, +) where {T1,T2} + dst.mlir_data = src[:, idxs].mlir_data + return dst +end + +function NNlib.gather!( + dst::TracedRArray{T1,2}, src::AnyTracedRArray{T2,2}, idxs::AbstractVector{<:Number} +) where {T1,T2} + dims = NNlib.scatter_dims(src, dst, idxs) + @assert dims == 1 # scatter_dims lets us do some size checks so we call that function + idxs = (Reactant.promote_to(TracedRArray{Int,1}, idxs) .- 1).mlir_data + slice_sizes = Reactant.promote_to(TracedRArray{Int,1}, [size(src, 1), 1]).mlir_data + + #! format: off + dimension_numbers = MLIR.API.stablehloGatherDimensionNumbersGet( + MLIR.IR.context(), + Int64(1), Int64[0], + Int64(1), Int64[1], + Int64(0), Int64[], + Int64(0), Int64[], + Int64(1), Int64[1], + Int64(1) + ) + #! format: on + + res = MLIR.IR.result( + Reactant.MLIR.Dialects.stablehlo.dynamic_gather( + src.mlir_data, idxs, slice_sizes; dimension_numbers + ), + 1, + ) + dst.mlir_data = res + return dst +end + +# XXX: For performance to use `stablehlo.dynamic_gather` or atleast use traced loop +# instead of unrolling the loop (the case for AbstractArray can just use +# `stablehlo.gather`). See above for the special case implementation that is optimized. +function NNlib.gather!(dst::TracedRArray, src::AnyTracedRArray, idxs::AbstractArray) + @warn "Using fallback implementation of `gather!` for using `stablehlo.dynamic_slice`. \ + This case is not optimized and will be slow." maxlog = 1 + dims = NNlib.scatter_dims(src, dst, idxs) + colons = ntuple(Returns(Colon()), dims) + start_sizes = ntuple(i -> size(src, i), dims) + results = map(CartesianIndices(idxs)) do k + res = src[colons..., Tuple(idxs[k])...] + res isa TracedRNumber && (res = Reactant.broadcast_to_size(res, (1,))) + return reshape(res, start_sizes..., :) + end + res = reshape(cat(results...; dims=(dims + 1)), size(dst)) + dst.mlir_data = res.mlir_data + return dst +end + end # module ReactantNNlibExt diff --git a/src/TracedRArray.jl b/src/TracedRArray.jl index 118316b73..ed704124b 100644 --- a/src/TracedRArray.jl +++ b/src/TracedRArray.jl @@ -56,7 +56,9 @@ function get_ancestor_indices(x::WrappedTracedRArray, indices...) return get_ancestor_indices(parent(x), Base.reindex(parentindices(x), indices)...) end -function Base.getindex(a::TracedRArray{T,N}, index::Vararg{Int,N}) where {T,N} +function Base.getindex( + a::TracedRArray{T,N}, index::Vararg{Union{Int,TracedRNumber{Int}},N} +) where {T,N} @warn( """Performing scalar indexing on task $(current_task()). Invocation resulted in scalar indexing of a TracedRArray. @@ -65,14 +67,11 @@ Such implementations *do not* execute on device, but very slowly on the CPU, and require expensive copies and synchronization each time and therefore should be avoided.""" ) + start_indices = [promote_to(TracedRNumber{Int}, i - 1).mlir_data for i in index] + slice_sizes = [Int64(1) for _ in index] + res1 = MLIR.IR.result( - MLIR.Dialects.stablehlo.slice( - a.mlir_data; - start_indices=MLIR.IR.DenseArrayAttribute([Int64(i - 1) for i in index]), - limit_indices=MLIR.IR.DenseArrayAttribute([Int64(i) for i in index]), - strides=MLIR.IR.DenseArrayAttribute([Int64(1) for i in index]), - ), - 1, + MLIR.Dialects.stablehlo.dynamic_slice(a.mlir_data, start_indices; slice_sizes), 1 ) res2 = MLIR.IR.result( MLIR.Dialects.stablehlo.reshape( @@ -80,6 +79,7 @@ and require expensive copies and synchronization each time and therefore should ), 1, ) + return TracedRNumber{T}((), res2) end @@ -87,27 +87,39 @@ function Base.getindex(a::TracedRArray{T,0}) where {T} return TracedRNumber{T}((), a.mlir_data) end +# XXX: We want to support https://github.com/EnzymeAD/Reactant.jl/issues/242 eventually function Base.getindex(a::TracedRArray{T,N}, indices::Vararg{Any,N}) where {T,N} - indices = [i isa Colon ? (1:size(a, idx)) : i for (idx, i) in enumerate(indices)] + indices = map(enumerate(indices)) do (idx, i) + i isa Colon && return 1:size(a, idx) + i isa CartesianIndex && return Tuple(i) + return i + end + + foreach(indices) do idxs + idxs isa Number && return nothing + contiguous = all(isone, diff(idxs)) + # XXX: We want to throw error even for dynamic indexing + if typeof(a) <: Bool + contiguous || error("non-contiguous indexing is not supported") + end + end + + start_indices = map(indices) do i + return promote_to(TracedRNumber{Int}, first(i) - 1).mlir_data + end + slice_sizes = [Int64(length(i)) for i in indices] res = MLIR.IR.result( - MLIR.Dialects.stablehlo.slice( - a.mlir_data; - start_indices=MLIR.IR.DenseArrayAttribute([ - Int64(first(i) - 1) for i in indices - ]), - limit_indices=MLIR.IR.DenseArrayAttribute([Int64(last(i)) for i in indices]), - strides=MLIR.IR.DenseArrayAttribute([Int64(1) for i in indices]), - ), - 1, + MLIR.Dialects.stablehlo.dynamic_slice(a.mlir_data, start_indices; slice_sizes), 1 ) + x = TracedRArray{T,N}((), res, Tuple(length.(indices))) - ddims = findall(x -> x isa Integer, indices) - !isempty(ddims) && return dropdims(x; dims=Tuple(ddims)) + ddims = findall(Base.Fix2(isa, Integer), indices) + isempty(ddims) || return dropdims(x; dims=Tuple(ddims)) return x end # Prevent ambiguity -function Base.getindex(a::WrappedTracedRArray, index::Int...) +function Base.getindex(a::WrappedTracedRArray, index::Union{Int,TracedRNumber{Int}}...) return getindex(ancestor(a), get_ancestor_indices(a, index...)...) end @@ -116,7 +128,9 @@ function Base.getindex(a::WrappedTracedRArray, indices...) end function Base.setindex!( - a::TracedRArray{T,N}, v, indices::Vararg{Union{Base.AbstractUnitRange,Colon,Int},N} + a::TracedRArray{T,N}, + v, + indices::Vararg{Union{Base.AbstractUnitRange,Colon,Int,TracedRNumber{Int}},N}, ) where {T,N} indices = map(enumerate(indices)) do (idx, i) i isa Int ? (i:i) : (i isa Colon ? (1:size(a, idx)) : i) @@ -138,13 +152,17 @@ function Base.setindex!( end function Base.setindex!( - a::AnyTracedRArray{T,N}, v, indices::Vararg{Union{Base.AbstractUnitRange,Colon,Int},N} + a::AnyTracedRArray{T,N}, + v, + indices::Vararg{Union{Base.AbstractUnitRange,Colon,Int,TracedRNumber{Int}},N}, ) where {T,N} ancestor_indices = get_ancestor_indices(a, indices...) setindex!(ancestor(a), v, ancestor_indices...) return a end +Base.Tuple(x::TracedRArray) = ntuple(Base.Fix1(Base.getindex, x), length(x)) + Base.size(x::TracedRArray) = x.shape Base.copy(A::TracedRArray{T,N}) where {T,N} = TracedRArray{T,N}((), A.mlir_data, size(A)) @@ -699,7 +717,7 @@ end function broadcast_to_size(arg::T, rsize) where {T<:Number} attr = MLIR.IR.DenseElementsAttribute(Base.fill(arg, Tuple(rsize))) - return arg = TracedRArray{T,length(rsize)}( + return TracedRArray{T,length(rsize)}( (), MLIR.IR.result(MLIR.Dialects.stablehlo.constant(; value=attr), 1), rsize ) end @@ -711,6 +729,11 @@ function broadcast_to_size(arg::TracedRNumber, rsize) ) end +function broadcast_to_size(arg::AnyTracedRArray{T,0}, rsize) where {T} + arg = materialize_traced_array(arg) + return broadcast_to_size(TracedRNumber{T}((), arg.mlir_data), rsize) +end + function broadcast_to_size(arg::AnyTracedRArray, rsize) arg = materialize_traced_array(arg) size(arg) == rsize && return arg @@ -856,3 +879,6 @@ for (minT, maxT) in Iterators.product((Number, TracedRNumber), (Number, TracedRN return x end end + +Base.all(f::Function, x::TracedRArray) = mapreduce(f, &, x) +Base.any(f::Function, x::TracedRArray) = mapreduce(f, |, x) diff --git a/src/TracedRNumber.jl b/src/TracedRNumber.jl index 89513a2d9..4ddb02131 100644 --- a/src/TracedRNumber.jl +++ b/src/TracedRNumber.jl @@ -193,20 +193,24 @@ function Base.ifelse( ) end -function Base.:&(x::TracedRNumber{Bool}, y::TracedRNumber{Bool}) - return TracedRNumber{Bool}( - (), MLIR.IR.result(MLIR.Dialects.stablehlo.and(x.mlir_data, y.mlir_data), 1) - ) -end -function Base.:|(x::TracedRNumber{Bool}, y::TracedRNumber{Bool}) - return TracedRNumber{Bool}( - (), MLIR.IR.result(MLIR.Dialects.stablehlo.or(x.mlir_data, y.mlir_data), 1) - ) -end -function Base.:!(x::TracedRNumber{Bool}) - return TracedRNumber{Bool}( - (), MLIR.IR.result(MLIR.Dialects.stablehlo.not(x.mlir_data), 1) - ) +for (T1, T2) in zip((Bool, Integer), (Bool, Integer)) + @eval begin + function Base.:&(x::TracedRNumber{<:$(T1)}, y::TracedRNumber{<:$(T2)}) + return TracedRNumber{promote_type(eltype(x), eltype(y))}( + (), MLIR.IR.result(MLIR.Dialects.stablehlo.and(x.mlir_data, y.mlir_data), 1) + ) + end + function Base.:|(x::TracedRNumber{<:$(T1)}, y::TracedRNumber{<:$(T2)}) + return TracedRNumber{promote_type(eltype(x), eltype(y))}( + (), MLIR.IR.result(MLIR.Dialects.stablehlo.or(x.mlir_data, y.mlir_data), 1) + ) + end + function Base.:!(x::TracedRNumber{<:$(T1)}) + return TracedRNumber{eltype(x)}( + (), MLIR.IR.result(MLIR.Dialects.stablehlo.not(x.mlir_data), 1) + ) + end + end end function Base.literal_pow( diff --git a/test/basic.jl b/test/basic.jl index 5137379bc..e386ae545 100644 --- a/test/basic.jl +++ b/test/basic.jl @@ -557,3 +557,14 @@ end @test minimum(y) ≥ 0.0 @test x_ra ≈ x end + +@testset "dynamic indexing" begin + x = randn(5, 3) + x_ra = Reactant.to_rarray(x) + + idx = [1, 2, 3] + idx_ra = Reactant.to_rarray(idx) + + y = @jit(getindex(x_ra, idx_ra, :)) + @test y ≈ x[idx, :] +end diff --git a/test/nn/nnlib.jl b/test/nn/nnlib.jl index 9aa024c73..3be3d97fc 100644 --- a/test/nn/nnlib.jl +++ b/test/nn/nnlib.jl @@ -173,3 +173,181 @@ end causal_mask2(x) = NNlib.make_causal_mask(x; dims=1) @test @jit(causal_mask2(x_ra)) ≈ causal_mask2(x) end + +# Adapted from https://github.com/FluxML/NNlib.jl/blob/02138682a4fc5ca019759218be50e59907d4527c/test/testsuite/gather.jl#L5 +@testset "NNlib gather" begin + @testset "gather scalar index" begin + ## 1d src, 2d index of ints -> 2d output + src = Float32[3, 4, 5, 6, 7] + index = [ + 1 2 3 4 + 4 2 1 3 + 3 5 5 3 + ] + output = Float32[ + 3 4 5 6 + 6 4 3 5 + 5 7 7 5 + ] + + y1 = @jit(NNlib.gather(Reactant.to_rarray(src), Reactant.to_rarray(index))) + @test y1 ≈ output + @test y1 isa ConcreteRArray{Float32,2} + @test size(y1) == size(index) + + y2 = @jit(NNlib.gather(Reactant.to_rarray(src), index)) + @test y2 ≈ output + @test y2 isa ConcreteRArray{Float32,2} + @test size(y2) == size(index) + + dst = Float32.(zero.(index)) + @test @jit( + NNlib.gather!( + Reactant.to_rarray(dst), Reactant.to_rarray(src), Reactant.to_rarray(index) + ) + ) ≈ output + + dst = zeros(Float32, 3, 5) + @test_throws ArgumentError @jit( + NNlib.gather!( + Reactant.to_rarray(dst), Reactant.to_rarray(src), Reactant.to_rarray(index) + ) + ) + + ## 1d src, 3d index of ints -> 3d output + src = Float32[3, 4, 5, 6, 7] + index = [ + 1 2 3 4 + 4 2 1 3 + 3 5 5 3 + ][:, :, 1:1] + output = Float32[ + 3 4 5 6 + 6 4 3 5 + 5 7 7 5 + ][:, :, 1:1] + y = @jit(NNlib.gather(Reactant.to_rarray(src), Reactant.to_rarray(index))) + @test y ≈ output + @test y isa ConcreteRArray{Float32,3} + @test size(y) == size(index) + + y = @jit(NNlib.gather(Reactant.to_rarray(src), index)) + @test y ≈ output + @test y isa ConcreteRArray{Float32,3} + @test size(y) == size(index) + + ## 2d src, 2d index of ints -> 3d output + src = Float32[ + 3 5 7 + 4 6 8 + ] + index = [ + 1 2 3 + 2 2 1 + 3 1 3 + ] + + output = zeros(Float32, 2, 3, 3) + output[:, :, 1] = [ + 3 5 7 + 4 6 8 + ] + output[:, :, 2] = [ + 5 5 3 + 6 6 4 + ] + output[:, :, 3] = [ + 7 3 7 + 8 4 8 + ] + + y = @jit(NNlib.gather(Reactant.to_rarray(src), Reactant.to_rarray(index))) + @test y ≈ output + @test y isa ConcreteRArray{Float32,3} + @test size(y) == (size(src)[1:(end - 1)]..., size(index)...) + + y = @jit(NNlib.gather(Reactant.to_rarray(src), index)) + @test y ≈ output + @test y isa ConcreteRArray{Float32,3} + @test size(y) == (size(src)[1:(end - 1)]..., size(index)...) + end + + @testset "gather tuple index" begin + ## 2d src, 1d index of 2-tuples -> 1d output + src = Float32[ + 3 5 7 + 4 6 8 + ] + index = [(1, 1), (1, 2), (1, 3), (2, 1), (2, 2), (2, 3)] + output = Float32[3, 5, 7, 4, 6, 8] + + y = @jit(NNlib.gather(Reactant.to_rarray(src), Reactant.to_rarray(index))) + M = NNlib.typelength(eltype(index)) + Nsrc = ndims(src) + @test y isa ConcreteRArray{Float32,1} + @test size(y) == (size(src)[1:(Nsrc - M)]..., size(index)...) + @test y ≈ output + + y = @jit(NNlib.gather(Reactant.to_rarray(src), index)) + @test y ≈ output + @test y isa ConcreteRArray{Float32,1} + @test size(y) == (size(src)[1:(Nsrc - M)]..., size(index)...) + @test y ≈ output + + ## 3d src, 2d index of 2-tuples -> 3d output + n1, nsrc, nidx = 2, 3, 6 + src = rand(Float32, n1, nsrc, nsrc) + index = [(rand(1:nsrc), rand(1:nsrc)) for i in 1:nidx, j in 1:nidx] + + y = @jit(NNlib.gather(Reactant.to_rarray(src), Reactant.to_rarray(index))) + M = NNlib.typelength(eltype(index)) + Nsrc = ndims(src) + @test y isa ConcreteRArray{Float32,3} + @test size(y) == (size(src)[1:(Nsrc - M)]..., size(index)...) + + y = @jit(NNlib.gather(Reactant.to_rarray(src), index)) + M = NNlib.typelength(eltype(index)) + Nsrc = ndims(src) + @test y isa ConcreteRArray{Float32,3} + @test size(y) == (size(src)[1:(Nsrc - M)]..., size(index)...) + end + + @testset "gather cartesian index" begin + ## 2d src, 1d index of 2-tuples -> 1d output + src = Float32[ + 3 5 7 + 4 6 8 + ] + index = CartesianIndex.([(1, 1), (1, 2), (1, 3), (2, 1), (2, 2), (2, 3)]) + output = Float32[3, 5, 7, 4, 6, 8] + + y = @jit(NNlib.gather(Reactant.to_rarray(src), Reactant.to_rarray(index))) + M = NNlib.typelength(eltype(index)) + Nsrc = ndims(src) + @test y isa ConcreteRArray{Float32,1} + @test size(y) == (size(src)[1:(Nsrc - M)]..., size(index)...) + @test y ≈ output + + y = @jit(NNlib.gather(Reactant.to_rarray(src), index)) + @test y ≈ output + @test y isa ConcreteRArray{Float32,1} + @test size(y) == (size(src)[1:(Nsrc - M)]..., size(index)...) + + ## 3d src, 2d index of 2-tuples -> 3d output + n1, nsrc, nidx = 2, 3, 6 + src = rand(Float32, n1, nsrc, nsrc) + index = [CartesianIndex((rand(1:nsrc), rand(1:nsrc))) for i in 1:nidx, j in 1:nidx] + + y = @jit(NNlib.gather(Reactant.to_rarray(src), Reactant.to_rarray(index))) + M = NNlib.typelength(eltype(index)) + Nsrc = ndims(src) + @test y isa ConcreteRArray{Float32,3} + @test size(y) == (size(src)[1:(Nsrc - M)]..., size(index)...) + + y = @jit(NNlib.gather(Reactant.to_rarray(src), index)) + M = NNlib.typelength(eltype(index)) + Nsrc = ndims(src) + @test y isa ConcreteRArray{Float32,3} + @test size(y) == (size(src)[1:(Nsrc - M)]..., size(index)...) + end +end