Skip to content

Commit

Permalink
Merge pull request #1298 from ROCmSoftwarePlatform/hotfix-5.5-swdev-3…
Browse files Browse the repository at this point in the history
…81033-trsm

Hotfix: Fix offset calculation to prevent overflow if offset is really large
  • Loading branch information
amcamd authored Mar 28, 2023
2 parents 3ec7630 + 936df4f commit cdd561f
Show file tree
Hide file tree
Showing 11 changed files with 495 additions and 255 deletions.
54 changes: 54 additions & 0 deletions clients/gtest/trsm_gtest.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -60,6 +60,17 @@ Definitions:
- &large_memory_matrix_size_range
- { M: 8320, N: 128, lda: 8320, ldb: 8320 }

- &size_t_left_matrix_size_range
- { M: 128, N: 128, lda: 16777220, ldb: 16777220 } # trsm "special" kernel with overflow lda/ldb
- { M: 1000, N: 4, lda: 2147500, ldb: 2147500 } # trsm "left" kernel with overflow lda/ldb
#- { M: 16, N: 16, lda: 134217728, ldb: 134217728 } # trsm "small" kernel with overflow lda/ldb
#- { M: 128, N: 64, lda: 33554450, ldb: 33554450 } # trsm "subsitution" kernel with overflow lda/ldb

- &size_t_right_matrix_size_range
#- { M: 128, N: 128, lda: 16777220, ldb: 16777220 } # trsm "special" kernel with overflow lda/ldb
- { M: 4, N: 1000, lda: 2147500, ldb: 2147500 } # trsm "right" kernel with overflow lda/ldb
#- { M: 16, N: 16, lda: 134217728, ldb: 134217728 } # trsm "small" kernel with overflow lda/ldb

- &substitution_size_range_thorough
- { M: 1, N: 1, lda: 100, ldb: 100 }
- { M: 1, N: 32, lda: 100, ldb: 100 }
Expand Down Expand Up @@ -486,6 +497,49 @@ Tests:
matrix_size: *testset2_matrix_size_range
alpha: [ 1 ]

- name: trsm_size_t_left
category: nightly
function:
- trsm
#- trsm_batched
#- trsm_strided_batched
# precision: *single_double_precisions_complex_real
precision: *single_precision
arguments:
#- { side: L, uplo: L, transA: N, diag: N }
#- { side: L, uplo: L, transA: N, diag: U }
#- { side: L, uplo: L, transA: T, diag: N }
#- { side: L, uplo: L, transA: T, diag: U }
#- { side: L, uplo: U, transA: N, diag: N }
- { side: L, uplo: U, transA: N, diag: U }
#- { side: L, uplo: U, transA: T, diag: N }
#- { side: L, uplo: U, transA: T, diag: U }
matrix_size: *size_t_left_matrix_size_range
alpha: [2]
batch_count: [2]
stride_scale: [1]

- name: trsm_size_t_right
category: nightly
function:
- trsm
#- trsm_batched
#- trsm_strided_batched
# precision: *single_double_precisions_complex_real
precision: *single_precision
arguments:
- { side: R, uplo: L, transA: N, diag: N }
#- { side: R, uplo: L, transA: N, diag: U }
#- { side: R, uplo: L, transA: T, diag: N }
#- { side: R, uplo: L, transA: T, diag: U }
#- { side: R, uplo: U, transA: N, diag: N }
#- { side: R, uplo: U, transA: N, diag: U }
#- { side: R, uplo: U, transA: T, diag: N }
#- { side: R, uplo: U, transA: T, diag: U }
matrix_size: *size_t_right_matrix_size_range
alpha: [2]
stride_scale: [1]

- name: trsm_large
category: nightly
function: trsm
Expand Down
16 changes: 8 additions & 8 deletions clients/include/blas3/testing_trsm_strided_batched.hpp
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
/* ************************************************************************
* Copyright (C) 2018-2022 Advanced Micro Devices, Inc. All rights reserved.
* Copyright (C) 2018-2023 Advanced Micro Devices, Inc. All rights reserved.
*
* Permission is hereby granted, free of charge, to any person obtaining a copy
* of this software and associated documentation files (the "Software"), to deal
Expand Down Expand Up @@ -413,13 +413,13 @@ void testing_trsm_strided_batched(const Arguments& arg)
auto rocblas_trsm_strided_batched_fn = arg.fortran ? rocblas_trsm_strided_batched<T, true>
: rocblas_trsm_strided_batched<T, false>;

rocblas_int M = arg.M;
rocblas_int N = arg.N;
rocblas_int lda = arg.lda;
rocblas_int ldb = arg.ldb;
rocblas_int stride_A = arg.stride_a;
rocblas_int stride_B = arg.stride_b;
rocblas_int batch_count = arg.batch_count;
rocblas_int M = arg.M;
rocblas_int N = arg.N;
rocblas_int lda = arg.lda;
rocblas_int ldb = arg.ldb;
rocblas_stride stride_A = arg.stride_a;
rocblas_stride stride_B = arg.stride_b;
rocblas_int batch_count = arg.batch_count;

char char_side = arg.side;
char char_uplo = arg.uplo;
Expand Down
10 changes: 5 additions & 5 deletions clients/include/utility.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -341,24 +341,24 @@ void make_unit_diagonal(rocblas_fill uplo, T* hA, rocblas_int lda, rocblas_int N
{
for(int i = 0; i < N; i++)
{
T diag = hA[i + i * lda];
T diag = hA[i + i * size_t(lda)];
for(int j = 0; j <= i; j++)
hA[i + j * lda] = hA[i + j * lda] / diag;
hA[i + j * size_t(lda)] = hA[i + j * size_t(lda)] / diag;
}
}
else // rocblas_fill_upper
{
for(int j = 0; j < N; j++)
{
T diag = hA[j + j * lda];
T diag = hA[j + j * size_t(lda)];
for(int i = 0; i <= j; i++)
hA[i + j * lda] = hA[i + j * lda] / diag;
hA[i + j * size_t(lda)] = hA[i + j * size_t(lda)] / diag;
}
}
// randomly initalize diagonal to ensure we aren't using it's values for tests.
for(int i = 0; i < N; i++)
{
rocblas_init<T>(hA + i * lda + i, 1, 1, 1);
rocblas_init<T>(hA + i * size_t(lda) + i, 1, 1, 1);
}
}

Expand Down
40 changes: 27 additions & 13 deletions library/src/blas3/Tensile/gemm_tensile.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -139,19 +139,33 @@ inline rocblas_status call_tensile(rocblas_handle handle,
}
#endif

RocblasContractionProblem<T> problem{handle, trans_a,
trans_b, m,
n, k,
alpha, A,
nullptr, ld_a,
stride_a, offset_a,
B, nullptr,
ld_b, stride_b,
offset_b, beta,
C, nullptr,
ld_c, stride_c,
offset_c, batch_count,
true, rocblas_gemm_flags_none};
// pre apply offsets for non-batched and strided
RocblasContractionProblem<T> problem{handle,
trans_a,
trans_b,
m,
n,
k,
alpha,
A + offset_a,
nullptr,
ld_a,
stride_a,
0 /* offset_a */,
B + offset_b,
nullptr,
ld_b,
stride_b,
0 /* offset_b */,
beta,
C + offset_c,
nullptr,
ld_c,
stride_c,
0 /* offset_c */,
batch_count,
true,
rocblas_gemm_flags_none};

return runContractionProblem(problem);
}
Expand Down
Loading

0 comments on commit cdd561f

Please sign in to comment.