diff --git a/src/gemm_v2.cpp b/src/gemm_v2.cpp index 19e302f166..1750f625da 100644 --- a/src/gemm_v2.cpp +++ b/src/gemm_v2.cpp @@ -413,6 +413,7 @@ miopenStatus_t CallGemm(const Handle& handle, gemm_desc.isColMajor = !gemm_desc.isColMajor; std::swap(A, B); std::swap(a_offset, b_offset); + std::swap(gemm_desc.a_cast_type, gemm_desc.b_cast_type); std::swap(gemm_desc.transA, gemm_desc.transB); std::swap(gemm_desc.m, gemm_desc.n); std::swap(gemm_desc.lda, gemm_desc.ldb); @@ -665,6 +666,7 @@ miopenStatus_t CallGemmStridedBatched(const Handle& handle, gemm_desc.isColMajor = !gemm_desc.isColMajor; std::swap(A, B); std::swap(a_offset, b_offset); + std::swap(gemm_desc.a_cast_type, gemm_desc.b_cast_type); std::swap(gemm_desc.transA, gemm_desc.transB); std::swap(gemm_desc.m, gemm_desc.n); std::swap(gemm_desc.lda, gemm_desc.ldb); @@ -938,6 +940,7 @@ miopenStatus_t CallGemmStridedBatchedSequential(const Handle& handle, gemm_desc.isColMajor = !gemm_desc.isColMajor; std::swap(A, B); std::swap(a_offset, b_offset); + std::swap(gemm_desc.a_cast_type, gemm_desc.b_cast_type); std::swap(gemm_desc.transA, gemm_desc.transB); std::swap(gemm_desc.m, gemm_desc.n); std::swap(gemm_desc.lda, gemm_desc.ldb);