Skip to content

Commit

Permalink
fix: higher dimensional indexing + tests
Browse files Browse the repository at this point in the history
  • Loading branch information
avik-pal committed Dec 29, 2024
1 parent 1affd14 commit b60ca6c
Show file tree
Hide file tree
Showing 3 changed files with 91 additions and 25 deletions.
42 changes: 27 additions & 15 deletions src/Ops.jl
Original file line number Diff line number Diff line change
Expand Up @@ -1445,14 +1445,21 @@ instead.
pushfirst!(update_computation, block)

#! format: off
update_window_dims = Int64[]
inserted_window_dims = collect(Int64, 0:(N - 1))
input_batching_dims = Int64[]
scatter_indices_batching_dims = Int64[]
scatter_dims_to_operand_dims = collect(Int64, 0:(N - 1))
index_vector_dim = Int64(1)

scatter_dimension_numbers = MLIR.API.stablehloScatterDimensionNumbersGet(
MLIR.IR.context(),
Int64(0), Int64[],
Int64(N), collect(Int64, 0:(N - 1)),
Int64(0), Int64[],
Int64(0), Int64[],
Int64(N), collect(Int64, 0:(N - 1)),
Int64(1)
length(update_window_dims), update_window_dims,
length(inserted_window_dims), inserted_window_dims,
length(input_batching_dims), input_batching_dims,
length(scatter_indices_batching_dims), scatter_indices_batching_dims,
length(scatter_dims_to_operand_dims), scatter_dims_to_operand_dims,
index_vector_dim,
)
#! format: on

Expand Down Expand Up @@ -1486,20 +1493,26 @@ use [`MLIR.Dialects.stablehlo.dynamic_slice`](@ref) instead.
@assert size(gather_indices, 2) == N

#! format: off
offset_dims = Int64[1]
collapsed_slice_dims = collect(Int64, 0:(N - 2))
operand_batching_dims = Int64[]
start_indices_batching_dims = Int64[]
start_index_map = collect(Int64, 0:(N - 1))
index_vector_dim = Int64(1)

dimension_numbers = MLIR.API.stablehloGatherDimensionNumbersGet(
MLIR.IR.context(),
Int64(1), Int64[1],
Int64(N - 1), collect(Int64, 0:(N - 2)),
Int64(0), Int64[],
Int64(0), Int64[],
Int64(N), collect(Int64, 0:(N - 1)),
1
Int64(length(offset_dims)), offset_dims,
Int64(length(collapsed_slice_dims)), collapsed_slice_dims,
Int64(length(operand_batching_dims)), operand_batching_dims,
Int64(length(start_indices_batching_dims)), start_indices_batching_dims,
Int64(length(start_index_map)), start_index_map,
Int64(index_vector_dim),
)
#! format: on

return reshape(
TracedRArray{T,2}(
(),
TracedRArray{T}(
MLIR.IR.result(
MLIR.Dialects.stablehlo.gather(
src.mlir_data,
Expand All @@ -1510,7 +1523,6 @@ use [`MLIR.Dialects.stablehlo.dynamic_slice`](@ref) instead.
),
1,
),
(size(gather_indices, 1), 1),
),
size(gather_indices, 1),
)
Expand Down
20 changes: 10 additions & 10 deletions src/TracedRArray.jl
Original file line number Diff line number Diff line change
Expand Up @@ -56,9 +56,7 @@ function Base.getindex(
return TracedRNumber{T}((), res2)
end

function Base.getindex(a::TracedRArray{T,0}) where {T}
return TracedRNumber{T}((), a.mlir_data)
end
Base.getindex(a::TracedRArray{T,0}) where {T} = TracedRNumber{T}((), a.mlir_data)

function Base.getindex(a::TracedRArray{T,N}, indices::Vararg{Any,N}) where {T,N}
indices = map(enumerate(indices)) do (idx, i)
Expand All @@ -80,12 +78,13 @@ function Base.getindex(a::TracedRArray{T,N}, indices::Vararg{Any,N}) where {T,N}

if non_contiguous_getindex
indices_tuples = collect(Iterators.product(indices...))
indices = Matrix{Int}(undef, (length(indices_tuples), 2))
indices = Matrix{Int}(
undef, (length(indices_tuples), length(first(indices_tuples)))
)
for (i, idx) in enumerate(indices_tuples)
indices[i, 1] = idx[1] - 1
indices[i, 2] = idx[2] - 1
indices[i, :] .= idx .- 1
end
indices = promote_to(TracedRArray{Int,2}, indices)
indices = TracedUtils.promote_to(TracedRArray{Int,2}, indices)
res = Ops.gather_getindex(a, indices)
return Ops.reshape(res, size(indices_tuples)...)
end
Expand Down Expand Up @@ -133,10 +132,11 @@ function Base.setindex!(a::TracedRArray{T,N}, v, indices::Vararg{Any,N}) where {

if non_contiguous_setindex
indices_tuples = collect(Iterators.product(indices...))
indices = Matrix{Int}(undef, (length(indices_tuples), 2))
indices = Matrix{Int}(
undef, (length(indices_tuples), length(first(indices_tuples)))
)
for (i, idx) in enumerate(indices_tuples)
indices[i, 1] = idx[1] - 1
indices[i, 2] = idx[2] - 1
indices[i, :] .= idx .- 1
end
indices = TracedUtils.promote_to(TracedRArray{Int,2}, indices)
res = Ops.scatter_setindex(a, indices, Ops.reshape(v, length(v)))
Expand Down
54 changes: 54 additions & 0 deletions test/basic.jl
Original file line number Diff line number Diff line change
Expand Up @@ -737,3 +737,57 @@ end
@test res[1] isa ConcreteRArray{Float64,2}
@test res[2] isa ConcreteRNumber{Float64}
end

@testset "non-contiguous indexing" begin
x = rand(4, 4, 3)
x_ra = Reactant.to_rarray(x)

non_contiguous_indexing1(x) = x[[1, 3, 2], :, :]
non_contiguous_indexing2(x) = x[:, [1, 2, 1, 3], [1, 3]]

@test @jit(non_contiguous_indexing1(x_ra)) non_contiguous_indexing1(x)
@test @jit(non_contiguous_indexing2(x_ra)) non_contiguous_indexing2(x)

x = rand(4, 2)
x_ra = Reactant.to_rarray(x)

non_contiguous_indexing1(x) = x[[1, 3, 2], :]
non_contiguous_indexing2(x) = x[:, [1, 2, 2]]

@test @jit(non_contiguous_indexing1(x_ra)) non_contiguous_indexing1(x)
@test @jit(non_contiguous_indexing2(x_ra)) non_contiguous_indexing2(x)

x = rand(4, 4, 3)
x_ra = Reactant.to_rarray(x)

non_contiguous_indexing1!(x) = x[[1, 3, 2], :, :] .= 2
non_contiguous_indexing2!(x) = x[:, [1, 2, 1, 3], [1, 3]] .= 2

@jit(non_contiguous_indexing1!(x_ra))
non_contiguous_indexing1!(x)
@test x_ra x

x = rand(4, 4, 3)
x_ra = Reactant.to_rarray(x)

@jit(non_contiguous_indexing2!(x_ra))
non_contiguous_indexing2!(x)
@test x_ra x

x = rand(4, 2)
x_ra = Reactant.to_rarray(x)

non_contiguous_indexing1!(x) = x[[1, 3, 2], :] .= 2
non_contiguous_indexing2!(x) = x[:, [1, 2, 2]] .= 2

@jit(non_contiguous_indexing1!(x_ra))
non_contiguous_indexing1!(x)
@test x_ra x

x = rand(4, 2)
x_ra = Reactant.to_rarray(x)

@jit(non_contiguous_indexing2!(x_ra))
non_contiguous_indexing2!(x)
@test x_ra x
end

0 comments on commit b60ca6c

Please sign in to comment.