Skip to content

Commit

Permalink
[torch.compile] register allreduce operations as custom ops (vllm-pro…
Browse files Browse the repository at this point in the history
…ject#8526)

Signed-off-by: Amit Garg <mitgarg17495@gmail.com>
  • Loading branch information
youkaichao authored and garg-amit committed Oct 28, 2024
1 parent 63d8e62 commit 46f5089
Show file tree
Hide file tree
Showing 9 changed files with 137 additions and 50 deletions.
10 changes: 3 additions & 7 deletions .buildkite/test-pipeline.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -163,13 +163,6 @@ steps:
- python3 tensorize_vllm_model.py --model facebook/opt-125m serialize --serialized-directory /tmp/ --suffix v1 && python3 tensorize_vllm_model.py --model facebook/opt-125m deserialize --path-to-tensors /tmp/vllm/facebook/opt-125m/v1/model.tensors
- python3 offline_inference_encoder_decoder.py

- label: torch compile integration test
source_file_dependencies:
- vllm/
commands:
- pytest -v -s ./compile/test_full_graph.py
- pytest -v -s ./compile/test_wrapper.py

- label: Prefix Caching Test # 7min
#mirror_hardwares: [amd]
source_file_dependencies:
Expand Down Expand Up @@ -348,7 +341,10 @@ steps:
- vllm/executor/
- vllm/model_executor/models/
- tests/distributed/
- vllm/compilation
commands:
- pytest -v -s ./compile/test_full_graph.py
- pytest -v -s ./compile/test_wrapper.py
- VLLM_TEST_SAME_HOST=1 torchrun --nproc-per-node=4 distributed/test_same_node.py | grep -q 'Same node test passed'
- TARGET_TEST_SUITE=L4 pytest basic_correctness/ -v -s -m distributed_2_gpus
# Avoid importing model tests that cause CUDA reinitialization error
Expand Down
12 changes: 0 additions & 12 deletions csrc/custom_all_reduce.cu
Original file line number Diff line number Diff line change
Expand Up @@ -55,18 +55,6 @@ bool _is_weak_contiguous(torch::Tensor& t) {
t.numel() * t.element_size());
}

bool should_custom_ar(torch::Tensor& inp, int64_t max_size, int64_t world_size,
bool full_nvlink) {
auto inp_size = inp.numel() * inp.element_size();
// custom allreduce requires input byte size to be multiples of 16
if (inp_size % 16 != 0) return false;
if (!_is_weak_contiguous(inp)) return false;
if (world_size == 2 || full_nvlink) return inp_size <= max_size;
// for 4 or more non NVLink-capable GPUs, custom allreduce provides little
// performance improvement over NCCL.
return false;
}

void _all_reduce(fptr_t _fa, torch::Tensor& inp, torch::Tensor& out,
cudaStream_t stream) {
auto fa = reinterpret_cast<vllm::CustomAllreduce*>(_fa);
Expand Down
2 changes: 0 additions & 2 deletions csrc/ops.h
Original file line number Diff line number Diff line change
Expand Up @@ -241,8 +241,6 @@ fptr_t init_custom_ar(torch::Tensor& meta, torch::Tensor& rank_data,
const std::vector<std::string>& handles,
const std::vector<int64_t>& offsets, int64_t rank,
bool full_nvlink);
bool should_custom_ar(torch::Tensor& inp, int64_t max_size, int64_t world_size,
bool full_nvlink);
void all_reduce_reg(fptr_t _fa, torch::Tensor& inp, torch::Tensor& out);
void all_reduce_unreg(fptr_t _fa, torch::Tensor& inp, torch::Tensor& reg_buffer,
torch::Tensor& out);
Expand Down
5 changes: 0 additions & 5 deletions csrc/torch_bindings.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -411,11 +411,6 @@ TORCH_LIBRARY_EXPAND(CONCAT(TORCH_EXTENSION_NAME, _custom_ar), custom_ar) {
"bool full_nvlink) -> int");
custom_ar.impl("init_custom_ar", torch::kCUDA, &init_custom_ar);

custom_ar.def(
"should_custom_ar(Tensor inp, int max_size, int world_size, "
"bool full_nvlink) -> bool");
custom_ar.impl("should_custom_ar", torch::kCUDA, &should_custom_ar);

custom_ar.def("all_reduce_reg(int fa, Tensor inp, Tensor! out) -> ()");
custom_ar.impl("all_reduce_reg", torch::kCUDA, &all_reduce_reg);

Expand Down
Empty file added tests/compile/__init__.py
Empty file.
15 changes: 13 additions & 2 deletions tests/compile/test_full_graph.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,9 +2,20 @@

import pytest

from vllm.utils import cuda_device_count_stateless

from ..utils import fork_new_process_for_each_test


@pytest.mark.parametrize("model", ["meta-llama/Meta-Llama-3-8B"])
def test_full_graph(model):
@pytest.mark.parametrize("tp_size", [1, 2])
@fork_new_process_for_each_test
def test_full_graph(model, tp_size):

# Skip the test if there are not enough CUDA devices.
if cuda_device_count_stateless() < tp_size:
pytest.skip("Not enough CUDA devices for the test.")

# make sure these models can be captured in full graph mode
if "VLLM_TEST_DYNAMO_GRAPH_CAPTURE" not in os.environ:
os.environ["VLLM_TEST_DYNAMO_GRAPH_CAPTURE"] = "1"
Expand All @@ -17,7 +28,7 @@ def test_full_graph(model):
"The future of AI is",
]
sampling_params = SamplingParams(temperature=0)
llm = LLM(model=model, enforce_eager=True)
llm = LLM(model=model, enforce_eager=True, tensor_parallel_size=tp_size)

outputs = llm.generate(prompts, sampling_params)

Expand Down
6 changes: 0 additions & 6 deletions vllm/_custom_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -870,12 +870,6 @@ def init_custom_ar(meta: torch.Tensor, rank_data: torch.Tensor,
offsets, rank, full_nvlink)


def should_custom_ar(inp: torch.Tensor, max_size: int, world_size: int,
full_nvlink: bool) -> bool:
return torch.ops._C_custom_ar.should_custom_ar(inp, max_size, world_size,
full_nvlink)


def all_reduce_reg(fa: int, inp: torch.Tensor, out: torch.Tensor) -> None:
torch.ops._C_custom_ar.all_reduce_reg(fa, inp, out)

Expand Down
21 changes: 19 additions & 2 deletions vllm/distributed/device_communicators/custom_all_reduce.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,12 @@ def _can_p2p(rank: int, world_size: int) -> bool:
return True


def is_weak_contiguous(inp: torch.Tensor):
return inp.is_contiguous() or (inp.storage().nbytes() -
inp.storage_offset() * inp.element_size()
== inp.numel() * inp.element_size())


class CustomAllreduce:

_SUPPORTED_WORLD_SIZES = [2, 4, 6, 8]
Expand Down Expand Up @@ -224,8 +230,19 @@ def register_graph_buffers(self):
ops.register_graph_buffers(self._ptr, handles, offsets)

def should_custom_ar(self, inp: torch.Tensor):
return ops.should_custom_ar(inp, self.max_size, self.world_size,
self.full_nvlink)
if self.disabled:
return False
inp_size = inp.numel() * inp.element_size()
# custom allreduce requires input byte size to be multiples of 16
if inp_size % 16 != 0:
return False
if not is_weak_contiguous(inp):
return False
# for 4 or more non NVLink-capable GPUs, custom allreduce provides
# little performance improvement over NCCL.
if self.world_size == 2 or self.full_nvlink:
return inp_size < self.max_size
return False

# all reduce, assuming inp tensor is IPC registered with register_buffer,
# or, in the context of cuda graphs, register_graph_buffers
Expand Down
116 changes: 102 additions & 14 deletions vllm/distributed/parallel_state.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,11 +21,12 @@
"""
import contextlib
import pickle
import weakref
from collections import namedtuple
from contextlib import contextmanager, nullcontext
from dataclasses import dataclass
from multiprocessing import shared_memory
from typing import Any, Dict, List, Optional, Tuple, Union
from typing import Any, Callable, Dict, List, Optional, Tuple, Union
from unittest.mock import patch

import torch
Expand Down Expand Up @@ -69,6 +70,58 @@ def _split_tensor_dict(
return metadata_list, tensor_list


_group_name_counter: Dict[str, int] = {}


def _get_unique_name(name: str) -> str:
"""Get a unique name for the group.
Example:
_get_unique_name("tp") -> "tp:0"
_get_unique_name("tp") -> "tp:1"
"""
if name not in _group_name_counter:
_group_name_counter[name] = 0
newname = f"{name}:{_group_name_counter[name]}"
_group_name_counter[name] += 1
return newname


_groups: Dict[str, Callable[[], "GroupCoordinator"]] = {}


def _register_group(group: "GroupCoordinator") -> None:
# looks like Python 3.8 does not understand `ReferenceType`
_groups[group.unique_name] = weakref.ref(group) # type: ignore


@torch.library.custom_op("vllm::inplace_all_reduce", mutates_args=["tensor"])
def inplace_all_reduce(tensor: torch.Tensor, group_name: str) -> None:
assert group_name in _groups, f"Group {group_name} is not found."
group = _groups[group_name]()
if group is None:
raise ValueError(f"Group {group_name} is destroyed.")
group._all_reduce(tensor)


@inplace_all_reduce.register_fake
def _(tensor: torch.Tensor, group_name: str) -> None:
return


@torch.library.custom_op("vllm::outplace_all_reduce", mutates_args=[])
def outplace_all_reduce(tensor: torch.Tensor, group_name: str) -> torch.Tensor:
assert group_name in _groups, f"Group {group_name} is not found."
group = _groups[group_name]()
if group is None:
raise ValueError(f"Group {group_name} is destroyed.")
return group._all_reduce(tensor)


@outplace_all_reduce.register_fake
def _(tensor: torch.Tensor, group_name: str) -> torch.Tensor:
return torch.empty_like(tensor)


class GroupCoordinator:
"""
PyTorch ProcessGroup wrapper for a group of processes.
Expand Down Expand Up @@ -111,7 +164,11 @@ def __init__(
use_custom_allreduce: bool,
use_tpu_communicator: bool,
use_message_queue_broadcaster: bool = False,
group_name: Optional[str] = None,
):
group_name = group_name or "anonymous"
self.unique_name = _get_unique_name(group_name)
_register_group(self)

self.rank = torch.distributed.get_rank()
self.local_rank = local_rank
Expand Down Expand Up @@ -149,28 +206,24 @@ def __init__(
from vllm.distributed.device_communicators.pynccl import (
PyNcclCommunicator)

self.pynccl_comm: Optional[PyNcclCommunicator]
self.pynccl_comm: Optional[PyNcclCommunicator] = None
if use_pynccl and self.world_size > 1:
self.pynccl_comm = PyNcclCommunicator(
group=self.cpu_group,
device=self.device,
)
else:
self.pynccl_comm = None

self.ca_comm: Optional[CustomAllreduce]
self.ca_comm: Optional[CustomAllreduce] = None
if use_custom_allreduce and self.world_size > 1:
# Initialize a custom fast all-reduce implementation.
self.ca_comm = CustomAllreduce(
group=self.cpu_group,
device=self.device,
)
else:
self.ca_comm = None

from vllm.distributed.device_communicators.tpu_communicator import (
TpuCommunicator)
self.tpu_communicator: Optional[TpuCommunicator]
self.tpu_communicator: Optional[TpuCommunicator] = None
if use_tpu_communicator and self.world_size > 1:
self.tpu_communicator = TpuCommunicator(group=self.cpu_group)

Expand Down Expand Up @@ -264,16 +317,46 @@ def graph_capture(

def all_reduce(self, input_: torch.Tensor) -> torch.Tensor:
"""
User-facing all-reduce function before we actually call the
all-reduce operation.
We need this because Dynamo does not support passing an arbitrary
object (`self` in this case) to a custom op. We need to pass the
group name as a string, and then look up the group coordinator from
the group name, dispatch the all-reduce operation to the group
coordinator.
In addition, PyTorch custom ops do not support mutation or returning
a new tensor in the same op. So we need to figure out if the op is
in-place or out-of-place ahead of time.
"""
# Bypass the function if we are using only 1 GPU.
if self.world_size == 1:
return input_

if self.tpu_communicator is not None and \
not self.tpu_communicator.disabled:
# TPU handles Dynamo with its own logic.
return self._all_reduce(input_)

if self.ca_comm is not None and self.ca_comm.should_custom_ar(input_):
return torch.ops.vllm.outplace_all_reduce(
input_, group_name=self.unique_name)
else:
torch.ops.vllm.inplace_all_reduce(input_,
group_name=self.unique_name)
return input_

def _all_reduce(self, input_: torch.Tensor) -> torch.Tensor:
"""
The actual all-reduce implementation.
NOTE: This operation will be applied in-place or out-of-place.
Always assume this function modifies its input, but use the return
value as the output.
"""
ca_comm = self.ca_comm

# Bypass the function if we are using only 1 GPU.
if self.world_size == 1:
return input_

# For TPUs, use TPU communicator.
tpu_comm = self.tpu_communicator
if tpu_comm is not None and not tpu_comm.disabled:
Expand Down Expand Up @@ -758,6 +841,7 @@ def init_world_group(ranks: List[int], local_rank: int,
use_pynccl=False,
use_custom_allreduce=False,
use_tpu_communicator=False,
group_name="world",
)


Expand All @@ -767,6 +851,7 @@ def init_model_parallel_group(
backend: str,
use_custom_allreduce: Optional[bool] = None,
use_message_queue_broadcaster: bool = False,
group_name: Optional[str] = None,
) -> GroupCoordinator:
if use_custom_allreduce is None:
use_custom_allreduce = _ENABLE_CUSTOM_ALL_REDUCE
Expand All @@ -778,6 +863,7 @@ def init_model_parallel_group(
use_custom_allreduce=use_custom_allreduce,
use_tpu_communicator=True,
use_message_queue_broadcaster=use_message_queue_broadcaster,
group_name=group_name,
)


Expand Down Expand Up @@ -931,7 +1017,8 @@ def initialize_model_parallel(
_TP = init_model_parallel_group(group_ranks,
get_world_group().local_rank,
backend,
use_message_queue_broadcaster=True)
use_message_queue_broadcaster=True,
group_name="tp")

# Build the pipeline model-parallel groups.
num_pipeline_model_parallel_groups: int = (world_size //
Expand All @@ -947,7 +1034,8 @@ def initialize_model_parallel(
_PP = init_model_parallel_group(group_ranks,
get_world_group().local_rank,
backend,
use_custom_allreduce=False)
use_custom_allreduce=False,
group_name="pp")


def ensure_model_parallel_initialized(
Expand Down

0 comments on commit 46f5089

Please sign in to comment.