From f9bcd064f5a8be01fb4bcd3ea676aa1e52413421 Mon Sep 17 00:00:00 2001 From: Sebastian Stock <42280794+sostock@users.noreply.github.com> Date: Wed, 9 Feb 2022 10:00:55 +0100 Subject: [PATCH] Allow negative `stride(A,2)` in `gemv!` (#42054) (cherry picked from commit 33a71b7b81948b1c5a130bcacf0084ecdcf1266c) --- stdlib/LinearAlgebra/src/blas.jl | 19 +++++++--- stdlib/LinearAlgebra/test/blas.jl | 58 +++++++++++++++---------------- 2 files changed, 43 insertions(+), 34 deletions(-) diff --git a/stdlib/LinearAlgebra/src/blas.jl b/stdlib/LinearAlgebra/src/blas.jl index 1c29d88f410b0..236d7fe927418 100644 --- a/stdlib/LinearAlgebra/src/blas.jl +++ b/stdlib/LinearAlgebra/src/blas.jl @@ -702,17 +702,26 @@ for (fname, elty) in ((:dgemv_,:Float64), end chkstride1(A) lda = stride(A,2) - lda >= max(1, size(A,1)) || error("`stride(A,2)` must be at least `max(1, size(A,1))`") sX = stride(X,1) - pX = pointer(X, sX > 0 ? firstindex(X) : lastindex(X)) sY = stride(Y,1) - pY = pointer(Y, sY > 0 ? firstindex(X) : lastindex(X)) - GC.@preserve X Y ccall((@blasfunc($fname), libblastrampoline), Cvoid, + if lda < 0 + colindex = lastindex(A, 2) + lda = -lda + trans == 'N' ? (sX = -sX) : (sY = -sY) + else + colindex = firstindex(A, 2) + end + lda >= size(A,1) || size(A,2) <= 1 || error("when `size(A,2) > 1`, `abs(stride(A,2))` must be at least `size(A,1)`") + lda = max(1, size(A,1), lda) + pA = pointer(A, Base._sub2ind(A, 1, colindex)) + pX = pointer(X, stride(X,1) > 0 ? firstindex(X) : lastindex(X)) + pY = pointer(Y, stride(Y,1) > 0 ? firstindex(Y) : lastindex(Y)) + GC.@preserve A X Y ccall((@blasfunc($fname), libblastrampoline), Cvoid, (Ref{UInt8}, Ref{BlasInt}, Ref{BlasInt}, Ref{$elty}, Ptr{$elty}, Ref{BlasInt}, Ptr{$elty}, Ref{BlasInt}, Ref{$elty}, Ptr{$elty}, Ref{BlasInt}, Clong), trans, size(A,1), size(A,2), alpha, - A, lda, pX, sX, + pA, lda, pX, sX, beta, pY, sY, 1) Y end diff --git a/stdlib/LinearAlgebra/test/blas.jl b/stdlib/LinearAlgebra/test/blas.jl index 18c1410f64d41..5bb47080d1b31 100644 --- a/stdlib/LinearAlgebra/test/blas.jl +++ b/stdlib/LinearAlgebra/test/blas.jl @@ -381,38 +381,38 @@ Random.seed!(100) @test all(BLAS.gemv('N', U4, o4) .== v41) @test all(BLAS.gemv('N', U4, o4) .== v41) @testset "non-standard strides" begin - if elty <: Complex - A = elty[1+2im 3+4im 5+6im 7+8im; 2+3im 4+5im 6+7im 8+9im; 3+4im 5+6im 7+8im 9+10im] - v = elty[1+2im, 2+3im, 3+4im, 4+5im] - dest = view(ones(elty, 5), 4:-2:2) - @test BLAS.gemv!('N', elty(2), view(A, 2:3, 2:2:4), view(v, 1:3:4), elty(3), dest) == elty[-35+178im, -39+202im] - @test BLAS.gemv('N', elty(-1), view(A, 2:3, 2:3), view(v, 2:-1:1)) == elty[15-41im, 17-49im] - @test BLAS.gemv('N', view(A, 1:0, 1:2), view(v, 1:2)) == elty[] - dest = view(ones(elty, 5), 4:-2:2) - @test BLAS.gemv!('T', elty(2), view(A, 2:3, 2:2:4), view(v, 1:3:4), elty(3), dest) == elty[-29+124im, -45+220im] - @test BLAS.gemv('T', elty(-1), view(A, 2:3, 2:3), view(v, 2:-1:1)) == elty[14-38im, 18-54im] - @test BLAS.gemv('T', view(A, 2:3, 2:1), view(v, 1:2)) == elty[] - dest = view(ones(elty, 5), 4:-2:2) - @test BLAS.gemv!('C', elty(2), view(A, 2:3, 2:2:4), view(v, 1:3:4), elty(3), dest) == elty[131+8im, 227+24im] - @test BLAS.gemv('C', elty(-1), view(A, 2:3, 2:3), view(v, 2:-1:1)) == elty[-40-6im, -56-10im] - @test BLAS.gemv('C', view(A, 2:3, 2:1), view(v, 1:2)) == elty[] - else - A = elty[1 2 3 4; 5 6 7 8; 9 10 11 12] - v = elty[1, 2, 3, 4] - dest = view(ones(elty, 5), 4:-2:2) - @test BLAS.gemv!('N', elty(2), view(A, 2:3, 2:2:4), view(v, 1:3:4), elty(3), dest) == elty[79, 119] - @test BLAS.gemv('N', elty(-1), view(A, 2:3, 2:3), view(v, 2:-1:1)) == elty[-19, -31] - @test BLAS.gemv('N', view(A, 1:0, 1:2), view(v, 1:2)) == elty[] - for trans = ('T', 'C') - dest = view(ones(elty, 5), 4:-2:2) - @test BLAS.gemv!(trans, elty(2), view(A, 2:3, 2:2:4), view(v, 1:3:4), elty(3), dest) == elty[95, 115] - @test BLAS.gemv(trans, elty(-1), view(A, 2:3, 2:3), view(v, 2:-1:1)) == elty[-22, -25] - @test BLAS.gemv(trans, view(A, 2:3, 2:1), view(v, 1:2)) == elty[] + A = rand(elty, 3, 4) + x = rand(elty, 5) + for y = (view(ones(elty, 5), 1:2:5), view(ones(elty, 7), 6:-2:2)) + ycopy = copy(y) + @test BLAS.gemv!('N', elty(2), view(A, :, 2:2:4), view(x, 1:3:4), elty(3), y) ≈ 2*A[:,2:2:4]*x[1:3:4] + 3*ycopy + ycopy = copy(y) + @test BLAS.gemv!('N', elty(2), view(A, :, 4:-2:2), view(x, 1:3:4), elty(3), y) ≈ 2*A[:,4:-2:2]*x[1:3:4] + 3*ycopy + ycopy = copy(y) + @test BLAS.gemv!('N', elty(2), view(A, :, 2:2:4), view(x, 4:-3:1), elty(3), y) ≈ 2*A[:,2:2:4]*x[4:-3:1] + 3*ycopy + ycopy = copy(y) + @test BLAS.gemv!('N', elty(2), view(A, :, 4:-2:2), view(x, 4:-3:1), elty(3), y) ≈ 2*A[:,4:-2:2]*x[4:-3:1] + 3*ycopy + ycopy = copy(y) + @test BLAS.gemv!('N', elty(2), view(A, :, StepRangeLen(1,0,1)), view(x, 1:1), elty(3), y) ≈ 2*A[:,1:1]*x[1:1] + 3*ycopy # stride(A,2) == 0 + end + @test BLAS.gemv!('N', elty(1), zeros(elty, 0, 5), zeros(elty, 5), elty(1), zeros(elty, 0)) == elty[] # empty matrix, stride(A,2) == 0 + @test BLAS.gemv('N', elty(-1), view(A, 2:3, 1:2:3), view(x, 2:-1:1)) ≈ -1*A[2:3,1:2:3]*x[2:-1:1] + @test BLAS.gemv('N', view(A, 2:3, 3:-2:1), view(x, 1:2:3)) ≈ A[2:3,3:-2:1]*x[1:2:3] + for (trans, f) = (('T',transpose), ('C',adjoint)) + for y = (view(ones(elty, 3), 1:2:3), view(ones(elty, 5), 4:-2:2)) + ycopy = copy(y) + @test BLAS.gemv!(trans, elty(2), view(A, :, 2:2:4), view(x, 1:2:5), elty(3), y) ≈ 2*f(A[:,2:2:4])*x[1:2:5] + 3*ycopy + ycopy = copy(y) + @test BLAS.gemv!(trans, elty(2), view(A, :, 4:-2:2), view(x, 1:2:5), elty(3), y) ≈ 2*f(A[:,4:-2:2])*x[1:2:5] + 3*ycopy + ycopy = copy(y) + @test BLAS.gemv!(trans, elty(2), view(A, :, 2:2:4), view(x, 5:-2:1), elty(3), y) ≈ 2*f(A[:,2:2:4])*x[5:-2:1] + 3*ycopy + ycopy = copy(y) + @test BLAS.gemv!(trans, elty(2), view(A, :, 4:-2:2), view(x, 5:-2:1), elty(3), y) ≈ 2*f(A[:,4:-2:2])*x[5:-2:1] + 3*ycopy end + @test BLAS.gemv!(trans, elty(2), view(A, :, StepRangeLen(1,0,1)), view(x, 1:2:5), elty(3), elty[1]) ≈ 2*f(A[:,1:1])*x[1:2:5] + elty[3] # stride(A,2) == 0 end for trans = ('N', 'T', 'C') - @test_throws ErrorException BLAS.gemv(trans, view(A, 1:2:3, 1:2), view(v, 1:2)) - @test_throws ErrorException BLAS.gemv(trans, view(A, 1:2, 2:-1:1), view(v, 1:2)) + @test_throws ErrorException BLAS.gemv(trans, view(A, 1:2:3, 1:2), view(x, 1:2)) # stride(A,1) must be 1 end end end