diff --git a/src/TracedRArray.jl b/src/TracedRArray.jl index a14b439a0..3593f141f 100644 --- a/src/TracedRArray.jl +++ b/src/TracedRArray.jl @@ -97,7 +97,11 @@ function Base.getindex(a::TracedRArray{T,N}, indices::Vararg{Any,N}) where {T,N} foreach(indices) do idxs idxs isa Number && return nothing - all(isone, diff(idxs)) || error("non-contiguous indexing is not supported") + contiguous = all(isone, diff(idxs)) + # XXX: We want to throw error even for dynamic indexing + if typeof(a) <: Bool + contiguous || error("non-contiguous indexing is not supported") + end end start_indices = map(indices) do i @@ -875,3 +879,6 @@ for (minT, maxT) in Iterators.product((Number, TracedRNumber), (Number, TracedRN return x end end + +Base.all(f::Function, x::TracedRArray) = mapreduce(f, &, x) +Base.any(f::Function, x::TracedRArray) = mapreduce(f, |, x) diff --git a/src/TracedRNumber.jl b/src/TracedRNumber.jl index 89513a2d9..4ddb02131 100644 --- a/src/TracedRNumber.jl +++ b/src/TracedRNumber.jl @@ -193,20 +193,24 @@ function Base.ifelse( ) end -function Base.:&(x::TracedRNumber{Bool}, y::TracedRNumber{Bool}) - return TracedRNumber{Bool}( - (), MLIR.IR.result(MLIR.Dialects.stablehlo.and(x.mlir_data, y.mlir_data), 1) - ) -end -function Base.:|(x::TracedRNumber{Bool}, y::TracedRNumber{Bool}) - return TracedRNumber{Bool}( - (), MLIR.IR.result(MLIR.Dialects.stablehlo.or(x.mlir_data, y.mlir_data), 1) - ) -end -function Base.:!(x::TracedRNumber{Bool}) - return TracedRNumber{Bool}( - (), MLIR.IR.result(MLIR.Dialects.stablehlo.not(x.mlir_data), 1) - ) +for (T1, T2) in zip((Bool, Integer), (Bool, Integer)) + @eval begin + function Base.:&(x::TracedRNumber{<:$(T1)}, y::TracedRNumber{<:$(T2)}) + return TracedRNumber{promote_type(eltype(x), eltype(y))}( + (), MLIR.IR.result(MLIR.Dialects.stablehlo.and(x.mlir_data, y.mlir_data), 1) + ) + end + function Base.:|(x::TracedRNumber{<:$(T1)}, y::TracedRNumber{<:$(T2)}) + return TracedRNumber{promote_type(eltype(x), eltype(y))}( + (), MLIR.IR.result(MLIR.Dialects.stablehlo.or(x.mlir_data, y.mlir_data), 1) + ) + end + function Base.:!(x::TracedRNumber{<:$(T1)}) + return TracedRNumber{eltype(x)}( + (), MLIR.IR.result(MLIR.Dialects.stablehlo.not(x.mlir_data), 1) + ) + end + end end function Base.literal_pow( diff --git a/test/basic.jl b/test/basic.jl index 5137379bc..e386ae545 100644 --- a/test/basic.jl +++ b/test/basic.jl @@ -557,3 +557,14 @@ end @test minimum(y) ≥ 0.0 @test x_ra ≈ x end + +@testset "dynamic indexing" begin + x = randn(5, 3) + x_ra = Reactant.to_rarray(x) + + idx = [1, 2, 3] + idx_ra = Reactant.to_rarray(idx) + + y = @jit(getindex(x_ra, idx_ra, :)) + @test y ≈ x[idx, :] +end