Skip to content

Commit

Permalink
fix: matrix multiplication of wrapper types
Browse files Browse the repository at this point in the history
  • Loading branch information
avik-pal committed Dec 29, 2024
1 parent b60ca6c commit c8e4c1c
Show file tree
Hide file tree
Showing 4 changed files with 57 additions and 12 deletions.
25 changes: 25 additions & 0 deletions src/Overlay.jl
Original file line number Diff line number Diff line change
Expand Up @@ -115,3 +115,28 @@ for randfun in (:rand, :randn, :randexp)
# end
end
end

# LinearAlgebra.jl overloads
## `_mul!` goes through too many layers of abstractions and we aren't able to overload
## without specializing on every possible combination of types
@reactant_overlay @noinline function LinearAlgebra.mul!(
C::AbstractVector, A::AbstractMatrix, B::AbstractVector, α::Number, β::Number
)
if any(Base.Fix2(isa, TracedRArray) ancestor, (C, A, B))
TracedLinearAlgebra.overloaded_mul!(C, A, B, α, β)
else
LinearAlgebra._mul!(C, A, B, α, β)
end
return C
end

@reactant_overlay @noinline function LinearAlgebra.mul!(
C::AbstractMatrix, A::AbstractMatrix, B::AbstractVecOrMat, α::Number, β::Number
)
if any(Base.Fix2(isa, TracedRArray) ancestor, (C, A, B))
TracedLinearAlgebra.overloaded_mul!(C, A, B, α, β)
else
LinearAlgebra._mul!(C, A, B, α, β)
end
return C
end
4 changes: 3 additions & 1 deletion src/Reactant.jl
Original file line number Diff line number Diff line change
Expand Up @@ -119,7 +119,9 @@ const WrappedTracedRArray{T,N} = WrappedArray{
const AnyTracedRArray{T,N} = Union{TracedRArray{T,N},WrappedTracedRArray{T,N}}
const AnyTracedRVector{T} = AnyTracedRArray{T,1}
const AnyTracedRMatrix{T} = Union{
AnyTracedRArray{T,2},LinearAlgebra.Diagonal{T,TracedRArray{T,1}}
AnyTracedRArray{T,2},
LinearAlgebra.Diagonal{TracedRNumber{T},TracedRArray{T,1}},
LinearAlgebra.Tridiagonal{TracedRNumber{T},TracedRArray{T,1}},
}
const AnyTracedRVecOrMat{T} = Union{AnyTracedRVector{T},AnyTracedRMatrix{T}}

Expand Down
10 changes: 6 additions & 4 deletions src/stdlibs/LinearAlgebra.jl
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,9 @@ function TracedUtils.materialize_traced_array(
return diagm(parent(x))
end

function TracedUtils.materialize_traced_array(x::Tridiagonal{T,TracedRArray{T,1}}) where {T}
function TracedUtils.materialize_traced_array(
x::Tridiagonal{TracedRNumber{T},TracedRArray{T,1}}
) where {T}
return diagm(-1 => x.dl, 0 => x.d, 1 => x.du)
end

Expand Down Expand Up @@ -162,7 +164,7 @@ function TracedUtils.set_mlir_data!(
end

# Core functions
function LinearAlgebra.mul!(
function overloaded_mul!(
@nospecialize(C::TracedRArray{T,1}),
@nospecialize(A::AnyTracedRMatrix),
@nospecialize(B::AnyTracedRVector),
Expand All @@ -176,7 +178,7 @@ function LinearAlgebra.mul!(
return C
end

function LinearAlgebra.mul!(
function overloaded_mul!(
@nospecialize(C::TracedRArray{T,2}),
@nospecialize(A::AnyTracedRMatrix),
@nospecialize(B::AnyTracedRVector),
Expand All @@ -187,7 +189,7 @@ function LinearAlgebra.mul!(
return C
end

function LinearAlgebra.mul!(
function overloaded_mul!(
@nospecialize(C::TracedRArray{T,2}),
@nospecialize(A::AnyTracedRMatrix),
@nospecialize(B::AnyTracedRMatrix),
Expand Down
30 changes: 23 additions & 7 deletions test/integration/linear_algebra.jl
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
using LinearAlgebra, Reactant
using LinearAlgebra, Reactant, Test

function muladd2(A, x, b)
C = similar(A, promote_type(eltype(A), eltype(b)), size(A, 1), size(x, 2))
Expand Down Expand Up @@ -143,13 +143,29 @@ end
@test @jit(diagm(1 => x_ra1, 1 => x_ra2)) diagm(1 => x1, 1 => x2)
end

# TODO: Currently Diagonal(x) * x goes down the generic matmul path but it should clearly be
# optimized
# TODO: Currently <Wrapper Type>(x) * x goes down the generic matmul path but it should
# clearly be optimized
mul_diagonal(x) = Diagonal(x) * x

@testset "mul_diagonal" begin
x = rand(4)
mul_tridiagonal(x) = Tridiagonal(x) * x
mul_unit_lower_triangular(x) = UnitLowerTriangular(x) * x
mul_unit_upper_triangular(x) = UnitUpperTriangular(x) * x
mul_lower_triangular(x) = LowerTriangular(x) * x
mul_upper_triangular(x) = UpperTriangular(x) * x
mul_symmetric(x) = Symmetric(x) * x

@testset "Wrapper Types Matrix Multiplication" begin
x = rand(4, 4)
x_ra = Reactant.to_rarray(x)

@test @jit(mul_diagonal(x_ra)) mul_diagonal(x)
@testset "$(wrapper_type)" for (wrapper_type, fn) in [
(Diagonal, mul_diagonal),
(Tridiagonal, mul_tridiagonal),
(UnitLowerTriangular, mul_unit_lower_triangular),
(UnitUpperTriangular, mul_unit_upper_triangular),
(LowerTriangular, mul_lower_triangular),
(UpperTriangular, mul_upper_triangular),
(Symmetric, mul_symmetric),
]
@test @jit(fn(x_ra)) fn(x)
end
end

0 comments on commit c8e4c1c

Please sign in to comment.