Skip to content

Commit

Permalink
add inplace kron (#31069)
Browse files Browse the repository at this point in the history
  • Loading branch information
Roger-luo authored May 19, 2020
1 parent f0dd781 commit cebd4fa
Show file tree
Hide file tree
Showing 10 changed files with 169 additions and 48 deletions.
2 changes: 1 addition & 1 deletion NEWS.md
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,7 @@ Build system changes

New library functions
---------------------

* New function `Base.kron!` and corresponding overloads for various matrix types for performing Kronecker product in-place. ([#31069]).

New library features
--------------------
Expand Down
1 change: 1 addition & 0 deletions base/exports.jl
Original file line number Diff line number Diff line change
Expand Up @@ -463,6 +463,7 @@ export
adjoint,
transpose,
kron,
kron!,

# bitarrays
falses,
Expand Down
2 changes: 2 additions & 0 deletions base/operators.jl
Original file line number Diff line number Diff line change
Expand Up @@ -542,6 +542,8 @@ for op in (:+, :*, :&, :|, :xor, :min, :max, :kron)
end
end

function kron! end

const var"'" = adjoint

"""
Expand Down
1 change: 1 addition & 0 deletions stdlib/LinearAlgebra/docs/src/index.md
Original file line number Diff line number Diff line change
Expand Up @@ -409,6 +409,7 @@ Base.inv(::AbstractMatrix)
LinearAlgebra.pinv
LinearAlgebra.nullspace
Base.kron
Base.kron!
LinearAlgebra.exp(::StridedMatrix{<:LinearAlgebra.BlasFloat})
Base.:^(::AbstractMatrix, ::Number)
Base.:^(::Number, ::AbstractMatrix)
Expand Down
2 changes: 1 addition & 1 deletion stdlib/LinearAlgebra/src/LinearAlgebra.jl
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@ import Base: \, /, *, ^, +, -, ==
import Base: USE_BLAS64, abs, acos, acosh, acot, acoth, acsc, acsch, adjoint, asec, asech,
asin, asinh, atan, atanh, axes, big, broadcast, ceil, conj, convert, copy, copyto!, cos,
cosh, cot, coth, csc, csch, eltype, exp, fill!, floor, getindex, hcat,
getproperty, imag, inv, isapprox, isone, iszero, IndexStyle, kron, length, log, map, ndims,
getproperty, imag, inv, isapprox, isone, iszero, IndexStyle, kron, kron!, length, log, map, ndims,
oneunit, parent, power_by_squaring, print_matrix, promote_rule, real, round, sec, sech,
setindex!, show, similar, sin, sincos, sinh, size, sqrt,
strides, stride, tan, tanh, transpose, trunc, typed_hcat, vec
Expand Down
26 changes: 20 additions & 6 deletions stdlib/LinearAlgebra/src/bitarray.jl
Original file line number Diff line number Diff line change
Expand Up @@ -92,22 +92,29 @@ qr(A::BitMatrix) = qr(float(A))

## kron

function kron(a::BitVector, b::BitVector)
@inline function kron!(R::BitVector, a::BitVector, b::BitVector)
m = length(a)
n = length(b)
R = falses(n * m)
@boundscheck length(R) == n*m || throw(DimensionMismatch())
Rc = R.chunks
bc = b.chunks
for j = 1:m
a[j] && Base.copy_chunks!(Rc, (j-1)*n+1, bc, 1, n)
end
R
return R
end

function kron(a::BitMatrix, b::BitMatrix)
function kron(a::BitVector, b::BitVector)
m = length(a)
n = length(b)
R = falses(n * m)
return @inbounds kron!(R, a, b)
end

function kron!(R::BitMatrix, a::BitMatrix, b::BitMatrix)
mA,nA = size(a)
mB,nB = size(b)
R = falses(mA*mB, nA*nB)
@boundscheck size(R) == (mA*mB, nA*nB) || throw(DimensionMismatch())

for i = 1:mA
ri = (1:mB) .+ ((i-1)*mB)
Expand All @@ -118,7 +125,14 @@ function kron(a::BitMatrix, b::BitMatrix)
end
end
end
R
return R
end

function kron(a::BitMatrix, b::BitMatrix)
mA,nA = size(a)
mB,nB = size(b)
R = falses(mA*mB, nA*nB)
return @inbounds kron!(R, a, b)
end

## Structure query functions
Expand Down
46 changes: 37 additions & 9 deletions stdlib/LinearAlgebra/src/dense.jl
Original file line number Diff line number Diff line change
Expand Up @@ -336,6 +336,29 @@ function tr(A::Matrix{T}) where T
t
end

"""
kron!(C, A, B)
`kron!` is the in-place version of [`kron`](@ref). Computes `kron(A, B)` and stores the result in `C`
overwriting the existing value of `C`.
!!! tip
Bounds checking can be disabled by [`@inbounds`](@ref), but you need to take care of the shape
of `C`, `A`, `B` yourself.
"""
@inline function kron!(C::AbstractMatrix, A::AbstractMatrix, B::AbstractMatrix)
require_one_based_indexing(A, B)
@boundscheck (size(C) == (size(A,1)*size(B,1), size(A,2)*size(B,2))) || throw(DimensionMismatch())
m = 0
@inbounds for j = 1:size(A,2), l = 1:size(B,2), i = 1:size(A,1)
Aij = A[i,j]
for k = 1:size(B,1)
C[m += 1] = Aij*B[k,l]
end
end
return C
end

"""
kron(A, B)
Expand Down Expand Up @@ -383,18 +406,23 @@ julia> reshape(kron(v,w), (length(w), length(v)))
```
"""
function kron(a::AbstractMatrix{T}, b::AbstractMatrix{S}) where {T,S}
require_one_based_indexing(a, b)
R = Matrix{promote_op(*,T,S)}(undef, size(a,1)*size(b,1), size(a,2)*size(b,2))
m = 0
@inbounds for j = 1:size(a,2), l = 1:size(b,2), i = 1:size(a,1)
aij = a[i,j]
for k = 1:size(b,1)
R[m += 1] = aij*b[k,l]
end
end
R
return @inbounds kron!(R, a, b)
end

kron!(c::AbstractVecOrMat, a::AbstractVecOrMat, b::Number) = mul!(c, a, b)

Base.@propagate_inbounds function kron!(c::AbstractVector, a::AbstractVector, b::AbstractVector)
C = reshape(c, length(a)*length(b), 1)
A = reshape(a ,length(a), 1)
B = reshape(b, length(b), 1)
kron!(C, A, B)
return c
end

Base.@propagate_inbounds kron!(C::AbstractMatrix, a::AbstractMatrix, b::AbstractVector) = kron!(C, a, reshape(b, length(b), 1))
Base.@propagate_inbounds kron!(C::AbstractMatrix, a::AbstractVector, b::AbstractMatrix) = kron!(C, reshape(a, length(a), 1), b)

kron(a::Number, b::Union{Number, AbstractVecOrMat}) = a * b
kron(a::AbstractVecOrMat, b::Number) = a * b
kron(a::AbstractVector, b::AbstractVector) = vec(kron(reshape(a ,length(a), 1), reshape(b, length(b), 1)))
Expand Down
60 changes: 44 additions & 16 deletions stdlib/LinearAlgebra/src/diagonal.jl
Original file line number Diff line number Diff line change
Expand Up @@ -493,52 +493,80 @@ rdiv!(A::AbstractMatrix{T}, transD::Transpose{<:Any,<:Diagonal{T}}) where {T} =
(\)(A::Union{QR,QRCompactWY,QRPivoted}, B::Diagonal) =
invoke(\, Tuple{Union{QR,QRCompactWY,QRPivoted}, AbstractVecOrMat}, A, B)

function kron(A::Diagonal{T1}, B::Diagonal{T2}) where {T1<:Number, T2<:Number}

@inline function kron!(C::AbstractMatrix{T}, A::Diagonal, B::Diagonal) where T
fill!(C, zero(T))
valA = A.diag; nA = length(valA)
valB = B.diag; nB = length(valB)
valC = Vector{typeof(zero(T1)*zero(T2))}(undef,nA*nB)
nC = checksquare(C)
@boundscheck nC == nA*nB ||
throw(DimensionMismatch("expect C to be a $(nA*nB)x$(nA*nB) matrix, got size $(nC)x$(nC)"))

@inbounds for i = 1:nA, j = 1:nB
valC[(i-1)*nB+j] = valA[i] * valB[j]
idx = (i-1)*nB+j
C[idx, idx] = valA[i] * valB[j]
end
return Diagonal(valC)
return C
end

function kron(A::Diagonal{T}, B::AbstractMatrix{S}) where {T<:Number, S<:Number}
function kron(A::Diagonal{T1}, B::Diagonal{T2}) where {T1<:Number, T2<:Number}
valA = A.diag; nA = length(valA)
valB = B.diag; nB = length(valB)
valC = Vector{typeof(zero(T1)*zero(T2))}(undef,nA*nB)
C = Diagonal(valC)
return @inbounds kron!(C, A, B)
end

@inline function kron!(C::AbstractMatrix, A::Diagonal, B::AbstractMatrix)
Base.require_one_based_indexing(B)
(mA, nA) = size(A); (mB, nB) = size(B)
R = zeros(Base.promote_op(*, T, S), mA * mB, nA * nB)
(mA, nA) = size(A); (mB, nB) = size(B); (mC, nC) = size(C);
@boundscheck (mC, nC) == (mA * mB, nA * nB) ||
throw(DimensionMismatch("expect C to be a $(mA * mB)x$(nA * nB) matrix, got size $(mC)x$(nC)"))
m = 1
for j = 1:nA
@inbounds for j = 1:nA
A_jj = A[j,j]
for k = 1:nB
for l = 1:mB
R[m] = A_jj * B[l,k]
C[m] = A_jj * B[l,k]
m += 1
end
m += (nA - 1) * mB
end
m += mB
end
return R
return C
end

function kron(A::AbstractMatrix{T}, B::Diagonal{S}) where {T<:Number, S<:Number}
@inline function kron!(C::AbstractMatrix, A::AbstractMatrix, B::Diagonal)
require_one_based_indexing(A)
(mA, nA) = size(A); (mB, nB) = size(B)
R = zeros(promote_op(*, T, S), mA * mB, nA * nB)
(mA, nA) = size(A); (mB, nB) = size(B); (mC, nC) = size(C);
@boundscheck (mC, nC) == (mA * mB, nA * nB) ||
throw(DimensionMismatch("expect C to be a $(mA * mB)x$(nA * nB) matrix, got size $(mC)x$(nC)"))
m = 1
for j = 1:nA
@inbounds for j = 1:nA
for l = 1:mB
Bll = B[l,l]
for k = 1:mA
R[m] = A[k,j] * Bll
C[m] = A[k,j] * Bll
m += nB
end
m += 1
end
m -= nB
end
return R
return C
end

function kron(A::Diagonal{T}, B::AbstractMatrix{S}) where {T<:Number, S<:Number}
(mA, nA) = size(A); (mB, nB) = size(B)
R = zeros(Base.promote_op(*, T, S), mA * mB, nA * nB)
return @inbounds kron!(R, A, B)
end

function kron(A::AbstractMatrix{T}, B::Diagonal{S}) where {T<:Number, S<:Number}
(mA, nA) = size(A); (mB, nB) = size(B)
R = zeros(promote_op(*, T, S), mA * mB, nA * nB)
return @inbounds kron!(R, A, B)
end

conj(D::Diagonal) = Diagonal(conj(D.diag))
Expand Down
2 changes: 1 addition & 1 deletion stdlib/SparseArrays/src/SparseArrays.jl
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@ import Base: acos, acosd, acot, acotd, acsch, asech, asin, asind, asinh,
tand, tanh, trunc, abs, abs2,
broadcast, ceil, complex, conj, convert, copy, copyto!, adjoint,
exp, expm1, findall, findmax, findmin, float, getindex,
vcat, hcat, hvcat, cat, imag, argmax, kron, length, log, log1p, max, min,
vcat, hcat, hvcat, cat, imag, argmax, kron, kron!, length, log, log1p, max, min,
maximum, minimum, one, promote_eltype, real, reshape, rot180,
rotl90, rotr90, round, setindex!, similar, size, transpose,
vec, permute!, map, map!, Array, diff, circshift!, circshift
Expand Down
75 changes: 61 additions & 14 deletions stdlib/SparseArrays/src/linalg.jl
Original file line number Diff line number Diff line change
Expand Up @@ -1295,16 +1295,21 @@ function opnormestinv(A::AbstractSparseMatrixCSC{T}, t::Integer = min(2,maximum(
end

## kron

# sparse matrix ⊗ sparse matrix
function kron(A::AbstractSparseMatrixCSC{T1,S1}, B::AbstractSparseMatrixCSC{T2,S2}) where {T1,S1,T2,S2}
@inline function kron!(C::SparseMatrixCSC, A::AbstractSparseMatrixCSC, B::AbstractSparseMatrixCSC)
nnzC = nnz(A)*nnz(B)
mA, nA = size(A); mB, nB = size(B)
mC, nC = mA*mB, nA*nB
colptrC = Vector{promote_type(S1,S2)}(undef, nC+1)
rowvalC = Vector{promote_type(S1,S2)}(undef, nnzC)
nzvalC = Vector{typeof(one(T1)*one(T2))}(undef, nnzC)
colptrC[1] = 1

rowvalC = rowvals(C)
nzvalC = nonzeros(C)
colptrC = getcolptr(C)

@boundscheck begin
length(colptrC) == nC+1 || throw(DimensionMismatch("expect C to be preallocated with $(nC+1) colptrs "))
length(rowvalC) == nnzC || throw(DimensionMismatch("expect C to be preallocated with $(nnzC) rowvals"))
length(nzvalC) == nnzC || throw(DimensionMismatch("expect C to be preallocated with $(nnzC) nzvals"))
end

col = 1
@inbounds for j = 1:nA
startA = getcolptr(A)[j]
Expand All @@ -1328,7 +1333,43 @@ function kron(A::AbstractSparseMatrixCSC{T1,S1}, B::AbstractSparseMatrixCSC{T2,S
end
end
end
return SparseMatrixCSC(mC, nC, colptrC, rowvalC, nzvalC)
return C
end

@inline function kron!(z::SparseVector, x::SparseVector, y::SparseVector)
nnzx = nnz(x); nnzy = nnz(y); nnzz = nnz(z);
nzind = nonzeroinds(z)
nzval = nonzeros(z)

@boundscheck begin
nnzval = length(nzval); nnzind = length(nzind)
nnzz = nnzx*nnzy
nnzval == nnzz || throw(DimensionMismatch("expect z to be preallocated with $nnzz nonzeros"))
nnzind == nnzz || throw(DimensionMismatch("expect z to be preallocated with $nnzz nonzeros"))
end

@inbounds for i = 1:nnzx, j = 1:nnzy
this_ind = (i-1)*nnzy+j
nzind[this_ind] = (nonzeroinds(x)[i]-1)*length(y) + nonzeroinds(y)[j]
nzval[this_ind] = nonzeros(x)[i] * nonzeros(y)[j]
end
return z
end

# sparse matrix ⊗ sparse matrix
function kron(A::AbstractSparseMatrixCSC{T1,S1}, B::AbstractSparseMatrixCSC{T2,S2}) where {T1,S1,T2,S2}
nnzC = nnz(A)*nnz(B)
mA, nA = size(A); mB, nB = size(B)
mC, nC = mA*mB, nA*nB
Tv = typeof(one(T1)*one(T2))
Ti = promote_type(S1,S2)
colptrC = Vector{Ti}(undef, nC+1)
rowvalC = Vector{Ti}(undef, nnzC)
nzvalC = Vector{Tv}(undef, nnzC)
colptrC[1] = 1
# skip sparse_check
C = SparseMatrixCSC{Tv, Ti}(mC, nC, colptrC, rowvalC, nzvalC)
return @inbounds kron!(C, A, B)
end

# sparse vector ⊗ sparse vector
Expand All @@ -1337,27 +1378,33 @@ function kron(x::SparseVector{T1,S1}, y::SparseVector{T2,S2}) where {T1,S1,T2,S2
nnzz = nnzx*nnzy # number of nonzeros in new vector
nzind = Vector{promote_type(S1,S2)}(undef, nnzz) # the indices of nonzeros
nzval = Vector{typeof(one(T1)*one(T2))}(undef, nnzz) # the values of nonzeros
@inbounds for i = 1:nnzx, j = 1:nnzy
this_ind = (i-1)*nnzy+j
nzind[this_ind] = (nonzeroinds(x)[i]-1)*length(y::SparseVector) + nonzeroinds(y)[j]
nzval[this_ind] = nonzeros(x)[i] * nonzeros(y)[j]
end
return SparseVector(length(x::SparseVector)*length(y::SparseVector), nzind, nzval)
z = SparseVector(length(x)*length(y), nzind, nzval)
return @inbounds kron!(z, x, y)
end

# sparse matrix ⊗ sparse vector & vice versa
Base.@propagate_inbounds kron!(C::SparseMatrixCSC, A::AbstractSparseMatrixCSC, x::SparseVector) = kron!(C, A, SparseMatrixCSC(x))
Base.@propagate_inbounds kron!(C::SparseMatrixCSC, x::SparseVector, A::AbstractSparseMatrixCSC) = kron!(C, SparseMatrixCSC(x), A)

kron(A::AbstractSparseMatrixCSC, x::SparseVector) = kron(A, SparseMatrixCSC(x))
kron(x::SparseVector, A::AbstractSparseMatrixCSC) = kron(SparseMatrixCSC(x), A)

# sparse vec/mat ⊗ vec/mat and vice versa
Base.@propagate_inbounds kron!(C::SparseMatrixCSC, A::Union{SparseVector,AbstractSparseMatrixCSC}, B::VecOrMat) = kron!(C, A, sparse(B))
Base.@propagate_inbounds kron!(C::SparseMatrixCSC, A::VecOrMat, B::Union{SparseVector,AbstractSparseMatrixCSC}) = kron!(C, sparse(A), B)

kron(A::Union{SparseVector,AbstractSparseMatrixCSC}, B::VecOrMat) = kron(A, sparse(B))
kron(A::VecOrMat, B::Union{SparseVector,AbstractSparseMatrixCSC}) = kron(sparse(A), B)

# sparse vec/mat ⊗ Diagonal and vice versa
Base.@propagate_inbounds kron!(C::SparseMatrixCSC, A::Diagonal{T}, B::Union{SparseVector{S}, AbstractSparseMatrixCSC{S}}) where {T<:Number, S<:Number} = kron!(C, sparse(A), B)
Base.@propagate_inbounds kron!(C::SparseMatrixCSC, A::Union{SparseVector{T}, AbstractSparseMatrixCSC{T}}, B::Diagonal{S}) where {T<:Number, S<:Number} = kron!(C, A, sparse(B))

kron(A::Diagonal{T}, B::Union{SparseVector{S}, AbstractSparseMatrixCSC{S}}) where {T<:Number, S<:Number} = kron(sparse(A), B)
kron(A::Union{SparseVector{T}, AbstractSparseMatrixCSC{T}}, B::Diagonal{S}) where {T<:Number, S<:Number} = kron(A, sparse(B))

# sparse outer product
kron!(C::SparseMatrixCSC, A::SparseVectorUnion, B::AdjOrTransSparseVectorUnion) = broadcast!(*, C, A, B)
kron(A::SparseVectorUnion, B::AdjOrTransSparseVectorUnion) = A .* B

## det, inv, cond
Expand Down

0 comments on commit cebd4fa

Please sign in to comment.