From 4da1a66bdd84402ae9d583c6d5cab63edc39b0c2 Mon Sep 17 00:00:00 2001 From: Daniel Karrasch Date: Mon, 10 Oct 2022 04:10:21 +0200 Subject: [PATCH] Check sizes in 3-arg diagonal (dot-)product (#47114) (cherry picked from commit 25e3809ea4c309f7d0d9d5db85c5107aa63877b5) --- stdlib/LinearAlgebra/src/diagonal.jl | 3 +++ stdlib/LinearAlgebra/test/diagonal.jl | 6 +++++- 2 files changed, 8 insertions(+), 1 deletion(-) diff --git a/stdlib/LinearAlgebra/src/diagonal.jl b/stdlib/LinearAlgebra/src/diagonal.jl index 4af42d8f53eb4..a2756bb3a1201 100644 --- a/stdlib/LinearAlgebra/src/diagonal.jl +++ b/stdlib/LinearAlgebra/src/diagonal.jl @@ -770,6 +770,9 @@ end dot(A::AbstractMatrix, B::Diagonal) = conj(dot(B, A)) function _mapreduce_prod(f, x, D::Diagonal, y) + if !(length(x) == length(D.diag) == length(y)) + throw(DimensionMismatch("x has length $(length(x)), D has size $(size(D)), and y has $(length(y))")) + end if isempty(x) && isempty(D) && isempty(y) return zero(promote_op(f, eltype(x), eltype(D), eltype(y))) else diff --git a/stdlib/LinearAlgebra/test/diagonal.jl b/stdlib/LinearAlgebra/test/diagonal.jl index 2801332e840e6..007420f1eb999 100644 --- a/stdlib/LinearAlgebra/test/diagonal.jl +++ b/stdlib/LinearAlgebra/test/diagonal.jl @@ -922,10 +922,14 @@ end @test s1 == prod(sign, d) end -@testset "Empty (#35424)" begin +@testset "Empty (#35424) & size checks (#47060)" begin @test zeros(0)'*Diagonal(zeros(0))*zeros(0) === 0.0 @test transpose(zeros(0))*Diagonal(zeros(Complex{Int}, 0))*zeros(0) === 0.0 + 0.0im @test dot(zeros(Int32, 0), Diagonal(zeros(Int, 0)), zeros(Int16, 0)) === 0 + @test_throws DimensionMismatch zeros(2)' * Diagonal(zeros(2)) * zeros(3) + @test_throws DimensionMismatch zeros(3)' * Diagonal(zeros(2)) * zeros(2) + @test_throws DimensionMismatch dot(zeros(2), Diagonal(zeros(2)), zeros(3)) + @test_throws DimensionMismatch dot(zeros(3), Diagonal(zeros(2)), zeros(2)) end @testset "Diagonal(undef)" begin