Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add support for XPU device in PrefetchLoader #7918

Merged
merged 8 commits into from
Aug 23, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).

### Added

- Added support for XPU device in `PrefetchLoader` ([#7918](https://github.com/pyg-team/pytorch_geometric/pull/7918))
- Added support for floating-point slicing in `Dataset`, *e.g.*, `dataset[:0.9]` ([#7915](https://github.com/pyg-team/pytorch_geometric/pull/7915))
- Added nightly GPU tests ([#7895](https://github.com/pyg-team/pytorch_geometric/pull/7895))
- Added the `HalfHop` graph upsampling augmentation ([#7827](https://github.com/pyg-team/pytorch_geometric/pull/7827))
Expand Down
67 changes: 49 additions & 18 deletions torch_geometric/loader/prefetch.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,52 @@
import warnings
from contextlib import nullcontext
from functools import partial
from typing import Any, Optional

import torch
from torch.utils.data import DataLoader

from torch_geometric.typing import WITH_IPEX


class DeviceHelper:
def __init__(self, device: Optional[torch.device] = None):
with_cuda = torch.cuda.is_available()
with_xpu = torch.xpu.is_available() if WITH_IPEX else False

if device is None:
if with_cuda:
device = 'cuda'
elif with_xpu:
device = 'xpu'
else:
device = 'cpu'

self.device = torch.device(device)
self.is_gpu = self.device.type in ['cuda', 'xpu']

if ((self.device.type == 'cuda' and not with_cuda)
or (self.device.type == 'xpu' and not with_xpu)):
warnings.warn(f"Requested device '{self.device.type}' is not "
f"available, falling back to CPU")
self.device = torch.device('cpu')

self.stream = None
self.stream_context = nullcontext
self.module = getattr(torch, self.device.type) if self.is_gpu else None

def maybe_init_stream(self) -> None:
if self.is_gpu:
self.stream = self.module.Stream()
self.stream_context = partial(
self.module.stream,
stream=self.stream,
)

def maybe_wait_stream(self) -> None:
if self.stream is not None:
self.module.current_stream().wait_stream(self.stream)


class PrefetchLoader:
r"""A GPU prefetcher class for asynchronously transferring data of a
Expand All @@ -20,46 +62,35 @@ def __init__(
loader: DataLoader,
device: Optional[torch.device] = None,
):
if device is None:
device = 'cuda' if torch.cuda.is_available() else 'cpu'

self.loader = loader
self.device = torch.device(device)

self.is_cuda = torch.cuda.is_available() and self.device.type == 'cuda'
self.device_helper = DeviceHelper(device)

def non_blocking_transfer(self, batch: Any) -> Any:
if not self.is_cuda:
if not self.device_helper.is_gpu:
return batch
if isinstance(batch, (list, tuple)):
return [self.non_blocking_transfer(v) for v in batch]
if isinstance(batch, dict):
return {k: self.non_blocking_transfer(v) for k, v in batch.items()}

batch = batch.pin_memory()
return batch.to(self.device, non_blocking=True)
batch = batch.pin_memory(self.device_helper.device)
return batch.to(self.device_helper.device, non_blocking=True)

def __iter__(self) -> Any:
first = True
if self.is_cuda:
stream = torch.cuda.Stream()
stream_context = partial(torch.cuda.stream, stream=stream)
else:
stream = None
stream_context = nullcontext
self.device_helper.maybe_init_stream()

for next_batch in self.loader:

with stream_context():
with self.device_helper.stream_context():
next_batch = self.non_blocking_transfer(next_batch)

if not first:
yield batch # noqa
else:
first = False

if stream is not None:
torch.cuda.current_stream().wait_stream(stream)
self.device_helper.maybe_wait_stream()

batch = next_batch

Expand Down
7 changes: 7 additions & 0 deletions torch_geometric/typing.py
Original file line number Diff line number Diff line change
Expand Up @@ -205,6 +205,13 @@ def masked_select_nnz(src: SparseTensor, mask: Tensor,
raise ImportError("'masked_select_nnz' requires 'torch-sparse'")


try:
import intel_extension_for_pytorch # noqa
WITH_IPEX = True
except (ImportError, OSError):
WITH_IPEX = False


class MockTorchCSCTensor:
def __init__(
self,
Expand Down