Skip to content

Commit

Permalink
Add \ for Sym/Tri/Bi/Diagonal and support non-commutative numbers (#3…
Browse files Browse the repository at this point in the history
…9701)

Co-authored-by: Daniel Karrasch <daniel.karrasch@posteo.de>
  • Loading branch information
dlfivefifty and dkarrasch authored Feb 26, 2021
1 parent 567f4bc commit 76698ea
Show file tree
Hide file tree
Showing 5 changed files with 37 additions and 3 deletions.
3 changes: 2 additions & 1 deletion stdlib/LinearAlgebra/src/bidiag.jl
Original file line number Diff line number Diff line change
Expand Up @@ -371,8 +371,9 @@ end

-(A::Bidiagonal)=Bidiagonal(-A.dv,-A.ev,A.uplo)
*(A::Bidiagonal, B::Number) = Bidiagonal(A.dv*B, A.ev*B, A.uplo)
*(B::Number, A::Bidiagonal) = A*B
*(B::Number, A::Bidiagonal) = Bidiagonal(B*A.dv, B*A.ev, A.uplo)
/(A::Bidiagonal, B::Number) = Bidiagonal(A.dv/B, A.ev/B, A.uplo)
\(B::Number, A::Bidiagonal) = Bidiagonal(B\A.dv, B\A.ev, A.uplo)

function ==(A::Bidiagonal, B::Bidiagonal)
if A.uplo == B.uplo
Expand Down
1 change: 1 addition & 0 deletions stdlib/LinearAlgebra/src/diagonal.jl
Original file line number Diff line number Diff line change
Expand Up @@ -174,6 +174,7 @@ end
(*)(x::Number, D::Diagonal) = Diagonal(x * D.diag)
(*)(D::Diagonal, x::Number) = Diagonal(D.diag * x)
(/)(D::Diagonal, x::Number) = Diagonal(D.diag / x)
(\)(x::Number, D::Diagonal) = Diagonal(x \ D.diag)

function (*)(Da::Diagonal, Db::Diagonal)
nDa, mDb = size(Da, 2), size(Db, 1)
Expand Down
6 changes: 4 additions & 2 deletions stdlib/LinearAlgebra/src/tridiag.jl
Original file line number Diff line number Diff line change
Expand Up @@ -206,8 +206,9 @@ end
-(A::SymTridiagonal, B::SymTridiagonal) = SymTridiagonal(A.dv-B.dv, A.ev-B.ev)
-(A::SymTridiagonal) = SymTridiagonal(-A.dv, -A.ev)
*(A::SymTridiagonal, B::Number) = SymTridiagonal(A.dv*B, A.ev*B)
*(B::Number, A::SymTridiagonal) = A*B
*(B::Number, A::SymTridiagonal) = SymTridiagonal(B*A.dv, B*A.ev)
/(A::SymTridiagonal, B::Number) = SymTridiagonal(A.dv/B, A.ev/B)
\(B::Number, A::SymTridiagonal) = SymTridiagonal(B\A.dv, B\A.ev)
==(A::SymTridiagonal, B::SymTridiagonal) = (A.dv==B.dv) && (A.ev==B.ev)

@inline mul!(A::StridedVecOrMat, B::SymTridiagonal, C::StridedVecOrMat,
Expand Down Expand Up @@ -733,8 +734,9 @@ end
+(A::Tridiagonal, B::Tridiagonal) = Tridiagonal(A.dl+B.dl, A.d+B.d, A.du+B.du)
-(A::Tridiagonal, B::Tridiagonal) = Tridiagonal(A.dl-B.dl, A.d-B.d, A.du-B.du)
*(A::Tridiagonal, B::Number) = Tridiagonal(A.dl*B, A.d*B, A.du*B)
*(B::Number, A::Tridiagonal) = A*B
*(B::Number, A::Tridiagonal) = Tridiagonal(B*A.dl, B*A.d, B*A.du)
/(A::Tridiagonal, B::Number) = Tridiagonal(A.dl/B, A.d/B, A.du/B)
\(B::Number, A::Tridiagonal) = Tridiagonal(B\A.dl, B\A.d, B\A.du)

==(A::Tridiagonal, B::Tridiagonal) = (A.dl==B.dl) && (A.d==B.d) && (A.du==B.du)
==(A::Tridiagonal, B::SymTridiagonal) = (A.dl==A.du==B.ev) && (A.d==B.dv)
Expand Down
14 changes: 14 additions & 0 deletions stdlib/LinearAlgebra/test/bidiag.jl
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,11 @@ module TestBidiagonal
using Test, LinearAlgebra, SparseArrays, Random
using LinearAlgebra: BlasReal, BlasFloat

const BASE_TEST_PATH = joinpath(Sys.BINDIR, "..", "share", "julia", "test")

isdefined(Main, :Quaternions) || @eval Main include(joinpath($(BASE_TEST_PATH), "testhelpers", "Quaternions.jl"))
using .Main.Quaternions

include("testutils.jl") # test_approx_eq_modphase

n = 10 #Size of test matrix
Expand Down Expand Up @@ -635,4 +640,13 @@ end
@test ubd .* 3 == ubd
end

@testset "non-commutative algebra (#39701)" begin
A = Bidiagonal(Quaternion.(randn(5), randn(5), randn(5), randn(5)), Quaternion.(randn(4), randn(4), randn(4), randn(4)), :U)
c = Quaternion(1,2,3,4)
@test A * c Matrix(A) * c
@test A / c Matrix(A) / c
@test c * A c * Matrix(A)
@test c \ A c \ Matrix(A)
end

end # module TestBidiagonal
16 changes: 16 additions & 0 deletions stdlib/LinearAlgebra/test/tridiag.jl
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,11 @@ module TestTridiagonal

using Test, LinearAlgebra, SparseArrays, Random

const BASE_TEST_PATH = joinpath(Sys.BINDIR, "..", "share", "julia", "test")

isdefined(Main, :Quaternions) || @eval Main include(joinpath($(BASE_TEST_PATH), "testhelpers", "Quaternions.jl"))
using .Main.Quaternions

include("testutils.jl") # test_approx_eq_modphase

#Test equivalence of eigenvectors/singular vectors taking into account possible phase (sign) differences
Expand Down Expand Up @@ -587,4 +592,15 @@ end
@test F.values F2.values
end

@testset "non-commutative algebra (#39701)" begin
for A in (SymTridiagonal(Quaternion.(randn(5), randn(5), randn(5), randn(5)), Quaternion.(randn(4), randn(4), randn(4), randn(4))),
Tridiagonal(Quaternion.(randn(4), randn(4), randn(4), randn(4)), Quaternion.(randn(5), randn(5), randn(5), randn(5)), Quaternion.(randn(4), randn(4), randn(4), randn(4))))
c = Quaternion(1,2,3,4)
@test A * c Matrix(A) * c
@test A / c Matrix(A) / c
@test c * A c * Matrix(A)
@test c \ A c \ Matrix(A)
end
end

end # module TestTridiagonal

0 comments on commit 76698ea

Please sign in to comment.