Skip to content

Commit

Permalink
fix: define getindexing into sub reshaped array
Browse files Browse the repository at this point in the history
  • Loading branch information
avik-pal committed Jan 17, 2025
1 parent 532fd73 commit 46029f4
Show file tree
Hide file tree
Showing 2 changed files with 12 additions and 1 deletion.
4 changes: 3 additions & 1 deletion src/TracedRArray.jl
Original file line number Diff line number Diff line change
Expand Up @@ -112,7 +112,9 @@ end

function Base.getindex(a::TracedRArray{T,N}, indices) where {T,N}
if !(indices isa TracedRArray)
indices = TracedUtils.promote_to(TracedRArray{Int,1}, collect(indices))
indices = collect(indices)
eltype(indices) <: CartesianIndex && (indices = LinearIndices(size(a))[indices])
indices = TracedUtils.promote_to(TracedRArray{Int,1}, indices)
end
return Ops.gather_getindex(a, scalar_index_to_cartesian(indices, size(a)))
end
Expand Down
9 changes: 9 additions & 0 deletions src/TracedUtils.jl
Original file line number Diff line number Diff line change
Expand Up @@ -60,6 +60,15 @@ function set_mlir_data!(
return x
end

function get_ancestor_indices(
x::WrappedReshapedArray{TracedRNumber{T},N,TracedRArray{T,M}}, indices...
) where {T,N,M}
cartesian_indices = CartesianIndex.(indices...)
linear_indices = LinearIndices(size(x))[cartesian_indices]
parent_cartesian_indices = CartesianIndices(size(parent(x)))[linear_indices]
return (parent_cartesian_indices,)
end

function set_mlir_data!(
x::PermutedDimsArray{TracedRNumber{T},N,perm,iperm,TracedRArray{T,N}}, data
) where {T,N,perm,iperm}
Expand Down

0 comments on commit 46029f4

Please sign in to comment.