Skip to content

Commit

Permalink
feat: special case gather! for the most common cases
Browse files Browse the repository at this point in the history
  • Loading branch information
avik-pal committed Nov 9, 2024
1 parent e33d616 commit 6955241
Show file tree
Hide file tree
Showing 2 changed files with 31 additions and 3 deletions.
32 changes: 30 additions & 2 deletions ext/ReactantNNlibExt.jl
Original file line number Diff line number Diff line change
Expand Up @@ -296,14 +296,42 @@ function NNlib.make_causal_mask(x::AnyTracedRArray; dims::Int=2)
)
end

function NNlib.gather!(
dst::TracedRArray{T1,2}, src::AnyTracedRArray{T2,2}, idxs::AbstractVector{<:Number}
) where {T1,T2}
dims = NNlib.scatter_dims(src, dst, idxs)
@assert dims == 1 # scatter_dims lets us do some size checks so we call that function
idxs = (Reactant.promote_to(TracedRArray{Int,1}, idxs) .- 1).mlir_data
slice_sizes = Reactant.promote_to(TracedRArray{Int,1}, [size(src, 1), 1]).mlir_data

dimension_numbers = """
#stablehlo.gather<
offset_dims = [0],
collapsed_slice_dims = [1],
start_index_map = [1],
index_vector_dim = 1>"""
dimension_numbers = parse(Reactant.MLIR.IR.Attribute, dimension_numbers)

res = MLIR.IR.result(
Reactant.MLIR.Dialects.stablehlo.dynamic_gather(
src.mlir_data, idxs, slice_sizes; dimension_numbers
),
1,
)
dst.mlir_data = res
return dst
end

# XXX: For performance to use `stablehlo.dynamic_gather` or atleast use traced loop
# instead of unrolling the loop (the case for AbstractArray can just use
# `stablehlo.gather`)
# `stablehlo.gather`). See above for the special case implementation that is optimized.
function NNlib.gather!(
dst::TracedRArray{T1,N}, src::AnyTracedRArray{T2,N}, idxs::AbstractArray
) where {T1,T2,N}
@warn "Using fallback implementation of `gather!` for using `stablehlo.dynamic_slice`. \
This case is not optimized and will be slow." maxlog = 1
dims = NNlib.scatter_dims(src, dst, idxs)
colons = ntuple(i -> Colon(), dims)
colons = ntuple(Returns(Colon()), dims)
start_sizes = ntuple(i -> size(src, i), dims)
results = [reshape(src[colons..., Tuple(idxs[k])...], start_sizes..., :) for k in idxs]
res = reshape(cat(results...; dims=(dims + 1)), size(dst))
Expand Down
2 changes: 1 addition & 1 deletion src/TracedRArray.jl
Original file line number Diff line number Diff line change
Expand Up @@ -75,7 +75,7 @@ and require expensive copies and synchronization each time and therefore should
)
res2 = MLIR.IR.result(
MLIR.Dialects.stablehlo.reshape(
res1; result_0=MLIR.IR.TensorType(Int[], eltype(MLIR.IR.type(res1)))
res1; result_0=MLIR.IR.TensorType(Int64[], eltype(MLIR.IR.type(res1)))
),
1,
)
Expand Down

0 comments on commit 6955241

Please sign in to comment.