From 8c4239d4b06361729ba9359d2095cc639841ee9c Mon Sep 17 00:00:00 2001 From: Chris Erb Date: Mon, 9 Oct 2023 18:22:16 -0500 Subject: [PATCH] [Bugfix] Add cast swapping for swapped gemm inputs. (#2443) * add swapping for cast types when swapping A+B for gemm --- src/gemm_v2.cpp | 3 +++ 1 file changed, 3 insertions(+) diff --git a/src/gemm_v2.cpp b/src/gemm_v2.cpp index 19e302f166..1750f625da 100644 --- a/src/gemm_v2.cpp +++ b/src/gemm_v2.cpp @@ -413,6 +413,7 @@ miopenStatus_t CallGemm(const Handle& handle, gemm_desc.isColMajor = !gemm_desc.isColMajor; std::swap(A, B); std::swap(a_offset, b_offset); + std::swap(gemm_desc.a_cast_type, gemm_desc.b_cast_type); std::swap(gemm_desc.transA, gemm_desc.transB); std::swap(gemm_desc.m, gemm_desc.n); std::swap(gemm_desc.lda, gemm_desc.ldb); @@ -665,6 +666,7 @@ miopenStatus_t CallGemmStridedBatched(const Handle& handle, gemm_desc.isColMajor = !gemm_desc.isColMajor; std::swap(A, B); std::swap(a_offset, b_offset); + std::swap(gemm_desc.a_cast_type, gemm_desc.b_cast_type); std::swap(gemm_desc.transA, gemm_desc.transB); std::swap(gemm_desc.m, gemm_desc.n); std::swap(gemm_desc.lda, gemm_desc.ldb); @@ -938,6 +940,7 @@ miopenStatus_t CallGemmStridedBatchedSequential(const Handle& handle, gemm_desc.isColMajor = !gemm_desc.isColMajor; std::swap(A, B); std::swap(a_offset, b_offset); + std::swap(gemm_desc.a_cast_type, gemm_desc.b_cast_type); std::swap(gemm_desc.transA, gemm_desc.transB); std::swap(gemm_desc.m, gemm_desc.n); std::swap(gemm_desc.lda, gemm_desc.ldb);