From 786bcb3ea25c49ab8305b13a2493b49b4691b950 Mon Sep 17 00:00:00 2001 From: Damian Szwichtenberg Date: Thu, 17 Aug 2023 12:21:48 +0200 Subject: [PATCH 1/8] Add support for XPU device in PrefetchLoader --- torch_geometric/loader/prefetch.py | 75 +++++++++++++++++++++++------- torch_geometric/typing.py | 10 ++++ 2 files changed, 67 insertions(+), 18 deletions(-) diff --git a/torch_geometric/loader/prefetch.py b/torch_geometric/loader/prefetch.py index 3bbfd69c6978..69c5bab6f309 100644 --- a/torch_geometric/loader/prefetch.py +++ b/torch_geometric/loader/prefetch.py @@ -4,6 +4,7 @@ import torch from torch.utils.data import DataLoader +from torch_geometric.typing import WITH_IPEX class PrefetchLoader: @@ -15,42 +16,81 @@ class PrefetchLoader: device (torch.device, optional): The device to load the data to. (default: :obj:`None`) """ + class PrefetchLoaderDevice: + def __init__(self, device: Optional[torch.device] = None): + cuda_present = torch.cuda.is_available() + xpu_present = torch.xpu.is_available() if WITH_IPEX else False + + if device is None: + if cuda_present: + device = 'cuda' + elif xpu_present: + device = 'xpu' + else: + device = 'cpu' + + self.device = torch.device(device) + + if ((self.device.type == 'cuda' and not cuda_present) or + (self.device.type == 'xpu' and not xpu_present)): + print(f'Requested device[{self.device.type}] is not available ' + '- fallback to CPU') + self.device = torch.device('cpu') + + self.is_gpu = self.device.type in ['cuda', 'xpu'] + self.stream = None + self.stream_context = nullcontext + + if self.is_gpu: + gpu_module = torch.cuda if self.device.type == 'cuda' else torch.xpu + else: + gpu_module = None + + self.gpu_module = gpu_module + + def maybe_init_stream(self) -> None: + if self.is_gpu: + self.stream = self.gpu_module.Stream() + self.stream_context = partial(self.gpu_module.stream, stream=self.stream) + + def maybe_wait_stream(self) -> None: + if self.stream is not None: + self.gpu_module.current_stream().wait_stream(self.stream) + + def get_device(self) -> torch.device: + return self.device + + def get_stream_context(self) -> Any: + return self.stream_context() + + def __init__( self, loader: DataLoader, device: Optional[torch.device] = None, ): - if device is None: - device = 'cuda' if torch.cuda.is_available() else 'cpu' - self.loader = loader - self.device = torch.device(device) - - self.is_cuda = torch.cuda.is_available() and self.device.type == 'cuda' + self.device_mgr = self.PrefetchLoaderDevice(device) def non_blocking_transfer(self, batch: Any) -> Any: - if not self.is_cuda: + if not self.device_mgr.is_gpu: return batch if isinstance(batch, (list, tuple)): return [self.non_blocking_transfer(v) for v in batch] if isinstance(batch, dict): return {k: self.non_blocking_transfer(v) for k, v in batch.items()} - batch = batch.pin_memory() - return batch.to(self.device, non_blocking=True) + device = self.device_mgr.get_device() + batch = batch.pin_memory(device) + return batch.to(device, non_blocking=True) def __iter__(self) -> Any: first = True - if self.is_cuda: - stream = torch.cuda.Stream() - stream_context = partial(torch.cuda.stream, stream=stream) - else: - stream = None - stream_context = nullcontext + self.device_mgr.maybe_init_stream() for next_batch in self.loader: - with stream_context(): + with self.device_mgr.get_stream_context(): next_batch = self.non_blocking_transfer(next_batch) if not first: @@ -58,8 +98,7 @@ def __iter__(self) -> Any: else: first = False - if stream is not None: - torch.cuda.current_stream().wait_stream(stream) + self.device_mgr.maybe_wait_stream() batch = next_batch diff --git a/torch_geometric/typing.py b/torch_geometric/typing.py index c412069c1344..2ea63f5daee3 100644 --- a/torch_geometric/typing.py +++ b/torch_geometric/typing.py @@ -204,6 +204,16 @@ def masked_select_nnz(src: SparseTensor, mask: Tensor, layout: Optional[str] = None) -> SparseTensor: raise ImportError("'masked_select_nnz' requires 'torch-sparse'") +try: + import intel_extension_for_pytorch # noqa + WITH_IPEX = True +except (ImportError, OSError) as e: + if isinstance(e, OSError): + warnings.warn("An issue occurred while importing" + "'intel-extension-for-pytorch'. " + f"Disabling its usage. Stacktrace: {e}") + WITH_IPEX = False + class MockTorchCSCTensor: def __init__( From bfd8d77a4c9fc08694a6bd1cc927e90971c88846 Mon Sep 17 00:00:00 2001 From: Damian Szwichtenberg Date: Wed, 23 Aug 2023 08:43:54 +0200 Subject: [PATCH 2/8] Update CHANGELOG.md --- CHANGELOG.md | 1 + 1 file changed, 1 insertion(+) diff --git a/CHANGELOG.md b/CHANGELOG.md index 0e9a9240f605..968fa1e512c4 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -7,6 +7,7 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/). ### Added +- Added support for XPU device in `PrefetchLoader` ([#7918](https://github.com/pyg-team/pytorch_geometric/pull/7918)) - Added support for floating-point slicing in `Dataset`, *e.g.*, `dataset[:0.9]` ([#7915](https://github.com/pyg-team/pytorch_geometric/pull/7915)) - Added nightly GPU tests ([#7895](https://github.com/pyg-team/pytorch_geometric/pull/7895)) - Added the `HalfHop` graph upsampling augmentation ([#7827](https://github.com/pyg-team/pytorch_geometric/pull/7827)) From f4b4c0a2316205e0fa52af7fb3bb2f5036c3bd48 Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Wed, 23 Aug 2023 06:52:19 +0000 Subject: [PATCH 3/8] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- torch_geometric/loader/prefetch.py | 9 +++++---- torch_geometric/typing.py | 1 + 2 files changed, 6 insertions(+), 4 deletions(-) diff --git a/torch_geometric/loader/prefetch.py b/torch_geometric/loader/prefetch.py index 69c5bab6f309..c442b7c35d97 100644 --- a/torch_geometric/loader/prefetch.py +++ b/torch_geometric/loader/prefetch.py @@ -4,6 +4,7 @@ import torch from torch.utils.data import DataLoader + from torch_geometric.typing import WITH_IPEX @@ -31,8 +32,8 @@ def __init__(self, device: Optional[torch.device] = None): self.device = torch.device(device) - if ((self.device.type == 'cuda' and not cuda_present) or - (self.device.type == 'xpu' and not xpu_present)): + if ((self.device.type == 'cuda' and not cuda_present) + or (self.device.type == 'xpu' and not xpu_present)): print(f'Requested device[{self.device.type}] is not available ' '- fallback to CPU') self.device = torch.device('cpu') @@ -51,7 +52,8 @@ def __init__(self, device: Optional[torch.device] = None): def maybe_init_stream(self) -> None: if self.is_gpu: self.stream = self.gpu_module.Stream() - self.stream_context = partial(self.gpu_module.stream, stream=self.stream) + self.stream_context = partial(self.gpu_module.stream, + stream=self.stream) def maybe_wait_stream(self) -> None: if self.stream is not None: @@ -63,7 +65,6 @@ def get_device(self) -> torch.device: def get_stream_context(self) -> Any: return self.stream_context() - def __init__( self, loader: DataLoader, diff --git a/torch_geometric/typing.py b/torch_geometric/typing.py index 2ea63f5daee3..48ba3120a10d 100644 --- a/torch_geometric/typing.py +++ b/torch_geometric/typing.py @@ -204,6 +204,7 @@ def masked_select_nnz(src: SparseTensor, mask: Tensor, layout: Optional[str] = None) -> SparseTensor: raise ImportError("'masked_select_nnz' requires 'torch-sparse'") + try: import intel_extension_for_pytorch # noqa WITH_IPEX = True From edf3cb3f5069944c53a391e93e638ce1907f87a2 Mon Sep 17 00:00:00 2001 From: Damian Szwichtenberg Date: Wed, 23 Aug 2023 08:56:31 +0200 Subject: [PATCH 4/8] Fix pre-commit errors --- torch_geometric/loader/prefetch.py | 8 +++++--- 1 file changed, 5 insertions(+), 3 deletions(-) diff --git a/torch_geometric/loader/prefetch.py b/torch_geometric/loader/prefetch.py index c442b7c35d97..0ea0aca2e0ac 100644 --- a/torch_geometric/loader/prefetch.py +++ b/torch_geometric/loader/prefetch.py @@ -42,10 +42,12 @@ def __init__(self, device: Optional[torch.device] = None): self.stream = None self.stream_context = nullcontext + gpu_module = None if self.is_gpu: - gpu_module = torch.cuda if self.device.type == 'cuda' else torch.xpu - else: - gpu_module = None + if self.device.type == 'cuda': + gpu_module = torch.cuda + else: + gpu_module = torch.xpu self.gpu_module = gpu_module From 6db848cb310012f4eeb73be6ea21375bad27c91a Mon Sep 17 00:00:00 2001 From: Matthias Fey Date: Wed, 23 Aug 2023 13:47:39 +0200 Subject: [PATCH 5/8] Update torch_geometric/typing.py --- torch_geometric/typing.py | 4 ---- 1 file changed, 4 deletions(-) diff --git a/torch_geometric/typing.py b/torch_geometric/typing.py index 48ba3120a10d..f6ad2bee1d8d 100644 --- a/torch_geometric/typing.py +++ b/torch_geometric/typing.py @@ -209,10 +209,6 @@ def masked_select_nnz(src: SparseTensor, mask: Tensor, import intel_extension_for_pytorch # noqa WITH_IPEX = True except (ImportError, OSError) as e: - if isinstance(e, OSError): - warnings.warn("An issue occurred while importing" - "'intel-extension-for-pytorch'. " - f"Disabling its usage. Stacktrace: {e}") WITH_IPEX = False From 24ae6fce303eb508928d17d35ac36bcafed475dc Mon Sep 17 00:00:00 2001 From: rusty1s Date: Wed, 23 Aug 2023 15:28:51 +0000 Subject: [PATCH 6/8] update --- torch_geometric/loader/prefetch.py | 105 +++++++++++++---------------- 1 file changed, 47 insertions(+), 58 deletions(-) diff --git a/torch_geometric/loader/prefetch.py b/torch_geometric/loader/prefetch.py index 0ea0aca2e0ac..1b975a4553b3 100644 --- a/torch_geometric/loader/prefetch.py +++ b/torch_geometric/loader/prefetch.py @@ -1,3 +1,4 @@ +import warnings from contextlib import nullcontext from functools import partial from typing import Any, Optional @@ -8,6 +9,45 @@ from torch_geometric.typing import WITH_IPEX +class DeviceHelper: + def __init__(self, device: Optional[torch.device] = None): + with_cuda = torch.cuda.is_available() + with_xpu = torch.xpu.is_available() if WITH_IPEX else False + + if device is None: + if with_cuda: + device = 'cuda' + elif with_xpu: + device = 'xpu' + else: + device = 'cpu' + + self.device = torch.device(device) + self.is_gpu = self.device.type in ['cuda', 'xpu'] + + if ((self.device.type == 'cuda' and not with_cuda) + or (self.device.type == 'xpu' and not with_xpu)): + warnings.warn(f"Requested device '{self.device.type}' is not " + f"available, falling back to CPU") + self.device = torch.device('cpu') + + self.stream = None + self.stream_context = nullcontext + self.module = getattr(torch, self.device.type) if self.is_gpu else None + + def maybe_init_stream(self) -> None: + if self.is_gpu: + self.stream = self.module.Stream() + self.stream_context = partial( + self.module.stream, + stream=self.stream, + ) + + def maybe_wait_stream(self) -> None: + if self.stream is not None: + self.module.current_stream().wait_stream(self.stream) + + class PrefetchLoader: r"""A GPU prefetcher class for asynchronously transferring data of a :class:`torch.utils.data.DataLoader` from host memory to device memory. @@ -17,83 +57,32 @@ class PrefetchLoader: device (torch.device, optional): The device to load the data to. (default: :obj:`None`) """ - class PrefetchLoaderDevice: - def __init__(self, device: Optional[torch.device] = None): - cuda_present = torch.cuda.is_available() - xpu_present = torch.xpu.is_available() if WITH_IPEX else False - - if device is None: - if cuda_present: - device = 'cuda' - elif xpu_present: - device = 'xpu' - else: - device = 'cpu' - - self.device = torch.device(device) - - if ((self.device.type == 'cuda' and not cuda_present) - or (self.device.type == 'xpu' and not xpu_present)): - print(f'Requested device[{self.device.type}] is not available ' - '- fallback to CPU') - self.device = torch.device('cpu') - - self.is_gpu = self.device.type in ['cuda', 'xpu'] - self.stream = None - self.stream_context = nullcontext - - gpu_module = None - if self.is_gpu: - if self.device.type == 'cuda': - gpu_module = torch.cuda - else: - gpu_module = torch.xpu - - self.gpu_module = gpu_module - - def maybe_init_stream(self) -> None: - if self.is_gpu: - self.stream = self.gpu_module.Stream() - self.stream_context = partial(self.gpu_module.stream, - stream=self.stream) - - def maybe_wait_stream(self) -> None: - if self.stream is not None: - self.gpu_module.current_stream().wait_stream(self.stream) - - def get_device(self) -> torch.device: - return self.device - - def get_stream_context(self) -> Any: - return self.stream_context() - def __init__( self, loader: DataLoader, device: Optional[torch.device] = None, ): self.loader = loader - self.device_mgr = self.PrefetchLoaderDevice(device) + self.device_helper = self.DeviceHelper(device) def non_blocking_transfer(self, batch: Any) -> Any: - if not self.device_mgr.is_gpu: + if not self.device_helper.is_gpu: return batch if isinstance(batch, (list, tuple)): return [self.non_blocking_transfer(v) for v in batch] if isinstance(batch, dict): return {k: self.non_blocking_transfer(v) for k, v in batch.items()} - device = self.device_mgr.get_device() - batch = batch.pin_memory(device) - return batch.to(device, non_blocking=True) + batch = batch.pin_memory(self.device_helper.device) + return batch.to(self.device_helper.device, non_blocking=True) def __iter__(self) -> Any: first = True - self.device_mgr.maybe_init_stream() + self.device_helper.maybe_init_stream() for next_batch in self.loader: - with self.device_mgr.get_stream_context(): + with self.device_helper.stream_context(): next_batch = self.non_blocking_transfer(next_batch) if not first: @@ -101,7 +90,7 @@ def __iter__(self) -> Any: else: first = False - self.device_mgr.maybe_wait_stream() + self.device_helper.maybe_wait_stream() batch = next_batch From 03411f61327830dbad59fb617f9417ff08934772 Mon Sep 17 00:00:00 2001 From: rusty1s Date: Wed, 23 Aug 2023 15:30:41 +0000 Subject: [PATCH 7/8] update --- torch_geometric/loader/prefetch.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/torch_geometric/loader/prefetch.py b/torch_geometric/loader/prefetch.py index 1b975a4553b3..1340c9e7660f 100644 --- a/torch_geometric/loader/prefetch.py +++ b/torch_geometric/loader/prefetch.py @@ -63,7 +63,7 @@ def __init__( device: Optional[torch.device] = None, ): self.loader = loader - self.device_helper = self.DeviceHelper(device) + self.device_helper = DeviceHelper(device) def non_blocking_transfer(self, batch: Any) -> Any: if not self.device_helper.is_gpu: From dc5e32aed21218fa04e3e35d2c7103eea1f049a4 Mon Sep 17 00:00:00 2001 From: rusty1s Date: Wed, 23 Aug 2023 15:33:10 +0000 Subject: [PATCH 8/8] update --- torch_geometric/typing.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/torch_geometric/typing.py b/torch_geometric/typing.py index f6ad2bee1d8d..38a3a01fbb46 100644 --- a/torch_geometric/typing.py +++ b/torch_geometric/typing.py @@ -208,7 +208,7 @@ def masked_select_nnz(src: SparseTensor, mask: Tensor, try: import intel_extension_for_pytorch # noqa WITH_IPEX = True -except (ImportError, OSError) as e: +except (ImportError, OSError): WITH_IPEX = False