diff --git a/src/rulesets/Base/arraymath.jl b/src/rulesets/Base/arraymath.jl index 3673c0a43..85fd8df51 100644 --- a/src/rulesets/Base/arraymath.jl +++ b/src/rulesets/Base/arraymath.jl @@ -342,20 +342,70 @@ function rrule(::typeof(\), A::AbstractVecOrMat{<:Real}, B::AbstractVecOrMat{<:R project_B = ProjectTo(B) Y = A \ B + function backslash_pullback(ȳ) Ȳ = unthunk(ȳ) + + Ȳf = Ȳ + @static if VERSION >= v"1.9" + # Need to ensure Ȳ is an array since since https://github.com/JuliaLang/julia/pull/44358 + if !isa(Ȳ, AbstractArray) + Ȳf = [Ȳ] + end + end + Yf = Y + @static if VERSION >= v"1.9" + # Need to ensure Yf is an array since since https://github.com/JuliaLang/julia/pull/44358 + if !isa(Y, AbstractArray) + Yf = [Y] + end + end + #@info "vars" typeof(Ȳ) typeof(Y) typeof(Yf) typeof(A) typeof(B) ∂A = @thunk begin - B̄ = A' \ Ȳ + B̄ = A' \ Ȳf Ā = -B̄ * Y' - Ā = add!!(Ā, (B - A * Y) * B̄' / A') - Ā = add!!(Ā, A' \ Y * (Ȳ' - B̄'A)) + t = (B - A * Y) * B̄' + @static if VERSION >= v"1.9" + # Need to ensure t is an array since since https://github.com/JuliaLang/julia/pull/44358 + if !isa(t, AbstractArray) + t = [t] + end + end + Ā = add!!(Ā, t / A') + Ā = add!!(Ā, A' \ Yf * (Ȳ' - B̄'A)) project_A(Ā) end - ∂B = @thunk project_B(A' \ Ȳ) + ∂B = @thunk project_B(A' \ Ȳf) return NoTangent(), ∂A, ∂B end return Y, backslash_pullback +end + +@static if VERSION >= v"1.9" + # Need to ensure things are not scalar since since https://github.com/JuliaLang/julia/pull/44358 + _maybe_descalar(x) = x isa AbstractArray ? x : [x] +else + _maybe_descalar(x) = x +end + +function rrule(A::AbstractVecOrMat{<:Real}, B::AbstractVecOrMat{<:Real}) + Y = A \ B + + + function backslash_pullback(ȳ) + Ȳ = unthunk(ȳ) + ∂A = @thunk begin + B̄ = A' \ _maybe_descalar(Ȳ) + Ā = -B̄ * Y' + Ā += _maybe_descalar((B - A * Y) * B̄') / A' + Ā += (A' \ _maybe_descalar(Y)) * (Ȳ' - B̄'A) + (Ā) + end + ∂B = @thunk (A' \ _maybe_descalar(Ȳ)) + return ∂A, ∂B + end + return Y, backslash_pullback end ##### diff --git a/test/rulesets/Base/arraymath.jl b/test/rulesets/Base/arraymath.jl index 5eaf9e7fc..847808c1f 100644 --- a/test/rulesets/Base/arraymath.jl +++ b/test/rulesets/Base/arraymath.jl @@ -1,7 +1,7 @@ @testset "arraymath.jl" begin @testset "inv(::Matrix{$T})" for T in (Float64, ComplexF64) B = generate_well_conditioned_matrix(T, 3) - if VERSION >= v"1.7" + if v"1.7" <= VERSION < v"1.9" @gpu test_frule(inv, B) @gpu test_rrule(inv, B) else @@ -167,12 +167,12 @@ @testset "Matrix $f Vector" begin X = randn(10, 4) y = randn(10) - test_rrule(f, X, y) + test_rrule(f, X, y; check_inferred=false) end @testset "Vector $f Matrix" begin x = randn(10) Y = randn(10, 4) - test_rrule(f, x, Y; output_tangent=Transpose(rand(4))) + test_rrule(f, x, Y; output_tangent=Transpose(rand(4)), check_inferred=false) end else A = rand(2, 4)