diff --git a/Project.toml b/Project.toml index c4182cfd..a495df86 100644 --- a/Project.toml +++ b/Project.toml @@ -1,6 +1,6 @@ name = "FillArrays" uuid = "1a297f60-69ca-5386-bcde-b61e274b549b" -version = "1.3.0" +version = "1.4.0" [deps] LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e" diff --git a/src/FillArrays.jl b/src/FillArrays.jl index 3222a8f7..4b2dbe47 100644 --- a/src/FillArrays.jl +++ b/src/FillArrays.jl @@ -6,7 +6,8 @@ import Base: size, getindex, setindex!, IndexStyle, checkbounds, convert, +, -, *, /, \, diff, sum, cumsum, maximum, minimum, sort, sort!, any, all, axes, isone, iterate, unique, allunique, permutedims, inv, copy, vec, setindex!, count, ==, reshape, _throw_dmrs, map, zero, - show, view, in, mapreduce, one, reverse, promote_op, promote_rule, repeat + show, view, in, mapreduce, one, reverse, promote_op, promote_rule, repeat, + parent import LinearAlgebra: rank, svdvals!, tril, triu, tril!, triu!, diag, transpose, adjoint, fill!, dot, norm2, norm1, normInf, normMinusInf, normp, lmul!, rmul!, diagzero, AdjointAbsVec, TransposeAbsVec, @@ -369,6 +370,8 @@ axes(T::UpperOrLowerTriangular{<:Any,<:AbstractFill}) = axes(parent(T)) axes(rd::RectDiagonal) = rd.axes size(rd::RectDiagonal) = map(length, rd.axes) +parent(rd::RectDiagonal) = rd.diag + @inline function getindex(rd::RectDiagonal{T}, i::Integer, j::Integer) where T @boundscheck checkbounds(rd, i, j) if i == j @@ -411,7 +414,8 @@ Base.replace_in_print_matrix(A::RectDiagonal, i::Integer, j::Integer, s::Abstrac const RectOrDiagonal{T,V,Axes} = Union{RectDiagonal{T,V,Axes}, Diagonal{T,V}} -const RectDiagonalEye{T} = RectDiagonal{T,<:Ones{T,1}} +const RectOrDiagonalFill{T,V<:AbstractFillVector{T},Axes} = RectOrDiagonal{T,V,Axes} +const RectDiagonalFill{T,V<:AbstractFillVector{T}} = RectDiagonal{T,V} const SquareEye{T,Axes} = Diagonal{T,Ones{T,1,Tuple{Axes}}} const Eye{T,Axes} = RectOrDiagonal{T,Ones{T,1,Tuple{Axes}}} @@ -537,6 +541,15 @@ convert(::Type{AbstractSparseArray{Tv,Ti}}, Z::Eye{T}) where {T,Tv,Ti} = convert(::Type{AbstractSparseArray{Tv,Ti,2}}, Z::Eye{T}) where {T,Tv,Ti} = convert(SparseMatrixCSC{Tv,Ti}, Z) +function SparseMatrixCSC{Tv}(R::RectOrDiagonalFill) where {Tv} + SparseMatrixCSC{Tv,eltype(axes(R,1))}(R) +end +function SparseMatrixCSC{Tv,Ti}(R::RectOrDiagonalFill) where {Tv,Ti} + Base.require_one_based_indexing(R) + v = parent(R) + J = getindex_value(v)*I + SparseMatrixCSC{Tv,Ti}(J, size(R)) +end ######### # maximum/minimum diff --git a/src/fillalgebra.jl b/src/fillalgebra.jl index 6d5c8394..2d1d5eca 100644 --- a/src/fillalgebra.jl +++ b/src/fillalgebra.jl @@ -453,4 +453,4 @@ function kron(f::AbstractFillVecOrMat, g::AbstractFillVecOrMat) sz = _kronsize(f, g) _kron(f, g, sz) end -kron(E1::RectDiagonalEye, E2::RectDiagonalEye) = kron(sparse(E1), sparse(E2)) +kron(E1::RectDiagonalFill, E2::RectDiagonalFill) = kron(sparse(E1), sparse(E2)) diff --git a/test/runtests.jl b/test/runtests.jl index 1ffd2d5a..546be82a 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -514,6 +514,28 @@ end convert(AbstractSparseMatrix{Float64,Int},Mat) == SMat end + + function testsparsediag(E) + S = @inferred SparseMatrixCSC(E) + @test S == E + S = @inferred SparseMatrixCSC{Float64}(E) + @test S == E + @test S isa SparseMatrixCSC{Float64} + @test convert(SparseMatrixCSC{Float64}, E) == S + S = @inferred SparseMatrixCSC{Float64,Int32}(E) + @test S == E + @test S isa SparseMatrixCSC{Float64,Int32} + @test convert(SparseMatrixCSC{Float64,Int32}, E) == S + end + + for f in (Fill(Int8(4),3), Ones{Int8}(3), Zeros{Int8}(3)) + E = Diagonal(f) + testsparsediag(E) + for sz in ((3,6), (6,3), (3,3)) + E = RectDiagonal(f, sz) + testsparsediag(E) + end + end end @testset "==" begin @@ -1534,6 +1556,12 @@ end C = collect(E) @test K == kron(C, C) @test issparse(kron(E,E)) + + E = RectDiagonal(Fill(4,3), (6,3)) + C = collect(E) + K = kron(E, E) + @test K == kron(C, C) + @test issparse(K) end @testset "dot products" begin