From 02048fc96ce4f5044a66c10636610a6bc9643181 Mon Sep 17 00:00:00 2001 From: Avshalom Date: Mon, 1 Jul 2024 15:43:34 +0300 Subject: [PATCH 01/10] adding chunking mechanism to fused_moe to handle long seqs --- vllm/envs.py | 4 + .../layers/fused_moe/fused_moe.py | 115 +++++++++++------- 2 files changed, 72 insertions(+), 47 deletions(-) diff --git a/vllm/envs.py b/vllm/envs.py index e8257535f1bf5..5cf93ba7c9d66 100644 --- a/vllm/envs.py +++ b/vllm/envs.py @@ -32,6 +32,7 @@ VLLM_OPENVINO_CPU_KV_CACHE_PRECISION: Optional[str] = None VLLM_OPENVINO_ENABLE_QUANTIZED_WEIGHTS: bool = False VLLM_XLA_CACHE_PATH: str = "~/.vllm/xla_cache/" + VLLM_FUSED_MOE_CHUNK_SIZE: int = 64 * 1024 VLLM_USE_RAY_COMPILED_DAG: bool = False VLLM_WORKER_MULTIPROC_METHOD: str = "fork" VLLM_IMAGE_FETCH_TIMEOUT: int = 5 @@ -248,6 +249,9 @@ # Only used for XLA devices such as TPUs. "VLLM_XLA_CACHE_PATH": lambda: os.getenv("VLLM_XLA_CACHE_PATH", "~/.vllm/xla_cache/"), + + "VLLM_FUSED_MOE_CHUNK_SIZE": + lambda: int(os.getenv("VLLM_FUSED_MOE_CHUNK_SIZE", "65536")), } # end-env-vars-definition diff --git a/vllm/model_executor/layers/fused_moe/fused_moe.py b/vllm/model_executor/layers/fused_moe/fused_moe.py index ecab77a8b6dfb..7498454bbf505 100644 --- a/vllm/model_executor/layers/fused_moe/fused_moe.py +++ b/vllm/model_executor/layers/fused_moe/fused_moe.py @@ -10,6 +10,7 @@ from vllm import _custom_ops as ops from vllm.logger import init_logger +import vllm.envs as envs logger = init_logger(__name__) @@ -420,13 +421,10 @@ def fused_experts(hidden_states: torch.Tensor, torch.float32, torch.float16, torch.bfloat16 ] - M, _ = hidden_states.shape + num_tokens, _ = hidden_states.shape E, N, _ = w1.shape - - if M > 65536: - # https://github.com/vllm-project/vllm/issues/5938 - raise ValueError("MoE kernel does not support more than 65536 tokens, " - f"but got {M}") + CHUNK_SIZE = envs.VLLM_FUSED_MOE_CHUNK_SIZE + M = min(num_tokens, CHUNK_SIZE) if override_config: config = override_config @@ -455,51 +453,74 @@ def fused_experts(hidden_states: torch.Tensor, device=hidden_states.device, dtype=hidden_states.dtype) - sorted_token_ids, expert_ids, num_tokens_post_padded = moe_align_block_size( - topk_ids, config['BLOCK_SIZE_M'], E) compute_type = (tl.bfloat16 if hidden_states.dtype == torch.bfloat16 else tl.float16) - invoke_fused_moe_kernel(hidden_states, - w1, - intermediate_cache1, - a1_scale, - w1_scale, - topk_weights, - topk_ids, - sorted_token_ids, - expert_ids, - num_tokens_post_padded, - False, - topk_ids.shape[1], - config, - compute_type=compute_type, - use_fp8=use_fp8) - - ops.silu_and_mul(intermediate_cache2, intermediate_cache1.view(-1, N)) - - invoke_fused_moe_kernel(intermediate_cache2, - w2, - intermediate_cache3, - a2_scale, - w2_scale, - topk_weights, - topk_ids, - sorted_token_ids, - expert_ids, - num_tokens_post_padded, - True, - 1, - config, - compute_type=compute_type, - use_fp8=use_fp8) - + states_acc = [] + for chunk in range((num_tokens // CHUNK_SIZE) + 1): + begin_chunk_idx, end_chunk_idx = chunk * CHUNK_SIZE, min((chunk + 1) * CHUNK_SIZE, num_tokens) + curr_hidden_states = hidden_states[begin_chunk_idx:end_chunk_idx] + tokens_in_chunk, _ = curr_hidden_states.shape + + if tokens_in_chunk == 0: + break + + if tokens_in_chunk < CHUNK_SIZE: + # will only happen in the last chunk + intermediate_cache1 = intermediate_cache1[:tokens_in_chunk] + intermediate_cache2 = intermediate_cache2[:tokens_in_chunk] + intermediate_cache3 = intermediate_cache3[:tokens_in_chunk] + + curr_topk_ids = topk_ids[begin_chunk_idx:end_chunk_idx] + curr_topk_weights = topk_weights[begin_chunk_idx:end_chunk_idx] + + sorted_token_ids, expert_ids, num_tokens_post_padded = moe_align_block_size(curr_topk_ids, + config['BLOCK_SIZE_M'], E) + + invoke_fused_moe_kernel(curr_hidden_states, + w1, + intermediate_cache1, + a1_scale, + w1_scale, + curr_topk_weights, + curr_topk_ids, + sorted_token_ids, + expert_ids, + num_tokens_post_padded, + False, + topk_ids.shape[1], + config, + compute_type=compute_type, + use_fp8=use_fp8) + + ops.silu_and_mul(intermediate_cache2, intermediate_cache1.view(-1, N)) + + invoke_fused_moe_kernel(intermediate_cache2, + w2, + intermediate_cache3, + a2_scale, + w2_scale, + curr_topk_weights, + curr_topk_ids, + sorted_token_ids, + expert_ids, + num_tokens_post_padded, + True, + 1, + config, + compute_type=compute_type, + use_fp8=use_fp8) + + if inplace: + torch.sum(intermediate_cache3.view(*intermediate_cache3.shape), + dim=1, + out=hidden_states[begin_chunk_idx:end_chunk_idx]) + else: + states_acc.append(torch.sum(intermediate_cache3.view(*intermediate_cache3.shape), dim=1)) if inplace: - return torch.sum(intermediate_cache3.view(*intermediate_cache3.shape), - dim=1, - out=hidden_states) - return torch.sum(intermediate_cache3.view(*intermediate_cache3.shape), - dim=1) + return hidden_states + else: + return torch.cat(states_acc, dim=0) def fused_moe( From 867da75b7e6c2fafa3380045fdf23521c73c3f4e Mon Sep 17 00:00:00 2001 From: Avshalom Date: Mon, 1 Jul 2024 16:01:59 +0300 Subject: [PATCH 02/10] fix ruff --- vllm/model_executor/layers/fused_moe/fused_moe.py | 10 ++++++---- 1 file changed, 6 insertions(+), 4 deletions(-) diff --git a/vllm/model_executor/layers/fused_moe/fused_moe.py b/vllm/model_executor/layers/fused_moe/fused_moe.py index 7498454bbf505..455d0f1d0b4c0 100644 --- a/vllm/model_executor/layers/fused_moe/fused_moe.py +++ b/vllm/model_executor/layers/fused_moe/fused_moe.py @@ -458,7 +458,8 @@ def fused_experts(hidden_states: torch.Tensor, states_acc = [] for chunk in range((num_tokens // CHUNK_SIZE) + 1): - begin_chunk_idx, end_chunk_idx = chunk * CHUNK_SIZE, min((chunk + 1) * CHUNK_SIZE, num_tokens) + begin_chunk_idx, end_chunk_idx = (chunk * CHUNK_SIZE, + min((chunk + 1) * CHUNK_SIZE, num_tokens)) curr_hidden_states = hidden_states[begin_chunk_idx:end_chunk_idx] tokens_in_chunk, _ = curr_hidden_states.shape @@ -474,8 +475,8 @@ def fused_experts(hidden_states: torch.Tensor, curr_topk_ids = topk_ids[begin_chunk_idx:end_chunk_idx] curr_topk_weights = topk_weights[begin_chunk_idx:end_chunk_idx] - sorted_token_ids, expert_ids, num_tokens_post_padded = moe_align_block_size(curr_topk_ids, - config['BLOCK_SIZE_M'], E) + sorted_token_ids, expert_ids, num_tokens_post_padded = moe_align_block_size( + curr_topk_ids, config['BLOCK_SIZE_M'], E) invoke_fused_moe_kernel(curr_hidden_states, w1, @@ -516,7 +517,8 @@ def fused_experts(hidden_states: torch.Tensor, dim=1, out=hidden_states[begin_chunk_idx:end_chunk_idx]) else: - states_acc.append(torch.sum(intermediate_cache3.view(*intermediate_cache3.shape), dim=1)) + states_acc.append(torch.sum(intermediate_cache3.view( + *intermediate_cache3.shape), dim=1)) if inplace: return hidden_states else: From 5dd87b7649e43e0d02d26a6adeb61f9b982b8e2c Mon Sep 17 00:00:00 2001 From: Avshalom Date: Mon, 1 Jul 2024 16:28:37 +0300 Subject: [PATCH 03/10] fix ruff #2 --- vllm/model_executor/layers/fused_moe/fused_moe.py | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/vllm/model_executor/layers/fused_moe/fused_moe.py b/vllm/model_executor/layers/fused_moe/fused_moe.py index 455d0f1d0b4c0..c2f7e5c579563 100644 --- a/vllm/model_executor/layers/fused_moe/fused_moe.py +++ b/vllm/model_executor/layers/fused_moe/fused_moe.py @@ -459,7 +459,8 @@ def fused_experts(hidden_states: torch.Tensor, states_acc = [] for chunk in range((num_tokens // CHUNK_SIZE) + 1): begin_chunk_idx, end_chunk_idx = (chunk * CHUNK_SIZE, - min((chunk + 1) * CHUNK_SIZE, num_tokens)) + min((chunk + 1) * CHUNK_SIZE, + num_tokens)) curr_hidden_states = hidden_states[begin_chunk_idx:end_chunk_idx] tokens_in_chunk, _ = curr_hidden_states.shape @@ -475,7 +476,7 @@ def fused_experts(hidden_states: torch.Tensor, curr_topk_ids = topk_ids[begin_chunk_idx:end_chunk_idx] curr_topk_weights = topk_weights[begin_chunk_idx:end_chunk_idx] - sorted_token_ids, expert_ids, num_tokens_post_padded = moe_align_block_size( + sorted_token_ids, expert_ids, num_tokens_post_padded = moe_align_block_size( # noqa: E501 curr_topk_ids, config['BLOCK_SIZE_M'], E) invoke_fused_moe_kernel(curr_hidden_states, From a8312baaa35acbbef7f694cfe812c34659e7ecef Mon Sep 17 00:00:00 2001 From: Avshalom Date: Mon, 1 Jul 2024 16:31:42 +0300 Subject: [PATCH 04/10] fixing isort --- vllm/model_executor/layers/fused_moe/fused_moe.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/vllm/model_executor/layers/fused_moe/fused_moe.py b/vllm/model_executor/layers/fused_moe/fused_moe.py index c2f7e5c579563..b385d2d317e05 100644 --- a/vllm/model_executor/layers/fused_moe/fused_moe.py +++ b/vllm/model_executor/layers/fused_moe/fused_moe.py @@ -8,9 +8,9 @@ import triton import triton.language as tl +import vllm.envs as envs from vllm import _custom_ops as ops from vllm.logger import init_logger -import vllm.envs as envs logger = init_logger(__name__) From 3c1e4f4a693b99e3b27b954313081775f040caef Mon Sep 17 00:00:00 2001 From: Avshalom Date: Mon, 1 Jul 2024 16:43:59 +0300 Subject: [PATCH 05/10] fix yapf --- vllm/model_executor/layers/fused_moe/fused_moe.py | 7 ++++--- 1 file changed, 4 insertions(+), 3 deletions(-) diff --git a/vllm/model_executor/layers/fused_moe/fused_moe.py b/vllm/model_executor/layers/fused_moe/fused_moe.py index b385d2d317e05..ddd973b4a23d9 100644 --- a/vllm/model_executor/layers/fused_moe/fused_moe.py +++ b/vllm/model_executor/layers/fused_moe/fused_moe.py @@ -476,7 +476,7 @@ def fused_experts(hidden_states: torch.Tensor, curr_topk_ids = topk_ids[begin_chunk_idx:end_chunk_idx] curr_topk_weights = topk_weights[begin_chunk_idx:end_chunk_idx] - sorted_token_ids, expert_ids, num_tokens_post_padded = moe_align_block_size( # noqa: E501 + sorted_token_ids, expert_ids, num_tokens_post_padded = moe_align_block_size( # noqa: E501 curr_topk_ids, config['BLOCK_SIZE_M'], E) invoke_fused_moe_kernel(curr_hidden_states, @@ -518,8 +518,9 @@ def fused_experts(hidden_states: torch.Tensor, dim=1, out=hidden_states[begin_chunk_idx:end_chunk_idx]) else: - states_acc.append(torch.sum(intermediate_cache3.view( - *intermediate_cache3.shape), dim=1)) + states_acc.append( + torch.sum(intermediate_cache3.view(*intermediate_cache3.shape), + dim=1)) if inplace: return hidden_states else: From 4bde0be9b1b1d6a220b9b6d0de39ae137c26597d Mon Sep 17 00:00:00 2001 From: Avshalom Date: Mon, 1 Jul 2024 16:46:35 +0300 Subject: [PATCH 06/10] fix yapf #2 --- vllm/envs.py | 1 - 1 file changed, 1 deletion(-) diff --git a/vllm/envs.py b/vllm/envs.py index 5cf93ba7c9d66..c624510c7ea1a 100644 --- a/vllm/envs.py +++ b/vllm/envs.py @@ -249,7 +249,6 @@ # Only used for XLA devices such as TPUs. "VLLM_XLA_CACHE_PATH": lambda: os.getenv("VLLM_XLA_CACHE_PATH", "~/.vllm/xla_cache/"), - "VLLM_FUSED_MOE_CHUNK_SIZE": lambda: int(os.getenv("VLLM_FUSED_MOE_CHUNK_SIZE", "65536")), } From 054bce70f7b0a132d3604ea813940295fc86dd29 Mon Sep 17 00:00:00 2001 From: Avshalom Date: Mon, 1 Jul 2024 20:25:57 +0300 Subject: [PATCH 07/10] pre-allocate output buffer --- .../layers/fused_moe/fused_moe.py | 22 +++++++++---------- 1 file changed, 10 insertions(+), 12 deletions(-) diff --git a/vllm/model_executor/layers/fused_moe/fused_moe.py b/vllm/model_executor/layers/fused_moe/fused_moe.py index ddd973b4a23d9..ed4777fea3e87 100644 --- a/vllm/model_executor/layers/fused_moe/fused_moe.py +++ b/vllm/model_executor/layers/fused_moe/fused_moe.py @@ -423,6 +423,8 @@ def fused_experts(hidden_states: torch.Tensor, num_tokens, _ = hidden_states.shape E, N, _ = w1.shape + # We execute the fused_moe kernel in chunks to circumvent this issue: + # https://github.com/vllm-project/vllm/issues/5938 CHUNK_SIZE = envs.VLLM_FUSED_MOE_CHUNK_SIZE M = min(num_tokens, CHUNK_SIZE) @@ -456,7 +458,11 @@ def fused_experts(hidden_states: torch.Tensor, compute_type = (tl.bfloat16 if hidden_states.dtype == torch.bfloat16 else tl.float16) - states_acc = [] + if inplace: + out_hidden_states = hidden_states + else: + out_hidden_states = torch.zeros_like(hidden_states) + for chunk in range((num_tokens // CHUNK_SIZE) + 1): begin_chunk_idx, end_chunk_idx = (chunk * CHUNK_SIZE, min((chunk + 1) * CHUNK_SIZE, @@ -513,18 +519,10 @@ def fused_experts(hidden_states: torch.Tensor, compute_type=compute_type, use_fp8=use_fp8) - if inplace: - torch.sum(intermediate_cache3.view(*intermediate_cache3.shape), + torch.sum(intermediate_cache3.view(*intermediate_cache3.shape), dim=1, - out=hidden_states[begin_chunk_idx:end_chunk_idx]) - else: - states_acc.append( - torch.sum(intermediate_cache3.view(*intermediate_cache3.shape), - dim=1)) - if inplace: - return hidden_states - else: - return torch.cat(states_acc, dim=0) + out=out_hidden_states[begin_chunk_idx:end_chunk_idx]) + return out_hidden_states def fused_moe( From f462d8bbd3dfd2f04e817c1cd0de5efdca5cc71c Mon Sep 17 00:00:00 2001 From: Avshalom Date: Mon, 1 Jul 2024 20:31:52 +0300 Subject: [PATCH 08/10] fixing yapf #3 --- vllm/model_executor/layers/fused_moe/fused_moe.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/vllm/model_executor/layers/fused_moe/fused_moe.py b/vllm/model_executor/layers/fused_moe/fused_moe.py index ed4777fea3e87..e930f4f4d3d6a 100644 --- a/vllm/model_executor/layers/fused_moe/fused_moe.py +++ b/vllm/model_executor/layers/fused_moe/fused_moe.py @@ -520,8 +520,8 @@ def fused_experts(hidden_states: torch.Tensor, use_fp8=use_fp8) torch.sum(intermediate_cache3.view(*intermediate_cache3.shape), - dim=1, - out=out_hidden_states[begin_chunk_idx:end_chunk_idx]) + dim=1, + out=out_hidden_states[begin_chunk_idx:end_chunk_idx]) return out_hidden_states From 1fda88469e9459d78fd95fb2dc6ae39f2e01457b Mon Sep 17 00:00:00 2001 From: Avshalom Date: Mon, 1 Jul 2024 20:36:46 +0300 Subject: [PATCH 09/10] add long input test case --- tests/kernels/test_moe.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/kernels/test_moe.py b/tests/kernels/test_moe.py index 2356b9ec18b0d..22b6769ac3f23 100644 --- a/tests/kernels/test_moe.py +++ b/tests/kernels/test_moe.py @@ -29,7 +29,7 @@ def torch_moe(a, w1, w2, score, topk): topk_weight.view(B, -1, 1).to(out.dtype)).sum(dim=1) -@pytest.mark.parametrize("m", [512, 222, 33, 1]) +@pytest.mark.parametrize("m", [1024 * 128, 512, 222, 33, 1]) @pytest.mark.parametrize("n", [2048, 256, 1024]) @pytest.mark.parametrize("k", [128, 511, 1024]) @pytest.mark.parametrize("e", [8, 64]) From fe51961d2bf31312e327993c634230513ddeefa6 Mon Sep 17 00:00:00 2001 From: Avshalom Date: Mon, 1 Jul 2024 20:42:00 +0300 Subject: [PATCH 10/10] zeros_like -> empty_like --- vllm/model_executor/layers/fused_moe/fused_moe.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/vllm/model_executor/layers/fused_moe/fused_moe.py b/vllm/model_executor/layers/fused_moe/fused_moe.py index e930f4f4d3d6a..99a5c7d78a67e 100644 --- a/vllm/model_executor/layers/fused_moe/fused_moe.py +++ b/vllm/model_executor/layers/fused_moe/fused_moe.py @@ -461,7 +461,7 @@ def fused_experts(hidden_states: torch.Tensor, if inplace: out_hidden_states = hidden_states else: - out_hidden_states = torch.zeros_like(hidden_states) + out_hidden_states = torch.empty_like(hidden_states) for chunk in range((num_tokens // CHUNK_SIZE) + 1): begin_chunk_idx, end_chunk_idx = (chunk * CHUNK_SIZE, @@ -482,8 +482,8 @@ def fused_experts(hidden_states: torch.Tensor, curr_topk_ids = topk_ids[begin_chunk_idx:end_chunk_idx] curr_topk_weights = topk_weights[begin_chunk_idx:end_chunk_idx] - sorted_token_ids, expert_ids, num_tokens_post_padded = moe_align_block_size( # noqa: E501 - curr_topk_ids, config['BLOCK_SIZE_M'], E) + sorted_token_ids, expert_ids, num_tokens_post_padded = ( + moe_align_block_size(curr_topk_ids, config['BLOCK_SIZE_M'], E)) invoke_fused_moe_kernel(curr_hidden_states, w1,