Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add support for XPU device in PrefetchLoader #7918

Merged
merged 8 commits into from
Aug 23, 2023
Merged
Show file tree
Hide file tree
Changes from 4 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).

### Added

- Added support for XPU device in `PrefetchLoader` ([#7918](https://github.com/pyg-team/pytorch_geometric/pull/7918))
- Added support for floating-point slicing in `Dataset`, *e.g.*, `dataset[:0.9]` ([#7915](https://github.com/pyg-team/pytorch_geometric/pull/7915))
- Added nightly GPU tests ([#7895](https://github.com/pyg-team/pytorch_geometric/pull/7895))
- Added the `HalfHop` graph upsampling augmentation ([#7827](https://github.com/pyg-team/pytorch_geometric/pull/7827))
Expand Down
78 changes: 60 additions & 18 deletions torch_geometric/loader/prefetch.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,8 @@
import torch
from torch.utils.data import DataLoader

from torch_geometric.typing import WITH_IPEX


class PrefetchLoader:
r"""A GPU prefetcher class for asynchronously transferring data of a
Expand All @@ -15,51 +17,91 @@ class PrefetchLoader:
device (torch.device, optional): The device to load the data to.
(default: :obj:`None`)
"""
class PrefetchLoaderDevice:
def __init__(self, device: Optional[torch.device] = None):
cuda_present = torch.cuda.is_available()
xpu_present = torch.xpu.is_available() if WITH_IPEX else False

if device is None:
if cuda_present:
device = 'cuda'
elif xpu_present:
device = 'xpu'
else:
device = 'cpu'

self.device = torch.device(device)

if ((self.device.type == 'cuda' and not cuda_present)
or (self.device.type == 'xpu' and not xpu_present)):
print(f'Requested device[{self.device.type}] is not available '
'- fallback to CPU')
self.device = torch.device('cpu')

self.is_gpu = self.device.type in ['cuda', 'xpu']
self.stream = None
self.stream_context = nullcontext

gpu_module = None
if self.is_gpu:
if self.device.type == 'cuda':
gpu_module = torch.cuda
else:
gpu_module = torch.xpu

self.gpu_module = gpu_module

def maybe_init_stream(self) -> None:
if self.is_gpu:
self.stream = self.gpu_module.Stream()
self.stream_context = partial(self.gpu_module.stream,
stream=self.stream)

def maybe_wait_stream(self) -> None:
if self.stream is not None:
self.gpu_module.current_stream().wait_stream(self.stream)

def get_device(self) -> torch.device:
return self.device

def get_stream_context(self) -> Any:
return self.stream_context()

def __init__(
self,
loader: DataLoader,
device: Optional[torch.device] = None,
):
if device is None:
device = 'cuda' if torch.cuda.is_available() else 'cpu'

self.loader = loader
self.device = torch.device(device)

self.is_cuda = torch.cuda.is_available() and self.device.type == 'cuda'
self.device_mgr = self.PrefetchLoaderDevice(device)

def non_blocking_transfer(self, batch: Any) -> Any:
if not self.is_cuda:
if not self.device_mgr.is_gpu:
rusty1s marked this conversation as resolved.
Show resolved Hide resolved
return batch
if isinstance(batch, (list, tuple)):
return [self.non_blocking_transfer(v) for v in batch]
if isinstance(batch, dict):
return {k: self.non_blocking_transfer(v) for k, v in batch.items()}

batch = batch.pin_memory()
return batch.to(self.device, non_blocking=True)
device = self.device_mgr.get_device()
batch = batch.pin_memory(device)
rusty1s marked this conversation as resolved.
Show resolved Hide resolved
return batch.to(device, non_blocking=True)

def __iter__(self) -> Any:
first = True
if self.is_cuda:
stream = torch.cuda.Stream()
stream_context = partial(torch.cuda.stream, stream=stream)
else:
stream = None
stream_context = nullcontext
self.device_mgr.maybe_init_stream()

for next_batch in self.loader:

with stream_context():
with self.device_mgr.get_stream_context():
next_batch = self.non_blocking_transfer(next_batch)

if not first:
yield batch # noqa
else:
first = False

if stream is not None:
torch.cuda.current_stream().wait_stream(stream)
self.device_mgr.maybe_wait_stream()

batch = next_batch

Expand Down
11 changes: 11 additions & 0 deletions torch_geometric/typing.py
Original file line number Diff line number Diff line change
Expand Up @@ -205,6 +205,17 @@ def masked_select_nnz(src: SparseTensor, mask: Tensor,
raise ImportError("'masked_select_nnz' requires 'torch-sparse'")


try:
import intel_extension_for_pytorch # noqa
WITH_IPEX = True
except (ImportError, OSError) as e:
if isinstance(e, OSError):
warnings.warn("An issue occurred while importing"
"'intel-extension-for-pytorch'. "
f"Disabling its usage. Stacktrace: {e}")
rusty1s marked this conversation as resolved.
Show resolved Hide resolved
WITH_IPEX = False


class MockTorchCSCTensor:
def __init__(
self,
Expand Down