From e326778650463a209542bcf665483e146b8228fd Mon Sep 17 00:00:00 2001 From: Jishnu Bhattacharya Date: Thu, 6 Jul 2023 15:33:16 +0530 Subject: [PATCH 1/2] kron for RectDiagonal fill --- Project.toml | 2 +- src/FillArrays.jl | 12 +++++++++++- src/fillalgebra.jl | 2 +- test/runtests.jl | 22 ++++++++++++++++++++++ 4 files changed, 35 insertions(+), 3 deletions(-) diff --git a/Project.toml b/Project.toml index c4182cfd..a495df86 100644 --- a/Project.toml +++ b/Project.toml @@ -1,6 +1,6 @@ name = "FillArrays" uuid = "1a297f60-69ca-5386-bcde-b61e274b549b" -version = "1.3.0" +version = "1.4.0" [deps] LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e" diff --git a/src/FillArrays.jl b/src/FillArrays.jl index 3222a8f7..b4389eda 100644 --- a/src/FillArrays.jl +++ b/src/FillArrays.jl @@ -411,7 +411,8 @@ Base.replace_in_print_matrix(A::RectDiagonal, i::Integer, j::Integer, s::Abstrac const RectOrDiagonal{T,V,Axes} = Union{RectDiagonal{T,V,Axes}, Diagonal{T,V}} -const RectDiagonalEye{T} = RectDiagonal{T,<:Ones{T,1}} +const RectDiagonalFill{T,V<:AbstractFillVector{T}} = RectDiagonal{T,V} +const RectDiagonalEye{T} = RectDiagonalFill{T,<:OnesVector{T}} const SquareEye{T,Axes} = Diagonal{T,Ones{T,1,Tuple{Axes}}} const Eye{T,Axes} = RectOrDiagonal{T,Ones{T,1,Tuple{Axes}}} @@ -537,6 +538,15 @@ convert(::Type{AbstractSparseArray{Tv,Ti}}, Z::Eye{T}) where {T,Tv,Ti} = convert(::Type{AbstractSparseArray{Tv,Ti,2}}, Z::Eye{T}) where {T,Tv,Ti} = convert(SparseMatrixCSC{Tv,Ti}, Z) +function SparseMatrixCSC{Tv}(R::RectDiagonalFill) where {Tv} + SparseMatrixCSC{Tv,eltype(axes(R,1))}(R) +end +function SparseMatrixCSC{Tv,Ti}(R::RectDiagonalFill) where {Tv,Ti} + Base.require_one_based_indexing(R) + v = R.diag + J = getindex_value(v)*I + SparseMatrixCSC{Tv,Ti}(J, size(R)) +end ######### # maximum/minimum diff --git a/src/fillalgebra.jl b/src/fillalgebra.jl index 6d5c8394..2d1d5eca 100644 --- a/src/fillalgebra.jl +++ b/src/fillalgebra.jl @@ -453,4 +453,4 @@ function kron(f::AbstractFillVecOrMat, g::AbstractFillVecOrMat) sz = _kronsize(f, g) _kron(f, g, sz) end -kron(E1::RectDiagonalEye, E2::RectDiagonalEye) = kron(sparse(E1), sparse(E2)) +kron(E1::RectDiagonalFill, E2::RectDiagonalFill) = kron(sparse(E1), sparse(E2)) diff --git a/test/runtests.jl b/test/runtests.jl index 1ffd2d5a..827f834b 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -514,6 +514,22 @@ end convert(AbstractSparseMatrix{Float64,Int},Mat) == SMat end + + for f in (Fill(Int8(4),3), Ones{Int8}(3), Zeros{Int8}(3)) + for sz in ((3,6), (6,3), (3,3)) + E = RectDiagonal(f, sz) + S = @inferred SparseMatrixCSC(E) + @test S == E + S = @inferred SparseMatrixCSC{Float64}(E) + @test S == E + @test S isa SparseMatrixCSC{Float64} + @test convert(SparseMatrixCSC{Float64}, E) == S + S = @inferred SparseMatrixCSC{Float64,Int32}(E) + @test S == E + @test S isa SparseMatrixCSC{Float64,Int32} + @test convert(SparseMatrixCSC{Float64,Int32}, E) == S + end + end end @testset "==" begin @@ -1534,6 +1550,12 @@ end C = collect(E) @test K == kron(C, C) @test issparse(kron(E,E)) + + E = RectDiagonal(Fill(4,3), (6,3)) + C = collect(E) + K = kron(E, E) + @test K == kron(C, C) + @test issparse(K) end @testset "dot products" begin From 1491582e95bf5e34307b410387458e3cd427a7d9 Mon Sep 17 00:00:00 2001 From: Jishnu Bhattacharya Date: Fri, 7 Jul 2023 12:14:30 +0530 Subject: [PATCH 2/2] specialize sparse for Diagonal Fill --- src/FillArrays.jl | 13 ++++++++----- test/runtests.jl | 26 ++++++++++++++++---------- 2 files changed, 24 insertions(+), 15 deletions(-) diff --git a/src/FillArrays.jl b/src/FillArrays.jl index b4389eda..4b2dbe47 100644 --- a/src/FillArrays.jl +++ b/src/FillArrays.jl @@ -6,7 +6,8 @@ import Base: size, getindex, setindex!, IndexStyle, checkbounds, convert, +, -, *, /, \, diff, sum, cumsum, maximum, minimum, sort, sort!, any, all, axes, isone, iterate, unique, allunique, permutedims, inv, copy, vec, setindex!, count, ==, reshape, _throw_dmrs, map, zero, - show, view, in, mapreduce, one, reverse, promote_op, promote_rule, repeat + show, view, in, mapreduce, one, reverse, promote_op, promote_rule, repeat, + parent import LinearAlgebra: rank, svdvals!, tril, triu, tril!, triu!, diag, transpose, adjoint, fill!, dot, norm2, norm1, normInf, normMinusInf, normp, lmul!, rmul!, diagzero, AdjointAbsVec, TransposeAbsVec, @@ -369,6 +370,8 @@ axes(T::UpperOrLowerTriangular{<:Any,<:AbstractFill}) = axes(parent(T)) axes(rd::RectDiagonal) = rd.axes size(rd::RectDiagonal) = map(length, rd.axes) +parent(rd::RectDiagonal) = rd.diag + @inline function getindex(rd::RectDiagonal{T}, i::Integer, j::Integer) where T @boundscheck checkbounds(rd, i, j) if i == j @@ -411,8 +414,8 @@ Base.replace_in_print_matrix(A::RectDiagonal, i::Integer, j::Integer, s::Abstrac const RectOrDiagonal{T,V,Axes} = Union{RectDiagonal{T,V,Axes}, Diagonal{T,V}} +const RectOrDiagonalFill{T,V<:AbstractFillVector{T},Axes} = RectOrDiagonal{T,V,Axes} const RectDiagonalFill{T,V<:AbstractFillVector{T}} = RectDiagonal{T,V} -const RectDiagonalEye{T} = RectDiagonalFill{T,<:OnesVector{T}} const SquareEye{T,Axes} = Diagonal{T,Ones{T,1,Tuple{Axes}}} const Eye{T,Axes} = RectOrDiagonal{T,Ones{T,1,Tuple{Axes}}} @@ -538,12 +541,12 @@ convert(::Type{AbstractSparseArray{Tv,Ti}}, Z::Eye{T}) where {T,Tv,Ti} = convert(::Type{AbstractSparseArray{Tv,Ti,2}}, Z::Eye{T}) where {T,Tv,Ti} = convert(SparseMatrixCSC{Tv,Ti}, Z) -function SparseMatrixCSC{Tv}(R::RectDiagonalFill) where {Tv} +function SparseMatrixCSC{Tv}(R::RectOrDiagonalFill) where {Tv} SparseMatrixCSC{Tv,eltype(axes(R,1))}(R) end -function SparseMatrixCSC{Tv,Ti}(R::RectDiagonalFill) where {Tv,Ti} +function SparseMatrixCSC{Tv,Ti}(R::RectOrDiagonalFill) where {Tv,Ti} Base.require_one_based_indexing(R) - v = R.diag + v = parent(R) J = getindex_value(v)*I SparseMatrixCSC{Tv,Ti}(J, size(R)) end diff --git a/test/runtests.jl b/test/runtests.jl index 827f834b..546be82a 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -515,19 +515,25 @@ end SMat end + function testsparsediag(E) + S = @inferred SparseMatrixCSC(E) + @test S == E + S = @inferred SparseMatrixCSC{Float64}(E) + @test S == E + @test S isa SparseMatrixCSC{Float64} + @test convert(SparseMatrixCSC{Float64}, E) == S + S = @inferred SparseMatrixCSC{Float64,Int32}(E) + @test S == E + @test S isa SparseMatrixCSC{Float64,Int32} + @test convert(SparseMatrixCSC{Float64,Int32}, E) == S + end + for f in (Fill(Int8(4),3), Ones{Int8}(3), Zeros{Int8}(3)) + E = Diagonal(f) + testsparsediag(E) for sz in ((3,6), (6,3), (3,3)) E = RectDiagonal(f, sz) - S = @inferred SparseMatrixCSC(E) - @test S == E - S = @inferred SparseMatrixCSC{Float64}(E) - @test S == E - @test S isa SparseMatrixCSC{Float64} - @test convert(SparseMatrixCSC{Float64}, E) == S - S = @inferred SparseMatrixCSC{Float64,Int32}(E) - @test S == E - @test S isa SparseMatrixCSC{Float64,Int32} - @test convert(SparseMatrixCSC{Float64,Int32}, E) == S + testsparsediag(E) end end end