diff --git a/ext/ReactantNNlibExt.jl b/ext/ReactantNNlibExt.jl index 5d03283ea..5b48e27b7 100644 --- a/ext/ReactantNNlibExt.jl +++ b/ext/ReactantNNlibExt.jl @@ -222,14 +222,20 @@ function NNlib.meanpool!( end NNlib.batched_transpose(x::AnyTracedRArray{T,3}) where {T} = permutedims(x, (2, 1, 3)) -NNlib.batched_adjoint(x::AnyTracedRArray{<:Real,3}) = NNlib.batched_transpose(x) +function NNlib.batched_adjoint(x::AnyTracedRArray{T,3}) where {T} + y = permutedims(x, (2, 1, 3)) + conj!(y) + return y +end -function NNlib.batched_mul(x::AnyTracedRArray{T,3}, y::AnyTracedRArray{T,3}) where {T} +function NNlib.batched_mul!( + res::TracedRArray{T1,3}, x::AnyTracedRArray{T2,3}, y::AnyTracedRArray{T3,3} +) where {T1,T2,T3} if (size(x, 3) != size(y, 3) && size(x, 3) != 1 && size(y, 3) != 1) || (size(x, 2) != size(y, 1)) throw( DimensionMismatch( - lazy"size(x) = $(size(x)), size(y) = $(size(y)) inconsistent for batched_matmul.", + lazy"size(x) = $(size(x)), size(y) = $(size(y)) inconsistent for batched_mul.", ), ) end @@ -238,7 +244,7 @@ function NNlib.batched_mul(x::AnyTracedRArray{T,3}, y::AnyTracedRArray{T,3}) whe B = max(size(x, 1), size(y, 1)) out_shape = (B, size(x, 2), size(y, 3)) - resty = MLIR.IR.TensorType(out_shape, eltype(MLIR.IR.type(x.mlir_data))) + resty = MLIR.IR.TensorType(out_shape, eltype(MLIR.IR.type(res.mlir_data))) if size(x, 1) != size(y, 1) if size(x, 1) == 1 @@ -255,7 +261,7 @@ function NNlib.batched_mul(x::AnyTracedRArray{T,3}, y::AnyTracedRArray{T,3}) whe prec = MLIR.IR.Attribute( MLIR.API.stablehloPrecisionAttrGet(MLIR.IR.context(), "DEFAULT") ) - res = TracedRArray{T,3}( + tmp = TracedRArray{T1,3}( (), MLIR.IR.result( MLIR.Dialects.stablehlo.dot_general( @@ -269,7 +275,8 @@ function NNlib.batched_mul(x::AnyTracedRArray{T,3}, y::AnyTracedRArray{T,3}) whe ), size(resty), ) - return permutedims(res, (2, 3, 1)) + res.mlir_data = permutedims(tmp, (2, 3, 1)).mlir_data + return res end function NNlib.pad_constant(