Skip to content

Commit

Permalink
refactor: overload inplace batched matmul
Browse files Browse the repository at this point in the history
  • Loading branch information
avik-pal committed Nov 10, 2024
1 parent 022555b commit 53cecf3
Showing 1 changed file with 13 additions and 6 deletions.
19 changes: 13 additions & 6 deletions ext/ReactantNNlibExt.jl
Original file line number Diff line number Diff line change
Expand Up @@ -222,14 +222,20 @@ function NNlib.meanpool!(
end

NNlib.batched_transpose(x::AnyTracedRArray{T,3}) where {T} = permutedims(x, (2, 1, 3))
NNlib.batched_adjoint(x::AnyTracedRArray{<:Real,3}) = NNlib.batched_transpose(x)
function NNlib.batched_adjoint(x::AnyTracedRArray{T,3}) where {T}
y = permutedims(x, (2, 1, 3))
conj!(y)
return y
end

function NNlib.batched_mul(x::AnyTracedRArray{T,3}, y::AnyTracedRArray{T,3}) where {T}
function NNlib.batched_mul!(
res::TracedRArray{T1,3}, x::AnyTracedRArray{T2,3}, y::AnyTracedRArray{T3,3}
) where {T1,T2,T3}
if (size(x, 3) != size(y, 3) && size(x, 3) != 1 && size(y, 3) != 1) ||
(size(x, 2) != size(y, 1))
throw(
DimensionMismatch(
lazy"size(x) = $(size(x)), size(y) = $(size(y)) inconsistent for batched_matmul.",
lazy"size(x) = $(size(x)), size(y) = $(size(y)) inconsistent for batched_mul.",
),
)
end
Expand All @@ -238,7 +244,7 @@ function NNlib.batched_mul(x::AnyTracedRArray{T,3}, y::AnyTracedRArray{T,3}) whe

B = max(size(x, 1), size(y, 1))
out_shape = (B, size(x, 2), size(y, 3))
resty = MLIR.IR.TensorType(out_shape, eltype(MLIR.IR.type(x.mlir_data)))
resty = MLIR.IR.TensorType(out_shape, eltype(MLIR.IR.type(res.mlir_data)))

if size(x, 1) != size(y, 1)
if size(x, 1) == 1
Expand All @@ -255,7 +261,7 @@ function NNlib.batched_mul(x::AnyTracedRArray{T,3}, y::AnyTracedRArray{T,3}) whe
prec = MLIR.IR.Attribute(
MLIR.API.stablehloPrecisionAttrGet(MLIR.IR.context(), "DEFAULT")
)
res = TracedRArray{T,3}(
tmp = TracedRArray{T1,3}(
(),
MLIR.IR.result(
MLIR.Dialects.stablehlo.dot_general(
Expand All @@ -269,7 +275,8 @@ function NNlib.batched_mul(x::AnyTracedRArray{T,3}, y::AnyTracedRArray{T,3}) whe
),
size(resty),
)
return permutedims(res, (2, 3, 1))
res.mlir_data = permutedims(tmp, (2, 3, 1)).mlir_data
return res
end

function NNlib.pad_constant(
Expand Down

0 comments on commit 53cecf3

Please sign in to comment.