diff --git a/stdlib/LinearAlgebra/src/structuredbroadcast.jl b/stdlib/LinearAlgebra/src/structuredbroadcast.jl index dfc58aabc63ec..6114c9a796390 100644 --- a/stdlib/LinearAlgebra/src/structuredbroadcast.jl +++ b/stdlib/LinearAlgebra/src/structuredbroadcast.jl @@ -8,8 +8,8 @@ struct StructuredMatrixStyle{T} <: Broadcast.AbstractArrayStyle{2} end StructuredMatrixStyle{T}(::Val{2}) where {T} = StructuredMatrixStyle{T}() StructuredMatrixStyle{T}(::Val{N}) where {T,N} = Broadcast.DefaultArrayStyle{N}() -const StructuredMatrix = Union{Diagonal,Bidiagonal,SymTridiagonal,Tridiagonal,LowerTriangular,UnitLowerTriangular,UpperTriangular,UnitUpperTriangular} -for ST in Base.uniontypes(StructuredMatrix) +const StructuredMatrix{T} = Union{Diagonal{T},Bidiagonal{T},SymTridiagonal{T},Tridiagonal{T},LowerTriangular{T},UnitLowerTriangular{T},UpperTriangular{T},UnitUpperTriangular{T}} +for ST in (Diagonal,Bidiagonal,SymTridiagonal,Tridiagonal,LowerTriangular,UnitLowerTriangular,UpperTriangular,UnitUpperTriangular) @eval Broadcast.BroadcastStyle(::Type{<:$ST}) = $(StructuredMatrixStyle{ST}()) end @@ -133,6 +133,7 @@ fails as `zero(::Tuple{Int})` is not defined. However, iszerodefined(::Type) = false iszerodefined(::Type{<:Number}) = true iszerodefined(::Type{<:AbstractArray{T}}) where T = iszerodefined(T) +iszerodefined(::Type{<:UniformScaling{T}}) where T = iszerodefined(T) fzeropreserving(bc) = (v = fzero(bc); !ismissing(v) && (iszerodefined(typeof(v)) ? iszero(v) : v == 0)) # Like sparse matrices, we assume that the zero-preservation property of a broadcasted @@ -144,6 +145,7 @@ fzero(::Type{T}) where T = T fzero(r::Ref) = r[] fzero(t::Tuple{Any}) = t[1] fzero(S::StructuredMatrix) = zero(eltype(S)) +fzero(::StructuredMatrix{<:AbstractMatrix{T}}) where {T<:Number} = haszero(T) ? zero(T)*I : missing fzero(x) = missing function fzero(bc::Broadcast.Broadcasted) args = map(fzero, bc.args) diff --git a/stdlib/LinearAlgebra/test/special.jl b/stdlib/LinearAlgebra/test/special.jl index 6465927db5235..a5198892ff995 100644 --- a/stdlib/LinearAlgebra/test/special.jl +++ b/stdlib/LinearAlgebra/test/special.jl @@ -111,6 +111,7 @@ Random.seed!(1) struct TypeWithZero end Base.promote_rule(::Type{TypeWithoutZero}, ::Type{TypeWithZero}) = TypeWithZero Base.convert(::Type{TypeWithZero}, ::TypeWithoutZero) = TypeWithZero() + Base.zero(x::Union{TypeWithoutZero, TypeWithZero}) = zero(typeof(x)) Base.zero(::Type{<:Union{TypeWithoutZero, TypeWithZero}}) = TypeWithZero() LinearAlgebra.symmetric(::TypeWithoutZero, ::Symbol) = TypeWithoutZero() LinearAlgebra.symmetric_type(::Type{TypeWithoutZero}) = TypeWithoutZero diff --git a/stdlib/LinearAlgebra/test/structuredbroadcast.jl b/stdlib/LinearAlgebra/test/structuredbroadcast.jl index 02d7ead55af17..3767fc10055f2 100644 --- a/stdlib/LinearAlgebra/test/structuredbroadcast.jl +++ b/stdlib/LinearAlgebra/test/structuredbroadcast.jl @@ -280,4 +280,62 @@ end # structured broadcast with function returning non-number type @test tuple.(Diagonal([1, 2])) == [(1,) (0,); (0,) (2,)] +@testset "broadcast over structured matrices with matrix elements" begin + function standardbroadcastingtests(D, T) + M = [x for x in D] + Dsum = D .+ D + @test Dsum isa T + @test Dsum == M .+ M + Dcopy = copy.(D) + @test Dcopy isa T + @test Dcopy == D + Df = float.(D) + @test Df isa T + @test Df == D + @test eltype(eltype(Df)) <: AbstractFloat + @test (x -> (x,)).(D) == (x -> (x,)).(M) + @test (x -> 1).(D) == ones(Int,size(D)) + @test all(==(2), ndims.(D)) + @test_throws MethodError size.(D) + end + @testset "Diagonal" begin + @testset "square" begin + A = [1 3; 2 4] + D = Diagonal([A, A]) + standardbroadcastingtests(D, Diagonal) + @test sincos.(D) == sincos.(Matrix{eltype(D)}(D)) + M = [x for x in D] + @test cos.(D) == cos.(M) + end + + @testset "different-sized square blocks" begin + D = Diagonal([ones(3,3), fill(3.0,2,2)]) + standardbroadcastingtests(D, Diagonal) + end + + @testset "rectangular blocks" begin + D = Diagonal([ones(Bool,3,4), ones(Bool,2,3)]) + standardbroadcastingtests(D, Diagonal) + end + + @testset "incompatible sizes" begin + A = reshape(1:12, 4, 3) + B = reshape(1:12, 3, 4) + D1 = Diagonal(fill(A, 2)) + D2 = Diagonal(fill(B, 2)) + @test_throws DimensionMismatch D1 .+ D2 + end + end + @testset "Bidiagonal" begin + A = [1 3; 2 4] + B = Bidiagonal(fill(A,3), fill(A,2), :U) + standardbroadcastingtests(B, Bidiagonal) + end + @testset "UpperTriangular" begin + A = [1 3; 2 4] + U = UpperTriangular([(i+j)*A for i in 1:3, j in 1:3]) + standardbroadcastingtests(U, UpperTriangular) + end +end + end