diff --git a/NEWS.md b/NEWS.md index c3173351ec9c5..ea03e90990f59 100644 --- a/NEWS.md +++ b/NEWS.md @@ -33,7 +33,7 @@ Build system changes New library functions --------------------- - +* New function `Base.kron!` and corresponding overloads for various matrix types for performing Kronecker product in-place. ([#31069]). New library features -------------------- diff --git a/base/exports.jl b/base/exports.jl index 316025db9ce6c..3d11cc0481931 100644 --- a/base/exports.jl +++ b/base/exports.jl @@ -463,6 +463,7 @@ export adjoint, transpose, kron, + kron!, # bitarrays falses, diff --git a/base/operators.jl b/base/operators.jl index c304530dfef80..c7998b51756c9 100644 --- a/base/operators.jl +++ b/base/operators.jl @@ -542,6 +542,8 @@ for op in (:+, :*, :&, :|, :xor, :min, :max, :kron) end end +function kron! end + const var"'" = adjoint """ diff --git a/stdlib/LinearAlgebra/docs/src/index.md b/stdlib/LinearAlgebra/docs/src/index.md index b78ed785080e0..c5f0448bfa629 100644 --- a/stdlib/LinearAlgebra/docs/src/index.md +++ b/stdlib/LinearAlgebra/docs/src/index.md @@ -409,6 +409,7 @@ Base.inv(::AbstractMatrix) LinearAlgebra.pinv LinearAlgebra.nullspace Base.kron +Base.kron! LinearAlgebra.exp(::StridedMatrix{<:LinearAlgebra.BlasFloat}) Base.:^(::AbstractMatrix, ::Number) Base.:^(::Number, ::AbstractMatrix) diff --git a/stdlib/LinearAlgebra/src/LinearAlgebra.jl b/stdlib/LinearAlgebra/src/LinearAlgebra.jl index bb1dcb3c17ea7..e9476645d5e89 100644 --- a/stdlib/LinearAlgebra/src/LinearAlgebra.jl +++ b/stdlib/LinearAlgebra/src/LinearAlgebra.jl @@ -11,7 +11,7 @@ import Base: \, /, *, ^, +, -, == import Base: USE_BLAS64, abs, acos, acosh, acot, acoth, acsc, acsch, adjoint, asec, asech, asin, asinh, atan, atanh, axes, big, broadcast, ceil, conj, convert, copy, copyto!, cos, cosh, cot, coth, csc, csch, eltype, exp, fill!, floor, getindex, hcat, - getproperty, imag, inv, isapprox, isone, iszero, IndexStyle, kron, length, log, map, ndims, + getproperty, imag, inv, isapprox, isone, iszero, IndexStyle, kron, kron!, length, log, map, ndims, oneunit, parent, power_by_squaring, print_matrix, promote_rule, real, round, sec, sech, setindex!, show, similar, sin, sincos, sinh, size, sqrt, strides, stride, tan, tanh, transpose, trunc, typed_hcat, vec diff --git a/stdlib/LinearAlgebra/src/bitarray.jl b/stdlib/LinearAlgebra/src/bitarray.jl index 3e38b073a992b..d1857c3c38659 100644 --- a/stdlib/LinearAlgebra/src/bitarray.jl +++ b/stdlib/LinearAlgebra/src/bitarray.jl @@ -92,22 +92,29 @@ qr(A::BitMatrix) = qr(float(A)) ## kron -function kron(a::BitVector, b::BitVector) +@inline function kron!(R::BitVector, a::BitVector, b::BitVector) m = length(a) n = length(b) - R = falses(n * m) + @boundscheck length(R) == n*m || throw(DimensionMismatch()) Rc = R.chunks bc = b.chunks for j = 1:m a[j] && Base.copy_chunks!(Rc, (j-1)*n+1, bc, 1, n) end - R + return R end -function kron(a::BitMatrix, b::BitMatrix) +function kron(a::BitVector, b::BitVector) + m = length(a) + n = length(b) + R = falses(n * m) + return @inbounds kron!(R, a, b) +end + +function kron!(R::BitMatrix, a::BitMatrix, b::BitMatrix) mA,nA = size(a) mB,nB = size(b) - R = falses(mA*mB, nA*nB) + @boundscheck size(R) == (mA*mB, nA*nB) || throw(DimensionMismatch()) for i = 1:mA ri = (1:mB) .+ ((i-1)*mB) @@ -118,7 +125,14 @@ function kron(a::BitMatrix, b::BitMatrix) end end end - R + return R +end + +function kron(a::BitMatrix, b::BitMatrix) + mA,nA = size(a) + mB,nB = size(b) + R = falses(mA*mB, nA*nB) + return @inbounds kron!(R, a, b) end ## Structure query functions diff --git a/stdlib/LinearAlgebra/src/dense.jl b/stdlib/LinearAlgebra/src/dense.jl index d33cd809f5339..8a57a1eaf52cf 100644 --- a/stdlib/LinearAlgebra/src/dense.jl +++ b/stdlib/LinearAlgebra/src/dense.jl @@ -336,6 +336,29 @@ function tr(A::Matrix{T}) where T t end +""" + kron!(C, A, B) + +`kron!` is the in-place version of [`kron`](@ref). Computes `kron(A, B)` and stores the result in `C` +overwriting the existing value of `C`. + +!!! tip + Bounds checking can be disabled by [`@inbounds`](@ref), but you need to take care of the shape + of `C`, `A`, `B` yourself. +""" +@inline function kron!(C::AbstractMatrix, A::AbstractMatrix, B::AbstractMatrix) + require_one_based_indexing(A, B) + @boundscheck (size(C) == (size(A,1)*size(B,1), size(A,2)*size(B,2))) || throw(DimensionMismatch()) + m = 0 + @inbounds for j = 1:size(A,2), l = 1:size(B,2), i = 1:size(A,1) + Aij = A[i,j] + for k = 1:size(B,1) + C[m += 1] = Aij*B[k,l] + end + end + return C +end + """ kron(A, B) @@ -383,18 +406,23 @@ julia> reshape(kron(v,w), (length(w), length(v))) ``` """ function kron(a::AbstractMatrix{T}, b::AbstractMatrix{S}) where {T,S} - require_one_based_indexing(a, b) R = Matrix{promote_op(*,T,S)}(undef, size(a,1)*size(b,1), size(a,2)*size(b,2)) - m = 0 - @inbounds for j = 1:size(a,2), l = 1:size(b,2), i = 1:size(a,1) - aij = a[i,j] - for k = 1:size(b,1) - R[m += 1] = aij*b[k,l] - end - end - R + return @inbounds kron!(R, a, b) end +kron!(c::AbstractVecOrMat, a::AbstractVecOrMat, b::Number) = mul!(c, a, b) + +Base.@propagate_inbounds function kron!(c::AbstractVector, a::AbstractVector, b::AbstractVector) + C = reshape(c, length(a)*length(b), 1) + A = reshape(a ,length(a), 1) + B = reshape(b, length(b), 1) + kron!(C, A, B) + return c +end + +Base.@propagate_inbounds kron!(C::AbstractMatrix, a::AbstractMatrix, b::AbstractVector) = kron!(C, a, reshape(b, length(b), 1)) +Base.@propagate_inbounds kron!(C::AbstractMatrix, a::AbstractVector, b::AbstractMatrix) = kron!(C, reshape(a, length(a), 1), b) + kron(a::Number, b::Union{Number, AbstractVecOrMat}) = a * b kron(a::AbstractVecOrMat, b::Number) = a * b kron(a::AbstractVector, b::AbstractVector) = vec(kron(reshape(a ,length(a), 1), reshape(b, length(b), 1))) diff --git a/stdlib/LinearAlgebra/src/diagonal.jl b/stdlib/LinearAlgebra/src/diagonal.jl index f3b4ac17eec78..27e8ba4c80f3d 100644 --- a/stdlib/LinearAlgebra/src/diagonal.jl +++ b/stdlib/LinearAlgebra/src/diagonal.jl @@ -493,52 +493,80 @@ rdiv!(A::AbstractMatrix{T}, transD::Transpose{<:Any,<:Diagonal{T}}) where {T} = (\)(A::Union{QR,QRCompactWY,QRPivoted}, B::Diagonal) = invoke(\, Tuple{Union{QR,QRCompactWY,QRPivoted}, AbstractVecOrMat}, A, B) -function kron(A::Diagonal{T1}, B::Diagonal{T2}) where {T1<:Number, T2<:Number} + +@inline function kron!(C::AbstractMatrix{T}, A::Diagonal, B::Diagonal) where T + fill!(C, zero(T)) valA = A.diag; nA = length(valA) valB = B.diag; nB = length(valB) - valC = Vector{typeof(zero(T1)*zero(T2))}(undef,nA*nB) + nC = checksquare(C) + @boundscheck nC == nA*nB || + throw(DimensionMismatch("expect C to be a $(nA*nB)x$(nA*nB) matrix, got size $(nC)x$(nC)")) + @inbounds for i = 1:nA, j = 1:nB - valC[(i-1)*nB+j] = valA[i] * valB[j] + idx = (i-1)*nB+j + C[idx, idx] = valA[i] * valB[j] end - return Diagonal(valC) + return C end -function kron(A::Diagonal{T}, B::AbstractMatrix{S}) where {T<:Number, S<:Number} +function kron(A::Diagonal{T1}, B::Diagonal{T2}) where {T1<:Number, T2<:Number} + valA = A.diag; nA = length(valA) + valB = B.diag; nB = length(valB) + valC = Vector{typeof(zero(T1)*zero(T2))}(undef,nA*nB) + C = Diagonal(valC) + return @inbounds kron!(C, A, B) +end + +@inline function kron!(C::AbstractMatrix, A::Diagonal, B::AbstractMatrix) Base.require_one_based_indexing(B) - (mA, nA) = size(A); (mB, nB) = size(B) - R = zeros(Base.promote_op(*, T, S), mA * mB, nA * nB) + (mA, nA) = size(A); (mB, nB) = size(B); (mC, nC) = size(C); + @boundscheck (mC, nC) == (mA * mB, nA * nB) || + throw(DimensionMismatch("expect C to be a $(mA * mB)x$(nA * nB) matrix, got size $(mC)x$(nC)")) m = 1 - for j = 1:nA + @inbounds for j = 1:nA A_jj = A[j,j] for k = 1:nB for l = 1:mB - R[m] = A_jj * B[l,k] + C[m] = A_jj * B[l,k] m += 1 end m += (nA - 1) * mB end m += mB end - return R + return C end -function kron(A::AbstractMatrix{T}, B::Diagonal{S}) where {T<:Number, S<:Number} +@inline function kron!(C::AbstractMatrix, A::AbstractMatrix, B::Diagonal) require_one_based_indexing(A) - (mA, nA) = size(A); (mB, nB) = size(B) - R = zeros(promote_op(*, T, S), mA * mB, nA * nB) + (mA, nA) = size(A); (mB, nB) = size(B); (mC, nC) = size(C); + @boundscheck (mC, nC) == (mA * mB, nA * nB) || + throw(DimensionMismatch("expect C to be a $(mA * mB)x$(nA * nB) matrix, got size $(mC)x$(nC)")) m = 1 - for j = 1:nA + @inbounds for j = 1:nA for l = 1:mB Bll = B[l,l] for k = 1:mA - R[m] = A[k,j] * Bll + C[m] = A[k,j] * Bll m += nB end m += 1 end m -= nB end - return R + return C +end + +function kron(A::Diagonal{T}, B::AbstractMatrix{S}) where {T<:Number, S<:Number} + (mA, nA) = size(A); (mB, nB) = size(B) + R = zeros(Base.promote_op(*, T, S), mA * mB, nA * nB) + return @inbounds kron!(R, A, B) +end + +function kron(A::AbstractMatrix{T}, B::Diagonal{S}) where {T<:Number, S<:Number} + (mA, nA) = size(A); (mB, nB) = size(B) + R = zeros(promote_op(*, T, S), mA * mB, nA * nB) + return @inbounds kron!(R, A, B) end conj(D::Diagonal) = Diagonal(conj(D.diag)) diff --git a/stdlib/SparseArrays/src/SparseArrays.jl b/stdlib/SparseArrays/src/SparseArrays.jl index f4867be25a042..0616763205696 100644 --- a/stdlib/SparseArrays/src/SparseArrays.jl +++ b/stdlib/SparseArrays/src/SparseArrays.jl @@ -24,7 +24,7 @@ import Base: acos, acosd, acot, acotd, acsch, asech, asin, asind, asinh, tand, tanh, trunc, abs, abs2, broadcast, ceil, complex, conj, convert, copy, copyto!, adjoint, exp, expm1, findall, findmax, findmin, float, getindex, - vcat, hcat, hvcat, cat, imag, argmax, kron, length, log, log1p, max, min, + vcat, hcat, hvcat, cat, imag, argmax, kron, kron!, length, log, log1p, max, min, maximum, minimum, one, promote_eltype, real, reshape, rot180, rotl90, rotr90, round, setindex!, similar, size, transpose, vec, permute!, map, map!, Array, diff, circshift!, circshift diff --git a/stdlib/SparseArrays/src/linalg.jl b/stdlib/SparseArrays/src/linalg.jl index c9d25c58750bd..c54a6c760c0fd 100644 --- a/stdlib/SparseArrays/src/linalg.jl +++ b/stdlib/SparseArrays/src/linalg.jl @@ -1295,16 +1295,21 @@ function opnormestinv(A::AbstractSparseMatrixCSC{T}, t::Integer = min(2,maximum( end ## kron - -# sparse matrix ⊗ sparse matrix -function kron(A::AbstractSparseMatrixCSC{T1,S1}, B::AbstractSparseMatrixCSC{T2,S2}) where {T1,S1,T2,S2} +@inline function kron!(C::SparseMatrixCSC, A::AbstractSparseMatrixCSC, B::AbstractSparseMatrixCSC) nnzC = nnz(A)*nnz(B) mA, nA = size(A); mB, nB = size(B) mC, nC = mA*mB, nA*nB - colptrC = Vector{promote_type(S1,S2)}(undef, nC+1) - rowvalC = Vector{promote_type(S1,S2)}(undef, nnzC) - nzvalC = Vector{typeof(one(T1)*one(T2))}(undef, nnzC) - colptrC[1] = 1 + + rowvalC = rowvals(C) + nzvalC = nonzeros(C) + colptrC = getcolptr(C) + + @boundscheck begin + length(colptrC) == nC+1 || throw(DimensionMismatch("expect C to be preallocated with $(nC+1) colptrs ")) + length(rowvalC) == nnzC || throw(DimensionMismatch("expect C to be preallocated with $(nnzC) rowvals")) + length(nzvalC) == nnzC || throw(DimensionMismatch("expect C to be preallocated with $(nnzC) nzvals")) + end + col = 1 @inbounds for j = 1:nA startA = getcolptr(A)[j] @@ -1328,7 +1333,43 @@ function kron(A::AbstractSparseMatrixCSC{T1,S1}, B::AbstractSparseMatrixCSC{T2,S end end end - return SparseMatrixCSC(mC, nC, colptrC, rowvalC, nzvalC) + return C +end + +@inline function kron!(z::SparseVector, x::SparseVector, y::SparseVector) + nnzx = nnz(x); nnzy = nnz(y); nnzz = nnz(z); + nzind = nonzeroinds(z) + nzval = nonzeros(z) + + @boundscheck begin + nnzval = length(nzval); nnzind = length(nzind) + nnzz = nnzx*nnzy + nnzval == nnzz || throw(DimensionMismatch("expect z to be preallocated with $nnzz nonzeros")) + nnzind == nnzz || throw(DimensionMismatch("expect z to be preallocated with $nnzz nonzeros")) + end + + @inbounds for i = 1:nnzx, j = 1:nnzy + this_ind = (i-1)*nnzy+j + nzind[this_ind] = (nonzeroinds(x)[i]-1)*length(y) + nonzeroinds(y)[j] + nzval[this_ind] = nonzeros(x)[i] * nonzeros(y)[j] + end + return z +end + +# sparse matrix ⊗ sparse matrix +function kron(A::AbstractSparseMatrixCSC{T1,S1}, B::AbstractSparseMatrixCSC{T2,S2}) where {T1,S1,T2,S2} + nnzC = nnz(A)*nnz(B) + mA, nA = size(A); mB, nB = size(B) + mC, nC = mA*mB, nA*nB + Tv = typeof(one(T1)*one(T2)) + Ti = promote_type(S1,S2) + colptrC = Vector{Ti}(undef, nC+1) + rowvalC = Vector{Ti}(undef, nnzC) + nzvalC = Vector{Tv}(undef, nnzC) + colptrC[1] = 1 + # skip sparse_check + C = SparseMatrixCSC{Tv, Ti}(mC, nC, colptrC, rowvalC, nzvalC) + return @inbounds kron!(C, A, B) end # sparse vector ⊗ sparse vector @@ -1337,27 +1378,33 @@ function kron(x::SparseVector{T1,S1}, y::SparseVector{T2,S2}) where {T1,S1,T2,S2 nnzz = nnzx*nnzy # number of nonzeros in new vector nzind = Vector{promote_type(S1,S2)}(undef, nnzz) # the indices of nonzeros nzval = Vector{typeof(one(T1)*one(T2))}(undef, nnzz) # the values of nonzeros - @inbounds for i = 1:nnzx, j = 1:nnzy - this_ind = (i-1)*nnzy+j - nzind[this_ind] = (nonzeroinds(x)[i]-1)*length(y::SparseVector) + nonzeroinds(y)[j] - nzval[this_ind] = nonzeros(x)[i] * nonzeros(y)[j] - end - return SparseVector(length(x::SparseVector)*length(y::SparseVector), nzind, nzval) + z = SparseVector(length(x)*length(y), nzind, nzval) + return @inbounds kron!(z, x, y) end # sparse matrix ⊗ sparse vector & vice versa +Base.@propagate_inbounds kron!(C::SparseMatrixCSC, A::AbstractSparseMatrixCSC, x::SparseVector) = kron!(C, A, SparseMatrixCSC(x)) +Base.@propagate_inbounds kron!(C::SparseMatrixCSC, x::SparseVector, A::AbstractSparseMatrixCSC) = kron!(C, SparseMatrixCSC(x), A) + kron(A::AbstractSparseMatrixCSC, x::SparseVector) = kron(A, SparseMatrixCSC(x)) kron(x::SparseVector, A::AbstractSparseMatrixCSC) = kron(SparseMatrixCSC(x), A) # sparse vec/mat ⊗ vec/mat and vice versa +Base.@propagate_inbounds kron!(C::SparseMatrixCSC, A::Union{SparseVector,AbstractSparseMatrixCSC}, B::VecOrMat) = kron!(C, A, sparse(B)) +Base.@propagate_inbounds kron!(C::SparseMatrixCSC, A::VecOrMat, B::Union{SparseVector,AbstractSparseMatrixCSC}) = kron!(C, sparse(A), B) + kron(A::Union{SparseVector,AbstractSparseMatrixCSC}, B::VecOrMat) = kron(A, sparse(B)) kron(A::VecOrMat, B::Union{SparseVector,AbstractSparseMatrixCSC}) = kron(sparse(A), B) # sparse vec/mat ⊗ Diagonal and vice versa +Base.@propagate_inbounds kron!(C::SparseMatrixCSC, A::Diagonal{T}, B::Union{SparseVector{S}, AbstractSparseMatrixCSC{S}}) where {T<:Number, S<:Number} = kron!(C, sparse(A), B) +Base.@propagate_inbounds kron!(C::SparseMatrixCSC, A::Union{SparseVector{T}, AbstractSparseMatrixCSC{T}}, B::Diagonal{S}) where {T<:Number, S<:Number} = kron!(C, A, sparse(B)) + kron(A::Diagonal{T}, B::Union{SparseVector{S}, AbstractSparseMatrixCSC{S}}) where {T<:Number, S<:Number} = kron(sparse(A), B) kron(A::Union{SparseVector{T}, AbstractSparseMatrixCSC{T}}, B::Diagonal{S}) where {T<:Number, S<:Number} = kron(A, sparse(B)) # sparse outer product +kron!(C::SparseMatrixCSC, A::SparseVectorUnion, B::AdjOrTransSparseVectorUnion) = broadcast!(*, C, A, B) kron(A::SparseVectorUnion, B::AdjOrTransSparseVectorUnion) = A .* B ## det, inv, cond