Skip to content

Commit

Permalink
stdlib: faster kronecker product between hermitian and symmetric matr…
Browse files Browse the repository at this point in the history
…ices (JuliaLang#53186)

The kronecker product between complex hermitian matrices is again
hermitian, so it can be computed much faster by only doing the upper (or
lower) triangular. As @andreasnoack will surely notice, this only true
for types where `conj(a*b) == conj(a)*conj(b)`, so I'm restricting the
function to act only on real and complex numbers. In the symmetric case,
however, no additional assumption is needed, so I'm letting it act on
anything.

Benchmarking showed that the code is roughly 2 times as fast as the
vanilla kronecker product, as expected. The fastest case was always the
UU case, and the slowest the LU case. The code I used is below
```julia
using LinearAlgebra
using BenchmarkTools
using Quaternions

randrmatrix(d, uplo = :U) = Hermitian(randn(Float64, d, d), uplo)
randcmatrix(d, uplo = :U) = Hermitian(randn(ComplexF64, d, d), uplo)
randsmatrix(d, uplo = :U) = Symmetric(randn(ComplexF64, d, d), uplo)
randqmatrix(d, uplo = :U) = Symmetric(randn(QuaternionF64, d, d), uplo)

dima = 69
dimb = 71
for randmatrix in [randrmatrix, randcmatrix, randsmatrix, randqmatrix]
    for auplo in [:U, :L]
        for buplo in [:U, :L]
            a = randmatrix(dima, auplo)
            b = randmatrix(dimb, buplo)
            c = kron(a,b)
            therm = @belapsed kron!($c, $a, $b)
            C = Matrix(c)
            A = Matrix(a)
            B = Matrix(b)
            told = @belapsed kron!($C, $A, $B)
            @show told/therm
        end
    end
end
```
Weirdly enough, I got this expected speedup in one of my machines, but
when running the benchmark in another I got roughly the same time. I
guess that's a bug with `BechmarkTools`, because that's not consistent
with the times I get running the functions individually, out of the
loop.

Another issue is that although I added a couple of tests, I couldn't get
them to run. Perhaps someone here can tell me what's going on? I could
run the tests from LinearAlgebra, it's just that editing the files made
no difference to what was being run. I did get hundreds of errors from
`triangular.jl`, but that's untouched by my code.

---------

Co-authored-by: Oscar Smith <oscardssmith@gmail.com>
  • Loading branch information
araujoms and oscardssmith authored Apr 18, 2024
1 parent 0f7674e commit c741bd3
Show file tree
Hide file tree
Showing 5 changed files with 238 additions and 2 deletions.
4 changes: 2 additions & 2 deletions stdlib/LinearAlgebra/src/dense.jl
Original file line number Diff line number Diff line change
Expand Up @@ -491,8 +491,8 @@ julia> reshape(kron(v,w), (length(w), length(v)))
```
"""
function kron(A::AbstractVecOrMat{T}, B::AbstractVecOrMat{S}) where {T,S}
R = Matrix{promote_op(*,T,S)}(undef, _kronsize(A, B))
return kron!(R, A, B)
C = Matrix{promote_op(*,T,S)}(undef, _kronsize(A, B))
return kron!(C, A, B)
end
function kron(a::AbstractVector{T}, b::AbstractVector{S}) where {T,S}
c = Vector{promote_op(*,T,S)}(undef, length(a)*length(b))
Expand Down
124 changes: 124 additions & 0 deletions stdlib/LinearAlgebra/src/symmetric.jl
Original file line number Diff line number Diff line change
Expand Up @@ -525,6 +525,130 @@ for (T, trans, real) in [(:Symmetric, :transpose, :identity), (:(Hermitian{<:Uni
end
end

function kron(A::Hermitian{T}, B::Hermitian{S}) where {T<:Union{Real,Complex},S<:Union{Real,Complex}}
resultuplo = A.uplo == 'U' || B.uplo == 'U' ? :U : :L
C = Hermitian(Matrix{promote_op(*, T, S)}(undef, _kronsize(A, B)), resultuplo)
return kron!(C, A, B)
end

function kron(A::Symmetric{T}, B::Symmetric{S}) where {T<:Number,S<:Number}
resultuplo = A.uplo == 'U' || B.uplo == 'U' ? :U : :L
C = Symmetric(Matrix{promote_op(*, T, S)}(undef, _kronsize(A, B)), resultuplo)
return kron!(C, A, B)
end

function kron!(C::Hermitian{<:Union{Real,Complex}}, A::Hermitian{<:Union{Real,Complex}}, B::Hermitian{<:Union{Real,Complex}})
size(C) == _kronsize(A, B) || throw(DimensionMismatch("kron!"))
if ((A.uplo == 'U' || B.uplo == 'U') && C.uplo != 'U') || ((A.uplo == 'L' && B.uplo == 'L') && C.uplo != 'L')
throw(ArgumentError("C.uplo must match A.uplo and B.uplo, got $(C.uplo) $(A.uplo) $(B.uplo)"))
end
_hermkron!(C.data, A.data, B.data, conj, real, A.uplo, B.uplo)
return C
end

function kron!(C::Symmetric{<:Number}, A::Symmetric{<:Number}, B::Symmetric{<:Number})
size(C) == _kronsize(A, B) || throw(DimensionMismatch("kron!"))
if ((A.uplo == 'U' || B.uplo == 'U') && C.uplo != 'U') || ((A.uplo == 'L' && B.uplo == 'L') && C.uplo != 'L')
throw(ArgumentError("C.uplo must match A.uplo and B.uplo, got $(C.uplo) $(A.uplo) $(B.uplo)"))
end
_hermkron!(C.data, A.data, B.data, identity, identity, A.uplo, B.uplo)
return C
end

function _hermkron!(C, A, B, conj::TC, real::TR, Auplo, Buplo) where {TC,TR}
n_A = size(A, 1)
n_B = size(B, 1)
@inbounds if Auplo == 'U' && Buplo == 'U'
for j = 1:n_A
jnB = (j - 1) * n_B
for i = 1:(j-1)
Aij = A[i, j]
inB = (i - 1) * n_B
for l = 1:n_B
for k = 1:(l-1)
C[inB+k, jnB+l] = Aij * B[k, l]
C[inB+l, jnB+k] = Aij * conj(B[k, l])
end
C[inB+l, jnB+l] = Aij * real(B[l, l])
end
end
Ajj = real(A[j, j])
for l = 1:n_B
for k = 1:(l-1)
C[jnB+k, jnB+l] = Ajj * B[k, l]
end
C[jnB+l, jnB+l] = Ajj * real(B[l, l])
end
end
elseif Auplo == 'U' && Buplo == 'L'
for j = 1:n_A
jnB = (j - 1) * n_B
for i = 1:(j-1)
Aij = A[i, j]
inB = (i - 1) * n_B
for l = 1:n_B
C[inB+l, jnB+l] = Aij * real(B[l, l])
for k = (l+1):n_B
C[inB+l, jnB+k] = Aij * conj(B[k, l])
C[inB+k, jnB+l] = Aij * B[k, l]
end
end
end
Ajj = real(A[j, j])
for l = 1:n_B
C[jnB+l, jnB+l] = Ajj * real(B[l, l])
for k = (l+1):n_B
C[jnB+l, jnB+k] = Ajj * conj(B[k, l])
end
end
end
elseif Auplo == 'L' && Buplo == 'U'
for j = 1:n_A
jnB = (j - 1) * n_B
Ajj = real(A[j, j])
for l = 1:n_B
for k = 1:(l-1)
C[jnB+k, jnB+l] = Ajj * B[k, l]
end
C[jnB+l, jnB+l] = Ajj * real(B[l, l])
end
for i = (j+1):n_A
conjAij = conj(A[i, j])
inB = (i - 1) * n_B
for l = 1:n_B
for k = 1:(l-1)
C[jnB+k, inB+l] = conjAij * B[k, l]
C[jnB+l, inB+k] = conjAij * conj(B[k, l])
end
C[jnB+l, inB+l] = conjAij * real(B[l, l])
end
end
end
else #if Auplo == 'L' && Buplo == 'L'
for j = 1:n_A
jnB = (j - 1) * n_B
Ajj = real(A[j, j])
for l = 1:n_B
C[jnB+l, jnB+l] = Ajj * real(B[l, l])
for k = (l+1):n_B
C[jnB+k, jnB+l] = Ajj * B[k, l]
end
end
for i = (j+1):n_A
Aij = A[i, j]
inB = (i - 1) * n_B
for l = 1:n_B
C[inB+l, jnB+l] = Aij * real(B[l, l])
for k = (l+1):n_B
C[inB+k, jnB+l] = Aij * B[k, l]
C[inB+l, jnB+k] = Aij * conj(B[k, l])
end
end
end
end
end
end

(-)(A::Symmetric) = Symmetric(parentof_applytri(-, A), sym_uplo(A.uplo))
(-)(A::Hermitian) = Hermitian(parentof_applytri(-, A), sym_uplo(A.uplo))

Expand Down
74 changes: 74 additions & 0 deletions stdlib/LinearAlgebra/src/triangular.jl
Original file line number Diff line number Diff line change
Expand Up @@ -757,6 +757,80 @@ for op in (:+, :-)
end
end

function kron(A::UpperTriangular{T}, B::UpperTriangular{S}) where {T<:Number,S<:Number}
C = UpperTriangular(Matrix{promote_op(*, T, S)}(undef, _kronsize(A, B)))
return kron!(C, A, B)
end

function kron(A::LowerTriangular{T}, B::LowerTriangular{S}) where {T<:Number,S<:Number}
C = LowerTriangular(Matrix{promote_op(*, T, S)}(undef, _kronsize(A, B)))
return kron!(C, A, B)
end

function kron!(C::UpperTriangular{<:Number}, A::UpperTriangular{<:Number}, B::UpperTriangular{<:Number})
size(C) == _kronsize(A, B) || throw(DimensionMismatch("kron!"))
_triukron!(C.data, A.data, B.data)
return C
end

function kron!(C::LowerTriangular{<:Number}, A::LowerTriangular{<:Number}, B::LowerTriangular{<:Number})
size(C) == _kronsize(A, B) || throw(DimensionMismatch("kron!"))
_trilkron!(C.data, A.data, B.data)
return C
end

function _triukron!(C, A, B)
n_A = size(A, 1)
n_B = size(B, 1)
@inbounds for j = 1:n_A
jnB = (j - 1) * n_B
for i = 1:(j-1)
Aij = A[i, j]
inB = (i - 1) * n_B
for l = 1:n_B
for k = 1:l
C[inB+k, jnB+l] = Aij * B[k, l]
end
for k = 1:(l-1)
C[inB+l, jnB+k] = zero(eltype(C))
end
end
end
Ajj = A[j, j]
for l = 1:n_B
for k = 1:l
C[jnB+k, jnB+l] = Ajj * B[k, l]
end
end
end
end

function _trilkron!(C, A, B)
n_A = size(A, 1)
n_B = size(B, 1)
@inbounds for j = 1:n_A
jnB = (j - 1) * n_B
Ajj = A[j, j]
for l = 1:n_B
for k = l:n_B
C[jnB+k, jnB+l] = Ajj * B[k, l]
end
end
for i = (j+1):n_A
Aij = A[i, j]
inB = (i - 1) * n_B
for l = 1:n_B
for k = l:n_B
C[inB+k, jnB+l] = Aij * B[k, l]
end
for k = (l+1):n_B
C[inB+l, jnB+k] = zero(eltype(C))
end
end
end
end
end

######################
# BlasFloat routines #
######################
Expand Down
35 changes: 35 additions & 0 deletions stdlib/LinearAlgebra/test/symmetric.jl
Original file line number Diff line number Diff line change
Expand Up @@ -467,6 +467,28 @@ end
@test dot(symblockml, symblockml) dot(msymblockml, msymblockml)
end
end

@testset "kronecker product of symmetric and Hermitian matrices" begin
for mtype in (Symmetric, Hermitian)
symau = mtype(a, :U)
symal = mtype(a, :L)
msymau = Matrix(symau)
msymal = Matrix(symal)
for eltyc in (Float32, Float64, ComplexF32, ComplexF64, BigFloat, Int)
creal = randn(n, n)/2
cimag = randn(n, n)/2
c = eltya == Int ? rand(1:7, n, n) : convert(Matrix{eltya}, eltya <: Complex ? complex.(creal, cimag) : creal)
symcu = mtype(c, :U)
symcl = mtype(c, :L)
msymcu = Matrix(symcu)
msymcl = Matrix(symcl)
@test kron(symau, symcu) kron(msymau, msymcu)
@test kron(symau, symcl) kron(msymau, msymcl)
@test kron(symal, symcu) kron(msymal, msymcu)
@test kron(symal, symcl) kron(msymal, msymcl)
end
end
end
end
end

Expand All @@ -487,6 +509,7 @@ end
@test S - S == MS - MS
@test S*2 == 2*S == 2*MS
@test S/2 == MS/2
@test kron(S,S) == kron(MS,MS)
end
@testset "mixed uplo" begin
Mu = Matrix{Complex{BigFloat}}(undef,2,2)
Expand All @@ -502,6 +525,8 @@ end
MSl = Matrix(Sl)
@test Su + Sl == Sl + Su == MSu + MSl
@test Su - Sl == -(Sl - Su) == MSu - MSl
@test kron(Su,Sl) == kron(MSu,MSl)
@test kron(Sl,Su) == kron(MSl,MSu)
end
end
end
Expand All @@ -517,6 +542,16 @@ end
@test dot(A, B) dot(Symmetric(A), Symmetric(B))
end

# let's make sure the analogous bug will not show up with kronecker products
@testset "kron Hermitian quaternion #52318" begin
A, B = [Quaternion.(randn(3,3), randn(3, 3), randn(3, 3), randn(3,3)) |> t -> t + t' for i in 1:2]
@test A == Hermitian(A) && B == Hermitian(B)
@test kron(A, B) kron(Hermitian(A), Hermitian(B))
A, B = [Quaternion.(randn(3,3), randn(3, 3), randn(3, 3), randn(3,3)) |> t -> t + transpose(t) for i in 1:2]
@test A == Symmetric(A) && B == Symmetric(B)
@test kron(A, B) kron(Symmetric(A), Symmetric(B))
end

#Issue #7647: test xsyevr, xheevr, xstevr drivers.
@testset "Eigenvalues in interval for $(typeof(Mi7647))" for Mi7647 in
(Symmetric(diagm(0 => 1.0:3.0)),
Expand Down
3 changes: 3 additions & 0 deletions stdlib/LinearAlgebra/test/triangular.jl
Original file line number Diff line number Diff line change
Expand Up @@ -359,6 +359,7 @@ debug && println("Test basic type functionality")
# Binary operations
@test A1 + A2 == M1 + M2
@test A1 - A2 == M1 - M2
@test kron(A1,A2) == kron(M1,M2)

# Triangular-Triangular multiplication and division
@test A1*A2 M1*M2
Expand Down Expand Up @@ -1014,6 +1015,7 @@ end
@test 2\L == 2\B
@test real(L) == real(B)
@test imag(L) == imag(B)
@test kron(L,L) == kron(B,B)
@test transpose!(MT(copy(A))) == transpose(L) broken=!(A isa Matrix)
@test adjoint!(MT(copy(A))) == adjoint(L) broken=!(A isa Matrix)
end
Expand All @@ -1035,6 +1037,7 @@ end
@test 2\U == 2\B
@test real(U) == real(B)
@test imag(U) == imag(B)
@test kron(U,U) == kron(B,B)
@test transpose!(MT(copy(A))) == transpose(U) broken=!(A isa Matrix)
@test adjoint!(MT(copy(A))) == adjoint(U) broken=!(A isa Matrix)
end
Expand Down

0 comments on commit c741bd3

Please sign in to comment.