Skip to content

Commit

Permalink
test: dynamic slice test
Browse files Browse the repository at this point in the history
  • Loading branch information
avik-pal committed Nov 10, 2024
1 parent 8b90aea commit cfcaece
Show file tree
Hide file tree
Showing 3 changed files with 37 additions and 15 deletions.
9 changes: 8 additions & 1 deletion src/TracedRArray.jl
Original file line number Diff line number Diff line change
Expand Up @@ -97,7 +97,11 @@ function Base.getindex(a::TracedRArray{T,N}, indices::Vararg{Any,N}) where {T,N}

foreach(indices) do idxs
idxs isa Number && return nothing
all(isone, diff(idxs)) || error("non-contiguous indexing is not supported")
contiguous = all(isone, diff(idxs))
# XXX: We want to throw error even for dynamic indexing
if typeof(a) <: Bool
contiguous || error("non-contiguous indexing is not supported")
end
end

start_indices = map(indices) do i
Expand Down Expand Up @@ -875,3 +879,6 @@ for (minT, maxT) in Iterators.product((Number, TracedRNumber), (Number, TracedRN
return x
end
end

Base.all(f::Function, x::TracedRArray) = mapreduce(f, &, x)
Base.any(f::Function, x::TracedRArray) = mapreduce(f, |, x)
32 changes: 18 additions & 14 deletions src/TracedRNumber.jl
Original file line number Diff line number Diff line change
Expand Up @@ -193,20 +193,24 @@ function Base.ifelse(
)
end

function Base.:&(x::TracedRNumber{Bool}, y::TracedRNumber{Bool})
return TracedRNumber{Bool}(
(), MLIR.IR.result(MLIR.Dialects.stablehlo.and(x.mlir_data, y.mlir_data), 1)
)
end
function Base.:|(x::TracedRNumber{Bool}, y::TracedRNumber{Bool})
return TracedRNumber{Bool}(
(), MLIR.IR.result(MLIR.Dialects.stablehlo.or(x.mlir_data, y.mlir_data), 1)
)
end
function Base.:!(x::TracedRNumber{Bool})
return TracedRNumber{Bool}(
(), MLIR.IR.result(MLIR.Dialects.stablehlo.not(x.mlir_data), 1)
)
for (T1, T2) in zip((Bool, Integer), (Bool, Integer))
@eval begin
function Base.:&(x::TracedRNumber{<:$(T1)}, y::TracedRNumber{<:$(T2)})
return TracedRNumber{promote_type(eltype(x), eltype(y))}(
(), MLIR.IR.result(MLIR.Dialects.stablehlo.and(x.mlir_data, y.mlir_data), 1)
)
end
function Base.:|(x::TracedRNumber{<:$(T1)}, y::TracedRNumber{<:$(T2)})
return TracedRNumber{promote_type(eltype(x), eltype(y))}(
(), MLIR.IR.result(MLIR.Dialects.stablehlo.or(x.mlir_data, y.mlir_data), 1)
)
end
function Base.:!(x::TracedRNumber{<:$(T1)})
return TracedRNumber{eltype(x)}(
(), MLIR.IR.result(MLIR.Dialects.stablehlo.not(x.mlir_data), 1)
)
end
end
end

function Base.literal_pow(
Expand Down
11 changes: 11 additions & 0 deletions test/basic.jl
Original file line number Diff line number Diff line change
Expand Up @@ -557,3 +557,14 @@ end
@test minimum(y) 0.0
@test x_ra x
end

@testset "dynamic indexing" begin
x = randn(5, 3)
x_ra = Reactant.to_rarray(x)

idx = [1, 2, 3]
idx_ra = Reactant.to_rarray(idx)

y = @jit(getindex(x_ra, idx_ra, :))
@test y x[idx, :]
end

0 comments on commit cfcaece

Please sign in to comment.