From 99fdb32906134917bfcd3f39ba429dd52ddf9c38 Mon Sep 17 00:00:00 2001 From: Seyoon Ko Date: Tue, 27 Aug 2024 21:54:38 -0700 Subject: [PATCH 1/5] Update wrappers.jl Fix incorrect definition of m and n in gemv_strided_batched! --- lib/cublas/wrappers.jl | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/lib/cublas/wrappers.jl b/lib/cublas/wrappers.jl index 150a01f66e..c17f40efd7 100644 --- a/lib/cublas/wrappers.jl +++ b/lib/cublas/wrappers.jl @@ -462,9 +462,9 @@ for (fname, fname_64, eltyin, eltyout) in ( if size(A, 3) != size(x, 2) || size(A, 3) != size(y, 2) throw(DimensionMismatch("Batch sizes must be equal for all inputs")) end - m = size(A, trans == 'N' ? 1 : 2) - n = size(A, trans == 'N' ? 2 : 1) - if m != size(y, 1) || n != size(x, 1) + m = size(A, 1) + n = size(A, 2) + if size(y, 1) != (trans == 'N' ? m : n) || size(x, 1) != (trans == 'N' ? n : m) throw(DimensionMismatch("A has dimension $(size(A)), x has dimension $(size(x)), y has dimension $(size(y))")) end From ddfde49a7ff365ba1a5d6a0983b4d3a41bcaf6ff Mon Sep 17 00:00:00 2001 From: Seyoon Ko Date: Wed, 28 Aug 2024 10:09:01 -0700 Subject: [PATCH 2/5] Update cublas gemv tests --- test/libraries/cublas.jl | 21 +++++++++++++++++++++ 1 file changed, 21 insertions(+) diff --git a/test/libraries/cublas.jl b/test/libraries/cublas.jl index cab9271dc7..6bb857f1df 100644 --- a/test/libraries/cublas.jl +++ b/test/libraries/cublas.jl @@ -124,6 +124,10 @@ end dy = CUBLAS.gemv('N', dA, dx) hy = collect(dy) @test hy ≈ A * x + dy = CuArray(y) + dx = CUBLAS.gemv('T', alpha, dA, dy) + hx = collect(dx) + @test hx ≈ alpha * A' * y end if CUBLAS.version() >= v"11.9" @@ -150,6 +154,16 @@ end y[i] = alpha * A[i] * x[i] + beta * y[i] @test y[i] ≈ hy end + dy = CuArray{elty, 1}[] + for i=1:length(A) + push!(dy, CuArray(y[i])) + end + CUBLAS.gemv_batched!('T', alpha, dA, dy, beta, dx) + for i=1:size(A, 3) + hx = collect(dx[:, i]) + x[:, i] = alpha * transpose(A[:, :, i]) * y[:, i] + beta * y[:, i] + @test x[:, i] ≈ hx + end end end @@ -173,6 +187,13 @@ end y[:, i] = alpha * A[:, :, i] * x[:, i] + beta * y[:, i] @test y[:, i] ≈ hy end + dy = CuArray(y) + CUBLAS.gemv_strided_batched!('T', alpha, dA, dy, beta, dx) + for i=1:size(A, 3) + hx = collect(dx[:, i]) + x[:, i] = alpha * transpose(A[:, :, i]) * y[:, i] + beta * y[:, i] + @test x[:, i] ≈ hx + end end end From e20b641c9efa2cce84d470f717f7c4ac188bf7ad Mon Sep 17 00:00:00 2001 From: Seyoon Ko Date: Wed, 28 Aug 2024 10:12:33 -0700 Subject: [PATCH 3/5] fix the corresponding bug in gemv_batched! --- lib/cublas/wrappers.jl | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/lib/cublas/wrappers.jl b/lib/cublas/wrappers.jl index c17f40efd7..a64c3b3901 100644 --- a/lib/cublas/wrappers.jl +++ b/lib/cublas/wrappers.jl @@ -423,8 +423,8 @@ for (fname, fname_64, eltyin, eltyout) in ( end end - m = size(A[1], trans == 'N' ? 1 : 2) - n = size(A[1], trans == 'N' ? 2 : 1) + m = size(A[1], 1) + n = size(A[1], 2) lda = max(1,stride(A[1],2)) incx = stride(x[1],1) incy = stride(y[1],1) From 318887dd07b39cad9ef64ebd717c6bd890e180b8 Mon Sep 17 00:00:00 2001 From: Seyoon Ko Date: Wed, 28 Aug 2024 10:27:46 -0700 Subject: [PATCH 4/5] more fix for gemv_batched! all the input dimensions should be identical for gemv_batched! --- lib/cublas/wrappers.jl | 9 +++++---- 1 file changed, 5 insertions(+), 4 deletions(-) diff --git a/lib/cublas/wrappers.jl b/lib/cublas/wrappers.jl index a64c3b3901..c09a456000 100644 --- a/lib/cublas/wrappers.jl +++ b/lib/cublas/wrappers.jl @@ -416,15 +416,16 @@ for (fname, fname_64, eltyin, eltyout) in ( if length(A) != length(x) || length(A) != length(y) throw(DimensionMismatch("Lengths of inputs must be the same")) end + m = size(A[1], 1) + n = size(A[1], 2) for (i, (As,xs,ys)) in enumerate(zip(A,x,y)) - m,n = size(As) + if size(As) != (m, n) + throw(DimensionMismatch("A[$i] has different dimension from A[1]. Dimensions between A's should be identical.")) + end if length(xs) != (trans == 'N' ? n : m) || length(ys) != (trans == 'N' ? m : n) throw(DimensionMismatch("Input $i: A has dimension $(size(As)), x has dimension $(size(xs)), y has dimension $(size(ys))")) end end - - m = size(A[1], 1) - n = size(A[1], 2) lda = max(1,stride(A[1],2)) incx = stride(x[1],1) incy = stride(y[1],1) From 68e4cdfacb443b9fcbbb933711a8b2c40737ef6d Mon Sep 17 00:00:00 2001 From: Seyoon Ko Date: Wed, 28 Aug 2024 13:05:17 -0700 Subject: [PATCH 5/5] fix tests --- test/libraries/cublas.jl | 16 ++++++++-------- 1 file changed, 8 insertions(+), 8 deletions(-) diff --git a/test/libraries/cublas.jl b/test/libraries/cublas.jl index 6bb857f1df..1a8bfb0934 100644 --- a/test/libraries/cublas.jl +++ b/test/libraries/cublas.jl @@ -105,7 +105,7 @@ end @test testf(*, rand(elty, m, n)', rand(elty, m)) x = rand(elty, m) A = rand(elty, m, m + 1 ) - y = rand(elty, m) + y = rand(elty, n) dx = CuArray(x) dA = CuArray(A) dy = CuArray(y) @@ -125,7 +125,7 @@ end hy = collect(dy) @test hy ≈ A * x dy = CuArray(y) - dx = CUBLAS.gemv('T', alpha, dA, dy) + dx = CUBLAS.gemv(elty <: Real ? 'T' : 'C', alpha, dA, dy) hx = collect(dx) @test hx ≈ alpha * A' * y end @@ -158,11 +158,11 @@ end for i=1:length(A) push!(dy, CuArray(y[i])) end - CUBLAS.gemv_batched!('T', alpha, dA, dy, beta, dx) + CUBLAS.gemv_batched!(elty <: Real ? 'T' : 'C', alpha, dA, dy, beta, dx) for i=1:size(A, 3) - hx = collect(dx[:, i]) - x[:, i] = alpha * transpose(A[:, :, i]) * y[:, i] + beta * y[:, i] - @test x[:, i] ≈ hx + hx = collect(dx[i]) + x[i] = alpha * A[i]' * y[i] + beta * x[i] + @test x[i] ≈ hx end end end @@ -188,10 +188,10 @@ end @test y[:, i] ≈ hy end dy = CuArray(y) - CUBLAS.gemv_strided_batched!('T', alpha, dA, dy, beta, dx) + CUBLAS.gemv_strided_batched!(elty <: Real ? 'T' : 'C', alpha, dA, dy, beta, dx) for i=1:size(A, 3) hx = collect(dx[:, i]) - x[:, i] = alpha * transpose(A[:, :, i]) * y[:, i] + beta * y[:, i] + x[:, i] = alpha * A[:, :, i]' * y[:, i] + beta * x[:, i] @test x[:, i] ≈ hx end end