Skip to content

Commit

Permalink
fix: use the C API for dimension numbers
Browse files Browse the repository at this point in the history
  • Loading branch information
avik-pal committed Nov 10, 2024
1 parent 4de7be4 commit cf6bffd
Showing 1 changed file with 11 additions and 7 deletions.
18 changes: 11 additions & 7 deletions ext/ReactantNNlibExt.jl
Original file line number Diff line number Diff line change
Expand Up @@ -315,13 +315,17 @@ function NNlib.gather!(
idxs = (Reactant.promote_to(TracedRArray{Int,1}, idxs) .- 1).mlir_data
slice_sizes = Reactant.promote_to(TracedRArray{Int,1}, [size(src, 1), 1]).mlir_data

dimension_numbers = """
#stablehlo.gather<
offset_dims = [0],
collapsed_slice_dims = [1],
start_index_map = [1],
index_vector_dim = 1>"""
dimension_numbers = parse(Reactant.MLIR.IR.Attribute, dimension_numbers)
#! format: off
dimension_numbers = MLIR.API.stablehloGatherDimensionNumbersGet(
MLIR.IR.context(),
Int64(1), Int64[0],
Int64(1), Int64[1],
Int64(0), Int64[],
Int64(0), Int64[],
Int64(1), Int64[1],
Int64(1)
)
#! format: on

res = MLIR.IR.result(
Reactant.MLIR.Dialects.stablehlo.dynamic_gather(
Expand Down

0 comments on commit cf6bffd

Please sign in to comment.