From 94476343d7de0bc6ce3467295b71e6e4d21cc676 Mon Sep 17 00:00:00 2001 From: Daniel Karrasch Date: Tue, 7 Jun 2022 18:48:20 +0200 Subject: [PATCH 1/3] Complete size checks in `BLAS.[sy/he]mm!` --- stdlib/LinearAlgebra/src/blas.jl | 6 ++++++ stdlib/LinearAlgebra/test/blas.jl | 2 ++ 2 files changed, 8 insertions(+) diff --git a/stdlib/LinearAlgebra/src/blas.jl b/stdlib/LinearAlgebra/src/blas.jl index 2710559e57d6b..057f3f2a51086 100644 --- a/stdlib/LinearAlgebra/src/blas.jl +++ b/stdlib/LinearAlgebra/src/blas.jl @@ -1545,6 +1545,9 @@ for (mfname, elty) in ((:dsymm_,:Float64), if size(B,2) != n throw(DimensionMismatch(lazy"B has second dimension $(size(B,2)) but needs to match second dimension of C, $n")) end + if j != size(B,1) + throw(DimensionMismatch(lazy"A has second dimension $j but needs to match first dimension of B, $(size(B,1))")) + end chkstride1(A) chkstride1(B) chkstride1(C) @@ -1619,6 +1622,9 @@ for (mfname, elty) in ((:zhemm_,:ComplexF64), if size(B,2) != n throw(DimensionMismatch(lazy"B has second dimension $(size(B,2)) but needs to match second dimension of C, $n")) end + if j != size(B,1) + throw(DimensionMismatch(lazy"A has second dimension $j but needs to match first dimension of B, $(size(B,1))")) + end chkstride1(A) chkstride1(B) chkstride1(C) diff --git a/stdlib/LinearAlgebra/test/blas.jl b/stdlib/LinearAlgebra/test/blas.jl index 0a2ac87c8026d..b166800d07ec9 100644 --- a/stdlib/LinearAlgebra/test/blas.jl +++ b/stdlib/LinearAlgebra/test/blas.jl @@ -222,11 +222,13 @@ Random.seed!(100) @test_throws DimensionMismatch BLAS.symm('R','U',Cmn,Cnn) @test_throws DimensionMismatch BLAS.symm!('L','U',one(elty),Asymm,Cnn,one(elty),Cmn) @test_throws DimensionMismatch BLAS.symm!('L','U',one(elty),Asymm,Cnn,one(elty),Cnm) + @test_throws DimensionMismatch BLAS.symm!('L','U',one(elty),Asymm,Cmn,one(elty),Cnn) if elty <: BlasComplex @test_throws DimensionMismatch BLAS.hemm('L','U',Cnm,Cnn) @test_throws DimensionMismatch BLAS.hemm('R','U',Cmn,Cnn) @test_throws DimensionMismatch BLAS.hemm!('L','U',one(elty),Aherm,Cnn,one(elty),Cmn) @test_throws DimensionMismatch BLAS.hemm!('L','U',one(elty),Aherm,Cnn,one(elty),Cnm) + @test_throws DimensionMismatch BLAS.hemm!('L','U',one(elty),Aherm,Cmn,one(elty),Cnn) end end end From e8e44970b2498b212eaff36935c5220673911021 Mon Sep 17 00:00:00 2001 From: Daniel Karrasch Date: Wed, 8 Jun 2022 10:12:22 +0200 Subject: [PATCH 2/3] take side into account --- stdlib/LinearAlgebra/src/blas.jl | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/stdlib/LinearAlgebra/src/blas.jl b/stdlib/LinearAlgebra/src/blas.jl index 057f3f2a51086..f3aa190bc6e0a 100644 --- a/stdlib/LinearAlgebra/src/blas.jl +++ b/stdlib/LinearAlgebra/src/blas.jl @@ -1542,10 +1542,10 @@ for (mfname, elty) in ((:dsymm_,:Float64), if j != (side == 'L' ? m : n) throw(DimensionMismatch(lazy"A has size $(size(A)), C has size ($m,$n)")) end - if size(B,2) != n + if size(B,2) != (side == 'L' ? n : j) throw(DimensionMismatch(lazy"B has second dimension $(size(B,2)) but needs to match second dimension of C, $n")) end - if j != size(B,1) + if size(B,1) != (side == 'L' ? j : m) throw(DimensionMismatch(lazy"A has second dimension $j but needs to match first dimension of B, $(size(B,1))")) end chkstride1(A) @@ -1619,10 +1619,10 @@ for (mfname, elty) in ((:zhemm_,:ComplexF64), if j != (side == 'L' ? m : n) throw(DimensionMismatch(lazy"A has size $(size(A)), C has size ($m,$n)")) end - if size(B,2) != n + if size(B,2) != (side == 'L' ? n : j) throw(DimensionMismatch(lazy"B has second dimension $(size(B,2)) but needs to match second dimension of C, $n")) end - if j != size(B,1) + if size(B,1) != (side == 'L' ? j : m) throw(DimensionMismatch(lazy"A has second dimension $j but needs to match first dimension of B, $(size(B,1))")) end chkstride1(A) From 581a2b7f55f5aa7e0df98fb863684d0319bb06e7 Mon Sep 17 00:00:00 2001 From: Daniel Karrasch Date: Thu, 9 Jun 2022 10:14:41 +0200 Subject: [PATCH 3/3] split size checks in left- and right-multiplication --- stdlib/LinearAlgebra/src/blas.jl | 58 +++++++++++++++++++------- stdlib/LinearAlgebra/test/blas.jl | 6 +++ stdlib/LinearAlgebra/test/symmetric.jl | 6 +++ 3 files changed, 54 insertions(+), 16 deletions(-) diff --git a/stdlib/LinearAlgebra/src/blas.jl b/stdlib/LinearAlgebra/src/blas.jl index f3aa190bc6e0a..0152ae3268da3 100644 --- a/stdlib/LinearAlgebra/src/blas.jl +++ b/stdlib/LinearAlgebra/src/blas.jl @@ -1539,14 +1539,27 @@ for (mfname, elty) in ((:dsymm_,:Float64), require_one_based_indexing(A, B, C) m, n = size(C) j = checksquare(A) - if j != (side == 'L' ? m : n) - throw(DimensionMismatch(lazy"A has size $(size(A)), C has size ($m,$n)")) - end - if size(B,2) != (side == 'L' ? n : j) - throw(DimensionMismatch(lazy"B has second dimension $(size(B,2)) but needs to match second dimension of C, $n")) - end - if size(B,1) != (side == 'L' ? j : m) - throw(DimensionMismatch(lazy"A has second dimension $j but needs to match first dimension of B, $(size(B,1))")) + M, N = size(B) + if side == 'L' + if j != m + throw(DimensionMismatch(lazy"A has first dimension $j but needs to match first dimension of C, $m")) + end + if N != n + throw(DimensionMismatch(lazy"B has second dimension $N but needs to match second dimension of C, $n")) + end + if j != M + throw(DimensionMismatch(lazy"A has second dimension $j but needs to match first dimension of B, $M")) + end + else + if j != n + throw(DimensionMismatch(lazy"B has second dimension $j but needs to match second dimension of C, $n")) + end + if N != j + throw(DimensionMismatch(lazy"A has second dimension $N but needs to match first dimension of B, $j")) + end + if M != m + throw(DimensionMismatch(lazy"A has first dimension $M but needs to match first dimension of C, $m")) + end end chkstride1(A) chkstride1(B) @@ -1616,14 +1629,27 @@ for (mfname, elty) in ((:zhemm_,:ComplexF64), require_one_based_indexing(A, B, C) m, n = size(C) j = checksquare(A) - if j != (side == 'L' ? m : n) - throw(DimensionMismatch(lazy"A has size $(size(A)), C has size ($m,$n)")) - end - if size(B,2) != (side == 'L' ? n : j) - throw(DimensionMismatch(lazy"B has second dimension $(size(B,2)) but needs to match second dimension of C, $n")) - end - if size(B,1) != (side == 'L' ? j : m) - throw(DimensionMismatch(lazy"A has second dimension $j but needs to match first dimension of B, $(size(B,1))")) + M, N = size(B) + if side == 'L' + if j != m + throw(DimensionMismatch(lazy"A has first dimension $j but needs to match first dimension of C, $m")) + end + if N != n + throw(DimensionMismatch(lazy"B has second dimension $N but needs to match second dimension of C, $n")) + end + if j != M + throw(DimensionMismatch(lazy"A has second dimension $j but needs to match first dimension of B, $M")) + end + else + if j != n + throw(DimensionMismatch(lazy"B has second dimension $j but needs to match second dimension of C, $n")) + end + if N != j + throw(DimensionMismatch(lazy"A has second dimension $N but needs to match first dimension of B, $j")) + end + if M != m + throw(DimensionMismatch(lazy"A has first dimension $M but needs to match first dimension of C, $m")) + end end chkstride1(A) chkstride1(B) diff --git a/stdlib/LinearAlgebra/test/blas.jl b/stdlib/LinearAlgebra/test/blas.jl index b166800d07ec9..a9cbfcf60b714 100644 --- a/stdlib/LinearAlgebra/test/blas.jl +++ b/stdlib/LinearAlgebra/test/blas.jl @@ -223,12 +223,18 @@ Random.seed!(100) @test_throws DimensionMismatch BLAS.symm!('L','U',one(elty),Asymm,Cnn,one(elty),Cmn) @test_throws DimensionMismatch BLAS.symm!('L','U',one(elty),Asymm,Cnn,one(elty),Cnm) @test_throws DimensionMismatch BLAS.symm!('L','U',one(elty),Asymm,Cmn,one(elty),Cnn) + @test_throws DimensionMismatch BLAS.symm!('R','U',one(elty),Asymm,Cnm,one(elty),Cmn) + @test_throws DimensionMismatch BLAS.symm!('R','U',one(elty),Asymm,Cnn,one(elty),Cnm) + @test_throws DimensionMismatch BLAS.symm!('R','U',one(elty),Asymm,Cmn,one(elty),Cnn) if elty <: BlasComplex @test_throws DimensionMismatch BLAS.hemm('L','U',Cnm,Cnn) @test_throws DimensionMismatch BLAS.hemm('R','U',Cmn,Cnn) @test_throws DimensionMismatch BLAS.hemm!('L','U',one(elty),Aherm,Cnn,one(elty),Cmn) @test_throws DimensionMismatch BLAS.hemm!('L','U',one(elty),Aherm,Cnn,one(elty),Cnm) @test_throws DimensionMismatch BLAS.hemm!('L','U',one(elty),Aherm,Cmn,one(elty),Cnn) + @test_throws DimensionMismatch BLAS.hemm!('R','U',one(elty),Aherm,Cnm,one(elty),Cmn) + @test_throws DimensionMismatch BLAS.hemm!('R','U',one(elty),Aherm,Cnn,one(elty),Cnm) + @test_throws DimensionMismatch BLAS.hemm!('R','U',one(elty),Aherm,Cmn,one(elty),Cnn) end end end diff --git a/stdlib/LinearAlgebra/test/symmetric.jl b/stdlib/LinearAlgebra/test/symmetric.jl index 47a36df5e7883..60b7f642b3b37 100644 --- a/stdlib/LinearAlgebra/test/symmetric.jl +++ b/stdlib/LinearAlgebra/test/symmetric.jl @@ -352,6 +352,9 @@ end C = zeros(eltya,n,n) @test Hermitian(aherm) * a ≈ aherm * a @test a * Hermitian(aherm) ≈ a * aherm + # rectangular multiplication + @test [a; a] * Hermitian(aherm) ≈ [a; a] * aherm + @test Hermitian(aherm) * [a a] ≈ aherm * [a a] @test Hermitian(aherm) * Hermitian(aherm) ≈ aherm*aherm @test_throws DimensionMismatch Hermitian(aherm) * Vector{eltya}(undef, n+1) LinearAlgebra.mul!(C,a,Hermitian(aherm)) @@ -360,6 +363,9 @@ end @test Symmetric(asym) * Symmetric(asym) ≈ asym*asym @test Symmetric(asym) * a ≈ asym * a @test a * Symmetric(asym) ≈ a * asym + # rectangular multiplication + @test Symmetric(asym) * [a a] ≈ asym * [a a] + @test [a; a] * Symmetric(asym) ≈ [a; a] * asym @test_throws DimensionMismatch Symmetric(asym) * Vector{eltya}(undef, n+1) LinearAlgebra.mul!(C,a,Symmetric(asym)) @test C ≈ a*asym