diff --git a/clients/gtest/trsm_gtest.yaml b/clients/gtest/trsm_gtest.yaml index 42a8d1cd6..019bdae81 100644 --- a/clients/gtest/trsm_gtest.yaml +++ b/clients/gtest/trsm_gtest.yaml @@ -60,6 +60,17 @@ Definitions: - &large_memory_matrix_size_range - { M: 8320, N: 128, lda: 8320, ldb: 8320 } + - &size_t_left_matrix_size_range + - { M: 128, N: 128, lda: 16777220, ldb: 16777220 } # trsm "special" kernel with overflow lda/ldb + - { M: 1000, N: 4, lda: 2147500, ldb: 2147500 } # trsm "left" kernel with overflow lda/ldb + #- { M: 16, N: 16, lda: 134217728, ldb: 134217728 } # trsm "small" kernel with overflow lda/ldb + #- { M: 128, N: 64, lda: 33554450, ldb: 33554450 } # trsm "subsitution" kernel with overflow lda/ldb + + - &size_t_right_matrix_size_range + #- { M: 128, N: 128, lda: 16777220, ldb: 16777220 } # trsm "special" kernel with overflow lda/ldb + - { M: 4, N: 1000, lda: 2147500, ldb: 2147500 } # trsm "right" kernel with overflow lda/ldb + #- { M: 16, N: 16, lda: 134217728, ldb: 134217728 } # trsm "small" kernel with overflow lda/ldb + - &substitution_size_range_thorough - { M: 1, N: 1, lda: 100, ldb: 100 } - { M: 1, N: 32, lda: 100, ldb: 100 } @@ -486,6 +497,49 @@ Tests: matrix_size: *testset2_matrix_size_range alpha: [ 1 ] +- name: trsm_size_t_left + category: nightly + function: + - trsm + #- trsm_batched + #- trsm_strided_batched + # precision: *single_double_precisions_complex_real + precision: *single_precision + arguments: + #- { side: L, uplo: L, transA: N, diag: N } + #- { side: L, uplo: L, transA: N, diag: U } + #- { side: L, uplo: L, transA: T, diag: N } + #- { side: L, uplo: L, transA: T, diag: U } + #- { side: L, uplo: U, transA: N, diag: N } + - { side: L, uplo: U, transA: N, diag: U } + #- { side: L, uplo: U, transA: T, diag: N } + #- { side: L, uplo: U, transA: T, diag: U } + matrix_size: *size_t_left_matrix_size_range + alpha: [2] + batch_count: [2] + stride_scale: [1] + +- name: trsm_size_t_right + category: nightly + function: + - trsm + #- trsm_batched + #- trsm_strided_batched + # precision: *single_double_precisions_complex_real + precision: *single_precision + arguments: + - { side: R, uplo: L, transA: N, diag: N } + #- { side: R, uplo: L, transA: N, diag: U } + #- { side: R, uplo: L, transA: T, diag: N } + #- { side: R, uplo: L, transA: T, diag: U } + #- { side: R, uplo: U, transA: N, diag: N } + #- { side: R, uplo: U, transA: N, diag: U } + #- { side: R, uplo: U, transA: T, diag: N } + #- { side: R, uplo: U, transA: T, diag: U } + matrix_size: *size_t_right_matrix_size_range + alpha: [2] + stride_scale: [1] + - name: trsm_large category: nightly function: trsm diff --git a/clients/include/blas3/testing_trsm_strided_batched.hpp b/clients/include/blas3/testing_trsm_strided_batched.hpp index 0891abaa0..89f4e3c24 100644 --- a/clients/include/blas3/testing_trsm_strided_batched.hpp +++ b/clients/include/blas3/testing_trsm_strided_batched.hpp @@ -1,5 +1,5 @@ /* ************************************************************************ - * Copyright (C) 2018-2022 Advanced Micro Devices, Inc. All rights reserved. + * Copyright (C) 2018-2023 Advanced Micro Devices, Inc. All rights reserved. * * Permission is hereby granted, free of charge, to any person obtaining a copy * of this software and associated documentation files (the "Software"), to deal @@ -413,13 +413,13 @@ void testing_trsm_strided_batched(const Arguments& arg) auto rocblas_trsm_strided_batched_fn = arg.fortran ? rocblas_trsm_strided_batched : rocblas_trsm_strided_batched; - rocblas_int M = arg.M; - rocblas_int N = arg.N; - rocblas_int lda = arg.lda; - rocblas_int ldb = arg.ldb; - rocblas_int stride_A = arg.stride_a; - rocblas_int stride_B = arg.stride_b; - rocblas_int batch_count = arg.batch_count; + rocblas_int M = arg.M; + rocblas_int N = arg.N; + rocblas_int lda = arg.lda; + rocblas_int ldb = arg.ldb; + rocblas_stride stride_A = arg.stride_a; + rocblas_stride stride_B = arg.stride_b; + rocblas_int batch_count = arg.batch_count; char char_side = arg.side; char char_uplo = arg.uplo; diff --git a/clients/include/utility.hpp b/clients/include/utility.hpp index 125388544..d5d197e16 100644 --- a/clients/include/utility.hpp +++ b/clients/include/utility.hpp @@ -341,24 +341,24 @@ void make_unit_diagonal(rocblas_fill uplo, T* hA, rocblas_int lda, rocblas_int N { for(int i = 0; i < N; i++) { - T diag = hA[i + i * lda]; + T diag = hA[i + i * size_t(lda)]; for(int j = 0; j <= i; j++) - hA[i + j * lda] = hA[i + j * lda] / diag; + hA[i + j * size_t(lda)] = hA[i + j * size_t(lda)] / diag; } } else // rocblas_fill_upper { for(int j = 0; j < N; j++) { - T diag = hA[j + j * lda]; + T diag = hA[j + j * size_t(lda)]; for(int i = 0; i <= j; i++) - hA[i + j * lda] = hA[i + j * lda] / diag; + hA[i + j * size_t(lda)] = hA[i + j * size_t(lda)] / diag; } } // randomly initalize diagonal to ensure we aren't using it's values for tests. for(int i = 0; i < N; i++) { - rocblas_init(hA + i * lda + i, 1, 1, 1); + rocblas_init(hA + i * size_t(lda) + i, 1, 1, 1); } } diff --git a/library/src/blas3/Tensile/gemm_tensile.hpp b/library/src/blas3/Tensile/gemm_tensile.hpp index 6703f12d8..ffc338a84 100644 --- a/library/src/blas3/Tensile/gemm_tensile.hpp +++ b/library/src/blas3/Tensile/gemm_tensile.hpp @@ -139,19 +139,33 @@ inline rocblas_status call_tensile(rocblas_handle handle, } #endif - RocblasContractionProblem problem{handle, trans_a, - trans_b, m, - n, k, - alpha, A, - nullptr, ld_a, - stride_a, offset_a, - B, nullptr, - ld_b, stride_b, - offset_b, beta, - C, nullptr, - ld_c, stride_c, - offset_c, batch_count, - true, rocblas_gemm_flags_none}; + // pre apply offsets for non-batched and strided + RocblasContractionProblem problem{handle, + trans_a, + trans_b, + m, + n, + k, + alpha, + A + offset_a, + nullptr, + ld_a, + stride_a, + 0 /* offset_a */, + B + offset_b, + nullptr, + ld_b, + stride_b, + 0 /* offset_b */, + beta, + C + offset_c, + nullptr, + ld_c, + stride_c, + 0 /* offset_c */, + batch_count, + true, + rocblas_gemm_flags_none}; return runContractionProblem(problem); } diff --git a/library/src/blas3/rocblas_trsm.hpp b/library/src/blas3/rocblas_trsm.hpp index 8127acb0e..6533afd91 100644 --- a/library/src/blas3/rocblas_trsm.hpp +++ b/library/src/blas3/rocblas_trsm.hpp @@ -198,7 +198,7 @@ copy_matrix_trsm(rocblas_int rows, size_t ty = blockIdx.y * blockDim.y + threadIdx.y; if(tx < rows && ty < cols) - xb[tx + ldb * ty] = xa[tx + lda * ty]; + xb[tx + size_t(ldb) * ty] = xa[tx + size_t(lda) * ty]; } /* ===============copy helper============================================= */ @@ -256,7 +256,7 @@ set_matrix_trsm(rocblas_int rows, size_t ty = blockIdx.y * blockDim.y + threadIdx.y; if(tx < rows && ty < cols) - xa[tx + lda * ty] = T(0.0); + xa[tx + size_t(lda) * ty] = T(0.0); } /* ===============set helper============================================= */ @@ -384,7 +384,7 @@ rocblas_status rocblas_trsm_left(rocblas_handle handle, jb, &alpha_1, invA, - i * BLOCK + offset_invAin, + i * size_t(BLOCK) + offset_invAin, BLOCK, stride_invA, (U)B, @@ -409,7 +409,8 @@ rocblas_status rocblas_trsm_left(rocblas_handle handle, BLOCK, &alpha_negative_one, A, - i + BLOCK + i * lda + offset_Ain, + i + BLOCK + i * size_t(lda) + + offset_Ain, lda, stride_A, (U)X, @@ -452,7 +453,7 @@ rocblas_status rocblas_trsm_left(rocblas_handle handle, jb, alpha, invA, - i * BLOCK + offset_invAin, + i * size_t(BLOCK) + offset_invAin, BLOCK, stride_invA, (U)B, @@ -476,7 +477,7 @@ rocblas_status rocblas_trsm_left(rocblas_handle handle, jb, &alpha_negative_one, A, - i * lda + offset_Ain, + i * size_t(lda) + offset_Ain, lda, stride_A, (U)X, @@ -502,7 +503,7 @@ rocblas_status rocblas_trsm_left(rocblas_handle handle, BLOCK, &alpha_1, invA, - i * BLOCK + offset_invAin, + i * size_t(BLOCK) + offset_invAin, BLOCK, stride_invA, (U)B, @@ -525,7 +526,7 @@ rocblas_status rocblas_trsm_left(rocblas_handle handle, BLOCK, &alpha_negative_one, A, - i * lda + offset_Ain, + i * size_t(lda) + offset_Ain, lda, stride_A, (U)X, @@ -557,7 +558,7 @@ rocblas_status rocblas_trsm_left(rocblas_handle handle, jb, alpha, invA, - i * BLOCK + offset_invAin, + i * size_t(BLOCK) + offset_invAin, BLOCK, stride_invA, (U)B, @@ -605,7 +606,7 @@ rocblas_status rocblas_trsm_left(rocblas_handle handle, BLOCK, &alpha_1, invA, - i * BLOCK + offset_invAin, + i * size_t(BLOCK) + offset_invAin, BLOCK, stride_invA, (U)B, @@ -679,7 +680,7 @@ rocblas_status rocblas_trsm_left(rocblas_handle handle, BLOCK, &alpha_negative_one, A, - BLOCK * lda + offset_Ain, + BLOCK * size_t(lda) + offset_Ain, lda, stride_A, (U)X, @@ -705,7 +706,7 @@ rocblas_status rocblas_trsm_left(rocblas_handle handle, jb, &alpha_1, invA, - i * BLOCK + offset_invAin, + i * size_t(BLOCK) + offset_invAin, BLOCK, stride_invA, (U)B, @@ -728,7 +729,8 @@ rocblas_status rocblas_trsm_left(rocblas_handle handle, BLOCK, &alpha_negative_one, A, - i + (i + BLOCK) * lda + offset_Ain, + i + (i + BLOCK) * size_t(lda) + + offset_Ain, lda, stride_A, (U)X, @@ -793,16 +795,16 @@ rocblas_status rocblas_trsm_right(rocblas_handle handle, jb, alpha, U(B), - i * ldb + offset_Bin, + i * size_t(ldb) + offset_Bin, ldb, stride_B, invA, - i * BLOCK + offset_invAin, + i * size_t(BLOCK) + offset_invAin, BLOCK, stride_invA, &beta_0, X, - i * m, + i * size_t(m), m, stride_X, batch_count); @@ -816,7 +818,7 @@ rocblas_status rocblas_trsm_right(rocblas_handle handle, jb, &alpha_negative_one, (U)X, - i * m, + i * size_t(m), m, stride_X, A, @@ -841,16 +843,16 @@ rocblas_status rocblas_trsm_right(rocblas_handle handle, BLOCK, &alpha_1, (U)B, - i * ldb + offset_Bin, + i * size_t(ldb) + offset_Bin, ldb, stride_B, invA, - i * BLOCK + offset_invAin, + i * size_t(BLOCK) + offset_invAin, BLOCK, stride_invA, &beta_0, X, - i * m, + i * size_t(m), m, stride_X, batch_count); @@ -864,7 +866,7 @@ rocblas_status rocblas_trsm_right(rocblas_handle handle, BLOCK, &alpha_negative_one, (U)X, - i * m, + i * size_t(m), m, stride_X, A, @@ -919,12 +921,12 @@ rocblas_status rocblas_trsm_right(rocblas_handle handle, m, stride_X, A, - BLOCK * lda + offset_Ain, + BLOCK * size_t(lda) + offset_Ain, lda, stride_A, alpha, B, - BLOCK * ldb + offset_Bin, + BLOCK * size_t(ldb) + offset_Bin, ldb, stride_B, batch_count); @@ -941,16 +943,16 @@ rocblas_status rocblas_trsm_right(rocblas_handle handle, jb, &alpha_1, (U)B, - i * ldb + offset_Bin, + i * size_t(ldb) + offset_Bin, ldb, stride_B, invA, - i * BLOCK + offset_invAin, + i * size_t(BLOCK) + offset_invAin, BLOCK, stride_invA, &beta_0, X, - i * m, + i * size_t(m), m, stride_X, batch_count); @@ -964,16 +966,17 @@ rocblas_status rocblas_trsm_right(rocblas_handle handle, BLOCK, &alpha_negative_one, (U)X, - i * m, + i * size_t(m), m, stride_X, A, - i + (i + BLOCK) * lda + offset_Ain, + i + (i + BLOCK) * size_t(lda) + + offset_Ain, lda, stride_A, &beta_1, B, - (i + BLOCK) * ldb + offset_Bin, + (i + BLOCK) * size_t(ldb) + offset_Bin, ldb, stride_B, batch_count); @@ -1027,7 +1030,7 @@ rocblas_status rocblas_trsm_right(rocblas_handle handle, stride_A, alpha, B, - BLOCK * ldb + offset_Bin, + BLOCK * size_t(ldb) + offset_Bin, ldb, stride_B, batch_count); @@ -1044,16 +1047,16 @@ rocblas_status rocblas_trsm_right(rocblas_handle handle, jb, &alpha_1, (U)B, - i * ldb + offset_Bin, + i * size_t(ldb) + offset_Bin, ldb, stride_B, invA, - i * BLOCK + offset_invAin, + i * size_t(BLOCK) + offset_invAin, BLOCK, stride_invA, &beta_0, X, - i * m, + i * size_t(m), m, stride_X, batch_count); @@ -1067,16 +1070,17 @@ rocblas_status rocblas_trsm_right(rocblas_handle handle, BLOCK, &alpha_negative_one, (U)X, - i * m, + i * size_t(m), m, stride_X, A, - BLOCK + i + i * lda + offset_Ain, + BLOCK + i + i * size_t(lda) + + offset_Ain, lda, stride_A, &beta_1, B, - (i + BLOCK) * ldb + offset_Bin, + (i + BLOCK) * size_t(ldb) + offset_Bin, ldb, stride_B, batch_count); @@ -1096,16 +1100,16 @@ rocblas_status rocblas_trsm_right(rocblas_handle handle, jb, alpha, (U)B, - i * ldb + offset_Bin, + i * size_t(ldb) + offset_Bin, ldb, stride_B, invA, - i * BLOCK + offset_invAin, + i * size_t(BLOCK) + offset_invAin, BLOCK, stride_invA, &beta_0, X, - i * m, + i * size_t(m), m, stride_X, batch_count); @@ -1119,11 +1123,11 @@ rocblas_status rocblas_trsm_right(rocblas_handle handle, jb, &alpha_negative_one, (U)X, - i * m, + i * size_t(m), m, stride_X, A, - i * lda + offset_Ain, + i * size_t(lda) + offset_Ain, lda, stride_A, alpha, @@ -1144,16 +1148,16 @@ rocblas_status rocblas_trsm_right(rocblas_handle handle, BLOCK, &alpha_1, (U)B, - i * ldb + offset_Bin, + i * size_t(ldb) + offset_Bin, ldb, stride_B, invA, - i * BLOCK + offset_invAin, + i * size_t(BLOCK) + offset_invAin, BLOCK, stride_invA, &beta_0, X, - i * m, + i * size_t(m), m, stride_X, batch_count); @@ -1167,11 +1171,11 @@ rocblas_status rocblas_trsm_right(rocblas_handle handle, BLOCK, &alpha_negative_one, (U)X, - i * m, + i * size_t(m), m, stride_X, A, - i * lda + offset_Ain, + i * size_t(lda) + offset_Ain, lda, stride_A, &beta_1, @@ -1248,9 +1252,10 @@ rocblas_status special_trsm_template(rocblas_handle handle, if(r) { - rocblas_int offsetA = 0; - rocblas_int offsetB = parity ? w * B_chunk_size * ldb - : w * B_chunk_size * ldb + (q + 1) * BLOCK; + rocblas_stride offsetA = 0; + rocblas_stride offsetB = parity + ? w * B_chunk_size * size_t(ldb) + : w * B_chunk_size * size_t(ldb) + (q + 1) * BLOCK; if(transA == rocblas_operation_none) offsetA = parity ? r * BLOCK : BLOCK * (q * lda + q + lda); @@ -1343,7 +1348,7 @@ rocblas_status special_trsm_template(rocblas_handle handle, stride_X, &beta_0, B, - size_t(w * B_chunk_size * ldb + j * BLOCK + offset_Bin), + size_t(w * B_chunk_size * size_t(ldb) + j * BLOCK + offset_Bin), size_t(ldb), stride_B, batch_count); @@ -1368,14 +1373,15 @@ rocblas_status special_trsm_template(rocblas_handle handle, width, stride_X, batch_count, - j * BLOCK * ldb + w * B_chunk_size + offset_Bin, + j * BLOCK * size_t(ldb) + w * B_chunk_size + offset_Bin, 0); if(r) { - rocblas_int offsetA = 0; - rocblas_int offsetB - = parity ? w * B_chunk_size + (q + 1) * BLOCK * ldb : w * B_chunk_size; + rocblas_stride offsetA = 0; + rocblas_stride offsetB = parity + ? w * B_chunk_size + (q + 1) * BLOCK * size_t(ldb) + : w * B_chunk_size; if(transA == rocblas_operation_none) offsetA = parity ? BLOCK * (q * lda + q + 1) : r * BLOCK * lda; else @@ -1432,7 +1438,7 @@ rocblas_status special_trsm_template(rocblas_handle handle, alpha, B, compute_type, - j * BLOCK * ldb + w * B_chunk_size + j * BLOCK * size_t(ldb) + w * B_chunk_size + offset_Bin, ldb, stride_B, @@ -1467,7 +1473,7 @@ rocblas_status special_trsm_template(rocblas_handle handle, stride_invA, &beta_0, B, - size_t(w * B_chunk_size * ldb + j * BLOCK * ldb + offset_Bin), + size_t(w * B_chunk_size * size_t(ldb) + j * BLOCK * size_t(ldb) + offset_Bin), size_t(ldb), stride_B, batch_count); @@ -1933,7 +1939,7 @@ rocblas_trsm_small_right_device(rocblas_fill uplo, const int maxColB = (bx < gridDim.x - 1) ? NB : m - bx * NB; // offset B into correct block row - B += bx * NB; + B += size_t(bx) * NB; __shared__ T sA[NB * NB]; __shared__ T sB[NB * NB]; @@ -1944,7 +1950,7 @@ rocblas_trsm_small_right_device(rocblas_fill uplo, { // Load A into sA, handle conjugation if necessary for(int i = 0; i <= maxColA; i++) - sA[i * NB + tx] = (CONJ) ? conj(A[i * lda + tx]) : A[i * lda + tx]; + sA[i * NB + tx] = (CONJ) ? conj(A[i * size_t(lda) + tx]) : A[i * size_t(lda) + tx]; // set unit diagonal if needed if(diag == rocblas_diagonal_unit) @@ -1955,7 +1961,7 @@ rocblas_trsm_small_right_device(rocblas_fill uplo, { // Load B into sB and multiply by alpha for(int i = 0; i < n; i++) - sB[i * NB + tx] = alpha * B[i * ldb + tx]; + sB[i * NB + tx] = alpha * B[i * size_t(ldb) + tx]; } __syncthreads(); @@ -2163,7 +2169,7 @@ rocblas_trsm_small_right_device(rocblas_fill uplo, if(tx < maxColB) { for(int i = 0; i < n; i++) - B[i * ldb + tx] = sB[i * NB + tx]; + B[i * size_t(ldb) + tx] = sB[i * NB + tx]; } } @@ -2205,7 +2211,7 @@ rocblas_trsm_small_64_right_device(rocblas_fill uplo, const int maxColB = (bx < gridDim.x - 1) ? NB : m - bx * NB; // offset B into correct block row - B += bx * NB; + B += bx * size_t(NB); __shared__ T sB[NB * NB]; @@ -2213,7 +2219,7 @@ rocblas_trsm_small_64_right_device(rocblas_fill uplo, { // Load B into sB and multiply by alpha for(int i = 0; i < n; i++) - sB[i * NB + tx] = alpha * B[i * ldb + tx]; + sB[i * NB + tx] = alpha * B[i * size_t(ldb) + tx]; } __syncthreads(); // Solve for B in shared memory @@ -2227,13 +2233,13 @@ rocblas_trsm_small_64_right_device(rocblas_fill uplo, T temp_reg_B = sB[i * NB + tx]; for(int j = 0; j < i; j++) { - T valA = A[i * lda + j]; + T valA = A[i * size_t(lda) + j]; temp_reg_B -= sB[j * NB + tx] * valA; } // Solve sB[i * NB + tx] = temp_reg_B; if(diag != rocblas_diagonal_unit) - sB[i * NB + tx] /= A[i * lda + i]; + sB[i * NB + tx] /= A[i * size_t(lda) + i]; } } else if(transA == rocblas_operation_none && uplo == rocblas_fill_lower) @@ -2243,12 +2249,12 @@ rocblas_trsm_small_64_right_device(rocblas_fill uplo, T temp_reg_B = sB[i * NB + tx]; for(int j = maxColA; j > i; j--) { - T valA = A[i * lda + j]; + T valA = A[i * size_t(lda) + j]; temp_reg_B -= sB[j * NB + tx] * valA; } sB[i * NB + tx] = temp_reg_B; if(diag != rocblas_diagonal_unit) - sB[i * NB + tx] /= A[i * lda + i]; + sB[i * NB + tx] /= A[i * size_t(lda) + i]; } } else if(uplo == rocblas_fill_upper) @@ -2258,12 +2264,12 @@ rocblas_trsm_small_64_right_device(rocblas_fill uplo, T temp_reg_B = sB[i * NB + tx]; for(int j = maxColA; j > i; j--) { - T valA = CONJ ? conj(A[j * lda + i]) : A[j * lda + i]; + T valA = CONJ ? conj(A[j * size_t(lda) + i]) : A[j * size_t(lda) + i]; temp_reg_B -= sB[j * NB + tx] * valA; } sB[i * NB + tx] = temp_reg_B; if(diag != rocblas_diagonal_unit) - sB[i * NB + tx] /= CONJ ? conj(A[i * lda + i]) : A[i * lda + i]; + sB[i * NB + tx] /= CONJ ? conj(A[i * size_t(lda) + i]) : A[i * size_t(lda) + i]; } } else // lower (conjugate-)transpose @@ -2273,12 +2279,12 @@ rocblas_trsm_small_64_right_device(rocblas_fill uplo, T temp_reg_B = sB[i * NB + tx]; for(int j = 0; j < i; j++) { - T valA = CONJ ? conj(A[j * lda + i]) : A[j * lda + i]; + T valA = CONJ ? conj(A[j * size_t(lda) + i]) : A[j * size_t(lda) + i]; temp_reg_B -= sB[j * NB + tx] * valA; } sB[i * NB + tx] = temp_reg_B; if(diag != rocblas_diagonal_unit) - sB[i * NB + tx] /= CONJ ? conj(A[i * lda + i]) : A[i * lda + i]; + sB[i * NB + tx] /= CONJ ? conj(A[i * size_t(lda) + i]) : A[i * size_t(lda) + i]; } } @@ -2286,7 +2292,7 @@ rocblas_trsm_small_64_right_device(rocblas_fill uplo, if(tx < maxColB) { for(int i = 0; i < n; i++) - B[i * ldb + tx] = sB[i * NB + tx]; + B[i * size_t(ldb) + tx] = sB[i * NB + tx]; } } @@ -2330,7 +2336,7 @@ rocblas_trsm_small_left_device(rocblas_fill uplo, const int maxColB = (bx < gridDim.x - 1) ? NB : n - bx * NB; // offset B into correct block column - B += bx * NB * ldb; + B += bx * NB * size_t(ldb); // shared A and shared B __shared__ T sA[NB * NB]; @@ -2343,7 +2349,7 @@ rocblas_trsm_small_left_device(rocblas_fill uplo, { // Load A into sA, handle conjugation if necessary for(int i = 0; i <= maxColA; i++) - sA[i * NB + tx] = (CONJ) ? conj(A[i * lda + tx]) : A[i * lda + tx]; + sA[i * NB + tx] = (CONJ) ? conj(A[i * size_t(lda) + tx]) : A[i * size_t(lda) + tx]; // set unit diagonal if needed if(diag == rocblas_diagonal_unit) @@ -2351,7 +2357,7 @@ rocblas_trsm_small_left_device(rocblas_fill uplo, // Load B into sB and multiply by alpha for(int i = 0; i < maxColB; i++) - sB[i * NB + tx] = alpha * B[i * ldb + tx]; + sB[i * NB + tx] = alpha * B[i * size_t(ldb) + tx]; } __syncthreads(); @@ -2561,7 +2567,7 @@ rocblas_trsm_small_left_device(rocblas_fill uplo, if(tx < m) { for(int i = 0; i < maxColB; i++) - B[i * ldb + tx] = sB[i * NB + tx]; + B[i * size_t(ldb) + tx] = sB[i * NB + tx]; } } @@ -2603,7 +2609,7 @@ rocblas_trsm_small_64_left_device(rocblas_fill uplo, const int maxColB = (bx < gridDim.x - 1) ? NB : n - bx * NB; // offset B into correct block column - B += bx * NB * ldb; + B += bx * NB * size_t(ldb); // shared B __shared__ T sB[NB * NB]; @@ -2611,7 +2617,7 @@ rocblas_trsm_small_64_left_device(rocblas_fill uplo, { // Load B into sB and multiply by alpha for(int i = 0; i < maxColB; i++) - sB[i * NB + tx] = alpha * B[i * ldb + tx]; + sB[i * NB + tx] = alpha * B[i * size_t(ldb) + tx]; } __syncthreads(); @@ -2624,11 +2630,11 @@ rocblas_trsm_small_64_left_device(rocblas_fill uplo, // Subtract previously solved parts for(int j = 0; j < i; j++) { - T valA = A[j * lda + i]; + T valA = A[j * size_t(lda) + i]; sB[tx * NB + i] -= sB[tx * NB + j] * valA; } if(diag != rocblas_diagonal_unit) - sB[tx * NB + i] /= A[i * lda + i]; + sB[tx * NB + i] /= A[i * size_t(lda) + i]; } } else if(!LOWER && transA == rocblas_operation_none) @@ -2638,12 +2644,12 @@ rocblas_trsm_small_64_left_device(rocblas_fill uplo, T temp_reg_B = sB[tx * NB + i]; for(int j = maxColA; j > i; j--) { - T valA = A[j * lda + i]; + T valA = A[j * size_t(lda) + i]; temp_reg_B -= sB[tx * NB + j] * valA; } sB[tx * NB + i] = temp_reg_B; if(diag != rocblas_diagonal_unit) - sB[tx * NB + i] /= A[i * lda + i]; + sB[tx * NB + i] /= A[i * size_t(lda) + i]; } } else if(LOWER) @@ -2653,12 +2659,12 @@ rocblas_trsm_small_64_left_device(rocblas_fill uplo, T temp_reg_B = sB[tx * NB + i]; for(int j = maxColA; j > i; j--) { - T valA = (CONJ) ? conj(A[i * lda + j]) : A[i * lda + j]; + T valA = (CONJ) ? conj(A[i * size_t(lda) + j]) : A[i * size_t(lda) + j]; temp_reg_B -= sB[tx * NB + j] * valA; } sB[tx * NB + i] = temp_reg_B; if(diag != rocblas_diagonal_unit) - sB[tx * NB + i] /= (CONJ) ? conj(A[i * lda + i]) : A[i * lda + i]; + sB[tx * NB + i] /= (CONJ) ? conj(A[i * size_t(lda) + i]) : A[i * size_t(lda) + i]; } } else if(!LOWER) @@ -2668,12 +2674,12 @@ rocblas_trsm_small_64_left_device(rocblas_fill uplo, T temp_reg_B = sB[tx * NB + i]; for(int j = 0; j < i; j++) { - T valA = (CONJ) ? conj(A[i * lda + j]) : A[i * lda + j]; + T valA = (CONJ) ? conj(A[i * size_t(lda) + j]) : A[i * size_t(lda) + j]; temp_reg_B -= sB[tx * NB + j] * valA; } sB[tx * NB + i] = temp_reg_B; if(diag != rocblas_diagonal_unit) - sB[tx * NB + i] /= (CONJ) ? conj(A[i * lda + i]) : A[i * lda + i]; + sB[tx * NB + i] /= (CONJ) ? conj(A[i * size_t(lda) + i]) : A[i * size_t(lda) + i]; } } @@ -2683,7 +2689,7 @@ rocblas_trsm_small_64_left_device(rocblas_fill uplo, if(tx < m) { for(int i = 0; i < maxColB; i++) - B[i * ldb + tx] = sB[i * NB + tx]; + B[i * size_t(ldb) + tx] = sB[i * NB + tx]; } } @@ -2808,7 +2814,7 @@ ROCBLAS_KERNEL_NO_BOUNDS rocblas_trsm_block_backward_substitution(int if(offY < n && tx < m) { - T valB = alpha * B[offY * ldb_norm + tx * ldb_trans]; + T valB = alpha * B[offY * size_t(ldb_norm) + tx * size_t(ldb_trans)]; for(int i = m - 1; i > 0; i--) { // tx is row of B, ty is col of B @@ -2817,23 +2823,23 @@ ROCBLAS_KERNEL_NO_BOUNDS rocblas_trsm_block_backward_substitution(int if(tx == i) { // solve cur row - valB = UNIT ? valB : valB / A[tx * lda_norm + tx * lda_trans]; + valB = UNIT ? valB : valB / A[tx * size_t(lda_norm) + tx * size_t(lda_trans)]; sB[ty] = valB; } __syncthreads(); if(tx < i) - valB -= (CONJ ? conj(A[i * lda_norm + tx * lda_trans]) - : A[i * lda_norm + tx * lda_trans]) + valB -= (CONJ ? conj(A[i * size_t(lda_norm) + tx * size_t(lda_trans)]) + : A[i * size_t(lda_norm) + tx * size_t(lda_trans)]) * sB[ty]; } if(!UNIT && tx == 0) - valB /= A[tx * lda_norm + tx * lda_trans]; + valB /= A[tx * size_t(lda_norm) + tx * size_t(lda_trans)]; // store back to mem - B[offY * ldb_norm + tx * ldb_trans] = valB; + B[offY * size_t(ldb_norm) + tx * size_t(ldb_trans)] = valB; } } @@ -2876,7 +2882,7 @@ ROCBLAS_KERNEL_NO_BOUNDS rocblas_trsm_block_forward_substitution(int if(offY < n && tx < m) { - T valB = alpha * B[offY * ldb_norm + tx * ldb_trans]; + T valB = alpha * B[offY * size_t(ldb_norm) + tx * size_t(ldb_trans)]; for(int i = 0; i < m - 1; i++) { // tx is row of B, ty is col of B @@ -2885,22 +2891,22 @@ ROCBLAS_KERNEL_NO_BOUNDS rocblas_trsm_block_forward_substitution(int if(tx == i) { // solve cur row - valB = UNIT ? valB : valB / A[tx * lda_norm + tx * lda_trans]; + valB = UNIT ? valB : valB / A[tx * size_t(lda_norm) + tx * size_t(lda_trans)]; sB[ty] = valB; } __syncthreads(); if(tx > i) - valB -= (CONJ ? conj(A[i * lda_norm + tx * lda_trans]) - : A[i * lda_norm + tx * lda_trans]) + valB -= (CONJ ? conj(A[i * size_t(lda_norm) + tx * size_t(lda_trans)]) + : A[i * size_t(lda_norm) + tx * size_t(lda_trans)]) * sB[ty]; } if(!UNIT && tx == m - 1) - valB /= A[tx * lda_norm + tx * lda_trans]; + valB /= A[tx * size_t(lda_norm) + tx * size_t(lda_trans)]; // store back to mem - B[offY * ldb_norm + tx * ldb_trans] = valB; + B[offY * size_t(ldb_norm) + tx * size_t(ldb_trans)] = valB; } } @@ -2949,7 +2955,7 @@ void rocblas_trsm_small_substitution(rocblas_handle handle, T negative_one = -1; T one = 1; rocblas_int j = 0; - rocblas_int offA_sub, offB_sub; + size_t offA_sub, offB_sub; size_t smem_size; // Different kernels for forward substitution vs. backward substitution @@ -2959,17 +2965,17 @@ void rocblas_trsm_small_substitution(rocblas_handle handle, { const rocblas_int j_next = j + NBX; - rocblas_int offA_gemm = LEFT ? (!TRANSA ? j * lda + j_next : j + j_next * lda) - : (!TRANSA ? j + j_next * lda : j * lda + j_next); - rocblas_int offB_gemm = LEFT ? j : j * ldb; - rocblas_int offC_gemm = LEFT ? j_next : j_next * ldb; - smem_size = (1024 / NBX) * sizeof(T); + size_t offA_gemm = LEFT ? (!TRANSA ? j * size_t(lda) + j_next : j + j_next * size_t(lda)) + : (!TRANSA ? j + j_next * size_t(lda) : j * size_t(lda) + j_next); + size_t offB_gemm = LEFT ? j : j * size_t(ldb); + size_t offC_gemm = LEFT ? j_next : j_next * size_t(ldb); + smem_size = (1024 / NBX) * sizeof(T); // 1. call trsm subtitution/solve if(FORWARD_SUB) { - offA_sub = j * lda + j; - offB_sub = LEFT ? j : j * ldb; + offA_sub = j * size_t(lda) + j; + offB_sub = LEFT ? j : j * size_t(ldb); hipLaunchKernelGGL((rocblas_trsm_block_forward_substitution= 0; i--) { diagP[(n - 1 - index) + (n - 1 - i) * n] - = i >= index ? A[Aoffset + index + i * lda] : 0.0f; + = i >= index ? A[Aoffset + index + i * size_t(lda)] : 0.0f; } } } @@ -107,12 +107,12 @@ ROCBLAS_KERNEL_ILF void custom_trtri_device(rocblas_fill uplo, if(uplo == rocblas_fill_lower) { for(int i = 0; i < n; i++) - diagP[index + i * n] = A[n + index + i * lda]; + diagP[index + i * n] = A[n + index + i * size_t(lda)]; } else { // transpose A in diag1 if upper for(int i = n - 1; i >= 0; i--) - diagP[index + i * n] = A[n * lda + index + i * lda]; + diagP[index + i * n] = A[n * size_t(lda) + index + i * size_t(lda)]; } } @@ -213,7 +213,7 @@ ROCBLAS_KERNEL_ILF void custom_trtri_device(rocblas_fill uplo, T sum(0); for(int k = 0; k < r + 1; k++) sum += -1.0f * diag2[r + k * n] * temp[k + c * n]; - invA[n + r + c * ldinvA] = sum; + invA[n + r + c * size_t(ldinvA)] = sum; } } else @@ -223,23 +223,24 @@ ROCBLAS_KERNEL_ILF void custom_trtri_device(rocblas_fill uplo, T sum(0); for(int k = r; k < IB; k++) sum += -1.0f * diag1[(n - 1 - r) + (n - 1 - k) * n] * temp[k + c * n]; - invA[n * ldinvA + r + c * ldinvA] = sum; + invA[n * size_t(ldinvA) + r + c * size_t(ldinvA)] = sum; } } if(tx < 2 * n) { - int AInvoffset = tx < n ? 0 : n * ldinvA + n; + size_t AInvoffset = tx < n ? 0 : n * size_t(ldinvA) + n; if(uplo == rocblas_fill_lower) { for(int i = 0; i <= index; i++) - invA[AInvoffset + index + i * ldinvA] = diagP[index + i * n]; + invA[AInvoffset + index + i * size_t(ldinvA)] = diagP[index + i * n]; } else { // transpose back to A from sA if upper for(int i = n - 1; i >= index; i--) - invA[AInvoffset + index + i * ldinvA] = diagP[(n - 1 - index) + (n - 1 - i) * n]; + invA[AInvoffset + index + i * size_t(ldinvA)] + = diagP[(n - 1 - index) + (n - 1 - i) * n]; } } } @@ -270,12 +271,12 @@ ROCBLAS_KERNEL_ILF void trtri_device(rocblas_fill uplo, { // compute only diagonal element for(int i = 0; i <= tx; i++) - sA[tx + i * n] = A[tx + i * lda]; + sA[tx + i * n] = A[tx + i * size_t(lda)]; } else { // transpose A in sA if upper for(int i = n - 1; i >= tx; i--) - sA[(n - 1 - tx) + (n - 1 - i) * n] = A[tx + i * lda]; + sA[(n - 1 - tx) + (n - 1 - i) * n] = A[tx + i * size_t(lda)]; } } __syncthreads(); // if NB < 64, this synch can be avoided @@ -336,12 +337,12 @@ ROCBLAS_KERNEL_ILF void trtri_device(rocblas_fill uplo, if(uplo == rocblas_fill_lower) { for(int i = 0; i <= tx; i++) - invA[tx + i * ldinvA] = sA[tx + i * n]; + invA[tx + i * size_t(ldinvA)] = sA[tx + i * n]; } else { // transpose back to A from sA if upper for(int i = n - 1; i >= tx; i--) - invA[tx + i * ldinvA] = sA[(n - 1 - tx) + (n - 1 - i) * n]; + invA[tx + i * size_t(ldinvA)] = sA[(n - 1 - tx) + (n - 1 - i) * n]; } } } @@ -364,7 +365,7 @@ ROCBLAS_KERNEL_ILF void rocblas_tritri_fill_upper(rocblas_stride offset, rocblas_int row = n - 2 - floor(sqrt(4 * n * (n - 1) - 7 - 8 * idx) / 2.0 - 0.5); rocblas_int col = idx + row + 1 - n * (n - 1) / 2 + (n - row) * (n - row - 1) / 2; - size_t final_offset = offset * sub_stride_A + (row * lda) + col; + size_t final_offset = offset * sub_stride_A + (row * size_t(lda)) + col; A[final_offset] = value; } @@ -376,7 +377,7 @@ ROCBLAS_KERNEL_ILF void rocblas_tritri_fill_lower( rocblas_int row = (rocblas_int)((-1 + sqrt(8 * idx + 1)) / 2); rocblas_int col = idx - row * (row + 1) / 2; - size_t final_offset = offset * sub_stride_A + ((row + 1) * lda) + col; + size_t final_offset = offset * sub_stride_A + ((row + 1) * size_t(lda)) + col; A[final_offset] = value; } @@ -503,7 +504,7 @@ rocblas_status rocblas_trtri_small(rocblas_handle handle, n, num_non_tri_elements(n), ldinvA, - n * ldinvA, + n * size_t(ldinvA), invA, offset_invA, 0, @@ -557,10 +558,10 @@ trtri_diagonal_kernel(rocblas_fill uplo, rocblas_int tiles = n / IB / 2; const T* individual_A = load_ptr_batch(A, blockIdx.y, offset_A, stride_A) - + (IB * 2 * lda + IB * 2) * (blockIdx.x % tiles) + + (IB * 2 * size_t(lda) + IB * 2) * (blockIdx.x % tiles) + sub_stride_A * (blockIdx.x / tiles); T* individual_invA = load_ptr_batch(invA, blockIdx.y, offset_invA, stride_invA) - + (IB * 2 * ldinvA + IB * 2) * (blockIdx.x % tiles) + + (IB * 2 * size_t(ldinvA) + IB * 2) * (blockIdx.x % tiles) + sub_stride_invA * (blockIdx.x / tiles); auto rem = n - (blockIdx.x % tiles) * IB; @@ -757,8 +758,10 @@ rocblas_status rocblas_trtri_large(rocblas_handle handle, dim3 grid_remainder(sub_batch_count, batch_count); dim3 threads_remainder(remainder); - rocblas_int offset_A2 = (n - remainder) + (n - remainder) * lda + offset_Ain; - rocblas_int offset_invA2 = (n - remainder) + (n - remainder) * ldinvA + offset_invAin; + rocblas_stride offset_A2 = (n - remainder) + (n - remainder) * size_t(lda) + offset_Ain; + rocblas_stride offset_invA2 + = (n - remainder) + (n - remainder) * size_t(ldinvA) + offset_invAin; + hipLaunchKernelGGL((trtri_remainder_kernel), grid_remainder, threads_remainder, @@ -799,7 +802,7 @@ rocblas_status rocblas_trtri_large(rocblas_handle handle, n, num_non_tri_elements(n), ldinvA, - n * ldinvA, + n * size_t(ldinvA), invA, offset_invAin, stride_invA, @@ -816,23 +819,24 @@ rocblas_status rocblas_trtri_large(rocblas_handle handle, { for(int i = 0; i < sub_batch_count; i++) { - rocblas_int offset_A + rocblas_stride offset_A = (uplo == rocblas_fill_lower ? current_n + i * sub_stride_Ain - : current_n * lda + i * sub_stride_Ain); - rocblas_int offset_invA1 + : current_n * size_t(lda) + i * sub_stride_Ain); + rocblas_stride offset_invA1 = (uplo == rocblas_fill_lower ? 0 + i * sub_stride_invAin - : current_n * ldinvA + current_n + i * sub_stride_invAin); - rocblas_int offset_invA2 + : current_n * size_t(ldinvA) + current_n + i * sub_stride_invAin); + rocblas_stride offset_invA2 = (uplo == rocblas_fill_lower - ? current_n * ldinvA + current_n + i * sub_stride_invAin + ? current_n * size_t(ldinvA) + current_n + i * sub_stride_invAin : 0 + i * sub_stride_invAin); - rocblas_int offset_invA3 - = (uplo == rocblas_fill_lower ? current_n + i * sub_stride_invAin - : current_n * ldinvA + i * sub_stride_invAin); - rocblas_int offset_C + rocblas_stride offset_invA3 + = (uplo == rocblas_fill_lower + ? current_n + i * sub_stride_invAin + : current_n * size_t(ldinvA) + i * sub_stride_invAin); + rocblas_stride offset_C = (uplo == rocblas_fill_lower - ? (n - current_n) * ldinvA + i * sub_stride_invAin + ? (n - current_n) * size_t(ldinvA) + i * sub_stride_invAin : (n - current_n * tiles_per_batch) + i * sub_stride_invAin); offset_A += offset_Ain; @@ -847,13 +851,14 @@ rocblas_status rocblas_trtri_large(rocblas_handle handle, (U)A, lda, stride_A, - 2 * current_n * lda + 2 * current_n, + 2 * current_n * size_t(lda) + 2 * current_n, (U)invA, (U)invA, (V)invA, ldinvA, stride_invA, - 2 * current_n * ldinvA + 2 * current_n, + 2 * current_n * size_t(ldinvA) + + 2 * current_n, (V)invA, ldinvA, stride_invA, @@ -871,30 +876,31 @@ rocblas_status rocblas_trtri_large(rocblas_handle handle, { for(int i = 0; i < tiles_per_batch; i++) { - rocblas_int sub_stride_A2 = (2 * current_n * lda + 2 * current_n); - rocblas_int sub_stride_invA2 = (2 * current_n * ldinvA + 2 * current_n); + rocblas_stride sub_stride_A2 = (2 * current_n * size_t(lda) + 2 * current_n); + rocblas_stride sub_stride_invA2 = (2 * current_n * size_t(ldinvA) + 2 * current_n); - rocblas_int offset_A + rocblas_stride offset_A = (uplo == rocblas_fill_lower ? current_n + i * sub_stride_A2 - : current_n * lda + i * sub_stride_A2); + : current_n * size_t(lda) + i * sub_stride_A2); - rocblas_int offset_invA1 + rocblas_stride offset_invA1 = (uplo == rocblas_fill_lower ? 0 + i * sub_stride_invA2 - : current_n * ldinvA + current_n + i * sub_stride_invA2); + : current_n * size_t(ldinvA) + current_n + i * sub_stride_invA2); - rocblas_int offset_invA2 + rocblas_stride offset_invA2 = (uplo == rocblas_fill_lower - ? current_n * ldinvA + current_n + i * sub_stride_invA2 + ? current_n * size_t(ldinvA) + current_n + i * sub_stride_invA2 : 0 + i * sub_stride_invA2); - rocblas_int offset_invA3 - = (uplo == rocblas_fill_lower ? current_n + i * sub_stride_invA2 - : current_n * ldinvA + i * sub_stride_invA2); + rocblas_stride offset_invA3 + = (uplo == rocblas_fill_lower + ? current_n + i * sub_stride_invA2 + : current_n * size_t(ldinvA) + i * sub_stride_invA2); - rocblas_int offset_C = (uplo == rocblas_fill_lower - ? (n - current_n) * ldinvA + i * current_n - : (n - current_n * tiles_per_batch) + i * current_n); + rocblas_stride offset_C = (uplo == rocblas_fill_lower + ? (n - current_n) * size_t(ldinvA) + i * current_n + : (n - current_n * tiles_per_batch) + i * current_n); offset_A += offset_Ain; offset_invA1 += offset_invAin; @@ -940,7 +946,7 @@ rocblas_status rocblas_trtri_large(rocblas_handle handle, n, num_non_tri_elements(n), ldinvA, - n * ldinvA, + n * size_t(ldinvA), invA, offset_invAin, stride_invA, @@ -959,12 +965,14 @@ rocblas_status rocblas_trtri_large(rocblas_handle handle, // and in some cases this happens with multiple sizes. if(remainder > 0) { - rocblas_int offset_A = (uplo == rocblas_fill_lower ? current_n : current_n * lda); - rocblas_int offset_invA1 - = (uplo == rocblas_fill_lower ? 0 : current_n * ldinvA + current_n); - rocblas_int offset_invA2 - = (uplo == rocblas_fill_lower ? current_n * ldinvA + current_n : 0); - rocblas_int offset_invA3 = (uplo == rocblas_fill_lower ? current_n : current_n * ldinvA); + rocblas_stride offset_A + = (uplo == rocblas_fill_lower ? current_n : current_n * size_t(lda)); + rocblas_stride offset_invA1 + = (uplo == rocblas_fill_lower ? 0 : current_n * size_t(ldinvA) + current_n); + rocblas_stride offset_invA2 + = (uplo == rocblas_fill_lower ? current_n * size_t(ldinvA) + current_n : 0); + rocblas_stride offset_invA3 + = (uplo == rocblas_fill_lower ? current_n : current_n * size_t(ldinvA)); offset_A += offset_Ain; offset_invA1 += offset_invAin; @@ -999,13 +1007,15 @@ rocblas_status rocblas_trtri_large(rocblas_handle handle, while(oddRemainder) { - current_n = n - oddRemainder; - rocblas_int offset_A = (uplo == rocblas_fill_lower ? current_n : current_n * lda); - rocblas_int offset_invA1 - = (uplo == rocblas_fill_lower ? 0 : current_n * ldinvA + current_n); - rocblas_int offset_invA2 - = (uplo == rocblas_fill_lower ? current_n * ldinvA + current_n : 0); - rocblas_int offset_invA3 = (uplo == rocblas_fill_lower ? current_n : current_n * ldinvA); + current_n = n - oddRemainder; + rocblas_stride offset_A + = (uplo == rocblas_fill_lower ? current_n : current_n * size_t(lda)); + rocblas_stride offset_invA1 + = (uplo == rocblas_fill_lower ? 0 : current_n * size_t(ldinvA) + current_n); + rocblas_stride offset_invA2 + = (uplo == rocblas_fill_lower ? current_n * size_t(ldinvA) + current_n : 0); + rocblas_stride offset_invA3 + = (uplo == rocblas_fill_lower ? current_n : current_n * size_t(ldinvA)); offset_A += offset_Ain; offset_invA1 += offset_invAin; diff --git a/library/src/blas3/trtri_trsm.hpp b/library/src/blas3/trtri_trsm.hpp index c2f5fab18..5f21ba7e1 100644 --- a/library/src/blas3/trtri_trsm.hpp +++ b/library/src/blas3/trtri_trsm.hpp @@ -54,9 +54,9 @@ trtri_trsm_kernel(rocblas_fill uplo, // device function only see one matrix // each hip thread Block compute a inverse of a IB * IB diagonal block of A - rocblas_int offA = (2 * blockIdx.x) * (IB * lda + IB) + offset_A; - rocblas_int offinvA = ((2 * blockIdx.x) / IBD) * (NB * NB) - + ((2 * blockIdx.x) % IBD) * (IB * NB + IB) + offset_invA; + rocblas_stride offA = (2 * blockIdx.x) * (IB * size_t(lda) + IB) + offset_A; + rocblas_stride offinvA = ((2 * blockIdx.x) / IBD) * (NB * size_t(NB)) + + ((2 * blockIdx.x) % IBD) * (IB * size_t(NB) + IB) + offset_invA; const T* a_i = load_ptr_batch(A, blockIdx.y, offA, stride_A); T* invA_i = load_ptr_batch(invA, blockIdx.y, offinvA, stride_invA); @@ -205,13 +205,13 @@ rocblas_status rocblas_trtri_trsm_template(rocblas_handle handle, sub_blocks); constexpr rocblas_int JB = IB * 4; - rocblas_int sub_stride_A = NB * lda + NB; - rocblas_int sub_stride_invA = NB * NB; - rocblas_int sub_stride_C = JB * JB; - rocblas_int offset_A = (uplo == rocblas_fill_lower ? IB * 2 : IB * 2 * lda); - rocblas_int offset_invA1 = (uplo == rocblas_fill_lower ? 0 : IB * 2 * NB + IB * 2); - rocblas_int offset_invA2 = (uplo == rocblas_fill_lower ? IB * 2 * NB + IB * 2 : 0); - rocblas_int offset_invA3 = (uplo == rocblas_fill_lower ? IB * 2 : IB * 2 * NB); + rocblas_stride sub_stride_A = NB * size_t(lda) + NB; + rocblas_stride sub_stride_invA = NB * NB; + rocblas_stride sub_stride_C = JB * JB; + rocblas_stride offset_A = (uplo == rocblas_fill_lower ? IB * 2 : IB * 2 * size_t(lda)); + rocblas_stride offset_invA1 = (uplo == rocblas_fill_lower ? 0 : IB * 2 * NB + IB * 2); + rocblas_stride offset_invA2 = (uplo == rocblas_fill_lower ? IB * 2 * NB + IB * 2 : 0); + rocblas_stride offset_invA3 = (uplo == rocblas_fill_lower ? IB * 2 : IB * 2 * NB); trtri_gemm_block(handle, IB * 2, @@ -238,7 +238,8 @@ rocblas_status rocblas_trtri_trsm_template(rocblas_handle handle, offset_invAin + offset_invA3, 0); - offset_A = (uplo == rocblas_fill_lower ? IB * 4 * lda + IB * 6 : IB * 6 * lda + IB * 4); + offset_A = (uplo == rocblas_fill_lower ? IB * 4 * size_t(lda) + IB * 6 + : IB * 6 * size_t(lda) + IB * 4); offset_invA1 = (uplo == rocblas_fill_lower ? IB * 4 * NB + IB * 4 : IB * 6 * NB + IB * 6); offset_invA2 = (uplo == rocblas_fill_lower ? IB * 6 * NB + IB * 6 : IB * 4 * NB + IB * 4); offset_invA3 = (uplo == rocblas_fill_lower ? IB * 4 * NB + IB * 6 : IB * 6 * NB + IB * 4); @@ -268,7 +269,7 @@ rocblas_status rocblas_trtri_trsm_template(rocblas_handle handle, offset_invAin + offset_invA3, 0); - offset_A = (uplo == rocblas_fill_lower ? JB : JB * lda); + offset_A = (uplo == rocblas_fill_lower ? JB : JB * size_t(lda)); offset_invA1 = (uplo == rocblas_fill_lower ? 0 : JB * NB + JB); offset_invA2 = (uplo == rocblas_fill_lower ? JB * NB + JB : 0); offset_invA3 = (uplo == rocblas_fill_lower ? JB : JB * NB); @@ -320,7 +321,7 @@ rocblas_status rocblas_trtri_trsm_template(rocblas_handle handle, NB, 0, invA, - sub_blocks * NB * NB + offset_invAin, + sub_blocks * NB * size_t(NB) + offset_invAin, stride_invA, 1); @@ -330,12 +331,12 @@ rocblas_status rocblas_trtri_trsm_template(rocblas_handle handle, diag, rem, A, - sub_blocks * NB * lda + sub_blocks * NB + offset_Ain, + sub_blocks * NB * size_t(lda) + sub_blocks * NB + offset_Ain, lda, stride_A, 0, invA, - sub_blocks * NB * NB + offset_invAin, + sub_blocks * NB * size_t(NB) + offset_invAin, NB, stride_invA, 0, diff --git a/library/src/blas_ex/rocblas_gemm_ex.hpp b/library/src/blas_ex/rocblas_gemm_ex.hpp index 28169c916..90b299640 100644 --- a/library/src/blas_ex/rocblas_gemm_ex.hpp +++ b/library/src/blas_ex/rocblas_gemm_ex.hpp @@ -263,11 +263,38 @@ rocblas_status gemm_ex_batched_template(rocblas_handle handle, int32_t solution_index, rocblas_gemm_flags flags) { - RocblasContractionProblem problem{ - handle, trans_a, trans_b, m, n, k, alpha, a, - nullptr, lda, stride_a, offset_a, b, nullptr, ldb, stride_b, - offset_b, beta, c, nullptr, ldc, stride_c, offset_c, d, - nullptr, ldd, stride_d, offset_d, batch_count, true, flags}; + // pre apply offsets for non-batched and strided + RocblasContractionProblem problem{handle, + trans_a, + trans_b, + m, + n, + k, + alpha, + a + offset_a, + nullptr, + lda, + stride_a, + 0 /* offset_a */, + b + offset_b, + nullptr, + ldb, + stride_b, + 0 /* offset_b */, + beta, + c + offset_c, + nullptr, + ldc, + stride_c, + 0 /* offset_c */, + d + offset_d, + nullptr, + ldd, + stride_d, + 0 /* offset_d */, + batch_count, + true, + flags}; return runContractionProblem(problem, algo, solution_index); } diff --git a/library/src/blas_ex/rocblas_gemm_ext2.hpp b/library/src/blas_ex/rocblas_gemm_ext2.hpp index 3ae4efad7..72dfca72e 100644 --- a/library/src/blas_ex/rocblas_gemm_ext2.hpp +++ b/library/src/blas_ex/rocblas_gemm_ext2.hpp @@ -58,36 +58,37 @@ rocblas_status gemm_ext2_batched_template(rocblas_handle handle, rocblas_int batch_count = 1, bool strided_batch = true) { + // pre apply offsets for non-batched and strided RocblasContractionProblem problem{handle, m, n, k, alpha, - a, + a + offset_a, nullptr, row_stride_a, col_stride_a, batch_stride_a, - offset_a, - b, + 0 /* offset_a */, + b + offset_b, nullptr, row_stride_b, col_stride_b, batch_stride_b, - offset_b, + 0 /* offset_b */, beta, - c, + c + offset_c, nullptr, row_stride_c, col_stride_c, batch_stride_c, - offset_c, - d, + 0 /* offset_c */, + d + offset_d, nullptr, row_stride_d, col_stride_d, batch_stride_d, - offset_d, + 0 /* offset_d */, batch_count, strided_batch}; diff --git a/scripts/performance/pts/benchmarks/trsm_problems.yaml b/scripts/performance/pts/benchmarks/trsm_problems.yaml index 6545bfe5b..4e3e134cf 100644 --- a/scripts/performance/pts/benchmarks/trsm_problems.yaml +++ b/scripts/performance/pts/benchmarks/trsm_problems.yaml @@ -29,52 +29,178 @@ Definitions: - { M: 120, N: 120, lda: 120, ldb: 120 } - { M: 124, N: 124, lda: 124, ldb: 124 } + - &testset1_small_matrix_size_range + - { M: 256, N: 256, lda: 256, ldb: 256 } + - { M: 512, N: 256, lda: 256, ldb: 512 } + - { M: 512, N: 512, lda: 512, ldb: 512 } + - { M: 768, N: 256, lda: 256, ldb: 768 } + - { M: 1024, N: 256, lda: 256, ldb: 1024 } + - { M: 1280, N: 256, lda: 256, ldb: 1280 } + - { M: 1536, N: 256, lda: 256, ldb: 1536 } + - { M: 1792, N: 256, lda: 256, ldb: 1792 } + - { M: 2048, N: 256, lda: 256, ldb: 2048 } + - { M: 384, N: 384, lda: 384, ldb: 384 } + - { M: 768, N: 384, lda: 384, ldb: 768 } + - { M: 1152, N: 384, lda: 384, ldb: 1152 } + - { M: 1536, N: 384, lda: 384, ldb: 1536 } + - { M: 1920, N: 384, lda: 384, ldb: 1920 } + - { M: 2304, N: 384, lda: 384, ldb: 2304 } + - { M: 2688, N: 384, lda: 384, ldb: 2688 } + - { M: 3072, N: 384, lda: 384, ldb: 3072 } + + - &testset2_small_matrix_size_range + - { M: 256, N: 256, lda: 256, ldb: 256 } + - { M: 256, N: 512, lda: 512, ldb: 256 } + - { M: 256, N: 768, lda: 768, ldb: 256 } + - { M: 256, N: 1024, lda: 1024, ldb: 256 } + - { M: 256, N: 1280, lda: 1280, ldb: 256 } + - { M: 256, N: 1536, lda: 1536, ldb: 256 } + - { M: 256, N: 1792, lda: 1792, ldb: 256 } + - { M: 256, N: 2048, lda: 2048, ldb: 256 } + - { M: 384, N: 384, lda: 384, ldb: 384 } + - { M: 384, N: 768, lda: 768, ldb: 384 } + - { M: 384, N: 1152, lda: 1152, ldb: 384 } + - { M: 384, N: 1536, lda: 1536, ldb: 384 } + - { M: 384, N: 1920, lda: 1920, ldb: 384 } + - { M: 384, N: 2304, lda: 2304, ldb: 384 } + - { M: 384, N: 2688, lda: 2688, ldb: 384 } + - { M: 384, N: 3072, lda: 3072, ldb: 384 } + + - &testset1_matrix_size_range + - { M: 128, N: 2048, lda: 128, ldb: 128 } + - { M: 128, N: 16848, lda: 128, ldb: 128 } + - { M: 128, N: 29696, lda: 128, ldb: 128 } +# - { M: 128, N: 44544, lda: 128, ldb: 128 } +# - { M: 128, N: 53632, lda: 128, ldb: 128 } + - { M: 256, N: 2048, lda: 256, ldb: 256 } + - { M: 256, N: 29696, lda: 256, ldb: 256 } +# - { M: 256, N: 44544, lda: 256, ldb: 256 } +# - { M: 256, N: 53504, lda: 256, ldb: 256 } + - { M: 384, N: 2048, lda: 384, ldb: 384 } + - { M: 384, N: 14976, lda: 384, ldb: 384 } + - { M: 384, N: 29952, lda: 384, ldb: 384 } +# - { M: 384, N: 44928, lda: 384, ldb: 384 } +# - { M: 384, N: 53376, lda: 384, ldb: 384 } + + - &testset2_matrix_size_range + - { M: 2048, N: 128, lda: 2048, ldb: 2048 } + - { M: 16848, N: 128, lda: 16848, ldb: 16848 } + - { M: 29696, N: 128, lda: 29696, ldb: 29696 } +# - { M: 44544, N: 128, lda: 44544, ldb: 44544 } +# - { M: 53632, N: 128, lda: 53632, ldb: 53632 } + - { M: 2048, N: 256, lda: 2048, ldb: 2048 } + - { M: 14848, N: 256, lda: 14848, ldb: 14848 } + - { M: 29696, N: 256, lda: 29696, ldb: 29696 } +# - { M: 44544, N: 256, lda: 44544, ldb: 44544 } +# - { M: 53504, N: 256, lda: 53504, ldb: 53504 } + - { M: 2048, N: 384, lda: 2048, ldb: 2048 } + - { M: 14976, N: 384, lda: 14976, ldb: 14976 } + - { M: 29952, N: 384, lda: 29952, ldb: 29952 } +# - { M: 44928, N: 384, lda: 44928, ldb: 44928 } +# - { M: 53376, N: 384, lda: 53376, ldb: 53376 } + Tests: - - name: trsm_bench_const_n +# - name: trsm_bench_const_n +# category: bench +# function: trsm +# precision: *single_precision +# transA: [ N, T ] +# side: L +# uplo: U +# diag: U +# alpha: 1 +# incx: 1 +# incy: 1 +# N: 32 +# M: 32..120..8 +# lda: 120 # TODO: easy way to increment lda in lockstep with M? +# ldb: 120 +# iters: 20 + +# - name: trsm_bench_const_m +# category: bench +# function: trsm +# precision: *single_precision +# transA: [ N, T ] +# side: L +# uplo: U +# diag: U +# alpha: 1 +# incx: 1 +# incy: 1 +# N: 32..480..32 +# M: 32 +# lda: 32 +# ldb: 32 +# iters: 20 + +# - name: trsm_bench_m_equals_n +# category: bench +# function: trsm +# precision: *single_precision +# transA: [ N, T ] +# side: L +# uplo: U +# diag: U +# alpha: 1 +# incx: 1 +# incy: 1 +# matrix_size: *m_equals_n_range +# iters: 20 + + - name: trsm_bench_1_small_matrix_size + category: bench + function: trsm + precision: *double_precision + transA: [ N, T ] + side: [ L, R ] + uplo: L + diag: U + alpha: 1 + incx: 1 + incy: 1 + matrix_size: *testset1_small_matrix_size_range + iters: 10 + + - name: trsm_bench_2_small_matrix_size category: bench function: trsm - precision: *single_precision + precision: *double_precision transA: [ N, T ] - side: L - uplo: U + side: [ L, R ] + uplo: L diag: U alpha: 1 incx: 1 incy: 1 - N: 32 - M: 32..120..8 - lda: 120 # TODO: easy way to increment lda in lockstep with M? - ldb: 120 - iters: 20 + matrix_size: *testset2_small_matrix_size_range + iters: 10 - - name: trsm_bench_const_m + - name: trsm_bench_1_matrix_size category: bench function: trsm - precision: *single_precision + precision: *double_precision transA: [ N, T ] - side: L - uplo: U + side: [ L, R ] + uplo: L diag: U alpha: 1 incx: 1 incy: 1 - N: 32..480..32 - M: 32 - lda: 32 - ldb: 32 - iters: 20 + matrix_size: *testset1_matrix_size_range + iters: 5 - - name: trsm_bench_m_equals_n + - name: trsm_bench_2_matrix_size category: bench function: trsm - precision: *single_precision + precision: *double_precision transA: [ N, T ] - side: L - uplo: U + side: [ L, R ] + uplo: L diag: U alpha: 1 incx: 1 incy: 1 - matrix_size: *m_equals_n_range - iters: 20 + matrix_size: *testset2_matrix_size_range + iters: 5 ... diff --git a/tensile_tag.txt b/tensile_tag.txt index 85fd25953..8c220d5f4 100644 --- a/tensile_tag.txt +++ b/tensile_tag.txt @@ -1 +1 @@ -e8a3c7d15ec1848a53888747345087ad74ce63f3 +38d444a9f2b6cddfeaeedcb39a5688150fa27093