diff --git a/src/linalg.jl b/src/linalg.jl index 322b506d..7ef517ab 100644 --- a/src/linalg.jl +++ b/src/linalg.jl @@ -1483,7 +1483,7 @@ const _Symmetric_DenseArrays{T,A<:Matrix} = Symmetric{T,A} const _Hermitian_DenseArrays{T,A<:Matrix} = Hermitian{T,A} const _Triangular_DenseArrays{T,A<:Matrix} = UpperOrLowerTriangular{<:Any,A} # AbstractTriangular{T,A} const _Annotated_DenseArrays = Union{_SpecialArrays, _Triangular_DenseArrays, _Symmetric_DenseArrays, _Hermitian_DenseArrays} -const _DenseConcatGroup = Union{Number, Vector, Adjoint{<:Any,<:Vector}, Transpose{<:Any,<:Vector}, Matrix, _Annotated_DenseArrays} +const _DenseKronGroup = Union{Number, Vector, Matrix, AdjOrTrans{<:Any,<:VecOrMat}, _Annotated_DenseArrays} @inline function kron!(C::SparseMatrixCSC, A::AbstractSparseMatrixCSC, B::AbstractSparseMatrixCSC) mA, nA = size(A); mB, nB = size(B) @@ -1541,9 +1541,9 @@ end end return z end -kron!(C::SparseMatrixCSC, A::_SparseKronGroup, B::_DenseConcatGroup) = +kron!(C::SparseMatrixCSC, A::_SparseKronGroup, B::_DenseKronGroup) = kron!(C, convert(SparseMatrixCSC, A), convert(SparseMatrixCSC, B)) -kron!(C::SparseMatrixCSC, A::_DenseConcatGroup, B::_SparseKronGroup) = +kron!(C::SparseMatrixCSC, A::_DenseKronGroup, B::_SparseKronGroup) = kron!(C, convert(SparseMatrixCSC, A), convert(SparseMatrixCSC, B)) kron!(C::SparseMatrixCSC, A::_SparseKronGroup, B::_SparseKronGroup) = kron!(C, convert(SparseMatrixCSC, A), convert(SparseMatrixCSC, B)) @@ -1580,8 +1580,8 @@ end # extend to annotated sparse arrays, but leave out the (dense ⊗ dense)-case kron(A::_SparseKronGroup, B::_SparseKronGroup) = kron(convert(SparseMatrixCSC, A), convert(SparseMatrixCSC, B)) -kron(A::_SparseKronGroup, B::_DenseConcatGroup) = kron(A, sparse(B)) -kron(A::_DenseConcatGroup, B::_SparseKronGroup) = kron(sparse(A), B) +kron(A::_SparseKronGroup, B::_DenseKronGroup) = kron(A, sparse(B)) +kron(A::_DenseKronGroup, B::_SparseKronGroup) = kron(sparse(A), B) kron(A::_SparseVectorUnion, B::_AdjOrTransSparseVectorUnion) = A .* B # disambiguation kron(A::AbstractCompressedVector, B::AdjOrTrans{<:Any,<:AbstractCompressedVector}) = A .* B diff --git a/test/linalg.jl b/test/linalg.jl index c9aef837..a9396ebf 100644 --- a/test/linalg.jl +++ b/test/linalg.jl @@ -751,31 +751,27 @@ end @test Array(kron(t(a), c_di)::SparseMatrixCSC) == kron(t(a_d), c_d) @test Array(kron(a, t(c_di))::SparseMatrixCSC) == kron(a_d, t(c_d)) @test Array(kron(t(a), t(c_di))::SparseMatrixCSC) == kron(t(a_d), t(c_d)) - @test issparse(kron(c_di, y)) - @test Array(kron(c_di, y)) == kron(c_di, y_d) - @test issparse(kron(x, d_di)) - @test Array(kron(x, d_di)) == kron(x_d, d_di) + @test Array(kron(c_di, y)::SparseMatrixCSC) == kron(c_di, y_d) + @test Array(kron(x, d_di)::SparseMatrixCSC) == kron(x_d, d_di) end end # vec ⊗ vec - @test Vector(kron(x, y)) == kron(x_d, y_d) - @test Vector(kron(x_d, y)) == kron(x_d, y_d) - @test Vector(kron(x, y_d)) == kron(x_d, y_d) + @test Vector(kron(x, y)::SparseVector) == kron(x_d, y_d) + @test Vector(kron(x_d, y)::SparseVector) == kron(x_d, y_d) + @test Vector(kron(x, y_d)::SparseVector) == kron(x_d, y_d) for t in (identity, adjoint, transpose) # mat ⊗ vec @test Array(kron(t(a), y)::SparseMatrixCSC) == kron(t(a_d), y_d) - @test Array(kron(t(a_d), y)) == kron(t(a_d), y_d) + @test Array(kron(t(a_d), y)::SparseMatrixCSC) == kron(t(a_d), y_d) @test Array(kron(t(a), y_d)::SparseMatrixCSC) == kron(t(a_d), y_d) # vec ⊗ mat @test Array(kron(x, t(b))::SparseMatrixCSC) == kron(x_d, t(b_d)) @test Array(kron(x_d, t(b))::SparseMatrixCSC) == kron(x_d, t(b_d)) - @test Array(kron(x, t(b_d))) == kron(x_d, t(b_d)) + @test Array(kron(x, t(b_d))::SparseMatrixCSC) == kron(x_d, t(b_d)) end # vec ⊗ vec' - @test issparse(kron(v, y')) - @test issparse(kron(x, y')) - @test Array(kron(v, y')) == kron(v_d, y_d') - @test Array(kron(x, y')) == kron(x_d, y_d') + @test Array(kron(v, y')::SparseMatrixCSC) == kron(v_d, y_d') + @test Array(kron(x, y')::SparseMatrixCSC) == kron(x_d, y_d') # test different types z = convert(SparseVector{Float16, Int8}, y); z_d = Vector(z) @test Vector(kron(x, z)) == kron(x_d, z_d)