diff --git a/CHANGELOG.md b/CHANGELOG.md index 0e9a9240f605..968fa1e512c4 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -7,6 +7,7 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/). ### Added +- Added support for XPU device in `PrefetchLoader` ([#7918](https://github.com/pyg-team/pytorch_geometric/pull/7918)) - Added support for floating-point slicing in `Dataset`, *e.g.*, `dataset[:0.9]` ([#7915](https://github.com/pyg-team/pytorch_geometric/pull/7915)) - Added nightly GPU tests ([#7895](https://github.com/pyg-team/pytorch_geometric/pull/7895)) - Added the `HalfHop` graph upsampling augmentation ([#7827](https://github.com/pyg-team/pytorch_geometric/pull/7827)) diff --git a/torch_geometric/loader/prefetch.py b/torch_geometric/loader/prefetch.py index 3bbfd69c6978..1340c9e7660f 100644 --- a/torch_geometric/loader/prefetch.py +++ b/torch_geometric/loader/prefetch.py @@ -1,3 +1,4 @@ +import warnings from contextlib import nullcontext from functools import partial from typing import Any, Optional @@ -5,6 +6,47 @@ import torch from torch.utils.data import DataLoader +from torch_geometric.typing import WITH_IPEX + + +class DeviceHelper: + def __init__(self, device: Optional[torch.device] = None): + with_cuda = torch.cuda.is_available() + with_xpu = torch.xpu.is_available() if WITH_IPEX else False + + if device is None: + if with_cuda: + device = 'cuda' + elif with_xpu: + device = 'xpu' + else: + device = 'cpu' + + self.device = torch.device(device) + self.is_gpu = self.device.type in ['cuda', 'xpu'] + + if ((self.device.type == 'cuda' and not with_cuda) + or (self.device.type == 'xpu' and not with_xpu)): + warnings.warn(f"Requested device '{self.device.type}' is not " + f"available, falling back to CPU") + self.device = torch.device('cpu') + + self.stream = None + self.stream_context = nullcontext + self.module = getattr(torch, self.device.type) if self.is_gpu else None + + def maybe_init_stream(self) -> None: + if self.is_gpu: + self.stream = self.module.Stream() + self.stream_context = partial( + self.module.stream, + stream=self.stream, + ) + + def maybe_wait_stream(self) -> None: + if self.stream is not None: + self.module.current_stream().wait_stream(self.stream) + class PrefetchLoader: r"""A GPU prefetcher class for asynchronously transferring data of a @@ -20,37 +62,27 @@ def __init__( loader: DataLoader, device: Optional[torch.device] = None, ): - if device is None: - device = 'cuda' if torch.cuda.is_available() else 'cpu' - self.loader = loader - self.device = torch.device(device) - - self.is_cuda = torch.cuda.is_available() and self.device.type == 'cuda' + self.device_helper = DeviceHelper(device) def non_blocking_transfer(self, batch: Any) -> Any: - if not self.is_cuda: + if not self.device_helper.is_gpu: return batch if isinstance(batch, (list, tuple)): return [self.non_blocking_transfer(v) for v in batch] if isinstance(batch, dict): return {k: self.non_blocking_transfer(v) for k, v in batch.items()} - batch = batch.pin_memory() - return batch.to(self.device, non_blocking=True) + batch = batch.pin_memory(self.device_helper.device) + return batch.to(self.device_helper.device, non_blocking=True) def __iter__(self) -> Any: first = True - if self.is_cuda: - stream = torch.cuda.Stream() - stream_context = partial(torch.cuda.stream, stream=stream) - else: - stream = None - stream_context = nullcontext + self.device_helper.maybe_init_stream() for next_batch in self.loader: - with stream_context(): + with self.device_helper.stream_context(): next_batch = self.non_blocking_transfer(next_batch) if not first: @@ -58,8 +90,7 @@ def __iter__(self) -> Any: else: first = False - if stream is not None: - torch.cuda.current_stream().wait_stream(stream) + self.device_helper.maybe_wait_stream() batch = next_batch diff --git a/torch_geometric/typing.py b/torch_geometric/typing.py index c412069c1344..38a3a01fbb46 100644 --- a/torch_geometric/typing.py +++ b/torch_geometric/typing.py @@ -205,6 +205,13 @@ def masked_select_nnz(src: SparseTensor, mask: Tensor, raise ImportError("'masked_select_nnz' requires 'torch-sparse'") +try: + import intel_extension_for_pytorch # noqa + WITH_IPEX = True +except (ImportError, OSError): + WITH_IPEX = False + + class MockTorchCSCTensor: def __init__( self,