forked from vllm-project/vllm
-
Notifications
You must be signed in to change notification settings - Fork 3
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
[Ray backend] Better error when pg topology is bad. (vllm-project#7584)
Co-authored-by: youkaichao <youkaichao@126.com>
- Loading branch information
1 parent
3ac607b
commit f988b5f
Showing
3 changed files
with
197 additions
and
9 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,64 @@ | ||
"""Make sure ray assigns GPU workers to the correct node. | ||
Run: | ||
```sh | ||
cd $VLLM_PATH/tests | ||
pytest distributed/test_multi_node_assignment.py | ||
``` | ||
""" | ||
|
||
import os | ||
|
||
import pytest | ||
import ray | ||
from ray.util.scheduling_strategies import PlacementGroupSchedulingStrategy | ||
|
||
from vllm import initialize_ray_cluster | ||
from vllm.config import ParallelConfig | ||
from vllm.executor.ray_utils import _wait_until_pg_removed | ||
from vllm.utils import get_ip | ||
|
||
VLLM_MULTI_NODE = os.getenv("VLLM_MULTI_NODE", "0") == "1" | ||
|
||
|
||
@pytest.mark.skipif(not VLLM_MULTI_NODE, | ||
reason="Need at least 2 nodes to run the test.") | ||
def test_multi_node_assignment() -> None: | ||
|
||
# NOTE: important to keep this class definition here | ||
# to let ray use cloudpickle to serialize it. | ||
class Actor: | ||
|
||
def get_ip(self): | ||
return get_ip() | ||
|
||
for _ in range(10): | ||
config = ParallelConfig(1, 2) | ||
initialize_ray_cluster(config) | ||
|
||
current_ip = get_ip() | ||
workers = [] | ||
for bundle_id, bundle in enumerate( | ||
config.placement_group.bundle_specs): | ||
if not bundle.get("GPU", 0): | ||
continue | ||
scheduling_strategy = PlacementGroupSchedulingStrategy( | ||
placement_group=config.placement_group, | ||
placement_group_capture_child_tasks=True, | ||
placement_group_bundle_index=bundle_id, | ||
) | ||
|
||
worker = ray.remote( | ||
num_cpus=0, | ||
num_gpus=1, | ||
scheduling_strategy=scheduling_strategy, | ||
)(Actor).remote() | ||
worker_ip = ray.get(worker.get_ip.remote()) | ||
assert worker_ip == current_ip | ||
workers.append(worker) | ||
|
||
for worker in workers: | ||
ray.kill(worker) | ||
|
||
_wait_until_pg_removed(config.placement_group) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters