Skip to content

Commit

Permalink
Fix / on 1.9
Browse files Browse the repository at this point in the history
  • Loading branch information
oxinabox committed May 19, 2023
1 parent 55a48c6 commit fe026b5
Showing 1 changed file with 9 additions and 5 deletions.
14 changes: 9 additions & 5 deletions src/rulesets/Base/arraymath.jl
Original file line number Diff line number Diff line change
Expand Up @@ -342,20 +342,24 @@ function rrule(::typeof(\), A::AbstractVecOrMat{<:Real}, B::AbstractVecOrMat{<:R
project_B = ProjectTo(B)

Y = A \ B
# Ever since https://github.com/JuliaLang/julia/pull/44358
# we need to use `pinv` rather than `/` to support both the cases of Y being scalar and array
# See also https://github.com/JuliaLang/julia/issues/28827 which would improve this
function backslash_pullback(ȳ)
= unthunk(ȳ)
Ati = pinv(A')
∂A = @thunk begin
= A' \

= Ati *
= -* Y'
= add!!(Ā, (B - A * Y) *' / A')
= add!!(Ā, A' \ Y * (Ȳ' -'A))
= add!!(Ā, ((B - A * Y) *') * Ati)
= add!!(Ā, Ati * Y * (Ȳ' -'A))
project_A(Ā)
end
∂B = @thunk project_B(A' \ Ȳ)
∂B = @thunk project_B(Ati * Ȳ)
return NoTangent(), ∂A, ∂B
end
return Y, backslash_pullback

end

#####
Expand Down

0 comments on commit fe026b5

Please sign in to comment.