Skip to content
This repository has been archived by the owner on Oct 9, 2023. It is now read-only.

Commit

Permalink
Docstrings for GraphClassificationData (#1121)
Browse files Browse the repository at this point in the history
  • Loading branch information
ethanwharris committed Jan 19, 2022
1 parent 68c3d70 commit 4997226
Show file tree
Hide file tree
Showing 7 changed files with 180 additions and 33 deletions.
2 changes: 2 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -60,6 +60,8 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).

- Fixed a bug when loading tabular data for prediction without a target field / column ([#1114](https://github.com/PyTorchLightning/lightning-flash/pull/1114))

- Fixed a bug when loading prediction data for graph classification without targets ([#1121](https://github.com/PyTorchLightning/lightning-flash/pull/1121))

### Removed

## [0.6.0] - 2021-13-12
Expand Down
15 changes: 7 additions & 8 deletions flash/core/data/io/classification_input.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,9 +47,7 @@ def load_target_metadata(
target_formatter: Optionally provide a :class:`~flash.core.data.utilities.classification.TargetFormatter`
rather than inferring from the targets.
"""
if target_formatter is None:
if targets is None:
raise ValueError("`targets` must be provided if `target_formatter` is `None`.")
if target_formatter is None and targets is not None:
classification_state = self.get_state(ClassificationState)
if classification_state is not None:
labels, num_classes = classification_state.labels, classification_state.num_classes
Expand All @@ -60,10 +58,11 @@ def load_target_metadata(
else:
self.target_formatter = target_formatter

self.multi_label = self.target_formatter.multi_label
self.labels = self.target_formatter.labels
self.num_classes = self.target_formatter.num_classes
self.set_state(ClassificationState(self.labels, self.num_classes))
if getattr(self, "target_formatter", None) is not None:
self.multi_label = self.target_formatter.multi_label
self.labels = self.target_formatter.labels
self.num_classes = self.target_formatter.num_classes
self.set_state(ClassificationState(self.labels, self.num_classes))

def format_target(self, target: Any) -> Any:
"""Format a single target according to the previously computed target format and metadata.
Expand All @@ -74,6 +73,6 @@ def format_target(self, target: Any) -> Any:
Returns:
The formatted target.
"""
if hasattr(self, "target_formatter"):
if getattr(self, "target_formatter", None) is not None:
return self.target_formatter(target)
return target
3 changes: 2 additions & 1 deletion flash/core/data/utilities/classification.py
Original file line number Diff line number Diff line change
Expand Up @@ -132,6 +132,7 @@ class SingleLabelTargetFormatter(TargetFormatter):
binary: ClassVar[Optional[bool]] = False

def __post_init__(self):
self.num_classes = len(self.labels) if self.num_classes is None else self.num_classes
self.label_to_idx = {label: idx for idx, label in enumerate(self.labels)}

def format(self, target: Any) -> Any:
Expand Down Expand Up @@ -423,7 +424,7 @@ def _get_target_details(

tokens = [_strip(token) for token in tokens]
labels = list(sorted_alphanumeric(set(tokens)))
num_classes = len(labels)
num_classes = None
return labels, num_classes


Expand Down
4 changes: 3 additions & 1 deletion flash/core/data/utilities/samples.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,9 @@ def to_sample(input: Any) -> Dict[str, Any]:
if isinstance(input, dict) and DataKeys.INPUT in input:
return input
if _is_list_like(input) and len(input) == 2:
return {DataKeys.INPUT: input[0], DataKeys.TARGET: input[1]}
if input[1] is not None:
return {DataKeys.INPUT: input[0], DataKeys.TARGET: input[1]}
input = input[0]
return {DataKeys.INPUT: input}


Expand Down
133 changes: 132 additions & 1 deletion flash/graph/classification/data.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,14 +18,21 @@
from flash.core.data.data_module import DataModule
from flash.core.data.data_pipeline import DataPipelineState
from flash.core.data.io.input import Input
from flash.core.data.utilities.classification import TargetFormatter
from flash.core.utilities.imports import _GRAPH_AVAILABLE
from flash.core.utilities.stages import RunningStage
from flash.core.utilities.types import INPUT_TRANSFORM_TYPE
from flash.graph.classification.input import GraphClassificationDatasetInput
from flash.graph.classification.input_transform import GraphClassificationInputTransform

# Skip doctests if requirements aren't available
if not _GRAPH_AVAILABLE:
__doctest_skip__ = ["GraphClassificationData", "GraphClassificationData.*"]


class GraphClassificationData(DataModule):
"""Data module for graph classification tasks."""
"""The ``GraphClassificationData`` class is a :class:`~flash.core.data.data_module.DataModule` with a set of
classmethods for loading data for graph classification."""

input_transform_cls = GraphClassificationInputTransform

Expand All @@ -40,12 +47,135 @@ def from_datasets(
val_transform: INPUT_TRANSFORM_TYPE = GraphClassificationInputTransform,
test_transform: INPUT_TRANSFORM_TYPE = GraphClassificationInputTransform,
predict_transform: INPUT_TRANSFORM_TYPE = GraphClassificationInputTransform,
target_formatter: Optional[TargetFormatter] = None,
input_cls: Type[Input] = GraphClassificationDatasetInput,
transform_kwargs: Optional[Dict] = None,
**data_module_kwargs,
) -> "GraphClassificationData":
"""Load the :class:`~flash.graph.classification.data.GraphClassificationData` from PyTorch Dataset objects.
The Dataset objects should be one of the following:
* A PyTorch Dataset where the ``__getitem__`` returns a tuple: ``(PyTorch Geometric Data object, target)``
* A PyTorch Dataset where the ``__getitem__`` returns a dict:
``{"input": PyTorch Geometric Data object, "target": target}``
To learn how to customize the transforms applied for each stage, read our
:ref:`customizing transforms guide <customizing_transforms>`.
Args:
train_dataset: The Dataset to use when training.
val_dataset: The Dataset to use when validating.
test_dataset: The Dataset to use when testing.
predict_dataset: The Dataset to use when predicting.
train_transform: The :class:`~flash.core.data.io.input_transform.InputTransform` type to use when training.
val_transform: The :class:`~flash.core.data.io.input_transform.InputTransform` type to use when validating.
test_transform: The :class:`~flash.core.data.io.input_transform.InputTransform` type to use when testing.
predict_transform: The :class:`~flash.core.data.io.input_transform.InputTransform` type to use when
predicting.
target_formatter: Optionally provide a :class:`~flash.core.data.utilities.classification.TargetFormatter` to
control how targets are handled. If ``None`` then no formatting will be applied to targets.
input_cls: The :class:`~flash.core.data.io.input.Input` type to use for loading the data.
transform_kwargs: Dict of keyword arguments to be provided when instantiating the transforms.
data_module_kwargs: Additional keyword arguments to provide to the
:class:`~flash.core.data.data_module.DataModule` constructor.
Returns:
The constructed :class:`~flash.graph.classification.data.GraphClassificationData`.
Examples
________
A PyTorch Dataset where the ``__getitem__`` returns a tuple: ``(PyTorch Geometric Data object, target)``:
.. doctest::
>>> import torch
>>> from torch.utils.data import Dataset
>>> from torch_geometric.data import Data
>>> from flash import Trainer
>>> from flash.graph import GraphClassificationData, GraphClassifier
>>> from flash.core.data.utilities.classification import SingleLabelTargetFormatter
>>>
>>> class CustomDataset(Dataset):
... def __init__(self, targets=None):
... self.targets = targets
... def __getitem__(self, index):
... edge_index = torch.tensor([[0, 1, 1, 2], [1, 0, 2, 1]], dtype=torch.long)
... x = torch.tensor([[-1], [0], [1]], dtype=torch.float)
... data = Data(x=x, edge_index=edge_index)
... if self.targets is not None:
... return data, self.targets[index]
... return data
... def __len__(self):
... return len(self.targets) if self.targets is not None else 3
...
>>> datamodule = GraphClassificationData.from_datasets(
... train_dataset=CustomDataset(["cat", "dog", "cat"]),
... predict_dataset=CustomDataset(),
... target_formatter=SingleLabelTargetFormatter(labels=["cat", "dog"]),
... batch_size=2,
... )
>>> datamodule.num_features
1
>>> datamodule.num_classes
2
>>> datamodule.labels
['cat', 'dog']
>>> model = GraphClassifier(num_features=datamodule.num_features, num_classes=datamodule.num_classes)
>>> trainer = Trainer(fast_dev_run=True)
>>> trainer.fit(model, datamodule=datamodule) # doctest: +ELLIPSIS +NORMALIZE_WHITESPACE
Training...
>>> trainer.predict(model, datamodule=datamodule) # doctest: +ELLIPSIS +NORMALIZE_WHITESPACE
Predicting...
A PyTorch Dataset where the ``__getitem__`` returns a dict:
``{"input": PyTorch Geometric Data object, "target": target}``:
.. doctest::
>>> import torch # noqa: F811
>>> from torch.utils.data import Dataset
>>> from torch_geometric.data import Data # noqa: F811
>>> from flash import Trainer
>>> from flash.graph import GraphClassificationData, GraphClassifier
>>> from flash.core.data.utilities.classification import SingleLabelTargetFormatter
>>>
>>> class CustomDataset(Dataset):
... def __init__(self, targets=None):
... self.targets = targets
... def __getitem__(self, index):
... edge_index = torch.tensor([[0, 1, 1, 2], [1, 0, 2, 1]], dtype=torch.long)
... x = torch.tensor([[-1], [0], [1]], dtype=torch.float)
... data = Data(x=x, edge_index=edge_index)
... if self.targets is not None:
... return {"input": data, "target": self.targets[index]}
... return {"input": data}
... def __len__(self):
... return len(self.targets) if self.targets is not None else 3
...
>>> datamodule = GraphClassificationData.from_datasets(
... train_dataset=CustomDataset(["cat", "dog", "cat"]),
... predict_dataset=CustomDataset(),
... target_formatter=SingleLabelTargetFormatter(labels=["cat", "dog"]),
... batch_size=2,
... )
>>> datamodule.num_features
1
>>> datamodule.num_classes
2
>>> datamodule.labels
['cat', 'dog']
>>> model = GraphClassifier(num_features=datamodule.num_features, num_classes=datamodule.num_classes)
>>> trainer = Trainer(fast_dev_run=True)
>>> trainer.fit(model, datamodule=datamodule) # doctest: +ELLIPSIS +NORMALIZE_WHITESPACE
Training...
>>> trainer.predict(model, datamodule=datamodule) # doctest: +ELLIPSIS +NORMALIZE_WHITESPACE
Predicting...
"""

ds_kw = dict(
target_formatter=target_formatter,
data_pipeline_state=DataPipelineState(),
transform_kwargs=transform_kwargs,
input_transforms_registry=cls.input_transforms_registry,
Expand All @@ -61,6 +191,7 @@ def from_datasets(

@property
def num_features(self):
"""The number of features per node in the graphs contained in this ``GraphClassificationData``."""
n_cls_train = getattr(self.train_dataset, "num_features", None)
n_cls_val = getattr(self.val_dataset, "num_features", None)
n_cls_test = getattr(self.test_dataset, "num_features", None)
Expand Down
43 changes: 26 additions & 17 deletions flash/graph/classification/input.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,37 +11,46 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import Any, Mapping
from typing import Any, Dict, Mapping, Optional

from torch.utils.data import Dataset

from flash.core.data.data_module import DatasetInput
from flash.core.data.io.classification_input import ClassificationInputMixin
from flash.core.data.io.input import DataKeys
from flash.core.data.io.input import DataKeys, Input
from flash.core.data.utilities.classification import TargetFormatter
from flash.core.data.utilities.samples import to_sample
from flash.core.utilities.imports import _GRAPH_AVAILABLE, requires

if _GRAPH_AVAILABLE:
from torch_geometric.data import Data
from torch_geometric.data import Dataset as TorchGeometricDataset
from torch_geometric.data import InMemoryDataset
from torch_geometric.data import Data, InMemoryDataset


class GraphClassificationDatasetInput(DatasetInput, ClassificationInputMixin):
def _get_num_features(sample: Dict[str, Any]) -> Optional[int]:
"""Get the number of features per node in the given dataset."""
data = sample[DataKeys.INPUT]
data = data[0] if isinstance(data, tuple) else data
return getattr(data, "num_node_features", None)


class GraphClassificationDatasetInput(Input, ClassificationInputMixin):
@requires("graph")
def load_data(self, dataset: Dataset) -> Dataset:
def load_data(self, dataset: Dataset, target_formatter: Optional[TargetFormatter] = None) -> Dataset:
if not self.predicting:
if isinstance(dataset, TorchGeometricDataset):
self.num_features = dataset.num_features
self.num_features = _get_num_features(self.load_sample(dataset[0]))

if isinstance(dataset, InMemoryDataset):
self.load_target_metadata([sample.y for sample in dataset], target_formatter)
else:
self.load_target_metadata(None, target_formatter)

if isinstance(dataset, InMemoryDataset):
self.load_target_metadata([sample.y for sample in dataset])
else:
self.num_classes = dataset.num_classes
if hasattr(dataset, "num_classes"):
self.num_classes = dataset.num_classes
return dataset

def load_sample(self, sample: Any) -> Mapping[str, Any]:
if isinstance(sample, Data):
sample = {DataKeys.INPUT: sample, DataKeys.TARGET: sample.y}
sample = (sample, sample.y)
sample = to_sample(sample)
if DataKeys.TARGET in sample:
sample[DataKeys.TARGET] = self.format_target(sample[DataKeys.TARGET])
return sample
return super().load_sample(sample)
return sample
13 changes: 8 additions & 5 deletions flash/graph/classification/input_transform.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@

from flash.core.data.io.input import DataKeys
from flash.core.data.io.input_transform import InputTransform
from flash.core.data.utilities.samples import to_sample
from flash.core.utilities.imports import _GRAPH_AVAILABLE

if _GRAPH_AVAILABLE:
Expand All @@ -37,19 +38,21 @@ class PyGTransformAdapter:

transform: Callable[[Data], Data]

def __call__(self, x):
def __call__(self, x: Dict[str, Any]):
data = x[DataKeys.INPUT]
data.y = x[DataKeys.TARGET]
data.y = x.get(DataKeys.TARGET, None)
data = self.transform(data)
return {DataKeys.INPUT: data, DataKeys.TARGET: data.y}
return to_sample((data, data.y))


class GraphClassificationInputTransform(InputTransform):
@staticmethod
def _pyg_collate(samples: List[Dict[str, Any]]) -> Dict[str, Any]:
inputs = Batch.from_data_list([sample[DataKeys.INPUT] for sample in samples])
targets = default_collate([sample[DataKeys.TARGET] for sample in samples])
return {DataKeys.INPUT: inputs, DataKeys.TARGET: targets}
if DataKeys.TARGET in samples[0]:
targets = default_collate([sample[DataKeys.TARGET] for sample in samples])
return {DataKeys.INPUT: inputs, DataKeys.TARGET: targets}
return {DataKeys.INPUT: inputs}

def collate(self) -> Callable:
return self._pyg_collate
Expand Down

0 comments on commit 4997226

Please sign in to comment.