Skip to content

Commit

Permalink
Adding more CUDA instead of GPU
Browse files Browse the repository at this point in the history
  • Loading branch information
qihqi committed Oct 6, 2023
1 parent e08c75a commit 6c59c2c
Show file tree
Hide file tree
Showing 7 changed files with 22 additions and 15 deletions.
8 changes: 4 additions & 4 deletions torch_xla/_internal/pjrt.py
Original file line number Diff line number Diff line change
Expand Up @@ -138,7 +138,7 @@ def run_multiprocess(fn: Callable[..., R],
"""
if runtime.device_type() == 'TPU':
num_processes = tpu.num_local_processes()
elif runtime.device_type() == 'GPU':
elif runtime.device_type() in ('GPU', 'ROCM', 'CUDA'):
num_processes = gpu.num_local_processes()
gpu.initialize_distributed_runtime(num_processes)
elif runtime.device_type() == 'NEURON':
Expand All @@ -160,7 +160,7 @@ def run_multiprocess(fn: Callable[..., R],
itertools.chain.from_iterable(
result.items() for result in process_results))

if runtime.device_type() == 'GPU':
if runtime.device_type() in ('GPU', 'ROCM', 'CUDA'):
gpu.shutdown_distributed_runtime()

return _merge_replica_results(replica_results)
Expand Down Expand Up @@ -210,8 +210,8 @@ def _initialize_single_process(local_rank: int, local_world_size: int):

def spawn_threads(fn: Callable, args: Tuple = ()) -> None:
"""Run function in one process with one thread per addressable device."""
assert runtime.device_type(
) != 'GPU', "spawn_threads does not support GPU device"
assert runtime.device_type() not in (
'GPU', 'ROCM', 'CUDA'), "spawn_threads does not support GPU device"
spawn_fn = _SpawnFn(fn, *args)
_run_thread_per_device(
local_rank=0,
Expand Down
6 changes: 3 additions & 3 deletions torch_xla/amp/autocast_mode.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@ def __init__(self,

self._enabled = enabled
self._xla_device = xm.xla_device_hw(device)
if self._xla_device == 'GPU':
if self._xla_device in ('GPU', 'ROCM', 'CUDA'):
backend = 'cuda'
self._xla_bfloat16 = False # True if xla backend with bfloat16 dtype.
if dtype is None:
Expand Down Expand Up @@ -70,7 +70,7 @@ def __init__(self,
def __enter__(self):
# This ensures that xla autocast is enabled even for XLA:GPU, which calls
# `torch.amp.autocast_mode.autocast` with `cuda` backend.
if self._xla_device == 'GPU':
if self._xla_device in ('GPU', 'ROCM', 'CUDA'):
self.prev = torch.is_autocast_xla_enabled() # type: ignore[attr-defined]
self.prev_dtype = torch.get_autocast_xla_dtype(
) # type: ignore[attr-defined]
Expand All @@ -86,7 +86,7 @@ def __enter__(self):

def __exit__(self, exc_type: Any, exc_val: Any,
exc_tb: Any): # type: ignore[override]
if self._xla_device == 'GPU':
if self._xla_device in ('GPU', 'ROCM', 'CUDA'):
if self._xla_bfloat16:
# autocast_xla flags will be set by `torch.autocast` and we need to
# set autocast flags as we call into `torch.autocast` apis.
Expand Down
12 changes: 7 additions & 5 deletions torch_xla/core/xla_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -71,7 +71,7 @@ def is_xla_tensor(tensor):


def parse_xla_device(device):
m = re.match(r'(CPU|TPU|GPU|XPU|NEURON):(\d+)$', device)
m = re.match(r'(CPU|TPU|GPU|ROCM|CUDA|XPU|NEURON):(\d+)$', device)
if m:
return (m.group(1), int(m.group(2)))

Expand All @@ -89,7 +89,9 @@ def get_xla_supported_devices(devkind=None, max_devices=None):
The list of device strings.
"""
xla_devices = _DEVICES.value
devkind = [devkind] if devkind else ['TPU', 'GPU', 'XPU', 'NEURON', 'CPU']
devkind = [devkind] if devkind else [
'TPU', 'GPU', 'XPU', 'NEURON', 'CPU', 'CUDA', 'ROCM'
]
for kind in devkind:
kind_devices = []
for i, device in enumerate(xla_devices):
Expand Down Expand Up @@ -181,8 +183,8 @@ def xla_device(n=None, devkind=None):
n (int, optional): The specific instance (ordinal) to be returned. If
specified, the specific XLA device instance will be returned. Otherwise
the first device of `devkind` will be returned.
devkind (string..., optional): If specified, one of `TPU`, `GPU`, `XPU`
`NEURON` or `CPU`.
devkind (string..., optional): If specified, one of `TPU`, `CUDA`, `XPU`
`NEURON`, `ROCM` or `CPU`.
Returns:
A `torch.device` with the requested instance.
Expand Down Expand Up @@ -217,7 +219,7 @@ def xla_device_hw(device):
real device.
Returns:
A string representation of the hardware type (`CPU`, `TPU`, `XPU`, `NEURON`, `GPU`)
A string representation of the hardware type (`CPU`, `TPU`, `XPU`, `NEURON`, `GPU`, `CUDA`, `ROCM`)
of the given device.
"""
real_device = _xla_real_device(device)
Expand Down
2 changes: 2 additions & 0 deletions torch_xla/csrc/random.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,8 @@ std::string GetDefaultGitGeneratorName() {
static_cast<XlaDeviceType>(bridge::GetCurrentDevice().type());
switch (hw_type) {
case XlaDeviceType::GPU:
case XlaDeviceType::CUDA:
case XlaDeviceType::ROCM:
return "three_fry";
default:
return "default";
Expand Down
3 changes: 2 additions & 1 deletion torch_xla/csrc/runtime/pjrt_computation_client.cc
Original file line number Diff line number Diff line change
Expand Up @@ -118,7 +118,8 @@ PjRtComputationClient::PjRtComputationClient() {
client_ = std::move(xla::GetCApiClient("TPU").value());
} else if (device_type == "TPU_LEGACY") {
XLA_ERROR() << "TPU_LEGACY client is no longer available.";
} else if (device_type == "GPU") {
} else if (device_type == "GPU" || device_type == "CUDA" ||
device_type == "ROCM") {
TF_VLOG(1) << "Initializing PjRt GPU client...";
bool async = sys_util::GetEnvBool(env::kEnvPjrtAsyncGpuClient, true);
int local_rank = sys_util::GetEnvInt(env::kEnvPjRtLocalRank, 0);
Expand Down
4 changes: 3 additions & 1 deletion torch_xla/csrc/tensor_impl.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -75,7 +75,9 @@ XLATensorImpl::XLATensorImpl(XLATensor&& tensor)
// Upstream TensorImpl cannot differentiate between XLA:TPU and XLA:GPU
// so we must manually update Autocast to AutocastCUDA on XLA:GPU.
torch::lazy::BackendDevice current_device = bridge::GetCurrentDevice();
if (static_cast<XlaDeviceType>(current_device.type()) == XlaDeviceType::GPU) {
auto dev_type = static_cast<XlaDeviceType>(current_device.type());
if (dev_type == XlaDeviceType::GPU || dev_type == XlaDeviceType::CUDA ||
dev_type == XlaDeviceType::ROCM) {
auto autocast_cuda_ks = c10::DispatchKeySet(c10::DispatchKey::AutocastCUDA);
auto autocast_xla_ks = c10::DispatchKeySet(c10::DispatchKey::AutocastXLA);
key_set_ = (key_set_ - autocast_xla_ks) | autocast_cuda_ks;
Expand Down
2 changes: 1 addition & 1 deletion torch_xla/runtime.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,7 @@ def _maybe_select_default_device():
# TODO(wcromar): Detect GPU device
elif xu.getenv_as(xenv.GPU_NUM_DEVICES, int, 0) > 0:
logging.warning('GPU_NUM_DEVICES is set. Setting PJRT_DEVICE=GPU')
os.environ[xenv.PJRT_DEVICE] = 'GPU'
os.environ[xenv.PJRT_DEVICE] = 'CUDA'
else:
logging.warning('Defaulting to PJRT_DEVICE=CPU')
os.environ[xenv.PJRT_DEVICE] = 'CPU'
Expand Down

0 comments on commit 6c59c2c

Please sign in to comment.