From 805706c5139da22245948b3543bb46c84c968679 Mon Sep 17 00:00:00 2001 From: Jishnu Bhattacharya Date: Tue, 15 Oct 2024 13:15:12 +0530 Subject: [PATCH] Fix zero elements for block-matrix kron involving Diagonal (#55941) --- stdlib/LinearAlgebra/src/diagonal.jl | 70 ++++++++++++++++++++++++--- stdlib/LinearAlgebra/test/diagonal.jl | 10 ++++ 2 files changed, 73 insertions(+), 7 deletions(-) diff --git a/stdlib/LinearAlgebra/src/diagonal.jl b/stdlib/LinearAlgebra/src/diagonal.jl index 830332db0dbb1a..eb0aec7726bc2b 100644 --- a/stdlib/LinearAlgebra/src/diagonal.jl +++ b/stdlib/LinearAlgebra/src/diagonal.jl @@ -634,16 +634,33 @@ for Tri in (:UpperTriangular, :LowerTriangular) end @inline function kron!(C::AbstractMatrix, A::Diagonal, B::Diagonal) - valA = A.diag; nA = length(valA) - valB = B.diag; nB = length(valB) + valA = A.diag; mA, nA = size(A) + valB = B.diag; mB, nB = size(B) nC = checksquare(C) @boundscheck nC == nA*nB || throw(DimensionMismatch(lazy"expect C to be a $(nA*nB)x$(nA*nB) matrix, got size $(nC)x$(nC)")) - isempty(A) || isempty(B) || fill!(C, zero(A[1,1] * B[1,1])) + zerofilled = false + if !(isempty(A) || isempty(B)) + z = A[1,1] * B[1,1] + if haszero(typeof(z)) + # in this case, the zero is unique + fill!(C, zero(z)) + zerofilled = true + end + end @inbounds for i = 1:nA, j = 1:nB idx = (i-1)*nB+j C[idx, idx] = valA[i] * valB[j] end + if !zerofilled + for j in 1:nA, i in 1:mA + Δrow, Δcol = (i-1)*mB, (j-1)*nB + for k in 1:nB, l in 1:mB + i == j && k == l && continue + C[Δrow + l, Δcol + k] = A[i,j] * B[l,k] + end + end + end return C end @@ -670,7 +687,15 @@ end (mC, nC) = size(C) @boundscheck (mC, nC) == (mA * mB, nA * nB) || throw(DimensionMismatch(lazy"expect C to be a $(mA * mB)x$(nA * nB) matrix, got size $(mC)x$(nC)")) - isempty(A) || isempty(B) || fill!(C, zero(A[1,1] * B[1,1])) + zerofilled = false + if !(isempty(A) || isempty(B)) + z = A[1,1] * B[1,1] + if haszero(typeof(z)) + # in this case, the zero is unique + fill!(C, zero(z)) + zerofilled = true + end + end m = 1 @inbounds for j = 1:nA A_jj = A[j,j] @@ -681,6 +706,18 @@ end end m += (nA - 1) * mB end + if !zerofilled + # populate the zero elements + for i in 1:mA + i == j && continue + A_ij = A[i, j] + Δrow, Δcol = (i-1)*mB, (j-1)*nB + for k in 1:nB, l in 1:nA + B_lk = B[l, k] + C[Δrow + l, Δcol + k] = A_ij * B_lk + end + end + end m += mB end return C @@ -693,17 +730,36 @@ end (mC, nC) = size(C) @boundscheck (mC, nC) == (mA * mB, nA * nB) || throw(DimensionMismatch(lazy"expect C to be a $(mA * mB)x$(nA * nB) matrix, got size $(mC)x$(nC)")) - isempty(A) || isempty(B) || fill!(C, zero(A[1,1] * B[1,1])) + zerofilled = false + if !(isempty(A) || isempty(B)) + z = A[1,1] * B[1,1] + if haszero(typeof(z)) + # in this case, the zero is unique + fill!(C, zero(z)) + zerofilled = true + end + end m = 1 @inbounds for j = 1:nA for l = 1:mB Bll = B[l,l] - for k = 1:mA - C[m] = A[k,j] * Bll + for i = 1:mA + C[m] = A[i,j] * Bll m += nB end m += 1 end + if !zerofilled + for i in 1:mA + A_ij = A[i, j] + Δrow, Δcol = (i-1)*mB, (j-1)*nB + for k in 1:nB, l in 1:mB + l == k && continue + B_lk = B[l, k] + C[Δrow + l, Δcol + k] = A_ij * B_lk + end + end + end m -= nB end return C diff --git a/stdlib/LinearAlgebra/test/diagonal.jl b/stdlib/LinearAlgebra/test/diagonal.jl index 7049dd784faa8a..cfd8a5277e5f04 100644 --- a/stdlib/LinearAlgebra/test/diagonal.jl +++ b/stdlib/LinearAlgebra/test/diagonal.jl @@ -1323,4 +1323,14 @@ end @test checkbounds(Bool, D, diagind(D, IndexCartesian())) end +@testset "zeros in kron with block matrices" begin + D = Diagonal(1:2) + B = reshape([ones(2,2), ones(3,2), ones(2,3), ones(3,3)], 2, 2) + @test kron(D, B) == kron(Array(D), B) + @test kron(B, D) == kron(B, Array(D)) + D2 = Diagonal([ones(2,2), ones(3,3)]) + @test kron(D, D2) == Diagonal([diag(D2); 2diag(D2)]) + @test kron(D2, D) == Diagonal([ones(2,2), fill(2.0,2,2), ones(3,3), fill(2.0,3,3)]) +end + end # module TestDiagonal