Skip to content

Commit

Permalink
feat: more coverage for NNlib functions (EnzymeAD#258)
Browse files Browse the repository at this point in the history
* feat: use dynamic slicing

* feat: special case `gather!` for the most common cases

* feat: use `@trace` to implement softmax

* refactor: directly overload inplace conv routine from NNlib

* refactor: overload inplace pooling layers

* refactor: overload inplace batched matmul

* fix: reactant needs latest reactant core

* fix: temporarily avoid tracing in softmax and logsoftmax
  • Loading branch information
avik-pal authored and Pangoraw committed Nov 11, 2024
1 parent 4000995 commit 9a21427
Show file tree
Hide file tree
Showing 5 changed files with 85 additions and 60 deletions.
2 changes: 1 addition & 1 deletion Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,7 @@ LinearAlgebra = "1.10"
NNlib = "0.9"
OrderedCollections = "1"
Preferences = "1.4"
ReactantCore = "0.1"
ReactantCore = "0.1.1"
Reactant_jll = "0.0.24"
Scratch = "1.2"
Statistics = "1.10"
Expand Down
126 changes: 73 additions & 53 deletions ext/ReactantNNlibExt.jl
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@ module ReactantNNlibExt
using NNlib
using Reactant:
Reactant, TracedRArray, AnyTracedRArray, materialize_traced_array, MLIR, TracedRNumber
using ReactantCore: @trace
using LinearAlgebra: LinearAlgebra, triu

for (jlop, hloop) in (
Expand All @@ -20,38 +21,46 @@ for (jlop, hloop) in (
end
end

# TODO handle non finite cases
function NNlib.softmax!(out::TracedRArray{T,N}, x::AbstractArray; dims=1) where {T,N}
max_ = NNlib.fast_maximum(x; dims)
#if all(isfinite, max_)
@fastmath out .= exp.(x .- max_)
#else
# _zero, _one, _inf = T(0), T(1), T(Inf)
# @fastmath @. out = ifelse(isequal(max_,_inf), ifelse(isequal(x,_inf), _one, _zero), exp(x - max_))
#end
# XXX: Once reverse mode of if is properly supported, we can make it @trace
# zero_num = Reactant.promote_to(TracedRNumber{T}, 0)
# one_num = Reactant.promote_to(TracedRNumber{T}, 1)
# @trace if all(isfinite, max_)
@. out = exp(x - max_)
# else
# cond = max_ .== Inf
# true_pred = ifelse.(x .== Inf, one_num, zero_num)
# @. out = ifelse(cond, true_pred, exp(x - max_))
# end
tmp = dims isa Colon ? sum(out) : sum!(max_, out)
return out ./= tmp
out ./= tmp
return out
end

function NNlib.logsoftmax!(out::TracedRArray{T}, x::AbstractArray; dims=1) where {T}
max_ = NNlib.fast_maximum(x; dims)
# if all(isfinite, max_)
@fastmath out .= x .- max_
# XXX: Once reverse mode of if is properly supported, we can make it @trace
# inf_num = Reactant.promote_to(TracedRNumber{T}, Inf)
# zero_num = Reactant.promote_to(TracedRNumber{T}, 0)
# @trace if all(isfinite, max_)
@. out = x - max_
# else
# _zero, _minf, _inf = T(0), T(-Inf), T(Inf)
# @. out = ifelse(
# isequal(max_, _inf), ifelse(isequal(x, _inf), _zero, _minf), x - max_
# )
# cond = max_ .== Inf
# true_pred = ifelse.(x .== Inf, zero_num, -inf_num)
# @. out = ifelse(cond, true_pred, x - max_)
# end
@fastmath log_ = log.(sum(exp, out; dims))
return out .-= log_
out .-= log_
return out
end

function NNlib.conv(
x::AnyTracedRArray{T,N}, W::AnyTracedRArray{T}, cdims::DenseConvDims
function NNlib.conv!(
y::TracedRArray{T,N}, x::AnyTracedRArray, W::AnyTracedRArray, cdims::DenseConvDims
) where {T,N}
x = materialize_traced_array(x)
W = materialize_traced_array(W)
# StableHLO expects matching element types
x = T.(materialize_traced_array(x))
W = T.(materialize_traced_array(W))

kernel_size = NNlib.kernel_size(cdims)
padding = NNlib.padding(cdims)
Expand All @@ -77,33 +86,31 @@ function NNlib.conv(
pl, pr = padding[2i - 1], padding[2i]
d = dilation[i]
s = stride[i]

(size(x, i) + pl + pr - d * (K - 1) - 1) ÷ s + 1
return (size(x, i) + pl + pr - d * (K - 1) - 1) ÷ s + 1
end
output_batch_dim = input_batch_dim
output_feature_dim = input_feature_dim
output_spatial_dims = input_spatial_dims

output_shape = (output_spatial_shapes..., size(W, kernel_output_dim), size(x, N))

dimension_numbers = """
#stablehlo.conv<raw
input_batch_dimension = $(input_batch_dim - 1),
input_feature_dimension = $(input_feature_dim - 1),
input_spatial_dimensions = [$(join(input_spatial_dims .- 1, ", "))],
kernel_output_feature_dimension = $(kernel_output_dim - 1),
kernel_input_feature_dimension = $(kernel_input_dim - 1),
kernel_spatial_dimensions = [$(join(kernel_spatial_dims .- 1, ", "))],
output_batch_dimension = $( output_batch_dim - 1 ),
output_feature_dimension = $( output_feature_dim - 1),
output_spatial_dimensions = [$(join(output_spatial_dims .- 1, ", "))],
>"""
dimension_numbers = parse(Reactant.MLIR.IR.Attribute, dimension_numbers)
#! format: off
dimension_numbers = MLIR.API.stablehloConvDimensionNumbersGet(
MLIR.IR.context(),
Int64(input_batch_dim - 1),
Int64(input_feature_dim - 1),
length(input_spatial_dims), Int64[i - 1 for i in input_spatial_dims],
Int64(kernel_input_dim - 1),
Int64(kernel_output_dim - 1),
length(kernel_spatial_dims), Int64[i - 1 for i in kernel_spatial_dims],
Int64(output_batch_dim - 1),
Int64(output_feature_dim - 1),
length(output_spatial_dims), Int64[i - 1 for i in output_spatial_dims],
)
#! format: on

padding = Reactant.MLIR.IR.DenseElementsAttribute(
reshape(collect(padding), (num_spatial_dims, 2))
)
result_type = Reactant.MLIR.IR.TensorType(output_shape, Reactant.MLIR.IR.Type(T))
result_type = Reactant.MLIR.IR.TensorType(size(y), Reactant.MLIR.IR.Type(T))

weight = W.mlir_data
if !flipkernel
Expand All @@ -126,8 +133,8 @@ function NNlib.conv(
feature_group_count,
batch_group_count=1,
)

return TracedRArray{T,N}((), Reactant.MLIR.IR.result(conv), output_shape)
y.mlir_data = Reactant.MLIR.IR.result(conv)
return y
end

function reduce_window(f, x::AnyTracedRArray{T,N}, pdims; init) where {T,N}
Expand Down Expand Up @@ -198,27 +205,39 @@ function reduce_window(f, x::AnyTracedRArray{T,N}, pdims; init) where {T,N}
return TracedRArray{T,N}((), Reactant.MLIR.IR.result(reduction), size(result_type))
end

function NNlib.maxpool(x::AnyTracedRArray{T}, pdims::NNlib.PoolDims) where {T}
return reduce_window(
Reactant.MLIR.Dialects.stablehlo.maximum, x, pdims; init=typemin(T)
)
function NNlib.maxpool!(
y::TracedRArray{T}, x::AnyTracedRArray, pdims::NNlib.PoolDims
) where {T}
y.mlir_data =
reduce_window(
Reactant.MLIR.Dialects.stablehlo.maximum, T.(x), pdims; init=typemin(T)
).mlir_data
return y
end

function NNlib.meanpool(x::AnyTracedRArray{T}, pdims::NNlib.PoolDims) where {T}
numel = prod(NNlib.kernel_size(pdims))
return reduce_window(Reactant.MLIR.Dialects.stablehlo.add, x, pdims; init=zero(T)) ./
T(numel)
function NNlib.meanpool!(
y::TracedRArray{T}, x::AnyTracedRArray, pdims::NNlib.PoolDims
) where {T}
res = reduce_window(Reactant.MLIR.Dialects.stablehlo.add, T.(x), pdims; init=zero(T))
y.mlir_data = (res ./ T(prod(NNlib.kernel_size(pdims)))).mlir_data
return y
end

NNlib.batched_transpose(x::AnyTracedRArray{T,3}) where {T} = permutedims(x, (2, 1, 3))
NNlib.batched_adjoint(x::AnyTracedRArray{<:Real,3}) = NNlib.batched_transpose(x)
function NNlib.batched_adjoint(x::AnyTracedRArray{T,3}) where {T}
y = permutedims(x, (2, 1, 3))
conj!(y)
return y
end

function NNlib.batched_mul(x::AnyTracedRArray{T,3}, y::AnyTracedRArray{T,3}) where {T}
function NNlib.batched_mul!(
res::TracedRArray{T1,3}, x::AnyTracedRArray{T2,3}, y::AnyTracedRArray{T3,3}
) where {T1,T2,T3}
if (size(x, 3) != size(y, 3) && size(x, 3) != 1 && size(y, 3) != 1) ||
(size(x, 2) != size(y, 1))
throw(
DimensionMismatch(
lazy"size(x) = $(size(x)), size(y) = $(size(y)) inconsistent for batched_matmul.",
lazy"size(x) = $(size(x)), size(y) = $(size(y)) inconsistent for batched_mul.",
),
)
end
Expand All @@ -227,7 +246,7 @@ function NNlib.batched_mul(x::AnyTracedRArray{T,3}, y::AnyTracedRArray{T,3}) whe

B = max(size(x, 1), size(y, 1))
out_shape = (B, size(x, 2), size(y, 3))
resty = MLIR.IR.TensorType(out_shape, eltype(MLIR.IR.type(x.mlir_data)))
resty = MLIR.IR.TensorType(out_shape, eltype(MLIR.IR.type(res.mlir_data)))

if size(x, 1) != size(y, 1)
if size(x, 1) == 1
Expand All @@ -244,7 +263,7 @@ function NNlib.batched_mul(x::AnyTracedRArray{T,3}, y::AnyTracedRArray{T,3}) whe
prec = MLIR.IR.Attribute(
MLIR.API.stablehloPrecisionAttrGet(MLIR.IR.context(), "DEFAULT")
)
res = TracedRArray{T,3}(
tmp = TracedRArray{T1,3}(
(),
MLIR.IR.result(
MLIR.Dialects.stablehlo.dot_general(
Expand All @@ -258,7 +277,8 @@ function NNlib.batched_mul(x::AnyTracedRArray{T,3}, y::AnyTracedRArray{T,3}) whe
),
size(resty),
)
return permutedims(res, (2, 3, 1))
res.mlir_data = permutedims(tmp, (2, 3, 1)).mlir_data
return res
end

function NNlib.pad_constant(
Expand Down
2 changes: 1 addition & 1 deletion lib/ReactantCore/Project.toml
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
name = "ReactantCore"
uuid = "a3311ec8-5e00-46d5-b541-4f83e724a433"
authors = ["William Moses <wmoses@mit.edu>", "Valentin Churavy <vchuravy@mit.edu>", "Sergio Sánchez Ramírez <sergio.sanchez.ramirez@bsc.es>", "Paul Berg <paul@plutojl.org>", "Avik Pal <avikpal@mit.edu>"]
version = "0.1.0"
version = "0.1.1"

[deps]
ExpressionExplorer = "21656369-7473-754a-2065-74616d696c43"
Expand Down
8 changes: 5 additions & 3 deletions lib/ReactantCore/src/ReactantCore.jl
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,10 @@ end

MissingTracedValue() = MissingTracedValue(())

const SPECIAL_SYMBOLS = [
:(:), :nothing, :missing, :Inf, :Inf16, :Inf32, :Inf64, :Base, :Core
]

# Code generation
"""
@trace <expr>
Expand Down Expand Up @@ -89,7 +93,7 @@ end
### Certain Symbols are Reserved
Symbols like `nothing`, `missing` and `:` are not allowed as variables in `@trace` expressions. While certain cases might work but these are not guaranteed to work. For
Symbols like $(SPECIAL_SYMBOLS) are not allowed as variables in `@trace` expressions. While certain cases might work but these are not guaranteed to work. For
example, the following will not work:
```julia
Expand Down Expand Up @@ -372,6 +376,4 @@ function error_if_return(expr)
end
end

const SPECIAL_SYMBOLS = [:(:), :nothing, :missing]

end
7 changes: 5 additions & 2 deletions src/TracedRArray.jl
Original file line number Diff line number Diff line change
Expand Up @@ -538,8 +538,10 @@ function Base.mapreduce(
dims = [dims]
end

op_in_T = Core.Compiler.return_type(f, Tuple{T})

if isnothing(init)
init = Base.reduce_empty(Base.BottomRF(op), Core.Compiler.return_type(f, Tuple{T}))
init = Base.reduce_empty(Base.BottomRF(op), op_in_T)
else
init = init::T
end
Expand All @@ -561,7 +563,8 @@ function Base.mapreduce(
fnbody = MLIR.IR.Block(in_tys, [MLIR.IR.Location() for arg in in_tys])

args = (
TracedRNumber{T}((), MLIR.IR.argument(fnbody, i)) for (i, ty) in enumerate(in_tys)
TracedRNumber{op_in_T}((), MLIR.IR.argument(fnbody, i)) for
(i, ty) in enumerate(in_tys)
)

res = MLIR.IR.block!(fnbody) do
Expand Down

0 comments on commit 9a21427

Please sign in to comment.