diff --git a/ext/ReactantNNlibExt.jl b/ext/ReactantNNlibExt.jl index 449a1252f..940accfbe 100644 --- a/ext/ReactantNNlibExt.jl +++ b/ext/ReactantNNlibExt.jl @@ -296,6 +296,17 @@ function NNlib.make_causal_mask(x::AnyTracedRArray; dims::Int=2) ) end +# XXX: reevaluate this manual optimization once +# https://github.com/EnzymeAD/Enzyme-JAX/issues/164 is handled +function NNlib.gather!( + dst::TracedRArray{T1,2}, + src::AnyTracedRArray{T2,2}, + idxs::Union{AbstractUnitRange{<:Number}}, +) where {T1,T2} + dst.mlir_data = src[:, idxs].mlir_data + return dst +end + function NNlib.gather!( dst::TracedRArray{T1,2}, src::AnyTracedRArray{T2,2}, idxs::AbstractVector{<:Number} ) where {T1,T2}