diff --git a/src/Overlay.jl b/src/Overlay.jl index b9785b7fa..701d9c711 100644 --- a/src/Overlay.jl +++ b/src/Overlay.jl @@ -115,3 +115,28 @@ for randfun in (:rand, :randn, :randexp) # end end end + +# LinearAlgebra.jl overloads +## `_mul!` goes through too many layers of abstractions and we aren't able to overload +## without specializing on every possible combination of types +@reactant_overlay @noinline function LinearAlgebra.mul!( + C::AbstractVector, A::AbstractMatrix, B::AbstractVector, α::Number, β::Number +) + if any(Base.Fix2(isa, TracedRArray) ∘ ancestor, (C, A, B)) + TracedLinearAlgebra.overloaded_mul!(C, A, B, α, β) + else + LinearAlgebra._mul!(C, A, B, α, β) + end + return C +end + +@reactant_overlay @noinline function LinearAlgebra.mul!( + C::AbstractMatrix, A::AbstractMatrix, B::AbstractVecOrMat, α::Number, β::Number +) + if any(Base.Fix2(isa, TracedRArray) ∘ ancestor, (C, A, B)) + TracedLinearAlgebra.overloaded_mul!(C, A, B, α, β) + else + LinearAlgebra._mul!(C, A, B, α, β) + end + return C +end diff --git a/src/Reactant.jl b/src/Reactant.jl index f8dd8008b..d06784c13 100644 --- a/src/Reactant.jl +++ b/src/Reactant.jl @@ -119,7 +119,9 @@ const WrappedTracedRArray{T,N} = WrappedArray{ const AnyTracedRArray{T,N} = Union{TracedRArray{T,N},WrappedTracedRArray{T,N}} const AnyTracedRVector{T} = AnyTracedRArray{T,1} const AnyTracedRMatrix{T} = Union{ - AnyTracedRArray{T,2},LinearAlgebra.Diagonal{T,TracedRArray{T,1}} + AnyTracedRArray{T,2}, + LinearAlgebra.Diagonal{TracedRNumber{T},TracedRArray{T,1}}, + LinearAlgebra.Tridiagonal{TracedRNumber{T},TracedRArray{T,1}}, } const AnyTracedRVecOrMat{T} = Union{AnyTracedRVector{T},AnyTracedRMatrix{T}} diff --git a/src/stdlibs/LinearAlgebra.jl b/src/stdlibs/LinearAlgebra.jl index 95369e2a2..02605e8dc 100644 --- a/src/stdlibs/LinearAlgebra.jl +++ b/src/stdlibs/LinearAlgebra.jl @@ -34,7 +34,9 @@ function TracedUtils.materialize_traced_array( return diagm(parent(x)) end -function TracedUtils.materialize_traced_array(x::Tridiagonal{T,TracedRArray{T,1}}) where {T} +function TracedUtils.materialize_traced_array( + x::Tridiagonal{TracedRNumber{T},TracedRArray{T,1}} +) where {T} return diagm(-1 => x.dl, 0 => x.d, 1 => x.du) end @@ -162,7 +164,7 @@ function TracedUtils.set_mlir_data!( end # Core functions -function LinearAlgebra.mul!( +function overloaded_mul!( @nospecialize(C::TracedRArray{T,1}), @nospecialize(A::AnyTracedRMatrix), @nospecialize(B::AnyTracedRVector), @@ -176,7 +178,7 @@ function LinearAlgebra.mul!( return C end -function LinearAlgebra.mul!( +function overloaded_mul!( @nospecialize(C::TracedRArray{T,2}), @nospecialize(A::AnyTracedRMatrix), @nospecialize(B::AnyTracedRVector), @@ -187,7 +189,7 @@ function LinearAlgebra.mul!( return C end -function LinearAlgebra.mul!( +function overloaded_mul!( @nospecialize(C::TracedRArray{T,2}), @nospecialize(A::AnyTracedRMatrix), @nospecialize(B::AnyTracedRMatrix), diff --git a/test/integration/linear_algebra.jl b/test/integration/linear_algebra.jl index 0aac0f11c..ea39556f9 100644 --- a/test/integration/linear_algebra.jl +++ b/test/integration/linear_algebra.jl @@ -1,4 +1,4 @@ -using LinearAlgebra, Reactant +using LinearAlgebra, Reactant, Test function muladd2(A, x, b) C = similar(A, promote_type(eltype(A), eltype(b)), size(A, 1), size(x, 2)) @@ -143,13 +143,29 @@ end @test @jit(diagm(1 => x_ra1, 1 => x_ra2)) ≈ diagm(1 => x1, 1 => x2) end -# TODO: Currently Diagonal(x) * x goes down the generic matmul path but it should clearly be -# optimized +# TODO: Currently (x) * x goes down the generic matmul path but it should +# clearly be optimized mul_diagonal(x) = Diagonal(x) * x - -@testset "mul_diagonal" begin - x = rand(4) +mul_tridiagonal(x) = Tridiagonal(x) * x +mul_unit_lower_triangular(x) = UnitLowerTriangular(x) * x +mul_unit_upper_triangular(x) = UnitUpperTriangular(x) * x +mul_lower_triangular(x) = LowerTriangular(x) * x +mul_upper_triangular(x) = UpperTriangular(x) * x +mul_symmetric(x) = Symmetric(x) * x + +@testset "Wrapper Types Matrix Multiplication" begin + x = rand(4, 4) x_ra = Reactant.to_rarray(x) - @test @jit(mul_diagonal(x_ra)) ≈ mul_diagonal(x) + @testset "$(wrapper_type)" for (wrapper_type, fn) in [ + (Diagonal, mul_diagonal), + (Tridiagonal, mul_tridiagonal), + (UnitLowerTriangular, mul_unit_lower_triangular), + (UnitUpperTriangular, mul_unit_upper_triangular), + (LowerTriangular, mul_lower_triangular), + (UpperTriangular, mul_upper_triangular), + (Symmetric, mul_symmetric), + ] + @test @jit(fn(x_ra)) ≈ fn(x) + end end