diff --git a/CHANGELOG.md b/CHANGELOG.md index 7c5a7d2e3204..e48601e889fd 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -59,6 +59,7 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/). ### Changed +- Unified `LightningNodeData` and `LightningLinkData` code paths ([#6473](https://github.com/pyg-team/pytorch_geometric/pull/6473)) - Allow indices with any integer type in `RGCNConv` ([#6463](https://github.com/pyg-team/pytorch_geometric/pull/6463)) - Re-structured the documentation ([#6420](https://github.com/pyg-team/pytorch_geometric/pull/6420), [#6423](https://github.com/pyg-team/pytorch_geometric/pull/6423), [#6429](https://github.com/pyg-team/pytorch_geometric/pull/6429), [#6440](https://github.com/pyg-team/pytorch_geometric/pull/6440), [#6443](https://github.com/pyg-team/pytorch_geometric/pull/6443), [#6445](https://github.com/pyg-team/pytorch_geometric/pull/6445), [#6452](https://github.com/pyg-team/pytorch_geometric/pull/6452), [#6453](https://github.com/pyg-team/pytorch_geometric/pull/6453), [#6458](https://github.com/pyg-team/pytorch_geometric/pull/6458), [#6459](https://github.com/pyg-team/pytorch_geometric/pull/6459), [#6460](https://github.com/pyg-team/pytorch_geometric/pull/6460)) - Fix the default arguments of `DataParallel` class ([#6376](https://github.com/pyg-team/pytorch_geometric/pull/6376)) diff --git a/test/data/lightning/test_datamodule.py b/test/data/lightning/test_datamodule.py index 0e1601382d61..8a4781eecc6f 100644 --- a/test/data/lightning/test_datamodule.py +++ b/test/data/lightning/test_datamodule.py @@ -45,7 +45,7 @@ def __init__(self, in_channels, hidden_channels, out_channels): def forward(self, x, batch): # Basic test to ensure that the dataset is not replicated: - self.trainer.datamodule.train_dataset.data.x.add_(1) + self.trainer.datamodule.train_dataset._data.x.add_(1) x = self.lin1(x).relu() x = global_mean_pool(x, batch) @@ -93,14 +93,14 @@ def test_lightning_dataset(get_dataset, strategy_type): max_epochs=1, log_every_n_steps=1) datamodule = LightningDataset(train_dataset, val_dataset, test_dataset, batch_size=5, num_workers=3) - old_x = train_dataset.data.x.clone() + old_x = train_dataset._data.x.clone() assert str(datamodule) == ('LightningDataset(train_dataset=MUTAG(50), ' 'val_dataset=MUTAG(30), ' 'test_dataset=MUTAG(10), batch_size=5, ' 'num_workers=3, pin_memory=True, ' 'persistent_workers=True)') trainer.fit(model, datamodule) - new_x = train_dataset.data.x + new_x = train_dataset._data.x offset = 10 + 6 + 2 * devices # `train_steps` + `val_steps` + `sanity` assert torch.all(new_x > (old_x + offset - 4)) # Ensure shared data. if strategy_type is None: @@ -274,7 +274,7 @@ def test_lightning_hetero_node_data(get_dataset): max_epochs=5, log_every_n_steps=1) datamodule = LightningNodeData(data, loader='neighbor', num_neighbors=[5], batch_size=32, num_workers=3) - assert isinstance(datamodule.neighbor_sampler, NeighborSampler) + assert isinstance(datamodule.graph_sampler, NeighborSampler) old_x = data['author'].x.clone() trainer.fit(model, datamodule) new_x = data['author'].x @@ -298,12 +298,12 @@ def sample_from_nodes(self, *args, **kwargs): datamodule = LightningNodeData(data, node_sampler=DummySampler(), input_train_nodes=torch.arange(2)) - assert isinstance(datamodule.neighbor_sampler, DummySampler) + assert isinstance(datamodule.graph_sampler, DummySampler) datamodule = LightningLinkData( data, link_sampler=DummySampler(), input_train_edges=torch.tensor([[0, 1], [0, 1]])) - assert isinstance(datamodule.neighbor_sampler, DummySampler) + assert isinstance(datamodule.graph_sampler, DummySampler) @onlyCUDA @@ -330,7 +330,7 @@ def test_lightning_hetero_link_data(): batch_size=32, num_workers=0, ) - assert isinstance(datamodule.neighbor_sampler, NeighborSampler) + assert isinstance(datamodule.graph_sampler, NeighborSampler) for batch in datamodule.train_dataloader(): assert 'edge_label_index' in batch['author', 'paper'] @@ -404,9 +404,9 @@ def test_eval_loader_kwargs(get_dataset): ) assert datamodule.loader_kwargs['batch_size'] == 32 - assert datamodule.neighbor_sampler.num_neighbors == [5] + assert datamodule.graph_sampler.num_neighbors == [5] assert datamodule.eval_loader_kwargs['batch_size'] == 64 - assert datamodule.eval_neighbor_sampler.num_neighbors == [-1] + assert datamodule.eval_graph_sampler.num_neighbors == [-1] train_loader = datamodule.train_dataloader() assert math.ceil(int(data.train_mask.sum()) / 32) == len(train_loader) diff --git a/torch_geometric/data/lightning/datamodule.py b/torch_geometric/data/lightning/datamodule.py index f44245fc2797..6f25e09fabeb 100644 --- a/torch_geometric/data/lightning/datamodule.py +++ b/torch_geometric/data/lightning/datamodule.py @@ -39,18 +39,18 @@ def __init__(self, has_val: bool, has_test: bool, **kwargs): if not has_test: self.test_dataloader = None + kwargs.setdefault('batch_size', 1) + kwargs.setdefault('num_workers', 0) + kwargs.setdefault('pin_memory', True) + kwargs.setdefault('persistent_workers', + kwargs.get('num_workers', 0) > 0) + if 'shuffle' in kwargs: warnings.warn(f"The 'shuffle={kwargs['shuffle']}' option is " f"ignored in '{self.__class__.__name__}'. Remove it " f"from the argument list to disable this warning") del kwargs['shuffle'] - if 'pin_memory' not in kwargs: - kwargs['pin_memory'] = True - - if 'persistent_workers' not in kwargs: - kwargs['persistent_workers'] = kwargs.get('num_workers', 0) > 0 - self.kwargs = kwargs def prepare_data(self): @@ -78,7 +78,152 @@ def prepare_data(self): f"'pytorch_lightning' (got '{strategy.__class__.__name__}')") def __repr__(self) -> str: - return f'{self.__class__.__name__}({self._kwargs_repr(**self.kwargs)})' + return f'{self.__class__.__name__}({kwargs_repr(**self.kwargs)})' + + +class LightningData(LightningDataModule): + def __init__( + self, + data: Union[Data, HeteroData], + has_val: bool, + has_test: bool, + loader: str = 'neighbor', + graph_sampler: Optional[BaseSampler] = None, + eval_loader_kwargs: Optional[Dict[str, Any]] = None, + **kwargs, + ): + kwargs.setdefault('batch_size', 1) + kwargs.setdefault('num_workers', 0) + + if graph_sampler is not None: + loader = 'custom' + + # For full-batch training, we use reasonable defaults for a lot of + # data-loading options: + if loader not in ['full', 'neighbor', 'link_neighbor', 'custom']: + raise ValueError("Undefined 'loader' option (got '{loader}')") + + if loader == 'full' and kwargs['batch_size'] != 1: + warnings.warn(f"Re-setting 'batch_size' to 1 in " + f"'{self.__class__.__name__}' for loader='full' " + f"(got '{kwargs['batch_size']}')") + kwargs['batch_size'] = 1 + + if loader == 'full' and kwargs['num_workers'] != 0: + warnings.warn(f"Re-setting 'num_workers' to 0 in " + f"'{self.__class__.__name__}' for loader='full' " + f"(got '{kwargs['num_workers']}')") + kwargs['num_workers'] = 0 + + if loader == 'full' and kwargs.get('sampler') is not None: + warnings.warn("'sampler' option is not supported for " + "loader='full'") + kwargs.pop('sampler', None) + + if loader == 'full' and kwargs.get('batch_sampler') is not None: + warnings.warn("'batch_sampler' option is not supported for " + "loader='full'") + kwargs.pop('sampler', None) + + super().__init__(has_val, has_test, **kwargs) + + if loader == 'full': + if kwargs.get('pin_memory', False): + warnings.warn(f"Re-setting 'pin_memory' to 'False' in " + f"'{self.__class__.__name__}' for loader='full' " + f"(got 'True')") + self.kwargs['pin_memory'] = False + + self.data = data + self.loader = loader + + # Determine sampler and loader arguments ############################## + + if loader in ['neighbor', 'link_neighbor']: + + # Define a new `NeighborSampler` to be re-used across data loaders: + sampler_kwargs, self.loader_kwargs = split_kwargs( + self.kwargs, + NeighborSampler, + ) + sampler_kwargs.setdefault('share_memory', + self.kwargs['num_workers'] > 0) + self.graph_sampler = NeighborSampler(data, **sampler_kwargs) + + elif graph_sampler is not None: + sampler_kwargs, self.loader_kwargs = split_kwargs( + self.kwargs, + graph_sampler.__class__, + ) + if len(sampler_kwargs) > 0: + warnings.warn(f"Ignoring the arguments " + f"{list(sampler_kwargs.keys())} in " + f"'{self.__class__.__name__}' since a custom " + f"'graph_sampler' was passed") + self.graph_sampler = graph_sampler + + else: + assert loader == 'full' + self.loader_kwargs = self.kwargs + + # Determine validation sampler and loader arguments ################### + + self.eval_loader_kwargs = copy.copy(self.loader_kwargs) + if eval_loader_kwargs is not None: + # If the user wants to override certain values during evaluation, + # we shallow-copy the graph sampler and update its attributes. + if hasattr(self, 'graph_sampler'): + self.eval_graph_sampler = copy.copy(self.graph_sampler) + + eval_sampler_kwargs, eval_loader_kwargs = split_kwargs( + eval_loader_kwargs, + self.graph_sampler.__class__, + ) + for key, value in eval_sampler_kwargs.items(): + setattr(self.eval_graph_sampler, key, value) + + self.eval_loader_kwargs.update(eval_loader_kwargs) + + elif hasattr(self, 'graph_sampler'): + self.eval_graph_sampler = self.graph_sampler + + self.eval_loader_kwargs.pop('sampler', None) + self.eval_loader_kwargs.pop('batch_sampler', None) + + @property + def train_shuffle(self) -> bool: + shuffle = self.loader_kwargs.get('sampler', None) is None + shuffle &= self.loader_kwargs.get('batch_sampler', None) is None + return shuffle + + def prepare_data(self): + if self.loader == 'full': + try: + num_devices = self.trainer.num_devices + except AttributeError: + # PyTorch Lightning < 1.6 backward compatibility: + num_devices = self.trainer.num_processes + num_devices = max(num_devices, self.trainer.num_gpus) + + if num_devices > 1: + raise ValueError( + f"'{self.__class__.__name__}' with loader='full' requires " + f"training on a single device") + super().prepare_data() + + def full_dataloader(self, **kwargs) -> DataLoader: + warnings.filterwarnings('ignore', '.*does not have many workers.*') + warnings.filterwarnings('ignore', '.*data loading bottlenecks.*') + + return torch.utils.data.DataLoader( + [self.data], + collate_fn=lambda xs: xs[0], + **kwargs, + ) + + def __repr__(self) -> str: + kwargs = kwargs_repr(data=self.data, loader=self.loader, **self.kwargs) + return f'{self.__class__.__name__}({kwargs})' class LightningDataset(LightningDataModule): @@ -114,11 +259,6 @@ class LightningDataset(LightningDataModule): (default: :obj:`None`) pred_dataset (Dataset, optional): The prediction dataset. (default: :obj:`None`) - batch_size (int, optional): How many samples per batch to load. - (default: :obj:`1`) - num_workers (int): How many subprocesses to use for data loading. - :obj:`0` means that the data will be loaded in the main process. - (default: :obj:`0`) **kwargs (optional): Additional arguments of :class:`torch_geometric.loader.DataLoader`. """ @@ -128,15 +268,11 @@ def __init__( val_dataset: Optional[Dataset] = None, test_dataset: Optional[Dataset] = None, pred_dataset: Optional[Dataset] = None, - batch_size: int = 1, - num_workers: int = 0, **kwargs, ): super().__init__( has_val=val_dataset is not None, has_test=test_dataset is not None, - batch_size=batch_size, - num_workers=num_workers, **kwargs, ) @@ -149,7 +285,6 @@ def dataloader(self, dataset: Dataset, **kwargs) -> DataLoader: return DataLoader(dataset, **kwargs) def train_dataloader(self) -> DataLoader: - """""" from torch.utils.data import IterableDataset shuffle = not isinstance(self.train_dataset, IterableDataset) @@ -160,7 +295,6 @@ def train_dataloader(self) -> DataLoader: **self.kwargs) def val_dataloader(self) -> DataLoader: - """""" kwargs = copy.copy(self.kwargs) kwargs.pop('sampler', None) kwargs.pop('batch_sampler', None) @@ -168,7 +302,6 @@ def val_dataloader(self) -> DataLoader: return self.dataloader(self.val_dataset, shuffle=False, **kwargs) def test_dataloader(self) -> DataLoader: - """""" kwargs = copy.copy(self.kwargs) kwargs.pop('sampler', None) kwargs.pop('batch_sampler', None) @@ -176,7 +309,6 @@ def test_dataloader(self) -> DataLoader: return self.dataloader(self.test_dataset, shuffle=False, **kwargs) def predict_dataloader(self) -> DataLoader: - """""" kwargs = copy.copy(self.kwargs) kwargs.pop('sampler', None) kwargs.pop('batch_sampler', None) @@ -190,8 +322,7 @@ def __repr__(self) -> str: return f'{self.__class__.__name__}({kwargs})' -# TODO explicitly support Tuple[FeatureStore, GraphStore] -class LightningNodeData(LightningDataModule): +class LightningNodeData(LightningData): r"""Converts a :class:`~torch_geometric.data.Data` or :class:`~torch_geometric.data.HeteroData` object into a :class:`pytorch_lightning.LightningDataModule` variant, which can be @@ -258,11 +389,6 @@ class LightningNodeData(LightningDataModule): node_sampler (BaseSampler, optional): A custom sampler object to generate mini-batches. If set, will ignore the :obj:`loader` option. (default: :obj:`None`) - batch_size (int, optional): How many samples per batch to load. - (default: :obj:`1`) - num_workers (int): How many subprocesses to use for data loading. - :obj:`0` means that the data will be loaded in the main process. - (default: :obj:`0`) eval_loader_kwargs (Dict[str, Any], optional): Custom keyword arguments that override the :class:`torch_geometric.loader.NeighborLoader` configuration during evaluation. (default: :obj:`None`) @@ -280,18 +406,11 @@ def __init__( input_test_time: OptTensor = None, input_pred_nodes: InputNodes = None, input_pred_time: OptTensor = None, - loader: str = "neighbor", + loader: str = 'neighbor', node_sampler: Optional[BaseSampler] = None, - batch_size: int = 1, - num_workers: int = 0, eval_loader_kwargs: Optional[Dict[str, Any]] = None, **kwargs, ): - if node_sampler is not None: - loader = 'custom' - - assert loader in ['full', 'neighbor', 'custom'] - if input_train_nodes is None: input_train_nodes = infer_input_nodes(data, split='train') @@ -306,127 +425,32 @@ def __init__( if input_pred_nodes is None: input_pred_nodes = infer_input_nodes(data, split='pred') - if loader == 'full' and batch_size != 1: - warnings.warn(f"Re-setting 'batch_size' to 1 in " - f"'{self.__class__.__name__}' for loader='full' " - f"(got '{batch_size}')") - batch_size = 1 - - if loader == 'full' and num_workers != 0: - warnings.warn(f"Re-setting 'num_workers' to 0 in " - f"'{self.__class__.__name__}' for loader='full' " - f"(got '{num_workers}')") - num_workers = 0 - - if loader == 'full' and kwargs.get('sampler') is not None: - warnings.warn("'sampler' option is not supported for " - "loader='full'") - kwargs.pop('sampler', None) - - if loader == 'full' and kwargs.get('batch_sampler') is not None: - warnings.warn("'batch_sampler' option is not supported for " - "loader='full'") - kwargs.pop('sampler', None) - super().__init__( + data=data, has_val=input_val_nodes is not None, has_test=input_test_nodes is not None, - batch_size=batch_size, - num_workers=num_workers, + loader=loader, + graph_sampler=node_sampler, + eval_loader_kwargs=eval_loader_kwargs, **kwargs, ) - if loader == 'full': - if kwargs.get('pin_memory', False): - warnings.warn(f"Re-setting 'pin_memory' to 'False' in " - f"'{self.__class__.__name__}' for loader='full' " - f"(got 'True')") - self.kwargs['pin_memory'] = False - - self.data = data - self.loader = loader - - # Determine sampler and loader arguments ############################## - - if loader == 'neighbor': - # Define a new `NeighborSampler` that can be re-used across - # different data loaders. - sampler_kwargs, self.loader_kwargs = split_kwargs( - self.kwargs, - NeighborSampler, - ) - sampler_kwargs.setdefault('share_memory', num_workers > 0) - - # TODO Consider renaming to `self.node_sampler` - self.neighbor_sampler = NeighborSampler(data, **sampler_kwargs) - - elif node_sampler is not None: - _, self.loader_kwargs = split_kwargs( - self.kwargs, - node_sampler.__class__, - ) - self.neighbor_sampler = node_sampler - - else: - self.loader_kwargs = self.kwargs - - # Determine validation sampler and loader arguments ################### - - self.eval_loader_kwargs = copy.copy(self.loader_kwargs) - if eval_loader_kwargs is not None: - # If the user wants to override certain values during evaluation, - # we shallow-copy the sampler and update its attributes. - if hasattr(self, 'neighbor_sampler'): - self.eval_neighbor_sampler = copy.copy(self.neighbor_sampler) - - eval_sampler_kwargs, eval_loader_kwargs = split_kwargs( - eval_loader_kwargs, - self.neighbor_sampler.__class__, - ) - for key, value in eval_sampler_kwargs.items(): - setattr(self.eval_neighbor_sampler, key, value) - - self.eval_loader_kwargs.update(eval_loader_kwargs) - - elif hasattr(self, 'neighbor_sampler'): - self.eval_neighbor_sampler = self.neighbor_sampler - - self.eval_loader_kwargs.pop('sampler', None) - self.eval_loader_kwargs.pop('batch_sampler', None) - - ####################################################################### - self.input_train_nodes = input_train_nodes self.input_train_time = input_train_time + self.input_train_id: OptTensor = None + self.input_val_nodes = input_val_nodes self.input_val_time = input_val_time + self.input_val_id: OptTensor = None + self.input_test_nodes = input_test_nodes self.input_test_time = input_test_time + self.input_test_id: OptTensor = None + self.input_pred_nodes = input_pred_nodes self.input_pred_time = input_pred_time - - # Can be overriden to set input indices of the `NodeLoader`: - self.input_train_id: OptTensor = None - self.input_val_id: OptTensor = None - self.input_test_id: OptTensor = None self.input_pred_id: OptTensor = None - def prepare_data(self): - """""" - if self.loader == 'full': - try: - num_devices = self.trainer.num_devices - except AttributeError: - # PyTorch Lightning < 1.6 backward compatibility: - num_devices = self.trainer.num_processes - num_devices = max(num_devices, self.trainer.num_gpus) - - if num_devices > 1: - raise ValueError( - f"'{self.__class__.__name__}' with loader='full' requires " - f"training on a single device") - super().prepare_data() - def dataloader( self, input_nodes: InputNodes, @@ -436,84 +460,61 @@ def dataloader( **kwargs, ) -> DataLoader: if self.loader == 'full': - warnings.filterwarnings('ignore', '.*does not have many workers.*') - warnings.filterwarnings('ignore', '.*data loading bottlenecks.*') + return self.full_dataloader(**kwargs) - return torch.utils.data.DataLoader( - [self.data], - collate_fn=lambda xs: xs[0], - **kwargs, - ) + assert node_sampler is not None - else: - if node_sampler is None: - warnings.warn("No 'node_sampler' specified. Falling back to " - "using the default training sampler.") - node_sampler = self.neighbor_sampler - - return NodeLoader( - self.data, - node_sampler=node_sampler, - input_nodes=input_nodes, - input_time=input_time, - input_id=input_id, - **kwargs, - ) + return NodeLoader( + self.data, + node_sampler=node_sampler, + input_nodes=input_nodes, + input_time=input_time, + input_id=input_id, + **kwargs, + ) def train_dataloader(self) -> DataLoader: - """""" - shuffle = self.loader_kwargs.get('sampler', None) is None - shuffle &= self.loader_kwargs.get('batch_sampler', None) is None - return self.dataloader( self.input_train_nodes, self.input_train_time, self.input_train_id, - node_sampler=getattr(self, 'neighbor_sampler', None), - shuffle=shuffle, + node_sampler=getattr(self, 'graph_sampler', None), + shuffle=self.train_shuffle, **self.loader_kwargs, ) def val_dataloader(self) -> DataLoader: - """""" return self.dataloader( self.input_val_nodes, self.input_val_time, self.input_val_id, - node_sampler=getattr(self, 'eval_neighbor_sampler', None), + node_sampler=getattr(self, 'eval_graph_sampler', None), shuffle=False, **self.eval_loader_kwargs, ) def test_dataloader(self) -> DataLoader: - """""" return self.dataloader( self.input_test_nodes, self.input_test_time, self.input_test_id, - node_sampler=getattr(self, 'eval_neighbor_sampler', None), + node_sampler=getattr(self, 'eval_graph_sampler', None), shuffle=False, **self.eval_loader_kwargs, ) def predict_dataloader(self) -> DataLoader: - """""" return self.dataloader( self.input_pred_nodes, self.input_pred_time, self.input_pred_id, - node_sampler=getattr(self, 'eval_neighbor_sampler', None), + node_sampler=getattr(self, 'eval_graph_sampler', None), shuffle=False, **self.eval_loader_kwargs, ) - def __repr__(self) -> str: - kwargs = kwargs_repr(data=self.data, loader=self.loader, **self.kwargs) - return f'{self.__class__.__name__}({kwargs})' - -# TODO: Unify implementation with LightningNodeData via a common base class. -class LightningLinkData(LightningDataModule): +class LightningLinkData(LightningData): r"""Converts a :class:`~torch_geometric.data.Data` or :class:`~torch_geometric.data.HeteroData` object into a :class:`pytorch_lightning.LightningDataModule` variant, which can be @@ -575,11 +576,6 @@ class LightningLinkData(LightningDataModule): link_sampler (BaseSampler, optional): A custom sampler object to generate mini-batches. If set, will ignore the :obj:`loader` option. (default: :obj:`None`) - batch_size (int, optional): How many samples per batch to load. - (default: :obj:`1`) - num_workers (int): How many subprocesses to use for data loading. - :obj:`0` means that the data will be loaded in the main process. - (default: :obj:`0`) eval_loader_kwargs (Dict[str, Any], optional): Custom keyword arguments that override the :class:`torch_geometric.loader.LinkNeighborLoader` configuration @@ -602,145 +598,41 @@ def __init__( input_pred_edges: InputEdges = None, input_pred_labels: OptTensor = None, input_pred_time: OptTensor = None, - loader: str = "neighbor", + loader: str = 'neighbor', link_sampler: Optional[BaseSampler] = None, - batch_size: int = 1, - num_workers: int = 0, eval_loader_kwargs: Optional[Dict[str, Any]] = None, **kwargs, ): - if link_sampler is not None: - loader = 'custom' - - assert loader in ['full', 'neighbor', 'link_neighbor', 'custom'] - - if input_train_edges is None: - raise NotImplementedError(f"'{self.__class__.__name__}' cannot " - f"yet infer input edges automatically") - - if loader == 'full' and batch_size != 1: - warnings.warn(f"Re-setting 'batch_size' to 1 in " - f"'{self.__class__.__name__}' for loader='full' " - f"(got '{batch_size}')") - batch_size = 1 - - if loader == 'full' and num_workers != 0: - warnings.warn(f"Re-setting 'num_workers' to 0 in " - f"'{self.__class__.__name__}' for loader='full' " - f"(got '{num_workers}')") - num_workers = 0 - - if loader == 'full' and kwargs.get('sampler') is not None: - warnings.warn("'sampler' option is not supported for " - "loader='full'") - kwargs.pop('sampler', None) - - if loader == 'full' and kwargs.get('batch_sampler') is not None: - warnings.warn("'batch_sampler' option is not supported for " - "loader='full'") - kwargs.pop('sampler', None) - super().__init__( + data=data, has_val=input_val_edges is not None, has_test=input_test_edges is not None, - batch_size=batch_size, - num_workers=num_workers, + loader=loader, + graph_sampler=link_sampler, + eval_loader_kwargs=eval_loader_kwargs, **kwargs, ) - if loader == 'full': - if kwargs.get('pin_memory', False): - warnings.warn(f"Re-setting 'pin_memory' to 'False' in " - f"'{self.__class__.__name__}' for loader='full' " - f"(got 'True')") - self.kwargs['pin_memory'] = False - - self.data = data - self.loader = loader - - # Determine sampler and loader arguments ############################## - - if loader in ['neighbor', 'link_neighbor']: - # Define a new `NeighborSampler` that can be re-used across - # different data loaders. - sampler_kwargs, self.loader_kwargs = split_kwargs( - self.kwargs, - NeighborSampler, - ) - sampler_kwargs.setdefault('share_memory', num_workers > 0) - - # TODO Consider renaming to `self.link_sampler` - self.neighbor_sampler = NeighborSampler(data, **sampler_kwargs) - - elif link_sampler is not None: - _, self.loader_kwargs = split_kwargs( - self.kwargs, - link_sampler.__class__, - ) - self.neighbor_sampler = link_sampler - - else: - self.loader_kwargs = self.kwargs - - # Determine validation sampler and loader arguments ################### - - self.eval_loader_kwargs = copy.copy(self.loader_kwargs) - if eval_loader_kwargs is not None: - # If the user wants to override certain values during evaluation, - # we shallow-copy the sampler and update its attributes. - if hasattr(self, 'neighbor_sampler'): - self.eval_neighbor_sampler = copy.copy(self.neighbor_sampler) - - eval_sampler_kwargs, eval_loader_kwargs = split_kwargs( - eval_loader_kwargs, - self.neighbor_sampler.__class__, - ) - for key, value in eval_sampler_kwargs.items(): - setattr(self.eval_neighbor_sampler, key, value) - - self.eval_loader_kwargs.update(eval_loader_kwargs) - - elif hasattr(self, 'neighbor_sampler'): - self.eval_neighbor_sampler = self.neighbor_sampler - - self.eval_loader_kwargs.pop('sampler', None) - self.eval_loader_kwargs.pop('batch_sampler', None) - self.input_train_edges = input_train_edges self.input_train_labels = input_train_labels self.input_train_time = input_train_time + self.input_train_id: OptTensor = None + self.input_val_edges = input_val_edges self.input_val_labels = input_val_labels self.input_val_time = input_val_time + self.input_val_id: OptTensor = None + self.input_test_edges = input_test_edges self.input_test_labels = input_test_labels self.input_test_time = input_test_time + self.input_test_id: OptTensor = None + self.input_pred_edges = input_pred_edges self.input_pred_labels = input_pred_labels self.input_pred_time = input_pred_time - - # Can be overriden to set input indices of the `LinkLoader`: - self.input_train_id: OptTensor = None - self.input_val_id: OptTensor = None - self.input_test_id: OptTensor = None self.input_pred_id: OptTensor = None - def prepare_data(self): - """""" - if self.loader == 'full': - try: - num_devices = self.trainer.num_devices - except AttributeError: - # PyTorch Lightning < 1.6 backward compatibility: - num_devices = self.trainer.num_processes - num_devices = max(num_devices, self.trainer.num_gpus) - - if num_devices > 1: - raise ValueError( - f"'{self.__class__.__name__}' with loader='full' requires " - f"training on a single device") - super().prepare_data() - def dataloader( self, input_edges: InputEdges, @@ -751,86 +643,64 @@ def dataloader( **kwargs, ) -> DataLoader: if self.loader == 'full': - warnings.filterwarnings('ignore', '.*does not have many workers.*') - warnings.filterwarnings('ignore', '.*data loading bottlenecks.*') + return self.full_dataloader(**kwargs) - return torch.utils.data.DataLoader( - [self.data], - collate_fn=lambda xs: xs[0], - **kwargs, - ) + assert link_sampler is not None - else: - if link_sampler is None: - warnings.warn("No 'link_sampler' specified. Falling back to " - "using the default training sampler.") - link_sampler = self.neighbor_sampler - - return LinkLoader( - self.data, - link_sampler=link_sampler, - edge_label_index=input_edges, - edge_label=input_labels, - edge_label_time=input_time, - input_id=input_id, - **kwargs, - ) + return LinkLoader( + self.data, + link_sampler=link_sampler, + edge_label_index=input_edges, + edge_label=input_labels, + edge_label_time=input_time, + input_id=input_id, + **kwargs, + ) def train_dataloader(self) -> DataLoader: - """""" - shuffle = self.loader_kwargs.get('sampler', None) is None - shuffle &= self.loader_kwargs.get('batch_sampler', None) is None - return self.dataloader( self.input_train_edges, self.input_train_labels, self.input_train_time, self.input_train_id, - link_sampler=getattr(self, 'neighbor_sampler', None), - shuffle=shuffle, + link_sampler=getattr(self, 'graph_sampler', None), + shuffle=self.train_shuffle, **self.loader_kwargs, ) def val_dataloader(self) -> DataLoader: - """""" return self.dataloader( self.input_val_edges, self.input_val_labels, self.input_val_time, self.input_val_id, - link_sampler=getattr(self, 'eval_neighbor_sampler', None), + link_sampler=getattr(self, 'eval_graph_sampler', None), shuffle=False, **self.eval_loader_kwargs, ) def test_dataloader(self) -> DataLoader: - """""" return self.dataloader( self.input_test_edges, self.input_test_labels, self.input_test_time, self.input_test_id, - link_sampler=getattr(self, 'eval_neighbor_sampler', None), + link_sampler=getattr(self, 'eval_graph_sampler', None), shuffle=False, **self.eval_loader_kwargs, ) def predict_dataloader(self) -> DataLoader: - """""" return self.dataloader( self.input_pred_edges, self.input_pred_labels, self.input_pred_time, self.input_pred_id, - link_sampler=getattr(self, 'eval_neighbor_sampler', None), + link_sampler=getattr(self, 'eval_graph_sampler', None), shuffle=False, **self.eval_loader_kwargs, ) - def __repr__(self) -> str: - kwargs = kwargs_repr(data=self.data, loader=self.loader, **self.kwargs) - return f'{self.__class__.__name__}({kwargs})' - ###############################################################################