Skip to content

Commit

Permalink
Add first version of GPUPrefetcher (#7376)
Browse files Browse the repository at this point in the history
  • Loading branch information
rusty1s authored May 17, 2023
1 parent f0e91ad commit c51436c
Show file tree
Hide file tree
Showing 3 changed files with 101 additions and 0 deletions.
1 change: 1 addition & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).

### Added

- Added `GPUPrefetcher` capabilities ([#7376](https://github.com/pyg-team/pytorch_geometric/pull/7376))
- Added an example for hierarichial sampling ([#7244](https://github.com/pyg-team/pytorch_geometric/pull/7244))
- Added Kùzu remote backend examples ([#7298](https://github.com/pyg-team/pytorch_geometric/pull/7298))
- Fixed tracing of `add_self_loops` for a dynamic number of nodes ([#7330](https://github.com/pyg-team/pytorch_geometric/pull/7330))
Expand Down
19 changes: 19 additions & 0 deletions test/loader/test_prefetch.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,19 @@
import torch

from torch_geometric.loader.prefetch import GPUPrefetcher
from torch_geometric.testing import onlyCUDA


@onlyCUDA
def test_gpu_prefetcher():
data = [torch.randn(5, 5) for _ in range(10)]

loader = GPUPrefetcher(data, device='cuda')
assert str(loader).startswith('GPUPrefetcher')
assert len(loader) == 10

for i, batch in enumerate(loader):
assert batch.is_cuda
assert torch.equal(batch.cpu(), data[i])
assert loader.idx > 0
assert loader.idx == 0
81 changes: 81 additions & 0 deletions torch_geometric/loader/prefetch.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,81 @@
from queue import Queue
from threading import Thread
from typing import Any, Optional

import torch
from torch.utils.data import DataLoader


class GPUPrefetcher:
r"""A GPU prefetcher class for asynchronously loading data from a
:class:`torch.utils.data.DataLoader` from host memory to device memory.
Args:
loader (torch.utils.DataLoader): A data loader object.
device (torch.device): The CUDA device to load the data to.
prefetch_size (int, optional): The number of batches to prefetch at
once. (default: :obj:`1`)
"""
def __init__(
self,
loader: DataLoader,
device: torch.device,
prefetch_size: int = 1,
):
if prefetch_size < 1:
raise ValueError(f"'prefetch_size' must be greater than 0 "
f"(got {prefetch_size})")

self.loader = loader
self.device = torch.device(device)
self.prefetch_size = prefetch_size

self.load_stream = torch.cuda.Stream(device=device)
self.queue = Queue(maxsize=prefetch_size)
self.worker: Optional[Thread] = None

self.idx = 0

def non_blocking_transfer(self, batch: Any) -> Any:
# (Recursive) non-blocking device transfer:
if isinstance(batch, (list, tuple)):
return [self.non_blocking_transfer(v) for v in batch]
if isinstance(batch, dict):
return {k: self.non_blocking_transfer(v) for k, v in batch.items()}

with torch.cuda.stream(self.load_stream):
if not batch.is_pinned():
batch = batch.pin_memory()
return batch.to(self.device, non_blocking=True)

def load_loop(self):
for batch in self.loader:
self.queue.put(self.non_blocking_transfer(batch))

def __iter__(self) -> 'GPUPrefetcher':
is_dead = self.worker is None or not self.worker.is_alive()
if is_dead and self.queue.empty() and self.idx == 0:
self.worker = Thread(target=self.load_loop)
self.worker.daemon = True
self.worker.start()

return self

def __next__(self) -> Any:
is_dead = not self.worker.is_alive()
if (is_dead and self.queue.empty()) or self.idx >= len(self):
self.idx = 0
self.queue.join()
self.worker.join()
raise StopIteration

out = self.queue.get()
self.queue.task_done()
self.idx += 1
return out

def __len__(self) -> int:
return len(self.loader)

def __repr__(self) -> str:
return f'{self.__class__.__name__}({self.loader})'

0 comments on commit c51436c

Please sign in to comment.