Skip to content

Commit

Permalink
cherry-pick Remove workaround for K==0 from rocBLAS-internal commit f…
Browse files Browse the repository at this point in the history
  • Loading branch information
amcamd committed Aug 4, 2020
1 parent 505fdb2 commit be61384
Show file tree
Hide file tree
Showing 3 changed files with 9 additions and 17 deletions.
5 changes: 3 additions & 2 deletions clients/gtest/gemm_gtest.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -2911,9 +2911,10 @@ Tests:
alpha_beta: *alpha_beta_range
K: 0
matrix_size:
- { M: 1, N: 2 }
- { M: 3, N: 5 }
- { M: 1, N: 2 }
- { M: 3, N: 5 }
- { M: 512, N: 100 }
- { M: 63, N: 512 }
- { M: 100, N: 1000 }

...
19 changes: 5 additions & 14 deletions library/src/tensile_host.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -140,19 +140,14 @@ namespace
freeIndex[0].c = freeIndex[0].d = 0;
freeIndex[1].c = freeIndex[1].d = 1;

// Tensile does not support 0-sized dimensions. For when k == 0, we still need to
// multiply C by beta, but not add any of the rank-0 dot products. As a workaround,
// we pass k = 1 and set alpha == 0, since alpha == 0 has the same effect as k == 0.
auto k = prob.k == 0 ? 1 : prob.k;

// clang-format off

// If A is transposed, swap the free and bound dimensions and their ranks
if(prob.trans_a != rocblas_operation_none)
{
a = {
Tensile_Ti,
{k, prob.m, prob.batch_count},
{prob.k, prob.m, prob.batch_count},
{prob.row_stride_a, prob.col_stride_a, prob.batch_stride_a}
};
freeIndex[0].i = 1;
Expand All @@ -162,7 +157,7 @@ namespace
{
a = {
Tensile_Ti,
{prob.m, k, prob.batch_count},
{prob.m, prob.k, prob.batch_count},
{prob.row_stride_a, prob.col_stride_a, prob.batch_stride_a}
};
freeIndex[0].i = 0;
Expand All @@ -178,7 +173,7 @@ namespace
{
b = {
Tensile_Ti,
{prob.n, k, prob.batch_count},
{prob.n, prob.k, prob.batch_count},
{prob.row_stride_b, prob.col_stride_b, prob.batch_stride_b}
};
freeIndex[1].i = 0;
Expand All @@ -188,7 +183,7 @@ namespace
{
b = {
Tensile_Ti,
{k, prob.n, prob.batch_count},
{prob.k, prob.n, prob.batch_count},
{prob.row_stride_b, prob.col_stride_b, prob.batch_stride_b}
};
freeIndex[1].i = 1;
Expand Down Expand Up @@ -303,11 +298,7 @@ namespace

// alpha and beta are stored by value in Tensile::TypedContractionInputs
// alpha and beta are copied from host to Tensile::TypedContractionInputs
// We set alpha = 0 if k == 0 (see above)
if(prob.k == 0)
memset(&inputs.alpha, 0, sizeof(inputs.alpha));
else
AlphaBeta<Ti, To, Tc>::copy(&inputs.alpha, prob.alpha);
AlphaBeta<Ti, To, Tc>::copy(&inputs.alpha, prob.alpha);
AlphaBeta<Ti, To, Tc>::copy(&inputs.beta, prob.beta);

return inputs;
Expand Down
2 changes: 1 addition & 1 deletion tensile_tag.txt
Original file line number Diff line number Diff line change
@@ -1 +1 @@
19319869e83243c5e3ca649532d2951de0ba35be
af71ea890a893e647bf2cf4571a90297d65689ca

0 comments on commit be61384

Please sign in to comment.