forked from ray-project/ray
-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
[Train] Decouple device-related modules and add Huawei NPU support to…
… Ray Train (ray-project#44086) We are looking to expand the hardware support range of Ray Train by incorporating Huawei Ascend NPU support. However, as the number of hardware types increases, scattered and device-specific modifications have been made to the code, which can impact future compatibility and maintainability. To address this, we have extracted the device-related modules from Ray Train and consolidated them into the `accelerator_utils`. This allows for greater independence among the device-specific code, resulting in improved maintainability. Signed-off-by: liuxsh9 <liuxiaoshuang4@huawei.com> Signed-off-by: matthewdeng <matt@anyscale.com> Co-authored-by: matthewdeng <matt@anyscale.com> Signed-off-by: ujjawal-khare <ujjawal.khare@dream11.com>
- Loading branch information
1 parent
a1cbb6a
commit 9049ad4
Showing
14 changed files
with
541 additions
and
67 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,92 @@ | ||
import logging | ||
import threading | ||
from typing import Optional | ||
|
||
import ray | ||
import ray._private.ray_constants as ray_constants | ||
from ray.air._internal.device_manager.cpu import CPUTorchDeviceManager | ||
from ray.air._internal.device_manager.hpu import HPUTorchDeviceManager | ||
from ray.air._internal.device_manager.npu import NPUTorchDeviceManager | ||
from ray.air._internal.device_manager.nvidia_gpu import CUDATorchDeviceManager | ||
from ray.air._internal.device_manager.torch_device_manager import TorchDeviceManager | ||
|
||
logger = logging.getLogger(__name__) | ||
|
||
|
||
DEFAULT_TORCH_DEVICE_MANAGER_CLS = CPUTorchDeviceManager | ||
|
||
|
||
SUPPORTED_ACCELERATOR_TORCH_DEVICE_MANAGER = { | ||
ray_constants.GPU: CUDATorchDeviceManager, | ||
ray_constants.HPU: HPUTorchDeviceManager, | ||
ray_constants.NPU: NPUTorchDeviceManager, | ||
} | ||
|
||
|
||
def register_custom_torch_dist_backend(backend: Optional[str] = None) -> None: | ||
if backend == "hccl": | ||
# The name for the communication backend of Habana and torch-npu is the same. | ||
HPUTorchDeviceManager.register_custom_torch_dist_backend() | ||
|
||
NPUTorchDeviceManager.register_custom_torch_dist_backend() | ||
|
||
|
||
_torch_device_manager = None | ||
_torch_device_manager_lock = threading.Lock() | ||
|
||
|
||
def get_torch_device_manager_by_context() -> TorchDeviceManager: | ||
global _torch_device_manager | ||
|
||
with _torch_device_manager_lock: | ||
if not _torch_device_manager: | ||
existing_device_manager_cls = None | ||
resources = ray.get_runtime_context().get_accelerator_ids() | ||
|
||
# select correct accelerator type from resources | ||
for resource_type, resource_value in resources.items(): | ||
device_manager_cls = SUPPORTED_ACCELERATOR_TORCH_DEVICE_MANAGER.get( | ||
resource_type, None | ||
) | ||
if resource_value and device_manager_cls: | ||
# An error will raise when multiple accelerators are specified. | ||
if existing_device_manager_cls: | ||
raise RuntimeError( | ||
"Unable to determine the appropriate DeviceManager " | ||
f"for the specified resources {resources}." | ||
) | ||
else: | ||
existing_device_manager_cls = device_manager_cls | ||
|
||
device_manager_cls = ( | ||
existing_device_manager_cls or DEFAULT_TORCH_DEVICE_MANAGER_CLS | ||
) | ||
|
||
_torch_device_manager = device_manager_cls() | ||
|
||
return _torch_device_manager | ||
|
||
|
||
def get_torch_device_manager_by_device_type(device_type: str): | ||
if device_type.lower() == ray_constants.GPU.lower() or device_type == "cuda": | ||
return CUDATorchDeviceManager() | ||
elif device_type.lower() == ray_constants.NPU.lower(): | ||
return NPUTorchDeviceManager() | ||
elif device_type.lower() == ray_constants.HPU.lower(): | ||
return HPUTorchDeviceManager() | ||
elif device_type.lower() == "cpu": | ||
return CPUTorchDeviceManager() | ||
|
||
raise RuntimeError(f"Device type {device_type} cannot be recognized.") | ||
|
||
|
||
__all__ = [ | ||
TorchDeviceManager, | ||
CPUTorchDeviceManager, | ||
CUDATorchDeviceManager, | ||
HPUTorchDeviceManager, | ||
NPUTorchDeviceManager, | ||
register_custom_torch_dist_backend, | ||
get_torch_device_manager_by_context, | ||
get_torch_device_manager_by_device_type, | ||
] |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,30 @@ | ||
from contextlib import contextmanager | ||
from typing import List | ||
|
||
import torch | ||
|
||
from ray.air._internal.device_manager.torch_device_manager import TorchDeviceManager | ||
|
||
|
||
class CPUTorchDeviceManager(TorchDeviceManager): | ||
"""CPU device manager""" | ||
|
||
def is_available(self) -> bool(): | ||
return True | ||
|
||
def get_devices(self) -> List[torch.device]: | ||
"""Gets the correct torch device list configured for this process.""" | ||
return [torch.device("cpu")] | ||
|
||
def supports_stream(self) -> bool: | ||
"""Validate if the device type support create a stream""" | ||
return False | ||
|
||
def get_stream_context(self, stream): | ||
"""Return empty context mananger for CPU.""" | ||
|
||
@contextmanager | ||
def default_context_manager(): | ||
yield | ||
|
||
return default_context_manager() |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,50 @@ | ||
from contextlib import contextmanager | ||
from typing import List, Union | ||
|
||
import torch | ||
|
||
from ray._private.accelerators.hpu import HPU_PACKAGE_AVAILABLE | ||
from ray.air._internal.device_manager.torch_device_manager import TorchDeviceManager | ||
|
||
if HPU_PACKAGE_AVAILABLE: | ||
import habana_frameworks.torch.hpu as torch_hpu | ||
|
||
|
||
class HPUTorchDeviceManager(TorchDeviceManager): | ||
"""HPU device manager""" | ||
|
||
@staticmethod | ||
def register_custom_torch_dist_backend(): | ||
if HPU_PACKAGE_AVAILABLE: | ||
import habana_frameworks.torch.core # noqa: F401 | ||
import habana_frameworks.torch.distributed.hccl # noqa: F401 | ||
|
||
def is_available(self) -> bool(): | ||
if not HPU_PACKAGE_AVAILABLE: | ||
return False | ||
|
||
return torch_hpu.is_available() | ||
|
||
def get_devices(self) -> List[torch.device]: | ||
if not self.is_available(): | ||
raise RuntimeError( | ||
"Using HPUTorchDeviceManager but torch hpu is not available." | ||
) | ||
|
||
return [torch.device("hpu")] | ||
|
||
def set_device(self, device: Union[torch.device, int, str, None]): | ||
torch_hpu.set_device(device) | ||
|
||
def supports_stream(self) -> bool: | ||
"""Validate if the device type support create a stream""" | ||
return False | ||
|
||
def get_stream_context(self, stream): | ||
"""Get HPU stream context manager, empty so far.""" | ||
|
||
@contextmanager | ||
def default_context_manager(): | ||
yield | ||
|
||
return default_context_manager() |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,105 @@ | ||
import os | ||
from importlib.util import find_spec | ||
from typing import List, Union | ||
|
||
import torch | ||
|
||
import ray | ||
import ray._private.ray_constants as ray_constants | ||
from ray.air._internal.device_manager.torch_device_manager import TorchDeviceManager | ||
|
||
|
||
def is_package_present(package_name: str) -> bool: | ||
try: | ||
return find_spec(package_name) is not None | ||
except ModuleNotFoundError: | ||
return False | ||
|
||
|
||
NPU_TORCH_PACKAGE_AVAILABLE = is_package_present("torch_npu") | ||
|
||
|
||
if NPU_TORCH_PACKAGE_AVAILABLE: | ||
import torch_npu # noqa: F401 | ||
|
||
|
||
class NPUTorchDeviceManager(TorchDeviceManager): | ||
"""Ascend NPU device manager""" | ||
|
||
@staticmethod | ||
def register_custom_torch_dist_backend(): | ||
if NPU_TORCH_PACKAGE_AVAILABLE: | ||
import torch_npu # noqa: F401, F811 | ||
|
||
def is_available(self) -> bool: | ||
if not NPU_TORCH_PACKAGE_AVAILABLE: | ||
return False | ||
|
||
return torch.npu.is_available() | ||
|
||
def get_devices(self) -> List[torch.device]: | ||
"""Gets the correct torch device list configured for this process. | ||
Returns a list of torch NPU devices allocated for the current worker. | ||
If no NPUs are assigned, then it returns a list with a single CPU device. | ||
""" | ||
if NPU_TORCH_PACKAGE_AVAILABLE and torch.npu.is_available(): | ||
npu_ids = [ | ||
str(id) | ||
for id in ray.get_runtime_context().get_accelerator_ids()[ | ||
ray_constants.NPU | ||
] | ||
] | ||
|
||
device_ids = [] | ||
|
||
if len(npu_ids) > 0: | ||
npu_visible_str = os.environ.get( | ||
ray_constants.NPU_RT_VISIBLE_DEVICES_ENV_VAR, "" | ||
) | ||
if npu_visible_str and npu_visible_str != "NoDevFiles": | ||
npu_visible_list = npu_visible_str.split(",") | ||
else: | ||
npu_visible_list = [] | ||
|
||
for npu_id in npu_ids: | ||
try: | ||
device_ids.append(npu_visible_list.index(npu_id)) | ||
except IndexError: | ||
raise RuntimeError( | ||
"ASCEND_RT_VISIBLE_DEVICES set incorrectly. " | ||
f"Got {npu_visible_str}, expected to include {npu_id}. " | ||
"Did you override the `ASCEND_RT_VISIBLE_DEVICES` " | ||
"environment variable?" | ||
) | ||
else: | ||
# If called on the driver or outside of Ray Train, return the | ||
# 0th device. | ||
device_ids.append(0) | ||
|
||
devices = [torch.device(f"npu:{device_id}") for device_id in device_ids] | ||
else: | ||
raise RuntimeError( | ||
"Using NPUTorchDeviceManager but torch npu is not available." | ||
) | ||
|
||
return devices | ||
|
||
def set_device(self, device: Union[torch.device, int]): | ||
torch.npu.set_device(device) | ||
|
||
def supports_stream(self) -> bool: | ||
"""Validate if the device type support to create a stream""" | ||
return True | ||
|
||
def create_stream(self, device): | ||
"""Create a stream on NPU device""" | ||
return torch.npu.Stream(device) | ||
|
||
def get_stream_context(self, stream): | ||
"""Get a torch.stream context on NPU device""" | ||
return torch.npu.stream(stream) | ||
|
||
def get_current_stream(self): | ||
"""Get current stream for NPU device""" | ||
return torch.npu.current_stream() |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,79 @@ | ||
import os | ||
from typing import List, Union | ||
|
||
import torch | ||
|
||
import ray | ||
from ray.air._internal.device_manager.torch_device_manager import TorchDeviceManager | ||
|
||
|
||
class CUDATorchDeviceManager(TorchDeviceManager): | ||
"""CUDA device manager""" | ||
|
||
def is_available(self) -> bool(): | ||
return torch.cuda.is_available() | ||
|
||
def get_devices(self) -> List[torch.device]: | ||
"""Gets the correct torch device list configured for this process. | ||
Returns a list of torch CUDA devices allocated for the current worker. | ||
If no GPUs are assigned, then it returns a list with a single CPU device. | ||
Assumes that `CUDA_VISIBLE_DEVICES` is set and is a | ||
superset of the `ray.get_gpu_ids()`. | ||
""" | ||
|
||
# GPU IDs are assigned by Ray after you specify "use_gpu" | ||
# GPU `ray.get_gpu_ids()` may return ints or may return strings. | ||
# We should always convert to strings. | ||
gpu_ids = [str(id) for id in ray.get_gpu_ids()] | ||
|
||
device_ids = [] | ||
|
||
if len(gpu_ids) > 0: | ||
cuda_visible_str = os.environ.get("CUDA_VISIBLE_DEVICES", "") | ||
if cuda_visible_str and cuda_visible_str != "NoDevFiles": | ||
cuda_visible_list = cuda_visible_str.split(",") | ||
else: | ||
cuda_visible_list = [] | ||
|
||
# By default, there should only be one GPU ID if `use_gpu=True`. | ||
# If there are multiple GPUs, return a list of devices. | ||
# If using fractional GPUs, these IDs are not guaranteed | ||
# to be unique across different processes. | ||
for gpu_id in gpu_ids: | ||
try: | ||
device_ids.append(cuda_visible_list.index(gpu_id)) | ||
except IndexError: | ||
raise RuntimeError( | ||
"CUDA_VISIBLE_DEVICES set incorrectly. " | ||
f"Got {cuda_visible_str}, expected to include {gpu_id}. " | ||
"Did you override the `CUDA_VISIBLE_DEVICES` environment" | ||
" variable? If not, please help file an issue on Github." | ||
) | ||
|
||
else: | ||
# If called on the driver or outside of Ray Train, return the | ||
# 0th device. | ||
device_ids.append(0) | ||
|
||
return [torch.device(f"cuda:{device_id}") for device_id in device_ids] | ||
|
||
def set_device(self, device: Union[torch.device, int, str, None]): | ||
torch.cuda.set_device(device) | ||
|
||
def supports_stream(self) -> bool: | ||
"""Validate if the device type support create a stream""" | ||
return True | ||
|
||
def create_stream(self, device: torch.device) -> torch.cuda.Stream: | ||
"""Create a stream on cuda device""" | ||
return torch.cuda.Stream(device) | ||
|
||
def get_stream_context(self, stream): | ||
"""Get a stream context for cuda device""" | ||
return torch.cuda.stream(stream) | ||
|
||
def get_current_stream(self) -> torch.cuda.Stream: | ||
"""Get current stream for cuda device""" | ||
return torch.cuda.current_stream() |
Oops, something went wrong.