Skip to content

Commit

Permalink
refactor: directly overload inplace conv routine from NNlib
Browse files Browse the repository at this point in the history
  • Loading branch information
avik-pal committed Nov 10, 2024
1 parent 6deb95d commit 3102dfa
Showing 1 changed file with 20 additions and 22 deletions.
42 changes: 20 additions & 22 deletions ext/ReactantNNlibExt.jl
Original file line number Diff line number Diff line change
Expand Up @@ -53,8 +53,8 @@ function NNlib.logsoftmax!(out::TracedRArray{T}, x::AbstractArray; dims=1) where
return out
end

function NNlib.conv(
x::AnyTracedRArray{T,N}, W::AnyTracedRArray{T}, cdims::DenseConvDims
function NNlib.conv!(
y::TracedRArray{T,N}, x::AnyTracedRArray, W::AnyTracedRArray, cdims::DenseConvDims
) where {T,N}
x = materialize_traced_array(x)
W = materialize_traced_array(W)
Expand Down Expand Up @@ -83,33 +83,31 @@ function NNlib.conv(
pl, pr = padding[2i - 1], padding[2i]
d = dilation[i]
s = stride[i]

(size(x, i) + pl + pr - d * (K - 1) - 1) ÷ s + 1
return (size(x, i) + pl + pr - d * (K - 1) - 1) ÷ s + 1
end
output_batch_dim = input_batch_dim
output_feature_dim = input_feature_dim
output_spatial_dims = input_spatial_dims

output_shape = (output_spatial_shapes..., size(W, kernel_output_dim), size(x, N))

dimension_numbers = """
#stablehlo.conv<raw
input_batch_dimension = $(input_batch_dim - 1),
input_feature_dimension = $(input_feature_dim - 1),
input_spatial_dimensions = [$(join(input_spatial_dims .- 1, ", "))],
kernel_output_feature_dimension = $(kernel_output_dim - 1),
kernel_input_feature_dimension = $(kernel_input_dim - 1),
kernel_spatial_dimensions = [$(join(kernel_spatial_dims .- 1, ", "))],
output_batch_dimension = $( output_batch_dim - 1 ),
output_feature_dimension = $( output_feature_dim - 1),
output_spatial_dimensions = [$(join(output_spatial_dims .- 1, ", "))],
>"""
dimension_numbers = parse(Reactant.MLIR.IR.Attribute, dimension_numbers)
#! format: off
dimension_numbers = MLIR.API.stablehloConvDimensionNumbersGet(
MLIR.IR.context(),
Int64(input_batch_dim - 1),
Int64(input_feature_dim - 1),
length(input_spatial_dims), Int64[i - 1 for i in input_spatial_dims],
Int64(kernel_input_dim - 1),
Int64(kernel_output_dim - 1),
length(kernel_spatial_dims), Int64[i - 1 for i in kernel_spatial_dims],
Int64(output_batch_dim - 1),
Int64(output_feature_dim - 1),
length(output_spatial_dims), Int64[i - 1 for i in output_spatial_dims],
)
#! format: on

padding = Reactant.MLIR.IR.DenseElementsAttribute(
reshape(collect(padding), (num_spatial_dims, 2))
)
result_type = Reactant.MLIR.IR.TensorType(output_shape, Reactant.MLIR.IR.Type(T))
result_type = Reactant.MLIR.IR.TensorType(size(y), Reactant.MLIR.IR.Type(T))

weight = W.mlir_data
if !flipkernel
Expand All @@ -132,8 +130,8 @@ function NNlib.conv(
feature_group_count,
batch_group_count=1,
)

return TracedRArray{T,N}((), Reactant.MLIR.IR.result(conv), output_shape)
y.mlir_data = Reactant.MLIR.IR.result(conv)
return y
end

function reduce_window(f, x::AnyTracedRArray{T,N}, pdims; init) where {T,N}
Expand Down

0 comments on commit 3102dfa

Please sign in to comment.