Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: partial NNlib.gather support + better indexing support #252

Merged
merged 12 commits into from
Nov 10, 2024
56 changes: 56 additions & 0 deletions ext/ReactantNNlibExt.jl
Original file line number Diff line number Diff line change
Expand Up @@ -296,4 +296,60 @@ function NNlib.make_causal_mask(x::AnyTracedRArray; dims::Int=2)
)
end

# XXX: reevaluate this manual optimization once
# https://github.com/EnzymeAD/Enzyme-JAX/issues/164 is handled
function NNlib.gather!(
dst::TracedRArray{T1,2},
src::AnyTracedRArray{T2,2},
idxs::Union{AbstractUnitRange{<:Number}},
) where {T1,T2}
dst.mlir_data = src[:, idxs].mlir_data
return dst
end

function NNlib.gather!(
dst::TracedRArray{T1,2}, src::AnyTracedRArray{T2,2}, idxs::AbstractVector{<:Number}
) where {T1,T2}
dims = NNlib.scatter_dims(src, dst, idxs)
@assert dims == 1 # scatter_dims lets us do some size checks so we call that function
idxs = (Reactant.promote_to(TracedRArray{Int,1}, idxs) .- 1).mlir_data
slice_sizes = Reactant.promote_to(TracedRArray{Int,1}, [size(src, 1), 1]).mlir_data

dimension_numbers = """
#stablehlo.gather<
offset_dims = [0],
collapsed_slice_dims = [1],
start_index_map = [1],
index_vector_dim = 1>"""
dimension_numbers = parse(Reactant.MLIR.IR.Attribute, dimension_numbers)
avik-pal marked this conversation as resolved.
Show resolved Hide resolved

res = MLIR.IR.result(
Reactant.MLIR.Dialects.stablehlo.dynamic_gather(
src.mlir_data, idxs, slice_sizes; dimension_numbers
),
1,
)
dst.mlir_data = res
return dst
end

# XXX: For performance to use `stablehlo.dynamic_gather` or atleast use traced loop
# instead of unrolling the loop (the case for AbstractArray can just use
# `stablehlo.gather`). See above for the special case implementation that is optimized.
function NNlib.gather!(dst::TracedRArray, src::AnyTracedRArray, idxs::AbstractArray)
@warn "Using fallback implementation of `gather!` for using `stablehlo.dynamic_slice`. \
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

if you'd like you can put this behind a global to make sure it's only printed once (I think other indexing does that).

Though also I'm fine with having it always warn

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

ah that's what maxlog does!

This case is not optimized and will be slow." maxlog = 1
dims = NNlib.scatter_dims(src, dst, idxs)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

can we emit a warning here at least in the interim?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I dont think this is the right way to go. Even for a small testcase (nanoGPT) it takes forever to compile. Let me try to understand the stablehlo gather and get it fixed

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

👍 , if you'd like feel free to separate the dynamic_slice stuff into a different PR

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

optimized the common cases and printing a warning for the other cases.

colons = ntuple(Returns(Colon()), dims)
start_sizes = ntuple(i -> size(src, i), dims)
results = map(CartesianIndices(idxs)) do k
res = src[colons..., Tuple(idxs[k])...]
res isa TracedRNumber && (res = Reactant.broadcast_to_size(res, (1,)))
return reshape(res, start_sizes..., :)
end
res = reshape(cat(results...; dims=(dims + 1)), size(dst))
dst.mlir_data = res.mlir_data
return dst
end

end # module ReactantNNlibExt
74 changes: 50 additions & 24 deletions src/TracedRArray.jl
Original file line number Diff line number Diff line change
Expand Up @@ -56,7 +56,9 @@ function get_ancestor_indices(x::WrappedTracedRArray, indices...)
return get_ancestor_indices(parent(x), Base.reindex(parentindices(x), indices)...)
end

function Base.getindex(a::TracedRArray{T,N}, index::Vararg{Int,N}) where {T,N}
function Base.getindex(
a::TracedRArray{T,N}, index::Vararg{Union{Int,TracedRNumber{Int}},N}
) where {T,N}
@warn(
"""Performing scalar indexing on task $(current_task()).
Invocation resulted in scalar indexing of a TracedRArray.
Expand All @@ -65,49 +67,59 @@ Such implementations *do not* execute on device, but very slowly on the CPU,
and require expensive copies and synchronization each time and therefore should be avoided."""
)

start_indices = [promote_to(TracedRNumber{Int}, i - 1).mlir_data for i in index]
slice_sizes = [1 for _ in index]
avik-pal marked this conversation as resolved.
Show resolved Hide resolved

res1 = MLIR.IR.result(
MLIR.Dialects.stablehlo.slice(
a.mlir_data;
start_indices=MLIR.IR.DenseArrayAttribute([Int64(i - 1) for i in index]),
limit_indices=MLIR.IR.DenseArrayAttribute([Int64(i) for i in index]),
strides=MLIR.IR.DenseArrayAttribute([Int64(1) for i in index]),
),
1,
MLIR.Dialects.stablehlo.dynamic_slice(a.mlir_data, start_indices; slice_sizes), 1
)
res2 = MLIR.IR.result(
MLIR.Dialects.stablehlo.reshape(
res1; result_0=MLIR.IR.TensorType(Int64[], eltype(MLIR.IR.type(res1)))
),
1,
)

return TracedRNumber{T}((), res2)
end

function Base.getindex(a::TracedRArray{T,0}) where {T}
return TracedRNumber{T}((), a.mlir_data)
end

# XXX: We want to support https://github.com/EnzymeAD/Reactant.jl/issues/242 eventually
function Base.getindex(a::TracedRArray{T,N}, indices::Vararg{Any,N}) where {T,N}
indices = [i isa Colon ? (1:size(a, idx)) : i for (idx, i) in enumerate(indices)]
indices = map(enumerate(indices)) do (idx, i)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

yeah I think this should be doable with gather. That part I'm less confident we have all the optimization rules to lower into dynamic slice

i isa Colon && return 1:size(a, idx)
i isa CartesianIndex && return Tuple(i)
return i
end

foreach(indices) do idxs
idxs isa Number && return nothing
contiguous = all(isone, diff(idxs))
# XXX: We want to throw error even for dynamic indexing
if typeof(a) <: Bool
contiguous || error("non-contiguous indexing is not supported")
end
end

start_indices = map(indices) do i
return promote_to(TracedRNumber{Int}, first(i) - 1).mlir_data
end
slice_sizes = [length(i) for i in indices]
avik-pal marked this conversation as resolved.
Show resolved Hide resolved
res = MLIR.IR.result(
MLIR.Dialects.stablehlo.slice(
a.mlir_data;
start_indices=MLIR.IR.DenseArrayAttribute([
Int64(first(i) - 1) for i in indices
]),
limit_indices=MLIR.IR.DenseArrayAttribute([Int64(last(i)) for i in indices]),
strides=MLIR.IR.DenseArrayAttribute([Int64(1) for i in indices]),
),
1,
MLIR.Dialects.stablehlo.dynamic_slice(a.mlir_data, start_indices; slice_sizes), 1
)

x = TracedRArray{T,N}((), res, Tuple(length.(indices)))
ddims = findall(x -> x isa Integer, indices)
!isempty(ddims) && return dropdims(x; dims=Tuple(ddims))
ddims = findall(Base.Fix2(isa, Integer), indices)
isempty(ddims) || return dropdims(x; dims=Tuple(ddims))
return x
end

# Prevent ambiguity
function Base.getindex(a::WrappedTracedRArray, index::Int...)
function Base.getindex(a::WrappedTracedRArray, index::Union{Int,TracedRNumber{Int}}...)
return getindex(ancestor(a), get_ancestor_indices(a, index...)...)
end

Expand All @@ -116,7 +128,9 @@ function Base.getindex(a::WrappedTracedRArray, indices...)
end

function Base.setindex!(
a::TracedRArray{T,N}, v, indices::Vararg{Union{Base.AbstractUnitRange,Colon,Int},N}
a::TracedRArray{T,N},
v,
indices::Vararg{Union{Base.AbstractUnitRange,Colon,Int,TracedRNumber{Int}},N},
) where {T,N}
indices = map(enumerate(indices)) do (idx, i)
i isa Int ? (i:i) : (i isa Colon ? (1:size(a, idx)) : i)
Expand All @@ -138,13 +152,17 @@ function Base.setindex!(
end

function Base.setindex!(
a::AnyTracedRArray{T,N}, v, indices::Vararg{Union{Base.AbstractUnitRange,Colon,Int},N}
a::AnyTracedRArray{T,N},
v,
indices::Vararg{Union{Base.AbstractUnitRange,Colon,Int,TracedRNumber{Int}},N},
) where {T,N}
ancestor_indices = get_ancestor_indices(a, indices...)
setindex!(ancestor(a), v, ancestor_indices...)
return a
end

Base.Tuple(x::TracedRArray) = ntuple(Base.Fix1(Base.getindex, x), length(x))

Base.size(x::TracedRArray) = x.shape

Base.copy(A::TracedRArray{T,N}) where {T,N} = TracedRArray{T,N}((), A.mlir_data, size(A))
Expand Down Expand Up @@ -699,7 +717,7 @@ end

function broadcast_to_size(arg::T, rsize) where {T<:Number}
attr = MLIR.IR.DenseElementsAttribute(Base.fill(arg, Tuple(rsize)))
return arg = TracedRArray{T,length(rsize)}(
return TracedRArray{T,length(rsize)}(
(), MLIR.IR.result(MLIR.Dialects.stablehlo.constant(; value=attr), 1), rsize
)
end
Expand All @@ -711,6 +729,11 @@ function broadcast_to_size(arg::TracedRNumber, rsize)
)
end

function broadcast_to_size(arg::AnyTracedRArray{T,0}, rsize) where {T}
arg = materialize_traced_array(arg)
return broadcast_to_size(TracedRNumber{T}((), arg.mlir_data), rsize)
end

function broadcast_to_size(arg::AnyTracedRArray, rsize)
arg = materialize_traced_array(arg)
size(arg) == rsize && return arg
Expand Down Expand Up @@ -856,3 +879,6 @@ for (minT, maxT) in Iterators.product((Number, TracedRNumber), (Number, TracedRN
return x
end
end

Base.all(f::Function, x::TracedRArray) = mapreduce(f, &, x)
Base.any(f::Function, x::TracedRArray) = mapreduce(f, |, x)
32 changes: 18 additions & 14 deletions src/TracedRNumber.jl
Original file line number Diff line number Diff line change
Expand Up @@ -193,20 +193,24 @@ function Base.ifelse(
)
end

function Base.:&(x::TracedRNumber{Bool}, y::TracedRNumber{Bool})
return TracedRNumber{Bool}(
(), MLIR.IR.result(MLIR.Dialects.stablehlo.and(x.mlir_data, y.mlir_data), 1)
)
end
function Base.:|(x::TracedRNumber{Bool}, y::TracedRNumber{Bool})
return TracedRNumber{Bool}(
(), MLIR.IR.result(MLIR.Dialects.stablehlo.or(x.mlir_data, y.mlir_data), 1)
)
end
function Base.:!(x::TracedRNumber{Bool})
return TracedRNumber{Bool}(
(), MLIR.IR.result(MLIR.Dialects.stablehlo.not(x.mlir_data), 1)
)
for (T1, T2) in zip((Bool, Integer), (Bool, Integer))
@eval begin
function Base.:&(x::TracedRNumber{<:$(T1)}, y::TracedRNumber{<:$(T2)})
return TracedRNumber{promote_type(eltype(x), eltype(y))}(
(), MLIR.IR.result(MLIR.Dialects.stablehlo.and(x.mlir_data, y.mlir_data), 1)
)
end
function Base.:|(x::TracedRNumber{<:$(T1)}, y::TracedRNumber{<:$(T2)})
return TracedRNumber{promote_type(eltype(x), eltype(y))}(
(), MLIR.IR.result(MLIR.Dialects.stablehlo.or(x.mlir_data, y.mlir_data), 1)
)
end
function Base.:!(x::TracedRNumber{<:$(T1)})
return TracedRNumber{eltype(x)}(
(), MLIR.IR.result(MLIR.Dialects.stablehlo.not(x.mlir_data), 1)
)
end
end
end

function Base.literal_pow(
Expand Down
11 changes: 11 additions & 0 deletions test/basic.jl
Original file line number Diff line number Diff line change
Expand Up @@ -557,3 +557,14 @@ end
@test minimum(y) ≥ 0.0
@test x_ra ≈ x
end

@testset "dynamic indexing" begin
x = randn(5, 3)
x_ra = Reactant.to_rarray(x)

idx = [1, 2, 3]
idx_ra = Reactant.to_rarray(idx)

y = @jit(getindex(x_ra, idx_ra, :))
@test y ≈ x[idx, :]
end
Loading
Loading