diff --git a/src/FillArrays.jl b/src/FillArrays.jl index 1b235ea6..7b920955 100644 --- a/src/FillArrays.jl +++ b/src/FillArrays.jl @@ -7,14 +7,14 @@ import Base: size, getindex, setindex!, IndexStyle, checkbounds, convert, any, all, axes, isone, iterate, unique, allunique, permutedims, inv, copy, vec, setindex!, count, ==, reshape, _throw_dmrs, map, zero, show, view, in, mapreduce, one, reverse, promote_op, promote_rule, repeat, - parent, issorted + parent, similar, issorted import LinearAlgebra: rank, svdvals!, tril, triu, tril!, triu!, diag, transpose, adjoint, fill!, dot, norm2, norm1, normInf, normMinusInf, normp, lmul!, rmul!, diagzero, AdjointAbsVec, TransposeAbsVec, issymmetric, ishermitian, AdjOrTransAbsVec, checksquare, mul!, kron -import Base.Broadcast: broadcasted, DefaultArrayStyle, broadcast_shape +import Base.Broadcast: broadcasted, DefaultArrayStyle, broadcast_shape, BroadcastStyle, Broadcasted import Statistics: mean, std, var, cov, cor @@ -481,6 +481,14 @@ end @inline Eye{T}(A::AbstractMatrix) where T = Eye{T}(size(A)...) @inline Eye(A::AbstractMatrix) = Eye{eltype(A)}(size(A)...) +# This may break, as it uses undocumented internals of LinearAlgebra +# Ideally this should be copied over to this package +# Also, maybe this should reuse the broadcasting behavior of the parent, +# once AbstractFill types implement their own BroadcastStyle +BroadcastStyle(::Type{<:RectDiagonal}) = LinearAlgebra.StructuredMatrixStyle{RectDiagonal}() +LinearAlgebra.structured_broadcast_alloc(bc, ::Type{<:RectDiagonal}, ::Type{ElType}, n) where {ElType} = + RectDiagonal(Array{ElType}(undef, n), axes(bc)) +@inline LinearAlgebra.fzero(S::RectDiagonal{T}) where {T} = zero(T) ######### # Special matrix types diff --git a/test/runtests.jl b/test/runtests.jl index 884aade0..a7f46561 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -1271,6 +1271,20 @@ end @test FillArrays._copy_oftype(D2, eltype(D2)) !== D2 end +@testset "Eye broadcast" begin + E = Eye(2,3) + M = Matrix(E) + F = E .+ E + @test F isa FillArrays.RectDiagonal + @test F == M + M + + F = E .+ 1 + @test F == M .+ 1 + + E = Eye((SOneTo(2), SOneTo(2))) + @test axes(E .+ E) === axes(E) +end + @testset "Issue #31" begin @test convert(SparseMatrixCSC{Float64,Int64}, Zeros{Float64}(3, 3)) == spzeros(3, 3) @test sparse(Zeros(4, 2)) == spzeros(4, 2)