From 6d3dd379620745f397aecde1c8f2944fa2937a41 Mon Sep 17 00:00:00 2001 From: amcamd Date: Sun, 5 Mar 2023 11:26:04 -0500 Subject: [PATCH 01/11] make trsm offset calculations 64 bit safe --- library/src/blas3/rocblas_trsm.hpp | 592 ++++++++++++++--------------- 1 file changed, 296 insertions(+), 296 deletions(-) diff --git a/library/src/blas3/rocblas_trsm.hpp b/library/src/blas3/rocblas_trsm.hpp index 8127acb0e..0243c04cd 100644 --- a/library/src/blas3/rocblas_trsm.hpp +++ b/library/src/blas3/rocblas_trsm.hpp @@ -384,7 +384,7 @@ rocblas_status rocblas_trsm_left(rocblas_handle handle, jb, &alpha_1, invA, - i * BLOCK + offset_invAin, + i * size_t(BLOCK) + offset_invAin, BLOCK, stride_invA, (U)B, @@ -409,7 +409,7 @@ rocblas_status rocblas_trsm_left(rocblas_handle handle, BLOCK, &alpha_negative_one, A, - i + BLOCK + i * lda + offset_Ain, + i + BLOCK + i * size_t(lda) + offset_Ain, lda, stride_A, (U)X, @@ -452,7 +452,7 @@ rocblas_status rocblas_trsm_left(rocblas_handle handle, jb, alpha, invA, - i * BLOCK + offset_invAin, + i * size_t(BLOCK) + offset_invAin, BLOCK, stride_invA, (U)B, @@ -476,7 +476,7 @@ rocblas_status rocblas_trsm_left(rocblas_handle handle, jb, &alpha_negative_one, A, - i * lda + offset_Ain, + i * size_t(lda) + offset_Ain, lda, stride_A, (U)X, @@ -502,7 +502,7 @@ rocblas_status rocblas_trsm_left(rocblas_handle handle, BLOCK, &alpha_1, invA, - i * BLOCK + offset_invAin, + i * size_t(BLOCK) + offset_invAin, BLOCK, stride_invA, (U)B, @@ -525,7 +525,7 @@ rocblas_status rocblas_trsm_left(rocblas_handle handle, BLOCK, &alpha_negative_one, A, - i * lda + offset_Ain, + i * size_t(lda) + offset_Ain, lda, stride_A, (U)X, @@ -557,7 +557,7 @@ rocblas_status rocblas_trsm_left(rocblas_handle handle, jb, alpha, invA, - i * BLOCK + offset_invAin, + i * size_t(BLOCK) + offset_invAin, BLOCK, stride_invA, (U)B, @@ -605,7 +605,7 @@ rocblas_status rocblas_trsm_left(rocblas_handle handle, BLOCK, &alpha_1, invA, - i * BLOCK + offset_invAin, + i * size_t(BLOCK) + offset_invAin, BLOCK, stride_invA, (U)B, @@ -679,7 +679,7 @@ rocblas_status rocblas_trsm_left(rocblas_handle handle, BLOCK, &alpha_negative_one, A, - BLOCK * lda + offset_Ain, + BLOCK * size_t(lda) + offset_Ain, lda, stride_A, (U)X, @@ -705,7 +705,7 @@ rocblas_status rocblas_trsm_left(rocblas_handle handle, jb, &alpha_1, invA, - i * BLOCK + offset_invAin, + i * size_t(BLOCK) + offset_invAin, BLOCK, stride_invA, (U)B, @@ -728,7 +728,7 @@ rocblas_status rocblas_trsm_left(rocblas_handle handle, BLOCK, &alpha_negative_one, A, - i + (i + BLOCK) * lda + offset_Ain, + i + (i + BLOCK) * size_t(lda) + offset_Ain, lda, stride_A, (U)X, @@ -793,16 +793,16 @@ rocblas_status rocblas_trsm_right(rocblas_handle handle, jb, alpha, U(B), - i * ldb + offset_Bin, + i * size_t(ldb) + offset_Bin, ldb, stride_B, invA, - i * BLOCK + offset_invAin, + i * size_t(BLOCK) + offset_invAin, BLOCK, stride_invA, &beta_0, X, - i * m, + i * size_t(m), m, stride_X, batch_count); @@ -816,7 +816,7 @@ rocblas_status rocblas_trsm_right(rocblas_handle handle, jb, &alpha_negative_one, (U)X, - i * m, + i * size_t(m), m, stride_X, A, @@ -841,16 +841,16 @@ rocblas_status rocblas_trsm_right(rocblas_handle handle, BLOCK, &alpha_1, (U)B, - i * ldb + offset_Bin, + i * size_t(ldb) + offset_Bin, ldb, stride_B, invA, - i * BLOCK + offset_invAin, + i * size_t(BLOCK) + offset_invAin, BLOCK, stride_invA, &beta_0, X, - i * m, + i * size_t(m), m, stride_X, batch_count); @@ -864,7 +864,7 @@ rocblas_status rocblas_trsm_right(rocblas_handle handle, BLOCK, &alpha_negative_one, (U)X, - i * m, + i * size_t(m), m, stride_X, A, @@ -919,12 +919,12 @@ rocblas_status rocblas_trsm_right(rocblas_handle handle, m, stride_X, A, - BLOCK * lda + offset_Ain, + BLOCK * size_t(lda) + offset_Ain, lda, stride_A, alpha, B, - BLOCK * ldb + offset_Bin, + BLOCK * size_t(ldb) + offset_Bin, ldb, stride_B, batch_count); @@ -941,16 +941,16 @@ rocblas_status rocblas_trsm_right(rocblas_handle handle, jb, &alpha_1, (U)B, - i * ldb + offset_Bin, + i * size_t(ldb) + offset_Bin, ldb, stride_B, invA, - i * BLOCK + offset_invAin, + i * size_t(BLOCK) + offset_invAin, BLOCK, stride_invA, &beta_0, X, - i * m, + i * size_t(m), m, stride_X, batch_count); @@ -964,16 +964,16 @@ rocblas_status rocblas_trsm_right(rocblas_handle handle, BLOCK, &alpha_negative_one, (U)X, - i * m, + i * size_t(m), m, stride_X, A, - i + (i + BLOCK) * lda + offset_Ain, + i + (i + BLOCK) * size_t(lda) + offset_Ain, lda, stride_A, &beta_1, B, - (i + BLOCK) * ldb + offset_Bin, + (i + BLOCK) * size_t(ldb) + offset_Bin, ldb, stride_B, batch_count); @@ -1027,7 +1027,7 @@ rocblas_status rocblas_trsm_right(rocblas_handle handle, stride_A, alpha, B, - BLOCK * ldb + offset_Bin, + BLOCK * size_t(ldb) + offset_Bin, ldb, stride_B, batch_count); @@ -1044,16 +1044,16 @@ rocblas_status rocblas_trsm_right(rocblas_handle handle, jb, &alpha_1, (U)B, - i * ldb + offset_Bin, + i * size_t(ldb) + offset_Bin, ldb, stride_B, invA, - i * BLOCK + offset_invAin, + i * size_t(BLOCK) + offset_invAin, BLOCK, stride_invA, &beta_0, X, - i * m, + i * size_t(m), m, stride_X, batch_count); @@ -1067,16 +1067,16 @@ rocblas_status rocblas_trsm_right(rocblas_handle handle, BLOCK, &alpha_negative_one, (U)X, - i * m, + i * size_t(m), m, stride_X, A, - BLOCK + i + i * lda + offset_Ain, + BLOCK + i + i * size_t(lda) + offset_Ain, lda, stride_A, &beta_1, B, - (i + BLOCK) * ldb + offset_Bin, + (i + BLOCK) * size_t(ldb) + offset_Bin, ldb, stride_B, batch_count); @@ -1096,16 +1096,16 @@ rocblas_status rocblas_trsm_right(rocblas_handle handle, jb, alpha, (U)B, - i * ldb + offset_Bin, + i * size_t(ldb) + offset_Bin, ldb, stride_B, invA, - i * BLOCK + offset_invAin, + i * size_t(BLOCK) + offset_invAin, BLOCK, stride_invA, &beta_0, X, - i * m, + i * size_t(m), m, stride_X, batch_count); @@ -1119,11 +1119,11 @@ rocblas_status rocblas_trsm_right(rocblas_handle handle, jb, &alpha_negative_one, (U)X, - i * m, + i * size_t(m), m, stride_X, A, - i * lda + offset_Ain, + i * size_t(lda) + offset_Ain, lda, stride_A, alpha, @@ -1144,16 +1144,16 @@ rocblas_status rocblas_trsm_right(rocblas_handle handle, BLOCK, &alpha_1, (U)B, - i * ldb + offset_Bin, + i * size_t(ldb) + offset_Bin, ldb, stride_B, invA, - i * BLOCK + offset_invAin, + i * size_t(BLOCK) + offset_invAin, BLOCK, stride_invA, &beta_0, X, - i * m, + i * size_t(m), m, stride_X, batch_count); @@ -1167,11 +1167,11 @@ rocblas_status rocblas_trsm_right(rocblas_handle handle, BLOCK, &alpha_negative_one, (U)X, - i * m, + i * size_t(m), m, stride_X, A, - i * lda + offset_Ain, + i * size_t(lda) + offset_Ain, lda, stride_A, &beta_1, @@ -1222,7 +1222,7 @@ rocblas_status special_trsm_template(rocblas_handle handle, for(size_t w = 0; w < W; w++) { - size_t width = std::min(bsize - w * B_chunk_size, B_chunk_size); + size_t width = std::min(bsize - w * size_t(B_chunk_size), B_chunk_size); if(side == rocblas_side_left) { @@ -1243,19 +1243,19 @@ rocblas_status special_trsm_template(rocblas_handle handle, BLOCK, stride_X, batch_count, - j * BLOCK + w * B_chunk_size * ldb + offset_Bin, + j * size_t(BLOCK) + w * size_t(B_chunk_size) * size_t(ldb) + offset_Bin, 0); if(r) { - rocblas_int offsetA = 0; - rocblas_int offsetB = parity ? w * B_chunk_size * ldb - : w * B_chunk_size * ldb + (q + 1) * BLOCK; + rocblas_stride offsetA = 0; + rocblas_stride offsetB = parity ? w * size_t(B_chunk_size) * size_t(ldb) + : w * size_t(B_chunk_size) * size_t(ldb) + (q + 1) * BLOCK; if(transA == rocblas_operation_none) - offsetA = parity ? r * BLOCK : BLOCK * (q * lda + q + lda); + offsetA = parity ? r * BLOCK : BLOCK * (q * size_t(lda) + q + lda); else - offsetA = parity ? r * BLOCK * lda : BLOCK * (q * lda + q + 1); + offsetA = parity ? r * BLOCK * size_t(lda) : BLOCK * (q * size_t(lda) + q + 1); if(!tensile_supports_ldc_ne_ldd) { @@ -1308,7 +1308,7 @@ rocblas_status special_trsm_template(rocblas_handle handle, alpha, B, compute_type, - j * BLOCK + w * B_chunk_size * ldb + j * size_t(BLOCK) + w * size_t(B_chunk_size) * size_t(ldb) + offset_Bin, ldb, stride_B, @@ -1334,7 +1334,7 @@ rocblas_status special_trsm_template(rocblas_handle handle, BLOCK, r ? &alpha_1 : alpha, invA, - size_t(j * BLOCK * BLOCK + offset_invAin), + size_t(j * size_t(BLOCK) * BLOCK + offset_invAin), size_t(BLOCK), stride_invA, (U)w_x_temp, @@ -1343,7 +1343,7 @@ rocblas_status special_trsm_template(rocblas_handle handle, stride_X, &beta_0, B, - size_t(w * B_chunk_size * ldb + j * BLOCK + offset_Bin), + size_t(w * size_t(B_chunk_size) * size_t(ldb) + j * size_t(BLOCK) + offset_Bin), size_t(ldb), stride_B, batch_count); @@ -1368,18 +1368,18 @@ rocblas_status special_trsm_template(rocblas_handle handle, width, stride_X, batch_count, - j * BLOCK * ldb + w * B_chunk_size + offset_Bin, + j * size_t(BLOCK) * size_t(ldb) + w * size_t(B_chunk_size) + offset_Bin, 0); if(r) { - rocblas_int offsetA = 0; - rocblas_int offsetB - = parity ? w * B_chunk_size + (q + 1) * BLOCK * ldb : w * B_chunk_size; + rocblas_stride offsetA = 0; + rocblas_stride offsetB + = parity ? w * size_t(B_chunk_size) + (q + 1) * BLOCK * size_t(ldb) : w * size_t(B_chunk_size); if(transA == rocblas_operation_none) - offsetA = parity ? BLOCK * (q * lda + q + 1) : r * BLOCK * lda; + offsetA = parity ? BLOCK * (q * size_t(lda) + q + 1) : r * BLOCK * size_t(lda); else - offsetA = parity ? BLOCK * (q * lda + q + lda) : r * BLOCK; + offsetA = parity ? BLOCK * (q * size_t(lda) + q + lda) : r * BLOCK; if(!tensile_supports_ldc_ne_ldd) { @@ -1432,7 +1432,7 @@ rocblas_status special_trsm_template(rocblas_handle handle, alpha, B, compute_type, - j * BLOCK * ldb + w * B_chunk_size + j * size_t(BLOCK) * size_t(ldb) + w * size_t(B_chunk_size) + offset_Bin, ldb, stride_B, @@ -1462,12 +1462,12 @@ rocblas_status special_trsm_template(rocblas_handle handle, width, stride_X, invA, - size_t(j * BLOCK * BLOCK + offset_invAin), + size_t(j * size_t(BLOCK) * BLOCK + offset_invAin), size_t(BLOCK), stride_invA, &beta_0, B, - size_t(w * B_chunk_size * ldb + j * BLOCK * ldb + offset_Bin), + size_t(w * size_t(B_chunk_size) * size_t(ldb) + j * size_t(BLOCK) * size_t(ldb) + offset_Bin), size_t(ldb), stride_B, batch_count); @@ -1933,7 +1933,7 @@ rocblas_trsm_small_right_device(rocblas_fill uplo, const int maxColB = (bx < gridDim.x - 1) ? NB : m - bx * NB; // offset B into correct block row - B += bx * NB; + B += size_t(bx) * NB; __shared__ T sA[NB * NB]; __shared__ T sB[NB * NB]; @@ -1944,18 +1944,18 @@ rocblas_trsm_small_right_device(rocblas_fill uplo, { // Load A into sA, handle conjugation if necessary for(int i = 0; i <= maxColA; i++) - sA[i * NB + tx] = (CONJ) ? conj(A[i * lda + tx]) : A[i * lda + tx]; + sA[i * size_t(NB) + tx] = (CONJ) ? conj(A[i * size_t(lda) + tx]) : A[i * size_t(lda) + tx]; // set unit diagonal if needed if(diag == rocblas_diagonal_unit) - sA[tx * NB + tx] = T(1.0); + sA[tx * size_t(NB) + tx] = T(1.0); } if(tx < maxColB) { // Load B into sB and multiply by alpha for(int i = 0; i < n; i++) - sB[i * NB + tx] = alpha * B[i * ldb + tx]; + sB[i * size_t(NB) + tx] = alpha * B[i * size_t(ldb) + tx]; } __syncthreads(); @@ -1966,48 +1966,48 @@ rocblas_trsm_small_right_device(rocblas_fill uplo, for(i = 0; i + 3 <= maxColA; i += 4) { // Subtract previously solved parts - resB[0] = sB[(i + 0) * NB + tx]; - resB[1] = sB[(i + 1) * NB + tx]; - resB[2] = sB[(i + 2) * NB + tx]; - resB[3] = sB[(i + 3) * NB + tx]; + resB[0] = sB[(i + 0) * size_t(NB) + tx]; + resB[1] = sB[(i + 1) * size_t(NB) + tx]; + resB[2] = sB[(i + 2) * size_t(NB) + tx]; + resB[3] = sB[(i + 3) * size_t(NB) + tx]; for(int j = 0; j < i; j++) { - T sB_reg = sB[j * NB + tx]; - resB[0] -= sB_reg * sA[(i + 0) * NB + j]; - resB[1] -= sB_reg * sA[(i + 1) * NB + j]; - resB[2] -= sB_reg * sA[(i + 2) * NB + j]; - resB[3] -= sB_reg * sA[(i + 3) * NB + j]; + T sB_reg = sB[j * size_t(NB) + tx]; + resB[0] -= sB_reg * sA[(i + 0) * size_t(NB) + j]; + resB[1] -= sB_reg * sA[(i + 1) * size_t(NB) + j]; + resB[2] -= sB_reg * sA[(i + 2) * size_t(NB) + j]; + resB[3] -= sB_reg * sA[(i + 3) * size_t(NB) + j]; } - resB[0] /= sA[(i + 0) * NB + (i + 0)]; - sB[(i + 0) * NB + tx] = resB[0]; + resB[0] /= sA[(i + 0) * size_t(NB) + (i + 0)]; + sB[(i + 0) * size_t(NB) + tx] = resB[0]; - resB[1] -= resB[0] * sA[(i + 1) * NB + (i + 0)]; - resB[1] /= sA[(i + 1) * NB + (i + 1)]; - sB[(i + 1) * NB + tx] = resB[1]; + resB[1] -= resB[0] * sA[(i + 1) * size_t(NB) + (i + 0)]; + resB[1] /= sA[(i + 1) * size_t(NB) + (i + 1)]; + sB[(i + 1) * size_t(NB) + tx] = resB[1]; - resB[2] -= resB[0] * sA[(i + 2) * NB + (i + 0)]; - resB[2] -= resB[1] * sA[(i + 2) * NB + (i + 1)]; - resB[2] /= sA[(i + 2) * NB + (i + 2)]; - sB[(i + 2) * NB + tx] = resB[2]; + resB[2] -= resB[0] * sA[(i + 2) * size_t(NB) + (i + 0)]; + resB[2] -= resB[1] * sA[(i + 2) * size_t(NB) + (i + 1)]; + resB[2] /= sA[(i + 2) * size_t(NB) + (i + 2)]; + sB[(i + 2) * size_t(NB) + tx] = resB[2]; - resB[3] -= resB[0] * sA[(i + 3) * NB + (i + 0)]; - resB[3] -= resB[1] * sA[(i + 3) * NB + (i + 1)]; - resB[3] -= resB[2] * sA[(i + 3) * NB + (i + 2)]; - resB[3] /= sA[(i + 3) * NB + (i + 3)]; - sB[(i + 3) * NB + tx] = resB[3]; + resB[3] -= resB[0] * sA[(i + 3) * size_t(NB) + (i + 0)]; + resB[3] -= resB[1] * sA[(i + 3) * size_t(NB) + (i + 1)]; + resB[3] -= resB[2] * sA[(i + 3) * size_t(NB) + (i + 2)]; + resB[3] /= sA[(i + 3) * size_t(NB) + (i + 3)]; + sB[(i + 3) * size_t(NB) + tx] = resB[3]; } // tail end if not divisible by 4 for(; i <= maxColA; i++) { - resB[0] = sB[i * NB + tx]; + resB[0] = sB[i * size_t(NB) + tx]; for(int j = 0; j < i; j++) { - resB[0] -= sB[j * NB + tx] * sA[i * NB + j]; + resB[0] -= sB[j * size_t(NB) + tx] * sA[i * size_t(NB) + j]; } - sB[i * NB + tx] = resB[0] / sA[i * NB + i]; + sB[i * size_t(NB) + tx] = resB[0] / sA[i * size_t(NB) + i]; } } else if(transA == rocblas_operation_none && uplo == rocblas_fill_lower) @@ -2015,47 +2015,47 @@ rocblas_trsm_small_right_device(rocblas_fill uplo, int i; for(i = maxColA; i >= 3; i -= 4) { - resB[0] = sB[(i - 0) * NB + tx]; - resB[1] = sB[(i - 1) * NB + tx]; - resB[2] = sB[(i - 2) * NB + tx]; - resB[3] = sB[(i - 3) * NB + tx]; + resB[0] = sB[(i - 0) * size_t(NB) + tx]; + resB[1] = sB[(i - 1) * size_t(NB) + tx]; + resB[2] = sB[(i - 2) * size_t(NB) + tx]; + resB[3] = sB[(i - 3) * size_t(NB) + tx]; for(int j = maxColA; j > i; j--) { - T sB_reg = sB[j * NB + tx]; - resB[0] -= sB_reg * sA[(i - 0) * NB + j]; - resB[1] -= sB_reg * sA[(i - 1) * NB + j]; - resB[2] -= sB_reg * sA[(i - 2) * NB + j]; - resB[3] -= sB_reg * sA[(i - 3) * NB + j]; + T sB_reg = sB[j * size_t(NB) + tx]; + resB[0] -= sB_reg * sA[(i - 0) * size_t(NB) + j]; + resB[1] -= sB_reg * sA[(i - 1) * size_t(NB) + j]; + resB[2] -= sB_reg * sA[(i - 2) * size_t(NB) + j]; + resB[3] -= sB_reg * sA[(i - 3) * size_t(NB) + j]; } - resB[0] /= sA[(i - 0) * NB + (i - 0)]; - sB[(i - 0) * NB + tx] = resB[0]; + resB[0] /= sA[(i - 0) * size_t(NB) + (i - 0)]; + sB[(i - 0) * size_t(NB) + tx] = resB[0]; - resB[1] -= resB[0] * sA[(i - 1) * NB + (i - 0)]; - resB[1] /= sA[(i - 1) * NB + (i - 1)]; - sB[(i - 1) * NB + tx] = resB[1]; + resB[1] -= resB[0] * sA[(i - 1) * size_t(NB) + (i - 0)]; + resB[1] /= sA[(i - 1) * size_t(NB) + (i - 1)]; + sB[(i - 1) * size_t(NB) + tx] = resB[1]; - resB[2] -= resB[0] * sA[(i - 2) * NB + (i - 0)]; - resB[2] -= resB[1] * sA[(i - 2) * NB + (i - 1)]; - resB[2] /= sA[(i - 2) * NB + (i - 2)]; - sB[(i - 2) * NB + tx] = resB[2]; + resB[2] -= resB[0] * sA[(i - 2) * size_t(NB) + (i - 0)]; + resB[2] -= resB[1] * sA[(i - 2) * size_t(NB) + (i - 1)]; + resB[2] /= sA[(i - 2) * size_t(NB) + (i - 2)]; + sB[(i - 2) * size_t(NB) + tx] = resB[2]; - resB[3] -= resB[0] * sA[(i - 3) * NB + (i - 0)]; - resB[3] -= resB[1] * sA[(i - 3) * NB + (i - 1)]; - resB[3] -= resB[2] * sA[(i - 3) * NB + (i - 2)]; - resB[3] /= sA[(i - 3) * NB + (i - 3)]; - sB[(i - 3) * NB + tx] = resB[3]; + resB[3] -= resB[0] * sA[(i - 3) * size_t(NB) + (i - 0)]; + resB[3] -= resB[1] * sA[(i - 3) * size_t(NB) + (i - 1)]; + resB[3] -= resB[2] * sA[(i - 3) * size_t(NB) + (i - 2)]; + resB[3] /= sA[(i - 3) * size_t(NB) + (i - 3)]; + sB[(i - 3) * size_t(NB) + tx] = resB[3]; } for(; i >= 0; i--) { - resB[0] = sB[i * NB + tx]; + resB[0] = sB[i * size_t(NB) + tx]; for(int j = maxColA; j > i; j--) { - resB[0] -= sB[j * NB + tx] * sA[i * NB + j]; + resB[0] -= sB[j * size_t(NB) + tx] * sA[i * size_t(NB) + j]; } - sB[i * NB + tx] = resB[0] / sA[i * NB + i]; + sB[i * size_t(NB) + tx] = resB[0] / sA[i * size_t(NB) + i]; } } else if(uplo == rocblas_fill_upper) @@ -2063,14 +2063,14 @@ rocblas_trsm_small_right_device(rocblas_fill uplo, int i; for(i = maxColA; i >= 3; i -= 4) { - resB[0] = sB[(i - 0) * NB + tx]; - resB[1] = sB[(i - 1) * NB + tx]; - resB[2] = sB[(i - 2) * NB + tx]; - resB[3] = sB[(i - 3) * NB + tx]; + resB[0] = sB[(i - 0) * size_t(NB) + tx]; + resB[1] = sB[(i - 1) * size_t(NB) + tx]; + resB[2] = sB[(i - 2) * size_t(NB) + tx]; + resB[3] = sB[(i - 3) * size_t(NB) + tx]; for(int j = maxColA; j > i; j--) { - rocblas_int col_off = j * NB; + size_t col_off = j * size_t(NB); T sB_reg = sB[col_off + tx]; resB[0] -= sB_reg * sA[col_off + (i - 0)]; resB[1] -= sB_reg * sA[col_off + (i - 1)]; @@ -2078,33 +2078,33 @@ rocblas_trsm_small_right_device(rocblas_fill uplo, resB[3] -= sB_reg * sA[col_off + (i - 3)]; } - resB[0] /= sA[(i - 0) * NB + (i - 0)]; - sB[(i - 0) * NB + tx] = resB[0]; + resB[0] /= sA[(i - 0) * size_t(NB) + (i - 0)]; + sB[(i - 0) * size_t(NB) + tx] = resB[0]; - resB[1] -= resB[0] * sA[(i - 0) * NB + (i - 1)]; - resB[1] /= sA[(i - 1) * NB + (i - 1)]; - sB[(i - 1) * NB + tx] = resB[1]; + resB[1] -= resB[0] * sA[(i - 0) * size_t(NB) + (i - 1)]; + resB[1] /= sA[(i - 1) * size_t(NB) + (i - 1)]; + sB[(i - 1) * size_t(NB) + tx] = resB[1]; - resB[2] -= resB[0] * sA[(i - 0) * NB + (i - 2)]; - resB[2] -= resB[1] * sA[(i - 1) * NB + (i - 2)]; - resB[2] /= sA[(i - 2) * NB + (i - 2)]; - sB[(i - 2) * NB + tx] = resB[2]; + resB[2] -= resB[0] * sA[(i - 0) * size_t(NB) + (i - 2)]; + resB[2] -= resB[1] * sA[(i - 1) * size_t(NB) + (i - 2)]; + resB[2] /= sA[(i - 2) * size_t(NB) + (i - 2)]; + sB[(i - 2) * size_t(NB) + tx] = resB[2]; - resB[3] -= resB[0] * sA[(i - 0) * NB + (i - 3)]; - resB[3] -= resB[1] * sA[(i - 1) * NB + (i - 3)]; - resB[3] -= resB[2] * sA[(i - 2) * NB + (i - 3)]; - resB[3] /= sA[(i - 3) * NB + (i - 3)]; - sB[(i - 3) * NB + tx] = resB[3]; + resB[3] -= resB[0] * sA[(i - 0) * size_t(NB) + (i - 3)]; + resB[3] -= resB[1] * sA[(i - 1) * size_t(NB) + (i - 3)]; + resB[3] -= resB[2] * sA[(i - 2) * size_t(NB) + (i - 3)]; + resB[3] /= sA[(i - 3) * size_t(NB) + (i - 3)]; + sB[(i - 3) * size_t(NB) + tx] = resB[3]; } for(; i >= 0; i--) { - resB[0] = sB[i * NB + tx]; + resB[0] = sB[i * size_t(NB) + tx]; for(int j = maxColA; j > i; j--) { - resB[0] -= sB[j * NB + tx] * sA[j * NB + i]; + resB[0] -= sB[j * size_t(NB) + tx] * sA[j * size_t(NB) + i]; } - sB[i * NB + tx] = resB[0] / sA[i * NB + i]; + sB[i * size_t(NB) + tx] = resB[0] / sA[i * size_t(NB) + i]; } } else // lower (conjugate-)transpose @@ -2113,14 +2113,14 @@ rocblas_trsm_small_right_device(rocblas_fill uplo, for(i = 0; i + 3 <= maxColA; i += 4) { // Subtract previously solved parts - resB[0] = sB[(i + 0) * NB + tx]; - resB[1] = sB[(i + 1) * NB + tx]; - resB[2] = sB[(i + 2) * NB + tx]; - resB[3] = sB[(i + 3) * NB + tx]; + resB[0] = sB[(i + 0) * size_t(NB) + tx]; + resB[1] = sB[(i + 1) * size_t(NB) + tx]; + resB[2] = sB[(i + 2) * size_t(NB) + tx]; + resB[3] = sB[(i + 3) * size_t(NB) + tx]; for(int j = 0; j < i; j++) { - rocblas_int col_off = j * NB; + size_t col_off = j * size_t(NB); T sB_reg = sB[col_off + tx]; resB[0] -= sB_reg * sA[col_off + (i + 0)]; resB[1] -= sB_reg * sA[col_off + (i + 1)]; @@ -2128,34 +2128,34 @@ rocblas_trsm_small_right_device(rocblas_fill uplo, resB[3] -= sB_reg * sA[col_off + (i + 3)]; } - resB[0] /= sA[(i + 0) * NB + (i + 0)]; - sB[(i + 0) * NB + tx] = resB[0]; + resB[0] /= sA[(i + 0) * size_t(NB) + (i + 0)]; + sB[(i + 0) * size_t(NB) + tx] = resB[0]; - resB[1] -= resB[0] * sA[(i + 0) * NB + (i + 1)]; - resB[1] /= sA[(i + 1) * NB + (i + 1)]; - sB[(i + 1) * NB + tx] = resB[1]; + resB[1] -= resB[0] * sA[(i + 0) * size_t(NB) + (i + 1)]; + resB[1] /= sA[(i + 1) * size_t(NB) + (i + 1)]; + sB[(i + 1) * size_t(NB) + tx] = resB[1]; - resB[2] -= resB[0] * sA[(i + 0) * NB + (i + 2)]; - resB[2] -= resB[1] * sA[(i + 1) * NB + (i + 2)]; - resB[2] /= sA[(i + 2) * NB + (i + 2)]; - sB[(i + 2) * NB + tx] = resB[2]; + resB[2] -= resB[0] * sA[(i + 0) * size_t(NB) + (i + 2)]; + resB[2] -= resB[1] * sA[(i + 1) * size_t(NB) + (i + 2)]; + resB[2] /= sA[(i + 2) * size_t(NB) + (i + 2)]; + sB[(i + 2) * size_t(NB) + tx] = resB[2]; - resB[3] -= resB[0] * sA[(i + 0) * NB + (i + 3)]; - resB[3] -= resB[1] * sA[(i + 1) * NB + (i + 3)]; - resB[3] -= resB[2] * sA[(i + 2) * NB + (i + 3)]; - resB[3] /= sA[(i + 3) * NB + (i + 3)]; - sB[(i + 3) * NB + tx] = resB[3]; + resB[3] -= resB[0] * sA[(i + 0) * size_t(NB) + (i + 3)]; + resB[3] -= resB[1] * sA[(i + 1) * size_t(NB) + (i + 3)]; + resB[3] -= resB[2] * sA[(i + 2) * size_t(NB) + (i + 3)]; + resB[3] /= sA[(i + 3) * size_t(NB) + (i + 3)]; + sB[(i + 3) * size_t(NB) + tx] = resB[3]; } // tail end if not divisible by 4 for(; i <= maxColA; i++) { - resB[0] = sB[i * NB + tx]; + resB[0] = sB[i * size_t(NB) + tx]; for(int j = 0; j < i; j++) { - resB[0] -= sB[j * NB + tx] * sA[j * NB + i]; + resB[0] -= sB[j * size_t(NB) + tx] * sA[j * size_t(NB) + i]; } - sB[i * NB + tx] = resB[0] / sA[i * NB + i]; + sB[i * size_t(NB) + tx] = resB[0] / sA[i * size_t(NB) + i]; } } @@ -2163,7 +2163,7 @@ rocblas_trsm_small_right_device(rocblas_fill uplo, if(tx < maxColB) { for(int i = 0; i < n; i++) - B[i * ldb + tx] = sB[i * NB + tx]; + B[i * size_t(ldb) + tx] = sB[i * size_t(NB) + tx]; } } @@ -2205,7 +2205,7 @@ rocblas_trsm_small_64_right_device(rocblas_fill uplo, const int maxColB = (bx < gridDim.x - 1) ? NB : m - bx * NB; // offset B into correct block row - B += bx * NB; + B += bx * size_t(NB); __shared__ T sB[NB * NB]; @@ -2213,7 +2213,7 @@ rocblas_trsm_small_64_right_device(rocblas_fill uplo, { // Load B into sB and multiply by alpha for(int i = 0; i < n; i++) - sB[i * NB + tx] = alpha * B[i * ldb + tx]; + sB[i * size_t(NB) + tx] = alpha * B[i * size_t(ldb) + tx]; } __syncthreads(); // Solve for B in shared memory @@ -2224,61 +2224,61 @@ rocblas_trsm_small_64_right_device(rocblas_fill uplo, for(int i = 0; i <= maxColA; i++) { // Subtract previously solved parts - T temp_reg_B = sB[i * NB + tx]; + T temp_reg_B = sB[i * size_t(NB) + tx]; for(int j = 0; j < i; j++) { - T valA = A[i * lda + j]; - temp_reg_B -= sB[j * NB + tx] * valA; + T valA = A[i * size_t(lda) + j]; + temp_reg_B -= sB[j * size_t(NB) + tx] * valA; } // Solve - sB[i * NB + tx] = temp_reg_B; + sB[i * size_t(NB) + tx] = temp_reg_B; if(diag != rocblas_diagonal_unit) - sB[i * NB + tx] /= A[i * lda + i]; + sB[i * size_t(NB) + tx] /= A[i * size_t(lda) + i]; } } else if(transA == rocblas_operation_none && uplo == rocblas_fill_lower) { for(int i = maxColA; i >= 0; i--) { - T temp_reg_B = sB[i * NB + tx]; + T temp_reg_B = sB[i * size_t(NB) + tx]; for(int j = maxColA; j > i; j--) { - T valA = A[i * lda + j]; - temp_reg_B -= sB[j * NB + tx] * valA; + T valA = A[i * size_t(lda) + j]; + temp_reg_B -= sB[j * size_t(NB) + tx] * valA; } - sB[i * NB + tx] = temp_reg_B; + sB[i * size_t(NB) + tx] = temp_reg_B; if(diag != rocblas_diagonal_unit) - sB[i * NB + tx] /= A[i * lda + i]; + sB[i * size_t(NB) + tx] /= A[i * size_t(lda) + i]; } } else if(uplo == rocblas_fill_upper) { for(int i = maxColA; i >= 0; i--) { - T temp_reg_B = sB[i * NB + tx]; + T temp_reg_B = sB[i * size_t(NB) + tx]; for(int j = maxColA; j > i; j--) { - T valA = CONJ ? conj(A[j * lda + i]) : A[j * lda + i]; - temp_reg_B -= sB[j * NB + tx] * valA; + T valA = CONJ ? conj(A[j * size_t(lda) + i]) : A[j * size_t(lda) + i]; + temp_reg_B -= sB[j * size_t(NB) + tx] * valA; } - sB[i * NB + tx] = temp_reg_B; + sB[i * size_t(NB) + tx] = temp_reg_B; if(diag != rocblas_diagonal_unit) - sB[i * NB + tx] /= CONJ ? conj(A[i * lda + i]) : A[i * lda + i]; + sB[i * size_t(NB) + tx] /= CONJ ? conj(A[i * size_t(lda) + i]) : A[i * size_t(lda) + i]; } } else // lower (conjugate-)transpose { for(int i = 0; i <= maxColA; i++) { - T temp_reg_B = sB[i * NB + tx]; + T temp_reg_B = sB[i * size_t(NB) + tx]; for(int j = 0; j < i; j++) { - T valA = CONJ ? conj(A[j * lda + i]) : A[j * lda + i]; - temp_reg_B -= sB[j * NB + tx] * valA; + T valA = CONJ ? conj(A[j * size_t(lda) + i]) : A[j * size_t(lda) + i]; + temp_reg_B -= sB[j * size_t(NB) + tx] * valA; } - sB[i * NB + tx] = temp_reg_B; + sB[i * size_t(NB) + tx] = temp_reg_B; if(diag != rocblas_diagonal_unit) - sB[i * NB + tx] /= CONJ ? conj(A[i * lda + i]) : A[i * lda + i]; + sB[i * size_t(NB) + tx] /= CONJ ? conj(A[i * size_t(lda) + i]) : A[i * size_t(lda) + i]; } } @@ -2286,7 +2286,7 @@ rocblas_trsm_small_64_right_device(rocblas_fill uplo, if(tx < maxColB) { for(int i = 0; i < n; i++) - B[i * ldb + tx] = sB[i * NB + tx]; + B[i * size_t(ldb) + tx] = sB[i * size_t(NB) + tx]; } } @@ -2330,7 +2330,7 @@ rocblas_trsm_small_left_device(rocblas_fill uplo, const int maxColB = (bx < gridDim.x - 1) ? NB : n - bx * NB; // offset B into correct block column - B += bx * NB * ldb; + B += bx * NB * size_t(ldb); // shared A and shared B __shared__ T sA[NB * NB]; @@ -2343,7 +2343,7 @@ rocblas_trsm_small_left_device(rocblas_fill uplo, { // Load A into sA, handle conjugation if necessary for(int i = 0; i <= maxColA; i++) - sA[i * NB + tx] = (CONJ) ? conj(A[i * lda + tx]) : A[i * lda + tx]; + sA[i * size_t(NB) + tx] = (CONJ) ? conj(A[i * size_t(lda) + tx]) : A[i * size_t(lda) + tx]; // set unit diagonal if needed if(diag == rocblas_diagonal_unit) @@ -2351,7 +2351,7 @@ rocblas_trsm_small_left_device(rocblas_fill uplo, // Load B into sB and multiply by alpha for(int i = 0; i < maxColB; i++) - sB[i * NB + tx] = alpha * B[i * ldb + tx]; + sB[i * size_t(NB) + tx] = alpha * B[i * size_t(ldb) + tx]; } __syncthreads(); @@ -2369,7 +2369,7 @@ rocblas_trsm_small_left_device(rocblas_fill uplo, for(int j = 0; j < i; j++) { - rocblas_int col_off = j * NB; + size_t col_off = j * size_t(NB); T sB_reg = sB[sb_col + j]; resB[0] -= sB_reg * sA[col_off + i]; resB[1] -= sB_reg * sA[col_off + (i + 1)]; @@ -2377,22 +2377,22 @@ rocblas_trsm_small_left_device(rocblas_fill uplo, resB[3] -= sB_reg * sA[col_off + (i + 3)]; } - resB[0] /= sA[(i + 0) * NB + (i + 0)]; + resB[0] /= sA[(i + 0) * size_t(NB) + (i + 0)]; sB[sb_col + i + 0] = resB[0]; - resB[1] -= resB[0] * sA[(i + 0) * NB + (i + 1)]; - resB[1] /= sA[(i + 1) * NB + (i + 1)]; + resB[1] -= resB[0] * sA[(i + 0) * size_t(NB) + (i + 1)]; + resB[1] /= sA[(i + 1) * size_t(NB) + (i + 1)]; sB[sb_col + i + 1] = resB[1]; - resB[2] -= resB[0] * sA[(i + 0) * NB + (i + 2)]; - resB[2] -= resB[1] * sA[(i + 1) * NB + (i + 2)]; - resB[2] /= sA[(i + 2) * NB + (i + 2)]; + resB[2] -= resB[0] * sA[(i + 0) * size_t(NB) + (i + 2)]; + resB[2] -= resB[1] * sA[(i + 1) * size_t(NB) + (i + 2)]; + resB[2] /= sA[(i + 2) * size_t(NB) + (i + 2)]; sB[sb_col + i + 2] = resB[2]; - resB[3] -= resB[0] * sA[(i + 0) * NB + (i + 3)]; - resB[3] -= resB[1] * sA[(i + 1) * NB + (i + 3)]; - resB[3] -= resB[2] * sA[(i + 2) * NB + (i + 3)]; - resB[3] /= sA[(i + 3) * NB + (i + 3)]; + resB[3] -= resB[0] * sA[(i + 0) * size_t(NB) + (i + 3)]; + resB[3] -= resB[1] * sA[(i + 1) * size_t(NB) + (i + 3)]; + resB[3] -= resB[2] * sA[(i + 2) * size_t(NB) + (i + 3)]; + resB[3] /= sA[(i + 3) * size_t(NB) + (i + 3)]; sB[sb_col + i + 3] = resB[3]; } @@ -2402,9 +2402,9 @@ rocblas_trsm_small_left_device(rocblas_fill uplo, resB[0] = sB[sb_col + i]; for(int j = 0; j < i; j++) { - resB[0] -= sB[sb_col + j] * sA[j * NB + i]; + resB[0] -= sB[sb_col + j] * sA[j * size_t(NB) + i]; } - sB[sb_col + i] = resB[0] / sA[i * NB + i]; + sB[sb_col + i] = resB[0] / sA[i * size_t(NB) + i]; } } else if(!LOWER && transA == rocblas_operation_none) @@ -2419,7 +2419,7 @@ rocblas_trsm_small_left_device(rocblas_fill uplo, for(int j = maxColA; j > i; j--) { - rocblas_int col_off = j * NB; + size_t col_off = j * size_t(NB); T sB_reg = sB[sb_col + j]; resB[0] -= sB_reg * sA[col_off + (i - 0)]; resB[1] -= sB_reg * sA[col_off + (i - 1)]; @@ -2427,22 +2427,22 @@ rocblas_trsm_small_left_device(rocblas_fill uplo, resB[3] -= sB_reg * sA[col_off + (i - 3)]; } - resB[0] /= sA[(i - 0) * NB + (i - 0)]; + resB[0] /= sA[(i - 0) * size_t(NB) + (i - 0)]; sB[sb_col + i - 0] = resB[0]; - resB[1] -= resB[0] * sA[(i - 0) * NB + (i - 1)]; - resB[1] /= sA[(i - 1) * NB + (i - 1)]; + resB[1] -= resB[0] * sA[(i - 0) * size_t(NB) + (i - 1)]; + resB[1] /= sA[(i - 1) * size_t(NB) + (i - 1)]; sB[sb_col + i - 1] = resB[1]; - resB[2] -= resB[0] * sA[(i - 0) * NB + (i - 2)]; - resB[2] -= resB[1] * sA[(i - 1) * NB + (i - 2)]; - resB[2] /= sA[(i - 2) * NB + (i - 2)]; + resB[2] -= resB[0] * sA[(i - 0) * size_t(NB) + (i - 2)]; + resB[2] -= resB[1] * sA[(i - 1) * size_t(NB) + (i - 2)]; + resB[2] /= sA[(i - 2) * size_t(NB) + (i - 2)]; sB[sb_col + i - 2] = resB[2]; - resB[3] -= resB[0] * sA[(i - 0) * NB + (i - 3)]; - resB[3] -= resB[1] * sA[(i - 1) * NB + (i - 3)]; - resB[3] -= resB[2] * sA[(i - 2) * NB + (i - 3)]; - resB[3] /= sA[(i - 3) * NB + (i - 3)]; + resB[3] -= resB[0] * sA[(i - 0) * size_t(NB) + (i - 3)]; + resB[3] -= resB[1] * sA[(i - 1) * size_t(NB) + (i - 3)]; + resB[3] -= resB[2] * sA[(i - 2) * size_t(NB) + (i - 3)]; + resB[3] /= sA[(i - 3) * size_t(NB) + (i - 3)]; sB[sb_col + i - 3] = resB[3]; } @@ -2451,9 +2451,9 @@ rocblas_trsm_small_left_device(rocblas_fill uplo, resB[0] = sB[sb_col + i]; for(int j = maxColA; j > i; j--) { - resB[0] -= sB[sb_col + j] * sA[j * NB + i]; + resB[0] -= sB[sb_col + j] * sA[j * size_t(NB) + i]; } - sB[sb_col + i] = resB[0] / sA[i * NB + i]; + sB[sb_col + i] = resB[0] / sA[i * size_t(NB) + i]; } } else if(LOWER) @@ -2469,28 +2469,28 @@ rocblas_trsm_small_left_device(rocblas_fill uplo, for(int j = maxColA; j > i; j--) { T sB_reg = sB[sb_col + j]; - resB[0] -= sB_reg * sA[(i - 0) * NB + j]; - resB[1] -= sB_reg * sA[(i - 1) * NB + j]; - resB[2] -= sB_reg * sA[(i - 2) * NB + j]; - resB[3] -= sB_reg * sA[(i - 3) * NB + j]; + resB[0] -= sB_reg * sA[(i - 0) * size_t(NB) + j]; + resB[1] -= sB_reg * sA[(i - 1) * size_t(NB) + j]; + resB[2] -= sB_reg * sA[(i - 2) * size_t(NB) + j]; + resB[3] -= sB_reg * sA[(i - 3) * size_t(NB) + j]; } - resB[0] /= sA[(i - 0) * NB + (i - 0)]; + resB[0] /= sA[(i - 0) * size_t(NB) + (i - 0)]; sB[sb_col + i - 0] = resB[0]; - resB[1] -= resB[0] * sA[(i - 1) * NB + (i - 0)]; - resB[1] /= sA[(i - 1) * NB + (i - 1)]; + resB[1] -= resB[0] * sA[(i - 1) * size_t(NB) + (i - 0)]; + resB[1] /= sA[(i - 1) * size_t(NB) + (i - 1)]; sB[sb_col + i - 1] = resB[1]; - resB[2] -= resB[0] * sA[(i - 2) * NB + (i - 0)]; - resB[2] -= resB[1] * sA[(i - 2) * NB + (i - 1)]; - resB[2] /= sA[(i - 2) * NB + (i - 2)]; + resB[2] -= resB[0] * sA[(i - 2) * size_t(NB) + (i - 0)]; + resB[2] -= resB[1] * sA[(i - 2) * size_t(NB) + (i - 1)]; + resB[2] /= sA[(i - 2) * size_t(NB) + (i - 2)]; sB[sb_col + i - 2] = resB[2]; - resB[3] -= resB[0] * sA[(i - 3) * NB + (i - 0)]; - resB[3] -= resB[1] * sA[(i - 3) * NB + (i - 1)]; - resB[3] -= resB[2] * sA[(i - 3) * NB + (i - 2)]; - resB[3] /= sA[(i - 3) * NB + (i - 3)]; + resB[3] -= resB[0] * sA[(i - 3) * size_t(NB) + (i - 0)]; + resB[3] -= resB[1] * sA[(i - 3) * size_t(NB) + (i - 1)]; + resB[3] -= resB[2] * sA[(i - 3) * size_t(NB) + (i - 2)]; + resB[3] /= sA[(i - 3) * size_t(NB) + (i - 3)]; sB[sb_col + i - 3] = resB[3]; } @@ -2499,9 +2499,9 @@ rocblas_trsm_small_left_device(rocblas_fill uplo, resB[0] = sB[sb_col + i]; for(int j = maxColA; j > i; j--) { - resB[0] -= sB[sb_col + j] * sA[i * NB + j]; + resB[0] -= sB[sb_col + j] * sA[i * size_t(NB) + j]; } - sB[sb_col + i] = resB[0] / sA[i * NB + i]; + sB[sb_col + i] = resB[0] / sA[i * size_t(NB) + i]; } } else if(!LOWER) @@ -2518,28 +2518,28 @@ rocblas_trsm_small_left_device(rocblas_fill uplo, for(int j = 0; j < i; j++) { T sB_reg = sB[sb_col + j]; - resB[0] -= sB_reg * sA[(i + 0) * NB + j]; - resB[1] -= sB_reg * sA[(i + 1) * NB + j]; - resB[2] -= sB_reg * sA[(i + 2) * NB + j]; - resB[3] -= sB_reg * sA[(i + 3) * NB + j]; + resB[0] -= sB_reg * sA[(i + 0) * size_t(NB) + j]; + resB[1] -= sB_reg * sA[(i + 1) * size_t(NB) + j]; + resB[2] -= sB_reg * sA[(i + 2) * size_t(NB) + j]; + resB[3] -= sB_reg * sA[(i + 3) * size_t(NB) + j]; } - resB[0] /= sA[(i + 0) * NB + (i + 0)]; + resB[0] /= sA[(i + 0) * size_t(NB) + (i + 0)]; sB[sb_col + i + 0] = resB[0]; - resB[1] -= resB[0] * sA[(i + 1) * NB + (i + 0)]; - resB[1] /= sA[(i + 1) * NB + (i + 1)]; + resB[1] -= resB[0] * sA[(i + 1) * size_t(NB) + (i + 0)]; + resB[1] /= sA[(i + 1) * size_t(NB) + (i + 1)]; sB[sb_col + i + 1] = resB[1]; - resB[2] -= resB[0] * sA[(i + 2) * NB + (i + 0)]; - resB[2] -= resB[1] * sA[(i + 2) * NB + (i + 1)]; - resB[2] /= sA[(i + 2) * NB + (i + 2)]; + resB[2] -= resB[0] * sA[(i + 2) * size_t(NB) + (i + 0)]; + resB[2] -= resB[1] * sA[(i + 2) * size_t(NB) + (i + 1)]; + resB[2] /= sA[(i + 2) * size_t(NB) + (i + 2)]; sB[sb_col + i + 2] = resB[2]; - resB[3] -= resB[0] * sA[(i + 3) * NB + (i + 0)]; - resB[3] -= resB[1] * sA[(i + 3) * NB + (i + 1)]; - resB[3] -= resB[2] * sA[(i + 3) * NB + (i + 2)]; - resB[3] /= sA[(i + 3) * NB + (i + 3)]; + resB[3] -= resB[0] * sA[(i + 3) * size_t(NB) + (i + 0)]; + resB[3] -= resB[1] * sA[(i + 3) * size_t(NB) + (i + 1)]; + resB[3] -= resB[2] * sA[(i + 3) * size_t(NB) + (i + 2)]; + resB[3] /= sA[(i + 3) * size_t(NB) + (i + 3)]; sB[sb_col + i + 3] = resB[3]; } @@ -2549,9 +2549,9 @@ rocblas_trsm_small_left_device(rocblas_fill uplo, resB[0] = sB[sb_col + i]; for(int j = 0; j < i; j++) { - resB[0] -= sB[sb_col + j] * sA[i * NB + j]; + resB[0] -= sB[sb_col + j] * sA[i * size_t(NB) + j]; } - sB[sb_col + i] = resB[0] / sA[i * NB + i]; + sB[sb_col + i] = resB[0] / sA[i * size_t(NB) + i]; } } @@ -2561,7 +2561,7 @@ rocblas_trsm_small_left_device(rocblas_fill uplo, if(tx < m) { for(int i = 0; i < maxColB; i++) - B[i * ldb + tx] = sB[i * NB + tx]; + B[i * size_t(ldb) + tx] = sB[i * size_t(NB) + tx]; } } @@ -2603,7 +2603,7 @@ rocblas_trsm_small_64_left_device(rocblas_fill uplo, const int maxColB = (bx < gridDim.x - 1) ? NB : n - bx * NB; // offset B into correct block column - B += bx * NB * ldb; + B += bx * NB * size_t(ldb); // shared B __shared__ T sB[NB * NB]; @@ -2611,7 +2611,7 @@ rocblas_trsm_small_64_left_device(rocblas_fill uplo, { // Load B into sB and multiply by alpha for(int i = 0; i < maxColB; i++) - sB[i * NB + tx] = alpha * B[i * ldb + tx]; + sB[i * size_t(NB) + tx] = alpha * B[i * size_t(ldb) + tx]; } __syncthreads(); @@ -2624,56 +2624,56 @@ rocblas_trsm_small_64_left_device(rocblas_fill uplo, // Subtract previously solved parts for(int j = 0; j < i; j++) { - T valA = A[j * lda + i]; - sB[tx * NB + i] -= sB[tx * NB + j] * valA; + T valA = A[j * size_t(lda) + i]; + sB[tx * size_t(NB) + i] -= sB[tx * size_t(NB) + j] * valA; } if(diag != rocblas_diagonal_unit) - sB[tx * NB + i] /= A[i * lda + i]; + sB[tx * size_t(NB) + i] /= A[i * size_t(lda) + i]; } } else if(!LOWER && transA == rocblas_operation_none) { for(int i = maxColA; i >= 0; i--) { - T temp_reg_B = sB[tx * NB + i]; + T temp_reg_B = sB[tx * size_t(NB) + i]; for(int j = maxColA; j > i; j--) { - T valA = A[j * lda + i]; - temp_reg_B -= sB[tx * NB + j] * valA; + T valA = A[j * size_t(lda) + i]; + temp_reg_B -= sB[tx * size_t(NB) + j] * valA; } - sB[tx * NB + i] = temp_reg_B; + sB[tx * size_t(NB) + i] = temp_reg_B; if(diag != rocblas_diagonal_unit) - sB[tx * NB + i] /= A[i * lda + i]; + sB[tx * size_t(NB) + i] /= A[i * size_t(lda) + i]; } } else if(LOWER) { for(int i = maxColA; i >= 0; i--) { - T temp_reg_B = sB[tx * NB + i]; + T temp_reg_B = sB[tx * size_t(NB) + i]; for(int j = maxColA; j > i; j--) { - T valA = (CONJ) ? conj(A[i * lda + j]) : A[i * lda + j]; - temp_reg_B -= sB[tx * NB + j] * valA; + T valA = (CONJ) ? conj(A[i * size_t(lda) + j]) : A[i * size_t(lda) + j]; + temp_reg_B -= sB[tx * size_t(NB) + j] * valA; } - sB[tx * NB + i] = temp_reg_B; + sB[tx * size_t(NB) + i] = temp_reg_B; if(diag != rocblas_diagonal_unit) - sB[tx * NB + i] /= (CONJ) ? conj(A[i * lda + i]) : A[i * lda + i]; + sB[tx * size_t(NB) + i] /= (CONJ) ? conj(A[i * size_t(lda) + i]) : A[i * size_t(lda) + i]; } } else if(!LOWER) { for(int i = 0; i <= maxColA; i++) { - T temp_reg_B = sB[tx * NB + i]; + T temp_reg_B = sB[tx * size_t(NB) + i]; for(int j = 0; j < i; j++) { - T valA = (CONJ) ? conj(A[i * lda + j]) : A[i * lda + j]; - temp_reg_B -= sB[tx * NB + j] * valA; + T valA = (CONJ) ? conj(A[i * size_t(lda) + j]) : A[i * size_t(lda) + j]; + temp_reg_B -= sB[tx * size_t(NB) + j] * valA; } - sB[tx * NB + i] = temp_reg_B; + sB[tx * size_t(NB) + i] = temp_reg_B; if(diag != rocblas_diagonal_unit) - sB[tx * NB + i] /= (CONJ) ? conj(A[i * lda + i]) : A[i * lda + i]; + sB[tx * size_t(NB) + i] /= (CONJ) ? conj(A[i * size_t(lda) + i]) : A[i * size_t(lda) + i]; } } @@ -2683,7 +2683,7 @@ rocblas_trsm_small_64_left_device(rocblas_fill uplo, if(tx < m) { for(int i = 0; i < maxColB; i++) - B[i * ldb + tx] = sB[i * NB + tx]; + B[i * size_t(ldb) + tx] = sB[i * size_t(NB) + tx]; } } @@ -2808,7 +2808,7 @@ ROCBLAS_KERNEL_NO_BOUNDS rocblas_trsm_block_backward_substitution(int if(offY < n && tx < m) { - T valB = alpha * B[offY * ldb_norm + tx * ldb_trans]; + T valB = alpha * B[offY * size_t(ldb_norm) + tx * size_t(ldb_trans)]; for(int i = m - 1; i > 0; i--) { // tx is row of B, ty is col of B @@ -2817,23 +2817,23 @@ ROCBLAS_KERNEL_NO_BOUNDS rocblas_trsm_block_backward_substitution(int if(tx == i) { // solve cur row - valB = UNIT ? valB : valB / A[tx * lda_norm + tx * lda_trans]; + valB = UNIT ? valB : valB / A[tx * size_t(lda_norm) + tx * size_t(lda_trans)]; sB[ty] = valB; } __syncthreads(); if(tx < i) - valB -= (CONJ ? conj(A[i * lda_norm + tx * lda_trans]) - : A[i * lda_norm + tx * lda_trans]) + valB -= (CONJ ? conj(A[i * size_t(lda_norm) + tx * size_t(lda_trans)]) + : A[i * size_t(lda_norm) + tx * size_t(lda_trans)]) * sB[ty]; } if(!UNIT && tx == 0) - valB /= A[tx * lda_norm + tx * lda_trans]; + valB /= A[tx * size_t(lda_norm) + tx * size_t(lda_trans)]; // store back to mem - B[offY * ldb_norm + tx * ldb_trans] = valB; + B[offY * size_t(ldb_norm) + tx * size_t(ldb_trans)] = valB; } } @@ -2876,7 +2876,7 @@ ROCBLAS_KERNEL_NO_BOUNDS rocblas_trsm_block_forward_substitution(int if(offY < n && tx < m) { - T valB = alpha * B[offY * ldb_norm + tx * ldb_trans]; + T valB = alpha * B[offY * size_t(ldb_norm) + tx * size_t(ldb_trans)]; for(int i = 0; i < m - 1; i++) { // tx is row of B, ty is col of B @@ -2885,22 +2885,22 @@ ROCBLAS_KERNEL_NO_BOUNDS rocblas_trsm_block_forward_substitution(int if(tx == i) { // solve cur row - valB = UNIT ? valB : valB / A[tx * lda_norm + tx * lda_trans]; + valB = UNIT ? valB : valB / A[tx * size_t(lda_norm) + tx * size_t(lda_trans)]; sB[ty] = valB; } __syncthreads(); if(tx > i) - valB -= (CONJ ? conj(A[i * lda_norm + tx * lda_trans]) - : A[i * lda_norm + tx * lda_trans]) + valB -= (CONJ ? conj(A[i * size_t(lda_norm) + tx * size_t(lda_trans)]) + : A[i * size_t(lda_norm) + tx * size_t(lda_trans)]) * sB[ty]; } if(!UNIT && tx == m - 1) - valB /= A[tx * lda_norm + tx * lda_trans]; + valB /= A[tx * size_t(lda_norm) + tx * size_t(lda_trans)]; // store back to mem - B[offY * ldb_norm + tx * ldb_trans] = valB; + B[offY * size_t(ldb_norm) + tx * size_t(ldb_trans)] = valB; } } @@ -2949,7 +2949,7 @@ void rocblas_trsm_small_substitution(rocblas_handle handle, T negative_one = -1; T one = 1; rocblas_int j = 0; - rocblas_int offA_sub, offB_sub; + size_t offA_sub, offB_sub; size_t smem_size; // Different kernels for forward substitution vs. backward substitution @@ -2959,17 +2959,17 @@ void rocblas_trsm_small_substitution(rocblas_handle handle, { const rocblas_int j_next = j + NBX; - rocblas_int offA_gemm = LEFT ? (!TRANSA ? j * lda + j_next : j + j_next * lda) - : (!TRANSA ? j + j_next * lda : j * lda + j_next); - rocblas_int offB_gemm = LEFT ? j : j * ldb; - rocblas_int offC_gemm = LEFT ? j_next : j_next * ldb; + size_t offA_gemm = LEFT ? (!TRANSA ? j * size_t(lda) + j_next : j + j_next * size_t(lda)) + : (!TRANSA ? j + j_next * size_t(lda) : j * size_t(lda) + j_next); + size_t offB_gemm = LEFT ? j : j * size_t(ldb); + size_t offC_gemm = LEFT ? j_next : j_next * size_t(ldb); smem_size = (1024 / NBX) * sizeof(T); // 1. call trsm subtitution/solve if(FORWARD_SUB) { - offA_sub = j * lda + j; - offB_sub = LEFT ? j : j * ldb; + offA_sub = j * size_t(lda) + j; + offB_sub = LEFT ? j : j * size_t(ldb); hipLaunchKernelGGL((rocblas_trsm_block_forward_substitution Date: Mon, 6 Mar 2023 14:37:58 -0500 Subject: [PATCH 02/11] remove redundant trsm offset datatype promotions --- library/src/blas3/rocblas_trsm.hpp | 81 +++++++++++++++++------------- 1 file changed, 46 insertions(+), 35 deletions(-) diff --git a/library/src/blas3/rocblas_trsm.hpp b/library/src/blas3/rocblas_trsm.hpp index 0243c04cd..2d20e5ec5 100644 --- a/library/src/blas3/rocblas_trsm.hpp +++ b/library/src/blas3/rocblas_trsm.hpp @@ -409,7 +409,8 @@ rocblas_status rocblas_trsm_left(rocblas_handle handle, BLOCK, &alpha_negative_one, A, - i + BLOCK + i * size_t(lda) + offset_Ain, + i + BLOCK + i * size_t(lda) + + offset_Ain, lda, stride_A, (U)X, @@ -728,7 +729,8 @@ rocblas_status rocblas_trsm_left(rocblas_handle handle, BLOCK, &alpha_negative_one, A, - i + (i + BLOCK) * size_t(lda) + offset_Ain, + i + (i + BLOCK) * size_t(lda) + + offset_Ain, lda, stride_A, (U)X, @@ -968,7 +970,8 @@ rocblas_status rocblas_trsm_right(rocblas_handle handle, m, stride_X, A, - i + (i + BLOCK) * size_t(lda) + offset_Ain, + i + (i + BLOCK) * size_t(lda) + + offset_Ain, lda, stride_A, &beta_1, @@ -1071,7 +1074,8 @@ rocblas_status rocblas_trsm_right(rocblas_handle handle, m, stride_X, A, - BLOCK + i + i * size_t(lda) + offset_Ain, + BLOCK + i + i * size_t(lda) + + offset_Ain, lda, stride_A, &beta_1, @@ -1222,7 +1226,7 @@ rocblas_status special_trsm_template(rocblas_handle handle, for(size_t w = 0; w < W; w++) { - size_t width = std::min(bsize - w * size_t(B_chunk_size), B_chunk_size); + size_t width = std::min(bsize - w * B_chunk_size, B_chunk_size); if(side == rocblas_side_left) { @@ -1243,19 +1247,19 @@ rocblas_status special_trsm_template(rocblas_handle handle, BLOCK, stride_X, batch_count, - j * size_t(BLOCK) + w * size_t(B_chunk_size) * size_t(ldb) + offset_Bin, + j * BLOCK + w * B_chunk_size * ldb + offset_Bin, 0); if(r) { rocblas_stride offsetA = 0; - rocblas_stride offsetB = parity ? w * size_t(B_chunk_size) * size_t(ldb) - : w * size_t(B_chunk_size) * size_t(ldb) + (q + 1) * BLOCK; + rocblas_stride offsetB = parity ? w * B_chunk_size * ldb + : w * B_chunk_size * ldb + (q + 1) * BLOCK; if(transA == rocblas_operation_none) - offsetA = parity ? r * BLOCK : BLOCK * (q * size_t(lda) + q + lda); + offsetA = parity ? r * BLOCK : BLOCK * (q * lda + q + lda); else - offsetA = parity ? r * BLOCK * size_t(lda) : BLOCK * (q * size_t(lda) + q + 1); + offsetA = parity ? r * BLOCK * lda : BLOCK * (q * lda + q + 1); if(!tensile_supports_ldc_ne_ldd) { @@ -1308,7 +1312,7 @@ rocblas_status special_trsm_template(rocblas_handle handle, alpha, B, compute_type, - j * size_t(BLOCK) + w * size_t(B_chunk_size) * size_t(ldb) + j * BLOCK + w * B_chunk_size * ldb + offset_Bin, ldb, stride_B, @@ -1334,7 +1338,7 @@ rocblas_status special_trsm_template(rocblas_handle handle, BLOCK, r ? &alpha_1 : alpha, invA, - size_t(j * size_t(BLOCK) * BLOCK + offset_invAin), + size_t(j * BLOCK * BLOCK + offset_invAin), size_t(BLOCK), stride_invA, (U)w_x_temp, @@ -1343,7 +1347,7 @@ rocblas_status special_trsm_template(rocblas_handle handle, stride_X, &beta_0, B, - size_t(w * size_t(B_chunk_size) * size_t(ldb) + j * size_t(BLOCK) + offset_Bin), + size_t(w * B_chunk_size * ldb + j * BLOCK + offset_Bin), size_t(ldb), stride_B, batch_count); @@ -1368,18 +1372,18 @@ rocblas_status special_trsm_template(rocblas_handle handle, width, stride_X, batch_count, - j * size_t(BLOCK) * size_t(ldb) + w * size_t(B_chunk_size) + offset_Bin, + j * BLOCK * ldb + w * B_chunk_size + offset_Bin, 0); if(r) { rocblas_stride offsetA = 0; rocblas_stride offsetB - = parity ? w * size_t(B_chunk_size) + (q + 1) * BLOCK * size_t(ldb) : w * size_t(B_chunk_size); + = parity ? w * B_chunk_size + (q + 1) * BLOCK * ldb : w * B_chunk_size; if(transA == rocblas_operation_none) - offsetA = parity ? BLOCK * (q * size_t(lda) + q + 1) : r * BLOCK * size_t(lda); + offsetA = parity ? BLOCK * (q * lda + q + 1) : r * BLOCK * lda; else - offsetA = parity ? BLOCK * (q * size_t(lda) + q + lda) : r * BLOCK; + offsetA = parity ? BLOCK * (q * lda + q + lda) : r * BLOCK; if(!tensile_supports_ldc_ne_ldd) { @@ -1432,7 +1436,7 @@ rocblas_status special_trsm_template(rocblas_handle handle, alpha, B, compute_type, - j * size_t(BLOCK) * size_t(ldb) + w * size_t(B_chunk_size) + j * BLOCK * ldb + w * B_chunk_size + offset_Bin, ldb, stride_B, @@ -1462,12 +1466,12 @@ rocblas_status special_trsm_template(rocblas_handle handle, width, stride_X, invA, - size_t(j * size_t(BLOCK) * BLOCK + offset_invAin), + size_t(j * BLOCK * BLOCK + offset_invAin), size_t(BLOCK), stride_invA, &beta_0, B, - size_t(w * size_t(B_chunk_size) * size_t(ldb) + j * size_t(BLOCK) * size_t(ldb) + offset_Bin), + size_t(w * B_chunk_size * ldb + j * BLOCK * ldb + offset_Bin), size_t(ldb), stride_B, batch_count); @@ -1944,7 +1948,8 @@ rocblas_trsm_small_right_device(rocblas_fill uplo, { // Load A into sA, handle conjugation if necessary for(int i = 0; i <= maxColA; i++) - sA[i * size_t(NB) + tx] = (CONJ) ? conj(A[i * size_t(lda) + tx]) : A[i * size_t(lda) + tx]; + sA[i * size_t(NB) + tx] + = (CONJ) ? conj(A[i * size_t(lda) + tx]) : A[i * size_t(lda) + tx]; // set unit diagonal if needed if(diag == rocblas_diagonal_unit) @@ -2071,7 +2076,7 @@ rocblas_trsm_small_right_device(rocblas_fill uplo, for(int j = maxColA; j > i; j--) { size_t col_off = j * size_t(NB); - T sB_reg = sB[col_off + tx]; + T sB_reg = sB[col_off + tx]; resB[0] -= sB_reg * sA[col_off + (i - 0)]; resB[1] -= sB_reg * sA[col_off + (i - 1)]; resB[2] -= sB_reg * sA[col_off + (i - 2)]; @@ -2121,7 +2126,7 @@ rocblas_trsm_small_right_device(rocblas_fill uplo, for(int j = 0; j < i; j++) { size_t col_off = j * size_t(NB); - T sB_reg = sB[col_off + tx]; + T sB_reg = sB[col_off + tx]; resB[0] -= sB_reg * sA[col_off + (i + 0)]; resB[1] -= sB_reg * sA[col_off + (i + 1)]; resB[2] -= sB_reg * sA[col_off + (i + 2)]; @@ -2263,7 +2268,8 @@ rocblas_trsm_small_64_right_device(rocblas_fill uplo, } sB[i * size_t(NB) + tx] = temp_reg_B; if(diag != rocblas_diagonal_unit) - sB[i * size_t(NB) + tx] /= CONJ ? conj(A[i * size_t(lda) + i]) : A[i * size_t(lda) + i]; + sB[i * size_t(NB) + tx] + /= CONJ ? conj(A[i * size_t(lda) + i]) : A[i * size_t(lda) + i]; } } else // lower (conjugate-)transpose @@ -2278,7 +2284,8 @@ rocblas_trsm_small_64_right_device(rocblas_fill uplo, } sB[i * size_t(NB) + tx] = temp_reg_B; if(diag != rocblas_diagonal_unit) - sB[i * size_t(NB) + tx] /= CONJ ? conj(A[i * size_t(lda) + i]) : A[i * size_t(lda) + i]; + sB[i * size_t(NB) + tx] + /= CONJ ? conj(A[i * size_t(lda) + i]) : A[i * size_t(lda) + i]; } } @@ -2343,7 +2350,8 @@ rocblas_trsm_small_left_device(rocblas_fill uplo, { // Load A into sA, handle conjugation if necessary for(int i = 0; i <= maxColA; i++) - sA[i * size_t(NB) + tx] = (CONJ) ? conj(A[i * size_t(lda) + tx]) : A[i * size_t(lda) + tx]; + sA[i * size_t(NB) + tx] + = (CONJ) ? conj(A[i * size_t(lda) + tx]) : A[i * size_t(lda) + tx]; // set unit diagonal if needed if(diag == rocblas_diagonal_unit) @@ -2370,7 +2378,7 @@ rocblas_trsm_small_left_device(rocblas_fill uplo, for(int j = 0; j < i; j++) { size_t col_off = j * size_t(NB); - T sB_reg = sB[sb_col + j]; + T sB_reg = sB[sb_col + j]; resB[0] -= sB_reg * sA[col_off + i]; resB[1] -= sB_reg * sA[col_off + (i + 1)]; resB[2] -= sB_reg * sA[col_off + (i + 2)]; @@ -2420,7 +2428,7 @@ rocblas_trsm_small_left_device(rocblas_fill uplo, for(int j = maxColA; j > i; j--) { size_t col_off = j * size_t(NB); - T sB_reg = sB[sb_col + j]; + T sB_reg = sB[sb_col + j]; resB[0] -= sB_reg * sA[col_off + (i - 0)]; resB[1] -= sB_reg * sA[col_off + (i - 1)]; resB[2] -= sB_reg * sA[col_off + (i - 2)]; @@ -2658,7 +2666,8 @@ rocblas_trsm_small_64_left_device(rocblas_fill uplo, } sB[tx * size_t(NB) + i] = temp_reg_B; if(diag != rocblas_diagonal_unit) - sB[tx * size_t(NB) + i] /= (CONJ) ? conj(A[i * size_t(lda) + i]) : A[i * size_t(lda) + i]; + sB[tx * size_t(NB) + i] + /= (CONJ) ? conj(A[i * size_t(lda) + i]) : A[i * size_t(lda) + i]; } } else if(!LOWER) @@ -2673,7 +2682,8 @@ rocblas_trsm_small_64_left_device(rocblas_fill uplo, } sB[tx * size_t(NB) + i] = temp_reg_B; if(diag != rocblas_diagonal_unit) - sB[tx * size_t(NB) + i] /= (CONJ) ? conj(A[i * size_t(lda) + i]) : A[i * size_t(lda) + i]; + sB[tx * size_t(NB) + i] + /= (CONJ) ? conj(A[i * size_t(lda) + i]) : A[i * size_t(lda) + i]; } } @@ -2949,7 +2959,7 @@ void rocblas_trsm_small_substitution(rocblas_handle handle, T negative_one = -1; T one = 1; rocblas_int j = 0; - size_t offA_sub, offB_sub; + size_t offA_sub, offB_sub; size_t smem_size; // Different kernels for forward substitution vs. backward substitution @@ -2960,10 +2970,10 @@ void rocblas_trsm_small_substitution(rocblas_handle handle, const rocblas_int j_next = j + NBX; size_t offA_gemm = LEFT ? (!TRANSA ? j * size_t(lda) + j_next : j + j_next * size_t(lda)) - : (!TRANSA ? j + j_next * size_t(lda) : j * size_t(lda) + j_next); + : (!TRANSA ? j + j_next * size_t(lda) : j * size_t(lda) + j_next); size_t offB_gemm = LEFT ? j : j * size_t(ldb); size_t offC_gemm = LEFT ? j_next : j_next * size_t(ldb); - smem_size = (1024 / NBX) * sizeof(T); + smem_size = (1024 / NBX) * sizeof(T); // 1. call trsm subtitution/solve if(FORWARD_SUB) @@ -2996,8 +3006,9 @@ void rocblas_trsm_small_substitution(rocblas_handle handle, } else { - offA_sub = LEFT ? (m - j_next) * size_t(lda) + (m - j_next) : (n - j_next) * size_t(lda) + (n - j_next); - offB_sub = LEFT ? m - j_next : (n - j_next) * size_t(ldb); + offA_sub = LEFT ? (m - j_next) * size_t(lda) + (m - j_next) + : (n - j_next) * size_t(lda) + (n - j_next); + offB_sub = LEFT ? m - j_next : (n - j_next) * size_t(ldb); offA_gemm = LEFT ? (!TRANSA ? (m - j_next) * size_t(lda) : m - j_next) : (!TRANSA ? n - j_next : (n - j_next) * size_t(lda)); offB_gemm = LEFT ? m - j_next : (n - j_next) * size_t(ldb); From 869461a0bad8db2cb40648c0723fea664fd20b33 Mon Sep 17 00:00:00 2001 From: amcamd Date: Wed, 8 Mar 2023 08:50:00 -0600 Subject: [PATCH 03/11] add trsm test cases to PTS --- .../pts/benchmarks/trsm_problems.yaml | 126 ++++++++++++++++++ 1 file changed, 126 insertions(+) diff --git a/scripts/performance/pts/benchmarks/trsm_problems.yaml b/scripts/performance/pts/benchmarks/trsm_problems.yaml index 6545bfe5b..92061bb20 100644 --- a/scripts/performance/pts/benchmarks/trsm_problems.yaml +++ b/scripts/performance/pts/benchmarks/trsm_problems.yaml @@ -29,6 +29,76 @@ Definitions: - { M: 120, N: 120, lda: 120, ldb: 120 } - { M: 124, N: 124, lda: 124, ldb: 124 } + - &testset1_small_matrix_size_range + - { M: 256, N: 256, lda: 256, ldb: 256 } + - { M: 512, N: 256, lda: 256, ldb: 512 } + - { M: 512, N: 512, lda: 512, ldb: 512 } + - { M: 768, N: 256, lda: 256, ldb: 768 } + - { M: 1024, N: 256, lda: 256, ldb: 1024 } + - { M: 1280, N: 256, lda: 256, ldb: 1280 } + - { M: 1536, N: 256, lda: 256, ldb: 1536 } + - { M: 1792, N: 256, lda: 256, ldb: 1792 } + - { M: 2048, N: 256, lda: 256, ldb: 2048 } + - { M: 384, N: 384, lda: 384, ldb: 384 } + - { M: 768, N: 384, lda: 384, ldb: 768 } + - { M: 1152, N: 384, lda: 384, ldb: 1152 } + - { M: 1536, N: 384, lda: 384, ldb: 1536 } + - { M: 1920, N: 384, lda: 384, ldb: 1920 } + - { M: 2304, N: 384, lda: 384, ldb: 2304 } + - { M: 2688, N: 384, lda: 384, ldb: 2688 } + - { M: 3072, N: 384, lda: 384, ldb: 3072 } + + - &testset2_small_matrix_size_range + - { M: 256, N: 256, lda: 256, ldb: 256 } + - { M: 256, N: 512, lda: 512, ldb: 256 } + - { M: 256, N: 768, lda: 768, ldb: 256 } + - { M: 256, N: 1024, lda: 1024, ldb: 256 } + - { M: 256, N: 1280, lda: 1280, ldb: 256 } + - { M: 256, N: 1536, lda: 1536, ldb: 256 } + - { M: 256, N: 1792, lda: 1792, ldb: 256 } + - { M: 256, N: 2048, lda: 2048, ldb: 256 } + - { M: 384, N: 384, lda: 384, ldb: 384 } + - { M: 384, N: 768, lda: 768, ldb: 384 } + - { M: 384, N: 1152, lda: 1152, ldb: 384 } + - { M: 384, N: 1536, lda: 1536, ldb: 384 } + - { M: 384, N: 1920, lda: 1920, ldb: 384 } + - { M: 384, N: 2304, lda: 2304, ldb: 384 } + - { M: 384, N: 2688, lda: 2688, ldb: 384 } + - { M: 384, N: 3072, lda: 3072, ldb: 384 } + + - &testset1_matrix_size_range + - { M: 128, N: 2048, lda: 128, ldb: 128 } + - { M: 128, N: 16848, lda: 128, ldb: 128 } +# - { M: 128, N: 29696, lda: 128, ldb: 128 } +# - { M: 128, N: 44544, lda: 128, ldb: 128 } +# - { M: 128, N: 53632, lda: 128, ldb: 128 } + - { M: 256, N: 2048, lda: 256, ldb: 256 } + - { M: 256, N: 29696, lda: 256, ldb: 256 } +# - { M: 256, N: 44544, lda: 256, ldb: 256 } +# - { M: 256, N: 53504, lda: 256, ldb: 256 } + - { M: 384, N: 2048, lda: 384, ldb: 384 } + - { M: 384, N: 14976, lda: 384, ldb: 384 } +# - { M: 384, N: 29952, lda: 384, ldb: 384 } +# - { M: 384, N: 44928, lda: 384, ldb: 384 } +# - { M: 384, N: 53376, lda: 384, ldb: 384 } + + - &testset2_matrix_size_range + - { M: 2048, N: 128, lda: 2048, ldb: 2048 } + - { M: 16848, N: 128, lda: 16848, ldb: 16848 } +# - { M: 29696, N: 128, lda: 29696, ldb: 29696 } +# - { M: 44544, N: 128, lda: 44544, ldb: 44544 } +# - { M: 53632, N: 128, lda: 53632, ldb: 53632 } + - { M: 2048, N: 256, lda: 2048, ldb: 2048 } + - { M: 14848, N: 256, lda: 14848, ldb: 14848 } +# - { M: 29696, N: 256, lda: 29696, ldb: 29696 } +# - { M: 44544, N: 256, lda: 44544, ldb: 44544 } +# - { M: 53504, N: 256, lda: 53504, ldb: 53504 } + - { M: 2048, N: 384, lda: 2048, ldb: 2048 } + - { M: 14976, N: 384, lda: 14976, ldb: 14976 } +# - { M: 29952, N: 384, lda: 29952, ldb: 29952 } +# - { M: 44928, N: 384, lda: 44928, ldb: 44928 } +# - { M: 53376, N: 384, lda: 53376, ldb: 53376 } + Tests: - name: trsm_bench_const_n category: bench @@ -77,4 +147,60 @@ Tests: incy: 1 matrix_size: *m_equals_n_range iters: 20 + + - name: trsm_bench_1_small_matrix_size + category: bench + function: trsm + precision: *double_precision + transA: [ N, T ] + side: [ L, R ] + uplo: L + diag: U + alpha: 1 + incx: 1 + incy: 1 + matrix_size: *testset1_small_matrix_size_range + iters: 10 + + - name: trsm_bench_2_small_matrix_size + category: bench + function: trsm + precision: *double_precision + transA: [ N, T ] + side: [ L, R ] + uplo: L + diag: U + alpha: 1 + incx: 1 + incy: 1 + matrix_size: *testset2_small_matrix_size_range + iters: 10 + + - name: trsm_bench_1_matrix_size + category: bench + function: trsm + precision: *double_precision + transA: [ N, T ] + side: [ L, R ] + uplo: L + diag: U + alpha: 1 + incx: 1 + incy: 1 + matrix_size: *testset1_matrix_size_range + iters: 5 + + - name: trsm_bench_2_matrix_size + category: bench + function: trsm + precision: *double_precision + transA: [ N, T ] + side: [ L, R ] + uplo: L + diag: U + alpha: 1 + incx: 1 + incy: 1 + matrix_size: *testset2_matrix_size_range + iters: 5 ... From 8933d40d08c921a9f4574b26e51aa3b16bceadb5 Mon Sep 17 00:00:00 2001 From: amcamd Date: Wed, 8 Mar 2023 21:04:01 -0600 Subject: [PATCH 04/11] add trsm size_t test and size_t corrections to code --- clients/gtest/trsm_gtest.yaml | 27 ++ library/src/blas3/rocblas_trsm.hpp | 426 +++++++++--------- library/src/blas3/rocblas_trtri.hpp | 180 ++++---- library/src/blas3/trtri_trsm.hpp | 31 +- .../pts/benchmarks/trsm_problems.yaml | 100 ++-- 5 files changed, 399 insertions(+), 365 deletions(-) diff --git a/clients/gtest/trsm_gtest.yaml b/clients/gtest/trsm_gtest.yaml index 42a8d1cd6..6e91f52a7 100644 --- a/clients/gtest/trsm_gtest.yaml +++ b/clients/gtest/trsm_gtest.yaml @@ -60,6 +60,15 @@ Definitions: - &large_memory_matrix_size_range - { M: 8320, N: 128, lda: 8320, ldb: 8320 } + - &size_t_left_matrix_size_range +# - { M: 4, N: 46435, lda: 4, ldb: 46435 } + - { M: 46345, N: 4, lda: 46345, ldb: 46345 } +# - { M: 47000, N: 4, lda: 47000, ldb: 47000 } # calls rocblas_internal_gemm_template with batch_count=367, stride_a=6016128 + + - &size_t_right_matrix_size_range + - { M: 4, N: 46345, lda: 46345, ldb: 4 } +# - { M: 4, N: 47000, lda: 47000, ldb: 4 } # calls rocblas_internal_gemm_template with batch_count=367, stride_a=6016128 + - &substitution_size_range_thorough - { M: 1, N: 1, lda: 100, ldb: 100 } - { M: 1, N: 32, lda: 100, ldb: 100 } @@ -486,6 +495,24 @@ Tests: matrix_size: *testset2_matrix_size_range alpha: [ 1 ] +- name: trsm_size_t_left + category: nightly + function: trsm + precision: *single_precision + arguments: + - { side: L, uplo: L, transA: N, diag: N } + matrix_size: *size_t_left_matrix_size_range + alpha: [2] + +- name: trsm_size_t_right + category: nightly + function: trsm + precision: *single_precision + arguments: + - { side: R, uplo: L, transA: N, diag: N } + matrix_size: *size_t_right_matrix_size_range + alpha: [2] + - name: trsm_large category: nightly function: trsm diff --git a/library/src/blas3/rocblas_trsm.hpp b/library/src/blas3/rocblas_trsm.hpp index 2d20e5ec5..6533afd91 100644 --- a/library/src/blas3/rocblas_trsm.hpp +++ b/library/src/blas3/rocblas_trsm.hpp @@ -198,7 +198,7 @@ copy_matrix_trsm(rocblas_int rows, size_t ty = blockIdx.y * blockDim.y + threadIdx.y; if(tx < rows && ty < cols) - xb[tx + ldb * ty] = xa[tx + lda * ty]; + xb[tx + size_t(ldb) * ty] = xa[tx + size_t(lda) * ty]; } /* ===============copy helper============================================= */ @@ -256,7 +256,7 @@ set_matrix_trsm(rocblas_int rows, size_t ty = blockIdx.y * blockDim.y + threadIdx.y; if(tx < rows && ty < cols) - xa[tx + lda * ty] = T(0.0); + xa[tx + size_t(lda) * ty] = T(0.0); } /* ===============set helper============================================= */ @@ -1253,8 +1253,9 @@ rocblas_status special_trsm_template(rocblas_handle handle, if(r) { rocblas_stride offsetA = 0; - rocblas_stride offsetB = parity ? w * B_chunk_size * ldb - : w * B_chunk_size * ldb + (q + 1) * BLOCK; + rocblas_stride offsetB = parity + ? w * B_chunk_size * size_t(ldb) + : w * B_chunk_size * size_t(ldb) + (q + 1) * BLOCK; if(transA == rocblas_operation_none) offsetA = parity ? r * BLOCK : BLOCK * (q * lda + q + lda); @@ -1347,7 +1348,7 @@ rocblas_status special_trsm_template(rocblas_handle handle, stride_X, &beta_0, B, - size_t(w * B_chunk_size * ldb + j * BLOCK + offset_Bin), + size_t(w * B_chunk_size * size_t(ldb) + j * BLOCK + offset_Bin), size_t(ldb), stride_B, batch_count); @@ -1372,14 +1373,15 @@ rocblas_status special_trsm_template(rocblas_handle handle, width, stride_X, batch_count, - j * BLOCK * ldb + w * B_chunk_size + offset_Bin, + j * BLOCK * size_t(ldb) + w * B_chunk_size + offset_Bin, 0); if(r) { rocblas_stride offsetA = 0; - rocblas_stride offsetB - = parity ? w * B_chunk_size + (q + 1) * BLOCK * ldb : w * B_chunk_size; + rocblas_stride offsetB = parity + ? w * B_chunk_size + (q + 1) * BLOCK * size_t(ldb) + : w * B_chunk_size; if(transA == rocblas_operation_none) offsetA = parity ? BLOCK * (q * lda + q + 1) : r * BLOCK * lda; else @@ -1436,7 +1438,7 @@ rocblas_status special_trsm_template(rocblas_handle handle, alpha, B, compute_type, - j * BLOCK * ldb + w * B_chunk_size + j * BLOCK * size_t(ldb) + w * B_chunk_size + offset_Bin, ldb, stride_B, @@ -1471,7 +1473,7 @@ rocblas_status special_trsm_template(rocblas_handle handle, stride_invA, &beta_0, B, - size_t(w * B_chunk_size * ldb + j * BLOCK * ldb + offset_Bin), + size_t(w * B_chunk_size * size_t(ldb) + j * BLOCK * size_t(ldb) + offset_Bin), size_t(ldb), stride_B, batch_count); @@ -1948,19 +1950,18 @@ rocblas_trsm_small_right_device(rocblas_fill uplo, { // Load A into sA, handle conjugation if necessary for(int i = 0; i <= maxColA; i++) - sA[i * size_t(NB) + tx] - = (CONJ) ? conj(A[i * size_t(lda) + tx]) : A[i * size_t(lda) + tx]; + sA[i * NB + tx] = (CONJ) ? conj(A[i * size_t(lda) + tx]) : A[i * size_t(lda) + tx]; // set unit diagonal if needed if(diag == rocblas_diagonal_unit) - sA[tx * size_t(NB) + tx] = T(1.0); + sA[tx * NB + tx] = T(1.0); } if(tx < maxColB) { // Load B into sB and multiply by alpha for(int i = 0; i < n; i++) - sB[i * size_t(NB) + tx] = alpha * B[i * size_t(ldb) + tx]; + sB[i * NB + tx] = alpha * B[i * size_t(ldb) + tx]; } __syncthreads(); @@ -1971,48 +1972,48 @@ rocblas_trsm_small_right_device(rocblas_fill uplo, for(i = 0; i + 3 <= maxColA; i += 4) { // Subtract previously solved parts - resB[0] = sB[(i + 0) * size_t(NB) + tx]; - resB[1] = sB[(i + 1) * size_t(NB) + tx]; - resB[2] = sB[(i + 2) * size_t(NB) + tx]; - resB[3] = sB[(i + 3) * size_t(NB) + tx]; + resB[0] = sB[(i + 0) * NB + tx]; + resB[1] = sB[(i + 1) * NB + tx]; + resB[2] = sB[(i + 2) * NB + tx]; + resB[3] = sB[(i + 3) * NB + tx]; for(int j = 0; j < i; j++) { - T sB_reg = sB[j * size_t(NB) + tx]; - resB[0] -= sB_reg * sA[(i + 0) * size_t(NB) + j]; - resB[1] -= sB_reg * sA[(i + 1) * size_t(NB) + j]; - resB[2] -= sB_reg * sA[(i + 2) * size_t(NB) + j]; - resB[3] -= sB_reg * sA[(i + 3) * size_t(NB) + j]; + T sB_reg = sB[j * NB + tx]; + resB[0] -= sB_reg * sA[(i + 0) * NB + j]; + resB[1] -= sB_reg * sA[(i + 1) * NB + j]; + resB[2] -= sB_reg * sA[(i + 2) * NB + j]; + resB[3] -= sB_reg * sA[(i + 3) * NB + j]; } - resB[0] /= sA[(i + 0) * size_t(NB) + (i + 0)]; - sB[(i + 0) * size_t(NB) + tx] = resB[0]; + resB[0] /= sA[(i + 0) * NB + (i + 0)]; + sB[(i + 0) * NB + tx] = resB[0]; - resB[1] -= resB[0] * sA[(i + 1) * size_t(NB) + (i + 0)]; - resB[1] /= sA[(i + 1) * size_t(NB) + (i + 1)]; - sB[(i + 1) * size_t(NB) + tx] = resB[1]; + resB[1] -= resB[0] * sA[(i + 1) * NB + (i + 0)]; + resB[1] /= sA[(i + 1) * NB + (i + 1)]; + sB[(i + 1) * NB + tx] = resB[1]; - resB[2] -= resB[0] * sA[(i + 2) * size_t(NB) + (i + 0)]; - resB[2] -= resB[1] * sA[(i + 2) * size_t(NB) + (i + 1)]; - resB[2] /= sA[(i + 2) * size_t(NB) + (i + 2)]; - sB[(i + 2) * size_t(NB) + tx] = resB[2]; + resB[2] -= resB[0] * sA[(i + 2) * NB + (i + 0)]; + resB[2] -= resB[1] * sA[(i + 2) * NB + (i + 1)]; + resB[2] /= sA[(i + 2) * NB + (i + 2)]; + sB[(i + 2) * NB + tx] = resB[2]; - resB[3] -= resB[0] * sA[(i + 3) * size_t(NB) + (i + 0)]; - resB[3] -= resB[1] * sA[(i + 3) * size_t(NB) + (i + 1)]; - resB[3] -= resB[2] * sA[(i + 3) * size_t(NB) + (i + 2)]; - resB[3] /= sA[(i + 3) * size_t(NB) + (i + 3)]; - sB[(i + 3) * size_t(NB) + tx] = resB[3]; + resB[3] -= resB[0] * sA[(i + 3) * NB + (i + 0)]; + resB[3] -= resB[1] * sA[(i + 3) * NB + (i + 1)]; + resB[3] -= resB[2] * sA[(i + 3) * NB + (i + 2)]; + resB[3] /= sA[(i + 3) * NB + (i + 3)]; + sB[(i + 3) * NB + tx] = resB[3]; } // tail end if not divisible by 4 for(; i <= maxColA; i++) { - resB[0] = sB[i * size_t(NB) + tx]; + resB[0] = sB[i * NB + tx]; for(int j = 0; j < i; j++) { - resB[0] -= sB[j * size_t(NB) + tx] * sA[i * size_t(NB) + j]; + resB[0] -= sB[j * NB + tx] * sA[i * NB + j]; } - sB[i * size_t(NB) + tx] = resB[0] / sA[i * size_t(NB) + i]; + sB[i * NB + tx] = resB[0] / sA[i * NB + i]; } } else if(transA == rocblas_operation_none && uplo == rocblas_fill_lower) @@ -2020,47 +2021,47 @@ rocblas_trsm_small_right_device(rocblas_fill uplo, int i; for(i = maxColA; i >= 3; i -= 4) { - resB[0] = sB[(i - 0) * size_t(NB) + tx]; - resB[1] = sB[(i - 1) * size_t(NB) + tx]; - resB[2] = sB[(i - 2) * size_t(NB) + tx]; - resB[3] = sB[(i - 3) * size_t(NB) + tx]; + resB[0] = sB[(i - 0) * NB + tx]; + resB[1] = sB[(i - 1) * NB + tx]; + resB[2] = sB[(i - 2) * NB + tx]; + resB[3] = sB[(i - 3) * NB + tx]; for(int j = maxColA; j > i; j--) { - T sB_reg = sB[j * size_t(NB) + tx]; - resB[0] -= sB_reg * sA[(i - 0) * size_t(NB) + j]; - resB[1] -= sB_reg * sA[(i - 1) * size_t(NB) + j]; - resB[2] -= sB_reg * sA[(i - 2) * size_t(NB) + j]; - resB[3] -= sB_reg * sA[(i - 3) * size_t(NB) + j]; + T sB_reg = sB[j * NB + tx]; + resB[0] -= sB_reg * sA[(i - 0) * NB + j]; + resB[1] -= sB_reg * sA[(i - 1) * NB + j]; + resB[2] -= sB_reg * sA[(i - 2) * NB + j]; + resB[3] -= sB_reg * sA[(i - 3) * NB + j]; } - resB[0] /= sA[(i - 0) * size_t(NB) + (i - 0)]; - sB[(i - 0) * size_t(NB) + tx] = resB[0]; + resB[0] /= sA[(i - 0) * NB + (i - 0)]; + sB[(i - 0) * NB + tx] = resB[0]; - resB[1] -= resB[0] * sA[(i - 1) * size_t(NB) + (i - 0)]; - resB[1] /= sA[(i - 1) * size_t(NB) + (i - 1)]; - sB[(i - 1) * size_t(NB) + tx] = resB[1]; + resB[1] -= resB[0] * sA[(i - 1) * NB + (i - 0)]; + resB[1] /= sA[(i - 1) * NB + (i - 1)]; + sB[(i - 1) * NB + tx] = resB[1]; - resB[2] -= resB[0] * sA[(i - 2) * size_t(NB) + (i - 0)]; - resB[2] -= resB[1] * sA[(i - 2) * size_t(NB) + (i - 1)]; - resB[2] /= sA[(i - 2) * size_t(NB) + (i - 2)]; - sB[(i - 2) * size_t(NB) + tx] = resB[2]; + resB[2] -= resB[0] * sA[(i - 2) * NB + (i - 0)]; + resB[2] -= resB[1] * sA[(i - 2) * NB + (i - 1)]; + resB[2] /= sA[(i - 2) * NB + (i - 2)]; + sB[(i - 2) * NB + tx] = resB[2]; - resB[3] -= resB[0] * sA[(i - 3) * size_t(NB) + (i - 0)]; - resB[3] -= resB[1] * sA[(i - 3) * size_t(NB) + (i - 1)]; - resB[3] -= resB[2] * sA[(i - 3) * size_t(NB) + (i - 2)]; - resB[3] /= sA[(i - 3) * size_t(NB) + (i - 3)]; - sB[(i - 3) * size_t(NB) + tx] = resB[3]; + resB[3] -= resB[0] * sA[(i - 3) * NB + (i - 0)]; + resB[3] -= resB[1] * sA[(i - 3) * NB + (i - 1)]; + resB[3] -= resB[2] * sA[(i - 3) * NB + (i - 2)]; + resB[3] /= sA[(i - 3) * NB + (i - 3)]; + sB[(i - 3) * NB + tx] = resB[3]; } for(; i >= 0; i--) { - resB[0] = sB[i * size_t(NB) + tx]; + resB[0] = sB[i * NB + tx]; for(int j = maxColA; j > i; j--) { - resB[0] -= sB[j * size_t(NB) + tx] * sA[i * size_t(NB) + j]; + resB[0] -= sB[j * NB + tx] * sA[i * NB + j]; } - sB[i * size_t(NB) + tx] = resB[0] / sA[i * size_t(NB) + i]; + sB[i * NB + tx] = resB[0] / sA[i * NB + i]; } } else if(uplo == rocblas_fill_upper) @@ -2068,48 +2069,48 @@ rocblas_trsm_small_right_device(rocblas_fill uplo, int i; for(i = maxColA; i >= 3; i -= 4) { - resB[0] = sB[(i - 0) * size_t(NB) + tx]; - resB[1] = sB[(i - 1) * size_t(NB) + tx]; - resB[2] = sB[(i - 2) * size_t(NB) + tx]; - resB[3] = sB[(i - 3) * size_t(NB) + tx]; + resB[0] = sB[(i - 0) * NB + tx]; + resB[1] = sB[(i - 1) * NB + tx]; + resB[2] = sB[(i - 2) * NB + tx]; + resB[3] = sB[(i - 3) * NB + tx]; for(int j = maxColA; j > i; j--) { - size_t col_off = j * size_t(NB); - T sB_reg = sB[col_off + tx]; + rocblas_int col_off = j * NB; + T sB_reg = sB[col_off + tx]; resB[0] -= sB_reg * sA[col_off + (i - 0)]; resB[1] -= sB_reg * sA[col_off + (i - 1)]; resB[2] -= sB_reg * sA[col_off + (i - 2)]; resB[3] -= sB_reg * sA[col_off + (i - 3)]; } - resB[0] /= sA[(i - 0) * size_t(NB) + (i - 0)]; - sB[(i - 0) * size_t(NB) + tx] = resB[0]; + resB[0] /= sA[(i - 0) * NB + (i - 0)]; + sB[(i - 0) * NB + tx] = resB[0]; - resB[1] -= resB[0] * sA[(i - 0) * size_t(NB) + (i - 1)]; - resB[1] /= sA[(i - 1) * size_t(NB) + (i - 1)]; - sB[(i - 1) * size_t(NB) + tx] = resB[1]; + resB[1] -= resB[0] * sA[(i - 0) * NB + (i - 1)]; + resB[1] /= sA[(i - 1) * NB + (i - 1)]; + sB[(i - 1) * NB + tx] = resB[1]; - resB[2] -= resB[0] * sA[(i - 0) * size_t(NB) + (i - 2)]; - resB[2] -= resB[1] * sA[(i - 1) * size_t(NB) + (i - 2)]; - resB[2] /= sA[(i - 2) * size_t(NB) + (i - 2)]; - sB[(i - 2) * size_t(NB) + tx] = resB[2]; + resB[2] -= resB[0] * sA[(i - 0) * NB + (i - 2)]; + resB[2] -= resB[1] * sA[(i - 1) * NB + (i - 2)]; + resB[2] /= sA[(i - 2) * NB + (i - 2)]; + sB[(i - 2) * NB + tx] = resB[2]; - resB[3] -= resB[0] * sA[(i - 0) * size_t(NB) + (i - 3)]; - resB[3] -= resB[1] * sA[(i - 1) * size_t(NB) + (i - 3)]; - resB[3] -= resB[2] * sA[(i - 2) * size_t(NB) + (i - 3)]; - resB[3] /= sA[(i - 3) * size_t(NB) + (i - 3)]; - sB[(i - 3) * size_t(NB) + tx] = resB[3]; + resB[3] -= resB[0] * sA[(i - 0) * NB + (i - 3)]; + resB[3] -= resB[1] * sA[(i - 1) * NB + (i - 3)]; + resB[3] -= resB[2] * sA[(i - 2) * NB + (i - 3)]; + resB[3] /= sA[(i - 3) * NB + (i - 3)]; + sB[(i - 3) * NB + tx] = resB[3]; } for(; i >= 0; i--) { - resB[0] = sB[i * size_t(NB) + tx]; + resB[0] = sB[i * NB + tx]; for(int j = maxColA; j > i; j--) { - resB[0] -= sB[j * size_t(NB) + tx] * sA[j * size_t(NB) + i]; + resB[0] -= sB[j * NB + tx] * sA[j * NB + i]; } - sB[i * size_t(NB) + tx] = resB[0] / sA[i * size_t(NB) + i]; + sB[i * NB + tx] = resB[0] / sA[i * NB + i]; } } else // lower (conjugate-)transpose @@ -2118,49 +2119,49 @@ rocblas_trsm_small_right_device(rocblas_fill uplo, for(i = 0; i + 3 <= maxColA; i += 4) { // Subtract previously solved parts - resB[0] = sB[(i + 0) * size_t(NB) + tx]; - resB[1] = sB[(i + 1) * size_t(NB) + tx]; - resB[2] = sB[(i + 2) * size_t(NB) + tx]; - resB[3] = sB[(i + 3) * size_t(NB) + tx]; + resB[0] = sB[(i + 0) * NB + tx]; + resB[1] = sB[(i + 1) * NB + tx]; + resB[2] = sB[(i + 2) * NB + tx]; + resB[3] = sB[(i + 3) * NB + tx]; for(int j = 0; j < i; j++) { - size_t col_off = j * size_t(NB); - T sB_reg = sB[col_off + tx]; + rocblas_int col_off = j * NB; + T sB_reg = sB[col_off + tx]; resB[0] -= sB_reg * sA[col_off + (i + 0)]; resB[1] -= sB_reg * sA[col_off + (i + 1)]; resB[2] -= sB_reg * sA[col_off + (i + 2)]; resB[3] -= sB_reg * sA[col_off + (i + 3)]; } - resB[0] /= sA[(i + 0) * size_t(NB) + (i + 0)]; - sB[(i + 0) * size_t(NB) + tx] = resB[0]; + resB[0] /= sA[(i + 0) * NB + (i + 0)]; + sB[(i + 0) * NB + tx] = resB[0]; - resB[1] -= resB[0] * sA[(i + 0) * size_t(NB) + (i + 1)]; - resB[1] /= sA[(i + 1) * size_t(NB) + (i + 1)]; - sB[(i + 1) * size_t(NB) + tx] = resB[1]; + resB[1] -= resB[0] * sA[(i + 0) * NB + (i + 1)]; + resB[1] /= sA[(i + 1) * NB + (i + 1)]; + sB[(i + 1) * NB + tx] = resB[1]; - resB[2] -= resB[0] * sA[(i + 0) * size_t(NB) + (i + 2)]; - resB[2] -= resB[1] * sA[(i + 1) * size_t(NB) + (i + 2)]; - resB[2] /= sA[(i + 2) * size_t(NB) + (i + 2)]; - sB[(i + 2) * size_t(NB) + tx] = resB[2]; + resB[2] -= resB[0] * sA[(i + 0) * NB + (i + 2)]; + resB[2] -= resB[1] * sA[(i + 1) * NB + (i + 2)]; + resB[2] /= sA[(i + 2) * NB + (i + 2)]; + sB[(i + 2) * NB + tx] = resB[2]; - resB[3] -= resB[0] * sA[(i + 0) * size_t(NB) + (i + 3)]; - resB[3] -= resB[1] * sA[(i + 1) * size_t(NB) + (i + 3)]; - resB[3] -= resB[2] * sA[(i + 2) * size_t(NB) + (i + 3)]; - resB[3] /= sA[(i + 3) * size_t(NB) + (i + 3)]; - sB[(i + 3) * size_t(NB) + tx] = resB[3]; + resB[3] -= resB[0] * sA[(i + 0) * NB + (i + 3)]; + resB[3] -= resB[1] * sA[(i + 1) * NB + (i + 3)]; + resB[3] -= resB[2] * sA[(i + 2) * NB + (i + 3)]; + resB[3] /= sA[(i + 3) * NB + (i + 3)]; + sB[(i + 3) * NB + tx] = resB[3]; } // tail end if not divisible by 4 for(; i <= maxColA; i++) { - resB[0] = sB[i * size_t(NB) + tx]; + resB[0] = sB[i * NB + tx]; for(int j = 0; j < i; j++) { - resB[0] -= sB[j * size_t(NB) + tx] * sA[j * size_t(NB) + i]; + resB[0] -= sB[j * NB + tx] * sA[j * NB + i]; } - sB[i * size_t(NB) + tx] = resB[0] / sA[i * size_t(NB) + i]; + sB[i * NB + tx] = resB[0] / sA[i * NB + i]; } } @@ -2168,7 +2169,7 @@ rocblas_trsm_small_right_device(rocblas_fill uplo, if(tx < maxColB) { for(int i = 0; i < n; i++) - B[i * size_t(ldb) + tx] = sB[i * size_t(NB) + tx]; + B[i * size_t(ldb) + tx] = sB[i * NB + tx]; } } @@ -2218,7 +2219,7 @@ rocblas_trsm_small_64_right_device(rocblas_fill uplo, { // Load B into sB and multiply by alpha for(int i = 0; i < n; i++) - sB[i * size_t(NB) + tx] = alpha * B[i * size_t(ldb) + tx]; + sB[i * NB + tx] = alpha * B[i * size_t(ldb) + tx]; } __syncthreads(); // Solve for B in shared memory @@ -2229,63 +2230,61 @@ rocblas_trsm_small_64_right_device(rocblas_fill uplo, for(int i = 0; i <= maxColA; i++) { // Subtract previously solved parts - T temp_reg_B = sB[i * size_t(NB) + tx]; + T temp_reg_B = sB[i * NB + tx]; for(int j = 0; j < i; j++) { T valA = A[i * size_t(lda) + j]; - temp_reg_B -= sB[j * size_t(NB) + tx] * valA; + temp_reg_B -= sB[j * NB + tx] * valA; } // Solve - sB[i * size_t(NB) + tx] = temp_reg_B; + sB[i * NB + tx] = temp_reg_B; if(diag != rocblas_diagonal_unit) - sB[i * size_t(NB) + tx] /= A[i * size_t(lda) + i]; + sB[i * NB + tx] /= A[i * size_t(lda) + i]; } } else if(transA == rocblas_operation_none && uplo == rocblas_fill_lower) { for(int i = maxColA; i >= 0; i--) { - T temp_reg_B = sB[i * size_t(NB) + tx]; + T temp_reg_B = sB[i * NB + tx]; for(int j = maxColA; j > i; j--) { T valA = A[i * size_t(lda) + j]; - temp_reg_B -= sB[j * size_t(NB) + tx] * valA; + temp_reg_B -= sB[j * NB + tx] * valA; } - sB[i * size_t(NB) + tx] = temp_reg_B; + sB[i * NB + tx] = temp_reg_B; if(diag != rocblas_diagonal_unit) - sB[i * size_t(NB) + tx] /= A[i * size_t(lda) + i]; + sB[i * NB + tx] /= A[i * size_t(lda) + i]; } } else if(uplo == rocblas_fill_upper) { for(int i = maxColA; i >= 0; i--) { - T temp_reg_B = sB[i * size_t(NB) + tx]; + T temp_reg_B = sB[i * NB + tx]; for(int j = maxColA; j > i; j--) { T valA = CONJ ? conj(A[j * size_t(lda) + i]) : A[j * size_t(lda) + i]; - temp_reg_B -= sB[j * size_t(NB) + tx] * valA; + temp_reg_B -= sB[j * NB + tx] * valA; } - sB[i * size_t(NB) + tx] = temp_reg_B; + sB[i * NB + tx] = temp_reg_B; if(diag != rocblas_diagonal_unit) - sB[i * size_t(NB) + tx] - /= CONJ ? conj(A[i * size_t(lda) + i]) : A[i * size_t(lda) + i]; + sB[i * NB + tx] /= CONJ ? conj(A[i * size_t(lda) + i]) : A[i * size_t(lda) + i]; } } else // lower (conjugate-)transpose { for(int i = 0; i <= maxColA; i++) { - T temp_reg_B = sB[i * size_t(NB) + tx]; + T temp_reg_B = sB[i * NB + tx]; for(int j = 0; j < i; j++) { T valA = CONJ ? conj(A[j * size_t(lda) + i]) : A[j * size_t(lda) + i]; - temp_reg_B -= sB[j * size_t(NB) + tx] * valA; + temp_reg_B -= sB[j * NB + tx] * valA; } - sB[i * size_t(NB) + tx] = temp_reg_B; + sB[i * NB + tx] = temp_reg_B; if(diag != rocblas_diagonal_unit) - sB[i * size_t(NB) + tx] - /= CONJ ? conj(A[i * size_t(lda) + i]) : A[i * size_t(lda) + i]; + sB[i * NB + tx] /= CONJ ? conj(A[i * size_t(lda) + i]) : A[i * size_t(lda) + i]; } } @@ -2293,7 +2292,7 @@ rocblas_trsm_small_64_right_device(rocblas_fill uplo, if(tx < maxColB) { for(int i = 0; i < n; i++) - B[i * size_t(ldb) + tx] = sB[i * size_t(NB) + tx]; + B[i * size_t(ldb) + tx] = sB[i * NB + tx]; } } @@ -2350,8 +2349,7 @@ rocblas_trsm_small_left_device(rocblas_fill uplo, { // Load A into sA, handle conjugation if necessary for(int i = 0; i <= maxColA; i++) - sA[i * size_t(NB) + tx] - = (CONJ) ? conj(A[i * size_t(lda) + tx]) : A[i * size_t(lda) + tx]; + sA[i * NB + tx] = (CONJ) ? conj(A[i * size_t(lda) + tx]) : A[i * size_t(lda) + tx]; // set unit diagonal if needed if(diag == rocblas_diagonal_unit) @@ -2359,7 +2357,7 @@ rocblas_trsm_small_left_device(rocblas_fill uplo, // Load B into sB and multiply by alpha for(int i = 0; i < maxColB; i++) - sB[i * size_t(NB) + tx] = alpha * B[i * size_t(ldb) + tx]; + sB[i * NB + tx] = alpha * B[i * size_t(ldb) + tx]; } __syncthreads(); @@ -2377,30 +2375,30 @@ rocblas_trsm_small_left_device(rocblas_fill uplo, for(int j = 0; j < i; j++) { - size_t col_off = j * size_t(NB); - T sB_reg = sB[sb_col + j]; + rocblas_int col_off = j * NB; + T sB_reg = sB[sb_col + j]; resB[0] -= sB_reg * sA[col_off + i]; resB[1] -= sB_reg * sA[col_off + (i + 1)]; resB[2] -= sB_reg * sA[col_off + (i + 2)]; resB[3] -= sB_reg * sA[col_off + (i + 3)]; } - resB[0] /= sA[(i + 0) * size_t(NB) + (i + 0)]; + resB[0] /= sA[(i + 0) * NB + (i + 0)]; sB[sb_col + i + 0] = resB[0]; - resB[1] -= resB[0] * sA[(i + 0) * size_t(NB) + (i + 1)]; - resB[1] /= sA[(i + 1) * size_t(NB) + (i + 1)]; + resB[1] -= resB[0] * sA[(i + 0) * NB + (i + 1)]; + resB[1] /= sA[(i + 1) * NB + (i + 1)]; sB[sb_col + i + 1] = resB[1]; - resB[2] -= resB[0] * sA[(i + 0) * size_t(NB) + (i + 2)]; - resB[2] -= resB[1] * sA[(i + 1) * size_t(NB) + (i + 2)]; - resB[2] /= sA[(i + 2) * size_t(NB) + (i + 2)]; + resB[2] -= resB[0] * sA[(i + 0) * NB + (i + 2)]; + resB[2] -= resB[1] * sA[(i + 1) * NB + (i + 2)]; + resB[2] /= sA[(i + 2) * NB + (i + 2)]; sB[sb_col + i + 2] = resB[2]; - resB[3] -= resB[0] * sA[(i + 0) * size_t(NB) + (i + 3)]; - resB[3] -= resB[1] * sA[(i + 1) * size_t(NB) + (i + 3)]; - resB[3] -= resB[2] * sA[(i + 2) * size_t(NB) + (i + 3)]; - resB[3] /= sA[(i + 3) * size_t(NB) + (i + 3)]; + resB[3] -= resB[0] * sA[(i + 0) * NB + (i + 3)]; + resB[3] -= resB[1] * sA[(i + 1) * NB + (i + 3)]; + resB[3] -= resB[2] * sA[(i + 2) * NB + (i + 3)]; + resB[3] /= sA[(i + 3) * NB + (i + 3)]; sB[sb_col + i + 3] = resB[3]; } @@ -2410,9 +2408,9 @@ rocblas_trsm_small_left_device(rocblas_fill uplo, resB[0] = sB[sb_col + i]; for(int j = 0; j < i; j++) { - resB[0] -= sB[sb_col + j] * sA[j * size_t(NB) + i]; + resB[0] -= sB[sb_col + j] * sA[j * NB + i]; } - sB[sb_col + i] = resB[0] / sA[i * size_t(NB) + i]; + sB[sb_col + i] = resB[0] / sA[i * NB + i]; } } else if(!LOWER && transA == rocblas_operation_none) @@ -2427,30 +2425,30 @@ rocblas_trsm_small_left_device(rocblas_fill uplo, for(int j = maxColA; j > i; j--) { - size_t col_off = j * size_t(NB); - T sB_reg = sB[sb_col + j]; + rocblas_int col_off = j * NB; + T sB_reg = sB[sb_col + j]; resB[0] -= sB_reg * sA[col_off + (i - 0)]; resB[1] -= sB_reg * sA[col_off + (i - 1)]; resB[2] -= sB_reg * sA[col_off + (i - 2)]; resB[3] -= sB_reg * sA[col_off + (i - 3)]; } - resB[0] /= sA[(i - 0) * size_t(NB) + (i - 0)]; + resB[0] /= sA[(i - 0) * NB + (i - 0)]; sB[sb_col + i - 0] = resB[0]; - resB[1] -= resB[0] * sA[(i - 0) * size_t(NB) + (i - 1)]; - resB[1] /= sA[(i - 1) * size_t(NB) + (i - 1)]; + resB[1] -= resB[0] * sA[(i - 0) * NB + (i - 1)]; + resB[1] /= sA[(i - 1) * NB + (i - 1)]; sB[sb_col + i - 1] = resB[1]; - resB[2] -= resB[0] * sA[(i - 0) * size_t(NB) + (i - 2)]; - resB[2] -= resB[1] * sA[(i - 1) * size_t(NB) + (i - 2)]; - resB[2] /= sA[(i - 2) * size_t(NB) + (i - 2)]; + resB[2] -= resB[0] * sA[(i - 0) * NB + (i - 2)]; + resB[2] -= resB[1] * sA[(i - 1) * NB + (i - 2)]; + resB[2] /= sA[(i - 2) * NB + (i - 2)]; sB[sb_col + i - 2] = resB[2]; - resB[3] -= resB[0] * sA[(i - 0) * size_t(NB) + (i - 3)]; - resB[3] -= resB[1] * sA[(i - 1) * size_t(NB) + (i - 3)]; - resB[3] -= resB[2] * sA[(i - 2) * size_t(NB) + (i - 3)]; - resB[3] /= sA[(i - 3) * size_t(NB) + (i - 3)]; + resB[3] -= resB[0] * sA[(i - 0) * NB + (i - 3)]; + resB[3] -= resB[1] * sA[(i - 1) * NB + (i - 3)]; + resB[3] -= resB[2] * sA[(i - 2) * NB + (i - 3)]; + resB[3] /= sA[(i - 3) * NB + (i - 3)]; sB[sb_col + i - 3] = resB[3]; } @@ -2459,9 +2457,9 @@ rocblas_trsm_small_left_device(rocblas_fill uplo, resB[0] = sB[sb_col + i]; for(int j = maxColA; j > i; j--) { - resB[0] -= sB[sb_col + j] * sA[j * size_t(NB) + i]; + resB[0] -= sB[sb_col + j] * sA[j * NB + i]; } - sB[sb_col + i] = resB[0] / sA[i * size_t(NB) + i]; + sB[sb_col + i] = resB[0] / sA[i * NB + i]; } } else if(LOWER) @@ -2477,28 +2475,28 @@ rocblas_trsm_small_left_device(rocblas_fill uplo, for(int j = maxColA; j > i; j--) { T sB_reg = sB[sb_col + j]; - resB[0] -= sB_reg * sA[(i - 0) * size_t(NB) + j]; - resB[1] -= sB_reg * sA[(i - 1) * size_t(NB) + j]; - resB[2] -= sB_reg * sA[(i - 2) * size_t(NB) + j]; - resB[3] -= sB_reg * sA[(i - 3) * size_t(NB) + j]; + resB[0] -= sB_reg * sA[(i - 0) * NB + j]; + resB[1] -= sB_reg * sA[(i - 1) * NB + j]; + resB[2] -= sB_reg * sA[(i - 2) * NB + j]; + resB[3] -= sB_reg * sA[(i - 3) * NB + j]; } - resB[0] /= sA[(i - 0) * size_t(NB) + (i - 0)]; + resB[0] /= sA[(i - 0) * NB + (i - 0)]; sB[sb_col + i - 0] = resB[0]; - resB[1] -= resB[0] * sA[(i - 1) * size_t(NB) + (i - 0)]; - resB[1] /= sA[(i - 1) * size_t(NB) + (i - 1)]; + resB[1] -= resB[0] * sA[(i - 1) * NB + (i - 0)]; + resB[1] /= sA[(i - 1) * NB + (i - 1)]; sB[sb_col + i - 1] = resB[1]; - resB[2] -= resB[0] * sA[(i - 2) * size_t(NB) + (i - 0)]; - resB[2] -= resB[1] * sA[(i - 2) * size_t(NB) + (i - 1)]; - resB[2] /= sA[(i - 2) * size_t(NB) + (i - 2)]; + resB[2] -= resB[0] * sA[(i - 2) * NB + (i - 0)]; + resB[2] -= resB[1] * sA[(i - 2) * NB + (i - 1)]; + resB[2] /= sA[(i - 2) * NB + (i - 2)]; sB[sb_col + i - 2] = resB[2]; - resB[3] -= resB[0] * sA[(i - 3) * size_t(NB) + (i - 0)]; - resB[3] -= resB[1] * sA[(i - 3) * size_t(NB) + (i - 1)]; - resB[3] -= resB[2] * sA[(i - 3) * size_t(NB) + (i - 2)]; - resB[3] /= sA[(i - 3) * size_t(NB) + (i - 3)]; + resB[3] -= resB[0] * sA[(i - 3) * NB + (i - 0)]; + resB[3] -= resB[1] * sA[(i - 3) * NB + (i - 1)]; + resB[3] -= resB[2] * sA[(i - 3) * NB + (i - 2)]; + resB[3] /= sA[(i - 3) * NB + (i - 3)]; sB[sb_col + i - 3] = resB[3]; } @@ -2507,9 +2505,9 @@ rocblas_trsm_small_left_device(rocblas_fill uplo, resB[0] = sB[sb_col + i]; for(int j = maxColA; j > i; j--) { - resB[0] -= sB[sb_col + j] * sA[i * size_t(NB) + j]; + resB[0] -= sB[sb_col + j] * sA[i * NB + j]; } - sB[sb_col + i] = resB[0] / sA[i * size_t(NB) + i]; + sB[sb_col + i] = resB[0] / sA[i * NB + i]; } } else if(!LOWER) @@ -2526,28 +2524,28 @@ rocblas_trsm_small_left_device(rocblas_fill uplo, for(int j = 0; j < i; j++) { T sB_reg = sB[sb_col + j]; - resB[0] -= sB_reg * sA[(i + 0) * size_t(NB) + j]; - resB[1] -= sB_reg * sA[(i + 1) * size_t(NB) + j]; - resB[2] -= sB_reg * sA[(i + 2) * size_t(NB) + j]; - resB[3] -= sB_reg * sA[(i + 3) * size_t(NB) + j]; + resB[0] -= sB_reg * sA[(i + 0) * NB + j]; + resB[1] -= sB_reg * sA[(i + 1) * NB + j]; + resB[2] -= sB_reg * sA[(i + 2) * NB + j]; + resB[3] -= sB_reg * sA[(i + 3) * NB + j]; } - resB[0] /= sA[(i + 0) * size_t(NB) + (i + 0)]; + resB[0] /= sA[(i + 0) * NB + (i + 0)]; sB[sb_col + i + 0] = resB[0]; - resB[1] -= resB[0] * sA[(i + 1) * size_t(NB) + (i + 0)]; - resB[1] /= sA[(i + 1) * size_t(NB) + (i + 1)]; + resB[1] -= resB[0] * sA[(i + 1) * NB + (i + 0)]; + resB[1] /= sA[(i + 1) * NB + (i + 1)]; sB[sb_col + i + 1] = resB[1]; - resB[2] -= resB[0] * sA[(i + 2) * size_t(NB) + (i + 0)]; - resB[2] -= resB[1] * sA[(i + 2) * size_t(NB) + (i + 1)]; - resB[2] /= sA[(i + 2) * size_t(NB) + (i + 2)]; + resB[2] -= resB[0] * sA[(i + 2) * NB + (i + 0)]; + resB[2] -= resB[1] * sA[(i + 2) * NB + (i + 1)]; + resB[2] /= sA[(i + 2) * NB + (i + 2)]; sB[sb_col + i + 2] = resB[2]; - resB[3] -= resB[0] * sA[(i + 3) * size_t(NB) + (i + 0)]; - resB[3] -= resB[1] * sA[(i + 3) * size_t(NB) + (i + 1)]; - resB[3] -= resB[2] * sA[(i + 3) * size_t(NB) + (i + 2)]; - resB[3] /= sA[(i + 3) * size_t(NB) + (i + 3)]; + resB[3] -= resB[0] * sA[(i + 3) * NB + (i + 0)]; + resB[3] -= resB[1] * sA[(i + 3) * NB + (i + 1)]; + resB[3] -= resB[2] * sA[(i + 3) * NB + (i + 2)]; + resB[3] /= sA[(i + 3) * NB + (i + 3)]; sB[sb_col + i + 3] = resB[3]; } @@ -2557,9 +2555,9 @@ rocblas_trsm_small_left_device(rocblas_fill uplo, resB[0] = sB[sb_col + i]; for(int j = 0; j < i; j++) { - resB[0] -= sB[sb_col + j] * sA[i * size_t(NB) + j]; + resB[0] -= sB[sb_col + j] * sA[i * NB + j]; } - sB[sb_col + i] = resB[0] / sA[i * size_t(NB) + i]; + sB[sb_col + i] = resB[0] / sA[i * NB + i]; } } @@ -2569,7 +2567,7 @@ rocblas_trsm_small_left_device(rocblas_fill uplo, if(tx < m) { for(int i = 0; i < maxColB; i++) - B[i * size_t(ldb) + tx] = sB[i * size_t(NB) + tx]; + B[i * size_t(ldb) + tx] = sB[i * NB + tx]; } } @@ -2619,7 +2617,7 @@ rocblas_trsm_small_64_left_device(rocblas_fill uplo, { // Load B into sB and multiply by alpha for(int i = 0; i < maxColB; i++) - sB[i * size_t(NB) + tx] = alpha * B[i * size_t(ldb) + tx]; + sB[i * NB + tx] = alpha * B[i * size_t(ldb) + tx]; } __syncthreads(); @@ -2633,57 +2631,55 @@ rocblas_trsm_small_64_left_device(rocblas_fill uplo, for(int j = 0; j < i; j++) { T valA = A[j * size_t(lda) + i]; - sB[tx * size_t(NB) + i] -= sB[tx * size_t(NB) + j] * valA; + sB[tx * NB + i] -= sB[tx * NB + j] * valA; } if(diag != rocblas_diagonal_unit) - sB[tx * size_t(NB) + i] /= A[i * size_t(lda) + i]; + sB[tx * NB + i] /= A[i * size_t(lda) + i]; } } else if(!LOWER && transA == rocblas_operation_none) { for(int i = maxColA; i >= 0; i--) { - T temp_reg_B = sB[tx * size_t(NB) + i]; + T temp_reg_B = sB[tx * NB + i]; for(int j = maxColA; j > i; j--) { T valA = A[j * size_t(lda) + i]; - temp_reg_B -= sB[tx * size_t(NB) + j] * valA; + temp_reg_B -= sB[tx * NB + j] * valA; } - sB[tx * size_t(NB) + i] = temp_reg_B; + sB[tx * NB + i] = temp_reg_B; if(diag != rocblas_diagonal_unit) - sB[tx * size_t(NB) + i] /= A[i * size_t(lda) + i]; + sB[tx * NB + i] /= A[i * size_t(lda) + i]; } } else if(LOWER) { for(int i = maxColA; i >= 0; i--) { - T temp_reg_B = sB[tx * size_t(NB) + i]; + T temp_reg_B = sB[tx * NB + i]; for(int j = maxColA; j > i; j--) { T valA = (CONJ) ? conj(A[i * size_t(lda) + j]) : A[i * size_t(lda) + j]; - temp_reg_B -= sB[tx * size_t(NB) + j] * valA; + temp_reg_B -= sB[tx * NB + j] * valA; } - sB[tx * size_t(NB) + i] = temp_reg_B; + sB[tx * NB + i] = temp_reg_B; if(diag != rocblas_diagonal_unit) - sB[tx * size_t(NB) + i] - /= (CONJ) ? conj(A[i * size_t(lda) + i]) : A[i * size_t(lda) + i]; + sB[tx * NB + i] /= (CONJ) ? conj(A[i * size_t(lda) + i]) : A[i * size_t(lda) + i]; } } else if(!LOWER) { for(int i = 0; i <= maxColA; i++) { - T temp_reg_B = sB[tx * size_t(NB) + i]; + T temp_reg_B = sB[tx * NB + i]; for(int j = 0; j < i; j++) { T valA = (CONJ) ? conj(A[i * size_t(lda) + j]) : A[i * size_t(lda) + j]; - temp_reg_B -= sB[tx * size_t(NB) + j] * valA; + temp_reg_B -= sB[tx * NB + j] * valA; } - sB[tx * size_t(NB) + i] = temp_reg_B; + sB[tx * NB + i] = temp_reg_B; if(diag != rocblas_diagonal_unit) - sB[tx * size_t(NB) + i] - /= (CONJ) ? conj(A[i * size_t(lda) + i]) : A[i * size_t(lda) + i]; + sB[tx * NB + i] /= (CONJ) ? conj(A[i * size_t(lda) + i]) : A[i * size_t(lda) + i]; } } @@ -2693,7 +2689,7 @@ rocblas_trsm_small_64_left_device(rocblas_fill uplo, if(tx < m) { for(int i = 0; i < maxColB; i++) - B[i * size_t(ldb) + tx] = sB[i * size_t(NB) + tx]; + B[i * size_t(ldb) + tx] = sB[i * NB + tx]; } } diff --git a/library/src/blas3/rocblas_trtri.hpp b/library/src/blas3/rocblas_trtri.hpp index 3739c932c..682497990 100644 --- a/library/src/blas3/rocblas_trtri.hpp +++ b/library/src/blas3/rocblas_trtri.hpp @@ -86,19 +86,19 @@ ROCBLAS_KERNEL_ILF void custom_trtri_device(rocblas_fill uplo, if(tx < 2 * n) { - int Aoffset = tx < n ? 0 : n * lda + n; + rocblas_stride Aoffset = tx < n ? 0 : n * size_t(lda) + n; if(uplo == rocblas_fill_lower) { for(int i = 0; i < n; i++) - diagP[index + i * n] = i <= index ? A[Aoffset + index + i * lda] : 0.0f; + diagP[index + i * n] = i <= index ? A[Aoffset + index + i * size_t(lda)] : 0.0f; } else { // transpose A in sA if upper for(int i = n - 1; i >= 0; i--) { diagP[(n - 1 - index) + (n - 1 - i) * n] - = i >= index ? A[Aoffset + index + i * lda] : 0.0f; + = i >= index ? A[Aoffset + index + i * size_t(lda)] : 0.0f; } } } @@ -107,12 +107,12 @@ ROCBLAS_KERNEL_ILF void custom_trtri_device(rocblas_fill uplo, if(uplo == rocblas_fill_lower) { for(int i = 0; i < n; i++) - diagP[index + i * n] = A[n + index + i * lda]; + diagP[index + i * n] = A[n + index + i * size_t(lda)]; } else { // transpose A in diag1 if upper for(int i = n - 1; i >= 0; i--) - diagP[index + i * n] = A[n * lda + index + i * lda]; + diagP[index + i * n] = A[n * size_t(lda) + index + i * size_t(lda)]; } } @@ -213,7 +213,7 @@ ROCBLAS_KERNEL_ILF void custom_trtri_device(rocblas_fill uplo, T sum(0); for(int k = 0; k < r + 1; k++) sum += -1.0f * diag2[r + k * n] * temp[k + c * n]; - invA[n + r + c * ldinvA] = sum; + invA[n + r + c * size_t(ldinvA)] = sum; } } else @@ -223,23 +223,24 @@ ROCBLAS_KERNEL_ILF void custom_trtri_device(rocblas_fill uplo, T sum(0); for(int k = r; k < IB; k++) sum += -1.0f * diag1[(n - 1 - r) + (n - 1 - k) * n] * temp[k + c * n]; - invA[n * ldinvA + r + c * ldinvA] = sum; + invA[n * size_t(ldinvA) + r + c * size_t(ldinvA)] = sum; } } if(tx < 2 * n) { - int AInvoffset = tx < n ? 0 : n * ldinvA + n; + size_t AInvoffset = tx < n ? 0 : n * size_t(ldinvA) + n; if(uplo == rocblas_fill_lower) { for(int i = 0; i <= index; i++) - invA[AInvoffset + index + i * ldinvA] = diagP[index + i * n]; + invA[AInvoffset + index + i * size_t(ldinvA)] = diagP[index + i * n]; } else { // transpose back to A from sA if upper for(int i = n - 1; i >= index; i--) - invA[AInvoffset + index + i * ldinvA] = diagP[(n - 1 - index) + (n - 1 - i) * n]; + invA[AInvoffset + index + i * size_t(ldinvA)] + = diagP[(n - 1 - index) + (n - 1 - i) * n]; } } } @@ -270,12 +271,12 @@ ROCBLAS_KERNEL_ILF void trtri_device(rocblas_fill uplo, { // compute only diagonal element for(int i = 0; i <= tx; i++) - sA[tx + i * n] = A[tx + i * lda]; + sA[tx + i * n] = A[tx + i * size_t(lda)]; } else { // transpose A in sA if upper for(int i = n - 1; i >= tx; i--) - sA[(n - 1 - tx) + (n - 1 - i) * n] = A[tx + i * lda]; + sA[(n - 1 - tx) + (n - 1 - i) * n] = A[tx + i * size_t(lda)]; } } __syncthreads(); // if NB < 64, this synch can be avoided @@ -336,12 +337,12 @@ ROCBLAS_KERNEL_ILF void trtri_device(rocblas_fill uplo, if(uplo == rocblas_fill_lower) { for(int i = 0; i <= tx; i++) - invA[tx + i * ldinvA] = sA[tx + i * n]; + invA[tx + i * size_t(ldinvA)] = sA[tx + i * n]; } else { // transpose back to A from sA if upper for(int i = n - 1; i >= tx; i--) - invA[tx + i * ldinvA] = sA[(n - 1 - tx) + (n - 1 - i) * n]; + invA[tx + i * size_t(ldinvA)] = sA[(n - 1 - tx) + (n - 1 - i) * n]; } } } @@ -364,7 +365,7 @@ ROCBLAS_KERNEL_ILF void rocblas_tritri_fill_upper(rocblas_stride offset, rocblas_int row = n - 2 - floor(sqrt(4 * n * (n - 1) - 7 - 8 * idx) / 2.0 - 0.5); rocblas_int col = idx + row + 1 - n * (n - 1) / 2 + (n - row) * (n - row - 1) / 2; - size_t final_offset = offset * sub_stride_A + (row * lda) + col; + size_t final_offset = offset * sub_stride_A + (row * size_t(lda)) + col; A[final_offset] = value; } @@ -376,7 +377,7 @@ ROCBLAS_KERNEL_ILF void rocblas_tritri_fill_lower( rocblas_int row = (rocblas_int)((-1 + sqrt(8 * idx + 1)) / 2); rocblas_int col = idx - row * (row + 1) / 2; - size_t final_offset = offset * sub_stride_A + ((row + 1) * lda) + col; + size_t final_offset = offset * sub_stride_A + ((row + 1) * size_t(lda)) + col; A[final_offset] = value; } @@ -503,7 +504,7 @@ rocblas_status rocblas_trtri_small(rocblas_handle handle, n, num_non_tri_elements(n), ldinvA, - n * ldinvA, + n * size_t(ldinvA), invA, offset_invA, 0, @@ -557,10 +558,10 @@ trtri_diagonal_kernel(rocblas_fill uplo, rocblas_int tiles = n / IB / 2; const T* individual_A = load_ptr_batch(A, blockIdx.y, offset_A, stride_A) - + (IB * 2 * lda + IB * 2) * (blockIdx.x % tiles) + + (IB * 2 * size_t(lda) + IB * 2) * (blockIdx.x % tiles) + sub_stride_A * (blockIdx.x / tiles); T* individual_invA = load_ptr_batch(invA, blockIdx.y, offset_invA, stride_invA) - + (IB * 2 * ldinvA + IB * 2) * (blockIdx.x % tiles) + + (IB * 2 * size_t(ldinvA) + IB * 2) * (blockIdx.x % tiles) + sub_stride_invA * (blockIdx.x / tiles); auto rem = n - (blockIdx.x % tiles) * IB; @@ -757,8 +758,10 @@ rocblas_status rocblas_trtri_large(rocblas_handle handle, dim3 grid_remainder(sub_batch_count, batch_count); dim3 threads_remainder(remainder); - rocblas_int offset_A2 = (n - remainder) + (n - remainder) * lda + offset_Ain; - rocblas_int offset_invA2 = (n - remainder) + (n - remainder) * ldinvA + offset_invAin; + rocblas_stride offset_A2 = (n - remainder) + (n - remainder) * size_t(lda) + offset_Ain; + rocblas_stride offset_invA2 + = (n - remainder) + (n - remainder) * size_t(ldinvA) + offset_invAin; + hipLaunchKernelGGL((trtri_remainder_kernel), grid_remainder, threads_remainder, @@ -799,7 +802,7 @@ rocblas_status rocblas_trtri_large(rocblas_handle handle, n, num_non_tri_elements(n), ldinvA, - n * ldinvA, + n * size_t(ldinvA), invA, offset_invAin, stride_invA, @@ -816,23 +819,24 @@ rocblas_status rocblas_trtri_large(rocblas_handle handle, { for(int i = 0; i < sub_batch_count; i++) { - rocblas_int offset_A + rocblas_stride offset_A = (uplo == rocblas_fill_lower ? current_n + i * sub_stride_Ain - : current_n * lda + i * sub_stride_Ain); - rocblas_int offset_invA1 + : current_n * size_t(lda) + i * sub_stride_Ain); + rocblas_stride offset_invA1 = (uplo == rocblas_fill_lower ? 0 + i * sub_stride_invAin - : current_n * ldinvA + current_n + i * sub_stride_invAin); - rocblas_int offset_invA2 + : current_n * size_t(ldinvA) + current_n + i * sub_stride_invAin); + rocblas_stride offset_invA2 = (uplo == rocblas_fill_lower - ? current_n * ldinvA + current_n + i * sub_stride_invAin + ? current_n * size_t(ldinvA) + current_n + i * sub_stride_invAin : 0 + i * sub_stride_invAin); - rocblas_int offset_invA3 - = (uplo == rocblas_fill_lower ? current_n + i * sub_stride_invAin - : current_n * ldinvA + i * sub_stride_invAin); - rocblas_int offset_C + rocblas_stride offset_invA3 + = (uplo == rocblas_fill_lower + ? current_n + i * sub_stride_invAin + : current_n * size_t(ldinvA) + i * sub_stride_invAin); + rocblas_stride offset_C = (uplo == rocblas_fill_lower - ? (n - current_n) * ldinvA + i * sub_stride_invAin + ? (n - current_n) * size_t(ldinvA) + i * sub_stride_invAin : (n - current_n * tiles_per_batch) + i * sub_stride_invAin); offset_A += offset_Ain; @@ -841,60 +845,62 @@ rocblas_status rocblas_trtri_large(rocblas_handle handle, offset_invA3 += offset_invAin; offset_C += offset_invAin; - trtri_gemm_block(handle, - current_n, - current_n, - (U)A, - lda, - stride_A, - 2 * current_n * lda + 2 * current_n, - (U)invA, - (U)invA, - (V)invA, - ldinvA, - stride_invA, - 2 * current_n * ldinvA + 2 * current_n, - (V)invA, - ldinvA, - stride_invA, - current_n, - batch_count, - tiles_per_batch, - offset_A, - offset_invA1, - offset_invA2, - offset_invA3, - offset_C); + trtri_gemm_block( + handle, + current_n, + current_n, + (U)A, + lda, + stride_A, + 2 * current_n * size_t(lda) + 2 * current_n, + (U)invA, + (U)invA, + (V)invA, + ldinvA, + stride_invA, + 2 * current_n * size_t(ldinvA) + 2 * current_n, + (V)invA, + ldinvA, + stride_invA, + current_n, + batch_count, + tiles_per_batch, + offset_A, + offset_invA1, + offset_invA2, + offset_invA3, + offset_C); } } else { for(int i = 0; i < tiles_per_batch; i++) { - rocblas_int sub_stride_A2 = (2 * current_n * lda + 2 * current_n); - rocblas_int sub_stride_invA2 = (2 * current_n * ldinvA + 2 * current_n); + rocblas_stride sub_stride_A2 = (2 * current_n * size_t(lda) + 2 * current_n); + rocblas_stride sub_stride_invA2 = (2 * current_n * size_t(ldinvA) + 2 * current_n); - rocblas_int offset_A + rocblas_stride offset_A = (uplo == rocblas_fill_lower ? current_n + i * sub_stride_A2 - : current_n * lda + i * sub_stride_A2); + : current_n * size_t(lda) + i * sub_stride_A2); - rocblas_int offset_invA1 + rocblas_stride offset_invA1 = (uplo == rocblas_fill_lower ? 0 + i * sub_stride_invA2 - : current_n * ldinvA + current_n + i * sub_stride_invA2); + : current_n * size_t(ldinvA) + current_n + i * sub_stride_invA2); - rocblas_int offset_invA2 + rocblas_stride offset_invA2 = (uplo == rocblas_fill_lower - ? current_n * ldinvA + current_n + i * sub_stride_invA2 + ? current_n * size_t(ldinvA) + current_n + i * sub_stride_invA2 : 0 + i * sub_stride_invA2); - rocblas_int offset_invA3 - = (uplo == rocblas_fill_lower ? current_n + i * sub_stride_invA2 - : current_n * ldinvA + i * sub_stride_invA2); + rocblas_stride offset_invA3 + = (uplo == rocblas_fill_lower + ? current_n + i * sub_stride_invA2 + : current_n * size_t(ldinvA) + i * sub_stride_invA2); - rocblas_int offset_C = (uplo == rocblas_fill_lower - ? (n - current_n) * ldinvA + i * current_n - : (n - current_n * tiles_per_batch) + i * current_n); + rocblas_stride offset_C = (uplo == rocblas_fill_lower + ? (n - current_n) * size_t(ldinvA) + i * current_n + : (n - current_n * tiles_per_batch) + i * current_n); offset_A += offset_Ain; offset_invA1 += offset_invAin; @@ -940,7 +946,7 @@ rocblas_status rocblas_trtri_large(rocblas_handle handle, n, num_non_tri_elements(n), ldinvA, - n * ldinvA, + n * size_t(ldinvA), invA, offset_invAin, stride_invA, @@ -959,12 +965,14 @@ rocblas_status rocblas_trtri_large(rocblas_handle handle, // and in some cases this happens with multiple sizes. if(remainder > 0) { - rocblas_int offset_A = (uplo == rocblas_fill_lower ? current_n : current_n * lda); - rocblas_int offset_invA1 - = (uplo == rocblas_fill_lower ? 0 : current_n * ldinvA + current_n); - rocblas_int offset_invA2 - = (uplo == rocblas_fill_lower ? current_n * ldinvA + current_n : 0); - rocblas_int offset_invA3 = (uplo == rocblas_fill_lower ? current_n : current_n * ldinvA); + rocblas_stride offset_A + = (uplo == rocblas_fill_lower ? current_n : current_n * size_t(lda)); + rocblas_stride offset_invA1 + = (uplo == rocblas_fill_lower ? 0 : current_n * size_t(ldinvA) + current_n); + rocblas_stride offset_invA2 + = (uplo == rocblas_fill_lower ? current_n * size_t(ldinvA) + current_n : 0); + rocblas_stride offset_invA3 + = (uplo == rocblas_fill_lower ? current_n : current_n * size_t(ldinvA)); offset_A += offset_Ain; offset_invA1 += offset_invAin; @@ -999,13 +1007,15 @@ rocblas_status rocblas_trtri_large(rocblas_handle handle, while(oddRemainder) { - current_n = n - oddRemainder; - rocblas_int offset_A = (uplo == rocblas_fill_lower ? current_n : current_n * lda); - rocblas_int offset_invA1 - = (uplo == rocblas_fill_lower ? 0 : current_n * ldinvA + current_n); - rocblas_int offset_invA2 - = (uplo == rocblas_fill_lower ? current_n * ldinvA + current_n : 0); - rocblas_int offset_invA3 = (uplo == rocblas_fill_lower ? current_n : current_n * ldinvA); + current_n = n - oddRemainder; + rocblas_stride offset_A + = (uplo == rocblas_fill_lower ? current_n : current_n * size_t(lda)); + rocblas_stride offset_invA1 + = (uplo == rocblas_fill_lower ? 0 : current_n * size_t(ldinvA) + current_n); + rocblas_stride offset_invA2 + = (uplo == rocblas_fill_lower ? current_n * size_t(ldinvA) + current_n : 0); + rocblas_stride offset_invA3 + = (uplo == rocblas_fill_lower ? current_n : current_n * size_t(ldinvA)); offset_A += offset_Ain; offset_invA1 += offset_invAin; diff --git a/library/src/blas3/trtri_trsm.hpp b/library/src/blas3/trtri_trsm.hpp index c2f5fab18..5f21ba7e1 100644 --- a/library/src/blas3/trtri_trsm.hpp +++ b/library/src/blas3/trtri_trsm.hpp @@ -54,9 +54,9 @@ trtri_trsm_kernel(rocblas_fill uplo, // device function only see one matrix // each hip thread Block compute a inverse of a IB * IB diagonal block of A - rocblas_int offA = (2 * blockIdx.x) * (IB * lda + IB) + offset_A; - rocblas_int offinvA = ((2 * blockIdx.x) / IBD) * (NB * NB) - + ((2 * blockIdx.x) % IBD) * (IB * NB + IB) + offset_invA; + rocblas_stride offA = (2 * blockIdx.x) * (IB * size_t(lda) + IB) + offset_A; + rocblas_stride offinvA = ((2 * blockIdx.x) / IBD) * (NB * size_t(NB)) + + ((2 * blockIdx.x) % IBD) * (IB * size_t(NB) + IB) + offset_invA; const T* a_i = load_ptr_batch(A, blockIdx.y, offA, stride_A); T* invA_i = load_ptr_batch(invA, blockIdx.y, offinvA, stride_invA); @@ -205,13 +205,13 @@ rocblas_status rocblas_trtri_trsm_template(rocblas_handle handle, sub_blocks); constexpr rocblas_int JB = IB * 4; - rocblas_int sub_stride_A = NB * lda + NB; - rocblas_int sub_stride_invA = NB * NB; - rocblas_int sub_stride_C = JB * JB; - rocblas_int offset_A = (uplo == rocblas_fill_lower ? IB * 2 : IB * 2 * lda); - rocblas_int offset_invA1 = (uplo == rocblas_fill_lower ? 0 : IB * 2 * NB + IB * 2); - rocblas_int offset_invA2 = (uplo == rocblas_fill_lower ? IB * 2 * NB + IB * 2 : 0); - rocblas_int offset_invA3 = (uplo == rocblas_fill_lower ? IB * 2 : IB * 2 * NB); + rocblas_stride sub_stride_A = NB * size_t(lda) + NB; + rocblas_stride sub_stride_invA = NB * NB; + rocblas_stride sub_stride_C = JB * JB; + rocblas_stride offset_A = (uplo == rocblas_fill_lower ? IB * 2 : IB * 2 * size_t(lda)); + rocblas_stride offset_invA1 = (uplo == rocblas_fill_lower ? 0 : IB * 2 * NB + IB * 2); + rocblas_stride offset_invA2 = (uplo == rocblas_fill_lower ? IB * 2 * NB + IB * 2 : 0); + rocblas_stride offset_invA3 = (uplo == rocblas_fill_lower ? IB * 2 : IB * 2 * NB); trtri_gemm_block(handle, IB * 2, @@ -238,7 +238,8 @@ rocblas_status rocblas_trtri_trsm_template(rocblas_handle handle, offset_invAin + offset_invA3, 0); - offset_A = (uplo == rocblas_fill_lower ? IB * 4 * lda + IB * 6 : IB * 6 * lda + IB * 4); + offset_A = (uplo == rocblas_fill_lower ? IB * 4 * size_t(lda) + IB * 6 + : IB * 6 * size_t(lda) + IB * 4); offset_invA1 = (uplo == rocblas_fill_lower ? IB * 4 * NB + IB * 4 : IB * 6 * NB + IB * 6); offset_invA2 = (uplo == rocblas_fill_lower ? IB * 6 * NB + IB * 6 : IB * 4 * NB + IB * 4); offset_invA3 = (uplo == rocblas_fill_lower ? IB * 4 * NB + IB * 6 : IB * 6 * NB + IB * 4); @@ -268,7 +269,7 @@ rocblas_status rocblas_trtri_trsm_template(rocblas_handle handle, offset_invAin + offset_invA3, 0); - offset_A = (uplo == rocblas_fill_lower ? JB : JB * lda); + offset_A = (uplo == rocblas_fill_lower ? JB : JB * size_t(lda)); offset_invA1 = (uplo == rocblas_fill_lower ? 0 : JB * NB + JB); offset_invA2 = (uplo == rocblas_fill_lower ? JB * NB + JB : 0); offset_invA3 = (uplo == rocblas_fill_lower ? JB : JB * NB); @@ -320,7 +321,7 @@ rocblas_status rocblas_trtri_trsm_template(rocblas_handle handle, NB, 0, invA, - sub_blocks * NB * NB + offset_invAin, + sub_blocks * NB * size_t(NB) + offset_invAin, stride_invA, 1); @@ -330,12 +331,12 @@ rocblas_status rocblas_trtri_trsm_template(rocblas_handle handle, diag, rem, A, - sub_blocks * NB * lda + sub_blocks * NB + offset_Ain, + sub_blocks * NB * size_t(lda) + sub_blocks * NB + offset_Ain, lda, stride_A, 0, invA, - sub_blocks * NB * NB + offset_invAin, + sub_blocks * NB * size_t(NB) + offset_invAin, NB, stride_invA, 0, diff --git a/scripts/performance/pts/benchmarks/trsm_problems.yaml b/scripts/performance/pts/benchmarks/trsm_problems.yaml index 92061bb20..4e3e134cf 100644 --- a/scripts/performance/pts/benchmarks/trsm_problems.yaml +++ b/scripts/performance/pts/benchmarks/trsm_problems.yaml @@ -69,7 +69,7 @@ Definitions: - &testset1_matrix_size_range - { M: 128, N: 2048, lda: 128, ldb: 128 } - { M: 128, N: 16848, lda: 128, ldb: 128 } -# - { M: 128, N: 29696, lda: 128, ldb: 128 } + - { M: 128, N: 29696, lda: 128, ldb: 128 } # - { M: 128, N: 44544, lda: 128, ldb: 128 } # - { M: 128, N: 53632, lda: 128, ldb: 128 } - { M: 256, N: 2048, lda: 256, ldb: 256 } @@ -78,75 +78,75 @@ Definitions: # - { M: 256, N: 53504, lda: 256, ldb: 256 } - { M: 384, N: 2048, lda: 384, ldb: 384 } - { M: 384, N: 14976, lda: 384, ldb: 384 } -# - { M: 384, N: 29952, lda: 384, ldb: 384 } + - { M: 384, N: 29952, lda: 384, ldb: 384 } # - { M: 384, N: 44928, lda: 384, ldb: 384 } # - { M: 384, N: 53376, lda: 384, ldb: 384 } - &testset2_matrix_size_range - { M: 2048, N: 128, lda: 2048, ldb: 2048 } - { M: 16848, N: 128, lda: 16848, ldb: 16848 } -# - { M: 29696, N: 128, lda: 29696, ldb: 29696 } + - { M: 29696, N: 128, lda: 29696, ldb: 29696 } # - { M: 44544, N: 128, lda: 44544, ldb: 44544 } # - { M: 53632, N: 128, lda: 53632, ldb: 53632 } - { M: 2048, N: 256, lda: 2048, ldb: 2048 } - { M: 14848, N: 256, lda: 14848, ldb: 14848 } -# - { M: 29696, N: 256, lda: 29696, ldb: 29696 } + - { M: 29696, N: 256, lda: 29696, ldb: 29696 } # - { M: 44544, N: 256, lda: 44544, ldb: 44544 } # - { M: 53504, N: 256, lda: 53504, ldb: 53504 } - { M: 2048, N: 384, lda: 2048, ldb: 2048 } - { M: 14976, N: 384, lda: 14976, ldb: 14976 } -# - { M: 29952, N: 384, lda: 29952, ldb: 29952 } + - { M: 29952, N: 384, lda: 29952, ldb: 29952 } # - { M: 44928, N: 384, lda: 44928, ldb: 44928 } # - { M: 53376, N: 384, lda: 53376, ldb: 53376 } Tests: - - name: trsm_bench_const_n - category: bench - function: trsm - precision: *single_precision - transA: [ N, T ] - side: L - uplo: U - diag: U - alpha: 1 - incx: 1 - incy: 1 - N: 32 - M: 32..120..8 - lda: 120 # TODO: easy way to increment lda in lockstep with M? - ldb: 120 - iters: 20 +# - name: trsm_bench_const_n +# category: bench +# function: trsm +# precision: *single_precision +# transA: [ N, T ] +# side: L +# uplo: U +# diag: U +# alpha: 1 +# incx: 1 +# incy: 1 +# N: 32 +# M: 32..120..8 +# lda: 120 # TODO: easy way to increment lda in lockstep with M? +# ldb: 120 +# iters: 20 - - name: trsm_bench_const_m - category: bench - function: trsm - precision: *single_precision - transA: [ N, T ] - side: L - uplo: U - diag: U - alpha: 1 - incx: 1 - incy: 1 - N: 32..480..32 - M: 32 - lda: 32 - ldb: 32 - iters: 20 +# - name: trsm_bench_const_m +# category: bench +# function: trsm +# precision: *single_precision +# transA: [ N, T ] +# side: L +# uplo: U +# diag: U +# alpha: 1 +# incx: 1 +# incy: 1 +# N: 32..480..32 +# M: 32 +# lda: 32 +# ldb: 32 +# iters: 20 - - name: trsm_bench_m_equals_n - category: bench - function: trsm - precision: *single_precision - transA: [ N, T ] - side: L - uplo: U - diag: U - alpha: 1 - incx: 1 - incy: 1 - matrix_size: *m_equals_n_range - iters: 20 +# - name: trsm_bench_m_equals_n +# category: bench +# function: trsm +# precision: *single_precision +# transA: [ N, T ] +# side: L +# uplo: U +# diag: U +# alpha: 1 +# incx: 1 +# incy: 1 +# matrix_size: *m_equals_n_range +# iters: 20 - name: trsm_bench_1_small_matrix_size category: bench From 6e75eb38dc224ad034c03e366d2f4d116252c09c Mon Sep 17 00:00:00 2001 From: Torre Zuk <42548444+TorreZuk@users.noreply.github.com> Date: Mon, 6 Mar 2023 12:19:48 -0700 Subject: [PATCH 05/11] pre-apply 64bit offsets for non-batched and strided (#1689) * works around tensile side offset size limitations --- library/src/blas3/Tensile/gemm_tensile.hpp | 40 +++++++++++++++------- library/src/blas_ex/rocblas_gemm_ex.hpp | 37 +++++++++++++++++--- library/src/blas_ex/rocblas_gemm_ext2.hpp | 17 ++++----- 3 files changed, 68 insertions(+), 26 deletions(-) diff --git a/library/src/blas3/Tensile/gemm_tensile.hpp b/library/src/blas3/Tensile/gemm_tensile.hpp index 6703f12d8..ffc338a84 100644 --- a/library/src/blas3/Tensile/gemm_tensile.hpp +++ b/library/src/blas3/Tensile/gemm_tensile.hpp @@ -139,19 +139,33 @@ inline rocblas_status call_tensile(rocblas_handle handle, } #endif - RocblasContractionProblem problem{handle, trans_a, - trans_b, m, - n, k, - alpha, A, - nullptr, ld_a, - stride_a, offset_a, - B, nullptr, - ld_b, stride_b, - offset_b, beta, - C, nullptr, - ld_c, stride_c, - offset_c, batch_count, - true, rocblas_gemm_flags_none}; + // pre apply offsets for non-batched and strided + RocblasContractionProblem problem{handle, + trans_a, + trans_b, + m, + n, + k, + alpha, + A + offset_a, + nullptr, + ld_a, + stride_a, + 0 /* offset_a */, + B + offset_b, + nullptr, + ld_b, + stride_b, + 0 /* offset_b */, + beta, + C + offset_c, + nullptr, + ld_c, + stride_c, + 0 /* offset_c */, + batch_count, + true, + rocblas_gemm_flags_none}; return runContractionProblem(problem); } diff --git a/library/src/blas_ex/rocblas_gemm_ex.hpp b/library/src/blas_ex/rocblas_gemm_ex.hpp index 28169c916..90b299640 100644 --- a/library/src/blas_ex/rocblas_gemm_ex.hpp +++ b/library/src/blas_ex/rocblas_gemm_ex.hpp @@ -263,11 +263,38 @@ rocblas_status gemm_ex_batched_template(rocblas_handle handle, int32_t solution_index, rocblas_gemm_flags flags) { - RocblasContractionProblem problem{ - handle, trans_a, trans_b, m, n, k, alpha, a, - nullptr, lda, stride_a, offset_a, b, nullptr, ldb, stride_b, - offset_b, beta, c, nullptr, ldc, stride_c, offset_c, d, - nullptr, ldd, stride_d, offset_d, batch_count, true, flags}; + // pre apply offsets for non-batched and strided + RocblasContractionProblem problem{handle, + trans_a, + trans_b, + m, + n, + k, + alpha, + a + offset_a, + nullptr, + lda, + stride_a, + 0 /* offset_a */, + b + offset_b, + nullptr, + ldb, + stride_b, + 0 /* offset_b */, + beta, + c + offset_c, + nullptr, + ldc, + stride_c, + 0 /* offset_c */, + d + offset_d, + nullptr, + ldd, + stride_d, + 0 /* offset_d */, + batch_count, + true, + flags}; return runContractionProblem(problem, algo, solution_index); } diff --git a/library/src/blas_ex/rocblas_gemm_ext2.hpp b/library/src/blas_ex/rocblas_gemm_ext2.hpp index 3ae4efad7..72dfca72e 100644 --- a/library/src/blas_ex/rocblas_gemm_ext2.hpp +++ b/library/src/blas_ex/rocblas_gemm_ext2.hpp @@ -58,36 +58,37 @@ rocblas_status gemm_ext2_batched_template(rocblas_handle handle, rocblas_int batch_count = 1, bool strided_batch = true) { + // pre apply offsets for non-batched and strided RocblasContractionProblem problem{handle, m, n, k, alpha, - a, + a + offset_a, nullptr, row_stride_a, col_stride_a, batch_stride_a, - offset_a, - b, + 0 /* offset_a */, + b + offset_b, nullptr, row_stride_b, col_stride_b, batch_stride_b, - offset_b, + 0 /* offset_b */, beta, - c, + c + offset_c, nullptr, row_stride_c, col_stride_c, batch_stride_c, - offset_c, - d, + 0 /* offset_c */, + d + offset_d, nullptr, row_stride_d, col_stride_d, batch_stride_d, - offset_d, + 0 /* offset_d */, batch_count, strided_batch}; From 3ca4c1ceba8b16ebdd12016d6f57d0d3441cee83 Mon Sep 17 00:00:00 2001 From: amcamd Date: Fri, 17 Mar 2023 13:58:13 -0500 Subject: [PATCH 06/11] tag from Tensile hotfix pr#1687 64-bit offset parameters for post kernels --- tensile_tag.txt | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tensile_tag.txt b/tensile_tag.txt index 85fd25953..c0735d6f6 100644 --- a/tensile_tag.txt +++ b/tensile_tag.txt @@ -1 +1 @@ -e8a3c7d15ec1848a53888747345087ad74ce63f3 +9ef81616d17104869349d547493c64132fe4baa2 From f7bd389e62e45b1f1b2dd1812dc9cc3bc96c02a7 Mon Sep 17 00:00:00 2001 From: amcamd Date: Fri, 17 Mar 2023 16:34:46 -0500 Subject: [PATCH 07/11] clang-format fix for single file --- library/src/blas3/rocblas_trtri.hpp | 50 ++++++++++++++--------------- 1 file changed, 25 insertions(+), 25 deletions(-) diff --git a/library/src/blas3/rocblas_trtri.hpp b/library/src/blas3/rocblas_trtri.hpp index 682497990..17604c339 100644 --- a/library/src/blas3/rocblas_trtri.hpp +++ b/library/src/blas3/rocblas_trtri.hpp @@ -845,31 +845,31 @@ rocblas_status rocblas_trtri_large(rocblas_handle handle, offset_invA3 += offset_invAin; offset_C += offset_invAin; - trtri_gemm_block( - handle, - current_n, - current_n, - (U)A, - lda, - stride_A, - 2 * current_n * size_t(lda) + 2 * current_n, - (U)invA, - (U)invA, - (V)invA, - ldinvA, - stride_invA, - 2 * current_n * size_t(ldinvA) + 2 * current_n, - (V)invA, - ldinvA, - stride_invA, - current_n, - batch_count, - tiles_per_batch, - offset_A, - offset_invA1, - offset_invA2, - offset_invA3, - offset_C); + trtri_gemm_block(handle, + current_n, + current_n, + (U)A, + lda, + stride_A, + 2 * current_n * size_t(lda) + 2 * current_n, + (U)invA, + (U)invA, + (V)invA, + ldinvA, + stride_invA, + 2 * current_n * size_t(ldinvA) + + 2 * current_n, + (V)invA, + ldinvA, + stride_invA, + current_n, + batch_count, + tiles_per_batch, + offset_A, + offset_invA1, + offset_invA2, + offset_invA3, + offset_C); } } else From 31414d783128df855c0bae4f692e6b1a75c62946 Mon Sep 17 00:00:00 2001 From: daineAMD Date: Wed, 22 Mar 2023 14:29:12 -0600 Subject: [PATCH 08/11] Fix to make_unit_diagonal in tests. --- clients/include/utility.hpp | 10 +++++----- 1 file changed, 5 insertions(+), 5 deletions(-) diff --git a/clients/include/utility.hpp b/clients/include/utility.hpp index 125388544..d5d197e16 100644 --- a/clients/include/utility.hpp +++ b/clients/include/utility.hpp @@ -341,24 +341,24 @@ void make_unit_diagonal(rocblas_fill uplo, T* hA, rocblas_int lda, rocblas_int N { for(int i = 0; i < N; i++) { - T diag = hA[i + i * lda]; + T diag = hA[i + i * size_t(lda)]; for(int j = 0; j <= i; j++) - hA[i + j * lda] = hA[i + j * lda] / diag; + hA[i + j * size_t(lda)] = hA[i + j * size_t(lda)] / diag; } } else // rocblas_fill_upper { for(int j = 0; j < N; j++) { - T diag = hA[j + j * lda]; + T diag = hA[j + j * size_t(lda)]; for(int i = 0; i <= j; i++) - hA[i + j * lda] = hA[i + j * lda] / diag; + hA[i + j * size_t(lda)] = hA[i + j * size_t(lda)] / diag; } } // randomly initalize diagonal to ensure we aren't using it's values for tests. for(int i = 0; i < N; i++) { - rocblas_init(hA + i * lda + i, 1, 1, 1); + rocblas_init(hA + i * size_t(lda) + i, 1, 1, 1); } } From a8a6cab45d5d24c275ea55aa118ac6a27bf3f17d Mon Sep 17 00:00:00 2001 From: daineAMD Date: Fri, 24 Mar 2023 08:46:44 -0600 Subject: [PATCH 09/11] Fixing trsm overflow tests and changing sizes. --- clients/gtest/trsm_gtest.yaml | 47 +++++++++++++++---- .../blas3/testing_trsm_strided_batched.hpp | 16 +++---- 2 files changed, 45 insertions(+), 18 deletions(-) diff --git a/clients/gtest/trsm_gtest.yaml b/clients/gtest/trsm_gtest.yaml index 6e91f52a7..89f752369 100644 --- a/clients/gtest/trsm_gtest.yaml +++ b/clients/gtest/trsm_gtest.yaml @@ -61,13 +61,15 @@ Definitions: - { M: 8320, N: 128, lda: 8320, ldb: 8320 } - &size_t_left_matrix_size_range -# - { M: 4, N: 46435, lda: 4, ldb: 46435 } - - { M: 46345, N: 4, lda: 46345, ldb: 46345 } -# - { M: 47000, N: 4, lda: 47000, ldb: 47000 } # calls rocblas_internal_gemm_template with batch_count=367, stride_a=6016128 + - { M: 128, N: 128, lda: 16777220, ldb: 16777220 } # trsm "special" kernel with overflow lda/ldb + - { M: 1000, N: 4, lda: 2147500, ldb: 2147500 } # trsm "left" kernel with overflow lda/ldb + #- { M: 16, N: 16, lda: 134217728, ldb: 134217728 } # trsm "small" kernel with overflow lda/ldb + #- { M: 128, N: 64, lda: 33554450, ldb: 33554450 } # trsm "subsitution" kernel with overflow lda/ldb - &size_t_right_matrix_size_range - - { M: 4, N: 46345, lda: 46345, ldb: 4 } -# - { M: 4, N: 47000, lda: 47000, ldb: 4 } # calls rocblas_internal_gemm_template with batch_count=367, stride_a=6016128 + - { M: 128, N: 128, lda: 16777220, ldb: 16777220 } # trsm "special" kernel with overflow lda/ldb + - { M: 4, N: 1000, lda: 2147500, ldb: 2147500 } # trsm "right" kernel with overflow lda/ldb + #- { M: 16, N: 16, lda: 134217728, ldb: 134217728 } # trsm "small" kernel with overflow lda/ldb - &substitution_size_range_thorough - { M: 1, N: 1, lda: 100, ldb: 100 } @@ -497,21 +499,46 @@ Tests: - name: trsm_size_t_left category: nightly - function: trsm - precision: *single_precision + function: + - trsm + - trsm_batched + - trsm_strided_batched + precision: *single_double_precisions_complex_real +# precision: *single_precision arguments: - - { side: L, uplo: L, transA: N, diag: N } + #- { side: L, uplo: L, transA: N, diag: N } + #- { side: L, uplo: L, transA: N, diag: U } + #- { side: L, uplo: L, transA: T, diag: N } + #- { side: L, uplo: L, transA: T, diag: U } + #- { side: L, uplo: U, transA: N, diag: N } + - { side: L, uplo: U, transA: N, diag: U } + #- { side: L, uplo: U, transA: T, diag: N } + #- { side: L, uplo: U, transA: T, diag: U } matrix_size: *size_t_left_matrix_size_range alpha: [2] + batch_count: [2] + stride_scale: [1] - name: trsm_size_t_right category: nightly - function: trsm - precision: *single_precision + function: + - trsm + - trsm_batched + - trsm_strided_batched + precision: *single_double_precisions_complex_real +# precision: *single_precision arguments: - { side: R, uplo: L, transA: N, diag: N } + #- { side: R, uplo: L, transA: N, diag: U } + #- { side: R, uplo: L, transA: T, diag: N } + #- { side: R, uplo: L, transA: T, diag: U } + #- { side: R, uplo: U, transA: N, diag: N } + #- { side: R, uplo: U, transA: N, diag: U } + #- { side: R, uplo: U, transA: T, diag: N } + #- { side: R, uplo: U, transA: T, diag: U } matrix_size: *size_t_right_matrix_size_range alpha: [2] + stride_scale: [1] - name: trsm_large category: nightly diff --git a/clients/include/blas3/testing_trsm_strided_batched.hpp b/clients/include/blas3/testing_trsm_strided_batched.hpp index 0891abaa0..89f4e3c24 100644 --- a/clients/include/blas3/testing_trsm_strided_batched.hpp +++ b/clients/include/blas3/testing_trsm_strided_batched.hpp @@ -1,5 +1,5 @@ /* ************************************************************************ - * Copyright (C) 2018-2022 Advanced Micro Devices, Inc. All rights reserved. + * Copyright (C) 2018-2023 Advanced Micro Devices, Inc. All rights reserved. * * Permission is hereby granted, free of charge, to any person obtaining a copy * of this software and associated documentation files (the "Software"), to deal @@ -413,13 +413,13 @@ void testing_trsm_strided_batched(const Arguments& arg) auto rocblas_trsm_strided_batched_fn = arg.fortran ? rocblas_trsm_strided_batched : rocblas_trsm_strided_batched; - rocblas_int M = arg.M; - rocblas_int N = arg.N; - rocblas_int lda = arg.lda; - rocblas_int ldb = arg.ldb; - rocblas_int stride_A = arg.stride_a; - rocblas_int stride_B = arg.stride_b; - rocblas_int batch_count = arg.batch_count; + rocblas_int M = arg.M; + rocblas_int N = arg.N; + rocblas_int lda = arg.lda; + rocblas_int ldb = arg.ldb; + rocblas_stride stride_A = arg.stride_a; + rocblas_stride stride_B = arg.stride_b; + rocblas_int batch_count = arg.batch_count; char char_side = arg.side; char char_uplo = arg.uplo; From 8e90720f33c90c727f71c9cc405877e7098fd02b Mon Sep 17 00:00:00 2001 From: daineAMD Date: Fri, 24 Mar 2023 11:52:35 -0600 Subject: [PATCH 10/11] Reducing trsm large lda tests. --- clients/gtest/trsm_gtest.yaml | 18 +++++++++--------- 1 file changed, 9 insertions(+), 9 deletions(-) diff --git a/clients/gtest/trsm_gtest.yaml b/clients/gtest/trsm_gtest.yaml index 89f752369..019bdae81 100644 --- a/clients/gtest/trsm_gtest.yaml +++ b/clients/gtest/trsm_gtest.yaml @@ -67,7 +67,7 @@ Definitions: #- { M: 128, N: 64, lda: 33554450, ldb: 33554450 } # trsm "subsitution" kernel with overflow lda/ldb - &size_t_right_matrix_size_range - - { M: 128, N: 128, lda: 16777220, ldb: 16777220 } # trsm "special" kernel with overflow lda/ldb + #- { M: 128, N: 128, lda: 16777220, ldb: 16777220 } # trsm "special" kernel with overflow lda/ldb - { M: 4, N: 1000, lda: 2147500, ldb: 2147500 } # trsm "right" kernel with overflow lda/ldb #- { M: 16, N: 16, lda: 134217728, ldb: 134217728 } # trsm "small" kernel with overflow lda/ldb @@ -501,10 +501,10 @@ Tests: category: nightly function: - trsm - - trsm_batched - - trsm_strided_batched - precision: *single_double_precisions_complex_real -# precision: *single_precision + #- trsm_batched + #- trsm_strided_batched + # precision: *single_double_precisions_complex_real + precision: *single_precision arguments: #- { side: L, uplo: L, transA: N, diag: N } #- { side: L, uplo: L, transA: N, diag: U } @@ -523,10 +523,10 @@ Tests: category: nightly function: - trsm - - trsm_batched - - trsm_strided_batched - precision: *single_double_precisions_complex_real -# precision: *single_precision + #- trsm_batched + #- trsm_strided_batched + # precision: *single_double_precisions_complex_real + precision: *single_precision arguments: - { side: R, uplo: L, transA: N, diag: N } #- { side: R, uplo: L, transA: N, diag: U } From 936df4f3e418c5108544644d01c7fbffd73b45ba Mon Sep 17 00:00:00 2001 From: amcamd Date: Fri, 24 Mar 2023 14:37:52 -0500 Subject: [PATCH 11/11] update tensile tag --- tensile_tag.txt | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tensile_tag.txt b/tensile_tag.txt index c0735d6f6..8c220d5f4 100644 --- a/tensile_tag.txt +++ b/tensile_tag.txt @@ -1 +1 @@ -9ef81616d17104869349d547493c64132fe4baa2 +38d444a9f2b6cddfeaeedcb39a5688150fa27093