From 0fee46419715fc8c52d29465c927a51342cf605c Mon Sep 17 00:00:00 2001 From: Avik Pal Date: Sun, 29 Dec 2024 12:41:15 -0500 Subject: [PATCH] fix: de-specialize 3 arg mul! --- src/Compiler.jl | 2 +- src/Overlay.jl | 38 +++++++++++++++++++++----------------- test/wrapped_arrays.jl | 30 ++++++++++++++++++++++++++++++ 3 files changed, 52 insertions(+), 18 deletions(-) diff --git a/src/Compiler.jl b/src/Compiler.jl index 62611047b..3c4c3996d 100644 --- a/src/Compiler.jl +++ b/src/Compiler.jl @@ -105,7 +105,7 @@ function create_result(tocopy::D, path, result_stores) where {K,V,D<:AbstractDic end function create_result( - tocopy::Union{Integer,AbstractFloat,AbstractString,Nothing,Type,Symbol}, + tocopy::Union{Integer,AbstractFloat,AbstractString,Nothing,Type,Symbol,Char}, path, result_stores, ) diff --git a/src/Overlay.jl b/src/Overlay.jl index 701d9c711..228edffe0 100644 --- a/src/Overlay.jl +++ b/src/Overlay.jl @@ -119,24 +119,28 @@ end # LinearAlgebra.jl overloads ## `_mul!` goes through too many layers of abstractions and we aren't able to overload ## without specializing on every possible combination of types -@reactant_overlay @noinline function LinearAlgebra.mul!( - C::AbstractVector, A::AbstractMatrix, B::AbstractVector, α::Number, β::Number +for (cT, aT, bT) in ( + (:AbstractVector, :AbstractMatrix, :AbstractVector), + (:AbstractMatrix, :AbstractMatrix, :AbstractVecOrMat), ) - if any(Base.Fix2(isa, TracedRArray) ∘ ancestor, (C, A, B)) - TracedLinearAlgebra.overloaded_mul!(C, A, B, α, β) - else - LinearAlgebra._mul!(C, A, B, α, β) - end - return C -end + @eval begin + @reactant_overlay @noinline function LinearAlgebra.mul!( + C::$cT, A::$aT, B::$bT, α::Number, β::Number + ) + if any(Base.Fix2(isa, TracedRArray) ∘ ancestor, (C, A, B)) + TracedLinearAlgebra.overloaded_mul!(C, A, B, α, β) + else + LinearAlgebra._mul!(C, A, B, α, β) + end + return C + end -@reactant_overlay @noinline function LinearAlgebra.mul!( - C::AbstractMatrix, A::AbstractMatrix, B::AbstractVecOrMat, α::Number, β::Number -) - if any(Base.Fix2(isa, TracedRArray) ∘ ancestor, (C, A, B)) - TracedLinearAlgebra.overloaded_mul!(C, A, B, α, β) - else - LinearAlgebra._mul!(C, A, B, α, β) + # Needed mostly for 1.10 where 3-arg mul is often specialized + @reactant_overlay @noinline function LinearAlgebra.mul!( + C::$cT, A::$aT, B::$bT + ) + call_with_reactant(LinearAlgebra.mul!, C, A, B, true, false) + return C + end end - return C end diff --git a/test/wrapped_arrays.jl b/test/wrapped_arrays.jl index f5418e5c8..c522bcd17 100644 --- a/test/wrapped_arrays.jl +++ b/test/wrapped_arrays.jl @@ -172,3 +172,33 @@ end @test all(iszero, y_res) end end + +function lower_triangular_write(x) + y = LowerTriangular(copy(x)) + @. y *= 2 + return y +end + +function upper_triangular_write(x) + y = UpperTriangular(copy(x)) + @. y *= 2 + return y +end + +function tridiagonal_write(x) + y = Tridiagonal(copy(x)) + @. y *= 2 + return y +end + +@testset "Broadcasted Multiply and Alloate" begin + @testset "$(aType)" for (aType, fn) in [ + ("LowerTriangular", lower_triangular_write), + ("UpperTriangular", upper_triangular_write), + ("Tridiagonal", tridiagonal_write), + ] + x = rand(4, 4) + x_ra = Reactant.to_rarray(x) + @test @jit(fn(x_ra)) ≈ fn(x) + end +end