Skip to content

Commit

Permalink
feat: optimize the special case of indexing with unitranges
Browse files Browse the repository at this point in the history
  • Loading branch information
avik-pal committed Nov 9, 2024
1 parent d786d2a commit 8b90aea
Showing 1 changed file with 11 additions and 0 deletions.
11 changes: 11 additions & 0 deletions ext/ReactantNNlibExt.jl
Original file line number Diff line number Diff line change
Expand Up @@ -296,6 +296,17 @@ function NNlib.make_causal_mask(x::AnyTracedRArray; dims::Int=2)
)
end

# XXX: reevaluate this manual optimization once
# https://github.com/EnzymeAD/Enzyme-JAX/issues/164 is handled
function NNlib.gather!(
dst::TracedRArray{T1,2},
src::AnyTracedRArray{T2,2},
idxs::Union{AbstractUnitRange{<:Number}},
) where {T1,T2}
dst.mlir_data = src[:, idxs].mlir_data
return dst
end

function NNlib.gather!(
dst::TracedRArray{T1,2}, src::AnyTracedRArray{T2,2}, idxs::AbstractVector{<:Number}
) where {T1,T2}
Expand Down

0 comments on commit 8b90aea

Please sign in to comment.