Skip to content

Commit

Permalink
fix #31674, error when storing nonzeros into structural zeros with .= (
Browse files Browse the repository at this point in the history
…#31678)

Previously, broadcasted assignment (`.=`) would happily ignore all nonstructured portions of the destination, regardless of whether the broadcasted expression would actually evaluate to zero or not. This changes these in-place methods to use the same infrastructure that out-of-place broadcast uses to determine the result type. If we are unsure of the structural properties of the output, we fall back to the generic implementation, which will attempt to store into every single location of the destination -- including those structural zeros. Thus we now error in cases where we generate nonzeros in those locations.
  • Loading branch information
mbauman authored and JeffBezanson committed May 16, 2019
1 parent ffb26e7 commit 6bd3967
Show file tree
Hide file tree
Showing 2 changed files with 35 additions and 4 deletions.
10 changes: 9 additions & 1 deletion stdlib/LinearAlgebra/src/structuredbroadcast.jl
Original file line number Diff line number Diff line change
Expand Up @@ -102,6 +102,7 @@ function Base.similar(bc::Broadcasted{StructuredMatrixStyle{T}}, ::Type{ElType})
end

function copyto!(dest::Diagonal, bc::Broadcasted{<:StructuredMatrixStyle})
!isstructurepreserving(bc) && !fzeropreserving(bc) && return copyto!(dest, convert(Broadcasted{Nothing}, bc))
axs = axes(dest)
axes(bc) == axs || Broadcast.throwdm(axes(bc), axs)
for i in axs[1]
Expand All @@ -111,6 +112,7 @@ function copyto!(dest::Diagonal, bc::Broadcasted{<:StructuredMatrixStyle})
end

function copyto!(dest::Bidiagonal, bc::Broadcasted{<:StructuredMatrixStyle})
!isstructurepreserving(bc) && !fzeropreserving(bc) && return copyto!(dest, convert(Broadcasted{Nothing}, bc))
axs = axes(dest)
axes(bc) == axs || Broadcast.throwdm(axes(bc), axs)
for i in axs[1]
Expand All @@ -129,18 +131,22 @@ function copyto!(dest::Bidiagonal, bc::Broadcasted{<:StructuredMatrixStyle})
end

function copyto!(dest::SymTridiagonal, bc::Broadcasted{<:StructuredMatrixStyle})
!isstructurepreserving(bc) && !fzeropreserving(bc) && return copyto!(dest, convert(Broadcasted{Nothing}, bc))
axs = axes(dest)
axes(bc) == axs || Broadcast.throwdm(axes(bc), axs)
for i in axs[1]
dest.dv[i] = Broadcast._broadcast_getindex(bc, CartesianIndex(i, i))
end
for i = 1:size(dest, 1)-1
dest.ev[i] = Broadcast._broadcast_getindex(bc, CartesianIndex(i, i+1))
v = Broadcast._broadcast_getindex(bc, CartesianIndex(i, i+1))
v == Broadcast._broadcast_getindex(bc, CartesianIndex(i+1, i)) || throw(ArgumentError("broadcasted assignment breaks symmetry between locations ($i, $(i+1)) and ($(i+1), $i)"))
dest.ev[i] = v
end
return dest
end

function copyto!(dest::Tridiagonal, bc::Broadcasted{<:StructuredMatrixStyle})
!isstructurepreserving(bc) && !fzeropreserving(bc) && return copyto!(dest, convert(Broadcasted{Nothing}, bc))
axs = axes(dest)
axes(bc) == axs || Broadcast.throwdm(axes(bc), axs)
for i in axs[1]
Expand All @@ -154,6 +160,7 @@ function copyto!(dest::Tridiagonal, bc::Broadcasted{<:StructuredMatrixStyle})
end

function copyto!(dest::LowerTriangular, bc::Broadcasted{<:StructuredMatrixStyle})
!isstructurepreserving(bc) && !fzeropreserving(bc) && return copyto!(dest, convert(Broadcasted{Nothing}, bc))
axs = axes(dest)
axes(bc) == axs || Broadcast.throwdm(axes(bc), axs)
for j in axs[2]
Expand All @@ -165,6 +172,7 @@ function copyto!(dest::LowerTriangular, bc::Broadcasted{<:StructuredMatrixStyle}
end

function copyto!(dest::UpperTriangular, bc::Broadcasted{<:StructuredMatrixStyle})
!isstructurepreserving(bc) && !fzeropreserving(bc) && return copyto!(dest, convert(Broadcasted{Nothing}, bc))
axs = axes(dest)
axes(bc) == axs || Broadcast.throwdm(axes(bc), axs)
for j in axs[2]
Expand Down
29 changes: 26 additions & 3 deletions stdlib/LinearAlgebra/test/structuredbroadcast.jl
Original file line number Diff line number Diff line change
Expand Up @@ -51,14 +51,37 @@ end
A = rand(N, N)
sA = A + copy(A')
D = Diagonal(rand(N))
B = Bidiagonal(rand(N), rand(N - 1), :U)
Bu = Bidiagonal(rand(N), rand(N - 1), :U)
Bl = Bidiagonal(rand(N), rand(N - 1), :L)
T = Tridiagonal(rand(N - 1), rand(N), rand(N - 1))
= LowerTriangular(rand(N,N))
= UpperTriangular(rand(N,N))

@test broadcast!(sin, copy(D), D) == Diagonal(sin.(D))
@test broadcast!(sin, copy(B), B) == Bidiagonal(sin.(B), :U)
@test broadcast!(sin, copy(Bu), Bu) == Bidiagonal(sin.(Bu), :U)
@test broadcast!(sin, copy(Bl), Bl) == Bidiagonal(sin.(Bl), :L)
@test broadcast!(sin, copy(T), T) == Tridiagonal(sin.(T))
@test broadcast!(sin, copy(◣), ◣) == LowerTriangular(sin.(◣))
@test broadcast!(sin, copy(◥), ◥) == UpperTriangular(sin.(◥))
@test broadcast!(*, copy(D), D, A) == Diagonal(broadcast(*, D, A))
@test broadcast!(*, copy(B), B, A) == Bidiagonal(broadcast(*, B, A), :U)
@test broadcast!(*, copy(Bu), Bu, A) == Bidiagonal(broadcast(*, Bu, A), :U)
@test broadcast!(*, copy(Bl), Bl, A) == Bidiagonal(broadcast(*, Bl, A), :L)
@test broadcast!(*, copy(T), T, A) == Tridiagonal(broadcast(*, T, A))
@test broadcast!(*, copy(◣), ◣, A) == LowerTriangular(broadcast(*, ◣, A))
@test broadcast!(*, copy(◥), ◥, A) == UpperTriangular(broadcast(*, ◥, A))

@test_throws ArgumentError broadcast!(cos, copy(D), D) == Diagonal(sin.(D))
@test_throws ArgumentError broadcast!(cos, copy(Bu), Bu) == Bidiagonal(sin.(Bu), :U)
@test_throws ArgumentError broadcast!(cos, copy(Bl), Bl) == Bidiagonal(sin.(Bl), :L)
@test_throws ArgumentError broadcast!(cos, copy(T), T) == Tridiagonal(sin.(T))
@test_throws ArgumentError broadcast!(cos, copy(◣), ◣) == LowerTriangular(sin.(◣))
@test_throws ArgumentError broadcast!(cos, copy(◥), ◥) == UpperTriangular(sin.(◥))
@test_throws ArgumentError broadcast!(+, copy(D), D, A) == Diagonal(broadcast(*, D, A))
@test_throws ArgumentError broadcast!(+, copy(Bu), Bu, A) == Bidiagonal(broadcast(*, Bu, A), :U)
@test_throws ArgumentError broadcast!(+, copy(Bl), Bl, A) == Bidiagonal(broadcast(*, Bl, A), :L)
@test_throws ArgumentError broadcast!(+, copy(T), T, A) == Tridiagonal(broadcast(*, T, A))
@test_throws ArgumentError broadcast!(+, copy(◣), ◣, A) == LowerTriangular(broadcast(*, ◣, A))
@test_throws ArgumentError broadcast!(+, copy(◥), ◥, A) == UpperTriangular(broadcast(*, ◥, A))
end

@testset "map[!] over combinations of structured matrices" begin
Expand Down

2 comments on commit 6bd3967

@nanosoldier
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Executing the daily benchmark build, I will reply here when finished:

@nanosoldier runbenchmarks(ALL, isdaily = true)

@nanosoldier
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Your benchmark job has completed - possible performance regressions were detected. A full report can be found here. cc @ararslan

Please sign in to comment.