Skip to content

Commit

Permalink
feat: unbreak NNlib.gather
Browse files Browse the repository at this point in the history
  • Loading branch information
avik-pal committed Nov 9, 2024
1 parent 2208689 commit 041cdc1
Showing 1 changed file with 6 additions and 1 deletion.
7 changes: 6 additions & 1 deletion src/TracedRArray.jl
Original file line number Diff line number Diff line change
Expand Up @@ -698,7 +698,7 @@ end

function broadcast_to_size(arg::T, rsize) where {T<:Number}
attr = MLIR.IR.DenseElementsAttribute(Base.fill(arg, Tuple(rsize)))
return arg = TracedRArray{T,length(rsize)}(
return TracedRArray{T,length(rsize)}(
(), MLIR.IR.result(MLIR.Dialects.stablehlo.constant(; value=attr), 1), rsize
)
end
Expand All @@ -710,6 +710,11 @@ function broadcast_to_size(arg::TracedRNumber, rsize)
)
end

function broadcast_to_size(arg::AnyTracedRArray{T, 0}, rsize) where {T}
arg = materialize_traced_array(arg)
return broadcast_to_size(TracedRNumber{T}((), arg.mlir_data), rsize)
end

function broadcast_to_size(arg::AnyTracedRArray, rsize)
arg = materialize_traced_array(arg)
size(arg) == rsize && return arg
Expand Down

0 comments on commit 041cdc1

Please sign in to comment.