From edbf40a9547abcce2319ed3a794f88e5f1fe0d45 Mon Sep 17 00:00:00 2001 From: Jishnu Bhattacharya Date: Mon, 10 Jul 2023 22:55:30 +0530 Subject: [PATCH] BroadcastStyle for RectDiagonal --- src/FillArrays.jl | 12 ++++++++++-- test/runtests.jl | 14 ++++++++++++++ 2 files changed, 24 insertions(+), 2 deletions(-) diff --git a/src/FillArrays.jl b/src/FillArrays.jl index 4b2dbe47..05ebcadb 100644 --- a/src/FillArrays.jl +++ b/src/FillArrays.jl @@ -7,14 +7,14 @@ import Base: size, getindex, setindex!, IndexStyle, checkbounds, convert, any, all, axes, isone, iterate, unique, allunique, permutedims, inv, copy, vec, setindex!, count, ==, reshape, _throw_dmrs, map, zero, show, view, in, mapreduce, one, reverse, promote_op, promote_rule, repeat, - parent + parent, similar import LinearAlgebra: rank, svdvals!, tril, triu, tril!, triu!, diag, transpose, adjoint, fill!, dot, norm2, norm1, normInf, normMinusInf, normp, lmul!, rmul!, diagzero, AdjointAbsVec, TransposeAbsVec, issymmetric, ishermitian, AdjOrTransAbsVec, checksquare, mul!, kron -import Base.Broadcast: broadcasted, DefaultArrayStyle, broadcast_shape +import Base.Broadcast: broadcasted, DefaultArrayStyle, broadcast_shape, BroadcastStyle, Broadcasted import Statistics: mean, std, var, cov, cor @@ -467,6 +467,14 @@ end @inline Eye{T}(A::AbstractMatrix) where T = Eye{T}(size(A)...) @inline Eye(A::AbstractMatrix) = Eye{eltype(A)}(size(A)...) +# This may break, as it uses undocumented internals of LinearAlgebra +# Ideally this should be copied over to this package +# Also, maybe this should reuse the broadcasting behavior of the parent, +# once AbstractFill types implement their own BroadcastStyle +BroadcastStyle(::Type{<:RectDiagonal}) = LinearAlgebra.StructuredMatrixStyle{RectDiagonal}() +LinearAlgebra.structured_broadcast_alloc(bc, ::Type{<:RectDiagonal}, ::Type{ElType}, n) where {ElType} = + RectDiagonal(Array{ElType}(undef, n), axes(bc)) +@inline LinearAlgebra.fzero(S::RectDiagonal{T}) where {T} = zero(T) ######### # Special matrix types diff --git a/test/runtests.jl b/test/runtests.jl index 546be82a..12f3bcde 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -1265,6 +1265,20 @@ end @test FillArrays._copy_oftype(D2, eltype(D2)) !== D2 end +@testset "Eye broadcast" begin + E = Eye(2,3) + M = Matrix(E) + F = E .+ E + @test F isa FillArrays.RectDiagonal + @test F == M + M + + F = E .+ 1 + @test F == M .+ 1 + + E = Eye((SOneTo(2), SOneTo(2))) + @test axes(E .+ E) === axes(E) +end + @testset "Issue #31" begin @test convert(SparseMatrixCSC{Float64,Int64}, Zeros{Float64}(3, 3)) == spzeros(3, 3) @test sparse(Zeros(4, 2)) == spzeros(4, 2)