From 3ecbe40c547f3af94abae76ac9c981e8ae4a51f9 Mon Sep 17 00:00:00 2001 From: Avik Pal Date: Sun, 10 Nov 2024 10:35:33 -0500 Subject: [PATCH] fix: use the C API for dimension numbers --- ext/ReactantNNlibExt.jl | 16 +++++++++------- 1 file changed, 9 insertions(+), 7 deletions(-) diff --git a/ext/ReactantNNlibExt.jl b/ext/ReactantNNlibExt.jl index fadbdaa2d..9a06bbdc7 100644 --- a/ext/ReactantNNlibExt.jl +++ b/ext/ReactantNNlibExt.jl @@ -315,13 +315,15 @@ function NNlib.gather!( idxs = (Reactant.promote_to(TracedRArray{Int,1}, idxs) .- 1).mlir_data slice_sizes = Reactant.promote_to(TracedRArray{Int,1}, [size(src, 1), 1]).mlir_data - dimension_numbers = """ - #stablehlo.gather< - offset_dims = [0], - collapsed_slice_dims = [1], - start_index_map = [1], - index_vector_dim = 1>""" - dimension_numbers = parse(Reactant.MLIR.IR.Attribute, dimension_numbers) + dimension_numbers = MLIR.API.stablehloGatherDimensionNumbersGet( + MLIR.IR.context(), + Int64(1), Int64[0], + Int64(1), Int64[1], + Int64(0), Int64[], + Int64(0), Int64[], + Int64(1), Int64[1], + Int64(1) + ) res = MLIR.IR.result( Reactant.MLIR.Dialects.stablehlo.dynamic_gather(