From 8b90aea0f85d26f7b923c542d591de4aaf5ade82 Mon Sep 17 00:00:00 2001 From: Avik Pal Date: Sat, 9 Nov 2024 16:03:32 -0500 Subject: [PATCH] feat: optimize the special case of indexing with unitranges --- ext/ReactantNNlibExt.jl | 11 +++++++++++ 1 file changed, 11 insertions(+) diff --git a/ext/ReactantNNlibExt.jl b/ext/ReactantNNlibExt.jl index 449a1252f..940accfbe 100644 --- a/ext/ReactantNNlibExt.jl +++ b/ext/ReactantNNlibExt.jl @@ -296,6 +296,17 @@ function NNlib.make_causal_mask(x::AnyTracedRArray; dims::Int=2) ) end +# XXX: reevaluate this manual optimization once +# https://github.com/EnzymeAD/Enzyme-JAX/issues/164 is handled +function NNlib.gather!( + dst::TracedRArray{T1,2}, + src::AnyTracedRArray{T2,2}, + idxs::Union{AbstractUnitRange{<:Number}}, +) where {T1,T2} + dst.mlir_data = src[:, idxs].mlir_data + return dst +end + function NNlib.gather!( dst::TracedRArray{T1,2}, src::AnyTracedRArray{T2,2}, idxs::AbstractVector{<:Number} ) where {T1,T2}