Skip to content

Commit

Permalink
kron for RectDiagonal fill (#272)
Browse files Browse the repository at this point in the history
* kron for RectDiagonal fill

* specialize sparse for Diagonal Fill
  • Loading branch information
jishnub authored Jul 7, 2023
1 parent e3c8eb9 commit 9be3ba2
Show file tree
Hide file tree
Showing 4 changed files with 45 additions and 4 deletions.
2 changes: 1 addition & 1 deletion Project.toml
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
name = "FillArrays"
uuid = "1a297f60-69ca-5386-bcde-b61e274b549b"
version = "1.3.0"
version = "1.4.0"

[deps]
LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e"
Expand Down
17 changes: 15 additions & 2 deletions src/FillArrays.jl
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,8 @@ import Base: size, getindex, setindex!, IndexStyle, checkbounds, convert,
+, -, *, /, \, diff, sum, cumsum, maximum, minimum, sort, sort!,
any, all, axes, isone, iterate, unique, allunique, permutedims, inv,
copy, vec, setindex!, count, ==, reshape, _throw_dmrs, map, zero,
show, view, in, mapreduce, one, reverse, promote_op, promote_rule, repeat
show, view, in, mapreduce, one, reverse, promote_op, promote_rule, repeat,
parent

import LinearAlgebra: rank, svdvals!, tril, triu, tril!, triu!, diag, transpose, adjoint, fill!,
dot, norm2, norm1, normInf, normMinusInf, normp, lmul!, rmul!, diagzero, AdjointAbsVec, TransposeAbsVec,
Expand Down Expand Up @@ -369,6 +370,8 @@ axes(T::UpperOrLowerTriangular{<:Any,<:AbstractFill}) = axes(parent(T))
axes(rd::RectDiagonal) = rd.axes
size(rd::RectDiagonal) = map(length, rd.axes)

parent(rd::RectDiagonal) = rd.diag

@inline function getindex(rd::RectDiagonal{T}, i::Integer, j::Integer) where T
@boundscheck checkbounds(rd, i, j)
if i == j
Expand Down Expand Up @@ -411,7 +414,8 @@ Base.replace_in_print_matrix(A::RectDiagonal, i::Integer, j::Integer, s::Abstrac


const RectOrDiagonal{T,V,Axes} = Union{RectDiagonal{T,V,Axes}, Diagonal{T,V}}
const RectDiagonalEye{T} = RectDiagonal{T,<:Ones{T,1}}
const RectOrDiagonalFill{T,V<:AbstractFillVector{T},Axes} = RectOrDiagonal{T,V,Axes}
const RectDiagonalFill{T,V<:AbstractFillVector{T}} = RectDiagonal{T,V}
const SquareEye{T,Axes} = Diagonal{T,Ones{T,1,Tuple{Axes}}}
const Eye{T,Axes} = RectOrDiagonal{T,Ones{T,1,Tuple{Axes}}}

Expand Down Expand Up @@ -537,6 +541,15 @@ convert(::Type{AbstractSparseArray{Tv,Ti}}, Z::Eye{T}) where {T,Tv,Ti} =
convert(::Type{AbstractSparseArray{Tv,Ti,2}}, Z::Eye{T}) where {T,Tv,Ti} =
convert(SparseMatrixCSC{Tv,Ti}, Z)

function SparseMatrixCSC{Tv}(R::RectOrDiagonalFill) where {Tv}
SparseMatrixCSC{Tv,eltype(axes(R,1))}(R)
end
function SparseMatrixCSC{Tv,Ti}(R::RectOrDiagonalFill) where {Tv,Ti}
Base.require_one_based_indexing(R)
v = parent(R)
J = getindex_value(v)*I
SparseMatrixCSC{Tv,Ti}(J, size(R))
end

#########
# maximum/minimum
Expand Down
2 changes: 1 addition & 1 deletion src/fillalgebra.jl
Original file line number Diff line number Diff line change
Expand Up @@ -453,4 +453,4 @@ function kron(f::AbstractFillVecOrMat, g::AbstractFillVecOrMat)
sz = _kronsize(f, g)
_kron(f, g, sz)
end
kron(E1::RectDiagonalEye, E2::RectDiagonalEye) = kron(sparse(E1), sparse(E2))
kron(E1::RectDiagonalFill, E2::RectDiagonalFill) = kron(sparse(E1), sparse(E2))
28 changes: 28 additions & 0 deletions test/runtests.jl
Original file line number Diff line number Diff line change
Expand Up @@ -514,6 +514,28 @@ end
convert(AbstractSparseMatrix{Float64,Int},Mat) ==
SMat
end

function testsparsediag(E)
S = @inferred SparseMatrixCSC(E)
@test S == E
S = @inferred SparseMatrixCSC{Float64}(E)
@test S == E
@test S isa SparseMatrixCSC{Float64}
@test convert(SparseMatrixCSC{Float64}, E) == S
S = @inferred SparseMatrixCSC{Float64,Int32}(E)
@test S == E
@test S isa SparseMatrixCSC{Float64,Int32}
@test convert(SparseMatrixCSC{Float64,Int32}, E) == S
end

for f in (Fill(Int8(4),3), Ones{Int8}(3), Zeros{Int8}(3))
E = Diagonal(f)
testsparsediag(E)
for sz in ((3,6), (6,3), (3,3))
E = RectDiagonal(f, sz)
testsparsediag(E)
end
end
end

@testset "==" begin
Expand Down Expand Up @@ -1534,6 +1556,12 @@ end
C = collect(E)
@test K == kron(C, C)
@test issparse(kron(E,E))

E = RectDiagonal(Fill(4,3), (6,3))
C = collect(E)
K = kron(E, E)
@test K == kron(C, C)
@test issparse(K)
end

@testset "dot products" begin
Expand Down

0 comments on commit 9be3ba2

Please sign in to comment.