Skip to content

Commit

Permalink
fix: use the C API for dimension numbers
Browse files Browse the repository at this point in the history
  • Loading branch information
avik-pal committed Nov 10, 2024
1 parent 4de7be4 commit 3ecbe40
Showing 1 changed file with 9 additions and 7 deletions.
16 changes: 9 additions & 7 deletions ext/ReactantNNlibExt.jl
Original file line number Diff line number Diff line change
Expand Up @@ -315,13 +315,15 @@ function NNlib.gather!(
idxs = (Reactant.promote_to(TracedRArray{Int,1}, idxs) .- 1).mlir_data
slice_sizes = Reactant.promote_to(TracedRArray{Int,1}, [size(src, 1), 1]).mlir_data

dimension_numbers = """
#stablehlo.gather<
offset_dims = [0],
collapsed_slice_dims = [1],
start_index_map = [1],
index_vector_dim = 1>"""
dimension_numbers = parse(Reactant.MLIR.IR.Attribute, dimension_numbers)
dimension_numbers = MLIR.API.stablehloGatherDimensionNumbersGet(
MLIR.IR.context(),
Int64(1), Int64[0],
Int64(1), Int64[1],
Int64(0), Int64[],
Int64(0), Int64[],
Int64(1), Int64[1],
Int64(1)
)

res = MLIR.IR.result(
Reactant.MLIR.Dialects.stablehlo.dynamic_gather(
Expand Down

0 comments on commit 3ecbe40

Please sign in to comment.