From 2c5d8f3b607b4a81dcb15ff9b22b6a9076f7ad1b Mon Sep 17 00:00:00 2001 From: Jishnu Bhattacharya Date: Fri, 17 May 2024 12:07:49 +0530 Subject: [PATCH] Reland "Support broadcasting over structured block matrices #53909" (#54460) This was reverted in https://github.com/JuliaLang/julia/pull/54332. This needs https://github.com/JuliaLang/julia/pull/54459 for the tests to pass. Opening this now to not forget about it. --- .../LinearAlgebra/src/structuredbroadcast.jl | 6 +- stdlib/LinearAlgebra/test/special.jl | 1 + .../LinearAlgebra/test/structuredbroadcast.jl | 58 +++++++++++++++++++ 3 files changed, 63 insertions(+), 2 deletions(-) diff --git a/stdlib/LinearAlgebra/src/structuredbroadcast.jl b/stdlib/LinearAlgebra/src/structuredbroadcast.jl index 6bf2caee92105..7fc2282a7df24 100644 --- a/stdlib/LinearAlgebra/src/structuredbroadcast.jl +++ b/stdlib/LinearAlgebra/src/structuredbroadcast.jl @@ -8,8 +8,8 @@ struct StructuredMatrixStyle{T} <: Broadcast.AbstractArrayStyle{2} end StructuredMatrixStyle{T}(::Val{2}) where {T} = StructuredMatrixStyle{T}() StructuredMatrixStyle{T}(::Val{N}) where {T,N} = Broadcast.DefaultArrayStyle{N}() -const StructuredMatrix = Union{Diagonal,Bidiagonal,SymTridiagonal,Tridiagonal,LowerTriangular,UnitLowerTriangular,UpperTriangular,UnitUpperTriangular} -for ST in Base.uniontypes(StructuredMatrix) +const StructuredMatrix{T} = Union{Diagonal{T},Bidiagonal{T},SymTridiagonal{T},Tridiagonal{T},LowerTriangular{T},UnitLowerTriangular{T},UpperTriangular{T},UnitUpperTriangular{T}} +for ST in (Diagonal,Bidiagonal,SymTridiagonal,Tridiagonal,LowerTriangular,UnitLowerTriangular,UpperTriangular,UnitUpperTriangular) @eval Broadcast.BroadcastStyle(::Type{<:$ST}) = $(StructuredMatrixStyle{ST}()) end @@ -133,6 +133,7 @@ fails as `zero(::Tuple{Int})` is not defined. However, iszerodefined(::Type) = false iszerodefined(::Type{<:Number}) = true iszerodefined(::Type{<:AbstractArray{T}}) where T = iszerodefined(T) +iszerodefined(::Type{<:UniformScaling{T}}) where T = iszerodefined(T) count_structedmatrix(T, bc::Broadcasted) = sum(Base.Fix2(isa, T), Broadcast.cat_nested(bc); init = 0) @@ -160,6 +161,7 @@ fzero(::Type{T}) where T = Some(T) fzero(r::Ref) = Some(r[]) fzero(t::Tuple{Any}) = Some(only(t)) fzero(S::StructuredMatrix) = Some(zero(eltype(S))) +fzero(::StructuredMatrix{<:AbstractMatrix{T}}) where {T<:Number} = Some(haszero(T) ? zero(T)*I : nothing) fzero(x) = nothing function fzero(bc::Broadcast.Broadcasted) args = map(fzero, bc.args) diff --git a/stdlib/LinearAlgebra/test/special.jl b/stdlib/LinearAlgebra/test/special.jl index 6df27e00cc81a..be04fb564a6e8 100644 --- a/stdlib/LinearAlgebra/test/special.jl +++ b/stdlib/LinearAlgebra/test/special.jl @@ -111,6 +111,7 @@ Random.seed!(1) struct TypeWithZero end Base.promote_rule(::Type{TypeWithoutZero}, ::Type{TypeWithZero}) = TypeWithZero Base.convert(::Type{TypeWithZero}, ::TypeWithoutZero) = TypeWithZero() + Base.zero(x::Union{TypeWithoutZero, TypeWithZero}) = zero(typeof(x)) Base.zero(::Type{<:Union{TypeWithoutZero, TypeWithZero}}) = TypeWithZero() LinearAlgebra.symmetric(::TypeWithoutZero, ::Symbol) = TypeWithoutZero() LinearAlgebra.symmetric_type(::Type{TypeWithoutZero}) = TypeWithoutZero diff --git a/stdlib/LinearAlgebra/test/structuredbroadcast.jl b/stdlib/LinearAlgebra/test/structuredbroadcast.jl index 5f7ac96fdf61e..384ed5b3b60cf 100644 --- a/stdlib/LinearAlgebra/test/structuredbroadcast.jl +++ b/stdlib/LinearAlgebra/test/structuredbroadcast.jl @@ -307,4 +307,62 @@ end @test select_first.(missing, diag) isa Matrix{Missing} end +@testset "broadcast over structured matrices with matrix elements" begin + function standardbroadcastingtests(D, T) + M = [x for x in D] + Dsum = D .+ D + @test Dsum isa T + @test Dsum == M .+ M + Dcopy = copy.(D) + @test Dcopy isa T + @test Dcopy == D + Df = float.(D) + @test Df isa T + @test Df == D + @test eltype(eltype(Df)) <: AbstractFloat + @test (x -> (x,)).(D) == (x -> (x,)).(M) + @test (x -> 1).(D) == ones(Int,size(D)) + @test all(==(2), ndims.(D)) + @test_throws MethodError size.(D) + end + @testset "Diagonal" begin + @testset "square" begin + A = [1 3; 2 4] + D = Diagonal([A, A]) + standardbroadcastingtests(D, Diagonal) + @test sincos.(D) == sincos.(Matrix{eltype(D)}(D)) + M = [x for x in D] + @test cos.(D) == cos.(M) + end + + @testset "different-sized square blocks" begin + D = Diagonal([ones(3,3), fill(3.0,2,2)]) + standardbroadcastingtests(D, Diagonal) + end + + @testset "rectangular blocks" begin + D = Diagonal([ones(Bool,3,4), ones(Bool,2,3)]) + standardbroadcastingtests(D, Diagonal) + end + + @testset "incompatible sizes" begin + A = reshape(1:12, 4, 3) + B = reshape(1:12, 3, 4) + D1 = Diagonal(fill(A, 2)) + D2 = Diagonal(fill(B, 2)) + @test_throws DimensionMismatch D1 .+ D2 + end + end + @testset "Bidiagonal" begin + A = [1 3; 2 4] + B = Bidiagonal(fill(A,3), fill(A,2), :U) + standardbroadcastingtests(B, Bidiagonal) + end + @testset "UpperTriangular" begin + A = [1 3; 2 4] + U = UpperTriangular([(i+j)*A for i in 1:3, j in 1:3]) + standardbroadcastingtests(U, UpperTriangular) + end +end + end