From 041cdc15d7ec65d7453ce43e510b3c64f962d9f4 Mon Sep 17 00:00:00 2001 From: Avik Pal Date: Fri, 8 Nov 2024 22:27:05 -0500 Subject: [PATCH] feat: unbreak NNlib.gather --- src/TracedRArray.jl | 7 ++++++- 1 file changed, 6 insertions(+), 1 deletion(-) diff --git a/src/TracedRArray.jl b/src/TracedRArray.jl index 4f25707a8..cf028b19c 100644 --- a/src/TracedRArray.jl +++ b/src/TracedRArray.jl @@ -698,7 +698,7 @@ end function broadcast_to_size(arg::T, rsize) where {T<:Number} attr = MLIR.IR.DenseElementsAttribute(Base.fill(arg, Tuple(rsize))) - return arg = TracedRArray{T,length(rsize)}( + return TracedRArray{T,length(rsize)}( (), MLIR.IR.result(MLIR.Dialects.stablehlo.constant(; value=attr), 1), rsize ) end @@ -710,6 +710,11 @@ function broadcast_to_size(arg::TracedRNumber, rsize) ) end +function broadcast_to_size(arg::AnyTracedRArray{T, 0}, rsize) where {T} + arg = materialize_traced_array(arg) + return broadcast_to_size(TracedRNumber{T}((), arg.mlir_data), rsize) +end + function broadcast_to_size(arg::AnyTracedRArray, rsize) arg = materialize_traced_array(arg) size(arg) == rsize && return arg