Skip to content

Commit

Permalink
Support sampler argument in LightningDataModule (#5456)
Browse files Browse the repository at this point in the history
* sampler support

* update

* update
  • Loading branch information
rusty1s authored Sep 16, 2022
1 parent 67b8324 commit ec7400e
Show file tree
Hide file tree
Showing 2 changed files with 74 additions and 30 deletions.
1 change: 1 addition & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).

## [2.2.0] - 2022-MM-DD
### Added
- Added `sampler` support to `LightningDataModule` ([#5456](https://github.com/pyg-team/pytorch_geometric/pull/5456))
- Added official splits to `MalNetTiny` dataset ([#5078](https://github.com/pyg-team/pytorch_geometric/pull/5078))
- Added `IndexToMask` and `MaskToIndex` transforms ([#5375](https://github.com/pyg-team/pytorch_geometric/pull/5375), [#5455](https://github.com/pyg-team/pytorch_geometric/pull/5455))
- Added `FeaturePropagation` transform ([#5387](https://github.com/pyg-team/pytorch_geometric/pull/5387))
Expand Down
103 changes: 73 additions & 30 deletions torch_geometric/data/lightning_datamodule.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import copy
import warnings
from typing import Optional, Tuple, Union

Expand Down Expand Up @@ -144,26 +145,32 @@ def __init__(
self.val_dataset = val_dataset
self.test_dataset = test_dataset

def dataloader(
self,
dataset: Dataset,
shuffle: bool = False,
) -> DataLoader:
return DataLoader(dataset, shuffle=shuffle, **self.kwargs)
def dataloader(self, dataset: Dataset, **kwargs) -> DataLoader:
return DataLoader(dataset, **kwargs)

def train_dataloader(self) -> DataLoader:
""""""
from torch.utils.data import IterableDataset
shuffle = not isinstance(self.train_dataset, IterableDataset)
return self.dataloader(self.train_dataset, shuffle=shuffle)

return self.dataloader(self.train_dataset, shuffle=shuffle,
**self.kwargs)

def val_dataloader(self) -> DataLoader:
""""""
return self.dataloader(self.val_dataset, shuffle=False)
kwargs = copy.copy(self.kwargs)
kwargs.pop('sampler', None)
kwargs.pop('batch_sampler', None)

return self.dataloader(self.val_dataset, shuffle=False, **kwargs)

def test_dataloader(self) -> DataLoader:
""""""
return self.dataloader(self.test_dataset, shuffle=False)
kwargs = copy.copy(self.kwargs)
kwargs.pop('sampler', None)
kwargs.pop('batch_sampler', None)

return self.dataloader(self.test_dataset, shuffle=False, **kwargs)

def __repr__(self) -> str:
kwargs = kwargs_repr(train_dataset=self.train_dataset,
Expand Down Expand Up @@ -331,37 +338,59 @@ def prepare_data(self):
def dataloader(
self,
input_nodes: InputNodes,
shuffle: bool = False,
**kwargs,
) -> DataLoader:
if self.loader == 'full':
warnings.filterwarnings('ignore', '.*does not have many workers.*')
warnings.filterwarnings('ignore', '.*data loading bottlenecks.*')
return torch.utils.data.DataLoader([self.data], shuffle=False,
collate_fn=lambda xs: xs[0],
**self.kwargs)

kwargs.pop('sampler', None)
kwargs.pop('batch_sampler', None)

return torch.utils.data.DataLoader(
[self.data],
collate_fn=lambda xs: xs[0],
**kwargs,
)

if self.loader == 'neighbor':
return NeighborLoader(data=self.data, input_nodes=input_nodes,
neighbor_sampler=self.neighbor_sampler,
shuffle=shuffle, **self.kwargs)
return NeighborLoader(
self.data,
input_nodes=input_nodes,
neighbor_sampler=self.neighbor_sampler,
**kwargs,
)

raise NotImplementedError

def train_dataloader(self) -> DataLoader:
""""""
return self.dataloader(self.input_train_nodes, shuffle=True)
return self.dataloader(self.input_train_nodes, shuffle=True,
**self.kwargs)

def val_dataloader(self) -> DataLoader:
""""""
return self.dataloader(self.input_val_nodes, shuffle=False)
kwargs = copy.copy(self.kwargs)
kwargs.pop('sampler', None)
kwargs.pop('batch_sampler', None)

return self.dataloader(self.input_val_nodes, shuffle=False, **kwargs)

def test_dataloader(self) -> DataLoader:
""""""
return self.dataloader(self.input_test_nodes, shuffle=False)
kwargs = copy.copy(self.kwargs)
kwargs.pop('sampler', None)
kwargs.pop('batch_sampler', None)

return self.dataloader(self.input_test_nodes, shuffle=False, **kwargs)

def predict_dataloader(self) -> DataLoader:
""""""
return self.dataloader(self.input_pred_nodes, shuffle=False)
kwargs = copy.copy(self.kwargs)
kwargs.pop('sampler', None)
kwargs.pop('batch_sampler', None)

return self.dataloader(self.input_pred_nodes, shuffle=False, **kwargs)

def __repr__(self) -> str:
kwargs = kwargs_repr(data=self.data, loader=self.loader, **self.kwargs)
Expand Down Expand Up @@ -529,43 +558,57 @@ def dataloader(
input_edges: InputEdges,
input_labels: Optional[Tensor],
input_time: Optional[Tensor] = None,
shuffle: bool = False,
**kwargs,
) -> DataLoader:
if self.loader == 'full':
warnings.filterwarnings('ignore', '.*does not have many workers.*')
warnings.filterwarnings('ignore', '.*data loading bottlenecks.*')
return torch.utils.data.DataLoader([self.data], shuffle=False,
collate_fn=lambda xs: xs[0],
**self.kwargs)

kwargs.pop('sampler', None)
kwargs.pop('batch_sampler', None)

return torch.utils.data.DataLoader(
[self.data],
collate_fn=lambda xs: xs[0],
**kwargs,
)

if self.loader in ['neighbor', 'link_neighbor']:
return LinkNeighborLoader(
data=self.data,
self.data,
edge_label_index=input_edges,
edge_label=input_labels,
edge_label_time=input_time,
neighbor_sampler=self.neighbor_sampler,
shuffle=shuffle,
neg_sampling_ratio=self.neg_sampling_ratio,
**self.kwargs,
**kwargs,
)

raise NotImplementedError

def train_dataloader(self) -> DataLoader:
""""""
return self.dataloader(self.input_train_edges, self.input_train_labels,
self.input_train_time, shuffle=True)
self.input_train_time, shuffle=True,
**self.kwargs)

def val_dataloader(self) -> DataLoader:
""""""
kwargs = copy.copy(self.kwargs)
kwargs.pop('sampler', None)
kwargs.pop('batch_sampler', None)

return self.dataloader(self.input_val_edges, self.input_val_labels,
self.input_val_time, shuffle=False)
self.input_val_time, shuffle=False, **kwargs)

def test_dataloader(self) -> DataLoader:
""""""
kwargs = copy.copy(self.kwargs)
kwargs.pop('sampler', None)
kwargs.pop('batch_sampler', None)

return self.dataloader(self.input_test_edges, self.input_test_labels,
self.input_test_time, shuffle=False)
self.input_test_time, shuffle=False, **kwargs)

def __repr__(self) -> str:
kwargs = kwargs_repr(data=self.data, loader=self.loader, **self.kwargs)
Expand Down

0 comments on commit ec7400e

Please sign in to comment.