From abc1f1d2c4d47ca54731e3e97afed6d3b738daca Mon Sep 17 00:00:00 2001 From: Sheehan Olver Date: Fri, 26 Feb 2021 14:19:05 +0000 Subject: [PATCH] Add \ for Sym/Tri/Bi/Diagonal and support non-commutative numbers (#39701) Co-authored-by: Daniel Karrasch --- stdlib/LinearAlgebra/src/bidiag.jl | 3 ++- stdlib/LinearAlgebra/src/diagonal.jl | 1 + stdlib/LinearAlgebra/src/tridiag.jl | 6 ++++-- stdlib/LinearAlgebra/test/bidiag.jl | 14 ++++++++++++++ stdlib/LinearAlgebra/test/tridiag.jl | 16 ++++++++++++++++ 5 files changed, 37 insertions(+), 3 deletions(-) diff --git a/stdlib/LinearAlgebra/src/bidiag.jl b/stdlib/LinearAlgebra/src/bidiag.jl index 1803effa24361..69fbaa476de73 100644 --- a/stdlib/LinearAlgebra/src/bidiag.jl +++ b/stdlib/LinearAlgebra/src/bidiag.jl @@ -371,8 +371,9 @@ end -(A::Bidiagonal)=Bidiagonal(-A.dv,-A.ev,A.uplo) *(A::Bidiagonal, B::Number) = Bidiagonal(A.dv*B, A.ev*B, A.uplo) -*(B::Number, A::Bidiagonal) = A*B +*(B::Number, A::Bidiagonal) = Bidiagonal(B*A.dv, B*A.ev, A.uplo) /(A::Bidiagonal, B::Number) = Bidiagonal(A.dv/B, A.ev/B, A.uplo) +\(B::Number, A::Bidiagonal) = Bidiagonal(B\A.dv, B\A.ev, A.uplo) function ==(A::Bidiagonal, B::Bidiagonal) if A.uplo == B.uplo diff --git a/stdlib/LinearAlgebra/src/diagonal.jl b/stdlib/LinearAlgebra/src/diagonal.jl index 0426945e24e73..64ef8c69660c0 100644 --- a/stdlib/LinearAlgebra/src/diagonal.jl +++ b/stdlib/LinearAlgebra/src/diagonal.jl @@ -174,6 +174,7 @@ end (*)(x::Number, D::Diagonal) = Diagonal(x * D.diag) (*)(D::Diagonal, x::Number) = Diagonal(D.diag * x) (/)(D::Diagonal, x::Number) = Diagonal(D.diag / x) +(\)(x::Number, D::Diagonal) = Diagonal(x \ D.diag) function (*)(Da::Diagonal, Db::Diagonal) nDa, mDb = size(Da, 2), size(Db, 1) diff --git a/stdlib/LinearAlgebra/src/tridiag.jl b/stdlib/LinearAlgebra/src/tridiag.jl index 949f196df2b69..8420750f8f4a1 100644 --- a/stdlib/LinearAlgebra/src/tridiag.jl +++ b/stdlib/LinearAlgebra/src/tridiag.jl @@ -206,8 +206,9 @@ end -(A::SymTridiagonal, B::SymTridiagonal) = SymTridiagonal(A.dv-B.dv, A.ev-B.ev) -(A::SymTridiagonal) = SymTridiagonal(-A.dv, -A.ev) *(A::SymTridiagonal, B::Number) = SymTridiagonal(A.dv*B, A.ev*B) -*(B::Number, A::SymTridiagonal) = A*B +*(B::Number, A::SymTridiagonal) = SymTridiagonal(B*A.dv, B*A.ev) /(A::SymTridiagonal, B::Number) = SymTridiagonal(A.dv/B, A.ev/B) +\(B::Number, A::SymTridiagonal) = SymTridiagonal(B\A.dv, B\A.ev) ==(A::SymTridiagonal, B::SymTridiagonal) = (A.dv==B.dv) && (A.ev==B.ev) @inline mul!(A::StridedVecOrMat, B::SymTridiagonal, C::StridedVecOrMat, @@ -733,8 +734,9 @@ end +(A::Tridiagonal, B::Tridiagonal) = Tridiagonal(A.dl+B.dl, A.d+B.d, A.du+B.du) -(A::Tridiagonal, B::Tridiagonal) = Tridiagonal(A.dl-B.dl, A.d-B.d, A.du-B.du) *(A::Tridiagonal, B::Number) = Tridiagonal(A.dl*B, A.d*B, A.du*B) -*(B::Number, A::Tridiagonal) = A*B +*(B::Number, A::Tridiagonal) = Tridiagonal(B*A.dl, B*A.d, B*A.du) /(A::Tridiagonal, B::Number) = Tridiagonal(A.dl/B, A.d/B, A.du/B) +\(B::Number, A::Tridiagonal) = Tridiagonal(B\A.dl, B\A.d, B\A.du) ==(A::Tridiagonal, B::Tridiagonal) = (A.dl==B.dl) && (A.d==B.d) && (A.du==B.du) ==(A::Tridiagonal, B::SymTridiagonal) = (A.dl==A.du==B.ev) && (A.d==B.dv) diff --git a/stdlib/LinearAlgebra/test/bidiag.jl b/stdlib/LinearAlgebra/test/bidiag.jl index 943c6862a3186..e4dcd14053778 100644 --- a/stdlib/LinearAlgebra/test/bidiag.jl +++ b/stdlib/LinearAlgebra/test/bidiag.jl @@ -5,6 +5,11 @@ module TestBidiagonal using Test, LinearAlgebra, SparseArrays, Random using LinearAlgebra: BlasReal, BlasFloat +const BASE_TEST_PATH = joinpath(Sys.BINDIR, "..", "share", "julia", "test") + +isdefined(Main, :Quaternions) || @eval Main include(joinpath($(BASE_TEST_PATH), "testhelpers", "Quaternions.jl")) +using .Main.Quaternions + include("testutils.jl") # test_approx_eq_modphase n = 10 #Size of test matrix @@ -635,4 +640,13 @@ end @test ubd .* 3 == ubd end +@testset "non-commutative algebra (#39701)" begin + A = Bidiagonal(Quaternion.(randn(5), randn(5), randn(5), randn(5)), Quaternion.(randn(4), randn(4), randn(4), randn(4)), :U) + c = Quaternion(1,2,3,4) + @test A * c ≈ Matrix(A) * c + @test A / c ≈ Matrix(A) / c + @test c * A ≈ c * Matrix(A) + @test c \ A ≈ c \ Matrix(A) +end + end # module TestBidiagonal diff --git a/stdlib/LinearAlgebra/test/tridiag.jl b/stdlib/LinearAlgebra/test/tridiag.jl index e0a21651eee96..ec777bcd46222 100644 --- a/stdlib/LinearAlgebra/test/tridiag.jl +++ b/stdlib/LinearAlgebra/test/tridiag.jl @@ -4,6 +4,11 @@ module TestTridiagonal using Test, LinearAlgebra, SparseArrays, Random +const BASE_TEST_PATH = joinpath(Sys.BINDIR, "..", "share", "julia", "test") + +isdefined(Main, :Quaternions) || @eval Main include(joinpath($(BASE_TEST_PATH), "testhelpers", "Quaternions.jl")) +using .Main.Quaternions + include("testutils.jl") # test_approx_eq_modphase #Test equivalence of eigenvectors/singular vectors taking into account possible phase (sign) differences @@ -587,4 +592,15 @@ end @test F.values ≈ F2.values end +@testset "non-commutative algebra (#39701)" begin + for A in (SymTridiagonal(Quaternion.(randn(5), randn(5), randn(5), randn(5)), Quaternion.(randn(4), randn(4), randn(4), randn(4))), + Tridiagonal(Quaternion.(randn(4), randn(4), randn(4), randn(4)), Quaternion.(randn(5), randn(5), randn(5), randn(5)), Quaternion.(randn(4), randn(4), randn(4), randn(4)))) + c = Quaternion(1,2,3,4) + @test A * c ≈ Matrix(A) * c + @test A / c ≈ Matrix(A) / c + @test c * A ≈ c * Matrix(A) + @test c \ A ≈ c \ Matrix(A) + end +end + end # module TestTridiagonal