From c741bd3d7e29a29dfeb4fcf2144a006669543069 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Mateus=20Ara=C3=BAjo?= Date: Thu, 18 Apr 2024 08:13:49 +0200 Subject: [PATCH] stdlib: faster kronecker product between hermitian and symmetric matrices (#53186) The kronecker product between complex hermitian matrices is again hermitian, so it can be computed much faster by only doing the upper (or lower) triangular. As @andreasnoack will surely notice, this only true for types where `conj(a*b) == conj(a)*conj(b)`, so I'm restricting the function to act only on real and complex numbers. In the symmetric case, however, no additional assumption is needed, so I'm letting it act on anything. Benchmarking showed that the code is roughly 2 times as fast as the vanilla kronecker product, as expected. The fastest case was always the UU case, and the slowest the LU case. The code I used is below ```julia using LinearAlgebra using BenchmarkTools using Quaternions randrmatrix(d, uplo = :U) = Hermitian(randn(Float64, d, d), uplo) randcmatrix(d, uplo = :U) = Hermitian(randn(ComplexF64, d, d), uplo) randsmatrix(d, uplo = :U) = Symmetric(randn(ComplexF64, d, d), uplo) randqmatrix(d, uplo = :U) = Symmetric(randn(QuaternionF64, d, d), uplo) dima = 69 dimb = 71 for randmatrix in [randrmatrix, randcmatrix, randsmatrix, randqmatrix] for auplo in [:U, :L] for buplo in [:U, :L] a = randmatrix(dima, auplo) b = randmatrix(dimb, buplo) c = kron(a,b) therm = @belapsed kron!($c, $a, $b) C = Matrix(c) A = Matrix(a) B = Matrix(b) told = @belapsed kron!($C, $A, $B) @show told/therm end end end ``` Weirdly enough, I got this expected speedup in one of my machines, but when running the benchmark in another I got roughly the same time. I guess that's a bug with `BechmarkTools`, because that's not consistent with the times I get running the functions individually, out of the loop. Another issue is that although I added a couple of tests, I couldn't get them to run. Perhaps someone here can tell me what's going on? I could run the tests from LinearAlgebra, it's just that editing the files made no difference to what was being run. I did get hundreds of errors from `triangular.jl`, but that's untouched by my code. --------- Co-authored-by: Oscar Smith --- stdlib/LinearAlgebra/src/dense.jl | 4 +- stdlib/LinearAlgebra/src/symmetric.jl | 124 ++++++++++++++++++++++++ stdlib/LinearAlgebra/src/triangular.jl | 74 ++++++++++++++ stdlib/LinearAlgebra/test/symmetric.jl | 35 +++++++ stdlib/LinearAlgebra/test/triangular.jl | 3 + 5 files changed, 238 insertions(+), 2 deletions(-) diff --git a/stdlib/LinearAlgebra/src/dense.jl b/stdlib/LinearAlgebra/src/dense.jl index 97a64aee3e721..4ed3b8624c47b 100644 --- a/stdlib/LinearAlgebra/src/dense.jl +++ b/stdlib/LinearAlgebra/src/dense.jl @@ -491,8 +491,8 @@ julia> reshape(kron(v,w), (length(w), length(v))) ``` """ function kron(A::AbstractVecOrMat{T}, B::AbstractVecOrMat{S}) where {T,S} - R = Matrix{promote_op(*,T,S)}(undef, _kronsize(A, B)) - return kron!(R, A, B) + C = Matrix{promote_op(*,T,S)}(undef, _kronsize(A, B)) + return kron!(C, A, B) end function kron(a::AbstractVector{T}, b::AbstractVector{S}) where {T,S} c = Vector{promote_op(*,T,S)}(undef, length(a)*length(b)) diff --git a/stdlib/LinearAlgebra/src/symmetric.jl b/stdlib/LinearAlgebra/src/symmetric.jl index 10b996ca49e36..410140bf7e8be 100644 --- a/stdlib/LinearAlgebra/src/symmetric.jl +++ b/stdlib/LinearAlgebra/src/symmetric.jl @@ -525,6 +525,130 @@ for (T, trans, real) in [(:Symmetric, :transpose, :identity), (:(Hermitian{<:Uni end end +function kron(A::Hermitian{T}, B::Hermitian{S}) where {T<:Union{Real,Complex},S<:Union{Real,Complex}} + resultuplo = A.uplo == 'U' || B.uplo == 'U' ? :U : :L + C = Hermitian(Matrix{promote_op(*, T, S)}(undef, _kronsize(A, B)), resultuplo) + return kron!(C, A, B) +end + +function kron(A::Symmetric{T}, B::Symmetric{S}) where {T<:Number,S<:Number} + resultuplo = A.uplo == 'U' || B.uplo == 'U' ? :U : :L + C = Symmetric(Matrix{promote_op(*, T, S)}(undef, _kronsize(A, B)), resultuplo) + return kron!(C, A, B) +end + +function kron!(C::Hermitian{<:Union{Real,Complex}}, A::Hermitian{<:Union{Real,Complex}}, B::Hermitian{<:Union{Real,Complex}}) + size(C) == _kronsize(A, B) || throw(DimensionMismatch("kron!")) + if ((A.uplo == 'U' || B.uplo == 'U') && C.uplo != 'U') || ((A.uplo == 'L' && B.uplo == 'L') && C.uplo != 'L') + throw(ArgumentError("C.uplo must match A.uplo and B.uplo, got $(C.uplo) $(A.uplo) $(B.uplo)")) + end + _hermkron!(C.data, A.data, B.data, conj, real, A.uplo, B.uplo) + return C +end + +function kron!(C::Symmetric{<:Number}, A::Symmetric{<:Number}, B::Symmetric{<:Number}) + size(C) == _kronsize(A, B) || throw(DimensionMismatch("kron!")) + if ((A.uplo == 'U' || B.uplo == 'U') && C.uplo != 'U') || ((A.uplo == 'L' && B.uplo == 'L') && C.uplo != 'L') + throw(ArgumentError("C.uplo must match A.uplo and B.uplo, got $(C.uplo) $(A.uplo) $(B.uplo)")) + end + _hermkron!(C.data, A.data, B.data, identity, identity, A.uplo, B.uplo) + return C +end + +function _hermkron!(C, A, B, conj::TC, real::TR, Auplo, Buplo) where {TC,TR} + n_A = size(A, 1) + n_B = size(B, 1) + @inbounds if Auplo == 'U' && Buplo == 'U' + for j = 1:n_A + jnB = (j - 1) * n_B + for i = 1:(j-1) + Aij = A[i, j] + inB = (i - 1) * n_B + for l = 1:n_B + for k = 1:(l-1) + C[inB+k, jnB+l] = Aij * B[k, l] + C[inB+l, jnB+k] = Aij * conj(B[k, l]) + end + C[inB+l, jnB+l] = Aij * real(B[l, l]) + end + end + Ajj = real(A[j, j]) + for l = 1:n_B + for k = 1:(l-1) + C[jnB+k, jnB+l] = Ajj * B[k, l] + end + C[jnB+l, jnB+l] = Ajj * real(B[l, l]) + end + end + elseif Auplo == 'U' && Buplo == 'L' + for j = 1:n_A + jnB = (j - 1) * n_B + for i = 1:(j-1) + Aij = A[i, j] + inB = (i - 1) * n_B + for l = 1:n_B + C[inB+l, jnB+l] = Aij * real(B[l, l]) + for k = (l+1):n_B + C[inB+l, jnB+k] = Aij * conj(B[k, l]) + C[inB+k, jnB+l] = Aij * B[k, l] + end + end + end + Ajj = real(A[j, j]) + for l = 1:n_B + C[jnB+l, jnB+l] = Ajj * real(B[l, l]) + for k = (l+1):n_B + C[jnB+l, jnB+k] = Ajj * conj(B[k, l]) + end + end + end + elseif Auplo == 'L' && Buplo == 'U' + for j = 1:n_A + jnB = (j - 1) * n_B + Ajj = real(A[j, j]) + for l = 1:n_B + for k = 1:(l-1) + C[jnB+k, jnB+l] = Ajj * B[k, l] + end + C[jnB+l, jnB+l] = Ajj * real(B[l, l]) + end + for i = (j+1):n_A + conjAij = conj(A[i, j]) + inB = (i - 1) * n_B + for l = 1:n_B + for k = 1:(l-1) + C[jnB+k, inB+l] = conjAij * B[k, l] + C[jnB+l, inB+k] = conjAij * conj(B[k, l]) + end + C[jnB+l, inB+l] = conjAij * real(B[l, l]) + end + end + end + else #if Auplo == 'L' && Buplo == 'L' + for j = 1:n_A + jnB = (j - 1) * n_B + Ajj = real(A[j, j]) + for l = 1:n_B + C[jnB+l, jnB+l] = Ajj * real(B[l, l]) + for k = (l+1):n_B + C[jnB+k, jnB+l] = Ajj * B[k, l] + end + end + for i = (j+1):n_A + Aij = A[i, j] + inB = (i - 1) * n_B + for l = 1:n_B + C[inB+l, jnB+l] = Aij * real(B[l, l]) + for k = (l+1):n_B + C[inB+k, jnB+l] = Aij * B[k, l] + C[inB+l, jnB+k] = Aij * conj(B[k, l]) + end + end + end + end + end +end + (-)(A::Symmetric) = Symmetric(parentof_applytri(-, A), sym_uplo(A.uplo)) (-)(A::Hermitian) = Hermitian(parentof_applytri(-, A), sym_uplo(A.uplo)) diff --git a/stdlib/LinearAlgebra/src/triangular.jl b/stdlib/LinearAlgebra/src/triangular.jl index a78e838654f37..bc6c2a64a6d7d 100644 --- a/stdlib/LinearAlgebra/src/triangular.jl +++ b/stdlib/LinearAlgebra/src/triangular.jl @@ -757,6 +757,80 @@ for op in (:+, :-) end end +function kron(A::UpperTriangular{T}, B::UpperTriangular{S}) where {T<:Number,S<:Number} + C = UpperTriangular(Matrix{promote_op(*, T, S)}(undef, _kronsize(A, B))) + return kron!(C, A, B) +end + +function kron(A::LowerTriangular{T}, B::LowerTriangular{S}) where {T<:Number,S<:Number} + C = LowerTriangular(Matrix{promote_op(*, T, S)}(undef, _kronsize(A, B))) + return kron!(C, A, B) +end + +function kron!(C::UpperTriangular{<:Number}, A::UpperTriangular{<:Number}, B::UpperTriangular{<:Number}) + size(C) == _kronsize(A, B) || throw(DimensionMismatch("kron!")) + _triukron!(C.data, A.data, B.data) + return C +end + +function kron!(C::LowerTriangular{<:Number}, A::LowerTriangular{<:Number}, B::LowerTriangular{<:Number}) + size(C) == _kronsize(A, B) || throw(DimensionMismatch("kron!")) + _trilkron!(C.data, A.data, B.data) + return C +end + +function _triukron!(C, A, B) + n_A = size(A, 1) + n_B = size(B, 1) + @inbounds for j = 1:n_A + jnB = (j - 1) * n_B + for i = 1:(j-1) + Aij = A[i, j] + inB = (i - 1) * n_B + for l = 1:n_B + for k = 1:l + C[inB+k, jnB+l] = Aij * B[k, l] + end + for k = 1:(l-1) + C[inB+l, jnB+k] = zero(eltype(C)) + end + end + end + Ajj = A[j, j] + for l = 1:n_B + for k = 1:l + C[jnB+k, jnB+l] = Ajj * B[k, l] + end + end + end +end + +function _trilkron!(C, A, B) + n_A = size(A, 1) + n_B = size(B, 1) + @inbounds for j = 1:n_A + jnB = (j - 1) * n_B + Ajj = A[j, j] + for l = 1:n_B + for k = l:n_B + C[jnB+k, jnB+l] = Ajj * B[k, l] + end + end + for i = (j+1):n_A + Aij = A[i, j] + inB = (i - 1) * n_B + for l = 1:n_B + for k = l:n_B + C[inB+k, jnB+l] = Aij * B[k, l] + end + for k = (l+1):n_B + C[inB+l, jnB+k] = zero(eltype(C)) + end + end + end + end +end + ###################### # BlasFloat routines # ###################### diff --git a/stdlib/LinearAlgebra/test/symmetric.jl b/stdlib/LinearAlgebra/test/symmetric.jl index 255e4e398b446..e2a6d2b74ff18 100644 --- a/stdlib/LinearAlgebra/test/symmetric.jl +++ b/stdlib/LinearAlgebra/test/symmetric.jl @@ -467,6 +467,28 @@ end @test dot(symblockml, symblockml) ≈ dot(msymblockml, msymblockml) end end + + @testset "kronecker product of symmetric and Hermitian matrices" begin + for mtype in (Symmetric, Hermitian) + symau = mtype(a, :U) + symal = mtype(a, :L) + msymau = Matrix(symau) + msymal = Matrix(symal) + for eltyc in (Float32, Float64, ComplexF32, ComplexF64, BigFloat, Int) + creal = randn(n, n)/2 + cimag = randn(n, n)/2 + c = eltya == Int ? rand(1:7, n, n) : convert(Matrix{eltya}, eltya <: Complex ? complex.(creal, cimag) : creal) + symcu = mtype(c, :U) + symcl = mtype(c, :L) + msymcu = Matrix(symcu) + msymcl = Matrix(symcl) + @test kron(symau, symcu) ≈ kron(msymau, msymcu) + @test kron(symau, symcl) ≈ kron(msymau, msymcl) + @test kron(symal, symcu) ≈ kron(msymal, msymcu) + @test kron(symal, symcl) ≈ kron(msymal, msymcl) + end + end + end end end @@ -487,6 +509,7 @@ end @test S - S == MS - MS @test S*2 == 2*S == 2*MS @test S/2 == MS/2 + @test kron(S,S) == kron(MS,MS) end @testset "mixed uplo" begin Mu = Matrix{Complex{BigFloat}}(undef,2,2) @@ -502,6 +525,8 @@ end MSl = Matrix(Sl) @test Su + Sl == Sl + Su == MSu + MSl @test Su - Sl == -(Sl - Su) == MSu - MSl + @test kron(Su,Sl) == kron(MSu,MSl) + @test kron(Sl,Su) == kron(MSl,MSu) end end end @@ -517,6 +542,16 @@ end @test dot(A, B) ≈ dot(Symmetric(A), Symmetric(B)) end +# let's make sure the analogous bug will not show up with kronecker products +@testset "kron Hermitian quaternion #52318" begin + A, B = [Quaternion.(randn(3,3), randn(3, 3), randn(3, 3), randn(3,3)) |> t -> t + t' for i in 1:2] + @test A == Hermitian(A) && B == Hermitian(B) + @test kron(A, B) ≈ kron(Hermitian(A), Hermitian(B)) + A, B = [Quaternion.(randn(3,3), randn(3, 3), randn(3, 3), randn(3,3)) |> t -> t + transpose(t) for i in 1:2] + @test A == Symmetric(A) && B == Symmetric(B) + @test kron(A, B) ≈ kron(Symmetric(A), Symmetric(B)) +end + #Issue #7647: test xsyevr, xheevr, xstevr drivers. @testset "Eigenvalues in interval for $(typeof(Mi7647))" for Mi7647 in (Symmetric(diagm(0 => 1.0:3.0)), diff --git a/stdlib/LinearAlgebra/test/triangular.jl b/stdlib/LinearAlgebra/test/triangular.jl index daed1e9ebab3f..9c23ec92fdc74 100644 --- a/stdlib/LinearAlgebra/test/triangular.jl +++ b/stdlib/LinearAlgebra/test/triangular.jl @@ -359,6 +359,7 @@ debug && println("Test basic type functionality") # Binary operations @test A1 + A2 == M1 + M2 @test A1 - A2 == M1 - M2 + @test kron(A1,A2) == kron(M1,M2) # Triangular-Triangular multiplication and division @test A1*A2 ≈ M1*M2 @@ -1014,6 +1015,7 @@ end @test 2\L == 2\B @test real(L) == real(B) @test imag(L) == imag(B) + @test kron(L,L) == kron(B,B) @test transpose!(MT(copy(A))) == transpose(L) broken=!(A isa Matrix) @test adjoint!(MT(copy(A))) == adjoint(L) broken=!(A isa Matrix) end @@ -1035,6 +1037,7 @@ end @test 2\U == 2\B @test real(U) == real(B) @test imag(U) == imag(B) + @test kron(U,U) == kron(B,B) @test transpose!(MT(copy(A))) == transpose(U) broken=!(A isa Matrix) @test adjoint!(MT(copy(A))) == adjoint(U) broken=!(A isa Matrix) end