-
-
Notifications
You must be signed in to change notification settings - Fork 5.5k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Fix Vector'Diagonal (and Transpose as well) to avoid infinite recursion. #26924
Conversation
stdlib/LinearAlgebra/src/diagonal.jl
Outdated
@@ -481,6 +481,10 @@ end | |||
# dismabiguation methods: * of Diagonal and Adj/Trans AbsVec | |||
*(A::Diagonal, B::Adjoint{<:Any,<:AbstractVector}) = A * copy(B) | |||
*(A::Diagonal, B::Transpose{<:Any,<:AbstractVector}) = A * copy(B) | |||
*(A::Adjoint{<:Any,<:AbstractVector}, B::Diagonal) = copy(A) * B | |||
*(A::Transpose{<:Any,<:AbstractVector}, B::Diagonal) = copy(A) * B | |||
*(x::Adjoint{<:Any,<:AbstractVector}, D::Diagonal) = Adjoint(map((t,s) -> t'*s, D.diag, parent(x))) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Perhaps (D' * x')'
would be simpler / clearer, if methods appropriate for D' * x'
exist? (Edit: I.e. methods for *(Adjoint{<:Any,<:Diagonal}, AbstractVector
).)
stdlib/LinearAlgebra/src/diagonal.jl
Outdated
*(A::Transpose{<:Any,<:AbstractVector}, B::Diagonal) = copy(A) * B | ||
*(x::Adjoint{<:Any,<:AbstractVector}, D::Diagonal) = Adjoint(map((t,s) -> t'*s, D.diag, parent(x))) | ||
*(x::Adjoint{<:Any,<:AbstractVector}, D::Diagonal, y::AbstractVector) = | ||
mapreduce(t -> t[1]'*t[2]*t[3], +, zip(parent(x), D.diag, y)) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Alternatively mapreduce(t -> t[1]*t[2]*t[3], +, zip(x, D.diag, y))
?
stdlib/LinearAlgebra/src/diagonal.jl
Outdated
*(x::Adjoint{<:Any,<:AbstractVector}, D::Diagonal) = Adjoint(map((t,s) -> t'*s, D.diag, parent(x))) | ||
*(x::Adjoint{<:Any,<:AbstractVector}, D::Diagonal, y::AbstractVector) = | ||
mapreduce(t -> t[1]'*t[2]*t[3], +, zip(parent(x), D.diag, y)) | ||
*(A::Transpose{<:Any,<:AbstractVector}, B::Diagonal) = Transpose(map(*, D.diag, parent(x))) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
B
in this method signature should be D
? (Perhaps test?)
Similarly to the above, perhaps Transpose(transpose(D) * transpose(x))
would be simpler / clearer, if methods appropriate for transpose(D) * transpose(x)
exist? (Edit: I.e. methods for *(Transpose{<:Any,<:Diagonal}, AbstractVector
).)
stdlib/LinearAlgebra/src/diagonal.jl
Outdated
*(x::Adjoint{<:Any,<:AbstractVector}, D::Diagonal, y::AbstractVector) = | ||
mapreduce(t -> t[1]'*t[2]*t[3], +, zip(parent(x), D.diag, y)) | ||
*(A::Transpose{<:Any,<:AbstractVector}, B::Diagonal) = Transpose(map(*, D.diag, parent(x))) | ||
*(x::Adjoint{<:Any,<:AbstractVector}, D::Diagonal, y::AbstractVector) = |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This method signature is identical to the signature above? Perhaps you meant *(x::Transpose{<:Any,<:AbstractVector}, D::Diagonal, y::AbstractVector)
instead? In such case the method body seems slightly amiss as well, particularly for nested objects? Perhaps mapreduce(t -> t[1]*t[2]*t[3], +, zip(x, D.diag, y))
instead?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Much thanks for patching these methods Andreas! Please find a few minor comments attached (edited to improve clarity after posting).
Also add optimized methods for x'D*y to avoid allocating temporary vector
Looks like this is working now. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Assuming the method deletions do not break an appreciable amount of external code, lgtm! :)
Also add optimized methods for x'D*y to avoid allocating temporary vector
Fixes JuliaLang/LinearAlgebra.jl#517