diff --git a/src/rulesets/Base/arraymath.jl b/src/rulesets/Base/arraymath.jl index 3673c0a43..e1e2626c7 100644 --- a/src/rulesets/Base/arraymath.jl +++ b/src/rulesets/Base/arraymath.jl @@ -342,20 +342,24 @@ function rrule(::typeof(\), A::AbstractVecOrMat{<:Real}, B::AbstractVecOrMat{<:R project_B = ProjectTo(B) Y = A \ B + # Ever since https://github.com/JuliaLang/julia/pull/44358 + # we need to use `pinv` rather than `/` to support both the cases of Y being scalar and array + # See also https://github.com/JuliaLang/julia/issues/28827 which would improve this function backslash_pullback(ȳ) Ȳ = unthunk(ȳ) + Ati = pinv(A') ∂A = @thunk begin - B̄ = A' \ Ȳ + + B̄ = Ati * Ȳ Ā = -B̄ * Y' - Ā = add!!(Ā, (B - A * Y) * B̄' / A') - Ā = add!!(Ā, A' \ Y * (Ȳ' - B̄'A)) + Ā = add!!(Ā, ((B - A * Y) * B̄') * Ati) + Ā = add!!(Ā, Ati * Y * (Ȳ' - B̄'A)) project_A(Ā) end - ∂B = @thunk project_B(A' \ Ȳ) + ∂B = @thunk project_B(Ati * Ȳ) return NoTangent(), ∂A, ∂B end return Y, backslash_pullback - end #####