diff --git a/flash/core/model.py b/flash/core/model.py index 78c907fc6c..d2cc3c2bef 100644 --- a/flash/core/model.py +++ b/flash/core/model.py @@ -245,22 +245,11 @@ def on_fit_end(self) -> None: self.data_pipeline._detach_from_model(self) super().on_fit_end() - @staticmethod - def _sanetize_funcs(obj: Any) -> Any: - if hasattr(obj, "__dict__"): - for k, v in obj.__dict__.items(): - if isinstance(v, Callable): - obj.__dict__[k] = inspect.unwrap(v) - return obj - def on_save_checkpoint(self, checkpoint: Dict[str, Any]) -> None: - # TODO: Is this the best way to do this? or should we also use some kind of hparams here? # This may be an issue since here we create the same problems with pickle as in # https://pytorch.org/docs/stable/notes/serialization.html if self.data_pipeline is not None and 'data_pipeline' not in checkpoint: - self._preprocess = self._sanetize_funcs(self._preprocess) checkpoint['data_pipeline'] = self.data_pipeline - # todo (tchaton) re-wrap visualization super().on_save_checkpoint(checkpoint) def on_load_checkpoint(self, checkpoint: Dict[str, Any]) -> None: diff --git a/flash/core/utils.py b/flash/core/utils.py index 353124ef94..7676218c71 100644 --- a/flash/core/utils.py +++ b/flash/core/utils.py @@ -11,7 +11,7 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -from typing import Callable, Dict, Mapping, Sequence, Union +from typing import Any, Callable, Dict, Mapping, Sequence, Type, Union def get_callable_name(fn_or_class: Union[Callable, object]) -> str: @@ -25,3 +25,15 @@ def get_callable_dict(fn: Union[Callable, Mapping, Sequence]) -> Union[Dict, Map return {get_callable_name(f): f for f in fn} elif callable(fn): return {get_callable_name(fn): fn} + + +def _is_overriden(method_name: str, instance: object, parent: Type[object]) -> bool: + """ + Cropped Version of + https://github.com/PyTorchLightning/pytorch-lightning/blob/master/pytorch_lightning/utilities/model_helpers.py + """ + + if not hasattr(instance, method_name): + return False + + return getattr(instance, method_name).__code__ != getattr(parent, method_name).__code__ diff --git a/flash/data/auto_dataset.py b/flash/data/auto_dataset.py index 5652496c10..498b67a33d 100644 --- a/flash/data/auto_dataset.py +++ b/flash/data/auto_dataset.py @@ -19,6 +19,7 @@ from pytorch_lightning.utilities.warning_utils import rank_zero_warn from torch.utils.data import Dataset +from flash.data.callback import ControlFlow from flash.data.process import Preprocess from flash.data.utils import _STAGES_PREFIX, _STAGES_PREFIX_VALUES, CurrentRunningStageFuncContext @@ -82,6 +83,12 @@ def preprocess(self) -> Optional[Preprocess]: if self.data_pipeline is not None: return self.data_pipeline._preprocess_pipeline + @property + def control_flow_callback(self) -> Optional[ControlFlow]: + preprocess = self.preprocess + if preprocess is not None: + return ControlFlow(preprocess.callbacks) + def _call_load_data(self, data: Any) -> Iterable: parameters = signature(self.load_data).parameters if len(parameters) > 1 and self.DATASET_KEY in parameters: @@ -124,7 +131,10 @@ def __getitem__(self, index: int) -> Any: raise RuntimeError("`__getitem__` for `load_sample` and `load_data` could not be inferred.") if self.load_sample: with self._load_sample_context: - return self._call_load_sample(self.preprocessed_data[index]) + data: Any = self._call_load_sample(self.preprocessed_data[index]) + if self.control_flow_callback: + self.control_flow_callback.on_load_sample(data, self.running_stage) + return data return self.preprocessed_data[index] def __len__(self) -> int: diff --git a/flash/data/base_viz.py b/flash/data/base_viz.py new file mode 100644 index 0000000000..2dcd95fc27 --- /dev/null +++ b/flash/data/base_viz.py @@ -0,0 +1,111 @@ +from contextlib import contextmanager +from typing import Any, Dict, List, Sequence + +from pytorch_lightning.trainer.states import RunningStage +from torch import Tensor + +from flash.core.utils import _is_overriden +from flash.data.callback import FlashCallback +from flash.data.process import Preprocess +from flash.data.utils import _PREPROCESS_FUNCS, _STAGES_PREFIX + + +class BaseViz(FlashCallback): + """ + This class is used to profile ``Preprocess`` hook outputs and visualize the data transformations. + It is disabled by default. + + batches: Dict = {"train": {"to_tensor_transform": [], ...}, ...} + + """ + + def __init__(self, enabled: bool = False): + self.batches = {k: {} for k in _STAGES_PREFIX.values()} + self.enabled = enabled + self._preprocess = None + + def on_load_sample(self, sample: Any, running_stage: RunningStage) -> None: + store = self.batches[_STAGES_PREFIX[running_stage]] + store.setdefault("load_sample", []) + store["load_sample"].append(sample) + + def on_pre_tensor_transform(self, sample: Any, running_stage: RunningStage) -> None: + store = self.batches[_STAGES_PREFIX[running_stage]] + store.setdefault("pre_tensor_transform", []) + store["pre_tensor_transform"].append(sample) + + def on_to_tensor_transform(self, sample: Any, running_stage: RunningStage) -> None: + store = self.batches[_STAGES_PREFIX[running_stage]] + store.setdefault("to_tensor_transform", []) + store["to_tensor_transform"].append(sample) + + def on_post_tensor_transform(self, sample: Tensor, running_stage: RunningStage) -> None: + store = self.batches[_STAGES_PREFIX[running_stage]] + store.setdefault("post_tensor_transform", []) + store["post_tensor_transform"].append(sample) + + def on_per_batch_transform(self, batch: Any, running_stage: RunningStage) -> None: + store = self.batches[_STAGES_PREFIX[running_stage]] + store.setdefault("per_batch_transform", []) + store["per_batch_transform"].append(batch) + + def on_collate(self, batch: Sequence, running_stage: RunningStage) -> None: + store = self.batches[_STAGES_PREFIX[running_stage]] + store.setdefault("collate", []) + store["collate"].append(batch) + + def on_per_sample_transform_on_device(self, samples: Sequence, running_stage: RunningStage) -> None: + store = self.batches[_STAGES_PREFIX[running_stage]] + store.setdefault("per_sample_transform_on_device", []) + store["per_sample_transform_on_device"].append(samples) + + def on_per_batch_transform_on_device(self, batch: Any, running_stage: RunningStage) -> None: + store = self.batches[_STAGES_PREFIX[running_stage]] + store.setdefault("per_batch_transform_on_device", []) + store["per_batch_transform_on_device"].append(batch) + + @contextmanager + def enable(self): + self.enabled = True + yield + self.enabled = False + + def attach_to_datamodule(self, datamodule) -> None: + datamodule.viz = self + + def attach_to_preprocess(self, preprocess: Preprocess) -> None: + preprocess.callbacks = [self] + self._preprocess = preprocess + + def show(self, batch: Dict[str, Any], running_stage: RunningStage) -> None: + """ + This function is a hook for users to override with their visualization on a batch. + """ + for func_name in _PREPROCESS_FUNCS: + hook_name = f"show_{func_name}" + if _is_overriden(hook_name, self, BaseViz): + getattr(self, hook_name)(batch[func_name], running_stage) + + def show_load_sample(self, samples: List[Any], running_stage: RunningStage): + pass + + def show_pre_tensor_transform(self, samples: List[Any], running_stage: RunningStage): + pass + + def show_to_tensor_transform(self, samples: List[Any], running_stage: RunningStage): + pass + + def show_post_tensor_transform(self, samples: List[Any], running_stage: RunningStage): + pass + + def show_collate(self, batch: Sequence, running_stage: RunningStage) -> None: + pass + + def show_per_batch_transform(self, batch: Any, running_stage: RunningStage) -> None: + pass + + def show_per_sample_transform_on_device(self, samples: Sequence, running_stage: RunningStage) -> None: + pass + + def show_per_batch_transform_on_device(self, batch: Any, running_stage: RunningStage) -> None: + pass diff --git a/flash/data/batch.py b/flash/data/batch.py index 9c7cce304e..3758d78a66 100644 --- a/flash/data/batch.py +++ b/flash/data/batch.py @@ -18,6 +18,7 @@ from pytorch_lightning.utilities.exceptions import MisconfigurationException from torch import Tensor +from flash.data.callback import ControlFlow from flash.data.utils import _contains_any_tensor, convert_to_modules, CurrentFuncContext, CurrentRunningStageContext if TYPE_CHECKING: @@ -43,6 +44,7 @@ def __init__( ): super().__init__() self.preprocess = preprocess + self.callback = ControlFlow(self.preprocess.callbacks) self.pre_tensor_transform = convert_to_modules(pre_tensor_transform) self.to_tensor_transform = convert_to_modules(to_tensor_transform) self.post_tensor_transform = convert_to_modules(post_tensor_transform) @@ -58,9 +60,11 @@ def forward(self, sample: Any) -> Any: with self._current_stage_context: with self._pre_tensor_transform_context: sample = self.pre_tensor_transform(sample) + self.callback.on_pre_tensor_transform(sample, self.stage) with self._to_tensor_transform_context: sample = self.to_tensor_transform(sample) + self.callback.on_to_tensor_transform(sample, self.stage) if self.assert_contains_tensor: if not _contains_any_tensor(sample): @@ -71,6 +75,7 @@ def forward(self, sample: Any) -> Any: with self._post_tensor_transform_context: sample = self.post_tensor_transform(sample) + self.callback.on_post_tensor_transform(sample, self.stage) return sample @@ -112,10 +117,11 @@ def __init__( per_batch_transform: Callable, stage: RunningStage, apply_per_sample_transform: bool = True, - on_device: bool = False + on_device: bool = False, ): super().__init__() self.preprocess = preprocess + self.callback = ControlFlow(self.preprocess.callbacks) self.collate_fn = convert_to_modules(collate_fn) self.per_sample_transform = convert_to_modules(per_sample_transform) self.per_batch_transform = convert_to_modules(per_batch_transform) @@ -123,25 +129,36 @@ def __init__( self.stage = stage self.on_device = on_device - extension = f"{'on_device' if self.on_device else ''}" + extension = f"{'_on_device' if self.on_device else ''}" self._current_stage_context = CurrentRunningStageContext(stage, preprocess) - self._per_sample_transform_context = CurrentFuncContext(f"per_sample_transform_{extension}", preprocess) + self._per_sample_transform_context = CurrentFuncContext(f"per_sample_transform{extension}", preprocess) self._collate_context = CurrentFuncContext("collate", preprocess) - self._per_batch_transform_context = CurrentFuncContext(f"per_batch_transform_{extension}", preprocess) + self._per_batch_transform_context = CurrentFuncContext(f"per_batch_transform{extension}", preprocess) def forward(self, samples: Sequence[Any]) -> Any: with self._current_stage_context: if self.apply_per_sample_transform: with self._per_sample_transform_context: - samples = [self.per_sample_transform(sample) for sample in samples] - samples = type(samples)(samples) + _samples = [] + for sample in samples: + sample = self.per_sample_transform(sample) + if self.on_device: + self.callback.on_per_sample_transform_on_device(sample, self.stage) + _samples.append(sample) + + samples = type(_samples)(_samples) with self._collate_context: samples = self.collate_fn(samples) + self.callback.on_collate(samples, self.stage) with self._per_batch_transform_context: samples = self.per_batch_transform(samples) + if self.on_device: + self.callback.on_per_batch_transform_on_device(samples, self.stage) + else: + self.callback.on_per_batch_transform(samples, self.stage) return samples def __str__(self) -> str: diff --git a/flash/data/callback.py b/flash/data/callback.py new file mode 100644 index 0000000000..b253b77d91 --- /dev/null +++ b/flash/data/callback.py @@ -0,0 +1,67 @@ +from typing import Any, List, Sequence + +from pytorch_lightning.callbacks import Callback +from pytorch_lightning.trainer.states import RunningStage +from torch import Tensor + + +class FlashCallback(Callback): + + def on_load_sample(self, sample: Any, running_stage: RunningStage) -> None: + """Called once a sample has been loaded using ``load_sample``.""" + + def on_pre_tensor_transform(self, sample: Any, running_stage: RunningStage) -> None: + """Called once ``pre_tensor_transform`` have been applied to a sample.""" + + def on_to_tensor_transform(self, sample: Any, running_stage: RunningStage) -> None: + """Called once ``to_tensor_transform`` have been applied to a sample.""" + + def on_post_tensor_transform(self, sample: Tensor, running_stage: RunningStage) -> None: + """Called once ``post_tensor_transform`` have been applied to a sample.""" + + def on_per_batch_transform(self, batch: Any, running_stage: RunningStage) -> None: + """Called once ``per_batch_transform`` have been applied to a batch.""" + + def on_collate(self, batch: Sequence, running_stage: RunningStage) -> None: + """Called once ``collate`` have been applied to a sequence of samples.""" + + def on_per_sample_transform_on_device(self, sample: Any, running_stage: RunningStage) -> None: + """Called once ``per_sample_transform_on_device`` have been applied to a sample.""" + + def on_per_batch_transform_on_device(self, batch: Any, running_stage: RunningStage) -> None: + """Called once ``per_batch_transform_on_device`` have been applied to a sample.""" + + +class ControlFlow(FlashCallback): + + def __init__(self, callbacks: List[FlashCallback]): + self._callbacks = callbacks + + def run_for_all_callbacks(self, *args, method_name: str, **kwargs): + if self._callbacks: + for cb in self._callbacks: + getattr(cb, method_name)(*args, **kwargs) + + def on_load_sample(self, sample: Any, running_stage: RunningStage) -> None: + self.run_for_all_callbacks(sample, running_stage, method_name="on_load_sample") + + def on_pre_tensor_transform(self, sample: Any, running_stage: RunningStage) -> None: + self.run_for_all_callbacks(sample, running_stage, method_name="on_pre_tensor_transform") + + def on_to_tensor_transform(self, sample: Any, running_stage: RunningStage) -> None: + self.run_for_all_callbacks(sample, running_stage, method_name="on_to_tensor_transform") + + def on_post_tensor_transform(self, sample: Tensor, running_stage: RunningStage) -> None: + self.run_for_all_callbacks(sample, running_stage, method_name="on_post_tensor_transform") + + def on_per_batch_transform(self, batch: Any, running_stage: RunningStage) -> None: + self.run_for_all_callbacks(batch, running_stage, method_name="on_per_batch_transform") + + def on_collate(self, batch: Sequence, running_stage: RunningStage) -> None: + self.run_for_all_callbacks(batch, running_stage, method_name="on_collate") + + def on_per_sample_transform_on_device(self, sample: Any, running_stage: RunningStage) -> None: + self.run_for_all_callbacks(sample, running_stage, method_name="on_per_sample_transform_on_device") + + def on_per_batch_transform_on_device(self, batch: Any, running_stage: RunningStage) -> None: + self.run_for_all_callbacks(batch, running_stage, method_name="on_per_batch_transform_on_device") diff --git a/flash/data/data_module.py b/flash/data/data_module.py index f7c2e8f6d2..4b9608cbda 100644 --- a/flash/data/data_module.py +++ b/flash/data/data_module.py @@ -13,7 +13,7 @@ # limitations under the License. import os import platform -from typing import Any, Callable, Dict, Optional, Tuple, Union +from typing import Any, Callable, Dict, Generator, Iterable, Optional, Tuple, Union import pytorch_lightning as pl import torch @@ -24,7 +24,9 @@ from torch.utils.data.dataset import Subset from flash.data.auto_dataset import AutoDataset +from flash.data.base_viz import BaseViz from flash.data.data_pipeline import DataPipeline, Postprocess, Preprocess +from flash.data.utils import _STAGES_PREFIX class DataModule(pl.LightningDataModule): @@ -53,7 +55,7 @@ def __init__( test_dataset: Optional[Dataset] = None, predict_dataset: Optional[Dataset] = None, batch_size: int = 1, - num_workers: Optional[int] = None, + num_workers: Optional[int] = 0, ) -> None: super().__init__() @@ -83,10 +85,67 @@ def __init__( self._preprocess = None self._postprocess = None + self._viz = None # this may also trigger data preloading self.set_running_stages() + @property + def viz(self) -> BaseViz: + return self._viz or DataModule.configure_vis() + + @viz.setter + def viz(self, viz: BaseViz) -> None: + self._viz = viz + + @staticmethod + def configure_vis(*args, **kwargs) -> BaseViz: + return BaseViz() + + def show(self, batch: Dict[str, Any], stage: RunningStage) -> None: + """ + This function is a hook for users to override with their visualization on a batch. + """ + self.viz.show(batch, stage) + + def _show_batch(self, stage: RunningStage, reset: bool = True) -> None: + """ + This function is used to handle transforms profiling for batch visualization. + """ + iter_name = f"_{stage}_iter" + + def _reset_iterator() -> Iterable[Any]: + dataloader_fn = getattr(self, f"{stage}_dataloader") + iterator = iter(dataloader_fn()) + setattr(self, iter_name, iterator) + return iterator + + if not hasattr(self, iter_name): + _reset_iterator() + + iter_dataloader = getattr(self, iter_name) + with self.viz.enable(): + try: + _ = next(iter_dataloader) + except StopIteration: + iter_dataloader = _reset_iterator() + _ = next(iter_dataloader) + self.show(self.viz.batches[stage], stage) + if reset: + self.viz.batches[stage] = {} + + def show_train_batch(self, reset: bool = True) -> None: + self._show_batch(_STAGES_PREFIX[RunningStage.TRAINING], reset=reset) + + def show_val_batch(self, reset: bool = True) -> None: + self._show_batch(_STAGES_PREFIX[RunningStage.VALIDATING], reset=reset) + + def show_test_batch(self, reset: bool = True) -> None: + self._show_batch(_STAGES_PREFIX[RunningStage.TESTING], reset=reset) + + def show_predict_batch(self, reset: bool = True) -> None: + self._show_batch(_STAGES_PREFIX[RunningStage.PREDICTING], reset=reset) + @staticmethod def get_dataset_attribute(dataset: torch.utils.data.Dataset, attr_name: str, default: Optional[Any] = None) -> Any: if isinstance(dataset, Subset): @@ -320,6 +379,9 @@ def from_load_data_inputs( else: data_pipeline = cls(**kwargs).data_pipeline + viz_callback = cls.configure_vis() + viz_callback.attach_to_preprocess(data_pipeline._preprocess_pipeline) + train_dataset = cls._generate_dataset_if_possible( train_load_data_input, running_stage=RunningStage.TRAINING, data_pipeline=data_pipeline ) @@ -341,4 +403,5 @@ def from_load_data_inputs( ) datamodule._preprocess = data_pipeline._preprocess_pipeline datamodule._postprocess = data_pipeline._postprocess_pipeline + viz_callback.attach_to_datamodule(datamodule) return datamodule diff --git a/flash/data/data_pipeline.py b/flash/data/data_pipeline.py index 7cce7fc04a..b6e400e0c9 100644 --- a/flash/data/data_pipeline.py +++ b/flash/data/data_pipeline.py @@ -26,7 +26,7 @@ from flash.data.auto_dataset import AutoDataset from flash.data.batch import _PostProcessor, _PreProcessor, _Sequential from flash.data.process import Postprocess, Preprocess -from flash.data.utils import _STAGES_PREFIX +from flash.data.utils import _PREPROCESS_FUNCS, _STAGES_PREFIX if TYPE_CHECKING: from flash.core.model import Task @@ -129,17 +129,7 @@ def forward(self, samples: Sequence[Any]): """ - PREPROCESS_FUNCS = { - "load_data", - "load_sample", - "pre_tensor_transform", - "to_tensor_transform", - "post_tensor_transform", - "per_batch_transform", - "per_sample_transform_on_device", - "per_batch_transform_on_device", - "collate", - } + PREPROCESS_FUNCS = _PREPROCESS_FUNCS def __init__(self, preprocess: Optional[Preprocess] = None, postprocess: Optional[Postprocess] = None) -> None: self._preprocess_pipeline = preprocess or Preprocess() diff --git a/flash/data/process.py b/flash/data/process.py index 8fd16e5359..1512ec1d28 100644 --- a/flash/data/process.py +++ b/flash/data/process.py @@ -13,7 +13,7 @@ # limitations under the License. import os from dataclasses import dataclass -from typing import Any, Callable, Dict, Optional, Sequence, Union +from typing import Any, Callable, Dict, List, Optional, Sequence, Union import torch from pytorch_lightning.trainer.states import RunningStage @@ -22,6 +22,7 @@ from torch.utils.data._utils.collate import default_collate from flash.data.batch import default_uncollate +from flash.data.callback import FlashCallback from flash.data.utils import convert_to_modules @@ -117,6 +118,8 @@ def __init__( if not hasattr(self, "_skip_mutual_check"): self._skip_mutual_check = False + self._callbacks: List[FlashCallback] = [] + @property def skip_mutual_check(self) -> bool: return self._skip_mutual_check @@ -150,6 +153,17 @@ def current_transform(self) -> Callable: def from_state(cls, state: PreprocessState) -> 'Preprocess': return cls(**vars(state)) + @property + def callbacks(self) -> List['FlashCallback']: + if not hasattr(self, "_callbacks"): + self._callbacks: List[FlashCallback] = [] + return self._callbacks + + @callbacks.setter + def callbacks(self, callbacks: List['FlashCallback']): + _callbacks = [c for c in callbacks if c not in self._callbacks] + self._callbacks.extend(_callbacks) + @classmethod def load_data(cls, data: Any, dataset: Optional[Any] = None) -> Any: """Loads entire data from Dataset""" @@ -201,7 +215,6 @@ def per_batch_transform_on_device(self, batch: Any) -> Any: return batch -@dataclass(unsafe_hash=True) class Postprocess(Properties, torch.nn.Module): def __init__(self, save_path: Optional[str] = None): diff --git a/flash/data/utils.py b/flash/data/utils.py index f3928861ba..8436001d5a 100644 --- a/flash/data/utils.py +++ b/flash/data/utils.py @@ -32,6 +32,18 @@ } _STAGES_PREFIX_VALUES = {"train", "test", "val", "predict"} +_PREPROCESS_FUNCS = { + "load_data", + "load_sample", + "pre_tensor_transform", + "to_tensor_transform", + "post_tensor_transform", + "per_batch_transform", + "per_sample_transform_on_device", + "per_batch_transform_on_device", + "collate", +} + class CurrentRunningStageContext: diff --git a/flash/vision/classification/data.py b/flash/vision/classification/data.py index d66c9bb355..8e6ad5c8c7 100644 --- a/flash/vision/classification/data.py +++ b/flash/vision/classification/data.py @@ -176,6 +176,9 @@ def to_tensor_transform(self, sample: Any) -> Any: def post_tensor_transform(self, sample: Any) -> Any: return self.common_step(sample) + # todo: (tchaton) `per_batch_transform` and `per_sample_transform_on_device` are mutually exclusive + # `skip_mutual_check` is used to skip the checks as the information are provided from the transforms directly + # Need to properly set the `collate` depending on user provided transforms def per_batch_transform(self, sample: Any) -> Any: return self.common_step(sample) @@ -220,6 +223,7 @@ def __init__( predict_dataset=predict_dataset, batch_size=batch_size, num_workers=num_workers, + **kwargs, ) self._num_classes = None @@ -468,6 +472,7 @@ def from_filepaths( folder/cat_asd932_.png Args: + train_filepaths: String or sequence of file paths for training dataset. Defaults to ``None``. train_labels: Sequence of labels for training dataset. Defaults to ``None``. val_filepaths: String or sequence of file paths for validation dataset. Defaults to ``None``. @@ -484,6 +489,7 @@ def from_filepaths( seed: Used for the train/val splits. Returns: + ImageClassificationData: The constructed data module. """ # enable passing in a string which loads all files in that folder as a list diff --git a/flash_examples/finetuning/image_classification.py b/flash_examples/finetuning/image_classification.py index 4a4032ac5d..b85ecfe3e5 100644 --- a/flash_examples/finetuning/image_classification.py +++ b/flash_examples/finetuning/image_classification.py @@ -26,6 +26,7 @@ val_folder="data/hymenoptera_data/val/", test_folder="data/hymenoptera_data/test/", ) + # 3. Build the model model = ImageClassifier(num_classes=datamodule.num_classes) diff --git a/tests/data/test_auto_dataset.py b/tests/data/test_auto_dataset.py index 273b5aa870..4030e90b35 100644 --- a/tests/data/test_auto_dataset.py +++ b/tests/data/test_auto_dataset.py @@ -11,18 +11,21 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. +from typing import List import pytest from pytorch_lightning.trainer.states import RunningStage from flash.data.auto_dataset import AutoDataset +from flash.data.callback import FlashCallback from flash.data.data_pipeline import DataPipeline -from flash.data.process import Postprocess, Preprocess +from flash.data.process import Preprocess class _AutoDatasetTestPreprocess(Preprocess): def __init__(self, with_dset: bool): + self._callbacks: List[FlashCallback] = [] self.load_data_count = 0 self.load_sample_count = 0 self.load_sample_with_dataset_count = 0 diff --git a/tests/data/test_callback.py b/tests/data/test_callback.py new file mode 100644 index 0000000000..0bc47a91cd --- /dev/null +++ b/tests/data/test_callback.py @@ -0,0 +1,94 @@ +# Copyright The PyTorch Lightning team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from typing import Sequence, Tuple +from unittest import mock +from unittest.mock import ANY, call, MagicMock, Mock + +import torch +from pytorch_lightning.callbacks.base import Callback +from pytorch_lightning.trainer.states import RunningStage +from torch import Tensor + +from flash.core.model import Task +from flash.core.trainer import Trainer +from flash.data.data_module import DataModule +from flash.data.process import Preprocess + + +@mock.patch("torch.save") # need to mock torch.save or we get pickle error +def test_flash_callback(_, tmpdir): + """Test the callback hook system for fit.""" + + callback_mock = MagicMock() + + inputs = [[torch.rand(1), torch.rand(1)]] + dm = DataModule.from_load_data_inputs(inputs, inputs, inputs, None, num_workers=0) + dm.preprocess.callbacks += [callback_mock] + + _ = next(iter(dm.train_dataloader())) + + assert callback_mock.method_calls == [ + call.on_load_sample(ANY, RunningStage.TRAINING), + call.on_pre_tensor_transform(ANY, RunningStage.TRAINING), + call.on_to_tensor_transform(ANY, RunningStage.TRAINING), + call.on_post_tensor_transform(ANY, RunningStage.TRAINING), + call.on_collate(ANY, RunningStage.TRAINING), + call.on_per_batch_transform(ANY, RunningStage.TRAINING), + ] + + class CustomModel(Task): + + def __init__(self): + super().__init__(model=torch.nn.Linear(1, 1), loss_fn=torch.nn.MSELoss()) + + trainer = Trainer( + default_root_dir=tmpdir, + max_epochs=1, + limit_val_batches=1, + limit_train_batches=1, + progress_bar_refresh_rate=0, + ) + dm = DataModule.from_load_data_inputs(inputs, inputs, inputs, None, num_workers=0) + dm.preprocess.callbacks += [callback_mock] + trainer.fit(CustomModel(), datamodule=dm) + + assert callback_mock.method_calls == [ + call.on_load_sample(ANY, RunningStage.TRAINING), + call.on_pre_tensor_transform(ANY, RunningStage.TRAINING), + call.on_to_tensor_transform(ANY, RunningStage.TRAINING), + call.on_post_tensor_transform(ANY, RunningStage.TRAINING), + call.on_collate(ANY, RunningStage.TRAINING), + call.on_per_batch_transform(ANY, RunningStage.TRAINING), + call.on_load_sample(ANY, RunningStage.VALIDATING), + call.on_pre_tensor_transform(ANY, RunningStage.VALIDATING), + call.on_to_tensor_transform(ANY, RunningStage.VALIDATING), + call.on_post_tensor_transform(ANY, RunningStage.VALIDATING), + call.on_collate(ANY, RunningStage.VALIDATING), + call.on_per_batch_transform(ANY, RunningStage.VALIDATING), + call.on_per_batch_transform_on_device(ANY, RunningStage.VALIDATING), + call.on_load_sample(ANY, RunningStage.TRAINING), + call.on_pre_tensor_transform(ANY, RunningStage.TRAINING), + call.on_to_tensor_transform(ANY, RunningStage.TRAINING), + call.on_post_tensor_transform(ANY, RunningStage.TRAINING), + call.on_collate(ANY, RunningStage.TRAINING), + call.on_per_batch_transform(ANY, RunningStage.TRAINING), + call.on_per_batch_transform_on_device(ANY, RunningStage.TRAINING), + call.on_load_sample(ANY, RunningStage.VALIDATING), + call.on_pre_tensor_transform(ANY, RunningStage.VALIDATING), + call.on_to_tensor_transform(ANY, RunningStage.VALIDATING), + call.on_post_tensor_transform(ANY, RunningStage.VALIDATING), + call.on_collate(ANY, RunningStage.VALIDATING), + call.on_per_batch_transform(ANY, RunningStage.VALIDATING), + call.on_per_batch_transform_on_device(ANY, RunningStage.VALIDATING) + ] diff --git a/tests/data/test_data_pipeline.py b/tests/data/test_data_pipeline.py index ed7ebe60b9..b0556be3dd 100644 --- a/tests/data/test_data_pipeline.py +++ b/tests/data/test_data_pipeline.py @@ -11,7 +11,6 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. - from typing import Any, Callable, Dict, List, Optional, Tuple from unittest import mock @@ -44,31 +43,19 @@ def __len__(self) -> int: return 5 -class CustomModel(Task): - - def __init__(self, postprocess: Optional[Postprocess] = None): - super().__init__(model=torch.nn.Linear(1, 1), loss_fn=torch.nn.MSELoss()) - self._postprocess = postprocess - - def train_dataloader(self) -> Any: - return DataLoader(DummyDataset()) - - -class CustomDataModule(DataModule): - - def __init__(self): - super().__init__( - train_dataset=DummyDataset(), - val_dataset=DummyDataset(), - test_dataset=DummyDataset(), - predict_dataset=DummyDataset(), - ) - - @pytest.mark.parametrize("use_preprocess", [False, True]) @pytest.mark.parametrize("use_postprocess", [False, True]) def test_data_pipeline_init_and_assignement(use_preprocess, use_postprocess, tmpdir): + class CustomModel(Task): + + def __init__(self, postprocess: Optional[Postprocess] = None): + super().__init__(model=torch.nn.Linear(1, 1), loss_fn=torch.nn.MSELoss()) + self._postprocess = postprocess + + def train_dataloader(self) -> Any: + return DataLoader(DummyDataset()) + class SubPreprocess(Preprocess): pass @@ -82,7 +69,7 @@ class SubPostprocess(Postprocess): assert isinstance(data_pipeline._preprocess_pipeline, SubPreprocess if use_preprocess else Preprocess) assert isinstance(data_pipeline._postprocess_pipeline, SubPostprocess if use_postprocess else Postprocess) - model = CustomModel(Postprocess()) + model = CustomModel(postprocess=Postprocess()) model.data_pipeline = data_pipeline assert isinstance(model._preprocess, SubPreprocess if use_preprocess else Preprocess) assert isinstance(model._postprocess, SubPostprocess if use_postprocess else Postprocess) @@ -287,6 +274,15 @@ def test_data_pipeline_predict_worker_preprocessor_and_device_preprocessor(): def test_detach_preprocessing_from_model(tmpdir): + class CustomModel(Task): + + def __init__(self, postprocess: Optional[Postprocess] = None): + super().__init__(model=torch.nn.Linear(1, 1), loss_fn=torch.nn.MSELoss()) + self._postprocess = postprocess + + def train_dataloader(self) -> Any: + return DataLoader(DummyDataset()) + preprocess = CustomPreprocess() data_pipeline = DataPipeline(preprocess) model = CustomModel() @@ -334,9 +330,37 @@ def predict_per_batch_transform_on_device(self, *_, **__): def test_attaching_datapipeline_to_model(tmpdir): - preprocess = TestPreprocess() + preprocess = Preprocess() data_pipeline = DataPipeline(preprocess) + class CustomModel(Task): + + _postprocess = Postprocess() + + def __init__(self): + super().__init__(model=torch.nn.Linear(1, 1), loss_fn=torch.nn.MSELoss()) + + def training_step(self, batch: Any, batch_idx: int) -> Any: + pass + + def validation_step(self, batch: Any, batch_idx: int) -> Any: + pass + + def test_step(self, batch: Any, batch_idx: int) -> Any: + pass + + def train_dataloader(self) -> Any: + return DataLoader(DummyDataset()) + + def val_dataloader(self) -> Any: + return DataLoader(DummyDataset()) + + def test_dataloader(self) -> Any: + return DataLoader(DummyDataset()) + + def predict_dataloader(self) -> Any: + return DataLoader(DummyDataset()) + class TestModel(CustomModel): stages = [RunningStage.TRAINING, RunningStage.VALIDATING, RunningStage.TESTING, RunningStage.PREDICTING] @@ -430,11 +454,10 @@ def on_fit_end(self) -> None: assert not isinstance(self.transfer_batch_to_device, _StageOrchestrator) assert self.predict_step == self._saved_predict_step - datamodule = CustomDataModule() - datamodule._data_pipeline = data_pipeline model = TestModel() + model.data_pipeline = data_pipeline trainer = Trainer(fast_dev_run=True) - trainer.fit(model, datamodule=datamodule) + trainer.fit(model) trainer.test(model) trainer.predict(model) @@ -497,11 +520,20 @@ def __init__(self): self.test_post_tensor_transform_called = False self.predict_load_data_called = False + @staticmethod + def fn_train_load_data() -> Tuple: + return ( + 0, + 1, + 2, + 3, + ) + def train_load_data(self, sample) -> LamdaDummyDataset: assert self.training assert self.current_fn == "load_data" self.train_load_data_called = True - return LamdaDummyDataset(lambda: (0, 1, 2, 3)) + return LamdaDummyDataset(self.fn_train_load_data) def train_pre_tensor_transform(self, sample: Any) -> Any: assert self.training @@ -557,11 +589,15 @@ def val_per_batch_transform_on_device(self, batch: Any) -> Any: assert torch.equal(batch["b"], tensor([1, 2])) return [False] + @staticmethod + def fn_test_load_data() -> List[torch.Tensor]: + return [torch.rand(1), torch.rand(1)] + def test_load_data(self, sample) -> LamdaDummyDataset: assert self.testing assert self.current_fn == "load_data" self.test_load_data_called = True - return LamdaDummyDataset(lambda: [torch.rand(1), torch.rand(1)]) + return LamdaDummyDataset(self.fn_test_load_data) def test_to_tensor_transform(self, sample: Any) -> Tensor: assert self.testing @@ -575,11 +611,15 @@ def test_post_tensor_transform(self, sample: Tensor) -> Tensor: self.test_post_tensor_transform_called = True return sample + @staticmethod + def fn_predict_load_data() -> List[str]: + return (["a", "b"]) + def predict_load_data(self, sample) -> LamdaDummyDataset: assert self.predicting assert self.current_fn == "load_data" self.predict_load_data_called = True - return LamdaDummyDataset(lambda: (["a", "b"])) + return LamdaDummyDataset(self.fn_predict_load_data) class TestPreprocessTransformations2(TestPreprocessTransformations): @@ -589,33 +629,35 @@ def val_to_tensor_transform(self, sample: Any) -> Tensor: return {"a": tensor(sample["a"]), "b": tensor(sample["b"])} -def test_datapipeline_transformations(tmpdir): +class CustomModel(Task): - class CustomModel(Task): + def __init__(self): + super().__init__(model=torch.nn.Linear(1, 1), loss_fn=torch.nn.MSELoss()) - def __init__(self): - super().__init__(model=torch.nn.Linear(1, 1), loss_fn=torch.nn.MSELoss()) + def training_step(self, batch, batch_idx): + assert batch is None - def training_step(self, batch, batch_idx): - assert batch is None + def validation_step(self, batch, batch_idx): + assert batch is False - def validation_step(self, batch, batch_idx): - assert batch is False + def test_step(self, batch, batch_idx): + assert len(batch) == 2 + assert batch[0].shape == torch.Size([2, 1]) - def test_step(self, batch, batch_idx): - assert len(batch) == 2 - assert batch[0].shape == torch.Size([2, 1]) + def predict_step(self, batch, batch_idx, dataloader_idx): + assert batch[0][0] == 'a' + assert batch[0][1] == 'a' + assert batch[1][0] == 'b' + assert batch[1][1] == 'b' + return tensor([0, 0, 0]) - def predict_step(self, batch, batch_idx, dataloader_idx): - assert batch[0][0] == 'a' - assert batch[0][1] == 'a' - assert batch[1][0] == 'b' - assert batch[1][1] == 'b' - return tensor([0, 0, 0]) - class CustomDataModule(DataModule): +class CustomDataModule(DataModule): - preprocess_cls = TestPreprocessTransformations + preprocess_cls = TestPreprocessTransformations + + +def test_datapipeline_transformations(tmpdir): datamodule = CustomDataModule.from_load_data_inputs(1, 1, 1, 1, batch_size=2, num_workers=0) diff --git a/tests/data/test_data_viz.py b/tests/data/test_data_viz.py new file mode 100644 index 0000000000..e23009c386 --- /dev/null +++ b/tests/data/test_data_viz.py @@ -0,0 +1,134 @@ +# Copyright The PyTorch Lightning team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from pathlib import Path +from typing import Any, List, Sequence + +import numpy as np +import torch +from PIL import Image +from pytorch_lightning import seed_everything +from pytorch_lightning.trainer.states import RunningStage + +from flash.data.base_viz import BaseViz +from flash.data.utils import _STAGES_PREFIX +from flash.vision import ImageClassificationData + + +def _rand_image(): + return Image.fromarray(np.random.randint(0, 255, (196, 196, 3), dtype="uint8")) + + +def test_base_viz(tmpdir): + + seed_everything(42) + tmpdir = Path(tmpdir) + + (tmpdir / "a").mkdir() + (tmpdir / "b").mkdir() + _rand_image().save(tmpdir / "a" / "a_1.png") + _rand_image().save(tmpdir / "a" / "a_2.png") + + _rand_image().save(tmpdir / "b" / "a_1.png") + _rand_image().save(tmpdir / "b" / "a_2.png") + + class CustomBaseViz(BaseViz): + + show_load_sample_called = False + show_pre_tensor_transform_called = False + show_to_tensor_transform_called = False + show_post_tensor_transform_called = False + show_collate_called = False + per_batch_transform_called = False + + def show_load_sample(self, samples: List[Any], running_stage: RunningStage): + self.show_load_sample_called = True + + def show_pre_tensor_transform(self, samples: List[Any], running_stage: RunningStage): + self.show_pre_tensor_transform_called = True + + def show_to_tensor_transform(self, samples: List[Any], running_stage: RunningStage): + self.show_to_tensor_transform_called = True + + def show_post_tensor_transform(self, samples: List[Any], running_stage: RunningStage): + self.show_post_tensor_transform_called = True + + def show_collate(self, batch: Sequence, running_stage: RunningStage) -> None: + self.show_collate_called = True + + def show_per_batch_transform(self, batch: Sequence, running_stage: RunningStage) -> None: + self.per_batch_transform_called = True + + def reset(self): + self.show_load_sample_called = False + self.show_pre_tensor_transform_called = False + self.show_to_tensor_transform_called = False + self.show_post_tensor_transform_called = False + self.show_collate_called = False + self.per_batch_transform_called = False + + class CustomImageClassificationData(ImageClassificationData): + + @staticmethod + def configure_vis(*args, **kwargs) -> CustomBaseViz: + return CustomBaseViz(*args, **kwargs) + + dm = CustomImageClassificationData.from_filepaths( + train_filepaths=[tmpdir / "a", tmpdir / "b"], + train_labels=[0, 1], + val_filepaths=[tmpdir / "a", tmpdir / "b"], + val_labels=[0, 1], + test_filepaths=[tmpdir / "a", tmpdir / "b"], + test_labels=[0, 1], + predict_filepaths=[tmpdir / "a", tmpdir / "b"], + batch_size=2, + num_workers=0, + ) + + for stage in _STAGES_PREFIX.values(): + + for _ in range(10): + getattr(dm, f"show_{stage}_batch")(reset=False) + + is_predict = stage == "predict" + + def extract_data(data): + if not is_predict: + return data[0][0] + return data[0] + + assert isinstance(extract_data(dm.viz.batches[stage]["load_sample"]), Image.Image) + if not is_predict: + assert isinstance(dm.viz.batches[stage]["load_sample"][0][1], int) + + assert isinstance(extract_data(dm.viz.batches[stage]["to_tensor_transform"]), torch.Tensor) + if not is_predict: + assert isinstance(dm.viz.batches[stage]["to_tensor_transform"][0][1], int) + + assert extract_data(dm.viz.batches[stage]["collate"]).shape == torch.Size([2, 3, 196, 196]) + if not is_predict: + assert dm.viz.batches[stage]["collate"][0][1].shape == torch.Size([2]) + + generated = extract_data(dm.viz.batches[stage]["per_batch_transform"]).shape + assert generated == torch.Size([2, 3, 196, 196]) + if not is_predict: + assert dm.viz.batches[stage]["per_batch_transform"][0][1].shape == torch.Size([2]) + + assert dm.viz.show_load_sample_called + assert dm.viz.show_pre_tensor_transform_called + assert dm.viz.show_to_tensor_transform_called + assert dm.viz.show_post_tensor_transform_called + assert dm.viz.show_collate_called + assert dm.viz.per_batch_transform_called + dm.viz.reset() diff --git a/tests/vision/classification/test_data.py b/tests/vision/classification/test_data.py index e90a5f838b..ab5cd66c8c 100644 --- a/tests/vision/classification/test_data.py +++ b/tests/vision/classification/test_data.py @@ -178,9 +178,7 @@ def test_from_folders(tmpdir): _rand_image().save(train_dir / "b" / "1.png") _rand_image().save(train_dir / "b" / "2.png") - img_data = ImageClassificationData.from_folders( - train_dir, train_transform=None, loader=_dummy_image_loader, batch_size=1 - ) + img_data = ImageClassificationData.from_folders(train_dir, train_transform=None, batch_size=1) data = next(iter(img_data.train_dataloader())) imgs, labels = data assert imgs.shape == (1, 3, 196, 196)