From e15632fd8b668bbb389a3baf41cf9cd129937378 Mon Sep 17 00:00:00 2001 From: Zhangyi <1109276519@qq.com> Date: Thu, 28 Nov 2024 22:19:33 +0800 Subject: [PATCH 01/10] move vllm distributed to sglang --- python/sglang/srt/_custom_ops.py | 1002 +++++++++++++ python/sglang/srt/distributed/__init__.py | 3 + .../srt/distributed/communication_op.py | 33 + .../device_communicators/__init__.py | 0 .../device_communicators/cuda_wrapper.py | 181 +++ .../device_communicators/custom_all_reduce.py | 309 ++++ .../custom_all_reduce_utils.py | 275 ++++ .../device_communicators/hpu_communicator.py | 46 + .../device_communicators/pynccl.py | 204 +++ .../device_communicators/pynccl_wrapper.py | 332 +++++ .../device_communicators/shm_broadcast.py | 496 +++++++ .../device_communicators/tpu_communicator.py | 59 + .../device_communicators/xpu_communicator.py | 45 + .../sglang/srt/distributed/parallel_state.py | 1291 +++++++++++++++++ python/sglang/srt/distributed/utils.py | 221 +++ python/sglang/srt/utils.py | 6 + 16 files changed, 4503 insertions(+) create mode 100644 python/sglang/srt/_custom_ops.py create mode 100644 python/sglang/srt/distributed/__init__.py create mode 100644 python/sglang/srt/distributed/communication_op.py create mode 100644 python/sglang/srt/distributed/device_communicators/__init__.py create mode 100644 python/sglang/srt/distributed/device_communicators/cuda_wrapper.py create mode 100644 python/sglang/srt/distributed/device_communicators/custom_all_reduce.py create mode 100644 python/sglang/srt/distributed/device_communicators/custom_all_reduce_utils.py create mode 100644 python/sglang/srt/distributed/device_communicators/hpu_communicator.py create mode 100644 python/sglang/srt/distributed/device_communicators/pynccl.py create mode 100644 python/sglang/srt/distributed/device_communicators/pynccl_wrapper.py create mode 100644 python/sglang/srt/distributed/device_communicators/shm_broadcast.py create mode 100644 python/sglang/srt/distributed/device_communicators/tpu_communicator.py create mode 100644 python/sglang/srt/distributed/device_communicators/xpu_communicator.py create mode 100644 python/sglang/srt/distributed/parallel_state.py create mode 100644 python/sglang/srt/distributed/utils.py diff --git a/python/sglang/srt/_custom_ops.py b/python/sglang/srt/_custom_ops.py new file mode 100644 index 0000000000..604b75bb95 --- /dev/null +++ b/python/sglang/srt/_custom_ops.py @@ -0,0 +1,1002 @@ +import contextlib +import functools +import importlib +import logging +from typing import TYPE_CHECKING, List, Optional, Tuple, Union + +import torch +import torch.library + +# import vllm.envs as envs +from vllm.platforms import current_platform + +# from vllm.scalar_type import ScalarType + +logger = logging.getLogger(__name__) + +if not current_platform.is_tpu() and not current_platform.is_hpu(): + try: + import custom_ar + except ImportError as e: + logger.warning("Failed to import from vllm._C with %r", e) + +""" +if current_platform.is_rocm(): + import vllm._rocm_C # noqa: F401 + +supports_moe_ops = False +with contextlib.suppress(ImportError): + import vllm._moe_C # noqa: F401 + supports_moe_ops = True + +# neuron has torch version that doesn't even have impl_abstract +if TYPE_CHECKING or current_platform.is_neuron(): + + def register_fake(fn): + return lambda name: fn +else: + try: + from torch.library import register_fake + except ImportError: + from torch.library import impl_abstract as register_fake +""" + + +def hint_on_error(fn): + + @functools.wraps(fn) + def wrapper(*args, **kwargs): + try: + return fn(*args, **kwargs) + + except NotImplementedError as e: + msg = ( + "Error in calling custom op %s: %s\n" + "Not implemented or built, mostly likely because the current current device " + "does not support this kernel (less likely TORCH_CUDA_ARCH_LIST was set " + "incorrectly while building)" + ) + logger.error(msg, fn.__name__, e) + raise NotImplementedError(msg % (fn.__name__, e)) from e + except AttributeError as e: + msg = ( + "Error in calling custom op %s: %s\n" + "Possibly you have built or installed an obsolete version of vllm.\n" + "Please try a clean build and install of vllm," + "or remove old built files such as vllm/*cpython*.so and build/ ." + ) + logger.error(msg, fn.__name__, e) + raise e + + return wrapper + + +''' +# activation ops +def silu_and_mul(out: torch.Tensor, x: torch.Tensor) -> None: + torch.ops._C.silu_and_mul(out, x) + + +def gelu_and_mul(out: torch.Tensor, x: torch.Tensor) -> None: + torch.ops._C.gelu_and_mul(out, x) + + +def gelu_tanh_and_mul(out: torch.Tensor, x: torch.Tensor) -> None: + torch.ops._C.gelu_tanh_and_mul(out, x) + + +def fatrelu_and_mul(out: torch.Tensor, + x: torch.Tensor, + threshold: float = 0.0) -> None: + torch.ops._C.fatrelu_and_mul(out, x, threshold) + + +def gelu_fast(out: torch.Tensor, x: torch.Tensor) -> None: + torch.ops._C.gelu_fast(out, x) + + +def gelu_new(out: torch.Tensor, x: torch.Tensor) -> None: + torch.ops._C.gelu_new(out, x) + + +def gelu_quick(out: torch.Tensor, x: torch.Tensor) -> None: + torch.ops._C.gelu_quick(out, x) + + +# page attention ops +def paged_attention_v1( + out: torch.Tensor, + query: torch.Tensor, + key_cache: torch.Tensor, + value_cache: torch.Tensor, + num_kv_heads: int, + scale: float, + block_tables: torch.Tensor, + seq_lens: torch.Tensor, + block_size: int, + max_seq_len: int, + alibi_slopes: Optional[torch.Tensor], + kv_cache_dtype: str, + k_scale: float, + v_scale: float, + tp_rank: int = 0, + blocksparse_local_blocks: int = 0, + blocksparse_vert_stride: int = 0, + blocksparse_block_size: int = 64, + blocksparse_head_sliding_step: int = 0, +) -> None: + torch.ops._C.paged_attention_v1( + out, query, key_cache, value_cache, num_kv_heads, scale, block_tables, + seq_lens, block_size, max_seq_len, alibi_slopes, kv_cache_dtype, + k_scale, v_scale, tp_rank, blocksparse_local_blocks, + blocksparse_vert_stride, blocksparse_block_size, + blocksparse_head_sliding_step) + + +def paged_attention_v2( + out: torch.Tensor, + exp_sum: torch.Tensor, + max_logits: torch.Tensor, + tmp_out: torch.Tensor, + query: torch.Tensor, + key_cache: torch.Tensor, + value_cache: torch.Tensor, + num_kv_heads: int, + scale: float, + block_tables: torch.Tensor, + seq_lens: torch.Tensor, + block_size: int, + max_seq_len: int, + alibi_slopes: Optional[torch.Tensor], + kv_cache_dtype: str, + k_scale: float, + v_scale: float, + tp_rank: int = 0, + blocksparse_local_blocks: int = 0, + blocksparse_vert_stride: int = 0, + blocksparse_block_size: int = 64, + blocksparse_head_sliding_step: int = 0, +) -> None: + torch.ops._C.paged_attention_v2( + out, exp_sum, max_logits, tmp_out, query, key_cache, value_cache, + num_kv_heads, scale, block_tables, seq_lens, block_size, max_seq_len, + alibi_slopes, kv_cache_dtype, k_scale, v_scale, tp_rank, + blocksparse_local_blocks, blocksparse_vert_stride, + blocksparse_block_size, blocksparse_head_sliding_step) + + +def paged_attention_rocm( + out: torch.Tensor, + exp_sum: torch.Tensor, + max_logits: torch.Tensor, + tmp_out: torch.Tensor, + query: torch.Tensor, + key_cache: torch.Tensor, + value_cache: torch.Tensor, + num_kv_heads: int, + scale: float, + block_tables: torch.Tensor, + seq_lens: torch.Tensor, + block_size: int, + max_seq_len: int, + alibi_slopes: Optional[torch.Tensor], + kv_cache_dtype: str, + k_scale: float, + v_scale: float, +) -> None: + torch.ops._rocm_C.paged_attention(out, exp_sum, max_logits, tmp_out, query, + key_cache, value_cache, num_kv_heads, + scale, block_tables, seq_lens, + block_size, max_seq_len, alibi_slopes, + kv_cache_dtype, k_scale, v_scale) + + +# pos encoding ops +def rotary_embedding( + positions: torch.Tensor, + query: torch.Tensor, + key: torch.Tensor, + head_size: int, + cos_sin_cache: torch.Tensor, + is_neox: bool, +) -> None: + torch.ops._C.rotary_embedding(positions, query, key, head_size, + cos_sin_cache, is_neox) + + +def batched_rotary_embedding(positions: torch.Tensor, query: torch.Tensor, + key: torch.Tensor, head_size: int, + cos_sin_cache: torch.Tensor, is_neox: bool, + rot_dim: int, + cos_sin_cache_offsets: torch.Tensor) -> None: + torch.ops._C.batched_rotary_embedding(positions, query, key, head_size, + cos_sin_cache, is_neox, rot_dim, + cos_sin_cache_offsets) + + +# layer norm ops +def rms_norm(out: torch.Tensor, input: torch.Tensor, weight: torch.Tensor, + epsilon: float) -> None: + torch.ops._C.rms_norm(out, input, weight, epsilon) + + +def fused_add_rms_norm(input: torch.Tensor, residual: torch.Tensor, + weight: torch.Tensor, epsilon: float) -> None: + torch.ops._C.fused_add_rms_norm(input, residual, weight, epsilon) + + +def advance_step_flashattn(num_seqs: int, num_queries: int, block_size: int, + input_tokens: torch.Tensor, + sampled_token_ids: torch.Tensor, + input_positions: torch.Tensor, + seq_lens: torch.Tensor, slot_mapping: torch.Tensor, + block_tables: torch.Tensor) -> None: + """Advance a step on GPU for existing inputs for a multi-step runner""" + return torch.ops._C.advance_step_flashattn(num_seqs, num_queries, + block_size, input_tokens, + sampled_token_ids, + input_positions, seq_lens, + slot_mapping, block_tables) + + +def advance_step_flashinfer(num_seqs: int, num_queries: int, block_size: int, + input_tokens: torch.Tensor, + sampled_token_ids: torch.Tensor, + input_positions: torch.Tensor, + seq_lens: torch.Tensor, slot_mapping: torch.Tensor, + block_tables: torch.Tensor, + paged_kv_indices: torch.Tensor, + paged_kv_indptr: torch.Tensor, + paged_kv_last_page_len: torch.Tensor, + block_table_bound: torch.Tensor) -> None: + + return torch.ops._C.advance_step_flashinfer( + num_seqs, num_queries, block_size, input_tokens, sampled_token_ids, + input_positions, seq_lens, slot_mapping, block_tables, + paged_kv_indices, paged_kv_indptr, paged_kv_last_page_len, + block_table_bound) + + +# quantization ops +# awq +def awq_dequantize(qweight: torch.Tensor, scales: torch.Tensor, + zeros: torch.Tensor, split_k_iters: int, thx: int, + thy: int) -> torch.Tensor: + if envs.VLLM_USE_TRITON_AWQ: + from vllm.model_executor.layers.quantization.awq_triton import ( + awq_dequantize_triton) + return awq_dequantize_triton(qweight, scales, zeros) + return torch.ops._C.awq_dequantize(qweight, scales, zeros, split_k_iters, + thx, thy) + + +def awq_gemm(input: torch.Tensor, qweight: torch.Tensor, qzeros: torch.Tensor, + scales: torch.Tensor, split_k_iters: int) -> torch.Tensor: + if envs.VLLM_USE_TRITON_AWQ: + from vllm.model_executor.layers.quantization.awq_triton import ( + awq_gemm_triton) + return awq_gemm_triton(input, qweight, qzeros, scales, split_k_iters) + return torch.ops._C.awq_gemm(input, qweight, qzeros, scales, split_k_iters) + + +# gptq +def gptq_gemm(a: torch.Tensor, b_q_weight: torch.Tensor, + b_gptq_qzeros: torch.Tensor, b_gptq_scales: torch.Tensor, + b_g_idx: torch.Tensor, use_exllama: bool, + bit: int) -> torch.Tensor: + return torch.ops._C.gptq_gemm(a, b_q_weight, b_gptq_qzeros, b_gptq_scales, + b_g_idx, use_exllama, bit) + + +if hasattr(torch.ops._C, "gptq_gemm"): + + @register_fake("_C::gptq_gemm") + def _gptq_gemm_fake(a: torch.Tensor, b_q_weight: torch.Tensor, + b_gptq_qzeros: torch.Tensor, + b_gptq_scales: torch.Tensor, b_g_idx: torch.Tensor, + use_exllama: bool, bit: int) -> torch.Tensor: + return torch.empty((a.size(0), b_q_weight.size(1)), + dtype=a.dtype, + device=a.device) + + +def gptq_shuffle(q_weight: torch.Tensor, q_perm: torch.Tensor, + bit: int) -> None: + torch.ops._C.gptq_shuffle(q_weight, q_perm, bit) + + +# marlin +def marlin_gemm(a: torch.Tensor, b_q_weight: torch.Tensor, + b_scales: torch.Tensor, workspace: torch.Tensor, size_m: int, + size_n: int, size_k: int) -> torch.Tensor: + return torch.ops._C.marlin_gemm(a, b_q_weight, b_scales, workspace, size_m, + size_n, size_k) + + +# marlin_24 +def gptq_marlin_24_gemm(a: torch.Tensor, b_q_weight: torch.Tensor, + b_meta: torch.Tensor, b_scales: torch.Tensor, + workspace: torch.Tensor, b_q_type: ScalarType, + size_m: int, size_n: int, size_k: int) -> torch.Tensor: + return torch.ops._C.gptq_marlin_24_gemm(a, b_q_weight, b_meta, b_scales, + workspace, b_q_type.id, size_m, + size_n, size_k) + + +if hasattr(torch.ops._C, "gptq_marlin_24_gemm"): + + @register_fake("_C::gptq_marlin_24_gemm") + def _gptq_marlin_24_gemm_fake(a: torch.Tensor, b_q_weight: torch.Tensor, + b_meta: torch.Tensor, b_scales: torch.Tensor, + workspace: torch.Tensor, + b_q_type: ScalarType, size_m: torch.SymInt, + size_n: torch.SymInt, + size_k: torch.SymInt) -> torch.Tensor: + return torch.empty((size_m, size_n), device=a.device, dtype=a.dtype) + + @register_fake("_C::gptq_marlin_gemm") + def _gptq_marlin_gemm_fake(a: torch.Tensor, + b_q_weight: torch.Tensor, + b_scales: torch.Tensor, + b_zeros: torch.Tensor, + g_idx: torch.Tensor, + perm: torch.Tensor, + workspace: torch.Tensor, + b_q_type: ScalarType, + size_m: torch.SymInt, + size_n: torch.SymInt, + size_k: torch.SymInt, + is_k_full: bool, + has_zp: bool = False, + use_fp32_reduce: bool = False) -> torch.Tensor: + return torch.empty((size_m, size_n), device=a.device, dtype=a.dtype) + + @register_fake("_C::ggml_dequantize") + def _ggml_dequantize_fake(W: torch.Tensor, quant_type: int, + m: torch.SymInt, + n: torch.SymInt) -> torch.Tensor: + return torch.empty((m, n), dtype=torch.float16, device=W.device) + + @register_fake("_C::ggml_mul_mat_vec_a8") + def _ggml_mul_mat_vec_a8_fake( + W: torch.Tensor, + X: torch.Tensor, + quant_type: int, + row: torch.SymInt, + ) -> torch.Tensor: + return torch.empty((1, row), dtype=torch.float16, device=W.device) + + @register_fake("_C::ggml_mul_mat_a8") + def _ggml_mul_mat_a8_fake( + W: torch.Tensor, + X: torch.Tensor, + quant_type: int, + row: torch.SymInt, + ) -> torch.Tensor: + batch = X.size(0) + return torch.empty((batch, row), dtype=torch.float16, device=W.device) + + @register_fake("_C::marlin_qqq_gemm") + def _marlin_qqq_gemm_fake(a: torch.Tensor, b_q_weight: torch.Tensor, + s_tok: torch.Tensor, s_ch: torch.Tensor, + s_group: torch.Tensor, workspace: torch.Tensor, + size_m: torch.SymInt, size_n: torch.SymInt, + size_k: torch.SymInt) -> torch.Tensor: + return torch.empty((size_m, size_n), + dtype=torch.float16, + device=a.device) + + @register_fake("_C::marlin_gemm") + def _marlin_gemm_fake(a: torch.Tensor, b_q_weight: torch.Tensor, + b_scales: torch.Tensor, workspace: torch.Tensor, + size_m: torch.SymInt, size_n: torch.SymInt, + size_k: torch.SymInt) -> torch.Tensor: + return torch.empty((size_m, size_n), + dtype=torch.float16, + device=a.device) + + @register_fake("_C::awq_dequantize") + def _awq_dequantize_fake(qweight: torch.Tensor, scales: torch.Tensor, + zeros: torch.Tensor, split_k_iters: torch.SymInt, + thx: int, thy: int) -> torch.Tensor: + in_c = qweight.size(0) + qout_c = qweight.size(1) + out_c = qout_c * 8 + return torch.empty((in_c, out_c), + dtype=scales.dtype, + device=scales.device) + + @register_fake("_C::awq_gemm") + def _awq_gemm_fake(input: torch.Tensor, qweight: torch.Tensor, + qzeros: torch.Tensor, scales: torch.Tensor, + split_k_iters: torch.SymInt) -> torch.Tensor: + num_in_feats = input.size(0) + return torch.empty((split_k_iters, num_in_feats, qweight.size(1) * 8), + dtype=input.dtype, + device=input.device).sum(0) + + @register_fake("_C::aqlm_gemm") + def _aqlm_gemm_fake(input: torch.Tensor, codes: torch.Tensor, + codebooks: torch.Tensor, scales: torch.Tensor, + codebook_partition_sizes: List[int], + bias: Optional[torch.Tensor]) -> torch.Tensor: + out_features = codes.size(0) * codebooks.size(2) + flat_input = input.reshape((-1, input.size(-1))) + flat_output = torch.empty((flat_input.size(0), out_features), + dtype=input.dtype, + device=input.device) + + output_sizes = list(input.shape) + output_sizes.pop() + output_sizes.append(-1) + return flat_output.reshape(tuple(output_sizes)) + + @register_fake("_C::aqlm_dequant") + def _aqlm_dequant_fake( + codes: torch.Tensor, codebooks: torch.Tensor, + codebook_partition_sizes: List[int]) -> torch.Tensor: + in_features = codes.size(1) * 8 + out_features = codes.size(0) + return torch.empty((out_features, in_features), + dtype=codebooks.dtype, + device=codebooks.device) + + @register_fake("_C::fp8_marlin_gemm") + def _fp8_marlin_gemm_fake(a: torch.Tensor, b_q_weight: torch.Tensor, + b_scales: torch.Tensor, workspace: torch.Tensor, + num_bits: int, size_m: torch.SymInt, + size_n: torch.SymInt, + size_k: torch.SymInt) -> torch.Tensor: + return torch.empty((size_m, size_n), dtype=a.dtype, device=a.device) + + @register_fake("_C::machete_gemm") + def machete_gemm_fake( + a: torch.Tensor, + # Should be the tensor returned by machete_prepack_B + b_q: torch.Tensor, + b_type: ScalarType, + b_scales: Optional[torch.Tensor] = None, + b_zeros: Optional[torch.Tensor] = None, + b_group_size: Optional[int] = None, + c: Optional[torch.Tensor] = None, + alpha: Optional[float] = None, + beta: Optional[float] = None, + schedule: Optional[str] = None, + ) -> torch.Tensor: + m = a.size(0) + n = b_q.size(1) + return torch.empty((m, n), device=a.device, dtype=a.dtype) + + @register_fake("_C::machete_prepack_B") + def machete_prepack_B_fake(b_q_weight: torch.Tensor, + b_type: ScalarType) -> torch.Tensor: + return torch.empty_like(b_q_weight, + memory_format=torch.contiguous_format) + + +# cutlass +def cutlass_scaled_mm_supports_fp8(cuda_device_capability: int) -> bool: + return torch.ops._C.cutlass_scaled_mm_supports_fp8(cuda_device_capability) + + +def cutlass_scaled_mm(a: torch.Tensor, + b: torch.Tensor, + scale_a: torch.Tensor, + scale_b: torch.Tensor, + out_dtype: torch.dtype, + bias: Optional[torch.Tensor] = None) -> torch.Tensor: + assert (b.shape[0] % 16 == 0 and b.shape[1] % 16 == 0) + assert (out_dtype is torch.bfloat16 or out_dtype is torch.float16) + assert bias is None or bias.shape[0] == b.shape[ + 1] and bias.dtype == out_dtype + + m = a.shape[0] + n = b.shape[1] + + if current_platform.is_rocm(): + triton_scaled_mm_module = importlib.import_module( + "vllm.model_executor.layers.quantization.compressed_tensors." + "triton_scaled_mm") + triton_scaled_mm = triton_scaled_mm_module.triton_scaled_mm + return triton_scaled_mm(a, b, scale_a, scale_b, out_dtype, bias) + + out = torch.empty((m, n), dtype=out_dtype, device=a.device) + + torch.ops._C.cutlass_scaled_mm(out, a, b, scale_a, scale_b, bias) + + return out + + +def cutlass_scaled_mm_azp(a: torch.Tensor, + b: torch.Tensor, + scale_a: torch.Tensor, + scale_b: torch.Tensor, + out_dtype: torch.dtype, + azp_adj: torch.Tensor, + azp: Optional[torch.Tensor] = None, + bias: Optional[torch.Tensor] = None) -> torch.Tensor: + """ + :param azp_adj: In the per-tensor case, this should include the azp. + Always per-channel. + :param azp: Only set in the per-token case. Per-token if set. + """ + assert (b.shape[0] % 16 == 0 and b.shape[1] % 16 == 0) + assert (out_dtype is torch.bfloat16 or out_dtype is torch.float16) + assert bias is None or bias.numel( + ) == b.shape[1] and bias.dtype == out_dtype + assert azp is None or azp.numel() == a.shape[0] + + m = a.shape[0] + n = b.shape[1] + out = torch.empty((m, n), dtype=out_dtype, device=a.device) + + torch.ops._C.cutlass_scaled_mm_azp(out, a, b, scale_a, scale_b, azp_adj, + azp, bias) + return out + + +# aqlm +def aqlm_gemm(input: torch.Tensor, codes: torch.Tensor, + codebooks: torch.Tensor, scales: torch.Tensor, + codebook_partition_sizes: List[int], + bias: Optional[torch.Tensor]) -> torch.Tensor: + return torch.ops._C.aqlm_gemm(input, codes, codebooks, scales, + codebook_partition_sizes, bias) + + +def aqlm_dequant(codes: torch.Tensor, codebooks: torch.Tensor, + codebook_partition_sizes: List[int]) -> torch.Tensor: + return torch.ops._C.aqlm_dequant(codes, codebooks, + codebook_partition_sizes) + + +# gptq_marlin +def gptq_marlin_repack(b_q_weight: torch.Tensor, perm: torch.Tensor, + size_k: int, size_n: int, + num_bits: int) -> torch.Tensor: + return torch.ops._C.gptq_marlin_repack(b_q_weight, perm, size_k, size_n, + num_bits) + + +# gptq_marlin +def awq_marlin_repack(b_q_weight: torch.Tensor, size_k: int, size_n: int, + num_bits: int) -> torch.Tensor: + return torch.ops._C.awq_marlin_repack(b_q_weight, size_k, size_n, num_bits) + + +def gptq_marlin_moe_repack(b_q_weight: torch.Tensor, perm: torch.Tensor, + size_k: int, size_n: int, + num_bits: int) -> torch.Tensor: + num_experts = b_q_weight.shape[0] + assert size_k % 16 == 0 + output = torch.empty((num_experts, size_k // 16, size_n * (num_bits // 2)), + device=b_q_weight.device, + dtype=b_q_weight.dtype) + for e in range(num_experts): + output[e] = torch.ops._C.gptq_marlin_repack(b_q_weight[e], perm[e], + size_k, size_n, num_bits) + return output + + +def awq_marlin_moe_repack(b_q_weight: torch.Tensor, perm: torch.Tensor, + size_k: int, size_n: int, + num_bits: int) -> torch.Tensor: + num_experts = b_q_weight.shape[0] + assert size_k % 16 == 0 + output = torch.empty((num_experts, size_k // 16, size_n * (num_bits // 2)), + device=b_q_weight.device, + dtype=b_q_weight.dtype) + for e in range(num_experts): + output[e] = torch.ops._C.awq_marlin_repack(b_q_weight[e], size_k, + size_n, num_bits) + return output + + +def gptq_marlin_gemm(a: torch.Tensor, + b_q_weight: torch.Tensor, + b_scales: torch.Tensor, + b_zeros: torch.Tensor, + g_idx: torch.Tensor, + perm: torch.Tensor, + workspace: torch.Tensor, + b_q_type: ScalarType, + size_m: int, + size_n: int, + size_k: int, + is_k_full: bool, + has_zp: bool = False, + use_fp32_reduce: bool = False) -> torch.Tensor: + return torch.ops._C.gptq_marlin_gemm(a, b_q_weight, b_scales, b_zeros, + g_idx, perm, workspace, b_q_type.id, + size_m, size_n, size_k, is_k_full, + has_zp, use_fp32_reduce) + + +# fp8 marlin +def fp8_marlin_gemm(a: torch.Tensor, b_q_weight: torch.Tensor, + b_scales: torch.Tensor, workspace: torch.Tensor, + num_bits: int, size_m: int, size_n: int, + size_k: int) -> torch.Tensor: + return torch.ops._C.fp8_marlin_gemm(a, b_q_weight, b_scales, workspace, + num_bits, size_m, size_n, size_k) + + +# machete +def machete_supported_schedules(b_type: ScalarType) -> List[str]: + return torch.ops._C.machete_supported_schedules(b_type.id) + + +def machete_gemm( + a: torch.Tensor, + b_q: torch.Tensor, # Should be the tensor returned by machete_prepack_B + b_type: ScalarType, + b_scales: Optional[torch.Tensor] = None, + b_zeros: Optional[torch.Tensor] = None, + b_group_size: Optional[int] = None, + c: Optional[torch.Tensor] = None, + alpha: Optional[float] = None, + beta: Optional[float] = None, + schedule: Optional[str] = None, +) -> torch.Tensor: + return torch.ops._C.machete_gemm(a, b_q, b_type.id, b_scales, b_zeros, + b_group_size, c, alpha, beta, schedule) + + +def machete_prepack_B(b_q_weight: torch.Tensor, + b_type: ScalarType) -> torch.Tensor: + return torch.ops._C.machete_prepack_B(b_q_weight, b_type.id) + + +if hasattr(torch.ops._C, "permute_cols"): + + @register_fake("_C::permute_cols") + def _permute_cols_fake(a: torch.Tensor, + perm: torch.Tensor) -> torch.Tensor: + return torch.empty_like(a) + + +def permute_cols(a: torch.Tensor, perm: torch.Tensor) -> torch.Tensor: + return torch.ops._C.permute_cols(a, perm) + + +# fp8 +def scaled_fp8_quant( + input: torch.Tensor, + scale: Optional[torch.Tensor] = None, + num_token_padding: Optional[int] = None, + scale_ub: Optional[torch.Tensor] = None, + use_per_token_if_dynamic: bool = False, +) -> Tuple[torch.Tensor, torch.Tensor]: + """ + Quantize input tensor to FP8 and return quantized tensor and scale. + + This function supports both static and dynamic quantization: If you + provide the scale, it will use static scaling and if you omit it, + the scale will be determined dynamically. The function also allows + optional padding of the output tensors for downstream kernels that + will benefit from padding. + + Args: + input: The input tensor to be quantized to FP8 + scale: Optional scaling factor for the FP8 quantization + scale_ub: Optional upper bound for scaling factor in dynamic + per token case + num_token_padding: If specified, pad the first dimension + of the output to at least this value. + use_per_token_if_dynamic: Whether to do per_tensor or per_token + in the dynamic quantization case. + + Returns: + Tuple[torch.Tensor, torch.Tensor]: The output tensor in FP8 and + scaling factor. + """ + # This code assumes batch_dim and num_tokens are flattened + assert (input.ndim == 2) + shape: Union[Tuple[int, int], torch.Size] = input.shape + # For rocm, the output fp8 dtype is torch.float_e3m3fnuz + out_dtype: torch.dtype = torch.float8_e4m3fnuz \ + if current_platform.is_rocm() else torch.float8_e4m3fn + if num_token_padding: + shape = (max(num_token_padding, input.shape[0]), shape[1]) + output = torch.empty(shape, device=input.device, dtype=out_dtype) + + if scale is None: + if use_per_token_if_dynamic: + scale = torch.empty((shape[0], 1), + device=input.device, + dtype=torch.float32) + torch.ops._C.dynamic_per_token_scaled_fp8_quant( + output, input, scale, scale_ub) + else: + scale = torch.zeros(1, device=input.device, dtype=torch.float32) + torch.ops._C.dynamic_scaled_fp8_quant(output, input, scale) + else: + # num_token_padding not implemented for this case + assert (scale.numel() == 1 or num_token_padding is None) + torch.ops._C.static_scaled_fp8_quant(output, input, scale) + + return output, scale + + +# int8 +def scaled_int8_quant( + input: torch.Tensor, + scale: Optional[torch.Tensor] = None, + azp: Optional[torch.Tensor] = None, + symmetric: bool = True +) -> Tuple[torch.Tensor, torch.Tensor, Optional[torch.Tensor]]: + """ + Quantize the input tensor to int8 and return the quantized tensor and scale, and maybe azp. + + Args: + input: The input tensor to be quantized to int8. + scale: Optional scaling factor for the int8 quantization. + When not provided, we invoke dynamic-per-token quantization. + azp: Optional zero-point for the int8 quantization. + Must be provided for asymmetric quantization if `scale` is provided. + symmetric: Whether to use symmetric quantization (scale only, azp ignored). + + Returns: + Tuple[torch.Tensor, torch.Tensor, Optional[torch.Tensor]] : Output int8 tensor, scales, and optionally azp. + """ + output = torch.empty_like(input, dtype=torch.int8) + if scale is not None: + # static-per-tensor quantization. + assert symmetric == ( + azp is + None), "azp must only be provided for asymmetric quantization." + torch.ops._C.static_scaled_int8_quant(output, input, scale, azp) + return output, scale, azp + + # dynamic-per-token quantization. + input_scales = torch.empty((input.numel() // input.shape[-1], 1), + device=input.device, + dtype=torch.float32) + input_azp = None if symmetric else torch.empty_like(input_scales, + dtype=torch.int32) + torch.ops._C.dynamic_scaled_int8_quant(output, input, input_scales, + input_azp) + return output, input_scales, input_azp + + +# qqq ops +def marlin_qqq_gemm(a: torch.Tensor, b_q_weight: torch.Tensor, + s_tok: torch.Tensor, s_ch: torch.Tensor, + s_group: torch.Tensor, workspace: torch.Tensor, + size_m: int, size_n: int, size_k: int) -> torch.Tensor: + return torch.ops._C.marlin_qqq_gemm(a, b_q_weight, s_tok, s_ch, s_group, + workspace, size_m, size_n, size_k) + + +# gguf +def ggml_dequantize(W: torch.Tensor, quant_type: int, m: int, + n: int) -> torch.Tensor: + return torch.ops._C.ggml_dequantize(W, quant_type, m, n) + + +def ggml_mul_mat_vec_a8( + W: torch.Tensor, + X: torch.Tensor, + quant_type: int, + row: int, +) -> torch.Tensor: + return torch.ops._C.ggml_mul_mat_vec_a8(W, X, quant_type, row) + + +def ggml_mul_mat_a8( + W: torch.Tensor, + X: torch.Tensor, + quant_type: int, + row: int, +) -> torch.Tensor: + return torch.ops._C.ggml_mul_mat_a8(W, X, quant_type, row) + + +# mamba +def causal_conv1d_fwd(x: torch.Tensor, weight: torch.Tensor, + bias_: Optional[torch.Tensor], + conv_states: Optional[torch.Tensor], + query_start_loc: Optional[torch.Tensor], + cache_indices: Optional[torch.Tensor], + has_initial_state: Optional[torch.Tensor], + silu_activation: bool, pad_slot_id: int): + torch.ops._C.causal_conv1d_fwd(x, weight, bias_, conv_states, + query_start_loc, cache_indices, + has_initial_state, silu_activation, + pad_slot_id) + + +def causal_conv1d_update(x: torch.Tensor, conv_state: torch.Tensor, + weight: torch.Tensor, bias_: Optional[torch.Tensor], + silu_activation: bool, + cache_seqlens: Optional[torch.Tensor], + conv_state_indices: Optional[torch.Tensor], + pad_slot_id: int): + torch.ops._C.causal_conv1d_update(x, conv_state, weight, bias_, + silu_activation, cache_seqlens, + conv_state_indices, pad_slot_id) + + +def selective_scan_fwd(u: torch.Tensor, delta: torch.Tensor, A: torch.Tensor, + B: torch.Tensor, C: torch.Tensor, + D_: Optional[torch.Tensor], z_: Optional[torch.Tensor], + delta_bias_: Optional[torch.Tensor], + delta_softplus: bool, + query_start_loc: Optional[torch.Tensor], + cache_indices: Optional[torch.Tensor], + has_initial_state: Optional[torch.Tensor], + ssm_states: torch.Tensor, pad_slot_id: int): + torch.ops._C.selective_scan_fwd(u, delta, A, B, C, D_, z_, delta_bias_, + delta_softplus, query_start_loc, + cache_indices, has_initial_state, + ssm_states, pad_slot_id) + + +# moe +def moe_sum(input: torch.Tensor, output: torch.Tensor): + torch.ops._moe_C.moe_sum(input, output) + + +def moe_align_block_size(topk_ids: torch.Tensor, num_experts: int, + block_size: int, sorted_token_ids: torch.Tensor, + experts_ids: torch.Tensor, + num_tokens_post_pad: torch.Tensor) -> None: + torch.ops._moe_C.moe_align_block_size(topk_ids, num_experts, block_size, + sorted_token_ids, experts_ids, + num_tokens_post_pad) + + +def topk_softmax(topk_weights: torch.Tensor, topk_ids: torch.Tensor, + token_expert_indicies: torch.Tensor, + gating_output: float) -> None: + torch.ops._moe_C.topk_softmax(topk_weights, topk_ids, + token_expert_indicies, gating_output) + + +if supports_moe_ops and hasattr(torch.ops._moe_C, "marlin_gemm_moe"): + + @register_fake("_moe_C::marlin_gemm_moe") + def marlin_gemm_moe_fake(a: torch.Tensor, b_q_weights: torch.Tensor, + sorted_ids: torch.Tensor, + topk_weights: torch.Tensor, + topk_ids: torch.Tensor, b_scales: torch.Tensor, + b_zero_points: torch.Tensor, g_idx: torch.Tensor, + perm: torch.Tensor, workspace: torch.Tensor, + b_q_type: ScalarType, size_m: torch.SymInt, + size_n: torch.SymInt, size_k: torch.SymInt, + is_k_full: bool, num_experts: int, topk: int, + moe_block_size: int, replicate_input: bool, + apply_weights: bool) -> torch.Tensor: + return torch.empty((size_m, topk, size_n), + dtype=a.dtype, + device=a.device) + + +def reshape_and_cache( + key: torch.Tensor, + value: torch.Tensor, + key_cache: torch.Tensor, + value_cache: torch.Tensor, + slot_mapping: torch.Tensor, + kv_cache_dtype: str, + k_scale: float, + v_scale: float, +) -> None: + torch.ops._C_cache_ops.reshape_and_cache(key, value, key_cache, + value_cache, slot_mapping, + kv_cache_dtype, k_scale, v_scale) + + +def reshape_and_cache_flash( + key: torch.Tensor, + value: torch.Tensor, + key_cache: torch.Tensor, + value_cache: torch.Tensor, + slot_mapping: torch.Tensor, + kv_cache_dtype: str, + k_scale: float, + v_scale: float, +) -> None: + torch.ops._C_cache_ops.reshape_and_cache_flash(key, value, key_cache, + value_cache, slot_mapping, + kv_cache_dtype, k_scale, + v_scale) + + +def copy_blocks(key_caches: List[torch.Tensor], + value_caches: List[torch.Tensor], + block_mapping: torch.Tensor) -> None: + torch.ops._C_cache_ops.copy_blocks(key_caches, value_caches, block_mapping) + + +def swap_blocks(src: torch.Tensor, dst: torch.Tensor, + block_mapping: torch.Tensor) -> None: + torch.ops._C_cache_ops.swap_blocks(src, dst, block_mapping) + + +def convert_fp8(output: torch.Tensor, + input: torch.Tensor, + scale: float = 1.0, + kv_dtype: str = "fp8") -> None: + torch.ops._C_cache_ops.convert_fp8(output, input, scale, kv_dtype) + + +def get_device_attribute(attribute: int, device: int) -> int: + return torch.ops._C_cuda_utils.get_device_attribute(attribute, device) + + +def get_max_shared_memory_per_block_device_attribute(device: int) -> int: + # ruff: noqa: E501 + return torch.ops._C_cuda_utils.get_max_shared_memory_per_block_device_attribute( + device) +''' + + +# custom ar +def init_custom_ar( + ipc_tensors: List[torch.Tensor], + rank_data: torch.Tensor, + rank: int, + full_nvlink: bool, +) -> int: + return torch.ops._C_vllm_ar.init_custom_ar( + ipc_tensors, rank_data, rank, full_nvlink + ) + + +def all_reduce( + fa: int, + inp: torch.Tensor, + out: torch.Tensor, + reg_buffer: int, + reg_buffer_sz_bytes: int, +) -> None: + torch.ops._C_vllm_ar.all_reduce(fa, inp, out, reg_buffer, reg_buffer_sz_bytes) + + +def dispose(fa: int) -> None: + torch.ops._C_vllm_ar.dispose(fa) + + +def meta_size() -> int: + return torch.ops._C_vllm_ar.meta_size() + + +def register_buffer(fa: int, ipc_tensors: List[int]) -> None: + return torch.ops._C_vllm_ar.register_buffer(fa, ipc_tensors) + + +def get_graph_buffer_ipc_meta(fa: int) -> Tuple[List[int], List[int]]: + return torch.ops._C_vllm_ar.get_graph_buffer_ipc_meta(fa) + + +def register_graph_buffers( + fa: int, handles: List[List[int]], offsets: List[List[int]] +) -> None: + torch.ops._C_vllm_ar.register_graph_buffers(fa, handles, offsets) + + +# temporary fix for https://github.com/vllm-project/vllm/issues/5456 +# TODO: remove this in v0.6.0 +names_and_values = globals() +names_and_values_to_update = {} +# prepare variables to avoid dict size change during iteration +k, v, arg = None, None, None +fn_type = type(lambda x: x) +for k, v in names_and_values.items(): + # find functions that are defined in this file and have torch.Tensor + # in their annotations. `arg == "torch.Tensor"` is used to handle + # the case when users use `import __annotations__` to turn type + # hints into strings. + if ( + isinstance(v, fn_type) + and v.__code__.co_filename == __file__ + and any( + arg is torch.Tensor or arg == "torch.Tensor" + for arg in v.__annotations__.values() + ) + ): + names_and_values_to_update[k] = hint_on_error(v) + +names_and_values.update(names_and_values_to_update) +del names_and_values_to_update, names_and_values, v, k, fn_type diff --git a/python/sglang/srt/distributed/__init__.py b/python/sglang/srt/distributed/__init__.py new file mode 100644 index 0000000000..db325cfabf --- /dev/null +++ b/python/sglang/srt/distributed/__init__.py @@ -0,0 +1,3 @@ +from .communication_op import * +from .parallel_state import * +from .utils import * diff --git a/python/sglang/srt/distributed/communication_op.py b/python/sglang/srt/distributed/communication_op.py new file mode 100644 index 0000000000..07b89a0bd5 --- /dev/null +++ b/python/sglang/srt/distributed/communication_op.py @@ -0,0 +1,33 @@ +from typing import Any, Dict, Optional, Union + +import torch +import torch.distributed + +from .parallel_state import get_tp_group + + +def tensor_model_parallel_all_reduce(input_: torch.Tensor) -> torch.Tensor: + """All-reduce the input tensor across model parallel group.""" + return get_tp_group().all_reduce(input_) + + +def tensor_model_parallel_all_gather( + input_: torch.Tensor, dim: int = -1 +) -> torch.Tensor: + """All-gather the input tensor across model parallel group.""" + return get_tp_group().all_gather(input_, dim) + + +def tensor_model_parallel_gather( + input_: torch.Tensor, dst: int = 0, dim: int = -1 +) -> Optional[torch.Tensor]: + """Gather the input tensor across model parallel group.""" + return get_tp_group().gather(input_, dst, dim) + + +def broadcast_tensor_dict( + tensor_dict: Optional[Dict[Any, Union[torch.Tensor, Any]]] = None, src: int = 0 +): + if not torch.distributed.is_initialized(): + return tensor_dict + return get_tp_group().broadcast_tensor_dict(tensor_dict, src) diff --git a/python/sglang/srt/distributed/device_communicators/__init__.py b/python/sglang/srt/distributed/device_communicators/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/python/sglang/srt/distributed/device_communicators/cuda_wrapper.py b/python/sglang/srt/distributed/device_communicators/cuda_wrapper.py new file mode 100644 index 0000000000..75c5cc93bc --- /dev/null +++ b/python/sglang/srt/distributed/device_communicators/cuda_wrapper.py @@ -0,0 +1,181 @@ +"""This file is a pure Python wrapper for the cudart library. +It avoids the need to compile a separate shared library, and is +convenient for use when we just need to call a few functions. +""" + +import ctypes +import logging +from dataclasses import dataclass +from typing import Any, Dict, List, Optional + +# this line makes it possible to directly load `libcudart.so` using `ctypes` +import torch # noqa + +logger = logging.getLogger(__name__) + +# === export types and functions from cudart to Python === +# for the original cudart definition, please check +# https://docs.nvidia.com/cuda/cuda-runtime-api/index.html + +cudaError_t = ctypes.c_int +cudaMemcpyKind = ctypes.c_int + + +class cudaIpcMemHandle_t(ctypes.Structure): + _fields_ = [("internal", ctypes.c_byte * 128)] + + +@dataclass +class Function: + name: str + restype: Any + argtypes: List[Any] + + +def find_loaded_library(lib_name) -> Optional[str]: + """ + According to according to https://man7.org/linux/man-pages/man5/proc_pid_maps.5.html, + the file `/proc/self/maps` contains the memory maps of the process, which includes the + shared libraries loaded by the process. We can use this file to find the path of the + a loaded library. + """ # noqa + found = False + with open("/proc/self/maps") as f: + for line in f: + if lib_name in line: + found = True + break + if not found: + # the library is not loaded in the current process + return None + # if lib_name is libcudart, we need to match a line with: + # address /path/to/libcudart-hash.so.11.0 + start = line.index("/") + path = line[start:].strip() + filename = path.split("/")[-1] + assert filename.rpartition(".so")[0].startswith( + lib_name + ), f"Unexpected filename: {filename} for library {lib_name}" + return path + + +class CudaRTLibrary: + exported_functions = [ + # ​cudaError_t cudaSetDevice ( int device ) + Function("cudaSetDevice", cudaError_t, [ctypes.c_int]), + # cudaError_t cudaDeviceSynchronize ( void ) + Function("cudaDeviceSynchronize", cudaError_t, []), + # ​cudaError_t cudaDeviceReset ( void ) + Function("cudaDeviceReset", cudaError_t, []), + # const char* cudaGetErrorString ( cudaError_t error ) + Function("cudaGetErrorString", ctypes.c_char_p, [cudaError_t]), + # ​cudaError_t cudaMalloc ( void** devPtr, size_t size ) + Function( + "cudaMalloc", + cudaError_t, + [ctypes.POINTER(ctypes.c_void_p), ctypes.c_size_t], + ), + # ​cudaError_t cudaFree ( void* devPtr ) + Function("cudaFree", cudaError_t, [ctypes.c_void_p]), + # ​cudaError_t cudaMemset ( void* devPtr, int value, size_t count ) + Function( + "cudaMemset", cudaError_t, [ctypes.c_void_p, ctypes.c_int, ctypes.c_size_t] + ), + # ​cudaError_t cudaMemcpy ( void* dst, const void* src, size_t count, cudaMemcpyKind kind ) # noqa + Function( + "cudaMemcpy", + cudaError_t, + [ctypes.c_void_p, ctypes.c_void_p, ctypes.c_size_t, cudaMemcpyKind], + ), + # cudaError_t cudaIpcGetMemHandle ( cudaIpcMemHandle_t* handle, void* devPtr ) # noqa + Function( + "cudaIpcGetMemHandle", + cudaError_t, + [ctypes.POINTER(cudaIpcMemHandle_t), ctypes.c_void_p], + ), + # ​cudaError_t cudaIpcOpenMemHandle ( void** devPtr, cudaIpcMemHandle_t handle, unsigned int flags ) # noqa + Function( + "cudaIpcOpenMemHandle", + cudaError_t, + [ctypes.POINTER(ctypes.c_void_p), cudaIpcMemHandle_t, ctypes.c_uint], + ), + ] + + # class attribute to store the mapping from the path to the library + # to avoid loading the same library multiple times + path_to_library_cache: Dict[str, Any] = {} + + # class attribute to store the mapping from library path + # to the corresponding dictionary + path_to_dict_mapping: Dict[str, Dict[str, Any]] = {} + + def __init__(self, so_file: Optional[str] = None): + if so_file is None: + so_file = find_loaded_library("libcudart") + assert so_file is not None, "libcudart is not loaded in the current process" + if so_file not in CudaRTLibrary.path_to_library_cache: + lib = ctypes.CDLL(so_file) + CudaRTLibrary.path_to_library_cache[so_file] = lib + self.lib = CudaRTLibrary.path_to_library_cache[so_file] + + if so_file not in CudaRTLibrary.path_to_dict_mapping: + _funcs = {} + for func in CudaRTLibrary.exported_functions: + f = getattr(self.lib, func.name) + f.restype = func.restype + f.argtypes = func.argtypes + _funcs[func.name] = f + CudaRTLibrary.path_to_dict_mapping[so_file] = _funcs + self.funcs = CudaRTLibrary.path_to_dict_mapping[so_file] + + def CUDART_CHECK(self, result: cudaError_t) -> None: + if result != 0: + error_str = self.cudaGetErrorString(result) + raise RuntimeError(f"CUDART error: {error_str}") + + def cudaGetErrorString(self, error: cudaError_t) -> str: + return self.funcs["cudaGetErrorString"](error).decode("utf-8") + + def cudaSetDevice(self, device: int) -> None: + self.CUDART_CHECK(self.funcs["cudaSetDevice"](device)) + + def cudaDeviceSynchronize(self) -> None: + self.CUDART_CHECK(self.funcs["cudaDeviceSynchronize"]()) + + def cudaDeviceReset(self) -> None: + self.CUDART_CHECK(self.funcs["cudaDeviceReset"]()) + + def cudaMalloc(self, size: int) -> ctypes.c_void_p: + devPtr = ctypes.c_void_p() + self.CUDART_CHECK(self.funcs["cudaMalloc"](ctypes.byref(devPtr), size)) + return devPtr + + def cudaFree(self, devPtr: ctypes.c_void_p) -> None: + self.CUDART_CHECK(self.funcs["cudaFree"](devPtr)) + + def cudaMemset(self, devPtr: ctypes.c_void_p, value: int, count: int) -> None: + self.CUDART_CHECK(self.funcs["cudaMemset"](devPtr, value, count)) + + def cudaMemcpy( + self, dst: ctypes.c_void_p, src: ctypes.c_void_p, count: int + ) -> None: + cudaMemcpyDefault = 4 + kind = cudaMemcpyDefault + self.CUDART_CHECK(self.funcs["cudaMemcpy"](dst, src, count, kind)) + + def cudaIpcGetMemHandle(self, devPtr: ctypes.c_void_p) -> cudaIpcMemHandle_t: + handle = cudaIpcMemHandle_t() + self.CUDART_CHECK( + self.funcs["cudaIpcGetMemHandle"](ctypes.byref(handle), devPtr) + ) + return handle + + def cudaIpcOpenMemHandle(self, handle: cudaIpcMemHandle_t) -> ctypes.c_void_p: + cudaIpcMemLazyEnablePeerAccess = 1 + devPtr = ctypes.c_void_p() + self.CUDART_CHECK( + self.funcs["cudaIpcOpenMemHandle"]( + ctypes.byref(devPtr), handle, cudaIpcMemLazyEnablePeerAccess + ) + ) + return devPtr diff --git a/python/sglang/srt/distributed/device_communicators/custom_all_reduce.py b/python/sglang/srt/distributed/device_communicators/custom_all_reduce.py new file mode 100644 index 0000000000..f36eead9b8 --- /dev/null +++ b/python/sglang/srt/distributed/device_communicators/custom_all_reduce.py @@ -0,0 +1,309 @@ +import ctypes +import logging +from contextlib import contextmanager +from typing import List, Optional, Union + +import torch +import torch.distributed as dist +import vllm.envs as envs +from torch.distributed import ProcessGroup +from vllm.platforms import current_platform +from vllm.utils import cuda_device_count_stateless + +from sglang.srt import _custom_ops as ops +from sglang.srt.distributed.device_communicators.cuda_wrapper import CudaRTLibrary +from sglang.srt.distributed.device_communicators.custom_all_reduce_utils import ( + gpu_p2p_access_check, +) +from sglang.srt.distributed.parallel_state import in_the_same_node_as + +try: + ops.meta_size() + custom_ar = True +except Exception: + # For AMD GPUs and CPUs + custom_ar = False + +logger = logging.getLogger(__name__) + + +def _can_p2p(rank: int, world_size: int) -> bool: + for i in range(world_size): + if i == rank: + continue + if envs.VLLM_SKIP_P2P_CHECK: + logger.info("Skipping P2P check and trusting the driver's P2P report.") + return torch.cuda.can_device_access_peer(rank, i) + if not gpu_p2p_access_check(rank, i): + return False + return True + + +def is_weak_contiguous(inp: torch.Tensor): + return inp.is_contiguous() or ( + inp.storage().nbytes() - inp.storage_offset() * inp.element_size() + == inp.numel() * inp.element_size() + ) + + +class CustomAllreduce: + + _SUPPORTED_WORLD_SIZES = [2, 4, 6, 8] + + # max_size: max supported allreduce size + def __init__( + self, + group: ProcessGroup, + device: Union[int, str, torch.device], + max_size=8192 * 1024, + ) -> None: + """ + Args: + group: the process group to work on. If None, it will use the + default process group. + device: the device to bind the CustomAllreduce to. If None, + it will be bind to f"cuda:{local_rank}". + It is the caller's responsibility to make sure each communicator + is bind to a unique device, and all communicators in this group + are in the same node. + """ + self._IS_CAPTURING = False + self.disabled = True + + if not custom_ar: + # disable because of missing custom allreduce library + # e.g. in a non-cuda environment + return + + self.group = group + + assert ( + dist.get_backend(group) != dist.Backend.NCCL + ), "CustomAllreduce should be attached to a non-NCCL group." + + if not all(in_the_same_node_as(group, source_rank=0)): + # No need to initialize custom allreduce for multi-node case. + logger.warning( + "Custom allreduce is disabled because this process group" + " spans across nodes." + ) + return + + rank = dist.get_rank(group=self.group) + world_size = dist.get_world_size(group=self.group) + if world_size == 1: + # No need to initialize custom allreduce for single GPU case. + return + + if world_size not in CustomAllreduce._SUPPORTED_WORLD_SIZES: + logger.warning( + "Custom allreduce is disabled due to an unsupported world" + " size: %d. Supported world sizes: %s. To silence this " + "warning, specify disable_custom_all_reduce=True explicitly.", + world_size, + str(CustomAllreduce._SUPPORTED_WORLD_SIZES), + ) + return + + if isinstance(device, int): + device = torch.device(f"cuda:{device}") + elif isinstance(device, str): + device = torch.device(device) + # now `device` is a `torch.device` object + assert isinstance(device, torch.device) + self.device = device + + cuda_visible_devices = envs.CUDA_VISIBLE_DEVICES + if cuda_visible_devices: + device_ids = list(map(int, cuda_visible_devices.split(","))) + else: + device_ids = list(range(cuda_device_count_stateless())) + + physical_device_id = device_ids[device.index] + tensor = torch.tensor([physical_device_id], dtype=torch.int, device="cpu") + gather_list = [ + torch.tensor([0], dtype=torch.int, device="cpu") for _ in range(world_size) + ] + dist.all_gather(gather_list, tensor, group=self.group) + physical_device_ids = [t.item() for t in gather_list] + + # test nvlink first, this will filter out most of the cases + # where custom allreduce is not supported + # this checks hardware and driver support for NVLink + assert current_platform.is_cuda() + from vllm.platforms.cuda import CudaPlatform + + cuda_platform: CudaPlatform = current_platform + full_nvlink = cuda_platform.is_full_nvlink(physical_device_ids) + if world_size > 2 and not full_nvlink: + logger.warning( + "Custom allreduce is disabled because it's not supported on" + " more than two PCIe-only GPUs. To silence this warning, " + "specify disable_custom_all_reduce=True explicitly." + ) + return + # test P2P capability, this checks software/cudaruntime support + # this is expensive to compute at the first time + # then we cache the result + if not _can_p2p(rank, world_size): + logger.warning( + "Custom allreduce is disabled because your platform lacks " + "GPU P2P capability or P2P test failed. To silence this " + "warning, specify disable_custom_all_reduce=True explicitly." + ) + return + + self.disabled = False + # Buffers memory are owned by this Python class and passed to C++. + # Meta data composes of two parts: meta data for synchronization and a + # temporary buffer for storing intermediate allreduce results. + self.meta_ptrs = self.create_shared_buffer( + ops.meta_size() + max_size, group=group + ) + # This is a pre-registered IPC buffer. In eager mode, input tensors + # are first copied into this buffer before allreduce is performed + self.buffer_ptrs = self.create_shared_buffer(max_size, group=group) + # This is a buffer for storing the tuples of pointers pointing to + # IPC buffers from all ranks. Each registered tuple has size of + # 8*world_size bytes where world_size is at most 8. Allocating 8MB + # is enough for 131072 such tuples. The largest model I've seen only + # needs less than 10000 of registered tuples. + self.rank_data = torch.empty( + 8 * 1024 * 1024, dtype=torch.uint8, device=self.device + ) + self.max_size = max_size + self.rank = rank + self.world_size = world_size + self.full_nvlink = full_nvlink + self._ptr = ops.init_custom_ar( + self.meta_ptrs, self.rank_data, rank, self.full_nvlink + ) + ops.register_buffer(self._ptr, self.buffer_ptrs) + + @staticmethod + def create_shared_buffer( + size_in_bytes: int, group: Optional[ProcessGroup] = None + ) -> List[int]: + """ + Creates a shared buffer and returns a list of pointers + representing the buffer on all processes in the group. + """ + lib = CudaRTLibrary() + pointer = lib.cudaMalloc(size_in_bytes) + handle = lib.cudaIpcGetMemHandle(pointer) + world_size = dist.get_world_size(group=group) + rank = dist.get_rank(group=group) + handles = [None] * world_size + dist.all_gather_object(handles, handle, group=group) + + pointers: List[int] = [] + for i, h in enumerate(handles): + if i == rank: + pointers.append(pointer.value) # type: ignore + else: + pointers.append(lib.cudaIpcOpenMemHandle(h).value) # type: ignore + + return pointers + + @staticmethod + def free_shared_buffer( + pointers: List[int], group: Optional[ProcessGroup] = None + ) -> None: + rank = dist.get_rank(group=group) + lib = CudaRTLibrary() + lib.cudaFree(ctypes.c_void_p(pointers[rank])) + + @contextmanager + def capture(self): + """ + The main responsibility of this context manager is the + `register_graph_buffers` call at the end of the context. + It records all the buffer addresses used in the CUDA graph. + """ + try: + self._IS_CAPTURING = True + yield + finally: + self._IS_CAPTURING = False + if not self.disabled: + self.register_graph_buffers() + + def register_graph_buffers(self): + handle, offset = ops.get_graph_buffer_ipc_meta(self._ptr) + logger.info("Registering %d cuda graph addresses", len(offset)) + # We cannot directly use `dist.all_gather_object` here + # because it is incompatible with `gloo` backend under inference mode. + # see https://github.com/pytorch/pytorch/issues/126032 for details. + all_data = [[None, None] for _ in range(dist.get_world_size(group=self.group))] + all_data[self.rank] = [handle, offset] + ranks = sorted(dist.get_process_group_ranks(group=self.group)) + for i, rank in enumerate(ranks): + dist.broadcast_object_list( + all_data[i], src=rank, group=self.group, device="cpu" + ) + # Unpack list of tuples to tuple of lists. + handles = [d[0] for d in all_data] # type: ignore + offsets = [d[1] for d in all_data] # type: ignore + ops.register_graph_buffers(self._ptr, handles, offsets) + + def should_custom_ar(self, inp: torch.Tensor): + if self.disabled: + return False + inp_size = inp.numel() * inp.element_size() + # custom allreduce requires input byte size to be multiples of 16 + if inp_size % 16 != 0: + return False + if not is_weak_contiguous(inp): + return False + # for 4 or more non NVLink-capable GPUs, custom allreduce provides + # little performance improvement over NCCL. + if self.world_size == 2 or self.full_nvlink: + return inp_size < self.max_size + return False + + def all_reduce( + self, inp: torch.Tensor, *, out: torch.Tensor = None, registered: bool = False + ): + """Performs an out-of-place all reduce. + + If registered is True, this assumes inp's pointer is already + IPC-registered. Otherwise, inp is first copied into a pre-registered + buffer. + """ + if out is None: + out = torch.empty_like(inp) + if registered: + ops.all_reduce(self._ptr, inp, out, 0, 0) + else: + ops.all_reduce( + self._ptr, inp, out, self.buffer_ptrs[self.rank], self.max_size + ) + return out + + def custom_all_reduce(self, input: torch.Tensor) -> Optional[torch.Tensor]: + """The main allreduce API that provides support for cuda graph.""" + # When custom allreduce is disabled, this will be None. + if self.disabled or not self.should_custom_ar(input): + return None + if self._IS_CAPTURING: + if torch.cuda.is_current_stream_capturing(): + return self.all_reduce(input, registered=True) + else: + # If warm up, mimic the allocation pattern since custom + # allreduce is out-of-place. + return torch.empty_like(input) + else: + # Note: outside of cuda graph context, custom allreduce incurs a + # cost of cudaMemcpy, which should be small (<=1% of overall + # latency) compared to the performance gain of using custom kernels + return self.all_reduce(input, registered=False) + + def close(self): + if not self.disabled and self._ptr: + ops.dispose(self._ptr) + self._ptr = 0 + self.free_shared_buffer(self.meta_ptrs) + self.free_shared_buffer(self.buffer_ptrs) + + def __del__(self): + self.close() diff --git a/python/sglang/srt/distributed/device_communicators/custom_all_reduce_utils.py b/python/sglang/srt/distributed/device_communicators/custom_all_reduce_utils.py new file mode 100644 index 0000000000..6c3f3c5c6b --- /dev/null +++ b/python/sglang/srt/distributed/device_communicators/custom_all_reduce_utils.py @@ -0,0 +1,275 @@ +import ctypes +import json +import logging +import os +import pickle +import subprocess +import sys +import tempfile +from itertools import product +from typing import Dict, List, Optional, Sequence + +import torch.distributed as dist +import torch.multiprocessing as mp +import vllm.envs as envs +from vllm.utils import cuda_device_count_stateless, update_environment_variables + +from sglang.srt.distributed.device_communicators.cuda_wrapper import CudaRTLibrary + +logger = logging.getLogger(__name__) + + +def producer( + batch_src: Sequence[int], + producer_queue, + consumer_queue, + result_queue, + cuda_visible_devices: Optional[str] = None, +): + if cuda_visible_devices is not None: + update_environment_variables({"CUDA_VISIBLE_DEVICES": cuda_visible_devices}) + + lib = CudaRTLibrary() + for i in batch_src: + lib.cudaSetDevice(i) + pointer = lib.cudaMalloc(1024) + lib.cudaMemset(pointer, 1, 1024) + lib.cudaDeviceSynchronize() + handle = lib.cudaIpcGetMemHandle(pointer) + producer_queue.put(handle) + open_success = consumer_queue.get() + if open_success: + # use two queues to simulate barrier + producer_queue.put(0) + consumer_queue.get() + # check if the memory is modified + host_data = (ctypes.c_char * 1024)() + lib.cudaMemcpy(host_data, pointer, 1024) # type: ignore + for i in range(1024): + if ord(host_data[i]) != 2: + open_success = False + break + result_queue.put(open_success) + lib.cudaDeviceReset() + + +def consumer( + batch_tgt: Sequence[int], + producer_queue, + consumer_queue, + result_queue, + cuda_visible_devices: Optional[str] = None, +): + if cuda_visible_devices is not None: + update_environment_variables({"CUDA_VISIBLE_DEVICES": cuda_visible_devices}) + + lib = CudaRTLibrary() + for j in batch_tgt: + lib.cudaSetDevice(j) + handle = producer_queue.get() + open_success = False + try: + pointer = lib.cudaIpcOpenMemHandle(handle) # type: ignore + open_success = True + except RuntimeError: + # cannot error out here, because the producer process + # is still waiting for the response. + pass + consumer_queue.put(open_success) + if open_success: + # modify the memory + lib.cudaMemset(pointer, 2, 1024) + lib.cudaDeviceSynchronize() + # use two queues to simulate barrier + producer_queue.get() + consumer_queue.put(0) + # check if the memory is modified + host_data = (ctypes.c_char * 1024)() + lib.cudaMemcpy(host_data, pointer, 1024) # type: ignore + for i in range(1024): + if ord(host_data[i]) != 2: + open_success = False + break + result_queue.put(open_success) + lib.cudaDeviceReset() + + +def can_actually_p2p( + batch_src: Sequence[int], + batch_tgt: Sequence[int], +) -> Sequence[bool]: + """ + Usually, checking if P2P access is enabled can be done by + `torch.cuda.can_device_access_peer(src, tgt)`. However, sometimes + the driver might be broken, and `torch.cuda.can_device_access_peer(src, tgt)` + returns `True` even if P2P access is not actually possible. + See https://github.com/vllm-project/vllm/issues/2728 and + https://forums.developer.nvidia.com/t/direct-gpu-gpu-communication-does-not-seem-to-work-properly/283264/10 + Therefore, we have to perform a real P2P access to check if it is actually + possible. + + Note on p2p and cuda IPC: + Usually, one process uses one GPU: + GPU src --> cuda context src --> tensor src --> process src + + We need to combine p2p and cuda IPC, so that: + GPU src --> cuda context src --> tensor src --> process src + |shared| + GPU tgt --> cuda context tgt --> tensor tgt --> process tgt + That is to say, process src creates a tensor in GPU src, passes IPC handle to + process tgt, and process tgt accesses the tensor in GPU tgt. Any operation on the + tensor in process tgt will be reflected in the tensor in process src, because + they are the same memory segment. + It is important to note that process tgt accesses the tensor in GPU tgt, not + GPU src. That's why we need p2p access. + + The most time-consuming part is the process creation. To avoid creating + processes for every pair of GPUs, we use batched testing. We create two + processes for testing all pairs of GPUs in batch. The trick is to reset + the device after each test (which is not available in PyTorch). + """ # noqa + cuda_visible_devices = envs.CUDA_VISIBLE_DEVICES + # pass the CUDA_VISIBLE_DEVICES to the child process + # to make sure they see the same set of GPUs + + # make sure the processes are spawned + smp = mp.get_context("spawn") + producer_queue = smp.Queue() + consumer_queue = smp.Queue() + result_queue = smp.Queue() + p_src = smp.Process( + target=producer, + args=( + batch_src, + producer_queue, + consumer_queue, + result_queue, + cuda_visible_devices, + ), + ) + p_tgt = smp.Process( + target=consumer, + args=( + batch_tgt, + producer_queue, + consumer_queue, + result_queue, + cuda_visible_devices, + ), + ) + p_src.start() + p_tgt.start() + p_src.join() + p_tgt.join() + assert p_src.exitcode == 0 and p_tgt.exitcode == 0 + result: List[bool] = [] + for src, tgt in zip(batch_src, batch_tgt): + a = result_queue.get() + b = result_queue.get() + if a != b: + logger.warning( + "Two processes do not agree on the P2P access" + " status on %d -> %d, treat as disabled.", + src, + tgt, + ) + result.append(False) + else: + result.append(a) + return result + + +# why do we need this cache? +# we are testing peer-to-peer (p2p) access between GPUs,across processes. +# if we test it every time, it will be very slow, because we need to create +# N * N * 2 processes, where N is the world size. This is very slow. +# to reduce the time, we use a cache file to store the p2p access status. +# the cache file is generated by the master process if it does not exist. +# then all the processes can read the cache file to check the p2p access status. +# Note that the cache file is suffixed by the CUDA_VISIBLE_DEVICES, so that we +# can have different cache files for different CUDA_VISIBLE_DEVICES settings, +# e.g. used by different vllm engines. The device id in the cache file is a +# **local** device id, i.e. from 0 to num_dev-1, where num_dev is the number +# of visible devices in the vllm engine. +_gpu_p2p_access_cache: Optional[Dict[str, bool]] = None + + +def gpu_p2p_access_check(src: int, tgt: int) -> bool: + """Check if GPU src can access GPU tgt.""" + + # if the cache variable is already calculated, + # read from the cache instead of checking it again + global _gpu_p2p_access_cache + if _gpu_p2p_access_cache is not None: + return _gpu_p2p_access_cache[f"{src}->{tgt}"] + + is_distributed = dist.is_initialized() + + num_dev = cuda_device_count_stateless() + cuda_visible_devices = envs.CUDA_VISIBLE_DEVICES + if cuda_visible_devices is None: + cuda_visible_devices = ",".join(str(i) for i in range(num_dev)) + + path = os.path.join( + envs.VLLM_CACHE_ROOT, f"gpu_p2p_access_cache_for_{cuda_visible_devices}.json" + ) + os.makedirs(os.path.dirname(path), exist_ok=True) + from vllm.distributed.parallel_state import get_world_group + + if (not is_distributed or get_world_group().local_rank == 0) and ( + not os.path.exists(path) + ): + # only the local master process (with local_rank == 0) can + # enter this block to calculate the cache + logger.info("generating GPU P2P access cache in %s", path) + cache: Dict[str, bool] = {} + ids = list(range(num_dev)) + # batch of all pairs of GPUs + batch_src, batch_tgt = zip(*list(product(ids, ids))) + # NOTE: we use `subprocess` rather than `multiprocessing` here + # because the caller might not have `if __name__ == "__main__":`, + # in that case we cannot use spawn method in multiprocessing. + # However, `can_actually_p2p` requires spawn method. + # The fix is, we use `subprocess` to call the function, + # where we have `if __name__ == "__main__":` in this file. + + # use a temporary file to store the result + # we don't use the output of the subprocess directly, + # because the subprocess might produce logging output + with tempfile.NamedTemporaryFile() as output_file: + input_bytes = pickle.dumps((batch_src, batch_tgt, output_file.name)) + returned = subprocess.run( + [sys.executable, __file__], input=input_bytes, capture_output=True + ) + # check if the subprocess is successful + try: + returned.check_returncode() + except Exception as e: + # wrap raised exception to provide more information + raise RuntimeError( + f"Error happened when batch testing " + f"peer-to-peer access from {batch_src} to {batch_tgt}:\n" + f"{returned.stderr.decode()}" + ) from e + with open(output_file.name, "rb") as f: + result = pickle.load(f) + for _i, _j, r in zip(batch_src, batch_tgt, result): + cache[f"{_i}->{_j}"] = r + with open(path, "w") as f: + json.dump(cache, f, indent=4) + if is_distributed: + get_world_group().barrier() + logger.info("reading GPU P2P access cache from %s", path) + with open(path) as f: + cache = json.load(f) + _gpu_p2p_access_cache = cache + return _gpu_p2p_access_cache[f"{src}->{tgt}"] + + +__all__ = ["gpu_p2p_access_check"] + +if __name__ == "__main__": + batch_src, batch_tgt, output_file = pickle.loads(sys.stdin.buffer.read()) + result = can_actually_p2p(batch_src, batch_tgt) + with open(output_file, "wb") as f: + f.write(pickle.dumps(result)) diff --git a/python/sglang/srt/distributed/device_communicators/hpu_communicator.py b/python/sglang/srt/distributed/device_communicators/hpu_communicator.py new file mode 100644 index 0000000000..4ce060e091 --- /dev/null +++ b/python/sglang/srt/distributed/device_communicators/hpu_communicator.py @@ -0,0 +1,46 @@ +import torch +import torch.distributed as dist +from torch.distributed import ProcessGroup +from vllm.platforms import current_platform + +if current_platform.is_hpu(): + import habana_frameworks.torch as htorch # noqa: F401 + + +class HpuCommunicator: + + def __init__(self, group: ProcessGroup): + if not current_platform.is_hpu(): + self.disabled = True + return + self.disabled = False + self.group = group + self.world_size = dist.get_world_size(self.group) + + def all_reduce(self, x: torch.Tensor) -> torch.Tensor: + # FIXME(kzawora): this is a workaround for a bug in Habana PT bridge + # occurring when PT_HPU_ENABLE_LAZY_COLLECTIVES=true env var is used + # (which is required for tensor parallel HPUGraph inference) + htorch.core.mark_step() + dist.all_reduce(x, group=self.group) + return x + + def all_gather(self, x: torch.Tensor, dim: int = -1) -> torch.Tensor: + world_size = self.world_size + if dim < 0: + # Convert negative dim to positive. + dim += x.dim() + input_size = x.size() + # Allocate output tensor. + output_tensor = torch.empty( + (world_size,) + input_size, dtype=x.dtype, device=x.device + ) + # All-gather. + htorch.core.mark_step() + dist.all_gather_into_tensor(output_tensor, x, group=self.group) + # Reshape + output_tensor = output_tensor.movedim(0, dim) + output_tensor = output_tensor.reshape( + input_size[:dim] + (world_size * input_size[dim],) + input_size[dim + 1 :] + ) + return output_tensor diff --git a/python/sglang/srt/distributed/device_communicators/pynccl.py b/python/sglang/srt/distributed/device_communicators/pynccl.py new file mode 100644 index 0000000000..b711830628 --- /dev/null +++ b/python/sglang/srt/distributed/device_communicators/pynccl.py @@ -0,0 +1,204 @@ +# ===================== import region ===================== +import logging +from contextlib import contextmanager +from typing import Optional, Union + +import torch +import torch.distributed as dist +from torch.distributed import ProcessGroup, ReduceOp + +from sglang.srt.distributed.device_communicators.pynccl_wrapper import ( + NCCLLibrary, + buffer_type, + cudaStream_t, + ncclComm_t, + ncclDataTypeEnum, + ncclRedOpTypeEnum, + ncclUniqueId, +) +from sglang.srt.distributed.utils import StatelessProcessGroup + +logger = logging.getLogger(__name__) + + +class PyNcclCommunicator: + + def __init__( + self, + group: Union[ProcessGroup, StatelessProcessGroup], + device: Union[int, str, torch.device], + library_path: Optional[str] = None, + ): + """ + Args: + group: the process group to work on. If None, it will use the + default process group. + device: the device to bind the PyNcclCommunicator to. If None, + it will be bind to f"cuda:{local_rank}". + library_path: the path to the NCCL library. If None, it will + use the default library path. + It is the caller's responsibility to make sure each communicator + is bind to a unique device. + """ + if not isinstance(group, StatelessProcessGroup): + assert dist.is_initialized() + assert ( + dist.get_backend(group) != dist.Backend.NCCL + ), "PyNcclCommunicator should be attached to a non-NCCL group." + # note: this rank is the rank in the group + self.rank = dist.get_rank(group) + self.world_size = dist.get_world_size(group) + else: + self.rank = group.rank + self.world_size = group.world_size + + self.group = group + + # if world_size == 1, no need to create communicator + if self.world_size == 1: + self.available = False + self.disabled = True + self.stream = None + return + try: + self.nccl = NCCLLibrary(library_path) + except Exception: + # disable because of missing NCCL library + # e.g. in a non-GPU environment + self.available = False + self.disabled = True + self.stream = None + return + + self.available = True + self.disabled = False + + logger.info("sglang is using nccl==%s", self.nccl.ncclGetVersion()) + + if self.rank == 0: + # get the unique id from NCCL + self.unique_id = self.nccl.ncclGetUniqueId() + else: + # construct an empty unique id + self.unique_id = ncclUniqueId() + + if not isinstance(group, StatelessProcessGroup): + tensor = torch.ByteTensor(list(self.unique_id.internal)) + ranks = dist.get_process_group_ranks(group) + # arg `src` in `broadcast` is the global rank + dist.broadcast(tensor, src=ranks[0], group=group) + byte_list = tensor.tolist() + for i, byte in enumerate(byte_list): + self.unique_id.internal[i] = byte + else: + self.unique_id = group.broadcast_obj(self.unique_id, src=0) + if isinstance(device, int): + device = torch.device(f"cuda:{device}") + elif isinstance(device, str): + device = torch.device(device) + # now `device` is a `torch.device` object + assert isinstance(device, torch.device) + self.device = device + # nccl communicator and stream will use this device + # `torch.cuda.device` is a context manager that changes the + # current cuda device to the specified one + with torch.cuda.device(device): + self.comm: ncclComm_t = self.nccl.ncclCommInitRank( + self.world_size, self.unique_id, self.rank + ) + self.stream = torch.cuda.Stream() + + # A small all_reduce for warmup. + data = torch.zeros(1, device=device) + self.all_reduce(data) + self.stream.synchronize() + del data + + # by default it is disabled, e.g. in profiling models and prefill phase. + # to use it, use under `with obj.change_state(enable=True)`, usually + # when we are using CUDA graph. + self.disabled = True + + def all_reduce( + self, tensor: torch.Tensor, op: ReduceOp = ReduceOp.SUM, stream=None + ): + if self.disabled: + return + # nccl communicator created on a specific device + # will only work on tensors on the same device + # otherwise it will cause "illegal memory access" + assert tensor.device == self.device, ( + f"this nccl communicator is created to work on {self.device}, " + f"but the input tensor is on {tensor.device}" + ) + if stream is None: + stream = self.stream + self.nccl.ncclAllReduce( + buffer_type(tensor.data_ptr()), + buffer_type(tensor.data_ptr()), + tensor.numel(), + ncclDataTypeEnum.from_torch(tensor.dtype), + ncclRedOpTypeEnum.from_torch(op), + self.comm, + cudaStream_t(stream.cuda_stream), + ) + + def send(self, tensor: torch.Tensor, dst: int, stream=None): + if self.disabled: + return + assert tensor.device == self.device, ( + f"this nccl communicator is created to work on {self.device}, " + f"but the input tensor is on {tensor.device}" + ) + if stream is None: + stream = self.stream + self.nccl.ncclSend( + buffer_type(tensor.data_ptr()), + tensor.numel(), + ncclDataTypeEnum.from_torch(tensor.dtype), + dst, + self.comm, + cudaStream_t(stream.cuda_stream), + ) + + def recv(self, tensor: torch.Tensor, src: int, stream=None): + if self.disabled: + return + assert tensor.device == self.device, ( + f"this nccl communicator is created to work on {self.device}, " + f"but the input tensor is on {tensor.device}" + ) + if stream is None: + stream = self.stream + self.nccl.ncclRecv( + buffer_type(tensor.data_ptr()), + tensor.numel(), + ncclDataTypeEnum.from_torch(tensor.dtype), + src, + self.comm, + cudaStream_t(stream.cuda_stream), + ) + + @contextmanager + def change_state( + self, enable: Optional[bool] = None, stream: Optional[torch.cuda.Stream] = None + ): + """ + A context manager to change the state of the communicator. + """ + if enable is None: + # guess a default value when not specified + enable = self.available + + if stream is None: + stream = self.stream + + old_disable = self.disabled + old_stream = self.stream + + self.stream = stream + self.disabled = not enable + yield + + self.disabled = old_disable + self.stream = old_stream diff --git a/python/sglang/srt/distributed/device_communicators/pynccl_wrapper.py b/python/sglang/srt/distributed/device_communicators/pynccl_wrapper.py new file mode 100644 index 0000000000..41b759a8d7 --- /dev/null +++ b/python/sglang/srt/distributed/device_communicators/pynccl_wrapper.py @@ -0,0 +1,332 @@ +# This file is a pure Python wrapper for the NCCL library. +# The main purpose is to use NCCL combined with CUDA graph. +# Before writing this script, we tried the following approach: +# 1. We tried to use `cupy`, it calls NCCL correctly, but `cupy` itself +# often gets stuck when initializing the NCCL communicator. +# 2. We tried to use `torch.distributed`, but `torch.distributed.all_reduce` +# contains many other potential cuda APIs, that are not allowed during +# capturing the CUDA graph. For further details, please check +# https://discuss.pytorch.org/t/pytorch-cudagraph-with-nccl-operation-failed/ . +# +# Another rejected idea is to write a C/C++ binding for NCCL. It is usually +# doable, but we often encounter issues related with nccl versions, and need +# to switch between different versions of NCCL. See +# https://github.com/NVIDIA/nccl/issues/1234 for more details. +# A C/C++ binding is not flexible enough to handle this. It requires +# recompilation of the code every time we want to switch between different +# versions. This current implementation, with a **pure** Python wrapper, is +# more flexible. We can easily switch between different versions of NCCL by +# changing the environment variable `VLLM_NCCL_SO_PATH`, or the `so_file` +# variable in the code. + +import ctypes +import logging +import platform +from dataclasses import dataclass +from typing import Any, Dict, List, Optional + +import torch +from torch.distributed import ReduceOp +from vllm.utils import find_nccl_library + +logger = logging.getLogger(__name__) + +# === export types and functions from nccl to Python === +# for the original nccl definition, please check +# https://github.com/NVIDIA/nccl/blob/master/src/nccl.h.in + +ncclResult_t = ctypes.c_int +ncclComm_t = ctypes.c_void_p + + +class ncclUniqueId(ctypes.Structure): + _fields_ = [("internal", ctypes.c_byte * 128)] + + +cudaStream_t = ctypes.c_void_p +buffer_type = ctypes.c_void_p + +ncclDataType_t = ctypes.c_int + + +class ncclDataTypeEnum: + ncclInt8 = 0 + ncclChar = 0 + ncclUint8 = 1 + ncclInt32 = 2 + ncclInt = 2 + ncclUint32 = 3 + ncclInt64 = 4 + ncclUint64 = 5 + ncclFloat16 = 6 + ncclHalf = 6 + ncclFloat32 = 7 + ncclFloat = 7 + ncclFloat64 = 8 + ncclDouble = 8 + ncclBfloat16 = 9 + ncclNumTypes = 10 + + @classmethod + def from_torch(cls, dtype: torch.dtype) -> int: + if dtype == torch.int8: + return cls.ncclInt8 + if dtype == torch.uint8: + return cls.ncclUint8 + if dtype == torch.int32: + return cls.ncclInt32 + if dtype == torch.int64: + return cls.ncclInt64 + if dtype == torch.float16: + return cls.ncclFloat16 + if dtype == torch.float32: + return cls.ncclFloat32 + if dtype == torch.float64: + return cls.ncclFloat64 + if dtype == torch.bfloat16: + return cls.ncclBfloat16 + raise ValueError(f"Unsupported dtype: {dtype}") + + +ncclRedOp_t = ctypes.c_int + + +class ncclRedOpTypeEnum: + ncclSum = 0 + ncclProd = 1 + ncclMax = 2 + ncclMin = 3 + ncclAvg = 4 + ncclNumOps = 5 + + @classmethod + def from_torch(cls, op: ReduceOp) -> int: + if op == ReduceOp.SUM: + return cls.ncclSum + if op == ReduceOp.PRODUCT: + return cls.ncclProd + if op == ReduceOp.MAX: + return cls.ncclMax + if op == ReduceOp.MIN: + return cls.ncclMin + if op == ReduceOp.AVG: + return cls.ncclAvg + raise ValueError(f"Unsupported op: {op}") + + +@dataclass +class Function: + name: str + restype: Any + argtypes: List[Any] + + +class NCCLLibrary: + exported_functions = [ + # const char* ncclGetErrorString(ncclResult_t result) + Function("ncclGetErrorString", ctypes.c_char_p, [ncclResult_t]), + # ncclResult_t ncclGetVersion(int *version); + Function("ncclGetVersion", ncclResult_t, [ctypes.POINTER(ctypes.c_int)]), + # ncclResult_t ncclGetUniqueId(ncclUniqueId* uniqueId); + Function("ncclGetUniqueId", ncclResult_t, [ctypes.POINTER(ncclUniqueId)]), + # ncclResult_t ncclCommInitRank( + # ncclComm_t* comm, int nranks, ncclUniqueId commId, int rank); + # note that ncclComm_t is a pointer type, so the first argument + # is a pointer to a pointer + Function( + "ncclCommInitRank", + ncclResult_t, + [ctypes.POINTER(ncclComm_t), ctypes.c_int, ncclUniqueId, ctypes.c_int], + ), + # ncclResult_t ncclAllReduce( + # const void* sendbuff, void* recvbuff, size_t count, + # ncclDataType_t datatype, ncclRedOp_t op, ncclComm_t comm, + # cudaStream_t stream); + # note that cudaStream_t is a pointer type, so the last argument + # is a pointer + Function( + "ncclAllReduce", + ncclResult_t, + [ + buffer_type, + buffer_type, + ctypes.c_size_t, + ncclDataType_t, + ncclRedOp_t, + ncclComm_t, + cudaStream_t, + ], + ), + # ncclResult_t ncclSend( + # const void* sendbuff, size_t count, ncclDataType_t datatype, + # int dest, ncclComm_t comm, cudaStream_t stream); + Function( + "ncclSend", + ncclResult_t, + [ + buffer_type, + ctypes.c_size_t, + ncclDataType_t, + ctypes.c_int, + ncclComm_t, + cudaStream_t, + ], + ), + # ncclResult_t ncclRecv( + # void* recvbuff, size_t count, ncclDataType_t datatype, + # int src, ncclComm_t comm, cudaStream_t stream); + Function( + "ncclRecv", + ncclResult_t, + [ + buffer_type, + ctypes.c_size_t, + ncclDataType_t, + ctypes.c_int, + ncclComm_t, + cudaStream_t, + ], + ), + # be cautious! this is a collective call, it will block until all + # processes in the communicator have called this function. + # because Python object destruction can happen in random order, + # it is better not to call it at all. + # ncclResult_t ncclCommDestroy(ncclComm_t comm); + Function("ncclCommDestroy", ncclResult_t, [ncclComm_t]), + ] + + # class attribute to store the mapping from the path to the library + # to avoid loading the same library multiple times + path_to_library_cache: Dict[str, Any] = {} + + # class attribute to store the mapping from library path + # to the corresponding dictionary + path_to_dict_mapping: Dict[str, Dict[str, Any]] = {} + + def __init__(self, so_file: Optional[str] = None): + + so_file = so_file or find_nccl_library() + + try: + if so_file not in NCCLLibrary.path_to_dict_mapping: + lib = ctypes.CDLL(so_file) + NCCLLibrary.path_to_library_cache[so_file] = lib + self.lib = NCCLLibrary.path_to_library_cache[so_file] + except Exception as e: + logger.error( + "Failed to load NCCL library from %s ." + "It is expected if you are not running on NVIDIA/AMD GPUs." + "Otherwise, the nccl library might not exist, be corrupted " + "or it does not support the current platform %s." + "If you already have the library, please set the " + "environment variable VLLM_NCCL_SO_PATH" + " to point to the correct nccl library path.", + so_file, + platform.platform(), + ) + raise e + + if so_file not in NCCLLibrary.path_to_dict_mapping: + _funcs: Dict[str, Any] = {} + for func in NCCLLibrary.exported_functions: + f = getattr(self.lib, func.name) + f.restype = func.restype + f.argtypes = func.argtypes + _funcs[func.name] = f + NCCLLibrary.path_to_dict_mapping[so_file] = _funcs + self._funcs = NCCLLibrary.path_to_dict_mapping[so_file] + + def ncclGetErrorString(self, result: ncclResult_t) -> str: + return self._funcs["ncclGetErrorString"](result).decode("utf-8") + + def NCCL_CHECK(self, result: ncclResult_t) -> None: + if result != 0: + error_str = self.ncclGetErrorString(result) + raise RuntimeError(f"NCCL error: {error_str}") + + def ncclGetVersion(self) -> str: + version = ctypes.c_int() + self.NCCL_CHECK(self._funcs["ncclGetVersion"](ctypes.byref(version))) + version_str = str(version.value) + # something like 21903 --> "2.19.3" + major = version_str[0].lstrip("0") + minor = version_str[1:3].lstrip("0") + patch = version_str[3:].lstrip("0") + return f"{major}.{minor}.{patch}" + + def ncclGetUniqueId(self) -> ncclUniqueId: + unique_id = ncclUniqueId() + self.NCCL_CHECK(self._funcs["ncclGetUniqueId"](ctypes.byref(unique_id))) + return unique_id + + def ncclCommInitRank( + self, world_size: int, unique_id: ncclUniqueId, rank: int + ) -> ncclComm_t: + comm = ncclComm_t() + self.NCCL_CHECK( + self._funcs["ncclCommInitRank"]( + ctypes.byref(comm), world_size, unique_id, rank + ) + ) + return comm + + def ncclAllReduce( + self, + sendbuff: buffer_type, + recvbuff: buffer_type, + count: int, + datatype: int, + op: int, + comm: ncclComm_t, + stream: cudaStream_t, + ) -> None: + # `datatype` actually should be `ncclDataType_t` + # and `op` should be `ncclRedOp_t` + # both are aliases of `ctypes.c_int` + # when we pass int to a function, it will be converted to `ctypes.c_int` + # by ctypes automatically + self.NCCL_CHECK( + self._funcs["ncclAllReduce"]( + sendbuff, recvbuff, count, datatype, op, comm, stream + ) + ) + + def ncclSend( + self, + sendbuff: buffer_type, + count: int, + datatype: int, + dest: int, + comm: ncclComm_t, + stream: cudaStream_t, + ) -> None: + self.NCCL_CHECK( + self._funcs["ncclSend"](sendbuff, count, datatype, dest, comm, stream) + ) + + def ncclRecv( + self, + recvbuff: buffer_type, + count: int, + datatype: int, + src: int, + comm: ncclComm_t, + stream: cudaStream_t, + ) -> None: + self.NCCL_CHECK( + self._funcs["ncclRecv"](recvbuff, count, datatype, src, comm, stream) + ) + + def ncclCommDestroy(self, comm: ncclComm_t) -> None: + self.NCCL_CHECK(self._funcs["ncclCommDestroy"](comm)) + + +__all__ = [ + "NCCLLibrary", + "ncclDataTypeEnum", + "ncclRedOpTypeEnum", + "ncclUniqueId", + "ncclComm_t", + "cudaStream_t", + "buffer_type", +] diff --git a/python/sglang/srt/distributed/device_communicators/shm_broadcast.py b/python/sglang/srt/distributed/device_communicators/shm_broadcast.py new file mode 100644 index 0000000000..e1a337ab7c --- /dev/null +++ b/python/sglang/srt/distributed/device_communicators/shm_broadcast.py @@ -0,0 +1,496 @@ +import logging +import os +import pickle +import time +from contextlib import contextmanager +from dataclasses import dataclass, field +from multiprocessing import shared_memory +from typing import List, Optional +from unittest.mock import patch + +import torch +import torch.distributed as dist +import vllm.envs as envs +from torch.distributed import ProcessGroup +from vllm.utils import get_ip, get_open_port, is_valid_ipv6_address +from zmq import IPV6 # type: ignore +from zmq import SUB, SUBSCRIBE, XPUB, XPUB_VERBOSE, Context # type: ignore + +VLLM_RINGBUFFER_WARNING_INTERVAL = envs.VLLM_RINGBUFFER_WARNING_INTERVAL + +logger = logging.getLogger(__name__) + + +class ShmRingBuffer: + + def __init__( + self, + n_reader: int, + max_chunk_bytes: int, + max_chunks: int, + name: Optional[str] = None, + ): + """ + A shared memory ring buffer implementation for broadcast communication. + Essentially, it is a queue where only one will `enqueue` and multiple + will `dequeue`. The max size of each item, together with the max number + of items that can be stored in the buffer are known in advance. + In this case, we don't need to synchronize the access to + the buffer. + + Buffer memory layout: + data metadata + | | + | (current_idx) | (current_idx) + v v + +-------------------------------+----------------------------------------+ + | chunk0 | chunk1 | ... | chunk | metadata0 | metadata1 | ... | metadata | + +-------------------------------+----------------------------------------+ + | max_chunks x max_chunk_bytes | max_chunks x (1 + n_reader) bytes | + + metadata memory layout: each byte is a flag, the first byte is the written + flag, and the rest are reader flags. The flags are set to 0 by default. + +--------------+--------------+--------------+-----+--------------+ + | written_flag | reader0_flag | reader1_flag | ... | readerN_flag | + +--------------+--------------+--------------+-----+--------------+ + + The state of metadata is as follows: + + (case 1) 0???...???: the block is not written yet, cannot read, can write + (case 2) 1000...000: the block is just written, can read, cannot write + (case 3) 1???...???: the block is written and read by some readers, can read if not read, cannot write + (case 4) 1111...111: the block is written and read by all readers, cannot read, can write + + State transition for readers: + + When a reader finds a block that it can read (case 2 or 3), it can yield the block for caller to read. + Only after the caller finishes reading the block, the reader can mark the block as read. + Readers only mark the block as read (from 0 to 1), the writer marks the block as ready to read (from 1 to 0). + + State transition for writer: + + When the writer writes to a block (case 1 or 4), it first resets the written flag to 0, converting either case + to case 1. Then it can yield the block for caller to write. After the caller finishes writing the block, the writer + can reset the reader flags to 0, and mark the block as written (from 0 to 1). + NOTE: the order is important here, first reset the reader flags (so that we are still in case 1), then mark the block as written. The state transition is atomic. If we do it in the reverse order, it will go through case 3 and then back to case 2, and readers might read the intermediate case 3, which is not correct. + + During creation, `name` is None and the buffer is created. We can pass the + created object to other processes by pickling it. The other processes will + get the name of the shared memory and open it, so that they can access the + same shared memory buffer. + """ # noqa + self.n_reader = n_reader + self.metadata_size = 1 + n_reader + self.max_chunk_bytes = max_chunk_bytes + self.max_chunks = max_chunks + self.total_bytes_of_buffer = ( + self.max_chunk_bytes + self.metadata_size + ) * self.max_chunks + self.data_offset = 0 + self.metadata_offset = self.max_chunk_bytes * self.max_chunks + + if name is None: + # we are creating a buffer + self.is_creator = True + self.shared_memory = shared_memory.SharedMemory( + create=True, size=self.total_bytes_of_buffer + ) + # initialize the metadata section to 0 + with memoryview( + self.shared_memory.buf[self.metadata_offset :] + ) as metadata_buffer: + torch.frombuffer(metadata_buffer, dtype=torch.uint8).fill_(0) + else: + # we are opening an existing buffer + self.is_creator = False + # fix to https://stackoverflow.com/q/62748654/9191338 + # Python incorrectly tracks shared memory even if it is not + # created by the process. The following patch is a workaround. + with patch( + "multiprocessing.resource_tracker.register", + lambda *args, **kwargs: None, + ): + try: + self.shared_memory = shared_memory.SharedMemory(name=name) + assert self.shared_memory.size == self.total_bytes_of_buffer + except FileNotFoundError: + # we might deserialize the object in a different node + # in this case, this object is not used, + # and we should suppress the error + pass + + def __reduce__(self): + return ( + self.__class__, + ( + self.n_reader, + self.max_chunk_bytes, + self.max_chunks, + self.shared_memory.name, + ), + ) + + def __del__(self): + if hasattr(self, "shared_memory"): + self.shared_memory.close() + if self.is_creator: + self.shared_memory.unlink() + + @contextmanager + def get_data(self, current_idx: int): + start = self.data_offset + current_idx * self.max_chunk_bytes + end = start + self.max_chunk_bytes + with memoryview(self.shared_memory.buf[start:end]) as buf: + yield buf + + @contextmanager + def get_metadata(self, current_idx: int): + start = self.metadata_offset + current_idx * self.metadata_size + end = start + self.metadata_size + with memoryview(self.shared_memory.buf[start:end]) as buf: + yield buf + + +@dataclass +class Handle: + connect_ip: str + local_reader_ranks: List[int] = field(default_factory=list) + + buffer: Optional[ShmRingBuffer] = None + local_subscribe_port: Optional[int] = None + remote_subscribe_port: Optional[int] = None + + +class MessageQueue: + + def __init__( + self, + n_reader, # number of all readers + n_local_reader, # number of local readers through shared memory + local_reader_ranks: Optional[List[int]] = None, + max_chunk_bytes: int = 1024 * 1024 * 10, + max_chunks: int = 10, + connect_ip: Optional[str] = None, + ): + if local_reader_ranks is None: + local_reader_ranks = list(range(n_local_reader)) + else: + assert len(local_reader_ranks) == n_local_reader + self.n_local_reader = n_local_reader + n_remote_reader = n_reader - n_local_reader + self.n_remote_reader = n_remote_reader + + if connect_ip is None: + connect_ip = get_ip() if n_remote_reader > 0 else "127.0.0.1" + + context = Context() + + if n_local_reader > 0: + # for local readers, we will: + # 1. create a shared memory ring buffer to communicate small data + # 2. create a publish-subscribe socket to communicate large data + self.buffer = ShmRingBuffer(n_local_reader, max_chunk_bytes, max_chunks) + + # XPUB is very similar to PUB, + # except that it can receive subscription messages + # to confirm the number of subscribers + self.local_socket = context.socket(XPUB) + # set the verbose option so that we can receive every subscription + # message. otherwise, we will only receive the first subscription + # see http://api.zeromq.org/3-3:zmq-setsockopt for more details + self.local_socket.setsockopt(XPUB_VERBOSE, True) + local_subscribe_port = get_open_port() + socket_addr = f"tcp://127.0.0.1:{local_subscribe_port}" + logger.debug("Binding to %s", socket_addr) + self.local_socket.bind(socket_addr) + + self.current_idx = 0 + + else: + self.buffer = None # type: ignore + local_subscribe_port = None + self.local_socket = None + self.current_idx = -1 + + if n_remote_reader > 0: + # for remote readers, we will: + # create a publish-subscribe socket to communicate large data + self.remote_socket = context.socket(XPUB) + self.remote_socket.setsockopt(XPUB_VERBOSE, True) + remote_subscribe_port = get_open_port() + if is_valid_ipv6_address(connect_ip): + self.remote_socket.setsockopt(IPV6, 1) + socket_addr = f"tcp://*:{remote_subscribe_port}" + self.remote_socket.bind(socket_addr) + + else: + remote_subscribe_port = None + self.remote_socket = None + + self._is_writer = True + self._is_local_reader = False + self.local_reader_rank = -1 + # rank does not matter for remote readers + self._is_remote_reader = False + + self.handle = Handle( + connect_ip=connect_ip, + local_reader_ranks=local_reader_ranks, + buffer=self.buffer, + local_subscribe_port=local_subscribe_port, + remote_subscribe_port=remote_subscribe_port, + ) + + logger.info("vLLM message queue communication handle: %s", self.handle) + + def export_handle(self) -> Handle: + return self.handle + + @staticmethod + def create_from_handle(handle: Handle, rank) -> "MessageQueue": + self = MessageQueue.__new__(MessageQueue) + self.handle = handle + self._is_writer = False + + context = Context() + + if rank in handle.local_reader_ranks: + assert handle.buffer is not None + self.buffer = handle.buffer + self.current_idx = 0 + self.local_reader_rank = handle.local_reader_ranks.index(rank) + self._is_local_reader = True + self._is_remote_reader = False + + self.local_socket = context.socket(SUB) + self.local_socket.setsockopt_string(SUBSCRIBE, "") + socket_addr = f"tcp://127.0.0.1:{handle.local_subscribe_port}" + logger.debug("Connecting to %s", socket_addr) + self.local_socket.connect(socket_addr) + + self.remote_socket = None + else: + self.buffer = None # type: ignore + self.current_idx = -1 + self.local_reader_rank = -1 + self._is_local_reader = False + self._is_remote_reader = True + + self.local_socket = None + + self.remote_socket = context.socket(SUB) + self.remote_socket.setsockopt_string(SUBSCRIBE, "") + if is_valid_ipv6_address(handle.connect_ip): + self.remote_socket.setsockopt(IPV6, 1) + socket_addr = f"tcp://{handle.connect_ip}:{handle.remote_subscribe_port}" + logger.debug("Connecting to %s", socket_addr) + self.remote_socket.connect(socket_addr) + + return self + + def wait_until_ready(self): + """This is a collective operation. All processes (including the + readers and the writer) should call this function. + """ + if self._is_writer: + # wait for all readers to connect + + # local readers + for i in range(self.n_local_reader): + # wait for subscription messages from all local readers + self.local_socket.recv() + if self.n_local_reader > 0: + # send a message to all local readers + # to make sure the publish channel is working + self.local_socket.send(b"READY") + + # remote readers + for i in range(self.n_remote_reader): + # wait for subscription messages from all remote readers + self.remote_socket.recv() + if self.n_remote_reader > 0: + # send a message to all remote readers + # to make sure the publish channel is working + self.remote_socket.send(b"READY") + elif self._is_local_reader: + # wait for the writer to send a message + recv = self.local_socket.recv() + assert recv == b"READY" + elif self._is_remote_reader: + # wait for the writer to send a message + recv = self.remote_socket.recv() + assert recv == b"READY" + + @contextmanager + def acquire_write(self): + assert self._is_writer, "Only writers can acquire write" + start_time = time.monotonic() + n_warning = 1 + while True: + with self.buffer.get_metadata(self.current_idx) as metadata_buffer: + read_count = sum(metadata_buffer[1:]) + written_flag = metadata_buffer[0] + if written_flag and read_count != self.buffer.n_reader: + # this block is written and not read by all readers + # for writers, `self.current_idx` is the next block to write + # if this block is not ready to write, + # we need to wait until it is read by all readers + + # Release the processor to other threads + os.sched_yield() + + # if we wait for a long time, we should warn the user + if ( + time.monotonic() - start_time + > VLLM_RINGBUFFER_WARNING_INTERVAL * n_warning + ): + logger.warning( + "No available block found in %s second. ", + VLLM_RINGBUFFER_WARNING_INTERVAL, + ) + n_warning += 1 + + continue + # found a block that is either + # (1) not written + # (2) read by all readers + + # mark the block as not written + metadata_buffer[0] = 0 + # let caller write to the buffer + with self.buffer.get_data(self.current_idx) as buf: + yield buf + + # caller has written to the buffer + # NOTE: order is important here + # first set the read flags to 0 + # then set the written flag to 1 + # otherwise, the readers may think they already read the block + for i in range(1, self.buffer.n_reader + 1): + # set read flag to 0, meaning it is not read yet + metadata_buffer[i] = 0 + # mark the block as written + metadata_buffer[0] = 1 + self.current_idx = (self.current_idx + 1) % self.buffer.max_chunks + break + + @contextmanager + def acquire_read(self): + assert self._is_local_reader, "Only readers can acquire read" + start_time = time.monotonic() + n_warning = 1 + while True: + with self.buffer.get_metadata(self.current_idx) as metadata_buffer: + read_flag = metadata_buffer[self.local_reader_rank + 1] + written_flag = metadata_buffer[0] + if not written_flag or read_flag: + # this block is either + # (1) not written + # (2) already read by this reader + + # for readers, `self.current_idx` is the next block to read + # if this block is not ready, + # we need to wait until it is written + + # Release the processor to other threads + os.sched_yield() + + # if we wait for a long time, we should warn the user + if ( + time.monotonic() - start_time + > VLLM_RINGBUFFER_WARNING_INTERVAL * n_warning + ): + logger.warning( + "No available block found in %s second. ", + VLLM_RINGBUFFER_WARNING_INTERVAL, + ) + n_warning += 1 + + continue + # found a block that is not read by this reader + # let caller read from the buffer + with self.buffer.get_data(self.current_idx) as buf: + yield buf + + # caller has read from the buffer + # set the read flag + metadata_buffer[self.local_reader_rank + 1] = 1 + self.current_idx = (self.current_idx + 1) % self.buffer.max_chunks + break + + def enqueue(self, obj): + assert self._is_writer, "Only writers can enqueue" + serialized_obj = pickle.dumps(obj, protocol=pickle.HIGHEST_PROTOCOL) + if self.n_local_reader > 0: + if len(serialized_obj) >= self.buffer.max_chunk_bytes: + with self.acquire_write() as buf: + buf[0] = 1 # overflow + self.local_socket.send(serialized_obj) + else: + with self.acquire_write() as buf: + buf[0] = 0 # not overflow + buf[1 : len(serialized_obj) + 1] = serialized_obj + if self.n_remote_reader > 0: + self.remote_socket.send(serialized_obj) + + def dequeue(self): + if self._is_local_reader: + with self.acquire_read() as buf: + overflow = buf[0] == 1 + if not overflow: + # no need to know the size of serialized object + # pickle format contains the size information internally + # see https://docs.python.org/3/library/pickle.html + obj = pickle.loads(buf[1:]) + if overflow: + recv = self.local_socket.recv() + obj = pickle.loads(recv) + elif self._is_remote_reader: + recv = self.remote_socket.recv() + obj = pickle.loads(recv) + else: + raise RuntimeError("Only readers can dequeue") + return obj + + def broadcast_object(self, obj=None): + if self._is_writer: + self.enqueue(obj) + return obj + else: + return self.dequeue() + + @staticmethod + def create_from_process_group( + pg: ProcessGroup, max_chunk_bytes, max_chunks, writer_rank=0 + ) -> "MessageQueue": + group_rank = dist.get_rank(pg) + group_world_size = dist.get_world_size(pg) + global_ranks = dist.get_process_group_ranks(pg) + + from vllm.distributed.parallel_state import in_the_same_node_as + + status = in_the_same_node_as(pg, source_rank=writer_rank) + same_node_ranks = [i for i, s in enumerate(status) if s] + n_reader = group_world_size - 1 + n_local_reader = len(same_node_ranks) - 1 + local_reader_ranks = [i for i in same_node_ranks if i != writer_rank] + buffer_io: MessageQueue + if group_rank == writer_rank: + buffer_io = MessageQueue( + n_reader=n_reader, + n_local_reader=n_local_reader, + local_reader_ranks=local_reader_ranks, + max_chunk_bytes=max_chunk_bytes, + max_chunks=max_chunks, + ) + handle = buffer_io.export_handle() + dist.broadcast_object_list( + [handle], src=global_ranks[writer_rank], group=pg + ) + else: + recv = [None] + dist.broadcast_object_list(recv, src=global_ranks[writer_rank], group=pg) + handle = recv[0] # type: ignore + buffer_io = MessageQueue.create_from_handle(handle, group_rank) + buffer_io.wait_until_ready() + return buffer_io diff --git a/python/sglang/srt/distributed/device_communicators/tpu_communicator.py b/python/sglang/srt/distributed/device_communicators/tpu_communicator.py new file mode 100644 index 0000000000..9027b68b38 --- /dev/null +++ b/python/sglang/srt/distributed/device_communicators/tpu_communicator.py @@ -0,0 +1,59 @@ +import os + +import torch +import torch.distributed as dist +from torch.distributed import ProcessGroup +from vllm.platforms import current_platform + +if current_platform.is_tpu(): + import torch_xla.core.xla_model as xm + import torch_xla.runtime as xr + from torch_xla._internal import pjrt + from vllm.executor import ray_utils + + +class TpuCommunicator: + + def __init__(self, group: ProcessGroup): + if not current_platform.is_tpu(): + self.disabled = True + return + self.disabled = False + + # NOTE(woosuk): When using TP > 1 on TPUs, every TPU on the same node + # must be used together. Therefore, the local rank and world size can + # be simply calculated as follows. + global_rank = dist.get_rank(group) + global_world_size = dist.get_world_size(group) + + # Calculate how many TPU nodes are in the current deployment. This + # is the Ray placement group if it is deployed with Ray. Default + # to the number of TPU nodes in the Ray cluster. The number of TPU + # nodes is computed by the total number of TPUs divided by the + # number of TPU accelerators per node, to account for clusters + # with both CPUs and TPUs. + num_nodes = ray_utils.get_num_tpu_nodes() + num_nodes_in_pg = ray_utils.get_num_nodes_in_placement_group() + if num_nodes_in_pg > 0: + num_nodes = num_nodes_in_pg + + local_world_size = global_world_size // num_nodes + local_rank = global_rank % local_world_size + + # Ensure environment variables are set for multihost deployments. + # On GKE, this is needed for libtpu and TPU driver to know which TPU + # chip is actually visible. Otherwise the TPU driver will fail to + # initialize because the number of devices would be different from + # the number of visible worker addresses. + os.environ["CLOUD_TPU_TASK_ID"] = str(global_rank) + os.environ["TPU_VISIBLE_CHIPS"] = str(local_rank) + + pjrt.initialize_multiprocess(local_rank, local_world_size) + xr._init_world_size_ordinal() + + def all_reduce(self, x: torch.Tensor) -> torch.Tensor: + return xm.all_reduce(xm.REDUCE_SUM, x) + + def all_gather(self, x: torch.Tensor, dim: int = -1) -> torch.Tensor: + assert dim == -1, "TPUs only support dim=-1 for all-gather." + return xm.all_gather(x, dim=dim) diff --git a/python/sglang/srt/distributed/device_communicators/xpu_communicator.py b/python/sglang/srt/distributed/device_communicators/xpu_communicator.py new file mode 100644 index 0000000000..e64fd83d04 --- /dev/null +++ b/python/sglang/srt/distributed/device_communicators/xpu_communicator.py @@ -0,0 +1,45 @@ +import torch +import torch.distributed as dist +from torch.distributed import ProcessGroup +from vllm.platforms import current_platform + + +class XpuCommunicator: + + def __init__(self, group: ProcessGroup): + if not current_platform.is_xpu(): + self.disabled = True + return + self.disabled = False + self.group = group + self.world_size = dist.get_world_size(self.group) + + def all_reduce(self, x: torch.Tensor) -> torch.Tensor: + dist.all_reduce(x, group=self.group) + return x + + def gather( + self, input_: torch.Tensor, rank_in_group: int, dst: int = 0, dim: int = -1 + ): + # For xpu path, gather doesn't work properly together with ray + # cluster so we use all_gather instead for now. + input_size = input_.size() + # Allocate output tensor. + output_tensor = torch.empty( + (self.world_size,) + input_size, dtype=input_.dtype, device=input_.device + ) + # All-gather. + torch.distributed.all_gather_into_tensor( + output_tensor, input_, group=self.group + ) + if rank_in_group == dst: + # Reshape + output_tensor = output_tensor.movedim(0, dim) + output_tensor = output_tensor.reshape( + input_size[:dim] + + (self.world_size * input_size[dim],) + + input_size[dim + 1 :] + ) + else: + output_tensor = None + return output_tensor diff --git a/python/sglang/srt/distributed/parallel_state.py b/python/sglang/srt/distributed/parallel_state.py new file mode 100644 index 0000000000..b8e6e9fe8e --- /dev/null +++ b/python/sglang/srt/distributed/parallel_state.py @@ -0,0 +1,1291 @@ +# Copyright 2023 The vLLM team. +# Adapted from +# https://github.com/NVIDIA/Megatron-LM/blob/main/megatron/core/parallel_state.py +# Copyright (c) 2022, NVIDIA CORPORATION. All rights reserved. +"""vLLM distributed state. +It takes over the control of the distributed environment from PyTorch. +The typical workflow is: + +- call `init_distributed_environment` to initialize the distributed environment. +- call `initialize_model_parallel` or `ensure_model_parallel_initialized` to + initialize the model parallel groups. + +- any code dealing with the distributed stuff + +- call `destroy_model_parallel` to destroy the model parallel groups. +- call `destroy_distributed_environment` to destroy the distributed environment. + +If you only need to use the distributed environment without model/pipeline + parallelism, you can skip the model parallel initialization and destruction + steps. +""" +import contextlib +import gc +import logging +import pickle +import weakref +from collections import namedtuple +from contextlib import contextmanager, nullcontext +from dataclasses import dataclass +from multiprocessing import shared_memory +from typing import Any, Callable, Dict, List, Optional, Tuple, Union +from unittest.mock import patch + +import torch +import torch.distributed +import vllm.envs as envs +from torch.distributed import Backend, ProcessGroup +from vllm.platforms import current_platform + +from sglang.srt.utils import direct_register_custom_op, supports_custom_op + + +@dataclass +class GraphCaptureContext: + stream: torch.cuda.Stream + + +TensorMetadata = namedtuple("TensorMetadata", ["device", "dtype", "size"]) + + +def _split_tensor_dict( + tensor_dict: Dict[str, Union[torch.Tensor, Any]] +) -> Tuple[List[Tuple[str, Any]], List[torch.Tensor]]: + """Split the tensor dictionary into two parts: + 1. A list of (key, value) pairs. If the value is a tensor, it is replaced + by its metadata. + 2. A list of tensors. + """ + metadata_list: List[Tuple[str, Any]] = [] + tensor_list: List[torch.Tensor] = [] + for key, value in tensor_dict.items(): + if isinstance(value, torch.Tensor): + # Note: we cannot use `value.device` here, + # because it contains not only the device type but also the device + # index (e.g. "cuda:0"). We only need the device type. + # receiving side will set the device index. + device = value.device.type + metadata_list.append( + (key, TensorMetadata(device, value.dtype, value.size())) + ) + tensor_list.append(value) + else: + metadata_list.append((key, value)) + return metadata_list, tensor_list + + +_group_name_counter: Dict[str, int] = {} + + +def _get_unique_name(name: str) -> str: + """Get a unique name for the group. + Example: + _get_unique_name("tp") -> "tp:0" + _get_unique_name("tp") -> "tp:1" + """ + if name not in _group_name_counter: + _group_name_counter[name] = 0 + newname = f"{name}:{_group_name_counter[name]}" + _group_name_counter[name] += 1 + return newname + + +_groups: Dict[str, Callable[[], Optional["GroupCoordinator"]]] = {} + + +def _register_group(group: "GroupCoordinator") -> None: + _groups[group.unique_name] = weakref.ref(group) + + +if supports_custom_op(): + + def inplace_all_reduce(tensor: torch.Tensor, group_name: str) -> None: + assert group_name in _groups, f"Group {group_name} is not found." + group = _groups[group_name]() + if group is None: + raise ValueError(f"Group {group_name} is destroyed.") + group._all_reduce_in_place(tensor) + + def inplace_all_reduce_fake(tensor: torch.Tensor, group_name: str) -> None: + return + + direct_register_custom_op( + op_name="inplace_all_reduce", + op_func=inplace_all_reduce, + mutates_args=["tensor"], + fake_impl=inplace_all_reduce_fake, + ) + + def outplace_all_reduce(tensor: torch.Tensor, group_name: str) -> torch.Tensor: + assert group_name in _groups, f"Group {group_name} is not found." + group = _groups[group_name]() + if group is None: + raise ValueError(f"Group {group_name} is destroyed.") + return group._all_reduce_out_place(tensor) + + def outplace_all_reduce_fake(tensor: torch.Tensor, group_name: str) -> torch.Tensor: + return torch.empty_like(tensor) + + direct_register_custom_op( + op_name="outplace_all_reduce", + op_func=outplace_all_reduce, + mutates_args=[], + fake_impl=outplace_all_reduce_fake, + ) + + +class GroupCoordinator: + """ + PyTorch ProcessGroup wrapper for a group of processes. + PyTorch ProcessGroup is bound to one specific communication backend, + e.g. NCCL, Gloo, MPI, etc. + GroupCoordinator takes charge of all the communication operations among + the processes in the group. It can route the communication to + a specific implementation (e.g. switch allreduce implementation + based on the tensor size and cuda graph mode). + """ + + # available attributes: + rank: int # global rank + ranks: List[int] # global ranks in the group + world_size: int # size of the group + # difference between `local_rank` and `rank_in_group`: + # if we have a group of size 4 across two nodes: + # Process | Node | Rank | Local Rank | Rank in Group + # 0 | 0 | 0 | 0 | 0 + # 1 | 0 | 1 | 1 | 1 + # 2 | 1 | 2 | 0 | 2 + # 3 | 1 | 3 | 1 | 3 + local_rank: int # local rank used to assign devices + rank_in_group: int # rank inside the group + cpu_group: ProcessGroup # group for CPU communication + device_group: ProcessGroup # group for device communication + use_pynccl: bool # a hint of whether to use PyNccl + use_custom_allreduce: bool # a hint of whether to use CustomAllreduce + # communicators are only created for world size > 1 + pynccl_comm: Optional[Any] # PyNccl communicator + ca_comm: Optional[Any] # Custom allreduce communicator + mq_broadcaster: Optional[Any] # shared memory broadcaster + + def __init__( + self, + group_ranks: List[List[int]], + local_rank: int, + torch_distributed_backend: Union[str, Backend], + use_pynccl: bool, + use_custom_allreduce: bool, + use_tpu_communicator: bool, + use_hpu_communicator: bool, + use_xpu_communicator: bool, + use_message_queue_broadcaster: bool = False, + group_name: Optional[str] = None, + ): + group_name = group_name or "anonymous" + self.unique_name = _get_unique_name(group_name) + _register_group(self) + + self.rank = torch.distributed.get_rank() + self.local_rank = local_rank + self.device_group = None + self.cpu_group = None + + for ranks in group_ranks: + device_group = torch.distributed.new_group( + ranks, backend=torch_distributed_backend + ) + # a group with `gloo` backend, to allow direct coordination between + # processes through the CPU. + cpu_group = torch.distributed.new_group(ranks, backend="gloo") + if self.rank in ranks: + self.ranks = ranks + self.world_size = len(ranks) + self.rank_in_group = ranks.index(self.rank) + self.device_group = device_group + self.cpu_group = cpu_group + + assert self.cpu_group is not None + assert self.device_group is not None + + if current_platform.is_cuda_alike(): + self.device = torch.device(f"cuda:{local_rank}") + else: + self.device = torch.device("cpu") + + self.use_pynccl = use_pynccl + self.use_custom_allreduce = use_custom_allreduce + self.use_tpu_communicator = use_tpu_communicator + self.use_hpu_communicator = use_hpu_communicator + self.use_xpu_communicator = use_xpu_communicator + + # lazy import to avoid documentation build error + from sglang.srt.distributed.device_communicators.custom_all_reduce import ( + CustomAllreduce, + ) + from sglang.srt.distributed.device_communicators.pynccl import ( + PyNcclCommunicator, + ) + + self.pynccl_comm: Optional[PyNcclCommunicator] = None + if use_pynccl and self.world_size > 1: + self.pynccl_comm = PyNcclCommunicator( + group=self.cpu_group, + device=self.device, + ) + + self.ca_comm: Optional[CustomAllreduce] = None + if use_custom_allreduce and self.world_size > 1: + # Initialize a custom fast all-reduce implementation. + self.ca_comm = CustomAllreduce( + group=self.cpu_group, + device=self.device, + ) + + from sglang.srt.distributed.device_communicators.tpu_communicator import ( + TpuCommunicator, + ) + + self.tpu_communicator: Optional[TpuCommunicator] = None + if use_tpu_communicator and self.world_size > 1: + self.tpu_communicator = TpuCommunicator(group=self.cpu_group) + + from sglang.srt.distributed.device_communicators.hpu_communicator import ( + HpuCommunicator, + ) + + self.hpu_communicator: Optional[HpuCommunicator] + if use_hpu_communicator and self.world_size > 1: + self.hpu_communicator = HpuCommunicator(group=self.device_group) + + from sglang.srt.distributed.device_communicators.xpu_communicator import ( + XpuCommunicator, + ) + + self.xpu_communicator: Optional[XpuCommunicator] + if use_xpu_communicator and self.world_size > 1: + self.xpu_communicator = XpuCommunicator(group=self.device_group) + + from sglang.srt.distributed.device_communicators.shm_broadcast import ( + MessageQueue, + ) + + self.mq_broadcaster: Optional[MessageQueue] = None + if use_message_queue_broadcaster and self.world_size > 1: + self.mq_broadcaster = MessageQueue.create_from_process_group( + self.cpu_group, 1 << 22, 6 + ) + + @property + def first_rank(self): + """Return the global rank of the first process in the group""" + return self.ranks[0] + + @property + def last_rank(self): + """Return the global rank of the last process in the group""" + return self.ranks[-1] + + @property + def is_first_rank(self): + """Return whether the caller is the first process in the group""" + return self.rank == self.first_rank + + @property + def is_last_rank(self): + """Return whether the caller is the last process in the group""" + return self.rank == self.last_rank + + @property + def next_rank(self): + """Return the global rank of the process that follows the caller""" + rank_in_group = self.rank_in_group + world_size = self.world_size + return self.ranks[(rank_in_group + 1) % world_size] + + @property + def prev_rank(self): + """Return the global rank of the process that precedes the caller""" + rank_in_group = self.rank_in_group + world_size = self.world_size + return self.ranks[(rank_in_group - 1) % world_size] + + @contextmanager + def graph_capture( + self, graph_capture_context: Optional[GraphCaptureContext] = None + ): + if graph_capture_context is None: + stream = torch.cuda.Stream() + graph_capture_context = GraphCaptureContext(stream) + else: + stream = graph_capture_context.stream + + ca_comm = self.ca_comm + maybe_ca_context = nullcontext() if ca_comm is None else ca_comm.capture() + + # ensure all initialization operations complete before attempting to + # capture the graph on another stream + curr_stream = torch.cuda.current_stream() + if curr_stream != stream: + stream.wait_stream(curr_stream) + + with torch.cuda.stream(stream), maybe_ca_context: + # In graph mode, we have to be very careful about the collective + # operations. The current status is: + # allreduce \ Mode | Eager | Graph | + # -------------------------------------------- + # custom allreduce | enabled | enabled | + # PyNccl | disabled| enabled | + # torch.distributed | enabled | disabled| + # + # Note that custom allreduce will have a runtime check, if the + # tensor size is too large, it will fallback to the next + # available option. + # In summary: When using CUDA graph, we use + # either custom all-reduce kernel or pynccl. When not using + # CUDA graph, we use either custom all-reduce kernel or + # PyTorch NCCL. We always prioritize using custom all-reduce + # kernel but fall back to PyTorch or pynccl if it is + # disabled or not supported. + pynccl_comm = self.pynccl_comm + maybe_pynccl_context: Any + if not pynccl_comm: + maybe_pynccl_context = nullcontext() + else: + maybe_pynccl_context = pynccl_comm.change_state( + enable=True, stream=torch.cuda.current_stream() + ) + with maybe_pynccl_context: + yield graph_capture_context + + def all_reduce(self, input_: torch.Tensor) -> torch.Tensor: + """ + User-facing all-reduce function before we actually call the + all-reduce operation. + + We need this because Dynamo does not support passing an arbitrary + object (`self` in this case) to a custom op. We need to pass the + group name as a string, and then look up the group coordinator from + the group name, dispatch the all-reduce operation to the group + coordinator. + + In addition, PyTorch custom ops do not support mutation or returning + a new tensor in the same op. So we need to figure out if the op is + in-place or out-of-place ahead of time. + """ + # Bypass the function if we are using only 1 GPU. + if self.world_size == 1: + return input_ + + if input_.is_cpu: + import intel_extension_for_pytorch as ipex + + ipex.distributed.all_reduce(input_, group=self.device_group) + return input_ + + if not supports_custom_op(): + self._all_reduce_in_place(input_) + return input_ + + if self.tpu_communicator is not None and not self.tpu_communicator.disabled: + # TPU handles Dynamo with its own logic. + return self.tpu_communicator.all_reduce(input_) + + if self.hpu_communicator is not None and not self.hpu_communicator.disabled: + return self.hpu_communicator.all_reduce(input_) + + if self.xpu_communicator is not None and not self.xpu_communicator.disabled: + return self.xpu_communicator.all_reduce(input_) + + if ( + self.ca_comm is not None + and not self.ca_comm.disabled + and self.ca_comm.should_custom_ar(input_) + ): + return torch.ops.sglang.outplace_all_reduce( + input_, group_name=self.unique_name + ) + else: + torch.ops.sglang.inplace_all_reduce(input_, group_name=self.unique_name) + return input_ + + def _all_reduce_out_place(self, input_: torch.Tensor) -> torch.Tensor: + ca_comm = self.ca_comm + assert ca_comm is not None + assert not ca_comm.disabled + out = ca_comm.custom_all_reduce(input_) + assert out is not None + return out + + def _all_reduce_in_place(self, input_: torch.Tensor) -> None: + pynccl_comm = self.pynccl_comm + if pynccl_comm is not None and not pynccl_comm.disabled: + pynccl_comm.all_reduce(input_) + else: + torch.distributed.all_reduce(input_, group=self.device_group) + + def all_gather(self, input_: torch.Tensor, dim: int = -1) -> torch.Tensor: + world_size = self.world_size + # Bypass the function if we are using only 1 GPU. + if world_size == 1: + return input_ + assert ( + -input_.dim() <= dim < input_.dim() + ), f"Invalid dim ({dim}) for input tensor with shape {input_.size()}" + + # For TPUs, use TPU communicator. + tpu_comm = self.tpu_communicator + if tpu_comm is not None and not tpu_comm.disabled: + return tpu_comm.all_gather(input_, dim) + + # For HPUs, use HPU communicator. + hpu_comm = self.hpu_communicator + if hpu_comm is not None and not hpu_comm.disabled: + return hpu_comm.all_gather(input_, dim) + + if dim < 0: + # Convert negative dim to positive. + dim += input_.dim() + input_size = input_.size() + # NOTE: we have to use concat-style all-gather here, + # stack-style all-gather has compatibility issues with + # torch.compile . see https://github.com/pytorch/pytorch/issues/138795 + output_size = (input_size[0] * world_size,) + input_size[1:] + # Allocate output tensor. + output_tensor = torch.empty( + output_size, dtype=input_.dtype, device=input_.device + ) + # All-gather. + torch.distributed.all_gather_into_tensor( + output_tensor, input_, group=self.device_group + ) + # Reshape + output_tensor = output_tensor.reshape((world_size,) + input_size) + output_tensor = output_tensor.movedim(0, dim) + output_tensor = output_tensor.reshape( + input_size[:dim] + (world_size * input_size[dim],) + input_size[dim + 1 :] + ) + return output_tensor + + def gather( + self, input_: torch.Tensor, dst: int = 0, dim: int = -1 + ) -> Optional[torch.Tensor]: + """ + NOTE: We assume that the input tensor is on the same device across + all the ranks. + NOTE: `dst` is the local rank of the destination rank. + """ + world_size = self.world_size + # Bypass the function if we are using only 1 GPU. + if world_size == 1: + return input_ + assert ( + -input_.dim() <= dim < input_.dim() + ), f"Invalid dim ({dim}) for input tensor with shape {input_.size()}" + if dim < 0: + # Convert negative dim to positive. + dim += input_.dim() + if self.xpu_communicator is not None and not self.xpu_communicator.disabled: + return self.xpu_communicator.gather(input_, self.rank_in_group, dst, dim) + # Allocate output tensor. + if self.rank_in_group == dst: + gather_list = [torch.empty_like(input_) for _ in range(world_size)] + else: + gather_list = None + # Gather. + torch.distributed.gather( + input_, gather_list, dst=self.ranks[dst], group=self.device_group + ) + if self.rank_in_group == dst: + output_tensor = torch.cat(gather_list, dim=dim) + else: + output_tensor = None + return output_tensor + + def broadcast(self, input_: torch.Tensor, src: int = 0): + """Broadcast the input tensor. + NOTE: `src` is the local rank of the source rank. + """ + assert src < self.world_size, f"Invalid src rank ({src})" + + # Bypass the function if we are using only 1 GPU. + if self.world_size == 1: + return input_ + # Broadcast. + torch.distributed.broadcast( + input_, src=self.ranks[src], group=self.device_group + ) + return input_ + + def broadcast_object(self, obj: Optional[Any] = None, src: int = 0): + """Broadcast the input object. + NOTE: `src` is the local rank of the source rank. + """ + assert src < self.world_size, f"Invalid src rank ({src})" + + # Bypass the function if we are using only 1 GPU. + if self.world_size == 1: + return obj + if self.mq_broadcaster is not None: + assert src == 0, "Message queue broadcaster only supports src=0" + return self.mq_broadcaster.broadcast_object(obj) + if self.rank_in_group == src: + torch.distributed.broadcast_object_list( + [obj], src=self.ranks[src], group=self.cpu_group + ) + return obj + else: + recv = [None] + torch.distributed.broadcast_object_list( + recv, src=self.ranks[src], group=self.cpu_group + ) + return recv[0] + + def broadcast_object_list( + self, obj_list: List[Any], src: int = 0, group: Optional[ProcessGroup] = None + ): + """Broadcast the input object list. + NOTE: `src` is the local rank of the source rank. + """ + assert src < self.world_size, f"Invalid src rank ({src})" + + # Bypass the function if we are using only 1 GPU. + if self.world_size == 1: + return obj_list + # Broadcast. + torch.distributed.broadcast_object_list( + obj_list, src=self.ranks[src], group=self.device_group + ) + return obj_list + + def send_object(self, obj: Any, dst: int) -> None: + """Send the input object list to the destination rank.""" + """NOTE: `dst` is the local rank of the destination rank.""" + + assert dst < self.world_size, f"Invalid dst rank ({dst})" + + assert dst != self.rank_in_group, ( + "Invalid destination rank. Destination rank is the same " + "as the current rank." + ) + + # Serialize object to tensor and get the size as well + object_tensor = torch.frombuffer(pickle.dumps(obj), dtype=torch.uint8) + + size_tensor = torch.tensor( + [object_tensor.numel()], dtype=torch.long, device="cpu" + ) + + # Send object size + + torch.distributed.send(size_tensor, dst=self.ranks[dst], group=self.cpu_group) + + # Send object + torch.distributed.send(object_tensor, dst=self.ranks[dst], group=self.cpu_group) + + return None + + def recv_object(self, src: int) -> Any: + """Receive the input object list from the source rank.""" + """NOTE: `src` is the local rank of the source rank.""" + + assert src < self.world_size, f"Invalid src rank ({src})" + + assert ( + src != self.rank_in_group + ), "Invalid source rank. Source rank is the same as the current rank." + + size_tensor = torch.empty(1, dtype=torch.long, device="cpu") + + # Receive object size + rank_size = torch.distributed.recv( + size_tensor, src=self.ranks[src], group=self.cpu_group + ) + + # Tensor to receive serialized objects into. + object_tensor = torch.empty( # type: ignore[call-overload] + size_tensor.item(), # type: ignore[arg-type] + dtype=torch.uint8, + device="cpu", + ) + + rank_object = torch.distributed.recv( + object_tensor, src=self.ranks[src], group=self.cpu_group + ) + + assert ( + rank_object == rank_size + ), "Received object sender rank does not match the size sender rank." + + obj = pickle.loads(object_tensor.numpy().tobytes()) + + return obj + + def broadcast_tensor_dict( + self, + tensor_dict: Optional[Dict[str, Union[torch.Tensor, Any]]] = None, + src: int = 0, + group: Optional[ProcessGroup] = None, + metadata_group: Optional[ProcessGroup] = None, + ) -> Optional[Dict[str, Union[torch.Tensor, Any]]]: + """Broadcast the input tensor dictionary. + NOTE: `src` is the local rank of the source rank. + """ + # Bypass the function if we are using only 1 GPU. + if not torch.distributed.is_initialized() or self.world_size == 1: + return tensor_dict + + group = self.device_group + metadata_group = self.cpu_group + assert src < self.world_size, f"Invalid src rank ({src})" + + rank_in_group = self.rank_in_group + if rank_in_group == src: + metadata_list: List[Tuple[Any, Any]] = [] + assert isinstance( + tensor_dict, dict + ), f"Expecting a dictionary, got {type(tensor_dict)}" + metadata_list, tensor_list = _split_tensor_dict(tensor_dict) + # `metadata_list` lives in CPU memory. + # `broadcast_object_list` has serialization & deserialization, + # all happening on CPU. Therefore, we can use the CPU group. + self.broadcast_object(metadata_list, src=src) + async_handles = [] + for tensor in tensor_list: + if tensor.numel() == 0: + # Skip broadcasting empty tensors. + continue + if tensor.is_cpu: + # use metadata_group for CPU tensors + handle = torch.distributed.broadcast( + tensor, src=self.ranks[src], group=metadata_group, async_op=True + ) + else: + # use group for GPU tensors + handle = torch.distributed.broadcast( + tensor, src=self.ranks[src], group=group, async_op=True + ) + async_handles.append(handle) + for async_handle in async_handles: + async_handle.wait() + + else: + metadata_list = self.broadcast_object(None, src=src) + tensor_dict = {} + async_handles = [] + for key, value in metadata_list: + if isinstance(value, TensorMetadata): + tensor = torch.empty( + value.size, dtype=value.dtype, device=value.device + ) + if tensor.numel() == 0: + # Skip broadcasting empty tensors. + tensor_dict[key] = tensor + continue + if tensor.is_cpu: + # use metadata_group for CPU tensors + handle = torch.distributed.broadcast( + tensor, + src=self.ranks[src], + group=metadata_group, + async_op=True, + ) + else: + # use group for GPU tensors + handle = torch.distributed.broadcast( + tensor, src=self.ranks[src], group=group, async_op=True + ) + async_handles.append(handle) + tensor_dict[key] = tensor + else: + tensor_dict[key] = value + for async_handle in async_handles: + async_handle.wait() + return tensor_dict + + def send_tensor_dict( + self, + tensor_dict: Dict[str, Union[torch.Tensor, Any]], + dst: Optional[int] = None, + all_gather_group: Optional["GroupCoordinator"] = None, + ) -> Optional[Dict[str, Union[torch.Tensor, Any]]]: + """Send the input tensor dictionary. + NOTE: `dst` is the local rank of the source rank. + """ + # Bypass the function if we are using only 1 GPU. + if not torch.distributed.is_initialized() or self.world_size == 1: + return tensor_dict + + all_gather_size = 1 if all_gather_group is None else all_gather_group.world_size + all_gather_rank = ( + 0 if all_gather_group is None else all_gather_group.rank_in_group + ) + + group = self.device_group + metadata_group = self.cpu_group + + if dst is None: + dst = (self.rank_in_group + 1) % self.world_size + assert dst < self.world_size, f"Invalid dst rank ({dst})" + + metadata_list: List[Tuple[Any, Any]] = [] + assert isinstance( + tensor_dict, dict + ), f"Expecting a dictionary, got {type(tensor_dict)}" + metadata_list, tensor_list = _split_tensor_dict(tensor_dict) + # `metadata_list` lives in CPU memory. + # `send_object_list` has serialization & deserialization, + # all happening on CPU. Therefore, we can use the CPU group. + self.send_object(metadata_list, dst=dst) + for tensor in tensor_list: + if tensor.numel() == 0: + # Skip sending empty tensors. + continue + + # send-allgather: send only a slice, then do allgather. + if all_gather_group is not None and tensor.numel() % all_gather_size == 0: + tensor = tensor.reshape(all_gather_size, -1)[all_gather_rank] + + if tensor.is_cpu: + # use metadata_group for CPU tensors + torch.distributed.send( + tensor, dst=self.ranks[dst], group=metadata_group + ) + else: + # use group for GPU tensors + torch.distributed.send(tensor, dst=self.ranks[dst], group=group) + return None + + def recv_tensor_dict( + self, + src: Optional[int] = None, + all_gather_group: Optional["GroupCoordinator"] = None, + ) -> Optional[Dict[str, Union[torch.Tensor, Any]]]: + """Recv the input tensor dictionary. + NOTE: `src` is the local rank of the source rank. + """ + # Bypass the function if we are using only 1 GPU. + if not torch.distributed.is_initialized() or self.world_size == 1: + return None + + all_gather_size = 1 if all_gather_group is None else all_gather_group.world_size + all_gather_rank = ( + 0 if all_gather_group is None else all_gather_group.rank_in_group + ) + + group = self.device_group + metadata_group = self.cpu_group + + if src is None: + src = (self.rank_in_group - 1) % self.world_size + assert src < self.world_size, f"Invalid src rank ({src})" + + recv_metadata_list = self.recv_object(src=src) + tensor_dict: Dict[str, Any] = {} + for key, value in recv_metadata_list: + if isinstance(value, TensorMetadata): + tensor = torch.empty(value.size, dtype=value.dtype, device=value.device) + if tensor.numel() == 0: + # Skip broadcasting empty tensors. + tensor_dict[key] = tensor + continue + + # send-allgather: send only a slice, then do allgather. + use_all_gather = ( + all_gather_group is not None + and tensor.numel() % all_gather_size == 0 + ) + + if use_all_gather: + orig_shape = tensor.shape + tensor = tensor.reshape(all_gather_size, -1)[all_gather_rank] + + if tensor.is_cpu: + # use metadata_group for CPU tensors + torch.distributed.recv( + tensor, src=self.ranks[src], group=metadata_group + ) + else: + # use group for GPU tensors + torch.distributed.recv(tensor, src=self.ranks[src], group=group) + if use_all_gather: + # do the allgather + tensor = all_gather_group.all_gather(tensor, dim=0) # type: ignore + tensor = tensor.reshape(orig_shape) + + tensor_dict[key] = tensor + else: + tensor_dict[key] = value + return tensor_dict + + def barrier(self): + """Barrier synchronization among the group. + NOTE: don't use `device_group` here! `barrier` in NCCL is + terrible because it is internally a broadcast operation with + secretly created GPU tensors. It is easy to mess up the current + device. Use the CPU group instead. + """ + torch.distributed.barrier(group=self.cpu_group) + + def send(self, tensor: torch.Tensor, dst: Optional[int] = None) -> None: + """Sends a tensor to the destination rank in a non-blocking way""" + """NOTE: `dst` is the local rank of the destination rank.""" + if dst is None: + dst = (self.rank_in_group + 1) % self.world_size + + pynccl_comm = self.pynccl_comm + if pynccl_comm is not None and not pynccl_comm.disabled: + pynccl_comm.send(tensor, dst) + else: + torch.distributed.send(tensor, self.ranks[dst], self.device_group) + + def recv( + self, size: torch.Size, dtype: torch.dtype, src: Optional[int] = None + ) -> torch.Tensor: + """Receives a tensor from the source rank.""" + """NOTE: `src` is the local rank of the source rank.""" + if src is None: + src = (self.rank_in_group - 1) % self.world_size + + tensor = torch.empty(size, dtype=dtype, device=self.device) + pynccl_comm = self.pynccl_comm + if pynccl_comm is not None and not pynccl_comm.disabled: + pynccl_comm.recv(tensor, src) + else: + torch.distributed.recv(tensor, self.ranks[src], self.device_group) + return tensor + + def destroy(self): + if self.device_group is not None: + torch.distributed.destroy_process_group(self.device_group) + self.device_group = None + if self.cpu_group is not None: + torch.distributed.destroy_process_group(self.cpu_group) + self.cpu_group = None + if self.pynccl_comm is not None: + self.pynccl_comm = None + if self.ca_comm is not None: + self.ca_comm = None + if self.mq_broadcaster is not None: + self.mq_broadcaster = None + + +_WORLD: Optional[GroupCoordinator] = None + + +def get_world_group() -> GroupCoordinator: + assert _WORLD is not None, "world group is not initialized" + return _WORLD + + +def init_world_group( + ranks: List[int], local_rank: int, backend: str +) -> GroupCoordinator: + return GroupCoordinator( + group_ranks=[ranks], + local_rank=local_rank, + torch_distributed_backend=backend, + use_pynccl=False, + use_custom_allreduce=False, + use_tpu_communicator=False, + use_hpu_communicator=False, + use_xpu_communicator=False, + group_name="world", + ) + + +def init_model_parallel_group( + group_ranks: List[List[int]], + local_rank: int, + backend: str, + use_custom_allreduce: Optional[bool] = None, + use_message_queue_broadcaster: bool = False, + group_name: Optional[str] = None, +) -> GroupCoordinator: + if use_custom_allreduce is None: + use_custom_allreduce = _ENABLE_CUSTOM_ALL_REDUCE + return GroupCoordinator( + group_ranks=group_ranks, + local_rank=local_rank, + torch_distributed_backend=backend, + use_pynccl=True, + use_custom_allreduce=use_custom_allreduce, + use_tpu_communicator=True, + use_hpu_communicator=True, + use_xpu_communicator=True, + use_message_queue_broadcaster=use_message_queue_broadcaster, + group_name=group_name, + ) + + +_TP: Optional[GroupCoordinator] = None + + +def get_tp_group() -> GroupCoordinator: + assert _TP is not None, "tensor model parallel group is not initialized" + return _TP + + +# kept for backward compatibility +get_tensor_model_parallel_group = get_tp_group + +_PP: Optional[GroupCoordinator] = None + + +def get_pp_group() -> GroupCoordinator: + assert _PP is not None, "pipeline model parallel group is not initialized" + return _PP + + +# kept for backward compatibility +get_pipeline_model_parallel_group = get_pp_group + + +@contextmanager +def graph_capture(): + """ + `graph_capture` is a context manager which should surround the code that + is capturing the CUDA graph. Its main purpose is to ensure that the + some operations will be run after the graph is captured, before the graph + is replayed. It returns a `GraphCaptureContext` object which contains the + necessary data for the graph capture. Currently, it only contains the + stream that the graph capture is running on. This stream is set to the + current CUDA stream when the context manager is entered and reset to the + default stream when the context manager is exited. This is to ensure that + the graph capture is running on a separate stream from the default stream, + in order to explicitly distinguish the kernels to capture + from other kernels possibly launched on background in the default stream. + """ + with get_tp_group().graph_capture() as context, get_pp_group().graph_capture( + context + ): + yield context + + +logger = logging.getLogger(__name__) + +_ENABLE_CUSTOM_ALL_REDUCE = True + + +def set_custom_all_reduce(enable: bool): + global _ENABLE_CUSTOM_ALL_REDUCE + _ENABLE_CUSTOM_ALL_REDUCE = enable + + +def init_distributed_environment( + world_size: int = -1, + rank: int = -1, + distributed_init_method: str = "env://", + local_rank: int = -1, + backend: str = "nccl", +): + logger.debug( + "world_size=%d rank=%d local_rank=%d " "distributed_init_method=%s backend=%s", + world_size, + rank, + local_rank, + distributed_init_method, + backend, + ) + if not torch.distributed.is_initialized(): + assert distributed_init_method is not None, ( + "distributed_init_method must be provided when initializing " + "distributed environment" + ) + # this backend is used for WORLD + torch.distributed.init_process_group( + backend=backend, + init_method=distributed_init_method, + world_size=world_size, + rank=rank, + ) + # set the local rank + # local_rank is not available in torch ProcessGroup, + # see https://github.com/pytorch/pytorch/issues/122816 + if local_rank == -1: + # local rank not set, this usually happens in single-node + # setting, where we can use rank as local rank + if distributed_init_method == "env://": + local_rank = envs.LOCAL_RANK + else: + local_rank = rank + global _WORLD + if _WORLD is None: + ranks = list(range(torch.distributed.get_world_size())) + _WORLD = init_world_group(ranks, local_rank, backend) + else: + assert ( + _WORLD.world_size == torch.distributed.get_world_size() + ), "world group already initialized with a different world size" + + +def initialize_model_parallel( + tensor_model_parallel_size: int = 1, + pipeline_model_parallel_size: int = 1, + backend: Optional[str] = None, +) -> None: + """ + Initialize model parallel groups. + + Arguments: + tensor_model_parallel_size: number of GPUs used for tensor model + parallelism. + pipeline_model_parallel_size: number of GPUs used for pipeline model + parallelism. + + Let's say we have a total of 8 GPUs denoted by g0 ... g7 and we + use 2 GPUs to parallelize the model tensor, and 4 GPUs to parallelize + the model pipeline. The present function will + create 4 tensor model-parallel groups and 2 pipeline model-parallel groups: + 4 tensor model-parallel groups: + [g0, g1], [g2, g3], [g4, g5], [g6, g7] + 2 pipeline model-parallel groups: + [g0, g2, g4, g6], [g1, g3, g5, g7] + Note that for efficiency, the caller should make sure adjacent ranks + are on the same DGX box. For example if we are using 2 DGX-1 boxes + with a total of 16 GPUs, rank 0 to 7 belong to the first box and + ranks 8 to 15 belong to the second box. + """ + # Get world size and rank. Ensure some consistencies. + assert torch.distributed.is_initialized() + world_size: int = torch.distributed.get_world_size() + backend = backend or torch.distributed.get_backend(get_world_group().device_group) + + if world_size != tensor_model_parallel_size * pipeline_model_parallel_size: + raise RuntimeError( + f"world_size ({world_size}) is not equal to " + f"tensor_model_parallel_size ({tensor_model_parallel_size}) x " + f"pipeline_model_parallel_size ({pipeline_model_parallel_size})" + ) + + # Build the tensor model-parallel groups. + num_tensor_model_parallel_groups: int = world_size // tensor_model_parallel_size + global _TP + assert _TP is None, "tensor model parallel group is already initialized" + group_ranks = [] + for i in range(num_tensor_model_parallel_groups): + ranks = list( + range(i * tensor_model_parallel_size, (i + 1) * tensor_model_parallel_size) + ) + group_ranks.append(ranks) + + # message queue broadcaster is only used in tensor model parallel group + _TP = init_model_parallel_group( + group_ranks, + get_world_group().local_rank, + backend, + use_message_queue_broadcaster=True, + group_name="tp", + ) + + # Build the pipeline model-parallel groups. + num_pipeline_model_parallel_groups: int = world_size // pipeline_model_parallel_size + global _PP + assert _PP is None, "pipeline model parallel group is already initialized" + group_ranks = [] + for i in range(num_pipeline_model_parallel_groups): + ranks = list(range(i, world_size, num_pipeline_model_parallel_groups)) + group_ranks.append(ranks) + # pipeline parallel does not need custom allreduce + _PP = init_model_parallel_group( + group_ranks, + get_world_group().local_rank, + backend, + use_custom_allreduce=False, + group_name="pp", + ) + + +def ensure_model_parallel_initialized( + tensor_model_parallel_size: int, + pipeline_model_parallel_size: int, + backend: Optional[str] = None, +) -> None: + """Helper to initialize model parallel groups if they are not initialized, + or ensure tensor-parallel and pipeline-parallel sizes are equal to expected + values if the model parallel groups are initialized. + """ + backend = backend or torch.distributed.get_backend(get_world_group().device_group) + if not model_parallel_is_initialized(): + initialize_model_parallel( + tensor_model_parallel_size, pipeline_model_parallel_size, backend + ) + return + + assert get_tensor_model_parallel_world_size() == tensor_model_parallel_size, ( + "tensor parallel group already initialized, but of unexpected size: " + f"{get_tensor_model_parallel_world_size()=} vs. " + f"{tensor_model_parallel_size=}" + ) + pp_world_size = get_pp_group().world_size + assert pp_world_size == pipeline_model_parallel_size, ( + "pipeline parallel group already initialized, but of unexpected size: " + f"{pp_world_size=} vs. " + f"{pipeline_model_parallel_size=}" + ) + + +def model_parallel_is_initialized(): + """Check if tensor and pipeline parallel groups are initialized.""" + return _TP is not None and _PP is not None + + +_TP_STATE_PATCHED = False + + +@contextmanager +def patch_tensor_parallel_group(tp_group: GroupCoordinator): + """Patch the tp group temporarily until this function ends. + + This method is for draft workers of speculative decoding to run draft model + with different tp degree from that of target model workers. + + Args: + tp_group (GroupCoordinator): the tp group coordinator + """ + global _TP_STATE_PATCHED + assert not _TP_STATE_PATCHED, "Should not call when it's already patched" + + _TP_STATE_PATCHED = True + old_tp_group = get_tp_group() + global _TP + _TP = tp_group + try: + yield + finally: + # restore the original state + _TP_STATE_PATCHED = False + _TP = old_tp_group + + +def get_tensor_model_parallel_world_size(): + """Return world size for the tensor model parallel group.""" + return get_tp_group().world_size + + +def get_tensor_model_parallel_rank(): + """Return my rank for the tensor model parallel group.""" + return get_tp_group().rank_in_group + + +def destroy_model_parallel(): + """Set the groups to none and destroy them.""" + global _TP + if _TP: + _TP.destroy() + _TP = None + + global _PP + if _PP: + _PP.destroy() + _PP = None + + +def destroy_distributed_environment(): + global _WORLD + if _WORLD: + _WORLD.destroy() + _WORLD = None + if torch.distributed.is_initialized(): + torch.distributed.destroy_process_group() + + +def cleanup_dist_env_and_memory(shutdown_ray: bool = False): + destroy_model_parallel() + destroy_distributed_environment() + with contextlib.suppress(AssertionError): + torch.distributed.destroy_process_group() + if shutdown_ray: + import ray # Lazy import Ray + + ray.shutdown() + gc.collect() + if not current_platform.is_cpu(): + torch.cuda.empty_cache() + + +def in_the_same_node_as(pg: ProcessGroup, source_rank: int = 0) -> List[bool]: + """ + This is a collective operation that returns if each rank is in the same node + as the source rank. It tests if processes are attached to the same + memory system (shared access to shared memory). + """ + assert ( + torch.distributed.get_backend(pg) != torch.distributed.Backend.NCCL + ), "in_the_same_node_as should be tested with a non-NCCL group." + # local rank inside the group + rank = torch.distributed.get_rank(group=pg) + world_size = torch.distributed.get_world_size(group=pg) + + # local tensor in each process to store the result + is_in_the_same_node = torch.tensor([0] * world_size, dtype=torch.int32) + + # global ranks of the processes in the group + ranks = torch.distributed.get_process_group_ranks(pg) + + magic_message = b"magic_message" + shm = None + + try: + with contextlib.suppress(OSError): + if rank == source_rank: + # create a shared memory segment + shm = shared_memory.SharedMemory(create=True, size=128) + shm.buf[: len(magic_message)] = magic_message + torch.distributed.broadcast_object_list( + [shm.name], src=ranks[source_rank], group=pg + ) + is_in_the_same_node[rank] = 1 + else: + # try to open the shared memory segment + recv = [None] + torch.distributed.broadcast_object_list( + recv, src=ranks[source_rank], group=pg + ) + name = recv[0] + # fix to https://stackoverflow.com/q/62748654/9191338 + # Python incorrectly tracks shared memory even if it is not + # created by the process. The following patch is a workaround. + with patch( + "multiprocessing.resource_tracker.register", + lambda *args, **kwargs: None, + ): + shm = shared_memory.SharedMemory(name=name) + if shm.buf[: len(magic_message)] == magic_message: + is_in_the_same_node[rank] = 1 + except Exception as e: + logger.error("Error ignored in is_in_the_same_node: %s", e) + finally: + if shm: + shm.close() + + torch.distributed.barrier(group=pg) + + # clean up the shared memory segment + with contextlib.suppress(OSError): + if rank == source_rank and shm: + shm.unlink() + torch.distributed.all_reduce(is_in_the_same_node, group=pg) + + return [x == 1 for x in is_in_the_same_node.tolist()] + + +vllm_get_pp_group = None +vllm_get_tp_group = None +vllm_get_world_group = None + + +def monkey_patch_vllm_parallel_state(reverse: bool = False): + import vllm.distributed.parallel_state as vllm_parrlel_state + + global vllm_get_pp_group, vllm_get_tp_group, vllm_get_world_group + if vllm_get_pp_group is None: + vllm_get_pp_group = vllm_parrlel_state.get_pp_group + vllm_get_tp_group = vllm_parrlel_state.get_tp_group + vllm_get_world_group = vllm_parrlel_state.get_world_group + if reverse: + setattr(vllm_parrlel_state, "get_pp_group", vllm_get_pp_group) + setattr(vllm_parrlel_state, "get_tp_group", vllm_get_tp_group) + setattr(vllm_parrlel_state, "get_world_group", vllm_get_world_group) + else: + setattr(vllm_parrlel_state, "get_pp_group", get_pp_group) + setattr(vllm_parrlel_state, "get_tp_group", get_tp_group) + setattr(vllm_parrlel_state, "get_world_group", get_world_group) diff --git a/python/sglang/srt/distributed/utils.py b/python/sglang/srt/distributed/utils.py new file mode 100644 index 0000000000..aa6dbab955 --- /dev/null +++ b/python/sglang/srt/distributed/utils.py @@ -0,0 +1,221 @@ +# Copyright 2023 The vLLM team. +# Adapted from +# https://github.com/NVIDIA/Megatron-LM/blob/main/megatron/core/tensor_parallel/utils.py +# Copyright (c) 2022, NVIDIA CORPORATION. All rights reserved. +import dataclasses +import logging +import pickle +import time +from collections import deque +from typing import Any, Deque, Dict, Optional, Sequence, Tuple + +import torch +import vllm.envs as envs +from torch.distributed import TCPStore + +logger = logging.getLogger(__name__) + + +def ensure_divisibility(numerator, denominator): + """Ensure that numerator is divisible by the denominator.""" + assert numerator % denominator == 0, "{} is not divisible by {}".format( + numerator, denominator + ) + + +def divide(numerator, denominator): + """Ensure that numerator is divisible by the denominator and return + the division value.""" + ensure_divisibility(numerator, denominator) + return numerator // denominator + + +def split_tensor_along_last_dim( + tensor: torch.Tensor, + num_partitions: int, + contiguous_split_chunks: bool = False, +) -> Sequence[torch.Tensor]: + """Split a tensor along its last dimension. + + Arguments: + tensor: input tensor. + num_partitions: number of partitions to split the tensor + contiguous_split_chunks: If True, make each chunk contiguous + in memory. + + Returns: + A list of Tensors + """ + # Get the size and dimension. + last_dim = tensor.dim() - 1 + last_dim_size = divide(tensor.size()[last_dim], num_partitions) + # Split. + tensor_list = torch.split(tensor, last_dim_size, dim=last_dim) + # NOTE: torch.split does not create contiguous tensors by default. + if contiguous_split_chunks: + return tuple(chunk.contiguous() for chunk in tensor_list) + + return tensor_list + + +def get_pp_indices( + num_hidden_layers: int, pp_rank: int, pp_size: int +) -> Tuple[int, int]: + """Try to evenly distribute layers across partitions. + If the number of layers is not divisible by the number of partitions, + the last partition will have the remaining layers. + """ + partition_list_str = envs.VLLM_PP_LAYER_PARTITION + if partition_list_str is not None: + try: + partitions = [int(layer) for layer in partition_list_str.split(",")] + except ValueError as err: + raise ValueError( + "Invalid partition string: {}".format(partition_list_str) + ) from err + if len(partitions) != pp_size: + raise ValueError(f"{len(partitions)=} does not match {pp_size=}.") + if sum(partitions) != num_hidden_layers: + raise ValueError(f"{sum(partitions)=} does not match {num_hidden_layers=}.") + start_layer = sum(partitions[:pp_rank]) + end_layer = start_layer + partitions[pp_rank] + else: + layers_per_partition = num_hidden_layers // pp_size + start_layer = pp_rank * layers_per_partition + end_layer = start_layer + layers_per_partition + + if pp_rank == pp_size - 1: + end_layer = num_hidden_layers + + return (start_layer, end_layer) + + +@dataclasses.dataclass +class StatelessProcessGroup: + """A dataclass to hold a metadata store, and the rank, world_size of the + group. Only use it to communicate metadata between processes. + For data-plane communication, create NCCL-related objects. + """ + + rank: int + world_size: int + store: torch._C._distributed_c10d.Store + data_expiration_seconds: int = 3600 # 1 hour + + # dst rank -> counter + send_dst_counter: Dict[int, int] = dataclasses.field(default_factory=dict) + # src rank -> counter + recv_src_counter: Dict[int, int] = dataclasses.field(default_factory=dict) + broadcast_send_counter: int = 0 + broadcast_recv_src_counter: Dict[int, int] = dataclasses.field(default_factory=dict) + + # A deque to store the data entries, with key and timestamp. + entries: Deque[Tuple[str, float]] = dataclasses.field(default_factory=deque) + + def __post_init__(self): + assert self.rank < self.world_size + self.send_dst_counter = {i: 0 for i in range(self.world_size)} + self.recv_src_counter = {i: 0 for i in range(self.world_size)} + self.broadcast_recv_src_counter = {i: 0 for i in range(self.world_size)} + + def send_obj(self, obj: Any, dst: int): + """Send an object to a destination rank.""" + self.expire_data() + key = f"send_to/{dst}/{self.send_dst_counter[dst]}" + self.store.set(key, pickle.dumps(obj)) + self.send_dst_counter[dst] += 1 + self.entries.append((key, time.time())) + + def expire_data(self): + """Expire data that is older than `data_expiration_seconds` seconds.""" + while self.entries: + # check the oldest entry + key, timestamp = self.entries[0] + if time.time() - timestamp > self.data_expiration_seconds: + self.store.delete_key(key) + self.entries.popleft() + else: + break + + def recv_obj(self, src: int) -> Any: + """Receive an object from a source rank.""" + obj = pickle.loads( + self.store.get(f"send_to/{self.rank}/{self.recv_src_counter[src]}") + ) + self.recv_src_counter[src] += 1 + return obj + + def broadcast_obj(self, obj: Optional[Any], src: int) -> Any: + """Broadcast an object from a source rank to all other ranks. + It does not clean up after all ranks have received the object. + Use it for limited times, e.g., for initialization. + """ + if self.rank == src: + self.expire_data() + key = f"broadcast_from/{src}/" f"{self.broadcast_send_counter}" + self.store.set(key, pickle.dumps(obj)) + self.broadcast_send_counter += 1 + self.entries.append((key, time.time())) + return obj + else: + key = f"broadcast_from/{src}/" f"{self.broadcast_recv_src_counter[src]}" + recv_obj = pickle.loads(self.store.get(key)) + self.broadcast_recv_src_counter[src] += 1 + return recv_obj + + def all_gather_obj(self, obj: Any) -> list[Any]: + """All gather an object from all ranks.""" + gathered_objs = [] + for i in range(self.world_size): + if i == self.rank: + gathered_objs.append(obj) + self.broadcast_obj(obj, src=self.rank) + else: + recv_obj = self.broadcast_obj(None, src=i) + gathered_objs.append(recv_obj) + return gathered_objs + + def barrier(self): + """A barrier to synchronize all ranks.""" + for i in range(self.world_size): + if i == self.rank: + self.broadcast_obj(None, src=self.rank) + else: + self.broadcast_obj(None, src=i) + + @staticmethod + def create( + host: str, + port: int, + rank: int, + world_size: int, + data_expiration_seconds: int = 3600, + ) -> "StatelessProcessGroup": + """A replacement for `torch.distributed.init_process_group` that does not + pollute the global state. + + If we have process A and process B called `torch.distributed.init_process_group` + to form a group, and then we want to form another group with process A, B, C, + D, it is not possible in PyTorch, because process A and process B have already + formed a group, and process C and process D cannot join that group. This + function is a workaround for this issue. + + `torch.distributed.init_process_group` is a global call, while this function + is a stateless call. It will return a `StatelessProcessGroup` object that can be + used for exchanging metadata. With this function, process A and process B + can call `StatelessProcessGroup.create` to form a group, and then process A, B, + C, and D can call `StatelessProcessGroup.create` to form another group. + """ # noqa + store = TCPStore( + host_name=host, + port=port, + world_size=world_size, + is_master=(rank == 0), + ) + + return StatelessProcessGroup( + rank=rank, + world_size=world_size, + store=store, + data_expiration_seconds=data_expiration_seconds, + ) diff --git a/python/sglang/srt/utils.py b/python/sglang/srt/utils.py index 46b4db8e88..d9a36742ee 100644 --- a/python/sglang/srt/utils.py +++ b/python/sglang/srt/utils.py @@ -944,6 +944,12 @@ def get_device_name(device_id: int = 0) -> str: sglang_lib = Library("sglang", "FRAGMENT") # noqa +# Some backends use pytorch version < 2.4.0 which doesn't +# support `torch.library.custom_op`. +def supports_custom_op() -> bool: + return hasattr(torch.library, "custom_op") + + def direct_register_custom_op( op_name: str, op_func: Callable, From 9fecd6e25a583e9d60b9d2e6299fda683f79f70e Mon Sep 17 00:00:00 2001 From: Zhangyi <1109276519@qq.com> Date: Fri, 29 Nov 2024 00:14:33 +0800 Subject: [PATCH 02/10] add is_hpu func since vllm=0.6.3.post1 does not support current_platform.is_hpu() --- python/sglang/srt/_custom_ops.py | 7 +++++-- .../device_communicators/custom_all_reduce_utils.py | 2 +- .../distributed/device_communicators/hpu_communicator.py | 7 ++++--- .../srt/distributed/device_communicators/shm_broadcast.py | 2 +- python/sglang/srt/utils.py | 4 ++++ 5 files changed, 15 insertions(+), 7 deletions(-) diff --git a/python/sglang/srt/_custom_ops.py b/python/sglang/srt/_custom_ops.py index 604b75bb95..22442181f4 100644 --- a/python/sglang/srt/_custom_ops.py +++ b/python/sglang/srt/_custom_ops.py @@ -10,15 +10,18 @@ # import vllm.envs as envs from vllm.platforms import current_platform +from sglang.srt.utils import is_hpu + # from vllm.scalar_type import ScalarType logger = logging.getLogger(__name__) -if not current_platform.is_tpu() and not current_platform.is_hpu(): +# if not current_platform.is_tpu() and not current_platform.is_hpu(): +if not current_platform.is_tpu() and not is_hpu(): try: import custom_ar except ImportError as e: - logger.warning("Failed to import from vllm._C with %r", e) + logger.warning("Failed to import from custom_ar with %r", e) """ if current_platform.is_rocm(): diff --git a/python/sglang/srt/distributed/device_communicators/custom_all_reduce_utils.py b/python/sglang/srt/distributed/device_communicators/custom_all_reduce_utils.py index 6c3f3c5c6b..6b73e1b138 100644 --- a/python/sglang/srt/distributed/device_communicators/custom_all_reduce_utils.py +++ b/python/sglang/srt/distributed/device_communicators/custom_all_reduce_utils.py @@ -214,7 +214,7 @@ def gpu_p2p_access_check(src: int, tgt: int) -> bool: envs.VLLM_CACHE_ROOT, f"gpu_p2p_access_cache_for_{cuda_visible_devices}.json" ) os.makedirs(os.path.dirname(path), exist_ok=True) - from vllm.distributed.parallel_state import get_world_group + from sglang.srt.distributed.parallel_state import get_world_group if (not is_distributed or get_world_group().local_rank == 0) and ( not os.path.exists(path) diff --git a/python/sglang/srt/distributed/device_communicators/hpu_communicator.py b/python/sglang/srt/distributed/device_communicators/hpu_communicator.py index 4ce060e091..671e7d283b 100644 --- a/python/sglang/srt/distributed/device_communicators/hpu_communicator.py +++ b/python/sglang/srt/distributed/device_communicators/hpu_communicator.py @@ -1,16 +1,17 @@ import torch import torch.distributed as dist from torch.distributed import ProcessGroup -from vllm.platforms import current_platform -if current_platform.is_hpu(): +from sglang.srt.utils import is_hpu + +if is_hpu(): import habana_frameworks.torch as htorch # noqa: F401 class HpuCommunicator: def __init__(self, group: ProcessGroup): - if not current_platform.is_hpu(): + if not is_hpu(): self.disabled = True return self.disabled = False diff --git a/python/sglang/srt/distributed/device_communicators/shm_broadcast.py b/python/sglang/srt/distributed/device_communicators/shm_broadcast.py index e1a337ab7c..bf91a99158 100644 --- a/python/sglang/srt/distributed/device_communicators/shm_broadcast.py +++ b/python/sglang/srt/distributed/device_communicators/shm_broadcast.py @@ -467,7 +467,7 @@ def create_from_process_group( group_world_size = dist.get_world_size(pg) global_ranks = dist.get_process_group_ranks(pg) - from vllm.distributed.parallel_state import in_the_same_node_as + from sglang.srt.distributed.parallel_state import in_the_same_node_as status = in_the_same_node_as(pg, source_rank=writer_rank) same_node_ranks = [i for i, s in enumerate(status) if s] diff --git a/python/sglang/srt/utils.py b/python/sglang/srt/utils.py index d9a36742ee..a84dec9e8d 100644 --- a/python/sglang/srt/utils.py +++ b/python/sglang/srt/utils.py @@ -67,6 +67,10 @@ def is_hip() -> bool: return torch.version.hip is not None +def is_hpu() -> bool: + return hasattr(torch, "hpu") and torch.hpu.is_available() + + def is_flashinfer_available(): """ Check whether flashinfer is available. From d637d0b04cf7968154799a7c93e455a9eaa2cdb2 Mon Sep 17 00:00:00 2001 From: Zhangyi <1109276519@qq.com> Date: Sat, 30 Nov 2024 00:26:18 +0800 Subject: [PATCH 03/10] add vllm version and remove unused code in _custom_ops --- python/sglang/srt/_custom_ops.py | 883 +----------------- .../srt/distributed/communication_op.py | 1 + .../device_communicators/cuda_wrapper.py | 1 + .../device_communicators/custom_all_reduce.py | 1 + .../custom_all_reduce_utils.py | 1 + .../device_communicators/hpu_communicator.py | 1 + .../device_communicators/pynccl.py | 2 +- .../device_communicators/pynccl_wrapper.py | 1 + .../device_communicators/shm_broadcast.py | 1 + .../sglang/srt/distributed/parallel_state.py | 1 + python/sglang/srt/distributed/utils.py | 1 + 11 files changed, 11 insertions(+), 883 deletions(-) diff --git a/python/sglang/srt/_custom_ops.py b/python/sglang/srt/_custom_ops.py index 22442181f4..778aef899f 100644 --- a/python/sglang/srt/_custom_ops.py +++ b/python/sglang/srt/_custom_ops.py @@ -1,3 +1,4 @@ +# reference: VLLM 0.6.4.post1 import contextlib import functools import importlib @@ -23,27 +24,6 @@ except ImportError as e: logger.warning("Failed to import from custom_ar with %r", e) -""" -if current_platform.is_rocm(): - import vllm._rocm_C # noqa: F401 - -supports_moe_ops = False -with contextlib.suppress(ImportError): - import vllm._moe_C # noqa: F401 - supports_moe_ops = True - -# neuron has torch version that doesn't even have impl_abstract -if TYPE_CHECKING or current_platform.is_neuron(): - - def register_fake(fn): - return lambda name: fn -else: - try: - from torch.library import register_fake - except ImportError: - from torch.library import impl_abstract as register_fake -""" - def hint_on_error(fn): @@ -74,867 +54,6 @@ def wrapper(*args, **kwargs): return wrapper -''' -# activation ops -def silu_and_mul(out: torch.Tensor, x: torch.Tensor) -> None: - torch.ops._C.silu_and_mul(out, x) - - -def gelu_and_mul(out: torch.Tensor, x: torch.Tensor) -> None: - torch.ops._C.gelu_and_mul(out, x) - - -def gelu_tanh_and_mul(out: torch.Tensor, x: torch.Tensor) -> None: - torch.ops._C.gelu_tanh_and_mul(out, x) - - -def fatrelu_and_mul(out: torch.Tensor, - x: torch.Tensor, - threshold: float = 0.0) -> None: - torch.ops._C.fatrelu_and_mul(out, x, threshold) - - -def gelu_fast(out: torch.Tensor, x: torch.Tensor) -> None: - torch.ops._C.gelu_fast(out, x) - - -def gelu_new(out: torch.Tensor, x: torch.Tensor) -> None: - torch.ops._C.gelu_new(out, x) - - -def gelu_quick(out: torch.Tensor, x: torch.Tensor) -> None: - torch.ops._C.gelu_quick(out, x) - - -# page attention ops -def paged_attention_v1( - out: torch.Tensor, - query: torch.Tensor, - key_cache: torch.Tensor, - value_cache: torch.Tensor, - num_kv_heads: int, - scale: float, - block_tables: torch.Tensor, - seq_lens: torch.Tensor, - block_size: int, - max_seq_len: int, - alibi_slopes: Optional[torch.Tensor], - kv_cache_dtype: str, - k_scale: float, - v_scale: float, - tp_rank: int = 0, - blocksparse_local_blocks: int = 0, - blocksparse_vert_stride: int = 0, - blocksparse_block_size: int = 64, - blocksparse_head_sliding_step: int = 0, -) -> None: - torch.ops._C.paged_attention_v1( - out, query, key_cache, value_cache, num_kv_heads, scale, block_tables, - seq_lens, block_size, max_seq_len, alibi_slopes, kv_cache_dtype, - k_scale, v_scale, tp_rank, blocksparse_local_blocks, - blocksparse_vert_stride, blocksparse_block_size, - blocksparse_head_sliding_step) - - -def paged_attention_v2( - out: torch.Tensor, - exp_sum: torch.Tensor, - max_logits: torch.Tensor, - tmp_out: torch.Tensor, - query: torch.Tensor, - key_cache: torch.Tensor, - value_cache: torch.Tensor, - num_kv_heads: int, - scale: float, - block_tables: torch.Tensor, - seq_lens: torch.Tensor, - block_size: int, - max_seq_len: int, - alibi_slopes: Optional[torch.Tensor], - kv_cache_dtype: str, - k_scale: float, - v_scale: float, - tp_rank: int = 0, - blocksparse_local_blocks: int = 0, - blocksparse_vert_stride: int = 0, - blocksparse_block_size: int = 64, - blocksparse_head_sliding_step: int = 0, -) -> None: - torch.ops._C.paged_attention_v2( - out, exp_sum, max_logits, tmp_out, query, key_cache, value_cache, - num_kv_heads, scale, block_tables, seq_lens, block_size, max_seq_len, - alibi_slopes, kv_cache_dtype, k_scale, v_scale, tp_rank, - blocksparse_local_blocks, blocksparse_vert_stride, - blocksparse_block_size, blocksparse_head_sliding_step) - - -def paged_attention_rocm( - out: torch.Tensor, - exp_sum: torch.Tensor, - max_logits: torch.Tensor, - tmp_out: torch.Tensor, - query: torch.Tensor, - key_cache: torch.Tensor, - value_cache: torch.Tensor, - num_kv_heads: int, - scale: float, - block_tables: torch.Tensor, - seq_lens: torch.Tensor, - block_size: int, - max_seq_len: int, - alibi_slopes: Optional[torch.Tensor], - kv_cache_dtype: str, - k_scale: float, - v_scale: float, -) -> None: - torch.ops._rocm_C.paged_attention(out, exp_sum, max_logits, tmp_out, query, - key_cache, value_cache, num_kv_heads, - scale, block_tables, seq_lens, - block_size, max_seq_len, alibi_slopes, - kv_cache_dtype, k_scale, v_scale) - - -# pos encoding ops -def rotary_embedding( - positions: torch.Tensor, - query: torch.Tensor, - key: torch.Tensor, - head_size: int, - cos_sin_cache: torch.Tensor, - is_neox: bool, -) -> None: - torch.ops._C.rotary_embedding(positions, query, key, head_size, - cos_sin_cache, is_neox) - - -def batched_rotary_embedding(positions: torch.Tensor, query: torch.Tensor, - key: torch.Tensor, head_size: int, - cos_sin_cache: torch.Tensor, is_neox: bool, - rot_dim: int, - cos_sin_cache_offsets: torch.Tensor) -> None: - torch.ops._C.batched_rotary_embedding(positions, query, key, head_size, - cos_sin_cache, is_neox, rot_dim, - cos_sin_cache_offsets) - - -# layer norm ops -def rms_norm(out: torch.Tensor, input: torch.Tensor, weight: torch.Tensor, - epsilon: float) -> None: - torch.ops._C.rms_norm(out, input, weight, epsilon) - - -def fused_add_rms_norm(input: torch.Tensor, residual: torch.Tensor, - weight: torch.Tensor, epsilon: float) -> None: - torch.ops._C.fused_add_rms_norm(input, residual, weight, epsilon) - - -def advance_step_flashattn(num_seqs: int, num_queries: int, block_size: int, - input_tokens: torch.Tensor, - sampled_token_ids: torch.Tensor, - input_positions: torch.Tensor, - seq_lens: torch.Tensor, slot_mapping: torch.Tensor, - block_tables: torch.Tensor) -> None: - """Advance a step on GPU for existing inputs for a multi-step runner""" - return torch.ops._C.advance_step_flashattn(num_seqs, num_queries, - block_size, input_tokens, - sampled_token_ids, - input_positions, seq_lens, - slot_mapping, block_tables) - - -def advance_step_flashinfer(num_seqs: int, num_queries: int, block_size: int, - input_tokens: torch.Tensor, - sampled_token_ids: torch.Tensor, - input_positions: torch.Tensor, - seq_lens: torch.Tensor, slot_mapping: torch.Tensor, - block_tables: torch.Tensor, - paged_kv_indices: torch.Tensor, - paged_kv_indptr: torch.Tensor, - paged_kv_last_page_len: torch.Tensor, - block_table_bound: torch.Tensor) -> None: - - return torch.ops._C.advance_step_flashinfer( - num_seqs, num_queries, block_size, input_tokens, sampled_token_ids, - input_positions, seq_lens, slot_mapping, block_tables, - paged_kv_indices, paged_kv_indptr, paged_kv_last_page_len, - block_table_bound) - - -# quantization ops -# awq -def awq_dequantize(qweight: torch.Tensor, scales: torch.Tensor, - zeros: torch.Tensor, split_k_iters: int, thx: int, - thy: int) -> torch.Tensor: - if envs.VLLM_USE_TRITON_AWQ: - from vllm.model_executor.layers.quantization.awq_triton import ( - awq_dequantize_triton) - return awq_dequantize_triton(qweight, scales, zeros) - return torch.ops._C.awq_dequantize(qweight, scales, zeros, split_k_iters, - thx, thy) - - -def awq_gemm(input: torch.Tensor, qweight: torch.Tensor, qzeros: torch.Tensor, - scales: torch.Tensor, split_k_iters: int) -> torch.Tensor: - if envs.VLLM_USE_TRITON_AWQ: - from vllm.model_executor.layers.quantization.awq_triton import ( - awq_gemm_triton) - return awq_gemm_triton(input, qweight, qzeros, scales, split_k_iters) - return torch.ops._C.awq_gemm(input, qweight, qzeros, scales, split_k_iters) - - -# gptq -def gptq_gemm(a: torch.Tensor, b_q_weight: torch.Tensor, - b_gptq_qzeros: torch.Tensor, b_gptq_scales: torch.Tensor, - b_g_idx: torch.Tensor, use_exllama: bool, - bit: int) -> torch.Tensor: - return torch.ops._C.gptq_gemm(a, b_q_weight, b_gptq_qzeros, b_gptq_scales, - b_g_idx, use_exllama, bit) - - -if hasattr(torch.ops._C, "gptq_gemm"): - - @register_fake("_C::gptq_gemm") - def _gptq_gemm_fake(a: torch.Tensor, b_q_weight: torch.Tensor, - b_gptq_qzeros: torch.Tensor, - b_gptq_scales: torch.Tensor, b_g_idx: torch.Tensor, - use_exllama: bool, bit: int) -> torch.Tensor: - return torch.empty((a.size(0), b_q_weight.size(1)), - dtype=a.dtype, - device=a.device) - - -def gptq_shuffle(q_weight: torch.Tensor, q_perm: torch.Tensor, - bit: int) -> None: - torch.ops._C.gptq_shuffle(q_weight, q_perm, bit) - - -# marlin -def marlin_gemm(a: torch.Tensor, b_q_weight: torch.Tensor, - b_scales: torch.Tensor, workspace: torch.Tensor, size_m: int, - size_n: int, size_k: int) -> torch.Tensor: - return torch.ops._C.marlin_gemm(a, b_q_weight, b_scales, workspace, size_m, - size_n, size_k) - - -# marlin_24 -def gptq_marlin_24_gemm(a: torch.Tensor, b_q_weight: torch.Tensor, - b_meta: torch.Tensor, b_scales: torch.Tensor, - workspace: torch.Tensor, b_q_type: ScalarType, - size_m: int, size_n: int, size_k: int) -> torch.Tensor: - return torch.ops._C.gptq_marlin_24_gemm(a, b_q_weight, b_meta, b_scales, - workspace, b_q_type.id, size_m, - size_n, size_k) - - -if hasattr(torch.ops._C, "gptq_marlin_24_gemm"): - - @register_fake("_C::gptq_marlin_24_gemm") - def _gptq_marlin_24_gemm_fake(a: torch.Tensor, b_q_weight: torch.Tensor, - b_meta: torch.Tensor, b_scales: torch.Tensor, - workspace: torch.Tensor, - b_q_type: ScalarType, size_m: torch.SymInt, - size_n: torch.SymInt, - size_k: torch.SymInt) -> torch.Tensor: - return torch.empty((size_m, size_n), device=a.device, dtype=a.dtype) - - @register_fake("_C::gptq_marlin_gemm") - def _gptq_marlin_gemm_fake(a: torch.Tensor, - b_q_weight: torch.Tensor, - b_scales: torch.Tensor, - b_zeros: torch.Tensor, - g_idx: torch.Tensor, - perm: torch.Tensor, - workspace: torch.Tensor, - b_q_type: ScalarType, - size_m: torch.SymInt, - size_n: torch.SymInt, - size_k: torch.SymInt, - is_k_full: bool, - has_zp: bool = False, - use_fp32_reduce: bool = False) -> torch.Tensor: - return torch.empty((size_m, size_n), device=a.device, dtype=a.dtype) - - @register_fake("_C::ggml_dequantize") - def _ggml_dequantize_fake(W: torch.Tensor, quant_type: int, - m: torch.SymInt, - n: torch.SymInt) -> torch.Tensor: - return torch.empty((m, n), dtype=torch.float16, device=W.device) - - @register_fake("_C::ggml_mul_mat_vec_a8") - def _ggml_mul_mat_vec_a8_fake( - W: torch.Tensor, - X: torch.Tensor, - quant_type: int, - row: torch.SymInt, - ) -> torch.Tensor: - return torch.empty((1, row), dtype=torch.float16, device=W.device) - - @register_fake("_C::ggml_mul_mat_a8") - def _ggml_mul_mat_a8_fake( - W: torch.Tensor, - X: torch.Tensor, - quant_type: int, - row: torch.SymInt, - ) -> torch.Tensor: - batch = X.size(0) - return torch.empty((batch, row), dtype=torch.float16, device=W.device) - - @register_fake("_C::marlin_qqq_gemm") - def _marlin_qqq_gemm_fake(a: torch.Tensor, b_q_weight: torch.Tensor, - s_tok: torch.Tensor, s_ch: torch.Tensor, - s_group: torch.Tensor, workspace: torch.Tensor, - size_m: torch.SymInt, size_n: torch.SymInt, - size_k: torch.SymInt) -> torch.Tensor: - return torch.empty((size_m, size_n), - dtype=torch.float16, - device=a.device) - - @register_fake("_C::marlin_gemm") - def _marlin_gemm_fake(a: torch.Tensor, b_q_weight: torch.Tensor, - b_scales: torch.Tensor, workspace: torch.Tensor, - size_m: torch.SymInt, size_n: torch.SymInt, - size_k: torch.SymInt) -> torch.Tensor: - return torch.empty((size_m, size_n), - dtype=torch.float16, - device=a.device) - - @register_fake("_C::awq_dequantize") - def _awq_dequantize_fake(qweight: torch.Tensor, scales: torch.Tensor, - zeros: torch.Tensor, split_k_iters: torch.SymInt, - thx: int, thy: int) -> torch.Tensor: - in_c = qweight.size(0) - qout_c = qweight.size(1) - out_c = qout_c * 8 - return torch.empty((in_c, out_c), - dtype=scales.dtype, - device=scales.device) - - @register_fake("_C::awq_gemm") - def _awq_gemm_fake(input: torch.Tensor, qweight: torch.Tensor, - qzeros: torch.Tensor, scales: torch.Tensor, - split_k_iters: torch.SymInt) -> torch.Tensor: - num_in_feats = input.size(0) - return torch.empty((split_k_iters, num_in_feats, qweight.size(1) * 8), - dtype=input.dtype, - device=input.device).sum(0) - - @register_fake("_C::aqlm_gemm") - def _aqlm_gemm_fake(input: torch.Tensor, codes: torch.Tensor, - codebooks: torch.Tensor, scales: torch.Tensor, - codebook_partition_sizes: List[int], - bias: Optional[torch.Tensor]) -> torch.Tensor: - out_features = codes.size(0) * codebooks.size(2) - flat_input = input.reshape((-1, input.size(-1))) - flat_output = torch.empty((flat_input.size(0), out_features), - dtype=input.dtype, - device=input.device) - - output_sizes = list(input.shape) - output_sizes.pop() - output_sizes.append(-1) - return flat_output.reshape(tuple(output_sizes)) - - @register_fake("_C::aqlm_dequant") - def _aqlm_dequant_fake( - codes: torch.Tensor, codebooks: torch.Tensor, - codebook_partition_sizes: List[int]) -> torch.Tensor: - in_features = codes.size(1) * 8 - out_features = codes.size(0) - return torch.empty((out_features, in_features), - dtype=codebooks.dtype, - device=codebooks.device) - - @register_fake("_C::fp8_marlin_gemm") - def _fp8_marlin_gemm_fake(a: torch.Tensor, b_q_weight: torch.Tensor, - b_scales: torch.Tensor, workspace: torch.Tensor, - num_bits: int, size_m: torch.SymInt, - size_n: torch.SymInt, - size_k: torch.SymInt) -> torch.Tensor: - return torch.empty((size_m, size_n), dtype=a.dtype, device=a.device) - - @register_fake("_C::machete_gemm") - def machete_gemm_fake( - a: torch.Tensor, - # Should be the tensor returned by machete_prepack_B - b_q: torch.Tensor, - b_type: ScalarType, - b_scales: Optional[torch.Tensor] = None, - b_zeros: Optional[torch.Tensor] = None, - b_group_size: Optional[int] = None, - c: Optional[torch.Tensor] = None, - alpha: Optional[float] = None, - beta: Optional[float] = None, - schedule: Optional[str] = None, - ) -> torch.Tensor: - m = a.size(0) - n = b_q.size(1) - return torch.empty((m, n), device=a.device, dtype=a.dtype) - - @register_fake("_C::machete_prepack_B") - def machete_prepack_B_fake(b_q_weight: torch.Tensor, - b_type: ScalarType) -> torch.Tensor: - return torch.empty_like(b_q_weight, - memory_format=torch.contiguous_format) - - -# cutlass -def cutlass_scaled_mm_supports_fp8(cuda_device_capability: int) -> bool: - return torch.ops._C.cutlass_scaled_mm_supports_fp8(cuda_device_capability) - - -def cutlass_scaled_mm(a: torch.Tensor, - b: torch.Tensor, - scale_a: torch.Tensor, - scale_b: torch.Tensor, - out_dtype: torch.dtype, - bias: Optional[torch.Tensor] = None) -> torch.Tensor: - assert (b.shape[0] % 16 == 0 and b.shape[1] % 16 == 0) - assert (out_dtype is torch.bfloat16 or out_dtype is torch.float16) - assert bias is None or bias.shape[0] == b.shape[ - 1] and bias.dtype == out_dtype - - m = a.shape[0] - n = b.shape[1] - - if current_platform.is_rocm(): - triton_scaled_mm_module = importlib.import_module( - "vllm.model_executor.layers.quantization.compressed_tensors." - "triton_scaled_mm") - triton_scaled_mm = triton_scaled_mm_module.triton_scaled_mm - return triton_scaled_mm(a, b, scale_a, scale_b, out_dtype, bias) - - out = torch.empty((m, n), dtype=out_dtype, device=a.device) - - torch.ops._C.cutlass_scaled_mm(out, a, b, scale_a, scale_b, bias) - - return out - - -def cutlass_scaled_mm_azp(a: torch.Tensor, - b: torch.Tensor, - scale_a: torch.Tensor, - scale_b: torch.Tensor, - out_dtype: torch.dtype, - azp_adj: torch.Tensor, - azp: Optional[torch.Tensor] = None, - bias: Optional[torch.Tensor] = None) -> torch.Tensor: - """ - :param azp_adj: In the per-tensor case, this should include the azp. - Always per-channel. - :param azp: Only set in the per-token case. Per-token if set. - """ - assert (b.shape[0] % 16 == 0 and b.shape[1] % 16 == 0) - assert (out_dtype is torch.bfloat16 or out_dtype is torch.float16) - assert bias is None or bias.numel( - ) == b.shape[1] and bias.dtype == out_dtype - assert azp is None or azp.numel() == a.shape[0] - - m = a.shape[0] - n = b.shape[1] - out = torch.empty((m, n), dtype=out_dtype, device=a.device) - - torch.ops._C.cutlass_scaled_mm_azp(out, a, b, scale_a, scale_b, azp_adj, - azp, bias) - return out - - -# aqlm -def aqlm_gemm(input: torch.Tensor, codes: torch.Tensor, - codebooks: torch.Tensor, scales: torch.Tensor, - codebook_partition_sizes: List[int], - bias: Optional[torch.Tensor]) -> torch.Tensor: - return torch.ops._C.aqlm_gemm(input, codes, codebooks, scales, - codebook_partition_sizes, bias) - - -def aqlm_dequant(codes: torch.Tensor, codebooks: torch.Tensor, - codebook_partition_sizes: List[int]) -> torch.Tensor: - return torch.ops._C.aqlm_dequant(codes, codebooks, - codebook_partition_sizes) - - -# gptq_marlin -def gptq_marlin_repack(b_q_weight: torch.Tensor, perm: torch.Tensor, - size_k: int, size_n: int, - num_bits: int) -> torch.Tensor: - return torch.ops._C.gptq_marlin_repack(b_q_weight, perm, size_k, size_n, - num_bits) - - -# gptq_marlin -def awq_marlin_repack(b_q_weight: torch.Tensor, size_k: int, size_n: int, - num_bits: int) -> torch.Tensor: - return torch.ops._C.awq_marlin_repack(b_q_weight, size_k, size_n, num_bits) - - -def gptq_marlin_moe_repack(b_q_weight: torch.Tensor, perm: torch.Tensor, - size_k: int, size_n: int, - num_bits: int) -> torch.Tensor: - num_experts = b_q_weight.shape[0] - assert size_k % 16 == 0 - output = torch.empty((num_experts, size_k // 16, size_n * (num_bits // 2)), - device=b_q_weight.device, - dtype=b_q_weight.dtype) - for e in range(num_experts): - output[e] = torch.ops._C.gptq_marlin_repack(b_q_weight[e], perm[e], - size_k, size_n, num_bits) - return output - - -def awq_marlin_moe_repack(b_q_weight: torch.Tensor, perm: torch.Tensor, - size_k: int, size_n: int, - num_bits: int) -> torch.Tensor: - num_experts = b_q_weight.shape[0] - assert size_k % 16 == 0 - output = torch.empty((num_experts, size_k // 16, size_n * (num_bits // 2)), - device=b_q_weight.device, - dtype=b_q_weight.dtype) - for e in range(num_experts): - output[e] = torch.ops._C.awq_marlin_repack(b_q_weight[e], size_k, - size_n, num_bits) - return output - - -def gptq_marlin_gemm(a: torch.Tensor, - b_q_weight: torch.Tensor, - b_scales: torch.Tensor, - b_zeros: torch.Tensor, - g_idx: torch.Tensor, - perm: torch.Tensor, - workspace: torch.Tensor, - b_q_type: ScalarType, - size_m: int, - size_n: int, - size_k: int, - is_k_full: bool, - has_zp: bool = False, - use_fp32_reduce: bool = False) -> torch.Tensor: - return torch.ops._C.gptq_marlin_gemm(a, b_q_weight, b_scales, b_zeros, - g_idx, perm, workspace, b_q_type.id, - size_m, size_n, size_k, is_k_full, - has_zp, use_fp32_reduce) - - -# fp8 marlin -def fp8_marlin_gemm(a: torch.Tensor, b_q_weight: torch.Tensor, - b_scales: torch.Tensor, workspace: torch.Tensor, - num_bits: int, size_m: int, size_n: int, - size_k: int) -> torch.Tensor: - return torch.ops._C.fp8_marlin_gemm(a, b_q_weight, b_scales, workspace, - num_bits, size_m, size_n, size_k) - - -# machete -def machete_supported_schedules(b_type: ScalarType) -> List[str]: - return torch.ops._C.machete_supported_schedules(b_type.id) - - -def machete_gemm( - a: torch.Tensor, - b_q: torch.Tensor, # Should be the tensor returned by machete_prepack_B - b_type: ScalarType, - b_scales: Optional[torch.Tensor] = None, - b_zeros: Optional[torch.Tensor] = None, - b_group_size: Optional[int] = None, - c: Optional[torch.Tensor] = None, - alpha: Optional[float] = None, - beta: Optional[float] = None, - schedule: Optional[str] = None, -) -> torch.Tensor: - return torch.ops._C.machete_gemm(a, b_q, b_type.id, b_scales, b_zeros, - b_group_size, c, alpha, beta, schedule) - - -def machete_prepack_B(b_q_weight: torch.Tensor, - b_type: ScalarType) -> torch.Tensor: - return torch.ops._C.machete_prepack_B(b_q_weight, b_type.id) - - -if hasattr(torch.ops._C, "permute_cols"): - - @register_fake("_C::permute_cols") - def _permute_cols_fake(a: torch.Tensor, - perm: torch.Tensor) -> torch.Tensor: - return torch.empty_like(a) - - -def permute_cols(a: torch.Tensor, perm: torch.Tensor) -> torch.Tensor: - return torch.ops._C.permute_cols(a, perm) - - -# fp8 -def scaled_fp8_quant( - input: torch.Tensor, - scale: Optional[torch.Tensor] = None, - num_token_padding: Optional[int] = None, - scale_ub: Optional[torch.Tensor] = None, - use_per_token_if_dynamic: bool = False, -) -> Tuple[torch.Tensor, torch.Tensor]: - """ - Quantize input tensor to FP8 and return quantized tensor and scale. - - This function supports both static and dynamic quantization: If you - provide the scale, it will use static scaling and if you omit it, - the scale will be determined dynamically. The function also allows - optional padding of the output tensors for downstream kernels that - will benefit from padding. - - Args: - input: The input tensor to be quantized to FP8 - scale: Optional scaling factor for the FP8 quantization - scale_ub: Optional upper bound for scaling factor in dynamic - per token case - num_token_padding: If specified, pad the first dimension - of the output to at least this value. - use_per_token_if_dynamic: Whether to do per_tensor or per_token - in the dynamic quantization case. - - Returns: - Tuple[torch.Tensor, torch.Tensor]: The output tensor in FP8 and - scaling factor. - """ - # This code assumes batch_dim and num_tokens are flattened - assert (input.ndim == 2) - shape: Union[Tuple[int, int], torch.Size] = input.shape - # For rocm, the output fp8 dtype is torch.float_e3m3fnuz - out_dtype: torch.dtype = torch.float8_e4m3fnuz \ - if current_platform.is_rocm() else torch.float8_e4m3fn - if num_token_padding: - shape = (max(num_token_padding, input.shape[0]), shape[1]) - output = torch.empty(shape, device=input.device, dtype=out_dtype) - - if scale is None: - if use_per_token_if_dynamic: - scale = torch.empty((shape[0], 1), - device=input.device, - dtype=torch.float32) - torch.ops._C.dynamic_per_token_scaled_fp8_quant( - output, input, scale, scale_ub) - else: - scale = torch.zeros(1, device=input.device, dtype=torch.float32) - torch.ops._C.dynamic_scaled_fp8_quant(output, input, scale) - else: - # num_token_padding not implemented for this case - assert (scale.numel() == 1 or num_token_padding is None) - torch.ops._C.static_scaled_fp8_quant(output, input, scale) - - return output, scale - - -# int8 -def scaled_int8_quant( - input: torch.Tensor, - scale: Optional[torch.Tensor] = None, - azp: Optional[torch.Tensor] = None, - symmetric: bool = True -) -> Tuple[torch.Tensor, torch.Tensor, Optional[torch.Tensor]]: - """ - Quantize the input tensor to int8 and return the quantized tensor and scale, and maybe azp. - - Args: - input: The input tensor to be quantized to int8. - scale: Optional scaling factor for the int8 quantization. - When not provided, we invoke dynamic-per-token quantization. - azp: Optional zero-point for the int8 quantization. - Must be provided for asymmetric quantization if `scale` is provided. - symmetric: Whether to use symmetric quantization (scale only, azp ignored). - - Returns: - Tuple[torch.Tensor, torch.Tensor, Optional[torch.Tensor]] : Output int8 tensor, scales, and optionally azp. - """ - output = torch.empty_like(input, dtype=torch.int8) - if scale is not None: - # static-per-tensor quantization. - assert symmetric == ( - azp is - None), "azp must only be provided for asymmetric quantization." - torch.ops._C.static_scaled_int8_quant(output, input, scale, azp) - return output, scale, azp - - # dynamic-per-token quantization. - input_scales = torch.empty((input.numel() // input.shape[-1], 1), - device=input.device, - dtype=torch.float32) - input_azp = None if symmetric else torch.empty_like(input_scales, - dtype=torch.int32) - torch.ops._C.dynamic_scaled_int8_quant(output, input, input_scales, - input_azp) - return output, input_scales, input_azp - - -# qqq ops -def marlin_qqq_gemm(a: torch.Tensor, b_q_weight: torch.Tensor, - s_tok: torch.Tensor, s_ch: torch.Tensor, - s_group: torch.Tensor, workspace: torch.Tensor, - size_m: int, size_n: int, size_k: int) -> torch.Tensor: - return torch.ops._C.marlin_qqq_gemm(a, b_q_weight, s_tok, s_ch, s_group, - workspace, size_m, size_n, size_k) - - -# gguf -def ggml_dequantize(W: torch.Tensor, quant_type: int, m: int, - n: int) -> torch.Tensor: - return torch.ops._C.ggml_dequantize(W, quant_type, m, n) - - -def ggml_mul_mat_vec_a8( - W: torch.Tensor, - X: torch.Tensor, - quant_type: int, - row: int, -) -> torch.Tensor: - return torch.ops._C.ggml_mul_mat_vec_a8(W, X, quant_type, row) - - -def ggml_mul_mat_a8( - W: torch.Tensor, - X: torch.Tensor, - quant_type: int, - row: int, -) -> torch.Tensor: - return torch.ops._C.ggml_mul_mat_a8(W, X, quant_type, row) - - -# mamba -def causal_conv1d_fwd(x: torch.Tensor, weight: torch.Tensor, - bias_: Optional[torch.Tensor], - conv_states: Optional[torch.Tensor], - query_start_loc: Optional[torch.Tensor], - cache_indices: Optional[torch.Tensor], - has_initial_state: Optional[torch.Tensor], - silu_activation: bool, pad_slot_id: int): - torch.ops._C.causal_conv1d_fwd(x, weight, bias_, conv_states, - query_start_loc, cache_indices, - has_initial_state, silu_activation, - pad_slot_id) - - -def causal_conv1d_update(x: torch.Tensor, conv_state: torch.Tensor, - weight: torch.Tensor, bias_: Optional[torch.Tensor], - silu_activation: bool, - cache_seqlens: Optional[torch.Tensor], - conv_state_indices: Optional[torch.Tensor], - pad_slot_id: int): - torch.ops._C.causal_conv1d_update(x, conv_state, weight, bias_, - silu_activation, cache_seqlens, - conv_state_indices, pad_slot_id) - - -def selective_scan_fwd(u: torch.Tensor, delta: torch.Tensor, A: torch.Tensor, - B: torch.Tensor, C: torch.Tensor, - D_: Optional[torch.Tensor], z_: Optional[torch.Tensor], - delta_bias_: Optional[torch.Tensor], - delta_softplus: bool, - query_start_loc: Optional[torch.Tensor], - cache_indices: Optional[torch.Tensor], - has_initial_state: Optional[torch.Tensor], - ssm_states: torch.Tensor, pad_slot_id: int): - torch.ops._C.selective_scan_fwd(u, delta, A, B, C, D_, z_, delta_bias_, - delta_softplus, query_start_loc, - cache_indices, has_initial_state, - ssm_states, pad_slot_id) - - -# moe -def moe_sum(input: torch.Tensor, output: torch.Tensor): - torch.ops._moe_C.moe_sum(input, output) - - -def moe_align_block_size(topk_ids: torch.Tensor, num_experts: int, - block_size: int, sorted_token_ids: torch.Tensor, - experts_ids: torch.Tensor, - num_tokens_post_pad: torch.Tensor) -> None: - torch.ops._moe_C.moe_align_block_size(topk_ids, num_experts, block_size, - sorted_token_ids, experts_ids, - num_tokens_post_pad) - - -def topk_softmax(topk_weights: torch.Tensor, topk_ids: torch.Tensor, - token_expert_indicies: torch.Tensor, - gating_output: float) -> None: - torch.ops._moe_C.topk_softmax(topk_weights, topk_ids, - token_expert_indicies, gating_output) - - -if supports_moe_ops and hasattr(torch.ops._moe_C, "marlin_gemm_moe"): - - @register_fake("_moe_C::marlin_gemm_moe") - def marlin_gemm_moe_fake(a: torch.Tensor, b_q_weights: torch.Tensor, - sorted_ids: torch.Tensor, - topk_weights: torch.Tensor, - topk_ids: torch.Tensor, b_scales: torch.Tensor, - b_zero_points: torch.Tensor, g_idx: torch.Tensor, - perm: torch.Tensor, workspace: torch.Tensor, - b_q_type: ScalarType, size_m: torch.SymInt, - size_n: torch.SymInt, size_k: torch.SymInt, - is_k_full: bool, num_experts: int, topk: int, - moe_block_size: int, replicate_input: bool, - apply_weights: bool) -> torch.Tensor: - return torch.empty((size_m, topk, size_n), - dtype=a.dtype, - device=a.device) - - -def reshape_and_cache( - key: torch.Tensor, - value: torch.Tensor, - key_cache: torch.Tensor, - value_cache: torch.Tensor, - slot_mapping: torch.Tensor, - kv_cache_dtype: str, - k_scale: float, - v_scale: float, -) -> None: - torch.ops._C_cache_ops.reshape_and_cache(key, value, key_cache, - value_cache, slot_mapping, - kv_cache_dtype, k_scale, v_scale) - - -def reshape_and_cache_flash( - key: torch.Tensor, - value: torch.Tensor, - key_cache: torch.Tensor, - value_cache: torch.Tensor, - slot_mapping: torch.Tensor, - kv_cache_dtype: str, - k_scale: float, - v_scale: float, -) -> None: - torch.ops._C_cache_ops.reshape_and_cache_flash(key, value, key_cache, - value_cache, slot_mapping, - kv_cache_dtype, k_scale, - v_scale) - - -def copy_blocks(key_caches: List[torch.Tensor], - value_caches: List[torch.Tensor], - block_mapping: torch.Tensor) -> None: - torch.ops._C_cache_ops.copy_blocks(key_caches, value_caches, block_mapping) - - -def swap_blocks(src: torch.Tensor, dst: torch.Tensor, - block_mapping: torch.Tensor) -> None: - torch.ops._C_cache_ops.swap_blocks(src, dst, block_mapping) - - -def convert_fp8(output: torch.Tensor, - input: torch.Tensor, - scale: float = 1.0, - kv_dtype: str = "fp8") -> None: - torch.ops._C_cache_ops.convert_fp8(output, input, scale, kv_dtype) - - -def get_device_attribute(attribute: int, device: int) -> int: - return torch.ops._C_cuda_utils.get_device_attribute(attribute, device) - - -def get_max_shared_memory_per_block_device_attribute(device: int) -> int: - # ruff: noqa: E501 - return torch.ops._C_cuda_utils.get_max_shared_memory_per_block_device_attribute( - device) -''' - - # custom ar def init_custom_ar( ipc_tensors: List[torch.Tensor], diff --git a/python/sglang/srt/distributed/communication_op.py b/python/sglang/srt/distributed/communication_op.py index 07b89a0bd5..1a9e75845b 100644 --- a/python/sglang/srt/distributed/communication_op.py +++ b/python/sglang/srt/distributed/communication_op.py @@ -1,3 +1,4 @@ +# reference: VLLM 0.6.4.post1 from typing import Any, Dict, Optional, Union import torch diff --git a/python/sglang/srt/distributed/device_communicators/cuda_wrapper.py b/python/sglang/srt/distributed/device_communicators/cuda_wrapper.py index 75c5cc93bc..91c9f64b16 100644 --- a/python/sglang/srt/distributed/device_communicators/cuda_wrapper.py +++ b/python/sglang/srt/distributed/device_communicators/cuda_wrapper.py @@ -1,3 +1,4 @@ +# reference: VLLM 0.6.4.post1 """This file is a pure Python wrapper for the cudart library. It avoids the need to compile a separate shared library, and is convenient for use when we just need to call a few functions. diff --git a/python/sglang/srt/distributed/device_communicators/custom_all_reduce.py b/python/sglang/srt/distributed/device_communicators/custom_all_reduce.py index f36eead9b8..8eb8cf749c 100644 --- a/python/sglang/srt/distributed/device_communicators/custom_all_reduce.py +++ b/python/sglang/srt/distributed/device_communicators/custom_all_reduce.py @@ -1,3 +1,4 @@ +# reference: VLLM 0.6.4.post1 import ctypes import logging from contextlib import contextmanager diff --git a/python/sglang/srt/distributed/device_communicators/custom_all_reduce_utils.py b/python/sglang/srt/distributed/device_communicators/custom_all_reduce_utils.py index 6b73e1b138..0f8cb69cd2 100644 --- a/python/sglang/srt/distributed/device_communicators/custom_all_reduce_utils.py +++ b/python/sglang/srt/distributed/device_communicators/custom_all_reduce_utils.py @@ -1,3 +1,4 @@ +# reference: VLLM 0.6.4.post1 import ctypes import json import logging diff --git a/python/sglang/srt/distributed/device_communicators/hpu_communicator.py b/python/sglang/srt/distributed/device_communicators/hpu_communicator.py index 671e7d283b..72292f1345 100644 --- a/python/sglang/srt/distributed/device_communicators/hpu_communicator.py +++ b/python/sglang/srt/distributed/device_communicators/hpu_communicator.py @@ -1,3 +1,4 @@ +# reference: VLLM 0.6.4.post1 import torch import torch.distributed as dist from torch.distributed import ProcessGroup diff --git a/python/sglang/srt/distributed/device_communicators/pynccl.py b/python/sglang/srt/distributed/device_communicators/pynccl.py index b711830628..bdb961bb5d 100644 --- a/python/sglang/srt/distributed/device_communicators/pynccl.py +++ b/python/sglang/srt/distributed/device_communicators/pynccl.py @@ -1,4 +1,4 @@ -# ===================== import region ===================== +# reference: VLLM 0.6.4.post1 import logging from contextlib import contextmanager from typing import Optional, Union diff --git a/python/sglang/srt/distributed/device_communicators/pynccl_wrapper.py b/python/sglang/srt/distributed/device_communicators/pynccl_wrapper.py index 41b759a8d7..fe1db89fa3 100644 --- a/python/sglang/srt/distributed/device_communicators/pynccl_wrapper.py +++ b/python/sglang/srt/distributed/device_communicators/pynccl_wrapper.py @@ -1,3 +1,4 @@ +# reference: VLLM 0.6.4.post1 # This file is a pure Python wrapper for the NCCL library. # The main purpose is to use NCCL combined with CUDA graph. # Before writing this script, we tried the following approach: diff --git a/python/sglang/srt/distributed/device_communicators/shm_broadcast.py b/python/sglang/srt/distributed/device_communicators/shm_broadcast.py index bf91a99158..925426ecc3 100644 --- a/python/sglang/srt/distributed/device_communicators/shm_broadcast.py +++ b/python/sglang/srt/distributed/device_communicators/shm_broadcast.py @@ -1,3 +1,4 @@ +# reference: VLLM 0.6.4.post1 import logging import os import pickle diff --git a/python/sglang/srt/distributed/parallel_state.py b/python/sglang/srt/distributed/parallel_state.py index b8e6e9fe8e..8d5c03fa78 100644 --- a/python/sglang/srt/distributed/parallel_state.py +++ b/python/sglang/srt/distributed/parallel_state.py @@ -1,3 +1,4 @@ +# reference: VLLM 0.6.4.post1 # Copyright 2023 The vLLM team. # Adapted from # https://github.com/NVIDIA/Megatron-LM/blob/main/megatron/core/parallel_state.py diff --git a/python/sglang/srt/distributed/utils.py b/python/sglang/srt/distributed/utils.py index aa6dbab955..cabbb86c4b 100644 --- a/python/sglang/srt/distributed/utils.py +++ b/python/sglang/srt/distributed/utils.py @@ -1,3 +1,4 @@ +# reference: VLLM 0.6.4.post1 # Copyright 2023 The vLLM team. # Adapted from # https://github.com/NVIDIA/Megatron-LM/blob/main/megatron/core/tensor_parallel/utils.py From e01d48aee032b19eda63a639fd0564a8d8c63676 Mon Sep 17 00:00:00 2001 From: Zhangyi <1109276519@qq.com> Date: Sat, 30 Nov 2024 01:56:43 +0800 Subject: [PATCH 04/10] remove vllm.utils, vllm.platforms and vllm.envs --- python/sglang/srt/_custom_ops.py | 6 +- .../device_communicators/custom_all_reduce.py | 62 +++++++++++--- .../custom_all_reduce_utils.py | 25 ++++-- .../device_communicators/pynccl_wrapper.py | 30 ++++++- .../device_communicators/shm_broadcast.py | 81 +++++++++++++++++-- .../device_communicators/xpu_communicator.py | 5 +- .../sglang/srt/distributed/parallel_state.py | 13 +-- python/sglang/srt/distributed/utils.py | 5 +- python/sglang/srt/utils.py | 54 +++++++++++++ 9 files changed, 246 insertions(+), 35 deletions(-) diff --git a/python/sglang/srt/_custom_ops.py b/python/sglang/srt/_custom_ops.py index 778aef899f..375839689a 100644 --- a/python/sglang/srt/_custom_ops.py +++ b/python/sglang/srt/_custom_ops.py @@ -8,17 +8,13 @@ import torch import torch.library -# import vllm.envs as envs -from vllm.platforms import current_platform - from sglang.srt.utils import is_hpu # from vllm.scalar_type import ScalarType logger = logging.getLogger(__name__) -# if not current_platform.is_tpu() and not current_platform.is_hpu(): -if not current_platform.is_tpu() and not is_hpu(): +if not is_hpu(): try: import custom_ar except ImportError as e: diff --git a/python/sglang/srt/distributed/device_communicators/custom_all_reduce.py b/python/sglang/srt/distributed/device_communicators/custom_all_reduce.py index 8eb8cf749c..cb029db447 100644 --- a/python/sglang/srt/distributed/device_communicators/custom_all_reduce.py +++ b/python/sglang/srt/distributed/device_communicators/custom_all_reduce.py @@ -1,15 +1,16 @@ # reference: VLLM 0.6.4.post1 import ctypes import logging +import os from contextlib import contextmanager -from typing import List, Optional, Union +from functools import wraps +from typing import Callable, List, Optional, TypeVar, Union +import pynvml import torch import torch.distributed as dist -import vllm.envs as envs from torch.distributed import ProcessGroup -from vllm.platforms import current_platform -from vllm.utils import cuda_device_count_stateless +from typing_extensions import ParamSpec from sglang.srt import _custom_ops as ops from sglang.srt.distributed.device_communicators.cuda_wrapper import CudaRTLibrary @@ -17,6 +18,7 @@ gpu_p2p_access_check, ) from sglang.srt.distributed.parallel_state import in_the_same_node_as +from sglang.srt.utils import cuda_device_count_stateless, is_cuda try: ops.meta_size() @@ -28,11 +30,53 @@ logger = logging.getLogger(__name__) +_P = ParamSpec("_P") +_R = TypeVar("_R") + + +def with_nvml_context(fn: Callable[_P, _R]) -> Callable[_P, _R]: + @wraps(fn) + def wrapper(*args: _P.args, **kwargs: _P.kwargs) -> _R: + pynvml.nvmlInit() + try: + return fn(*args, **kwargs) + finally: + pynvml.nvmlShutdown() + + return wrapper + + +@with_nvml_context +def is_full_nvlink(cls, physical_device_ids: List[int]) -> bool: + """ + query if the set of gpus are fully connected by nvlink (1 hop) + """ + handles = [pynvml.nvmlDeviceGetHandleByIndex(i) for i in physical_device_ids] + for i, handle in enumerate(handles): + for j, peer_handle in enumerate(handles): + if i < j: + try: + p2p_status = pynvml.nvmlDeviceGetP2PStatus( + handle, peer_handle, pynvml.NVML_P2P_CAPS_INDEX_NVLINK + ) + if p2p_status != pynvml.NVML_P2P_STATUS_OK: + return False + except pynvml.NVMLError: + logger.exception( + "NVLink detection failed. This is normal if your" + " machine has no NVLink equipped." + ) + return False + return True + + def _can_p2p(rank: int, world_size: int) -> bool: + # SGLANG_SKIP_P2P_CHECK can be set to False in sglang + SGLANG_SKIP_P2P_CHECK = os.getenv("VLLM_SKIP_P2P_CHECK", "0") == "1" for i in range(world_size): if i == rank: continue - if envs.VLLM_SKIP_P2P_CHECK: + if SGLANG_SKIP_P2P_CHECK: logger.info("Skipping P2P check and trusting the driver's P2P report.") return torch.cuda.can_device_access_peer(rank, i) if not gpu_p2p_access_check(rank, i): @@ -114,7 +158,7 @@ def __init__( assert isinstance(device, torch.device) self.device = device - cuda_visible_devices = envs.CUDA_VISIBLE_DEVICES + cuda_visible_devices = os.environ.get("CUDA_VISIBLE_DEVICES", None) if cuda_visible_devices: device_ids = list(map(int, cuda_visible_devices.split(","))) else: @@ -131,11 +175,9 @@ def __init__( # test nvlink first, this will filter out most of the cases # where custom allreduce is not supported # this checks hardware and driver support for NVLink - assert current_platform.is_cuda() - from vllm.platforms.cuda import CudaPlatform + assert is_cuda() - cuda_platform: CudaPlatform = current_platform - full_nvlink = cuda_platform.is_full_nvlink(physical_device_ids) + full_nvlink = is_full_nvlink(physical_device_ids) if world_size > 2 and not full_nvlink: logger.warning( "Custom allreduce is disabled because it's not supported on" diff --git a/python/sglang/srt/distributed/device_communicators/custom_all_reduce_utils.py b/python/sglang/srt/distributed/device_communicators/custom_all_reduce_utils.py index 0f8cb69cd2..c58d057a4b 100644 --- a/python/sglang/srt/distributed/device_communicators/custom_all_reduce_utils.py +++ b/python/sglang/srt/distributed/device_communicators/custom_all_reduce_utils.py @@ -7,19 +7,31 @@ import subprocess import sys import tempfile +from functools import lru_cache from itertools import product from typing import Dict, List, Optional, Sequence import torch.distributed as dist import torch.multiprocessing as mp -import vllm.envs as envs -from vllm.utils import cuda_device_count_stateless, update_environment_variables from sglang.srt.distributed.device_communicators.cuda_wrapper import CudaRTLibrary +from sglang.srt.utils import cuda_device_count_stateless logger = logging.getLogger(__name__) +def update_environment_variables(envs: Dict[str, str]): + for k, v in envs.items(): + if k in os.environ and os.environ[k] != v: + logger.warning( + "Overwriting environment variable %s " "from '%s' to '%s'", + k, + os.environ[k], + v, + ) + os.environ[k] = v + + def producer( batch_src: Sequence[int], producer_queue, @@ -129,7 +141,7 @@ def can_actually_p2p( processes for testing all pairs of GPUs in batch. The trick is to reset the device after each test (which is not available in PyTorch). """ # noqa - cuda_visible_devices = envs.CUDA_VISIBLE_DEVICES + cuda_visible_devices = os.environ.get("CUDA_VISIBLE_DEVICES", None) # pass the CUDA_VISIBLE_DEVICES to the child process # to make sure they see the same set of GPUs @@ -207,12 +219,15 @@ def gpu_p2p_access_check(src: int, tgt: int) -> bool: is_distributed = dist.is_initialized() num_dev = cuda_device_count_stateless() - cuda_visible_devices = envs.CUDA_VISIBLE_DEVICES + cuda_visible_devices = os.environ.get("CUDA_VISIBLE_DEVICES", None) if cuda_visible_devices is None: cuda_visible_devices = ",".join(str(i) for i in range(num_dev)) + # VLLM_CACHE_ROOT -> SGLANG_CACHE_ROOT + # "~/.cache/vllm" -> "~/.cache/sglang" + SGLANG_CACHE_ROOT = os.path.expanduser("~/.cache/sglang") path = os.path.join( - envs.VLLM_CACHE_ROOT, f"gpu_p2p_access_cache_for_{cuda_visible_devices}.json" + SGLANG_CACHE_ROOT, f"gpu_p2p_access_cache_for_{cuda_visible_devices}.json" ) os.makedirs(os.path.dirname(path), exist_ok=True) from sglang.srt.distributed.parallel_state import get_world_group diff --git a/python/sglang/srt/distributed/device_communicators/pynccl_wrapper.py b/python/sglang/srt/distributed/device_communicators/pynccl_wrapper.py index fe1db89fa3..421ff7bf53 100644 --- a/python/sglang/srt/distributed/device_communicators/pynccl_wrapper.py +++ b/python/sglang/srt/distributed/device_communicators/pynccl_wrapper.py @@ -22,16 +22,44 @@ import ctypes import logging +import os import platform from dataclasses import dataclass from typing import Any, Dict, List, Optional import torch from torch.distributed import ReduceOp -from vllm.utils import find_nccl_library logger = logging.getLogger(__name__) + +def find_nccl_library() -> str: + """ + We either use the library file specified by the `VLLM_NCCL_SO_PATH` + environment variable, or we find the library file brought by PyTorch. + After importing `torch`, `libnccl.so.2` or `librccl.so.1` can be + found by `ctypes` automatically. + """ + + # so_file can be set to None in sglang + so_file = os.environ.get("VLLM_NCCL_SO_PATH", None) + + # manually load the nccl library + if so_file: + logger.info( + "Found nccl from environment variable VLLM_NCCL_SO_PATH=%s", so_file + ) + else: + if torch.version.cuda is not None: + so_file = "libnccl.so.2" + elif torch.version.hip is not None: + so_file = "librccl.so.1" + else: + raise ValueError("NCCL only supports CUDA and ROCm backends.") + logger.info("Found nccl from library %s", so_file) + return so_file + + # === export types and functions from nccl to Python === # for the original nccl definition, please check # https://github.com/NVIDIA/nccl/blob/master/src/nccl.h.in diff --git a/python/sglang/srt/distributed/device_communicators/shm_broadcast.py b/python/sglang/srt/distributed/device_communicators/shm_broadcast.py index 925426ecc3..fe3a7156e1 100644 --- a/python/sglang/srt/distributed/device_communicators/shm_broadcast.py +++ b/python/sglang/srt/distributed/device_communicators/shm_broadcast.py @@ -1,8 +1,11 @@ # reference: VLLM 0.6.4.post1 +import ipaddress import logging import os import pickle +import socket import time +import warnings from contextlib import contextmanager from dataclasses import dataclass, field from multiprocessing import shared_memory @@ -11,17 +14,85 @@ import torch import torch.distributed as dist -import vllm.envs as envs from torch.distributed import ProcessGroup -from vllm.utils import get_ip, get_open_port, is_valid_ipv6_address from zmq import IPV6 # type: ignore from zmq import SUB, SUBSCRIBE, XPUB, XPUB_VERBOSE, Context # type: ignore -VLLM_RINGBUFFER_WARNING_INTERVAL = envs.VLLM_RINGBUFFER_WARNING_INTERVAL +# SGLANG_RINGBUFFER_WARNING_INTERVAL can be set to 60 +SGLANG_RINGBUFFER_WARNING_INTERVAL = int( + os.environ.get("VLLM_RINGBUFFER_WARNING_INTERVAL", "60") +) logger = logging.getLogger(__name__) +def get_ip() -> str: + # VLLM_HOST_IP env can be ignore + host_ip = os.getenv("VLLM_HOST_IP", "") or os.getenv("HOST_IP", "") + if host_ip: + return host_ip + + # IP is not set, try to get it from the network interface + + # try ipv4 + s = socket.socket(socket.AF_INET, socket.SOCK_DGRAM) + try: + s.connect(("8.8.8.8", 80)) # Doesn't need to be reachable + return s.getsockname()[0] + except Exception: + pass + + # try ipv6 + try: + s = socket.socket(socket.AF_INET6, socket.SOCK_DGRAM) + # Google's public DNS server, see + # https://developers.google.com/speed/public-dns/docs/using#addresses + s.connect(("2001:4860:4860::8888", 80)) # Doesn't need to be reachable + return s.getsockname()[0] + except Exception: + pass + + warnings.warn( + "Failed to get the IP address, using 0.0.0.0 by default." + "The value can be set by the environment variable" + " SGLANG_HOST_IP or HOST_IP.", + stacklevel=2, + ) + return "0.0.0.0" + + +def get_open_port() -> int: + + port = os.getenv("SGLANG_PORT") + if port is not None: + while True: + try: + with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as s: + s.bind(("", port)) + return port + except OSError: + port += 1 # Increment port number if already in use + logger.info("Port %d is already in use, trying port %d", port - 1, port) + # try ipv4 + try: + with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as s: + s.bind(("", 0)) + return s.getsockname()[1] + except OSError: + # try ipv6 + with socket.socket(socket.AF_INET6, socket.SOCK_STREAM) as s: + s.bind(("", 0)) + return s.getsockname()[1] + + +def is_valid_ipv6_address(address: str) -> bool: + try: + ipaddress.IPv6Address(address) + return True + except ValueError: + return False + + class ShmRingBuffer: def __init__( @@ -343,11 +414,11 @@ def acquire_write(self): # if we wait for a long time, we should warn the user if ( time.monotonic() - start_time - > VLLM_RINGBUFFER_WARNING_INTERVAL * n_warning + > SGLANG_RINGBUFFER_WARNING_INTERVAL * n_warning ): logger.warning( "No available block found in %s second. ", - VLLM_RINGBUFFER_WARNING_INTERVAL, + SGLANG_RINGBUFFER_WARNING_INTERVAL, ) n_warning += 1 diff --git a/python/sglang/srt/distributed/device_communicators/xpu_communicator.py b/python/sglang/srt/distributed/device_communicators/xpu_communicator.py index e64fd83d04..61b29b77c7 100644 --- a/python/sglang/srt/distributed/device_communicators/xpu_communicator.py +++ b/python/sglang/srt/distributed/device_communicators/xpu_communicator.py @@ -1,13 +1,14 @@ import torch import torch.distributed as dist from torch.distributed import ProcessGroup -from vllm.platforms import current_platform + +from sglang.srt.utils import is_xpu class XpuCommunicator: def __init__(self, group: ProcessGroup): - if not current_platform.is_xpu(): + if not is_xpu(): self.disabled = True return self.disabled = False diff --git a/python/sglang/srt/distributed/parallel_state.py b/python/sglang/srt/distributed/parallel_state.py index 8d5c03fa78..7286d7a93d 100644 --- a/python/sglang/srt/distributed/parallel_state.py +++ b/python/sglang/srt/distributed/parallel_state.py @@ -23,6 +23,7 @@ import contextlib import gc import logging +import os import pickle import weakref from collections import namedtuple @@ -34,11 +35,13 @@ import torch import torch.distributed -import vllm.envs as envs from torch.distributed import Backend, ProcessGroup -from vllm.platforms import current_platform -from sglang.srt.utils import direct_register_custom_op, supports_custom_op +from sglang.srt.utils import ( + direct_register_custom_op, + is_cuda_alike, + supports_custom_op, +) @dataclass @@ -207,7 +210,7 @@ def __init__( assert self.cpu_group is not None assert self.device_group is not None - if current_platform.is_cuda_alike(): + if is_cuda_alike(): self.device = torch.device(f"cuda:{local_rank}") else: self.device = torch.device("cpu") @@ -1005,7 +1008,7 @@ def init_distributed_environment( # local rank not set, this usually happens in single-node # setting, where we can use rank as local rank if distributed_init_method == "env://": - local_rank = envs.LOCAL_RANK + local_rank = int(os.environ.get("LOCAL_RANK", "0")) else: local_rank = rank global _WORLD diff --git a/python/sglang/srt/distributed/utils.py b/python/sglang/srt/distributed/utils.py index cabbb86c4b..bcc4a90e45 100644 --- a/python/sglang/srt/distributed/utils.py +++ b/python/sglang/srt/distributed/utils.py @@ -5,13 +5,13 @@ # Copyright (c) 2022, NVIDIA CORPORATION. All rights reserved. import dataclasses import logging +import os import pickle import time from collections import deque from typing import Any, Deque, Dict, Optional, Sequence, Tuple import torch -import vllm.envs as envs from torch.distributed import TCPStore logger = logging.getLogger(__name__) @@ -66,7 +66,8 @@ def get_pp_indices( If the number of layers is not divisible by the number of partitions, the last partition will have the remaining layers. """ - partition_list_str = envs.VLLM_PP_LAYER_PARTITION + # partition_list_str can be set to None in sglang + partition_list_str = os.getenv("VLLM_PP_LAYER_PARTITION", None) if partition_list_str is not None: try: partitions = [int(layer) for layer in partition_list_str.split(",")] diff --git a/python/sglang/srt/utils.py b/python/sglang/srt/utils.py index a84dec9e8d..e080608470 100644 --- a/python/sglang/srt/utils.py +++ b/python/sglang/srt/utils.py @@ -30,6 +30,7 @@ import tempfile import time import warnings +from functools import lru_cache from importlib.metadata import PackageNotFoundError, version from io import BytesIO from typing import Any, Callable, Dict, List, Optional, Protocol, Tuple, Union @@ -67,10 +68,22 @@ def is_hip() -> bool: return torch.version.hip is not None +def is_cuda(): + return hasattr(torch, "cuda") and torch.cuda.is_available() + + +def is_cuda_alike(): + return is_cuda() or is_hip() + + def is_hpu() -> bool: return hasattr(torch, "hpu") and torch.hpu.is_available() +def is_xpu() -> bool: + return hasattr(torch, "xpu") and torch.xpu.is_available() + + def is_flashinfer_available(): """ Check whether flashinfer is available. @@ -1030,3 +1043,44 @@ def set_gpu_proc_affinity( def get_bool_env_var(name: str, default: str = "false") -> bool: value = os.getenv(name, default) return value.lower() in ("true", "1") + + +@lru_cache(maxsize=8) +def _cuda_device_count_stateless(cuda_visible_devices: Optional[str] = None) -> int: + # Note: cuda_visible_devices is not used, but we keep it as an argument for + # LRU Cache purposes. + + # Code below is based on + # https://github.com/pytorch/pytorch/blob/ + # c1cd946818442aca8c7f812b16d187ce1586c3bc/ + # torch/cuda/__init__.py#L831C1-L831C17 + import torch.cuda + import torch.version + + if not torch.cuda._is_compiled(): + return 0 + if is_hip(): + # ROCm uses amdsmi instead of nvml for stateless device count + # This requires a sufficiently modern version of Torch 2.4.0 + raw_count = ( + torch.cuda._device_count_amdsmi() + if (hasattr(torch.cuda, "_device_count_amdsmi")) + else -1 + ) + else: + raw_count = torch.cuda._device_count_nvml() + r = torch._C._cuda_getDeviceCount() if raw_count < 0 else raw_count + return r + + +def cuda_device_count_stateless() -> int: + """Get number of CUDA devices, caching based on the value of + CUDA_VISIBLE_DEVICES at the time of call. + + This should be used instead of torch.cuda.device_count() + unless CUDA_VISIBLE_DEVICES has already been set to the desired + value.""" + + # This can be removed and simply replaced with torch.cuda.get_device_count + # after https://github.com/pytorch/pytorch/pull/122815 is released. + return _cuda_device_count_stateless(envs.CUDA_VISIBLE_DEVICES) From 4cdd03ad213edf9e572d28c7688d8ef8af7f3d47 Mon Sep 17 00:00:00 2001 From: Zhangyi <1109276519@qq.com> Date: Sat, 30 Nov 2024 02:08:20 +0800 Subject: [PATCH 05/10] add detail vllm version --- python/sglang/srt/_custom_ops.py | 2 +- python/sglang/srt/distributed/communication_op.py | 2 +- .../srt/distributed/device_communicators/cuda_wrapper.py | 2 +- .../distributed/device_communicators/custom_all_reduce.py | 2 +- .../device_communicators/custom_all_reduce_utils.py | 2 +- .../distributed/device_communicators/hpu_communicator.py | 2 +- .../sglang/srt/distributed/device_communicators/pynccl.py | 2 +- .../srt/distributed/device_communicators/pynccl_wrapper.py | 3 ++- .../srt/distributed/device_communicators/shm_broadcast.py | 2 +- .../distributed/device_communicators/xpu_communicator.py | 1 + python/sglang/srt/distributed/parallel_state.py | 7 ++----- python/sglang/srt/distributed/utils.py | 2 +- 12 files changed, 14 insertions(+), 15 deletions(-) diff --git a/python/sglang/srt/_custom_ops.py b/python/sglang/srt/_custom_ops.py index 375839689a..ede8e313d7 100644 --- a/python/sglang/srt/_custom_ops.py +++ b/python/sglang/srt/_custom_ops.py @@ -1,4 +1,4 @@ -# reference: VLLM 0.6.4.post1 +# Adapted from https://github.com/vllm-project/vllm/blob/a6221a144af772fd1a68fe7e627935dc53e81738/vllm/_custom_ops.py import contextlib import functools import importlib diff --git a/python/sglang/srt/distributed/communication_op.py b/python/sglang/srt/distributed/communication_op.py index 1a9e75845b..ddf3b8ef56 100644 --- a/python/sglang/srt/distributed/communication_op.py +++ b/python/sglang/srt/distributed/communication_op.py @@ -1,4 +1,4 @@ -# reference: VLLM 0.6.4.post1 +# Adapted from https://github.com/vllm-project/vllm/blob/a6221a144af772fd1a68fe7e627935dc53e81738/vllm/distributed/communication_op.py from typing import Any, Dict, Optional, Union import torch diff --git a/python/sglang/srt/distributed/device_communicators/cuda_wrapper.py b/python/sglang/srt/distributed/device_communicators/cuda_wrapper.py index 91c9f64b16..ab4ee33fcf 100644 --- a/python/sglang/srt/distributed/device_communicators/cuda_wrapper.py +++ b/python/sglang/srt/distributed/device_communicators/cuda_wrapper.py @@ -1,4 +1,4 @@ -# reference: VLLM 0.6.4.post1 +# Adapted from https://github.com/vllm-project/vllm/blob/a6221a144af772fd1a68fe7e627935dc53e81738/vllm/distributed/device_communicators/cuda_wrapper.py """This file is a pure Python wrapper for the cudart library. It avoids the need to compile a separate shared library, and is convenient for use when we just need to call a few functions. diff --git a/python/sglang/srt/distributed/device_communicators/custom_all_reduce.py b/python/sglang/srt/distributed/device_communicators/custom_all_reduce.py index cb029db447..a09346521c 100644 --- a/python/sglang/srt/distributed/device_communicators/custom_all_reduce.py +++ b/python/sglang/srt/distributed/device_communicators/custom_all_reduce.py @@ -1,4 +1,4 @@ -# reference: VLLM 0.6.4.post1 +# Adapted from https://github.com/vllm-project/vllm/blob/a6221a144af772fd1a68fe7e627935dc53e81738/vllm/distributed/device_communicators/custom_all_reduce.py import ctypes import logging import os diff --git a/python/sglang/srt/distributed/device_communicators/custom_all_reduce_utils.py b/python/sglang/srt/distributed/device_communicators/custom_all_reduce_utils.py index c58d057a4b..d807dfd5ce 100644 --- a/python/sglang/srt/distributed/device_communicators/custom_all_reduce_utils.py +++ b/python/sglang/srt/distributed/device_communicators/custom_all_reduce_utils.py @@ -1,4 +1,4 @@ -# reference: VLLM 0.6.4.post1 +# Adapted from https://github.com/vllm-project/vllm/blob/a6221a144af772fd1a68fe7e627935dc53e81738/vllm/distributed/device_communicators/custom_all_reduce_utils.py import ctypes import json import logging diff --git a/python/sglang/srt/distributed/device_communicators/hpu_communicator.py b/python/sglang/srt/distributed/device_communicators/hpu_communicator.py index 72292f1345..72ef3889e0 100644 --- a/python/sglang/srt/distributed/device_communicators/hpu_communicator.py +++ b/python/sglang/srt/distributed/device_communicators/hpu_communicator.py @@ -1,4 +1,4 @@ -# reference: VLLM 0.6.4.post1 +# Adapted from https://github.com/vllm-project/vllm/blob/a6221a144af772fd1a68fe7e627935dc53e81738/vllm/distributed/device_communicators/hpu_communicator.py import torch import torch.distributed as dist from torch.distributed import ProcessGroup diff --git a/python/sglang/srt/distributed/device_communicators/pynccl.py b/python/sglang/srt/distributed/device_communicators/pynccl.py index bdb961bb5d..baee270da9 100644 --- a/python/sglang/srt/distributed/device_communicators/pynccl.py +++ b/python/sglang/srt/distributed/device_communicators/pynccl.py @@ -1,4 +1,4 @@ -# reference: VLLM 0.6.4.post1 +# Adapted from https://github.com/vllm-project/vllm/blob/a6221a144af772fd1a68fe7e627935dc53e81738/vllm/distributed/device_communicators/pynccl.py import logging from contextlib import contextmanager from typing import Optional, Union diff --git a/python/sglang/srt/distributed/device_communicators/pynccl_wrapper.py b/python/sglang/srt/distributed/device_communicators/pynccl_wrapper.py index 421ff7bf53..b1721cec80 100644 --- a/python/sglang/srt/distributed/device_communicators/pynccl_wrapper.py +++ b/python/sglang/srt/distributed/device_communicators/pynccl_wrapper.py @@ -1,4 +1,5 @@ -# reference: VLLM 0.6.4.post1 +# Adapted from https://github.com/vllm-project/vllm/blob/a6221a144af772fd1a68fe7e627935dc53e81738/vllm/distributed/device_communicators/pynccl.py + # This file is a pure Python wrapper for the NCCL library. # The main purpose is to use NCCL combined with CUDA graph. # Before writing this script, we tried the following approach: diff --git a/python/sglang/srt/distributed/device_communicators/shm_broadcast.py b/python/sglang/srt/distributed/device_communicators/shm_broadcast.py index fe3a7156e1..c454a9c0ae 100644 --- a/python/sglang/srt/distributed/device_communicators/shm_broadcast.py +++ b/python/sglang/srt/distributed/device_communicators/shm_broadcast.py @@ -1,4 +1,4 @@ -# reference: VLLM 0.6.4.post1 +# Adapted from https://github.com/vllm-project/vllm/blob/a6221a144af772fd1a68fe7e627935dc53e81738/vllm/distributed/device_communicators/shm_broadcast.py import ipaddress import logging import os diff --git a/python/sglang/srt/distributed/device_communicators/xpu_communicator.py b/python/sglang/srt/distributed/device_communicators/xpu_communicator.py index 61b29b77c7..ff0981b80b 100644 --- a/python/sglang/srt/distributed/device_communicators/xpu_communicator.py +++ b/python/sglang/srt/distributed/device_communicators/xpu_communicator.py @@ -1,3 +1,4 @@ +# Adapted from https://github.com/vllm-project/vllm/blob/a6221a144af772fd1a68fe7e627935dc53e81738/vllm/distributed/device_communicators/xpu_communicator.py import torch import torch.distributed as dist from torch.distributed import ProcessGroup diff --git a/python/sglang/srt/distributed/parallel_state.py b/python/sglang/srt/distributed/parallel_state.py index 7286d7a93d..de8ee4957b 100644 --- a/python/sglang/srt/distributed/parallel_state.py +++ b/python/sglang/srt/distributed/parallel_state.py @@ -1,4 +1,5 @@ -# reference: VLLM 0.6.4.post1 +# Adapted from https://github.com/vllm-project/vllm/blob/a6221a144af772fd1a68fe7e627935dc53e81738/vllm/distributed/parallel_state.py + # Copyright 2023 The vLLM team. # Adapted from # https://github.com/NVIDIA/Megatron-LM/blob/main/megatron/core/parallel_state.py @@ -389,10 +390,6 @@ def all_reduce(self, input_: torch.Tensor) -> torch.Tensor: self._all_reduce_in_place(input_) return input_ - if self.tpu_communicator is not None and not self.tpu_communicator.disabled: - # TPU handles Dynamo with its own logic. - return self.tpu_communicator.all_reduce(input_) - if self.hpu_communicator is not None and not self.hpu_communicator.disabled: return self.hpu_communicator.all_reduce(input_) diff --git a/python/sglang/srt/distributed/utils.py b/python/sglang/srt/distributed/utils.py index bcc4a90e45..7fdf6a394d 100644 --- a/python/sglang/srt/distributed/utils.py +++ b/python/sglang/srt/distributed/utils.py @@ -1,4 +1,4 @@ -# reference: VLLM 0.6.4.post1 +# Adapted from https://github.com/vllm-project/vllm/blob/a6221a144af772fd1a68fe7e627935dc53e81738/vllm/distributed/utils.py # Copyright 2023 The vLLM team. # Adapted from # https://github.com/NVIDIA/Megatron-LM/blob/main/megatron/core/tensor_parallel/utils.py From 855c06938ec9fc0ee996f8c909ee79b92bfdb5da Mon Sep 17 00:00:00 2001 From: Zhangyi <1109276519@qq.com> Date: Sat, 30 Nov 2024 02:12:59 +0800 Subject: [PATCH 06/10] remove tpu_comm since it is useless in sglang --- .../device_communicators/tpu_communicator.py | 59 ------------------- .../sglang/srt/distributed/parallel_state.py | 17 ------ 2 files changed, 76 deletions(-) delete mode 100644 python/sglang/srt/distributed/device_communicators/tpu_communicator.py diff --git a/python/sglang/srt/distributed/device_communicators/tpu_communicator.py b/python/sglang/srt/distributed/device_communicators/tpu_communicator.py deleted file mode 100644 index 9027b68b38..0000000000 --- a/python/sglang/srt/distributed/device_communicators/tpu_communicator.py +++ /dev/null @@ -1,59 +0,0 @@ -import os - -import torch -import torch.distributed as dist -from torch.distributed import ProcessGroup -from vllm.platforms import current_platform - -if current_platform.is_tpu(): - import torch_xla.core.xla_model as xm - import torch_xla.runtime as xr - from torch_xla._internal import pjrt - from vllm.executor import ray_utils - - -class TpuCommunicator: - - def __init__(self, group: ProcessGroup): - if not current_platform.is_tpu(): - self.disabled = True - return - self.disabled = False - - # NOTE(woosuk): When using TP > 1 on TPUs, every TPU on the same node - # must be used together. Therefore, the local rank and world size can - # be simply calculated as follows. - global_rank = dist.get_rank(group) - global_world_size = dist.get_world_size(group) - - # Calculate how many TPU nodes are in the current deployment. This - # is the Ray placement group if it is deployed with Ray. Default - # to the number of TPU nodes in the Ray cluster. The number of TPU - # nodes is computed by the total number of TPUs divided by the - # number of TPU accelerators per node, to account for clusters - # with both CPUs and TPUs. - num_nodes = ray_utils.get_num_tpu_nodes() - num_nodes_in_pg = ray_utils.get_num_nodes_in_placement_group() - if num_nodes_in_pg > 0: - num_nodes = num_nodes_in_pg - - local_world_size = global_world_size // num_nodes - local_rank = global_rank % local_world_size - - # Ensure environment variables are set for multihost deployments. - # On GKE, this is needed for libtpu and TPU driver to know which TPU - # chip is actually visible. Otherwise the TPU driver will fail to - # initialize because the number of devices would be different from - # the number of visible worker addresses. - os.environ["CLOUD_TPU_TASK_ID"] = str(global_rank) - os.environ["TPU_VISIBLE_CHIPS"] = str(local_rank) - - pjrt.initialize_multiprocess(local_rank, local_world_size) - xr._init_world_size_ordinal() - - def all_reduce(self, x: torch.Tensor) -> torch.Tensor: - return xm.all_reduce(xm.REDUCE_SUM, x) - - def all_gather(self, x: torch.Tensor, dim: int = -1) -> torch.Tensor: - assert dim == -1, "TPUs only support dim=-1 for all-gather." - return xm.all_gather(x, dim=dim) diff --git a/python/sglang/srt/distributed/parallel_state.py b/python/sglang/srt/distributed/parallel_state.py index de8ee4957b..26d04b04ce 100644 --- a/python/sglang/srt/distributed/parallel_state.py +++ b/python/sglang/srt/distributed/parallel_state.py @@ -179,7 +179,6 @@ def __init__( torch_distributed_backend: Union[str, Backend], use_pynccl: bool, use_custom_allreduce: bool, - use_tpu_communicator: bool, use_hpu_communicator: bool, use_xpu_communicator: bool, use_message_queue_broadcaster: bool = False, @@ -218,7 +217,6 @@ def __init__( self.use_pynccl = use_pynccl self.use_custom_allreduce = use_custom_allreduce - self.use_tpu_communicator = use_tpu_communicator self.use_hpu_communicator = use_hpu_communicator self.use_xpu_communicator = use_xpu_communicator @@ -245,14 +243,6 @@ def __init__( device=self.device, ) - from sglang.srt.distributed.device_communicators.tpu_communicator import ( - TpuCommunicator, - ) - - self.tpu_communicator: Optional[TpuCommunicator] = None - if use_tpu_communicator and self.world_size > 1: - self.tpu_communicator = TpuCommunicator(group=self.cpu_group) - from sglang.srt.distributed.device_communicators.hpu_communicator import ( HpuCommunicator, ) @@ -432,11 +422,6 @@ def all_gather(self, input_: torch.Tensor, dim: int = -1) -> torch.Tensor: -input_.dim() <= dim < input_.dim() ), f"Invalid dim ({dim}) for input tensor with shape {input_.size()}" - # For TPUs, use TPU communicator. - tpu_comm = self.tpu_communicator - if tpu_comm is not None and not tpu_comm.disabled: - return tpu_comm.all_gather(input_, dim) - # For HPUs, use HPU communicator. hpu_comm = self.hpu_communicator if hpu_comm is not None and not hpu_comm.disabled: @@ -886,7 +871,6 @@ def init_world_group( torch_distributed_backend=backend, use_pynccl=False, use_custom_allreduce=False, - use_tpu_communicator=False, use_hpu_communicator=False, use_xpu_communicator=False, group_name="world", @@ -909,7 +893,6 @@ def init_model_parallel_group( torch_distributed_backend=backend, use_pynccl=True, use_custom_allreduce=use_custom_allreduce, - use_tpu_communicator=True, use_hpu_communicator=True, use_xpu_communicator=True, use_message_queue_broadcaster=use_message_queue_broadcaster, From a1f4006be1a7c8d1f98765b044a7d38386da7348 Mon Sep 17 00:00:00 2001 From: Zhangyi <1109276519@qq.com> Date: Sat, 30 Nov 2024 02:23:29 +0800 Subject: [PATCH 07/10] add reference for cuda_device_count_stateless --- python/sglang/srt/utils.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/python/sglang/srt/utils.py b/python/sglang/srt/utils.py index e080608470..d61d908eb1 100644 --- a/python/sglang/srt/utils.py +++ b/python/sglang/srt/utils.py @@ -1073,6 +1073,7 @@ def _cuda_device_count_stateless(cuda_visible_devices: Optional[str] = None) -> return r +# Adapted from https://github.com/vllm-project/vllm/blob/a6221a144af772fd1a68fe7e627935dc53e81738/vllm/utils.py def cuda_device_count_stateless() -> int: """Get number of CUDA devices, caching based on the value of CUDA_VISIBLE_DEVICES at the time of call. @@ -1083,4 +1084,4 @@ def cuda_device_count_stateless() -> int: # This can be removed and simply replaced with torch.cuda.get_device_count # after https://github.com/pytorch/pytorch/pull/122815 is released. - return _cuda_device_count_stateless(envs.CUDA_VISIBLE_DEVICES) + return _cuda_device_count_stateless(os.environ.get("CUDA_VISIBLE_DEVICES", None)) From 6e1a3a92dfc62171d1f08365aef2d5390e5d4e93 Mon Sep 17 00:00:00 2001 From: Yineng Zhang Date: Sat, 30 Nov 2024 13:37:54 +0800 Subject: [PATCH 08/10] upd --- python/sglang/srt/_custom_ops.py | 1 - 1 file changed, 1 deletion(-) diff --git a/python/sglang/srt/_custom_ops.py b/python/sglang/srt/_custom_ops.py index ede8e313d7..1685ee1161 100644 --- a/python/sglang/srt/_custom_ops.py +++ b/python/sglang/srt/_custom_ops.py @@ -10,7 +10,6 @@ from sglang.srt.utils import is_hpu -# from vllm.scalar_type import ScalarType logger = logging.getLogger(__name__) From 6e2346388e66107a02478e69112fe96beadba9b7 Mon Sep 17 00:00:00 2001 From: Zhangyi <1109276519@qq.com> Date: Sat, 30 Nov 2024 22:19:02 +0800 Subject: [PATCH 09/10] move all envs which prefix is VLLM -> SGLANG --- .../device_communicators/custom_all_reduce.py | 2 +- .../distributed/device_communicators/pynccl_wrapper.py | 10 +++++----- .../distributed/device_communicators/shm_broadcast.py | 10 +++++----- python/sglang/srt/distributed/utils.py | 2 +- 4 files changed, 12 insertions(+), 12 deletions(-) diff --git a/python/sglang/srt/distributed/device_communicators/custom_all_reduce.py b/python/sglang/srt/distributed/device_communicators/custom_all_reduce.py index a09346521c..b6df234407 100644 --- a/python/sglang/srt/distributed/device_communicators/custom_all_reduce.py +++ b/python/sglang/srt/distributed/device_communicators/custom_all_reduce.py @@ -72,7 +72,7 @@ def is_full_nvlink(cls, physical_device_ids: List[int]) -> bool: def _can_p2p(rank: int, world_size: int) -> bool: # SGLANG_SKIP_P2P_CHECK can be set to False in sglang - SGLANG_SKIP_P2P_CHECK = os.getenv("VLLM_SKIP_P2P_CHECK", "0") == "1" + SGLANG_SKIP_P2P_CHECK = os.getenv("SGLANG_SKIP_P2P_CHECK", "0") == "1" for i in range(world_size): if i == rank: continue diff --git a/python/sglang/srt/distributed/device_communicators/pynccl_wrapper.py b/python/sglang/srt/distributed/device_communicators/pynccl_wrapper.py index b1721cec80..e72284f511 100644 --- a/python/sglang/srt/distributed/device_communicators/pynccl_wrapper.py +++ b/python/sglang/srt/distributed/device_communicators/pynccl_wrapper.py @@ -18,7 +18,7 @@ # recompilation of the code every time we want to switch between different # versions. This current implementation, with a **pure** Python wrapper, is # more flexible. We can easily switch between different versions of NCCL by -# changing the environment variable `VLLM_NCCL_SO_PATH`, or the `so_file` +# changing the environment variable `SGLANG_NCCL_SO_PATH`, or the `so_file` # variable in the code. import ctypes @@ -36,19 +36,19 @@ def find_nccl_library() -> str: """ - We either use the library file specified by the `VLLM_NCCL_SO_PATH` + We either use the library file specified by the `SGLANG_NCCL_SO_PATH` environment variable, or we find the library file brought by PyTorch. After importing `torch`, `libnccl.so.2` or `librccl.so.1` can be found by `ctypes` automatically. """ # so_file can be set to None in sglang - so_file = os.environ.get("VLLM_NCCL_SO_PATH", None) + so_file = os.environ.get("SGLANG_NCCL_SO_PATH", None) # manually load the nccl library if so_file: logger.info( - "Found nccl from environment variable VLLM_NCCL_SO_PATH=%s", so_file + "Found nccl from environment variable SGLANG_NCCL_SO_PATH=%s", so_file ) else: if torch.version.cuda is not None: @@ -249,7 +249,7 @@ def __init__(self, so_file: Optional[str] = None): "Otherwise, the nccl library might not exist, be corrupted " "or it does not support the current platform %s." "If you already have the library, please set the " - "environment variable VLLM_NCCL_SO_PATH" + "environment variable SGLANG_NCCL_SO_PATH" " to point to the correct nccl library path.", so_file, platform.platform(), diff --git a/python/sglang/srt/distributed/device_communicators/shm_broadcast.py b/python/sglang/srt/distributed/device_communicators/shm_broadcast.py index c454a9c0ae..1afe6fca52 100644 --- a/python/sglang/srt/distributed/device_communicators/shm_broadcast.py +++ b/python/sglang/srt/distributed/device_communicators/shm_broadcast.py @@ -20,15 +20,15 @@ # SGLANG_RINGBUFFER_WARNING_INTERVAL can be set to 60 SGLANG_RINGBUFFER_WARNING_INTERVAL = int( - os.environ.get("VLLM_RINGBUFFER_WARNING_INTERVAL", "60") + os.environ.get("SGLANG_RINGBUFFER_WARNING_INTERVAL", "60") ) logger = logging.getLogger(__name__) def get_ip() -> str: - # VLLM_HOST_IP env can be ignore - host_ip = os.getenv("VLLM_HOST_IP", "") or os.getenv("HOST_IP", "") + # SGLANG_HOST_IP env can be ignore + host_ip = os.getenv("SGLANG_HOST_IP", "") or os.getenv("HOST_IP", "") if host_ip: return host_ip @@ -470,11 +470,11 @@ def acquire_read(self): # if we wait for a long time, we should warn the user if ( time.monotonic() - start_time - > VLLM_RINGBUFFER_WARNING_INTERVAL * n_warning + > SGLANG_RINGBUFFER_WARNING_INTERVAL * n_warning ): logger.warning( "No available block found in %s second. ", - VLLM_RINGBUFFER_WARNING_INTERVAL, + SGLANG_RINGBUFFER_WARNING_INTERVAL, ) n_warning += 1 diff --git a/python/sglang/srt/distributed/utils.py b/python/sglang/srt/distributed/utils.py index 7fdf6a394d..a225fbb918 100644 --- a/python/sglang/srt/distributed/utils.py +++ b/python/sglang/srt/distributed/utils.py @@ -67,7 +67,7 @@ def get_pp_indices( the last partition will have the remaining layers. """ # partition_list_str can be set to None in sglang - partition_list_str = os.getenv("VLLM_PP_LAYER_PARTITION", None) + partition_list_str = os.getenv("SGLANG_PP_LAYER_PARTITION", None) if partition_list_str is not None: try: partitions = [int(layer) for layer in partition_list_str.split(",")] From 1cea9d48335d0087db5bf6d17058f6b69f32b2a1 Mon Sep 17 00:00:00 2001 From: Zhangyi <1109276519@qq.com> Date: Sat, 30 Nov 2024 22:22:58 +0800 Subject: [PATCH 10/10] fix format --- python/sglang/srt/_custom_ops.py | 1 - 1 file changed, 1 deletion(-) diff --git a/python/sglang/srt/_custom_ops.py b/python/sglang/srt/_custom_ops.py index 1685ee1161..9eb7caa1bb 100644 --- a/python/sglang/srt/_custom_ops.py +++ b/python/sglang/srt/_custom_ops.py @@ -10,7 +10,6 @@ from sglang.srt.utils import is_hpu - logger = logging.getLogger(__name__) if not is_hpu():