Skip to content

Commit

Permalink
[Dlight] Fix GeMV shared memory estimation (apache#16731)
Browse files Browse the repository at this point in the history
Prior to this PR, there is one part missing in the shared memory
estimation of the GeMV rule. The GeMV rule optimizes by using
cross-thread reduction. When the target does not support warp
reduction primitives, the cross-thread reduction will be further
lowered to shared memory implementation, which consumes another
part of shared memory.

If we do not consider this part in the GeMV rule, it is possible
for the total shared memory usage to exceed the target shared
memory limit. For example, mlc-ai/mlc-llm#1841 reports an issue
on the Vulkan shared memory limit exceed.

This PR fixes the issue by introducing a flag `SUPPORT_WARP_SHUFFLE`
to the GeMV rule. We only enable warp shuffle for CUDA and Metal
backend, and turn it off for all other backends. This is basically
aligned with the lowering rule of thread allreduce intrinsic.

P.S.. ROCm also supports warp shuffle but has some limitation, where
not every set of parameters in the GeMV rule can meet. Therefore,
we regard ROCm as "not supported". This just mean we will be
conservative in the shared memory usage for ROCm, and does not mean
we do not use the warp shuffle when the workload is eligible
when lowering.
  • Loading branch information
MasterJH5574 authored Mar 16, 2024
1 parent b8f64c2 commit 1c73491
Showing 1 changed file with 15 additions and 3 deletions.
18 changes: 15 additions & 3 deletions python/tvm/dlight/gpu/gemv.py
Original file line number Diff line number Diff line change
Expand Up @@ -244,6 +244,7 @@ def apply(
LOAD_V_SHARED,
LOAD_V_VEC,
UNROLL,
SUPPORT_WARP_SHUFFLE,
):
# rfactor: reduce to tx * vec_c
_, s, r, c = sch.get_loops(block=gemv)
Expand Down Expand Up @@ -273,10 +274,17 @@ def apply(

shared_mem_usage = 0
for buf in vector_input_buffers:
buf_size = reduce(
lambda x, y: x * y, buf.shape, tir.IntImm(buf.shape[0].dtype, 1)
) * get_bytes(buf.dtype)
dtype_bytes = get_bytes(buf.dtype)
buf_size = (
reduce(lambda x, y: x * y, buf.shape, tir.IntImm(buf.shape[0].dtype, 1))
* dtype_bytes
)
shared_mem_usage += buf_size
if not SUPPORT_WARP_SHUFFLE:
# When warp shuffle is not able, cross-thread allreduce
# is implemented with shared memory.
shared_mem_usage += TS * TR * dtype_bytes

LOAD_V_SHARED = (
LOAD_V_SHARED
and isinstance(shared_mem_usage, tir.IntImm)
Expand Down Expand Up @@ -421,11 +429,13 @@ def apply(
len_R = len_r * len_c

TAG_S, TAG_R = "threadIdx.y", "threadIdx.x"
SUPPORT_WARP_SHUFFLE = False
if target.kind.name == "cuda":
VEC_C = 4
LOAD_V_SHARED = True
LOAD_V_VEC = 8
UNROLL = 256
SUPPORT_WARP_SHUFFLE = True
if isinstance(len_S, int):
if len_S > len_R:
TS, TR = 4, 64
Expand All @@ -438,6 +448,7 @@ def apply(
LOAD_V_SHARED = False
LOAD_V_VEC = -1
UNROLL = 256
SUPPORT_WARP_SHUFFLE = True
if isinstance(len_S, int):
if len_S > len_R:
TS, TR = 4, 16
Expand Down Expand Up @@ -515,6 +526,7 @@ def apply(
LOAD_V_SHARED=LOAD_V_SHARED,
LOAD_V_VEC=LOAD_V_VEC,
UNROLL=UNROLL,
SUPPORT_WARP_SHUFFLE=SUPPORT_WARP_SHUFFLE,
)

def sch_outer_reduction( # pylint: disable=too-many-arguments, invalid-name, unused-argument
Expand Down

0 comments on commit 1c73491

Please sign in to comment.