diff --git a/CHANGELOG.md b/CHANGELOG.md index 6133fa51b7..d9ebc3baaf 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -60,6 +60,8 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/). - Fixed a bug when loading tabular data for prediction without a target field / column ([#1114](https://github.com/PyTorchLightning/lightning-flash/pull/1114)) +- Fixed a bug when loading prediction data for graph classification without targets ([#1121](https://github.com/PyTorchLightning/lightning-flash/pull/1121)) + ### Removed ## [0.6.0] - 2021-13-12 diff --git a/flash/core/data/io/classification_input.py b/flash/core/data/io/classification_input.py index 59fe8c43d9..53291b492c 100644 --- a/flash/core/data/io/classification_input.py +++ b/flash/core/data/io/classification_input.py @@ -47,9 +47,7 @@ def load_target_metadata( target_formatter: Optionally provide a :class:`~flash.core.data.utilities.classification.TargetFormatter` rather than inferring from the targets. """ - if target_formatter is None: - if targets is None: - raise ValueError("`targets` must be provided if `target_formatter` is `None`.") + if target_formatter is None and targets is not None: classification_state = self.get_state(ClassificationState) if classification_state is not None: labels, num_classes = classification_state.labels, classification_state.num_classes @@ -60,10 +58,11 @@ def load_target_metadata( else: self.target_formatter = target_formatter - self.multi_label = self.target_formatter.multi_label - self.labels = self.target_formatter.labels - self.num_classes = self.target_formatter.num_classes - self.set_state(ClassificationState(self.labels, self.num_classes)) + if getattr(self, "target_formatter", None) is not None: + self.multi_label = self.target_formatter.multi_label + self.labels = self.target_formatter.labels + self.num_classes = self.target_formatter.num_classes + self.set_state(ClassificationState(self.labels, self.num_classes)) def format_target(self, target: Any) -> Any: """Format a single target according to the previously computed target format and metadata. @@ -74,6 +73,6 @@ def format_target(self, target: Any) -> Any: Returns: The formatted target. """ - if hasattr(self, "target_formatter"): + if getattr(self, "target_formatter", None) is not None: return self.target_formatter(target) return target diff --git a/flash/core/data/utilities/classification.py b/flash/core/data/utilities/classification.py index 3f422d0009..ebff506813 100644 --- a/flash/core/data/utilities/classification.py +++ b/flash/core/data/utilities/classification.py @@ -132,6 +132,7 @@ class SingleLabelTargetFormatter(TargetFormatter): binary: ClassVar[Optional[bool]] = False def __post_init__(self): + self.num_classes = len(self.labels) if self.num_classes is None else self.num_classes self.label_to_idx = {label: idx for idx, label in enumerate(self.labels)} def format(self, target: Any) -> Any: @@ -423,7 +424,7 @@ def _get_target_details( tokens = [_strip(token) for token in tokens] labels = list(sorted_alphanumeric(set(tokens))) - num_classes = len(labels) + num_classes = None return labels, num_classes diff --git a/flash/core/data/utilities/samples.py b/flash/core/data/utilities/samples.py index 57a27ce072..70a2bdf8db 100644 --- a/flash/core/data/utilities/samples.py +++ b/flash/core/data/utilities/samples.py @@ -36,7 +36,9 @@ def to_sample(input: Any) -> Dict[str, Any]: if isinstance(input, dict) and DataKeys.INPUT in input: return input if _is_list_like(input) and len(input) == 2: - return {DataKeys.INPUT: input[0], DataKeys.TARGET: input[1]} + if input[1] is not None: + return {DataKeys.INPUT: input[0], DataKeys.TARGET: input[1]} + input = input[0] return {DataKeys.INPUT: input} diff --git a/flash/graph/classification/data.py b/flash/graph/classification/data.py index 46a2b25842..e87bf30eb1 100644 --- a/flash/graph/classification/data.py +++ b/flash/graph/classification/data.py @@ -18,14 +18,21 @@ from flash.core.data.data_module import DataModule from flash.core.data.data_pipeline import DataPipelineState from flash.core.data.io.input import Input +from flash.core.data.utilities.classification import TargetFormatter +from flash.core.utilities.imports import _GRAPH_AVAILABLE from flash.core.utilities.stages import RunningStage from flash.core.utilities.types import INPUT_TRANSFORM_TYPE from flash.graph.classification.input import GraphClassificationDatasetInput from flash.graph.classification.input_transform import GraphClassificationInputTransform +# Skip doctests if requirements aren't available +if not _GRAPH_AVAILABLE: + __doctest_skip__ = ["GraphClassificationData", "GraphClassificationData.*"] + class GraphClassificationData(DataModule): - """Data module for graph classification tasks.""" + """The ``GraphClassificationData`` class is a :class:`~flash.core.data.data_module.DataModule` with a set of + classmethods for loading data for graph classification.""" input_transform_cls = GraphClassificationInputTransform @@ -40,12 +47,135 @@ def from_datasets( val_transform: INPUT_TRANSFORM_TYPE = GraphClassificationInputTransform, test_transform: INPUT_TRANSFORM_TYPE = GraphClassificationInputTransform, predict_transform: INPUT_TRANSFORM_TYPE = GraphClassificationInputTransform, + target_formatter: Optional[TargetFormatter] = None, input_cls: Type[Input] = GraphClassificationDatasetInput, transform_kwargs: Optional[Dict] = None, **data_module_kwargs, ) -> "GraphClassificationData": + """Load the :class:`~flash.graph.classification.data.GraphClassificationData` from PyTorch Dataset objects. + + The Dataset objects should be one of the following: + + * A PyTorch Dataset where the ``__getitem__`` returns a tuple: ``(PyTorch Geometric Data object, target)`` + * A PyTorch Dataset where the ``__getitem__`` returns a dict: + ``{"input": PyTorch Geometric Data object, "target": target}`` + + To learn how to customize the transforms applied for each stage, read our + :ref:`customizing transforms guide `. + + Args: + train_dataset: The Dataset to use when training. + val_dataset: The Dataset to use when validating. + test_dataset: The Dataset to use when testing. + predict_dataset: The Dataset to use when predicting. + train_transform: The :class:`~flash.core.data.io.input_transform.InputTransform` type to use when training. + val_transform: The :class:`~flash.core.data.io.input_transform.InputTransform` type to use when validating. + test_transform: The :class:`~flash.core.data.io.input_transform.InputTransform` type to use when testing. + predict_transform: The :class:`~flash.core.data.io.input_transform.InputTransform` type to use when + predicting. + target_formatter: Optionally provide a :class:`~flash.core.data.utilities.classification.TargetFormatter` to + control how targets are handled. If ``None`` then no formatting will be applied to targets. + input_cls: The :class:`~flash.core.data.io.input.Input` type to use for loading the data. + transform_kwargs: Dict of keyword arguments to be provided when instantiating the transforms. + data_module_kwargs: Additional keyword arguments to provide to the + :class:`~flash.core.data.data_module.DataModule` constructor. + + Returns: + The constructed :class:`~flash.graph.classification.data.GraphClassificationData`. + + Examples + ________ + + A PyTorch Dataset where the ``__getitem__`` returns a tuple: ``(PyTorch Geometric Data object, target)``: + + .. doctest:: + + >>> import torch + >>> from torch.utils.data import Dataset + >>> from torch_geometric.data import Data + >>> from flash import Trainer + >>> from flash.graph import GraphClassificationData, GraphClassifier + >>> from flash.core.data.utilities.classification import SingleLabelTargetFormatter + >>> + >>> class CustomDataset(Dataset): + ... def __init__(self, targets=None): + ... self.targets = targets + ... def __getitem__(self, index): + ... edge_index = torch.tensor([[0, 1, 1, 2], [1, 0, 2, 1]], dtype=torch.long) + ... x = torch.tensor([[-1], [0], [1]], dtype=torch.float) + ... data = Data(x=x, edge_index=edge_index) + ... if self.targets is not None: + ... return data, self.targets[index] + ... return data + ... def __len__(self): + ... return len(self.targets) if self.targets is not None else 3 + ... + >>> datamodule = GraphClassificationData.from_datasets( + ... train_dataset=CustomDataset(["cat", "dog", "cat"]), + ... predict_dataset=CustomDataset(), + ... target_formatter=SingleLabelTargetFormatter(labels=["cat", "dog"]), + ... batch_size=2, + ... ) + >>> datamodule.num_features + 1 + >>> datamodule.num_classes + 2 + >>> datamodule.labels + ['cat', 'dog'] + >>> model = GraphClassifier(num_features=datamodule.num_features, num_classes=datamodule.num_classes) + >>> trainer = Trainer(fast_dev_run=True) + >>> trainer.fit(model, datamodule=datamodule) # doctest: +ELLIPSIS +NORMALIZE_WHITESPACE + Training... + >>> trainer.predict(model, datamodule=datamodule) # doctest: +ELLIPSIS +NORMALIZE_WHITESPACE + Predicting... + + A PyTorch Dataset where the ``__getitem__`` returns a dict: + ``{"input": PyTorch Geometric Data object, "target": target}``: + + .. doctest:: + + >>> import torch # noqa: F811 + >>> from torch.utils.data import Dataset + >>> from torch_geometric.data import Data # noqa: F811 + >>> from flash import Trainer + >>> from flash.graph import GraphClassificationData, GraphClassifier + >>> from flash.core.data.utilities.classification import SingleLabelTargetFormatter + >>> + >>> class CustomDataset(Dataset): + ... def __init__(self, targets=None): + ... self.targets = targets + ... def __getitem__(self, index): + ... edge_index = torch.tensor([[0, 1, 1, 2], [1, 0, 2, 1]], dtype=torch.long) + ... x = torch.tensor([[-1], [0], [1]], dtype=torch.float) + ... data = Data(x=x, edge_index=edge_index) + ... if self.targets is not None: + ... return {"input": data, "target": self.targets[index]} + ... return {"input": data} + ... def __len__(self): + ... return len(self.targets) if self.targets is not None else 3 + ... + >>> datamodule = GraphClassificationData.from_datasets( + ... train_dataset=CustomDataset(["cat", "dog", "cat"]), + ... predict_dataset=CustomDataset(), + ... target_formatter=SingleLabelTargetFormatter(labels=["cat", "dog"]), + ... batch_size=2, + ... ) + >>> datamodule.num_features + 1 + >>> datamodule.num_classes + 2 + >>> datamodule.labels + ['cat', 'dog'] + >>> model = GraphClassifier(num_features=datamodule.num_features, num_classes=datamodule.num_classes) + >>> trainer = Trainer(fast_dev_run=True) + >>> trainer.fit(model, datamodule=datamodule) # doctest: +ELLIPSIS +NORMALIZE_WHITESPACE + Training... + >>> trainer.predict(model, datamodule=datamodule) # doctest: +ELLIPSIS +NORMALIZE_WHITESPACE + Predicting... + """ ds_kw = dict( + target_formatter=target_formatter, data_pipeline_state=DataPipelineState(), transform_kwargs=transform_kwargs, input_transforms_registry=cls.input_transforms_registry, @@ -61,6 +191,7 @@ def from_datasets( @property def num_features(self): + """The number of features per node in the graphs contained in this ``GraphClassificationData``.""" n_cls_train = getattr(self.train_dataset, "num_features", None) n_cls_val = getattr(self.val_dataset, "num_features", None) n_cls_test = getattr(self.test_dataset, "num_features", None) diff --git a/flash/graph/classification/input.py b/flash/graph/classification/input.py index 420fb04152..0273f4b350 100644 --- a/flash/graph/classification/input.py +++ b/flash/graph/classification/input.py @@ -11,37 +11,46 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -from typing import Any, Mapping +from typing import Any, Dict, Mapping, Optional from torch.utils.data import Dataset -from flash.core.data.data_module import DatasetInput from flash.core.data.io.classification_input import ClassificationInputMixin -from flash.core.data.io.input import DataKeys +from flash.core.data.io.input import DataKeys, Input +from flash.core.data.utilities.classification import TargetFormatter +from flash.core.data.utilities.samples import to_sample from flash.core.utilities.imports import _GRAPH_AVAILABLE, requires if _GRAPH_AVAILABLE: - from torch_geometric.data import Data - from torch_geometric.data import Dataset as TorchGeometricDataset - from torch_geometric.data import InMemoryDataset + from torch_geometric.data import Data, InMemoryDataset -class GraphClassificationDatasetInput(DatasetInput, ClassificationInputMixin): +def _get_num_features(sample: Dict[str, Any]) -> Optional[int]: + """Get the number of features per node in the given dataset.""" + data = sample[DataKeys.INPUT] + data = data[0] if isinstance(data, tuple) else data + return getattr(data, "num_node_features", None) + + +class GraphClassificationDatasetInput(Input, ClassificationInputMixin): @requires("graph") - def load_data(self, dataset: Dataset) -> Dataset: + def load_data(self, dataset: Dataset, target_formatter: Optional[TargetFormatter] = None) -> Dataset: if not self.predicting: - if isinstance(dataset, TorchGeometricDataset): - self.num_features = dataset.num_features + self.num_features = _get_num_features(self.load_sample(dataset[0])) + + if isinstance(dataset, InMemoryDataset): + self.load_target_metadata([sample.y for sample in dataset], target_formatter) + else: + self.load_target_metadata(None, target_formatter) - if isinstance(dataset, InMemoryDataset): - self.load_target_metadata([sample.y for sample in dataset]) - else: - self.num_classes = dataset.num_classes + if hasattr(dataset, "num_classes"): + self.num_classes = dataset.num_classes return dataset def load_sample(self, sample: Any) -> Mapping[str, Any]: if isinstance(sample, Data): - sample = {DataKeys.INPUT: sample, DataKeys.TARGET: sample.y} + sample = (sample, sample.y) + sample = to_sample(sample) + if DataKeys.TARGET in sample: sample[DataKeys.TARGET] = self.format_target(sample[DataKeys.TARGET]) - return sample - return super().load_sample(sample) + return sample diff --git a/flash/graph/classification/input_transform.py b/flash/graph/classification/input_transform.py index 55b7638710..1a78f0d098 100644 --- a/flash/graph/classification/input_transform.py +++ b/flash/graph/classification/input_transform.py @@ -18,6 +18,7 @@ from flash.core.data.io.input import DataKeys from flash.core.data.io.input_transform import InputTransform +from flash.core.data.utilities.samples import to_sample from flash.core.utilities.imports import _GRAPH_AVAILABLE if _GRAPH_AVAILABLE: @@ -37,19 +38,21 @@ class PyGTransformAdapter: transform: Callable[[Data], Data] - def __call__(self, x): + def __call__(self, x: Dict[str, Any]): data = x[DataKeys.INPUT] - data.y = x[DataKeys.TARGET] + data.y = x.get(DataKeys.TARGET, None) data = self.transform(data) - return {DataKeys.INPUT: data, DataKeys.TARGET: data.y} + return to_sample((data, data.y)) class GraphClassificationInputTransform(InputTransform): @staticmethod def _pyg_collate(samples: List[Dict[str, Any]]) -> Dict[str, Any]: inputs = Batch.from_data_list([sample[DataKeys.INPUT] for sample in samples]) - targets = default_collate([sample[DataKeys.TARGET] for sample in samples]) - return {DataKeys.INPUT: inputs, DataKeys.TARGET: targets} + if DataKeys.TARGET in samples[0]: + targets = default_collate([sample[DataKeys.TARGET] for sample in samples]) + return {DataKeys.INPUT: inputs, DataKeys.TARGET: targets} + return {DataKeys.INPUT: inputs} def collate(self) -> Callable: return self._pyg_collate