From 8fdabad234bbf9b177ac2deddbcbf069bef7d38f Mon Sep 17 00:00:00 2001 From: Ruihang Lai Date: Sat, 16 Mar 2024 13:27:59 -0400 Subject: [PATCH] [Dlight] Fix GeMV shared memory estimation Prior to this PR, there is one part missing in the shared memory estimation of the GeMV rule. The GeMV rule optimizes by using cross-thread reduction. When the target does not support warp reduction primitives, the cross-thread reduction will be further lowered to shared memory implementation, which consumes another part of shared memory. If we do not consider this part in the GeMV rule, it is possible for the total shared memory usage to exceed the target shared memory limit. For example, mlc-ai/mlc-llm#1841 reports an issue on the Vulkan shared memory limit exceed. This PR fixes the issue by introducing a flag `SUPPORT_WARP_SHUFFLE` to the GeMV rule. We only enable warp shuffle for CUDA and Metal backend, and turn it off for all other backends. This is basically aligned with the lowering rule of thread allreduce intrinsic. P.S.. ROCm also supports warp shuffle but has some limitation, where not every set of parameters in the GeMV rule can meet. Therefore, we regard ROCm as "not supported". This just mean we will be conservative in the shared memory usage for ROCm, and does not mean we do not use the warp shuffle when the workload is eligible when lowering. --- python/tvm/dlight/gpu/gemv.py | 18 +++++++++++++++--- 1 file changed, 15 insertions(+), 3 deletions(-) diff --git a/python/tvm/dlight/gpu/gemv.py b/python/tvm/dlight/gpu/gemv.py index d1a195fbad6f..ffd6b6d09533 100644 --- a/python/tvm/dlight/gpu/gemv.py +++ b/python/tvm/dlight/gpu/gemv.py @@ -244,6 +244,7 @@ def apply( LOAD_V_SHARED, LOAD_V_VEC, UNROLL, + SUPPORT_WARP_SHUFFLE, ): # rfactor: reduce to tx * vec_c _, s, r, c = sch.get_loops(block=gemv) @@ -273,10 +274,17 @@ def apply( shared_mem_usage = 0 for buf in vector_input_buffers: - buf_size = reduce( - lambda x, y: x * y, buf.shape, tir.IntImm(buf.shape[0].dtype, 1) - ) * get_bytes(buf.dtype) + dtype_bytes = get_bytes(buf.dtype) + buf_size = ( + reduce(lambda x, y: x * y, buf.shape, tir.IntImm(buf.shape[0].dtype, 1)) + * dtype_bytes + ) shared_mem_usage += buf_size + if not SUPPORT_WARP_SHUFFLE: + # When warp shuffle is not able, cross-thread allreduce + # is implemented with shared memory. + shared_mem_usage += TS * TR * dtype_bytes + LOAD_V_SHARED = ( LOAD_V_SHARED and isinstance(shared_mem_usage, tir.IntImm) @@ -421,11 +429,13 @@ def apply( len_R = len_r * len_c TAG_S, TAG_R = "threadIdx.y", "threadIdx.x" + SUPPORT_WARP_SHUFFLE = False if target.kind.name == "cuda": VEC_C = 4 LOAD_V_SHARED = True LOAD_V_VEC = 8 UNROLL = 256 + SUPPORT_WARP_SHUFFLE = True if isinstance(len_S, int): if len_S > len_R: TS, TR = 4, 64 @@ -438,6 +448,7 @@ def apply( LOAD_V_SHARED = False LOAD_V_VEC = -1 UNROLL = 256 + SUPPORT_WARP_SHUFFLE = True if isinstance(len_S, int): if len_S > len_R: TS, TR = 4, 16 @@ -515,6 +526,7 @@ def apply( LOAD_V_SHARED=LOAD_V_SHARED, LOAD_V_VEC=LOAD_V_VEC, UNROLL=UNROLL, + SUPPORT_WARP_SHUFFLE=SUPPORT_WARP_SHUFFLE, ) def sch_outer_reduction( # pylint: disable=too-many-arguments, invalid-name, unused-argument