From 168b2319a478f981a4ae3ae9418c7db018d8c33a Mon Sep 17 00:00:00 2001 From: tchaton Date: Tue, 30 Mar 2021 18:30:00 +0100 Subject: [PATCH 01/30] wip --- flash/data/base_viz.py | 54 ++++++++++++++++++++++++++ flash/data/data_pipeline.py | 4 +- tests/data/test_base_viz.py | 77 +++++++++++++++++++++++++++++++++++++ 3 files changed, 133 insertions(+), 2 deletions(-) create mode 100644 flash/data/base_viz.py create mode 100644 tests/data/test_base_viz.py diff --git a/flash/data/base_viz.py b/flash/data/base_viz.py new file mode 100644 index 0000000000..3e271afd6d --- /dev/null +++ b/flash/data/base_viz.py @@ -0,0 +1,54 @@ +import functools +from typing import Any, Callable + +from pytorch_lightning.callbacks import Callback +from pytorch_lightning.trainer.states import RunningStage + +from flash.data.data_module import DataModule +from flash.data.data_pipeline import DataPipeline +from flash.data.process import Postprocess, Preprocess + + +class BaseViz(Callback): + + def __init__(self, datamodule: DataModule): + self._datamodule = datamodule + self._wrap_preprocess() + + self.batches = {"train": {}, "val": {}, "test": {}, "predict": {}} + + def _wrap_fn( + self, + fn: Callable, + running_stage: RunningStage, + ) -> Callable: + """ + """ + + @functools.wraps(fn) + def wrapper(data) -> Any: + print(data) + data = fn(data) + print(data) + batches = self.batches[running_stage.value] + if fn.__name__ not in batches: + batches[fn.__name__] = [] + batches[fn.__name__].append(data) + return data + + return wrapper + + def _wrap_functions_per_stage(self, running_stage: RunningStage): + preprocess = self._datamodule.data_pipeline._preprocess_pipeline + fn_names = { + k: DataPipeline._resolve_function_hierarchy(k, preprocess, running_stage, Preprocess) + for k in DataPipeline.PREPROCESS_FUNCS + } + for fn_name in fn_names: + fn = getattr(preprocess, fn_name) + setattr(preprocess, fn_name, self._wrap_fn(fn, running_stage)) + + self._datamodule._train_ds.load_sample = preprocess.load_sample + + def _wrap_preprocess(self): + self._wrap_functions_per_stage(RunningStage.TRAINING) diff --git a/flash/data/data_pipeline.py b/flash/data/data_pipeline.py index b50e468c50..226e6bf1ab 100644 --- a/flash/data/data_pipeline.py +++ b/flash/data/data_pipeline.py @@ -108,11 +108,11 @@ def forward(self, samples: Sequence[Any]): post_tensor_transform │ ┌────────────────┴───────────────────┐ -(move Data to main worker) --> │ │ +(move list to main worker) --> │ │ per_sample_transform_on_device collate │ │ collate per_batch_transform - │ │ <-- (move Data to main worker) + │ │ <-- (move batch to main worker) per_batch_transform_on_device per_batch_transform_on_device │ │ └─────────────────┬──────────────────┘ diff --git a/tests/data/test_base_viz.py b/tests/data/test_base_viz.py new file mode 100644 index 0000000000..20fc836f51 --- /dev/null +++ b/tests/data/test_base_viz.py @@ -0,0 +1,77 @@ +# Copyright The PyTorch Lightning team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from pathlib import Path +from typing import Any, Callable, Dict, List, Optional, Tuple +from unittest import mock + +import numpy as np +import pytest +import torch +import torchvision.transforms as T +from PIL import Image +from pytorch_lightning import Trainer +from pytorch_lightning.trainer.states import RunningStage +from pytorch_lightning.utilities.exceptions import MisconfigurationException +from torch import Tensor, tensor +from torch.utils.data import DataLoader +from torch.utils.data._utils.collate import default_collate + +from flash.core import Task +from flash.data.auto_dataset import AutoDataset +from flash.data.base_viz import BaseViz +from flash.data.batch import _PostProcessor, _PreProcessor +from flash.data.data_module import DataModule +from flash.data.data_pipeline import _StageOrchestrator, DataPipeline +from flash.data.process import Postprocess, Preprocess +from flash.vision import ImageClassificationData + + +def _rand_image(): + return Image.fromarray(np.random.randint(0, 255, (196, 196, 3), dtype="uint8")) + + +class ImageClassificationDataViz(ImageClassificationData): + + def configure_vis(self): + if not hasattr(self, "viz"): + return BaseViz(self) + return self.viz + + def show_train_batch(self): + self.viz = self.configure_vis() + _ = next(iter(self.train_dataloader())) + + +def test_base_viz(tmpdir): + tmpdir = Path(tmpdir) + + (tmpdir / "a").mkdir() + (tmpdir / "b").mkdir() + _rand_image().save(tmpdir / "a" / "a_1.png") + _rand_image().save(tmpdir / "a" / "a_2.png") + + _rand_image().save(tmpdir / "b" / "a_1.png") + _rand_image().save(tmpdir / "b" / "a_2.png") + + img_data = ImageClassificationDataViz.from_filepaths( + train_filepaths=[tmpdir / "a", tmpdir / "b"], + train_transform=None, + train_labels=[0, 1], + batch_size=1, + num_workers=0, + ) + + img_data.show_train_batch() + assert img_data.viz.batches["train"]["load_sample"] is not None From cda64d3927412b870e500e6dcfb1c101b4b34687 Mon Sep 17 00:00:00 2001 From: tchaton Date: Wed, 31 Mar 2021 10:23:14 +0100 Subject: [PATCH 02/30] add base_viz + new features for DataPipeline --- flash/data/auto_dataset.py | 36 +++---- flash/data/base_viz.py | 1 + flash/data/batch.py | 65 ++++++++---- flash/data/data_module.py | 4 + flash/data/data_pipeline.py | 13 ++- flash/data/process.py | 40 +++++++- flash/data/utils.py | 36 +++++++ flash/vision/classification/data.py | 154 ++++++++-------------------- tests/data/test_base_viz.py | 3 + tests/data/test_data_pipeline.py | 50 ++++++--- 10 files changed, 230 insertions(+), 172 deletions(-) diff --git a/flash/data/auto_dataset.py b/flash/data/auto_dataset.py index be6e32038e..e42a4cf680 100644 --- a/flash/data/auto_dataset.py +++ b/flash/data/auto_dataset.py @@ -20,7 +20,7 @@ from torch.utils.data import Dataset from flash.data.process import Preprocess -from flash.data.utils import _STAGES_PREFIX, _STAGES_PREFIX_VALUES +from flash.data.utils import _STAGES_PREFIX, _STAGES_PREFIX_VALUES, set_current_stage_and_fn if TYPE_CHECKING: from flash.data.data_pipeline import DataPipeline @@ -73,12 +73,18 @@ def running_stage(self, running_stage: str) -> None: self._running_stage = running_stage self._setup(running_stage) + @property + def _preprocess(self): + if self.data_pipeline is not None: + return self.data_pipeline._preprocess_pipeline + def _call_load_data(self, data: Any) -> Iterable: parameters = signature(self.load_data).parameters - if len(parameters) > 1 and self.DATASET_KEY in parameters: - return self.load_data(data, self) - else: - return self.load_data(data) + with set_current_stage_and_fn(self._preprocess, self._running_stage, "load_data"): + if len(parameters) > 1 and self.DATASET_KEY in parameters: + return self.load_data(data, self) + else: + return self.load_data(data) def _call_load_sample(self, sample: Any) -> Any: parameters = signature(self.load_sample).parameters @@ -110,26 +116,16 @@ def _setup(self, stage: Optional[RunningStage]) -> None: "The load_data function of the Autogenerated Dataset changed. " "This is not expected! Preloading Data again to ensure compatibility. This may take some time." ) - with self._set_running_stage(stage): - self._preprocessed_data = self._call_load_data(self.data) + self._preprocessed_data = self._call_load_data(self.data) self._load_data_called = True - @contextmanager - def _set_running_stage(self, stage: RunningStage) -> None: - if self.load_data: - if self.data_pipeline and self.data_pipeline._preprocess_pipeline: - self.data_pipeline._preprocess_pipeline._running_stage = stage - yield - if self.load_data: - if self.data_pipeline and self.data_pipeline._preprocess_pipeline: - self.data_pipeline._preprocess_pipeline._running_stage = None - def __getitem__(self, index: int) -> Any: if not self.load_sample and not self.load_data: raise RuntimeError("`__getitem__` for `load_sample` and `load_data` could not be inferred.") - if self.load_sample: - return self._call_load_sample(self._preprocessed_data[index]) - return self._preprocessed_data[index] + with set_current_stage_and_fn(self._preprocess, self._running_stage, "load_sample"): + if self.load_sample: + return self._call_load_sample(self._preprocessed_data[index]) + return self._preprocessed_data[index] def __len__(self) -> int: if not self.load_sample and not self.load_data: diff --git a/flash/data/base_viz.py b/flash/data/base_viz.py index 3e271afd6d..fb50168d9f 100644 --- a/flash/data/base_viz.py +++ b/flash/data/base_viz.py @@ -48,6 +48,7 @@ def _wrap_functions_per_stage(self, running_stage: RunningStage): fn = getattr(preprocess, fn_name) setattr(preprocess, fn_name, self._wrap_fn(fn, running_stage)) + # hack until solved self._datamodule._train_ds.load_sample = preprocess.load_sample def _wrap_preprocess(self): diff --git a/flash/data/batch.py b/flash/data/batch.py index d6262b1e49..1047e85a44 100644 --- a/flash/data/batch.py +++ b/flash/data/batch.py @@ -11,14 +11,17 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -from typing import Any, Callable, Mapping, Optional, Sequence, Union +from typing import Any, Callable, Mapping, Optional, Sequence, TYPE_CHECKING, Union import torch from pytorch_lightning.trainer.states import RunningStage from pytorch_lightning.utilities.exceptions import MisconfigurationException from torch import Tensor -from flash.data.utils import _contains_any_tensor, convert_to_modules +from flash.data.utils import _contains_any_tensor, convert_to_modules, set_current_fn, set_current_stage + +if TYPE_CHECKING: + from flash.data.process import Preprocess class _Sequential(torch.nn.Module): @@ -31,29 +34,40 @@ class _Sequential(torch.nn.Module): def __init__( self, + preprocess: 'Preprocess', pre_tensor_transform: Callable, to_tensor_transform: Callable, post_tensor_transform: Callable, - assert_contains_tensor: bool = False + stage: RunningStage, + assert_contains_tensor: bool = False, ): super().__init__() - + self.preprocess = preprocess self.pre_tensor_transform = convert_to_modules(pre_tensor_transform) self.to_tensor_transform = convert_to_modules(to_tensor_transform) self.post_tensor_transform = convert_to_modules(post_tensor_transform) + self.stage = stage self.assert_contains_tensor = assert_contains_tensor def forward(self, sample: Any): - sample = self.pre_tensor_transform(sample) - sample = self.to_tensor_transform(sample) - if self.assert_contains_tensor: - if not _contains_any_tensor(sample): - raise MisconfigurationException( - "When ``to_tensor_transform`` is overriden, " - "``DataPipeline`` expects the outputs to be ``tensors``" - ) - sample = self.post_tensor_transform(sample) - return sample + with set_current_stage(self.preprocess, self.stage): + with set_current_fn(self.preprocess, "pre_tensor_transform"): + sample = self.pre_tensor_transform(sample) + + with set_current_fn(self.preprocess, "to_tensor_transform"): + sample = self.to_tensor_transform(sample) + + if self.assert_contains_tensor: + if not _contains_any_tensor(sample): + raise MisconfigurationException( + "When ``to_tensor_transform`` is overriden, " + "``DataPipeline`` expects the outputs to be ``tensors``" + ) + + with set_current_fn(self.preprocess, "post_tensor_transform"): + sample = self.post_tensor_transform(sample) + + return sample def __str__(self) -> str: repr_str = f'{self.__class__.__name__}:' @@ -87,26 +101,37 @@ class _PreProcessor(torch.nn.Module): def __init__( self, + preprocess: 'Preprocess', collate_fn: Callable, per_sample_transform: Union[Callable, _Sequential], per_batch_transform: Callable, stage: Optional[RunningStage] = None, apply_per_sample_transform: bool = True, + on_device: bool = False ): super().__init__() + self.preprocess = preprocess self.collate_fn = convert_to_modules(collate_fn) self.per_sample_transform = convert_to_modules(per_sample_transform) self.per_batch_transform = convert_to_modules(per_batch_transform) self.apply_per_sample_transform = apply_per_sample_transform self.stage = stage + self.on_device = on_device def forward(self, samples: Sequence[Any]): - if self.apply_per_sample_transform: - samples = [self.per_sample_transform(sample) for sample in samples] - samples = type(samples)(samples) - samples = self.collate_fn(samples) - samples = self.per_batch_transform(samples) - return samples + with set_current_stage(self.preprocess, self.stage): + + if self.apply_per_sample_transform: + with set_current_fn(self.preprocess, f"per_sample_transform_{'on_device' if self.on_device else ''}"): + samples = [self.per_sample_transform(sample) for sample in samples] + samples = type(samples)(samples) + + with set_current_fn(self.preprocess, "collate"): + samples = self.collate_fn(samples) + + with set_current_fn(self.preprocess, f"per_batch_transform_{'on_device' if self.on_device else ''}"): + samples = self.per_batch_transform(samples) + return samples def __str__(self) -> str: # todo: define repr function which would take object and string attributes to be shown diff --git a/flash/data/data_module.py b/flash/data/data_module.py index 641eff21d7..f998c62ad1 100644 --- a/flash/data/data_module.py +++ b/flash/data/data_module.py @@ -87,6 +87,9 @@ def __init__( # this may also trigger data preloading self.set_running_stages() + def configure_vis(self): + return self + @staticmethod def get_dataset_attribute(dataset: torch.utils.data.Dataset, attr_name: str, default: Optional[Any] = None) -> Any: if isinstance(dataset, Subset): @@ -340,4 +343,5 @@ def from_load_data_inputs( ) datamodule._preprocess = data_pipeline._preprocess_pipeline datamodule._postprocess = data_pipeline._postprocess_pipeline + datamodule.configure_vis() return datamodule diff --git a/flash/data/data_pipeline.py b/flash/data/data_pipeline.py index 226e6bf1ab..2aca209f3d 100644 --- a/flash/data/data_pipeline.py +++ b/flash/data/data_pipeline.py @@ -12,6 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. import functools +import inspect import weakref from typing import Any, Callable, Iterable, Optional, Sequence, Tuple, Type, TYPE_CHECKING, Union @@ -181,7 +182,7 @@ def _is_overriden_recursive( if not hasattr(process_obj, current_method_name): return DataPipeline._is_overriden_recursive(method_name, process_obj, super_obj) - current_code = getattr(process_obj, current_method_name).__code__ + current_code = inspect.unwrap(getattr(process_obj, current_method_name)).__code__ has_different_code = current_code != getattr(super_obj, method_name).__code__ if not prefix: @@ -257,7 +258,7 @@ def _create_collate_preprocessors( if per_batch_transform_overriden and per_sample_transform_on_device_overriden: raise MisconfigurationException( - f'{self.__class__.__name__}: `per_batch_transform` and `gpu_per_sample_transform` ' + f'{self.__class__.__name__}: `per_batch_transform` and `per_sample_transform_on_device` ' f'are mutual exclusive for stage {stage}' ) @@ -282,21 +283,25 @@ def _create_collate_preprocessors( ) worker_preprocessor = _PreProcessor( - worker_collate_fn, + self._preprocess_pipeline, worker_collate_fn, _Sequential( + self._preprocess_pipeline, getattr(self._preprocess_pipeline, func_names['pre_tensor_transform']), getattr(self._preprocess_pipeline, func_names['to_tensor_transform']), getattr(self._preprocess_pipeline, func_names['post_tensor_transform']), + stage, assert_contains_tensor=assert_contains_tensor, ), getattr(self._preprocess_pipeline, func_names['per_batch_transform']), stage ) worker_preprocessor._original_collate_fn = original_collate_fn device_preprocessor = _PreProcessor( + self._preprocess_pipeline, device_collate_fn, getattr(self._preprocess_pipeline, func_names['per_sample_transform_on_device']), getattr(self._preprocess_pipeline, func_names['per_batch_transform_on_device']), stage, - apply_per_sample_transform=device_collate_fn != self._identity + apply_per_sample_transform=device_collate_fn != self._identity, + on_device=True, ) return worker_preprocessor, device_preprocessor diff --git a/flash/data/process.py b/flash/data/process.py index f61220dc11..35ecda3993 100644 --- a/flash/data/process.py +++ b/flash/data/process.py @@ -27,7 +27,24 @@ class Properties: - _running_stage = None + _running_stage: RunningStage = None + _current_fn: str = None + + @property + def current_fn(self) -> str: + return self._current_fn + + @current_fn.setter + def current_fn(self, current_fn: str): + self._current_fn = current_fn + + @property + def running_stage(self) -> RunningStage: + return self._running_stage + + @running_stage.setter + def running_stage(self, running_stage: RunningStage): + self._running_stage = running_stage @property def training(self) -> bool: @@ -97,6 +114,27 @@ def __init__( self.test_transform = convert_to_modules(test_transform) self.predict_transform = convert_to_modules(predict_transform) + def _identify(self, x): + return x + + def _get_transform(self, transform: Dict[str, Callable]): + if self.current_fn in transform: + return transform[self.current_fn] + return self._identify + + @property + def current_transform(self): + if self.training and self.train_transform: + return self._get_transform(self.train_transform) + elif self.validating and self.val_transform: + return self._get_transform(self.val_transform) + elif self.testing and self.test_transform: + return self._get_transform(self.test_transform) + elif self.predicting and self.predict_transform: + return self._get_transform(self.predict_transform) + else: + return self._identify + @classmethod def from_state(cls, state: PreprocessState) -> 'Preprocess': return cls(**vars(state)) diff --git a/flash/data/utils.py b/flash/data/utils.py index 4be6d177ba..cad73d3258 100644 --- a/flash/data/utils.py +++ b/flash/data/utils.py @@ -14,6 +14,7 @@ import os.path import zipfile +from contextlib import contextmanager from typing import Any, Callable, Dict, Iterable, Mapping, Type import requests @@ -32,6 +33,41 @@ _STAGES_PREFIX_VALUES = {"train", "test", "val", "predict"} +@contextmanager +def set_current_stage(obj: Any, stage: RunningStage) -> None: + if obj is not None: + if getattr(obj, "_running_stage", None) == stage: + yield + else: + obj.running_stage = stage + yield + obj.running_stage = None + else: + yield + + +@contextmanager +def set_current_fn(obj: Any, current_fn: str) -> None: + if obj is not None: + obj.current_fn = current_fn + yield + obj.current_fn = None + else: + yield + + +@contextmanager +def set_current_stage_and_fn(obj: Any, stage: RunningStage, current_fn: str) -> None: + if obj is not None: + obj.running_stage = stage + obj.current_fn = current_fn + yield + obj.running_stage = None + obj.current_fn = None + else: + yield + + def download_data(url: str, path: str = "data/", verbose: bool = False) -> None: """ Download file with progressbar diff --git a/flash/vision/classification/data.py b/flash/vision/classification/data.py index 37baff9440..6f9cb8bb36 100644 --- a/flash/vision/classification/data.py +++ b/flash/vision/classification/data.py @@ -41,7 +41,6 @@ class ImageClassificationPreprocess(Preprocess): - to_tensor = torchvision_T.ToTensor() @staticmethod def _find_classes(dir: str) -> Tuple: @@ -112,7 +111,7 @@ def load_data(cls, data: Any, dataset: Optional[AutoDataset] = None) -> Iterable return cls._load_data_files_labels(data=data, dataset=dataset) @staticmethod - def load_sample(sample) -> Union[Image.Image, list]: + def load_sample(sample) -> Union[Image.Image]: # open path as file to avoid ResourceWarning (https://github.com/python-pillow/Pillow/issues/835) if isinstance(sample, torch.Tensor): return sample @@ -138,25 +137,6 @@ def predict_load_data(cls, samples: Any) -> Iterable: return samples return cls._get_predicting_files(samples) - def _convert_tensor_to_pil(self, sample): - # some datasets provide their data as tensors. - # however, it would be better to convert those data once in load_data - if isinstance(sample, torch.Tensor): - sample = to_pil_image(sample) - return sample - - def _apply_transform( - self, sample: Any, transform: Union[Callable, Dict[str, Callable]], func_name: str - ) -> torch.Tensor: - if transform is not None: - if isinstance(transform, (Dict, ModuleDict)): - if func_name not in transform: - return sample - else: - transform = transform[func_name] - sample = transform(sample) - return sample - def collate(self, samples: Sequence) -> Any: _samples = [] # todo: Kornia transforms add batch dimension which need to be removed @@ -168,56 +148,28 @@ def collate(self, samples: Sequence) -> Any: _samples.append(sample) return default_collate(_samples) - def common_pre_tensor_transform(self, sample: Any, transform) -> Any: - return self._apply_transform(sample, transform, "pre_tensor_transform") - - def train_pre_tensor_transform(self, sample: Any) -> Any: - source, target = sample - return self.common_pre_tensor_transform(source, self.train_transform), target - - def val_pre_tensor_transform(self, sample: Any) -> Any: - source, target = sample - return self.common_pre_tensor_transform(source, self.val_transform), target - - def test_pre_tensor_transform(self, sample: Any) -> Any: - source, target = sample - return self.common_pre_tensor_transform(source, self.test_transform), target - - def predict_pre_tensor_transform(self, sample: Any) -> Any: + def common_step(self, sample: Any) -> Any: + if isinstance(sample, (list, tuple)): + source, target = sample + return self.current_transform(source), target if isinstance(sample, torch.Tensor): return sample - return self.common_pre_tensor_transform(sample, self.predict_transform) + return self.current_transform(sample) - def to_tensor_transform(self, sample) -> Any: - source, target = sample - return source if isinstance(source, torch.Tensor) else self.to_tensor(source), target + def per_tensor_transform(self, sample: Any) -> Any: + return self.common_step(sample) - def predict_to_tensor_transform(self, sample) -> Any: - if isinstance(sample, torch.Tensor): - return sample - return self.to_tensor(sample) - - def common_post_tensor_transform(self, sample: Any, transform) -> Any: - return self._apply_transform(sample, transform, "post_tensor_transform") + def to_tensor_transform(self, sample: Any) -> Any: + return self.common_step(sample) - def train_post_tensor_transform(self, sample: Any) -> Any: - source, target = sample - return self.common_post_tensor_transform(source, self.train_transform), target + def post_tensor_transform(self, sample: Any) -> Any: + return self.common_step(sample) - def val_post_tensor_transform(self, sample: Any) -> Any: - source, target = sample - return self.common_post_tensor_transform(source, self.val_transform), target + def per_batch_transform(self, sample: Any) -> Any: + return self.common_step(sample) - def test_post_tensor_transform(self, sample: Any) -> Any: - source, target = sample - return self.common_post_tensor_transform(source, self.test_transform), target - - def predict_post_tensor_transform(self, sample: Any) -> Any: - return self.common_post_tensor_transform(sample, self.predict_transform) - - def train_per_batch_transform_on_device(self, batch: Tuple) -> Tuple: - batch, target = batch - return self._apply_transform(batch, self.train_transform, "per_batch_transform_on_device"), target + def per_batch_transform_on_device(self, sample: Any) -> Any: + return self.common_step(sample) class ImageClassificationData(DataModule): @@ -285,6 +237,7 @@ def default_train_transforms(): if _KORNIA_AVAILABLE and not os.getenv("FLASH_TESTING", "0") == "1": # Better approach as all transforms are applied on tensor directly return { + "to_tensor_transform": torchvision_T.ToTensor(), "post_tensor_transform": nn.Sequential(K.RandomResizedCrop(image_size), K.RandomHorizontalFlip()), "per_batch_transform_on_device": nn.Sequential( K.Normalize(torch.tensor([0.485, 0.456, 0.406]), torch.tensor([0.229, 0.224, 0.225])), @@ -294,6 +247,7 @@ def default_train_transforms(): from torchvision import transforms as T # noqa F811 return { "pre_tensor_transform": nn.Sequential(T.RandomResizedCrop(image_size), T.RandomHorizontalFlip()), + "to_tensor_transform": torchvision_T.ToTensor(), "post_tensor_transform": T.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]), } @@ -303,6 +257,7 @@ def default_val_transforms(): if _KORNIA_AVAILABLE and not os.getenv("FLASH_TESTING", "0") == "1": # Better approach as all transforms are applied on tensor directly return { + "to_tensor_transform": torchvision_T.ToTensor(), "post_tensor_transform": nn.Sequential(K.RandomResizedCrop(image_size)), "per_batch_transform_on_device": nn.Sequential( K.Normalize(torch.tensor([0.485, 0.456, 0.406]), torch.tensor([0.229, 0.224, 0.225])), @@ -312,6 +267,7 @@ def default_val_transforms(): from torchvision import transforms as T # noqa F811 return { "pre_tensor_transform": T.Compose([T.RandomResizedCrop(image_size)]), + "to_tensor_transform": torchvision_T.ToTensor(), "post_tensor_transform": T.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]), } @@ -471,23 +427,24 @@ def from_filepaths( test_filepaths: Optional[Union[str, pathlib.Path, Sequence[Union[str, pathlib.Path]]]] = None, test_labels: Optional[Sequence] = None, predict_filepaths: Optional[Union[str, pathlib.Path, Sequence[Union[str, pathlib.Path]]]] = None, - train_transform: Optional[Callable] = 'default', - val_transform: Optional[Callable] = 'default', + train_transform: Union[str, Dict] = 'default', + val_transform: Union[str, Dict] = 'default', + test_transform: Union[str, Dict] = 'default', + predict_transform: Union[str, Dict] = 'default', batch_size: int = 64, num_workers: Optional[int] = None, seed: Optional[int] = 42, + preprocess_cls: Optional[Type[Preprocess]] = None, **kwargs, ) -> 'ImageClassificationData': """ Creates a ImageClassificationData object from folders of images arranged in this way: :: - folder/dog_xxx.png folder/dog_xxy.png folder/dog_xxz.png folder/cat_123.png folder/cat_nsdf3.png folder/cat_asd932_.png - Args: train_filepaths: String or sequence of file paths for training dataset. Defaults to ``None``. train_labels: Sequence of labels for training dataset. Defaults to ``None``. @@ -502,19 +459,14 @@ def from_filepaths( num_workers: The number of workers to use for parallelized loading. Defaults to ``None`` which equals the number of available CPU threads. seed: Used for the train/val splits. - Returns: ImageClassificationData: The constructed data module. - Examples: >>> img_data = ImageClassificationData.from_filepaths(["a.png", "b.png"], [0, 1]) # doctest: +SKIP - Example when labels are in .csv file:: - train_labels = labels_from_categorical_csv('path/to/train.csv', 'my_id') val_labels = labels_from_categorical_csv(path/to/val.csv', 'my_id') test_labels = labels_from_categorical_csv(path/to/tests.csv', 'my_id') - data = ImageClassificationData.from_filepaths( batch_size=2, train_filepaths='path/to/train', @@ -524,7 +476,6 @@ def from_filepaths( test_filepaths='path/to/test', test_labels=test_labels, ) - """ # enable passing in a string which loads all files in that folder as a list if isinstance(train_filepaths, str): @@ -532,59 +483,34 @@ def from_filepaths( train_filepaths = [os.path.join(train_filepaths, x) for x in os.listdir(train_filepaths)] else: train_filepaths = [train_filepaths] + if isinstance(val_filepaths, str): if os.path.isdir(val_filepaths): val_filepaths = [os.path.join(val_filepaths, x) for x in os.listdir(val_filepaths)] else: val_filepaths = [val_filepaths] + if isinstance(test_filepaths, str): if os.path.isdir(test_filepaths): test_filepaths = [os.path.join(test_filepaths, x) for x in os.listdir(test_filepaths)] else: test_filepaths = [test_filepaths] - if isinstance(predict_filepaths, str): - if os.path.isdir(predict_filepaths): - predict_filepaths = [os.path.join(predict_filepaths, x) for x in os.listdir(predict_filepaths)] - else: - predict_filepaths = [predict_filepaths] - if train_filepaths is not None and train_labels is not None: - train_dataset = cls._generate_dataset_if_possible( - list(zip(train_filepaths, train_labels)), running_stage=RunningStage.TRAINING - ) - else: - train_dataset = None - - if val_filepaths is not None and val_labels is not None: - val_dataset = cls._generate_dataset_if_possible( - list(zip(val_filepaths, val_labels)), running_stage=RunningStage.VALIDATING - ) - else: - val_dataset = None - - if test_filepaths is not None and test_labels is not None: - test_dataset = cls._generate_dataset_if_possible( - list(zip(test_filepaths, test_labels)), running_stage=RunningStage.TESTING - ) - else: - test_dataset = None - - if predict_filepaths is not None: - predict_dataset = cls._generate_dataset_if_possible( - predict_filepaths, running_stage=RunningStage.PREDICTING - ) - else: - predict_dataset = None + preprocess = cls.instantiate_preprocess( + train_transform, + val_transform, + test_transform, + predict_transform, + preprocess_cls=preprocess_cls, + ) - return cls( - train_dataset=train_dataset, - val_dataset=val_dataset, - test_dataset=test_dataset, - predict_dataset=predict_dataset, - train_transform=train_transform, - val_transform=val_transform, + return cls.from_load_data_inputs( + train_load_data_input=list(zip(train_filepaths, train_labels)) if train_filepaths else None, + val_load_data_input=list(zip(val_filepaths, val_labels)) if val_filepaths else None, + test_load_data_input=list(zip(test_filepaths, test_labels)) if test_filepaths else None, + predict_load_data_input=predict_filepaths, batch_size=batch_size, num_workers=num_workers, - seed=seed, + preprocess=preprocess, **kwargs ) diff --git a/tests/data/test_base_viz.py b/tests/data/test_base_viz.py index 20fc836f51..c153903a76 100644 --- a/tests/data/test_base_viz.py +++ b/tests/data/test_base_viz.py @@ -75,3 +75,6 @@ def test_base_viz(tmpdir): img_data.show_train_batch() assert img_data.viz.batches["train"]["load_sample"] is not None + assert img_data.viz.batches["train"]["to_tensor_transform"] is not None + assert img_data.viz.batches["train"]["collate"] is not None + assert img_data.viz.batches["train"]["per_batch_transform"] is not None diff --git a/tests/data/test_data_pipeline.py b/tests/data/test_data_pipeline.py index 7aac65b07a..c1d8ae6b62 100644 --- a/tests/data/test_data_pipeline.py +++ b/tests/data/test_data_pipeline.py @@ -498,41 +498,59 @@ def __init__(self): self.predict_load_data_called = False def train_load_data(self, sample) -> LamdaDummyDataset: + assert self.training + assert self.current_fn == "load_data" self.train_load_data_called = True return LamdaDummyDataset(lambda: (0, 1, 2, 3)) def train_pre_tensor_transform(self, sample: Any) -> Any: + assert self.training + assert self.current_fn == "pre_tensor_transform" self.train_pre_tensor_transform_called = True return sample + (5, ) def train_collate(self, samples) -> Tensor: + assert self.training + assert self.current_fn == "collate" self.train_collate_called = True return tensor([list(s) for s in samples]) def train_per_batch_transform_on_device(self, batch: Any) -> Any: + assert self.training + assert self.current_fn == "per_batch_transform_on_device" self.train_per_batch_transform_on_device_called = True assert torch.equal(batch, tensor([[0, 1, 2, 3, 5], [0, 1, 2, 3, 5]])) def val_load_data(self, sample, dataset) -> List[int]: + assert self.validating + assert self.current_fn == "load_data" self.val_load_data_called = True assert isinstance(dataset, AutoDataset) return list(range(5)) def val_load_sample(self, sample) -> Dict[str, Tensor]: + assert self.validating + assert self.current_fn == "load_sample" self.val_load_sample_called = True return {"a": sample, "b": sample + 1} def val_to_tensor_transform(self, sample: Any) -> Tensor: + assert self.validating + assert self.current_fn == "to_tensor_transform" self.val_to_tensor_transform_called = True return sample def val_collate(self, samples) -> Dict[str, Tensor]: + assert self.validating + assert self.current_fn == "collate" self.val_collate_called = True _count = samples[0]['a'] assert samples == [{'a': _count, 'b': _count + 1}, {'a': _count + 1, 'b': _count + 2}] return {'a': tensor([0, 1]), 'b': tensor([1, 2])} def val_per_batch_transform_on_device(self, batch: Any) -> Any: + assert self.validating + assert self.current_fn == "per_batch_transform_on_device" self.val_per_batch_transform_on_device_called = True batch = batch[0] assert torch.equal(batch["a"], tensor([0, 1])) @@ -540,18 +558,26 @@ def val_per_batch_transform_on_device(self, batch: Any) -> Any: return [False] def test_load_data(self, sample) -> LamdaDummyDataset: + assert self.testing + assert self.current_fn == "load_data" self.test_load_data_called = True return LamdaDummyDataset(lambda: [torch.rand(1), torch.rand(1)]) def test_to_tensor_transform(self, sample: Any) -> Tensor: + assert self.testing + assert self.current_fn == "to_tensor_transform" self.test_to_tensor_transform_called = True return sample def test_post_tensor_transform(self, sample: Tensor) -> Tensor: + assert self.testing + assert self.current_fn == "post_tensor_transform" self.test_post_tensor_transform_called = True return sample def predict_load_data(self, sample) -> LamdaDummyDataset: + assert self.predicting + assert self.current_fn == "load_data" self.predict_load_data_called = True return LamdaDummyDataset(lambda: (["a", "b"])) @@ -563,7 +589,6 @@ def val_to_tensor_transform(self, sample: Any) -> Tensor: return {"a": tensor(sample["a"]), "b": tensor(sample["b"])} -@pytest.mark.skipif(reason="Still using DataPipeline Old API") def test_datapipeline_transformations(tmpdir): class CustomModel(Task): @@ -619,21 +644,20 @@ class CustomDataModule(DataModule): trainer.test(model) trainer.predict(model) - # todo (tchaton) resolve the lost reference. preprocess = model._preprocess - # assert preprocess.train_load_data_called - # assert preprocess.train_pre_tensor_transform_called - # assert preprocess.train_collate_called + assert preprocess.train_load_data_called + assert preprocess.train_pre_tensor_transform_called + assert preprocess.train_collate_called assert preprocess.train_per_batch_transform_on_device_called - # assert preprocess.val_load_data_called - # assert preprocess.val_load_sample_called - # assert preprocess.val_to_tensor_transform_called - # assert preprocess.val_collate_called + assert preprocess.val_load_data_called + assert preprocess.val_load_sample_called + assert preprocess.val_to_tensor_transform_called + assert preprocess.val_collate_called assert preprocess.val_per_batch_transform_on_device_called - # assert preprocess.test_load_data_called - # assert preprocess.test_to_tensor_transform_called - # assert preprocess.test_post_tensor_transform_called - # assert preprocess.predict_load_data_called + assert preprocess.test_load_data_called + assert preprocess.test_to_tensor_transform_called + assert preprocess.test_post_tensor_transform_called + assert preprocess.predict_load_data_called def test_is_overriden_recursive(tmpdir): From 2b2c49901eb5f1dc24c14e589ab894f3a62cce09 Mon Sep 17 00:00:00 2001 From: tchaton Date: Wed, 31 Mar 2021 13:22:25 +0100 Subject: [PATCH 03/30] update --- flash/core/classification.py | 3 +- flash/core/model.py | 13 +++++++-- flash/data/base_viz.py | 43 +++++++++++++---------------- flash/data/data_module.py | 21 ++++++++++++-- flash/data/data_pipeline.py | 37 ++++++++++++++----------- flash/data/process.py | 11 ++++++++ flash/data/utils.py | 1 + flash/vision/classification/data.py | 37 ++++++++++++++++++++----- tests/data/test_base_viz.py | 9 ++---- tests/examples/test_scripts.py | 10 +++---- 10 files changed, 119 insertions(+), 66 deletions(-) diff --git a/flash/core/classification.py b/flash/core/classification.py index 86b4066410..4340f404b5 100644 --- a/flash/core/classification.py +++ b/flash/core/classification.py @@ -11,10 +11,9 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -from typing import Any, Union +from typing import Any import torch -from torch import Tensor from flash.core.model import Task from flash.data.process import Postprocess diff --git a/flash/core/model.py b/flash/core/model.py index 6cc7bcda5f..b03a424dd2 100644 --- a/flash/core/model.py +++ b/flash/core/model.py @@ -12,7 +12,8 @@ # See the License for the specific language governing permissions and # limitations under the License. import functools -import os +import inspect +from copy import deepcopy from typing import Any, Callable, Dict, List, Mapping, Optional, Sequence, Type, Union import torch @@ -244,13 +245,21 @@ def on_fit_end(self) -> None: self.data_pipeline._detach_from_model(self) super().on_fit_end() + def _sanetize_funcs(self, obj: Any) -> Any: + if hasattr(obj, "__dict__"): + for k, v in obj.__dict__.items(): + if isinstance(v, Callable): + obj.__dict__[k] = inspect.unwrap(v) + return obj + def on_save_checkpoint(self, checkpoint: Dict[str, Any]) -> None: # TODO: Is this the best way to do this? or should we also use some kind of hparams here? # This may be an issue since here we create the same problems with pickle as in # https://pytorch.org/docs/stable/notes/serialization.html - if self.data_pipeline is not None and 'data_pipeline' not in checkpoint: + self._preprocess = self._sanetize_funcs(self._preprocess) checkpoint['data_pipeline'] = self.data_pipeline + # todo (tchaton) re-wrap visualization super().on_save_checkpoint(checkpoint) def on_load_checkpoint(self, checkpoint: Dict[str, Any]) -> None: diff --git a/flash/data/base_viz.py b/flash/data/base_viz.py index fb50168d9f..40e341196e 100644 --- a/flash/data/base_viz.py +++ b/flash/data/base_viz.py @@ -4,42 +4,43 @@ from pytorch_lightning.callbacks import Callback from pytorch_lightning.trainer.states import RunningStage -from flash.data.data_module import DataModule from flash.data.data_pipeline import DataPipeline -from flash.data.process import Postprocess, Preprocess +from flash.data.process import Preprocess class BaseViz(Callback): - def __init__(self, datamodule: DataModule): - self._datamodule = datamodule - self._wrap_preprocess() - + def __init__(self, enabled: bool = False): self.batches = {"train": {}, "val": {}, "test": {}, "predict": {}} + self.enabled = enabled + self._datamodule = None + + def attach_to_preprocess(self, preprocess: Preprocess) -> None: + self._wrap_functions_per_stage(RunningStage.TRAINING, preprocess) + + def attach_to_datamodule(self, datamodule) -> None: + self._datamodule = datamodule + datamodule.viz = self def _wrap_fn( self, fn: Callable, running_stage: RunningStage, ) -> Callable: - """ - """ @functools.wraps(fn) - def wrapper(data) -> Any: - print(data) - data = fn(data) - print(data) - batches = self.batches[running_stage.value] - if fn.__name__ not in batches: - batches[fn.__name__] = [] - batches[fn.__name__].append(data) + def wrapper(*args) -> Any: + data = fn(*args) + if self.enabled: + batches = self.batches[running_stage.value] + if fn.__name__ not in batches: + batches[fn.__name__] = [] + batches[fn.__name__].append(data) return data return wrapper - def _wrap_functions_per_stage(self, running_stage: RunningStage): - preprocess = self._datamodule.data_pipeline._preprocess_pipeline + def _wrap_functions_per_stage(self, running_stage: RunningStage, preprocess: Preprocess): fn_names = { k: DataPipeline._resolve_function_hierarchy(k, preprocess, running_stage, Preprocess) for k in DataPipeline.PREPROCESS_FUNCS @@ -47,9 +48,3 @@ def _wrap_functions_per_stage(self, running_stage: RunningStage): for fn_name in fn_names: fn = getattr(preprocess, fn_name) setattr(preprocess, fn_name, self._wrap_fn(fn, running_stage)) - - # hack until solved - self._datamodule._train_ds.load_sample = preprocess.load_sample - - def _wrap_preprocess(self): - self._wrap_functions_per_stage(RunningStage.TRAINING) diff --git a/flash/data/data_module.py b/flash/data/data_module.py index f998c62ad1..286be2b6fa 100644 --- a/flash/data/data_module.py +++ b/flash/data/data_module.py @@ -24,6 +24,7 @@ from torch.utils.data.dataset import Subset from flash.data.auto_dataset import AutoDataset +from flash.data.base_viz import BaseViz from flash.data.data_pipeline import DataPipeline, Postprocess, Preprocess @@ -83,12 +84,22 @@ def __init__( self._preprocess = None self._postprocess = None + self._viz = None # this may also trigger data preloading self.set_running_stages() - def configure_vis(self): - return self + @property + def viz(self) -> BaseViz: + return self._viz or DataModule.configure_vis() + + @viz.setter + def viz(self, viz: BaseViz) -> None: + self._viz = viz + + @classmethod + def configure_vis(cls) -> BaseViz: + return BaseViz() @staticmethod def get_dataset_attribute(dataset: torch.utils.data.Dataset, attr_name: str, default: Optional[Any] = None) -> Any: @@ -322,6 +333,10 @@ def from_load_data_inputs( ) else: data_pipeline = cls(**kwargs).data_pipeline + + viz_callback = cls.configure_vis() + viz_callback.attach_to_preprocess(data_pipeline._preprocess_pipeline) + train_dataset = cls._generate_dataset_if_possible( train_load_data_input, running_stage=RunningStage.TRAINING, data_pipeline=data_pipeline ) @@ -343,5 +358,5 @@ def from_load_data_inputs( ) datamodule._preprocess = data_pipeline._preprocess_pipeline datamodule._postprocess = data_pipeline._postprocess_pipeline - datamodule.configure_vis() + viz_callback.attach_to_datamodule(datamodule) return datamodule diff --git a/flash/data/data_pipeline.py b/flash/data/data_pipeline.py index 2aca209f3d..40f9d48be8 100644 --- a/flash/data/data_pipeline.py +++ b/flash/data/data_pipeline.py @@ -18,6 +18,7 @@ from pytorch_lightning.trainer.connectors.data_connector import _PatchDataLoader from pytorch_lightning.trainer.states import RunningStage +from pytorch_lightning.utilities import imports from pytorch_lightning.utilities.exceptions import MisconfigurationException from torch.utils.data._utils.collate import default_collate, default_convert from torch.utils.data.dataloader import DataLoader @@ -240,23 +241,27 @@ def _create_collate_preprocessors( if collate_fn is None: collate_fn = default_collate + preprocess = self._preprocess_pipeline + func_names = { - k: self._resolve_function_hierarchy(k, self._preprocess_pipeline, stage, Preprocess) + k: self._resolve_function_hierarchy(k, preprocess, stage, Preprocess) for k in self.PREPROCESS_FUNCS } - if self._is_overriden_recursive("collate", self._preprocess_pipeline, Preprocess, prefix=_STAGES_PREFIX[stage]): - collate_fn = getattr(self._preprocess_pipeline, func_names["collate"]) + if self._is_overriden_recursive("collate", preprocess, Preprocess, prefix=_STAGES_PREFIX[stage]): + collate_fn = getattr(preprocess, func_names["collate"]) per_batch_transform_overriden = self._is_overriden_recursive( - "per_batch_transform", self._preprocess_pipeline, Preprocess, prefix=_STAGES_PREFIX[stage] + "per_batch_transform", preprocess, Preprocess, prefix=_STAGES_PREFIX[stage] ) per_sample_transform_on_device_overriden = self._is_overriden_recursive( - "per_sample_transform_on_device", self._preprocess_pipeline, Preprocess, prefix=_STAGES_PREFIX[stage] + "per_sample_transform_on_device", preprocess, Preprocess, prefix=_STAGES_PREFIX[stage] ) - if per_batch_transform_overriden and per_sample_transform_on_device_overriden: + skip_mutual_check = preprocess.skip_mutual_check + + if (not skip_mutual_check and per_batch_transform_overriden and per_sample_transform_on_device_overriden): raise MisconfigurationException( f'{self.__class__.__name__}: `per_batch_transform` and `per_sample_transform_on_device` ' f'are mutual exclusive for stage {stage}' @@ -279,26 +284,26 @@ def _create_collate_preprocessors( ) else worker_collate_fn assert_contains_tensor = self._is_overriden_recursive( - "to_tensor_transform", self._preprocess_pipeline, Preprocess, prefix=_STAGES_PREFIX[stage] + "to_tensor_transform", preprocess, Preprocess, prefix=_STAGES_PREFIX[stage] ) worker_preprocessor = _PreProcessor( - self._preprocess_pipeline, worker_collate_fn, + preprocess, worker_collate_fn, _Sequential( - self._preprocess_pipeline, - getattr(self._preprocess_pipeline, func_names['pre_tensor_transform']), - getattr(self._preprocess_pipeline, func_names['to_tensor_transform']), - getattr(self._preprocess_pipeline, func_names['post_tensor_transform']), + preprocess, + getattr(preprocess, func_names['pre_tensor_transform']), + getattr(preprocess, func_names['to_tensor_transform']), + getattr(preprocess, func_names['post_tensor_transform']), stage, assert_contains_tensor=assert_contains_tensor, - ), getattr(self._preprocess_pipeline, func_names['per_batch_transform']), stage + ), getattr(preprocess, func_names['per_batch_transform']), stage ) worker_preprocessor._original_collate_fn = original_collate_fn device_preprocessor = _PreProcessor( - self._preprocess_pipeline, + preprocess, device_collate_fn, - getattr(self._preprocess_pipeline, func_names['per_sample_transform_on_device']), - getattr(self._preprocess_pipeline, func_names['per_batch_transform_on_device']), + getattr(preprocess, func_names['per_sample_transform_on_device']), + getattr(preprocess, func_names['per_batch_transform_on_device']), stage, apply_per_sample_transform=device_collate_fn != self._identity, on_device=True, diff --git a/flash/data/process.py b/flash/data/process.py index 35ecda3993..62b23cc4a0 100644 --- a/flash/data/process.py +++ b/flash/data/process.py @@ -114,6 +114,17 @@ def __init__( self.test_transform = convert_to_modules(test_transform) self.predict_transform = convert_to_modules(predict_transform) + if not hasattr(self, "_skip_mutual_check"): + self._skip_mutual_check = False + + @property + def skip_mutual_check(self) -> bool: + return self._skip_mutual_check + + @skip_mutual_check.setter + def skip_mutual_check(self, skip_mutual_check: bool) -> None: + self._skip_mutual_check = skip_mutual_check + def _identify(self, x): return x diff --git a/flash/data/utils.py b/flash/data/utils.py index cad73d3258..4b7fec9122 100644 --- a/flash/data/utils.py +++ b/flash/data/utils.py @@ -33,6 +33,7 @@ _STAGES_PREFIX_VALUES = {"train", "test", "val", "predict"} +# todo (tchaton) convert to class @contextmanager def set_current_stage(obj: Any, stage: RunningStage) -> None: if obj is not None: diff --git a/flash/vision/classification/data.py b/flash/vision/classification/data.py index 6f9cb8bb36..d09f467155 100644 --- a/flash/vision/classification/data.py +++ b/flash/vision/classification/data.py @@ -16,6 +16,7 @@ from typing import Any, Callable, Dict, Iterable, List, Optional, Sequence, Tuple, Type, Union import torch +import torchvision from PIL import Image from pytorch_lightning.trainer.states import RunningStage from pytorch_lightning.utilities.exceptions import MisconfigurationException @@ -23,7 +24,7 @@ from torch.nn.modules import ModuleDict from torch.utils.data import Dataset from torch.utils.data._utils.collate import default_collate -from torchvision import transforms as torchvision_T +from torchvision import transforms from torchvision.datasets.folder import has_file_allowed_extension, IMG_EXTENSIONS, make_dataset from torchvision.transforms.functional import to_pil_image @@ -42,6 +43,11 @@ class ImageClassificationPreprocess(Preprocess): + # this assignement is used to skip the assert that `per_batch_transform` and `per_sample_transform_on_device` + # are mutually exclusive on the DataPipeline internals + _skip_mutual_check = True + to_tensor = torchvision.transforms.ToTensor() + @staticmethod def _find_classes(dir: str) -> Tuple: """ @@ -152,7 +158,7 @@ def common_step(self, sample: Any) -> Any: if isinstance(sample, (list, tuple)): source, target = sample return self.current_transform(source), target - if isinstance(sample, torch.Tensor): + elif isinstance(sample, torch.Tensor): return sample return self.current_transform(sample) @@ -160,6 +166,13 @@ def per_tensor_transform(self, sample: Any) -> Any: return self.common_step(sample) def to_tensor_transform(self, sample: Any) -> Any: + if self.current_transform == self._identify: + if isinstance(sample, (list, tuple)): + source, target = sample + return self.to_tensor(source), target + elif isinstance(sample, torch.Tensor): + return sample + return self.to_tensor(sample) return self.common_step(sample) def post_tensor_transform(self, sample: Any) -> Any: @@ -168,6 +181,9 @@ def post_tensor_transform(self, sample: Any) -> Any: def per_batch_transform(self, sample: Any) -> Any: return self.common_step(sample) + def per_sample_transform_on_device(self, sample: Any) -> Any: + return self.common_step(sample) + def per_batch_transform_on_device(self, sample: Any) -> Any: return self.common_step(sample) @@ -229,6 +245,11 @@ def _check_transforms(transform: Dict[str, Union[nn.Module, Callable]]) -> Dict[ "Transform should be a dict. " f"Here are the available keys for your transforms: {DataPipeline.PREPROCESS_FUNCS}." ) + if "per_batch_transform" in transform and "per_sample_transform_on_device" in transform: + raise MisconfigurationException( + f'{transform}: `per_batch_transform` and `per_sample_transform_on_device` ' + f'are mutual exclusive.' + ) return transform @staticmethod @@ -237,7 +258,7 @@ def default_train_transforms(): if _KORNIA_AVAILABLE and not os.getenv("FLASH_TESTING", "0") == "1": # Better approach as all transforms are applied on tensor directly return { - "to_tensor_transform": torchvision_T.ToTensor(), + "to_tensor_transform": torchvision.transforms.ToTensor(), "post_tensor_transform": nn.Sequential(K.RandomResizedCrop(image_size), K.RandomHorizontalFlip()), "per_batch_transform_on_device": nn.Sequential( K.Normalize(torch.tensor([0.485, 0.456, 0.406]), torch.tensor([0.229, 0.224, 0.225])), @@ -247,7 +268,7 @@ def default_train_transforms(): from torchvision import transforms as T # noqa F811 return { "pre_tensor_transform": nn.Sequential(T.RandomResizedCrop(image_size), T.RandomHorizontalFlip()), - "to_tensor_transform": torchvision_T.ToTensor(), + "to_tensor_transform": torchvision.transforms.ToTensor(), "post_tensor_transform": T.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]), } @@ -257,7 +278,7 @@ def default_val_transforms(): if _KORNIA_AVAILABLE and not os.getenv("FLASH_TESTING", "0") == "1": # Better approach as all transforms are applied on tensor directly return { - "to_tensor_transform": torchvision_T.ToTensor(), + "to_tensor_transform": torchvision.transforms.ToTensor(), "post_tensor_transform": nn.Sequential(K.RandomResizedCrop(image_size)), "per_batch_transform_on_device": nn.Sequential( K.Normalize(torch.tensor([0.485, 0.456, 0.406]), torch.tensor([0.229, 0.224, 0.225])), @@ -267,7 +288,7 @@ def default_val_transforms(): from torchvision import transforms as T # noqa F811 return { "pre_tensor_transform": T.Compose([T.RandomResizedCrop(image_size)]), - "to_tensor_transform": torchvision_T.ToTensor(), + "to_tensor_transform": torchvision.transforms.ToTensor(), "post_tensor_transform": T.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]), } @@ -323,7 +344,9 @@ def instantiate_preprocess( ) preprocess_cls = preprocess_cls or cls.preprocess_cls - return preprocess_cls(train_transform, val_transform, test_transform, predict_transform) + preprocess = preprocess_cls(train_transform, val_transform, test_transform, predict_transform) + # todo (tchaton) add check on mutually exclusive transforms + return preprocess @classmethod def _resolve_transforms( diff --git a/tests/data/test_base_viz.py b/tests/data/test_base_viz.py index c153903a76..a4e49a3caf 100644 --- a/tests/data/test_base_viz.py +++ b/tests/data/test_base_viz.py @@ -44,14 +44,10 @@ def _rand_image(): class ImageClassificationDataViz(ImageClassificationData): - def configure_vis(self): - if not hasattr(self, "viz"): - return BaseViz(self) - return self.viz - def show_train_batch(self): - self.viz = self.configure_vis() + self.viz.enabled = True _ = next(iter(self.train_dataloader())) + self.viz.enabled = False def test_base_viz(tmpdir): @@ -67,7 +63,6 @@ def test_base_viz(tmpdir): img_data = ImageClassificationDataViz.from_filepaths( train_filepaths=[tmpdir / "a", tmpdir / "b"], - train_transform=None, train_labels=[0, 1], batch_size=1, num_workers=0, diff --git a/tests/examples/test_scripts.py b/tests/examples/test_scripts.py index 669baee5a1..3db3e9384b 100644 --- a/tests/examples/test_scripts.py +++ b/tests/examples/test_scripts.py @@ -61,11 +61,11 @@ def run_test(filepath): ("finetuning", "tabular_classification.py"), ("finetuning", "text_classification.py"), # TODO: takes too long # ("finetuning", "translation.py"), # TODO: takes too long. - ("predict", "image_classification.py"), - ("predict", "tabular_classification.py"), - ("predict", "text_classification.py"), - ("predict", "image_embedder.py"), - ("predict", "summarization.py"), # TODO: takes too long + #("predict", "image_classification.py"), + #("predict", "tabular_classification.py"), + #("predict", "text_classification.py"), + #("predict", "image_embedder.py"), + #("predict", "summarization.py"), # TODO: takes too long # ("predict", "translate.py"), # TODO: takes too long ] ) From 6db6b1cb90a2a457cabfa553a93c97ee8291f870 Mon Sep 17 00:00:00 2001 From: tchaton Date: Wed, 31 Mar 2021 13:22:52 +0100 Subject: [PATCH 04/30] resolve flake8 --- tests/examples/test_scripts.py | 10 +++++----- 1 file changed, 5 insertions(+), 5 deletions(-) diff --git a/tests/examples/test_scripts.py b/tests/examples/test_scripts.py index 3db3e9384b..e8bab588f3 100644 --- a/tests/examples/test_scripts.py +++ b/tests/examples/test_scripts.py @@ -61,11 +61,11 @@ def run_test(filepath): ("finetuning", "tabular_classification.py"), ("finetuning", "text_classification.py"), # TODO: takes too long # ("finetuning", "translation.py"), # TODO: takes too long. - #("predict", "image_classification.py"), - #("predict", "tabular_classification.py"), - #("predict", "text_classification.py"), - #("predict", "image_embedder.py"), - #("predict", "summarization.py"), # TODO: takes too long + # ("predict", "image_classification.py"), + # ("predict", "tabular_classification.py"), + # ("predict", "text_classification.py"), + # ("predict", "image_embedder.py"), + # ("predict", "summarization.py"), # TODO: takes too long # ("predict", "translate.py"), # TODO: takes too long ] ) From f61deeacb9b9d1b4766a8ee2411aedfdec8fa2e5 Mon Sep 17 00:00:00 2001 From: tchaton Date: Wed, 31 Mar 2021 17:23:28 +0100 Subject: [PATCH 05/30] update --- flash/data/auto_dataset.py | 26 +++++---- flash/data/batch.py | 31 ++++++---- flash/data/process.py | 8 +-- flash/data/utils.py | 90 ++++++++++++++++++----------- flash/vision/classification/data.py | 2 + tests/examples/test_scripts.py | 2 +- 6 files changed, 98 insertions(+), 61 deletions(-) diff --git a/flash/data/auto_dataset.py b/flash/data/auto_dataset.py index e42a4cf680..7a1bc0455b 100644 --- a/flash/data/auto_dataset.py +++ b/flash/data/auto_dataset.py @@ -20,7 +20,7 @@ from torch.utils.data import Dataset from flash.data.process import Preprocess -from flash.data.utils import _STAGES_PREFIX, _STAGES_PREFIX_VALUES, set_current_stage_and_fn +from flash.data.utils import _STAGES_PREFIX, _STAGES_PREFIX_VALUES, CurrentRunningStageFuncContext if TYPE_CHECKING: from flash.data.data_pipeline import DataPipeline @@ -68,9 +68,13 @@ def running_stage(self) -> Optional[RunningStage]: return self._running_stage @running_stage.setter - def running_stage(self, running_stage: str) -> None: + def running_stage(self, running_stage: RunningStage) -> None: if self._running_stage != running_stage or (not self._running_stage): self._running_stage = running_stage + self._load_data_context = CurrentRunningStageFuncContext(self._running_stage, "load_data", self._preprocess) + self._load_sample_context = CurrentRunningStageFuncContext( + self._running_stage, "load_sample", self._preprocess + ) self._setup(running_stage) @property @@ -80,11 +84,10 @@ def _preprocess(self): def _call_load_data(self, data: Any) -> Iterable: parameters = signature(self.load_data).parameters - with set_current_stage_and_fn(self._preprocess, self._running_stage, "load_data"): - if len(parameters) > 1 and self.DATASET_KEY in parameters: - return self.load_data(data, self) - else: - return self.load_data(data) + if len(parameters) > 1 and self.DATASET_KEY in parameters: + return self.load_data(data, self) + else: + return self.load_data(data) def _call_load_sample(self, sample: Any) -> Any: parameters = signature(self.load_sample).parameters @@ -116,16 +119,17 @@ def _setup(self, stage: Optional[RunningStage]) -> None: "The load_data function of the Autogenerated Dataset changed. " "This is not expected! Preloading Data again to ensure compatibility. This may take some time." ) - self._preprocessed_data = self._call_load_data(self.data) + with self._load_data_context: + self._preprocessed_data = self._call_load_data(self.data) self._load_data_called = True def __getitem__(self, index: int) -> Any: if not self.load_sample and not self.load_data: raise RuntimeError("`__getitem__` for `load_sample` and `load_data` could not be inferred.") - with set_current_stage_and_fn(self._preprocess, self._running_stage, "load_sample"): - if self.load_sample: + if self.load_sample: + with self._load_sample_context: return self._call_load_sample(self._preprocessed_data[index]) - return self._preprocessed_data[index] + return self._preprocessed_data[index] def __len__(self) -> int: if not self.load_sample and not self.load_data: diff --git a/flash/data/batch.py b/flash/data/batch.py index 1047e85a44..3b7bc70ab4 100644 --- a/flash/data/batch.py +++ b/flash/data/batch.py @@ -18,7 +18,7 @@ from pytorch_lightning.utilities.exceptions import MisconfigurationException from torch import Tensor -from flash.data.utils import _contains_any_tensor, convert_to_modules, set_current_fn, set_current_stage +from flash.data.utils import _contains_any_tensor, convert_to_modules, CurrentFuncContext, CurrentRunningStageContext if TYPE_CHECKING: from flash.data.process import Preprocess @@ -49,12 +49,17 @@ def __init__( self.stage = stage self.assert_contains_tensor = assert_contains_tensor + self._current_stage_context = CurrentRunningStageContext(stage, preprocess, reset=False) + self._pre_tensor_transform_context = CurrentFuncContext("pre_tensor_transform", preprocess) + self._to_tensor_transform_context = CurrentFuncContext("to_tensor_transform", preprocess) + self._post_tensor_transform_context = CurrentFuncContext("post_tensor_transform", preprocess) + def forward(self, sample: Any): - with set_current_stage(self.preprocess, self.stage): - with set_current_fn(self.preprocess, "pre_tensor_transform"): + with self._current_stage_context: + with self._pre_tensor_transform_context: sample = self.pre_tensor_transform(sample) - with set_current_fn(self.preprocess, "to_tensor_transform"): + with self._to_tensor_transform_context: sample = self.to_tensor_transform(sample) if self.assert_contains_tensor: @@ -64,7 +69,7 @@ def forward(self, sample: Any): "``DataPipeline`` expects the outputs to be ``tensors``" ) - with set_current_fn(self.preprocess, "post_tensor_transform"): + with self._post_tensor_transform_context: sample = self.post_tensor_transform(sample) return sample @@ -105,7 +110,7 @@ def __init__( collate_fn: Callable, per_sample_transform: Union[Callable, _Sequential], per_batch_transform: Callable, - stage: Optional[RunningStage] = None, + stage: RunningStage, apply_per_sample_transform: bool = True, on_device: bool = False ): @@ -118,18 +123,24 @@ def __init__( self.stage = stage self.on_device = on_device + extension = f"{'on_device' if self.on_device else ''}" + self._current_stage_context = CurrentRunningStageContext(stage, preprocess) + self._per_sample_transform_context = CurrentFuncContext(f"per_sample_transform_{extension}", preprocess) + self._collate_context = CurrentFuncContext("collate", preprocess) + self._per_batch_transform_context = CurrentFuncContext(f"per_batch_transform_{extension}", preprocess) + def forward(self, samples: Sequence[Any]): - with set_current_stage(self.preprocess, self.stage): + with self._current_stage_context: if self.apply_per_sample_transform: - with set_current_fn(self.preprocess, f"per_sample_transform_{'on_device' if self.on_device else ''}"): + with self._per_sample_transform_context: samples = [self.per_sample_transform(sample) for sample in samples] samples = type(samples)(samples) - with set_current_fn(self.preprocess, "collate"): + with self._collate_context: samples = self.collate_fn(samples) - with set_current_fn(self.preprocess, f"per_batch_transform_{'on_device' if self.on_device else ''}"): + with self._per_batch_transform_context: samples = self.per_batch_transform(samples) return samples diff --git a/flash/data/process.py b/flash/data/process.py index 62b23cc4a0..4dc1fea4c0 100644 --- a/flash/data/process.py +++ b/flash/data/process.py @@ -27,11 +27,11 @@ class Properties: - _running_stage: RunningStage = None - _current_fn: str = None + _running_stage: Optional[RunningStage] = None + _current_fn: Optional[str] = None @property - def current_fn(self) -> str: + def current_fn(self) -> Optional[str]: return self._current_fn @current_fn.setter @@ -39,7 +39,7 @@ def current_fn(self, current_fn: str): self._current_fn = current_fn @property - def running_stage(self) -> RunningStage: + def running_stage(self) -> Optional[RunningStage]: return self._running_stage @running_stage.setter diff --git a/flash/data/utils.py b/flash/data/utils.py index 4b7fec9122..f3928861ba 100644 --- a/flash/data/utils.py +++ b/flash/data/utils.py @@ -14,7 +14,7 @@ import os.path import zipfile -from contextlib import contextmanager +from contextlib import ContextDecorator, contextmanager from typing import Any, Callable, Dict, Iterable, Mapping, Type import requests @@ -33,40 +33,60 @@ _STAGES_PREFIX_VALUES = {"train", "test", "val", "predict"} -# todo (tchaton) convert to class -@contextmanager -def set_current_stage(obj: Any, stage: RunningStage) -> None: - if obj is not None: - if getattr(obj, "_running_stage", None) == stage: - yield - else: - obj.running_stage = stage - yield - obj.running_stage = None - else: - yield - - -@contextmanager -def set_current_fn(obj: Any, current_fn: str) -> None: - if obj is not None: - obj.current_fn = current_fn - yield - obj.current_fn = None - else: - yield - - -@contextmanager -def set_current_stage_and_fn(obj: Any, stage: RunningStage, current_fn: str) -> None: - if obj is not None: - obj.running_stage = stage - obj.current_fn = current_fn - yield - obj.running_stage = None - obj.current_fn = None - else: - yield +class CurrentRunningStageContext: + + def __init__(self, running_stage: RunningStage, obj: Any, reset: bool = True): + self._running_stage = running_stage + self._obj = obj + self._reset = reset + + def __enter__(self): + if self._obj is not None: + if getattr(self._obj, "running_stage", None) != self._running_stage: + self._obj.running_stage = self._running_stage + return self + + def __exit__(self, exc_type: Any, exc_val: Any, exc_tb: Any) -> None: + if self._obj is not None and self._reset: + self._obj.running_stage = None + + +class CurrentFuncContext: + + def __init__(self, current_fn: str, obj: Any): + self._current_fn = current_fn + self._obj = obj + + def __enter__(self): + if self._obj is not None: + if getattr(self._obj, "current_fn", None) != self._current_fn: + self._obj.current_fn = self._current_fn + return self + + def __exit__(self, exc_type: Any, exc_val: Any, exc_tb: Any) -> None: + if self._obj is not None: + self._obj.current_fn = None + + +class CurrentRunningStageFuncContext: + + def __init__(self, running_stage: RunningStage, current_fn: str, obj: Any): + self._running_stage = running_stage + self._current_fn = current_fn + self._obj = obj + + def __enter__(self): + if self._obj is not None: + if getattr(self._obj, "running_stage", None) != self._running_stage: + self._obj.running_stage = self._running_stage + if getattr(self._obj, "current_fn", None) != self._current_fn: + self._obj.current_fn = self._current_fn + return self + + def __exit__(self, exc_type: Any, exc_val: Any, exc_tb: Any) -> None: + if self._obj is not None: + self._obj.running_stage = None + self._obj.current_fn = None def download_data(url: str, path: str = "data/", verbose: bool = False) -> None: diff --git a/flash/vision/classification/data.py b/flash/vision/classification/data.py index d09f467155..97497c799b 100644 --- a/flash/vision/classification/data.py +++ b/flash/vision/classification/data.py @@ -169,6 +169,8 @@ def to_tensor_transform(self, sample: Any) -> Any: if self.current_transform == self._identify: if isinstance(sample, (list, tuple)): source, target = sample + if isinstance(source, torch.Tensor): + return source, target return self.to_tensor(source), target elif isinstance(sample, torch.Tensor): return sample diff --git a/tests/examples/test_scripts.py b/tests/examples/test_scripts.py index e8bab588f3..d733768d65 100644 --- a/tests/examples/test_scripts.py +++ b/tests/examples/test_scripts.py @@ -53,7 +53,7 @@ def run_test(filepath): @mock.patch.dict(os.environ, {"FLASH_TESTING": "1"}) @pytest.mark.parametrize( - "folder,file", + "folder, file", [ # ("finetuning", "image_classification.py"), # ("finetuning", "object_detection.py"), # TODO: takes too long. From ffaa7c7668ba836786e50aab3e83e0d9f52158f2 Mon Sep 17 00:00:00 2001 From: tchaton Date: Wed, 31 Mar 2021 18:16:00 +0100 Subject: [PATCH 06/30] resolve tests --- flash/data/data_pipeline.py | 2 +- flash/vision/classification/data.py | 11 ++++------- tests/examples/test_scripts.py | 10 +++++----- 3 files changed, 10 insertions(+), 13 deletions(-) diff --git a/flash/data/data_pipeline.py b/flash/data/data_pipeline.py index 40f9d48be8..f28d882a22 100644 --- a/flash/data/data_pipeline.py +++ b/flash/data/data_pipeline.py @@ -259,7 +259,7 @@ def _create_collate_preprocessors( "per_sample_transform_on_device", preprocess, Preprocess, prefix=_STAGES_PREFIX[stage] ) - skip_mutual_check = preprocess.skip_mutual_check + skip_mutual_check = getattr(preprocess, "skip_mutual_check", False) if (not skip_mutual_check and per_batch_transform_overriden and per_sample_transform_on_device_overriden): raise MisconfigurationException( diff --git a/flash/vision/classification/data.py b/flash/vision/classification/data.py index 97497c799b..39bd1a2023 100644 --- a/flash/vision/classification/data.py +++ b/flash/vision/classification/data.py @@ -18,15 +18,11 @@ import torch import torchvision from PIL import Image -from pytorch_lightning.trainer.states import RunningStage from pytorch_lightning.utilities.exceptions import MisconfigurationException from torch import nn -from torch.nn.modules import ModuleDict from torch.utils.data import Dataset from torch.utils.data._utils.collate import default_collate -from torchvision import transforms from torchvision.datasets.folder import has_file_allowed_extension, IMG_EXTENSIONS, make_dataset -from torchvision.transforms.functional import to_pil_image from flash.data.auto_dataset import AutoDataset from flash.data.data_module import DataModule @@ -158,11 +154,9 @@ def common_step(self, sample: Any) -> Any: if isinstance(sample, (list, tuple)): source, target = sample return self.current_transform(source), target - elif isinstance(sample, torch.Tensor): - return sample return self.current_transform(sample) - def per_tensor_transform(self, sample: Any) -> Any: + def pre_tensor_transform(self, sample: Any) -> Any: return self.common_step(sample) def to_tensor_transform(self, sample: Any) -> Any: @@ -175,6 +169,8 @@ def to_tensor_transform(self, sample: Any) -> Any: elif isinstance(sample, torch.Tensor): return sample return self.to_tensor(sample) + if isinstance(sample, torch.Tensor): + return sample return self.common_step(sample) def post_tensor_transform(self, sample: Any) -> Any: @@ -537,5 +533,6 @@ def from_filepaths( batch_size=batch_size, num_workers=num_workers, preprocess=preprocess, + seed=seed, **kwargs ) diff --git a/tests/examples/test_scripts.py b/tests/examples/test_scripts.py index d733768d65..f04af08d54 100644 --- a/tests/examples/test_scripts.py +++ b/tests/examples/test_scripts.py @@ -55,16 +55,16 @@ def run_test(filepath): @pytest.mark.parametrize( "folder, file", [ - # ("finetuning", "image_classification.py"), + ("finetuning", "image_classification.py"), # ("finetuning", "object_detection.py"), # TODO: takes too long. # ("finetuning", "summarization.py"), # TODO: takes too long. ("finetuning", "tabular_classification.py"), - ("finetuning", "text_classification.py"), # TODO: takes too long + # ("finetuning", "text_classification.py"), # TODO: takes too long # ("finetuning", "translation.py"), # TODO: takes too long. - # ("predict", "image_classification.py"), - # ("predict", "tabular_classification.py"), + ("predict", "image_classification.py"), + ("predict", "tabular_classification.py"), # ("predict", "text_classification.py"), - # ("predict", "image_embedder.py"), + ("predict", "image_embedder.py"), # ("predict", "summarization.py"), # TODO: takes too long # ("predict", "translate.py"), # TODO: takes too long ] From 596a523998feab60c90bc284e2aaf115503dcfb4 Mon Sep 17 00:00:00 2001 From: tchaton Date: Wed, 31 Mar 2021 18:20:49 +0100 Subject: [PATCH 07/30] update --- flash/data/base_viz.py | 50 ------------------------- flash/data/data_module.py | 18 --------- tests/data/test_base_viz.py | 75 ------------------------------------- 3 files changed, 143 deletions(-) delete mode 100644 flash/data/base_viz.py delete mode 100644 tests/data/test_base_viz.py diff --git a/flash/data/base_viz.py b/flash/data/base_viz.py deleted file mode 100644 index 40e341196e..0000000000 --- a/flash/data/base_viz.py +++ /dev/null @@ -1,50 +0,0 @@ -import functools -from typing import Any, Callable - -from pytorch_lightning.callbacks import Callback -from pytorch_lightning.trainer.states import RunningStage - -from flash.data.data_pipeline import DataPipeline -from flash.data.process import Preprocess - - -class BaseViz(Callback): - - def __init__(self, enabled: bool = False): - self.batches = {"train": {}, "val": {}, "test": {}, "predict": {}} - self.enabled = enabled - self._datamodule = None - - def attach_to_preprocess(self, preprocess: Preprocess) -> None: - self._wrap_functions_per_stage(RunningStage.TRAINING, preprocess) - - def attach_to_datamodule(self, datamodule) -> None: - self._datamodule = datamodule - datamodule.viz = self - - def _wrap_fn( - self, - fn: Callable, - running_stage: RunningStage, - ) -> Callable: - - @functools.wraps(fn) - def wrapper(*args) -> Any: - data = fn(*args) - if self.enabled: - batches = self.batches[running_stage.value] - if fn.__name__ not in batches: - batches[fn.__name__] = [] - batches[fn.__name__].append(data) - return data - - return wrapper - - def _wrap_functions_per_stage(self, running_stage: RunningStage, preprocess: Preprocess): - fn_names = { - k: DataPipeline._resolve_function_hierarchy(k, preprocess, running_stage, Preprocess) - for k in DataPipeline.PREPROCESS_FUNCS - } - for fn_name in fn_names: - fn = getattr(preprocess, fn_name) - setattr(preprocess, fn_name, self._wrap_fn(fn, running_stage)) diff --git a/flash/data/data_module.py b/flash/data/data_module.py index 286be2b6fa..f7c2e8f6d2 100644 --- a/flash/data/data_module.py +++ b/flash/data/data_module.py @@ -24,7 +24,6 @@ from torch.utils.data.dataset import Subset from flash.data.auto_dataset import AutoDataset -from flash.data.base_viz import BaseViz from flash.data.data_pipeline import DataPipeline, Postprocess, Preprocess @@ -84,23 +83,10 @@ def __init__( self._preprocess = None self._postprocess = None - self._viz = None # this may also trigger data preloading self.set_running_stages() - @property - def viz(self) -> BaseViz: - return self._viz or DataModule.configure_vis() - - @viz.setter - def viz(self, viz: BaseViz) -> None: - self._viz = viz - - @classmethod - def configure_vis(cls) -> BaseViz: - return BaseViz() - @staticmethod def get_dataset_attribute(dataset: torch.utils.data.Dataset, attr_name: str, default: Optional[Any] = None) -> Any: if isinstance(dataset, Subset): @@ -334,9 +320,6 @@ def from_load_data_inputs( else: data_pipeline = cls(**kwargs).data_pipeline - viz_callback = cls.configure_vis() - viz_callback.attach_to_preprocess(data_pipeline._preprocess_pipeline) - train_dataset = cls._generate_dataset_if_possible( train_load_data_input, running_stage=RunningStage.TRAINING, data_pipeline=data_pipeline ) @@ -358,5 +341,4 @@ def from_load_data_inputs( ) datamodule._preprocess = data_pipeline._preprocess_pipeline datamodule._postprocess = data_pipeline._postprocess_pipeline - viz_callback.attach_to_datamodule(datamodule) return datamodule diff --git a/tests/data/test_base_viz.py b/tests/data/test_base_viz.py deleted file mode 100644 index a4e49a3caf..0000000000 --- a/tests/data/test_base_viz.py +++ /dev/null @@ -1,75 +0,0 @@ -# Copyright The PyTorch Lightning team. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -from pathlib import Path -from typing import Any, Callable, Dict, List, Optional, Tuple -from unittest import mock - -import numpy as np -import pytest -import torch -import torchvision.transforms as T -from PIL import Image -from pytorch_lightning import Trainer -from pytorch_lightning.trainer.states import RunningStage -from pytorch_lightning.utilities.exceptions import MisconfigurationException -from torch import Tensor, tensor -from torch.utils.data import DataLoader -from torch.utils.data._utils.collate import default_collate - -from flash.core import Task -from flash.data.auto_dataset import AutoDataset -from flash.data.base_viz import BaseViz -from flash.data.batch import _PostProcessor, _PreProcessor -from flash.data.data_module import DataModule -from flash.data.data_pipeline import _StageOrchestrator, DataPipeline -from flash.data.process import Postprocess, Preprocess -from flash.vision import ImageClassificationData - - -def _rand_image(): - return Image.fromarray(np.random.randint(0, 255, (196, 196, 3), dtype="uint8")) - - -class ImageClassificationDataViz(ImageClassificationData): - - def show_train_batch(self): - self.viz.enabled = True - _ = next(iter(self.train_dataloader())) - self.viz.enabled = False - - -def test_base_viz(tmpdir): - tmpdir = Path(tmpdir) - - (tmpdir / "a").mkdir() - (tmpdir / "b").mkdir() - _rand_image().save(tmpdir / "a" / "a_1.png") - _rand_image().save(tmpdir / "a" / "a_2.png") - - _rand_image().save(tmpdir / "b" / "a_1.png") - _rand_image().save(tmpdir / "b" / "a_2.png") - - img_data = ImageClassificationDataViz.from_filepaths( - train_filepaths=[tmpdir / "a", tmpdir / "b"], - train_labels=[0, 1], - batch_size=1, - num_workers=0, - ) - - img_data.show_train_batch() - assert img_data.viz.batches["train"]["load_sample"] is not None - assert img_data.viz.batches["train"]["to_tensor_transform"] is not None - assert img_data.viz.batches["train"]["collate"] is not None - assert img_data.viz.batches["train"]["per_batch_transform"] is not None From 2fdefbedd0e66861a4ba4cc9fe1ed486b2a79683 Mon Sep 17 00:00:00 2001 From: tchaton Date: Wed, 31 Mar 2021 18:34:06 +0100 Subject: [PATCH 08/30] wip --- flash/data/base_viz.py | 50 ++++++++++++++++++++++++++++++++ flash/data/data_module.py | 18 ++++++++++++ tests/data/test_data_viz.py | 57 +++++++++++++++++++++++++++++++++++++ 3 files changed, 125 insertions(+) create mode 100644 flash/data/base_viz.py create mode 100644 tests/data/test_data_viz.py diff --git a/flash/data/base_viz.py b/flash/data/base_viz.py new file mode 100644 index 0000000000..40e341196e --- /dev/null +++ b/flash/data/base_viz.py @@ -0,0 +1,50 @@ +import functools +from typing import Any, Callable + +from pytorch_lightning.callbacks import Callback +from pytorch_lightning.trainer.states import RunningStage + +from flash.data.data_pipeline import DataPipeline +from flash.data.process import Preprocess + + +class BaseViz(Callback): + + def __init__(self, enabled: bool = False): + self.batches = {"train": {}, "val": {}, "test": {}, "predict": {}} + self.enabled = enabled + self._datamodule = None + + def attach_to_preprocess(self, preprocess: Preprocess) -> None: + self._wrap_functions_per_stage(RunningStage.TRAINING, preprocess) + + def attach_to_datamodule(self, datamodule) -> None: + self._datamodule = datamodule + datamodule.viz = self + + def _wrap_fn( + self, + fn: Callable, + running_stage: RunningStage, + ) -> Callable: + + @functools.wraps(fn) + def wrapper(*args) -> Any: + data = fn(*args) + if self.enabled: + batches = self.batches[running_stage.value] + if fn.__name__ not in batches: + batches[fn.__name__] = [] + batches[fn.__name__].append(data) + return data + + return wrapper + + def _wrap_functions_per_stage(self, running_stage: RunningStage, preprocess: Preprocess): + fn_names = { + k: DataPipeline._resolve_function_hierarchy(k, preprocess, running_stage, Preprocess) + for k in DataPipeline.PREPROCESS_FUNCS + } + for fn_name in fn_names: + fn = getattr(preprocess, fn_name) + setattr(preprocess, fn_name, self._wrap_fn(fn, running_stage)) diff --git a/flash/data/data_module.py b/flash/data/data_module.py index f7c2e8f6d2..286be2b6fa 100644 --- a/flash/data/data_module.py +++ b/flash/data/data_module.py @@ -24,6 +24,7 @@ from torch.utils.data.dataset import Subset from flash.data.auto_dataset import AutoDataset +from flash.data.base_viz import BaseViz from flash.data.data_pipeline import DataPipeline, Postprocess, Preprocess @@ -83,10 +84,23 @@ def __init__( self._preprocess = None self._postprocess = None + self._viz = None # this may also trigger data preloading self.set_running_stages() + @property + def viz(self) -> BaseViz: + return self._viz or DataModule.configure_vis() + + @viz.setter + def viz(self, viz: BaseViz) -> None: + self._viz = viz + + @classmethod + def configure_vis(cls) -> BaseViz: + return BaseViz() + @staticmethod def get_dataset_attribute(dataset: torch.utils.data.Dataset, attr_name: str, default: Optional[Any] = None) -> Any: if isinstance(dataset, Subset): @@ -320,6 +334,9 @@ def from_load_data_inputs( else: data_pipeline = cls(**kwargs).data_pipeline + viz_callback = cls.configure_vis() + viz_callback.attach_to_preprocess(data_pipeline._preprocess_pipeline) + train_dataset = cls._generate_dataset_if_possible( train_load_data_input, running_stage=RunningStage.TRAINING, data_pipeline=data_pipeline ) @@ -341,4 +358,5 @@ def from_load_data_inputs( ) datamodule._preprocess = data_pipeline._preprocess_pipeline datamodule._postprocess = data_pipeline._postprocess_pipeline + viz_callback.attach_to_datamodule(datamodule) return datamodule diff --git a/tests/data/test_data_viz.py b/tests/data/test_data_viz.py new file mode 100644 index 0000000000..45c118bcb0 --- /dev/null +++ b/tests/data/test_data_viz.py @@ -0,0 +1,57 @@ +# Copyright The PyTorch Lightning team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from pathlib import Path + +import numpy as np +from PIL import Image + +from flash.vision import ImageClassificationData + + +def _rand_image(): + return Image.fromarray(np.random.randint(0, 255, (196, 196, 3), dtype="uint8")) + + +class ImageClassificationDataViz(ImageClassificationData): + + def show_train_batch(self): + self.viz.enabled = True + _ = next(iter(self.train_dataloader())) + self.viz.enabled = False + + +def test_base_viz(tmpdir): + tmpdir = Path(tmpdir) + + (tmpdir / "a").mkdir() + (tmpdir / "b").mkdir() + _rand_image().save(tmpdir / "a" / "a_1.png") + _rand_image().save(tmpdir / "a" / "a_2.png") + + _rand_image().save(tmpdir / "b" / "a_1.png") + _rand_image().save(tmpdir / "b" / "a_2.png") + + img_data = ImageClassificationDataViz.from_filepaths( + train_filepaths=[tmpdir / "a", tmpdir / "b"], + train_labels=[0, 1], + batch_size=1, + num_workers=0, + ) + + img_data.show_train_batch() + assert img_data.viz.batches["train"]["load_sample"] is not None + assert img_data.viz.batches["train"]["to_tensor_transform"] is not None + assert img_data.viz.batches["train"]["collate"] is not None + assert img_data.viz.batches["train"]["per_batch_transform"] is not None From 43814412571b308dbe218fe63a35e2044e8c6439 Mon Sep 17 00:00:00 2001 From: tchaton Date: Wed, 31 Mar 2021 18:50:37 +0100 Subject: [PATCH 09/30] update --- flash/vision/classification/data.py | 23 +++++++---------------- 1 file changed, 7 insertions(+), 16 deletions(-) diff --git a/flash/vision/classification/data.py b/flash/vision/classification/data.py index 39bd1a2023..53ef54b7da 100644 --- a/flash/vision/classification/data.py +++ b/flash/vision/classification/data.py @@ -460,43 +460,34 @@ def from_filepaths( ) -> 'ImageClassificationData': """ Creates a ImageClassificationData object from folders of images arranged in this way: :: + folder/dog_xxx.png folder/dog_xxy.png folder/dog_xxz.png folder/cat_123.png folder/cat_nsdf3.png folder/cat_asd932_.png + Args: + train_filepaths: String or sequence of file paths for training dataset. Defaults to ``None``. train_labels: Sequence of labels for training dataset. Defaults to ``None``. val_filepaths: String or sequence of file paths for validation dataset. Defaults to ``None``. val_labels: Sequence of labels for validation dataset. Defaults to ``None``. test_filepaths: String or sequence of file paths for test dataset. Defaults to ``None``. test_labels: Sequence of labels for test dataset. Defaults to ``None``. - train_transform: Transforms for training dataset. Defaults to ``default``, which loads imagenet transforms. + train_transform: Transforms for training dataset. Defaults to ``default``, + which loads imagenet transforms. val_transform: Transforms for validation and testing dataset. Defaults to ``default``, which loads imagenet transforms. batch_size: The batchsize to use for parallel loading. Defaults to ``64``. num_workers: The number of workers to use for parallelized loading. Defaults to ``None`` which equals the number of available CPU threads. seed: Used for the train/val splits. + Returns: + ImageClassificationData: The constructed data module. - Examples: - >>> img_data = ImageClassificationData.from_filepaths(["a.png", "b.png"], [0, 1]) # doctest: +SKIP - Example when labels are in .csv file:: - train_labels = labels_from_categorical_csv('path/to/train.csv', 'my_id') - val_labels = labels_from_categorical_csv(path/to/val.csv', 'my_id') - test_labels = labels_from_categorical_csv(path/to/tests.csv', 'my_id') - data = ImageClassificationData.from_filepaths( - batch_size=2, - train_filepaths='path/to/train', - train_labels=train_labels, - val_filepaths='path/to/val', - val_labels=val_labels, - test_filepaths='path/to/test', - test_labels=test_labels, - ) """ # enable passing in a string which loads all files in that folder as a list if isinstance(train_filepaths, str): From d572248a93747f108ce92b87279765fa4ff963f5 Mon Sep 17 00:00:00 2001 From: tchaton Date: Wed, 31 Mar 2021 19:07:37 +0100 Subject: [PATCH 10/30] resolve doc --- flash/vision/classification/data.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/flash/vision/classification/data.py b/flash/vision/classification/data.py index 53ef54b7da..d0f7e0cb63 100644 --- a/flash/vision/classification/data.py +++ b/flash/vision/classification/data.py @@ -459,7 +459,9 @@ def from_filepaths( **kwargs, ) -> 'ImageClassificationData': """ - Creates a ImageClassificationData object from folders of images arranged in this way: :: + Creates a ImageClassificationData object from folders of images arranged in this way: + + Examples:: folder/dog_xxx.png folder/dog_xxy.png From b928fc5f7506fa3e219cc93e8dc1aeadfabf8cfa Mon Sep 17 00:00:00 2001 From: tchaton Date: Wed, 31 Mar 2021 19:29:34 +0100 Subject: [PATCH 11/30] resolve doc --- flash/vision/classification/data.py | 27 ++++++++++----------------- 1 file changed, 10 insertions(+), 17 deletions(-) diff --git a/flash/vision/classification/data.py b/flash/vision/classification/data.py index 39bd1a2023..d0f7e0cb63 100644 --- a/flash/vision/classification/data.py +++ b/flash/vision/classification/data.py @@ -459,44 +459,37 @@ def from_filepaths( **kwargs, ) -> 'ImageClassificationData': """ - Creates a ImageClassificationData object from folders of images arranged in this way: :: + Creates a ImageClassificationData object from folders of images arranged in this way: + + Examples:: + folder/dog_xxx.png folder/dog_xxy.png folder/dog_xxz.png folder/cat_123.png folder/cat_nsdf3.png folder/cat_asd932_.png + Args: + train_filepaths: String or sequence of file paths for training dataset. Defaults to ``None``. train_labels: Sequence of labels for training dataset. Defaults to ``None``. val_filepaths: String or sequence of file paths for validation dataset. Defaults to ``None``. val_labels: Sequence of labels for validation dataset. Defaults to ``None``. test_filepaths: String or sequence of file paths for test dataset. Defaults to ``None``. test_labels: Sequence of labels for test dataset. Defaults to ``None``. - train_transform: Transforms for training dataset. Defaults to ``default``, which loads imagenet transforms. + train_transform: Transforms for training dataset. Defaults to ``default``, + which loads imagenet transforms. val_transform: Transforms for validation and testing dataset. Defaults to ``default``, which loads imagenet transforms. batch_size: The batchsize to use for parallel loading. Defaults to ``64``. num_workers: The number of workers to use for parallelized loading. Defaults to ``None`` which equals the number of available CPU threads. seed: Used for the train/val splits. + Returns: + ImageClassificationData: The constructed data module. - Examples: - >>> img_data = ImageClassificationData.from_filepaths(["a.png", "b.png"], [0, 1]) # doctest: +SKIP - Example when labels are in .csv file:: - train_labels = labels_from_categorical_csv('path/to/train.csv', 'my_id') - val_labels = labels_from_categorical_csv(path/to/val.csv', 'my_id') - test_labels = labels_from_categorical_csv(path/to/tests.csv', 'my_id') - data = ImageClassificationData.from_filepaths( - batch_size=2, - train_filepaths='path/to/train', - train_labels=train_labels, - val_filepaths='path/to/val', - val_labels=val_labels, - test_filepaths='path/to/test', - test_labels=test_labels, - ) """ # enable passing in a string which loads all files in that folder as a list if isinstance(train_filepaths, str): From 9381d412de7a3ddeac2905c24bebbf8160783f65 Mon Sep 17 00:00:00 2001 From: tchaton Date: Wed, 31 Mar 2021 19:36:14 +0100 Subject: [PATCH 12/30] update doc --- flash/vision/classification/data.py | 6 +----- 1 file changed, 1 insertion(+), 5 deletions(-) diff --git a/flash/vision/classification/data.py b/flash/vision/classification/data.py index d0f7e0cb63..6e7f99ba16 100644 --- a/flash/vision/classification/data.py +++ b/flash/vision/classification/data.py @@ -459,9 +459,7 @@ def from_filepaths( **kwargs, ) -> 'ImageClassificationData': """ - Creates a ImageClassificationData object from folders of images arranged in this way: - - Examples:: + Creates a ImageClassificationData object from folders of images arranged in this way: :: folder/dog_xxx.png folder/dog_xxy.png @@ -471,7 +469,6 @@ def from_filepaths( folder/cat_asd932_.png Args: - train_filepaths: String or sequence of file paths for training dataset. Defaults to ``None``. train_labels: Sequence of labels for training dataset. Defaults to ``None``. val_filepaths: String or sequence of file paths for validation dataset. Defaults to ``None``. @@ -488,7 +485,6 @@ def from_filepaths( seed: Used for the train/val splits. Returns: - ImageClassificationData: The constructed data module. """ # enable passing in a string which loads all files in that folder as a list From 108a7cca4d6c390615fd78ba8c73aa32d7687c4a Mon Sep 17 00:00:00 2001 From: tchaton Date: Thu, 1 Apr 2021 09:53:03 +0100 Subject: [PATCH 13/30] update --- tests/data/test_data_pipeline.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/data/test_data_pipeline.py b/tests/data/test_data_pipeline.py index c1d8ae6b62..172a6793eb 100644 --- a/tests/data/test_data_pipeline.py +++ b/tests/data/test_data_pipeline.py @@ -626,7 +626,7 @@ class CustomDataModule(DataModule): batch = next(iter(datamodule.val_dataloader())) CustomDataModule.preprocess_cls = TestPreprocessTransformations2 - datamodule = CustomDataModule.from_load_data_inputs(1, 1, 1, 1, batch_size=2) + datamodule = CustomDataModule.from_load_data_inputs(1, 1, 1, 1, batch_size=2, num_workers=0) batch = next(iter(datamodule.val_dataloader())) assert torch.equal(batch["a"], tensor([0, 1])) assert torch.equal(batch["b"], tensor([1, 2])) From 6da92b375dee1c9f77d1ce121330ea08ce85f982 Mon Sep 17 00:00:00 2001 From: tchaton Date: Thu, 1 Apr 2021 10:02:25 +0100 Subject: [PATCH 14/30] update --- tests/data/test_data_pipeline.py | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/tests/data/test_data_pipeline.py b/tests/data/test_data_pipeline.py index 172a6793eb..d311011c68 100644 --- a/tests/data/test_data_pipeline.py +++ b/tests/data/test_data_pipeline.py @@ -607,7 +607,10 @@ def test_step(self, batch, batch_idx): assert batch[0].shape == torch.Size([2, 1]) def predict_step(self, batch, batch_idx, dataloader_idx): - assert batch == [('a', 'a'), ('b', 'b')] + assert batch[0][0] == 'a' + assert batch[0][1] == 'a' + assert batch[1][0] == 'b' + assert batch[1][1] == 'b' return tensor([0, 0, 0]) class CustomDataModule(DataModule): From d4cf9f585a6e0a0f0a72d3b49e841ebdefc6c68a Mon Sep 17 00:00:00 2001 From: thomas chaton Date: Thu, 1 Apr 2021 09:21:10 +0000 Subject: [PATCH 15/30] update --- flash/data/data_module.py | 2 +- tests/data/test_data_pipeline.py | 7 +++++-- 2 files changed, 6 insertions(+), 3 deletions(-) diff --git a/flash/data/data_module.py b/flash/data/data_module.py index 286be2b6fa..bb11595e85 100644 --- a/flash/data/data_module.py +++ b/flash/data/data_module.py @@ -54,7 +54,7 @@ def __init__( test_dataset: Optional[Dataset] = None, predict_dataset: Optional[Dataset] = None, batch_size: int = 1, - num_workers: Optional[int] = None, + num_workers: Optional[int] = 0, ) -> None: super().__init__() diff --git a/tests/data/test_data_pipeline.py b/tests/data/test_data_pipeline.py index c1d8ae6b62..d311011c68 100644 --- a/tests/data/test_data_pipeline.py +++ b/tests/data/test_data_pipeline.py @@ -607,7 +607,10 @@ def test_step(self, batch, batch_idx): assert batch[0].shape == torch.Size([2, 1]) def predict_step(self, batch, batch_idx, dataloader_idx): - assert batch == [('a', 'a'), ('b', 'b')] + assert batch[0][0] == 'a' + assert batch[0][1] == 'a' + assert batch[1][0] == 'b' + assert batch[1][1] == 'b' return tensor([0, 0, 0]) class CustomDataModule(DataModule): @@ -626,7 +629,7 @@ class CustomDataModule(DataModule): batch = next(iter(datamodule.val_dataloader())) CustomDataModule.preprocess_cls = TestPreprocessTransformations2 - datamodule = CustomDataModule.from_load_data_inputs(1, 1, 1, 1, batch_size=2) + datamodule = CustomDataModule.from_load_data_inputs(1, 1, 1, 1, batch_size=2, num_workers=0) batch = next(iter(datamodule.val_dataloader())) assert torch.equal(batch["a"], tensor([0, 1])) assert torch.equal(batch["b"], tensor([1, 2])) From 16deb7b0f787b6bcac15f4c942f6d724372464b2 Mon Sep 17 00:00:00 2001 From: tchaton Date: Thu, 1 Apr 2021 10:22:16 +0100 Subject: [PATCH 16/30] convert to staticmethod --- flash/core/model.py | 3 ++- tests/data/test_data_pipeline.py | 2 +- 2 files changed, 3 insertions(+), 2 deletions(-) diff --git a/flash/core/model.py b/flash/core/model.py index b03a424dd2..78c907fc6c 100644 --- a/flash/core/model.py +++ b/flash/core/model.py @@ -245,7 +245,8 @@ def on_fit_end(self) -> None: self.data_pipeline._detach_from_model(self) super().on_fit_end() - def _sanetize_funcs(self, obj: Any) -> Any: + @staticmethod + def _sanetize_funcs(obj: Any) -> Any: if hasattr(obj, "__dict__"): for k, v in obj.__dict__.items(): if isinstance(v, Callable): diff --git a/tests/data/test_data_pipeline.py b/tests/data/test_data_pipeline.py index d311011c68..0154377160 100644 --- a/tests/data/test_data_pipeline.py +++ b/tests/data/test_data_pipeline.py @@ -617,7 +617,7 @@ class CustomDataModule(DataModule): preprocess_cls = TestPreprocessTransformations - datamodule = CustomDataModule.from_load_data_inputs(1, 1, 1, 1, batch_size=2) + datamodule = CustomDataModule.from_load_data_inputs(1, 1, 1, 1, batch_size=2, num_workers=0) assert datamodule.train_dataloader().dataset[0] == (0, 1, 2, 3) batch = next(iter(datamodule.train_dataloader())) From 4025eb096c6fbdeb77396e274fcd41b65d3461f1 Mon Sep 17 00:00:00 2001 From: Edgar Riba Date: Thu, 1 Apr 2021 12:27:07 +0200 Subject: [PATCH 17/30] initial visualisation implementation --- flash/vision/__init__.py | 2 +- flash/vision/classification/__init__.py | 2 +- flash/vision/classification/data.py | 47 +++++++++++++++++++++++++ 3 files changed, 49 insertions(+), 2 deletions(-) diff --git a/flash/vision/__init__.py b/flash/vision/__init__.py index 2023605d7b..42e0ac34aa 100644 --- a/flash/vision/__init__.py +++ b/flash/vision/__init__.py @@ -1,3 +1,3 @@ -from flash.vision.classification import ImageClassificationData, ImageClassifier +from flash.vision.classification import ImageClassificationData, ImageClassifier, ImageClassificationDataViz from flash.vision.detection import ObjectDetectionData, ObjectDetector from flash.vision.embedding import ImageEmbedder diff --git a/flash/vision/classification/__init__.py b/flash/vision/classification/__init__.py index eaeab26233..c8f37d1f76 100644 --- a/flash/vision/classification/__init__.py +++ b/flash/vision/classification/__init__.py @@ -1,2 +1,2 @@ -from flash.vision.classification.data import ImageClassificationData +from flash.vision.classification.data import ImageClassificationData, ImageClassificationDataViz from flash.vision.classification.model import ImageClassifier diff --git a/flash/vision/classification/data.py b/flash/vision/classification/data.py index d0f7e0cb63..fecf23297d 100644 --- a/flash/vision/classification/data.py +++ b/flash/vision/classification/data.py @@ -529,3 +529,50 @@ def from_filepaths( seed=seed, **kwargs ) + +class ImageClassificationDataViz(ImageClassificationData): + + def show_train_batch(self): + self.viz.enabled = True + # fetch batch and cache data + _ = next(iter(self.train_dataloader())) + self.viz.enabled = False + + from typing import List + import kornia as K + import torchvision as tv + from PIL import Image + import numpy as np + import matplotlib.pyplot as plt + + # plot row data + rows: int = 4 # chenge later + data_raw: List[Image] = self.viz.batches['train']['load_sample'] + for num, x_data in enumerate(data_raw): + img, label = x_data + plt.subplot(rows, rows, num + 1) + plt.title(label) + plt.axis('off') + plt.imshow(np.array(img)) + plt.title('load_sample') + plt.show(block=False) + + mean = torch.tensor([0.485, 0.456, 0.406]) + std = torch.tensor([0.229, 0.224, 0.225]) + + # plot pre-process and after augmentations + data1, labels1 = self.viz.batches['train']['collate'][0] # this is before random transforms + data2, labels2 = self.viz.batches['train']['per_batch_transform'][0] # this should be after random transforms + + data1 = K.enhance.denormalize(data1, mean, std) + data2 = K.enhance.denormalize(data2, mean, std) + + # cast and prepare data for viualisation + data1_vis = K.tensor_to_image(tv.utils.make_grid(data1)) + data2_vis = K.tensor_to_image(tv.utils.make_grid(data2)) + + # plot using matplotlib + fig, (ax1, ax2) = plt.subplots(2) + ax1.imshow(data1_vis) + ax2.imshow(data2_vis) + plt.show() \ No newline at end of file From d2076d4a800d2488f151b7c6be5277ddfe593a52 Mon Sep 17 00:00:00 2001 From: Edgar Riba Date: Thu, 1 Apr 2021 12:42:18 +0200 Subject: [PATCH 18/30] implement test case using Kornia transforms --- tests/data/test_data_viz.py | 51 +++++++++++++++++++++++++++++++++++++ 1 file changed, 51 insertions(+) diff --git a/tests/data/test_data_viz.py b/tests/data/test_data_viz.py index 45c118bcb0..4644cee985 100644 --- a/tests/data/test_data_viz.py +++ b/tests/data/test_data_viz.py @@ -17,6 +17,10 @@ import numpy as np from PIL import Image +import torch +import kornia as K +import torchvision.transforms as T + from flash.vision import ImageClassificationData @@ -55,3 +59,50 @@ def test_base_viz(tmpdir): assert img_data.viz.batches["train"]["to_tensor_transform"] is not None assert img_data.viz.batches["train"]["collate"] is not None assert img_data.viz.batches["train"]["per_batch_transform"] is not None + + +def test_base_viz_kornia(tmpdir): + tmpdir = Path(tmpdir) + + (tmpdir / "a").mkdir() + (tmpdir / "b").mkdir() + _rand_image().save(tmpdir / "a" / "a_1.png") + _rand_image().save(tmpdir / "a" / "a_2.png") + + _rand_image().save(tmpdir / "b" / "a_1.png") + _rand_image().save(tmpdir / "b" / "a_2.png") + + # Define the augmentations pipeline + train_transforms = { + # can we just call this `preprocess` ? + # this is needed all the time in train, valid, etc + "pre_tensor_transform": T.Compose([ + T.RandomResizedCrop(224), + T.ToTensor() + ]), + "post_tensor_transform": nn.Sequential( + # Kornia RandomResizeCrop has a bug - I'll debug with Jian. + # K.augmentation.RandomResizedCrop((224, 224), align_corners=True), + K.augmentation.Normalize( + torch.tensor([0.485, 0.456, 0.406]), + torch.tensor([0.229, 0.224, 0.225])), + ), + "per_batch_transform_on_device": nn.Sequential( + K.augmentation.RandomAffine(360., p=0.5), + K.augmentation.ColorJitter(0.2, 0.3, 0.2, 0.3, p=0.5) + ) + } + img_data = ImageClassificationDataViz.from_filepaths( + train_filepaths=[tmpdir / "a", tmpdir / "b"], + train_labels=[0, 1], + batch_size=1, + num_workers=0, + train_transform=train_transforms, + valt_transform=train_transforms, + ) + + img_data.show_train_batch() + assert img_data.viz.batches["train"]["load_sample"] is not None + assert img_data.viz.batches["train"]["to_tensor_transform"] is not None + assert img_data.viz.batches["train"]["collate"] is not None + assert img_data.viz.batches["train"]["per_batch_transform"] is not None From ff8e1adf7161c54016f2bc0f1dd2d5029435b4c9 Mon Sep 17 00:00:00 2001 From: tchaton Date: Thu, 1 Apr 2021 11:42:42 +0100 Subject: [PATCH 19/30] update on comments --- flash/data/auto_dataset.py | 22 ++++---- flash/data/batch.py | 4 +- flash/data/data_pipeline.py | 66 +++++++++++------------ flash/data/process.py | 6 +-- flash/tabular/classification/data/data.py | 2 +- flash/vision/classification/data.py | 3 +- tests/data/test_data_pipeline.py | 8 +-- 7 files changed, 55 insertions(+), 56 deletions(-) diff --git a/flash/data/auto_dataset.py b/flash/data/auto_dataset.py index 7a1bc0455b..bc05f8c441 100644 --- a/flash/data/auto_dataset.py +++ b/flash/data/auto_dataset.py @@ -71,14 +71,14 @@ def running_stage(self) -> Optional[RunningStage]: def running_stage(self, running_stage: RunningStage) -> None: if self._running_stage != running_stage or (not self._running_stage): self._running_stage = running_stage - self._load_data_context = CurrentRunningStageFuncContext(self._running_stage, "load_data", self._preprocess) + self._load_data_context = CurrentRunningStageFuncContext(self._running_stage, "load_data", self.preprocess) self._load_sample_context = CurrentRunningStageFuncContext( - self._running_stage, "load_sample", self._preprocess + self._running_stage, "load_sample", self.preprocess ) self._setup(running_stage) @property - def _preprocess(self): + def preprocess(self) -> Optional[Preprocess]: if self.data_pipeline is not None: return self.data_pipeline._preprocess_pipeline @@ -102,15 +102,15 @@ def _setup(self, stage: Optional[RunningStage]) -> None: if self._running_stage and self.data_pipeline and (not self.load_data or not self.load_sample) and stage: self.load_data = getattr( - self.data_pipeline._preprocess_pipeline, + self.data_pipeline.preprocess_pipeline, self.data_pipeline._resolve_function_hierarchy( - 'load_data', self.data_pipeline._preprocess_pipeline, stage, Preprocess + 'load_data', self.data_pipeline.preprocess_pipeline, stage, Preprocess ) ) self.load_sample = getattr( - self.data_pipeline._preprocess_pipeline, + self.data_pipeline.preprocess_pipeline, self.data_pipeline._resolve_function_hierarchy( - 'load_sample', self.data_pipeline._preprocess_pipeline, stage, Preprocess + 'load_sample', self.data_pipeline.preprocess_pipeline, stage, Preprocess ) ) if self.load_data and (previous_load_data != self.load_data.__code__ or not self._load_data_called): @@ -120,7 +120,7 @@ def _setup(self, stage: Optional[RunningStage]) -> None: "This is not expected! Preloading Data again to ensure compatibility. This may take some time." ) with self._load_data_context: - self._preprocessed_data = self._call_load_data(self.data) + self.preprocessed_data = self._call_load_data(self.data) self._load_data_called = True def __getitem__(self, index: int) -> Any: @@ -128,10 +128,10 @@ def __getitem__(self, index: int) -> Any: raise RuntimeError("`__getitem__` for `load_sample` and `load_data` could not be inferred.") if self.load_sample: with self._load_sample_context: - return self._call_load_sample(self._preprocessed_data[index]) - return self._preprocessed_data[index] + return self._call_load_sample(self.preprocessed_data[index]) + return self.preprocessed_data[index] def __len__(self) -> int: if not self.load_sample and not self.load_data: raise RuntimeError("`__len__` for `load_sample` and `load_data` could not be inferred.") - return len(self._preprocessed_data) + return len(self.preprocessed_data) diff --git a/flash/data/batch.py b/flash/data/batch.py index 3b7bc70ab4..9c7cce304e 100644 --- a/flash/data/batch.py +++ b/flash/data/batch.py @@ -54,7 +54,7 @@ def __init__( self._to_tensor_transform_context = CurrentFuncContext("to_tensor_transform", preprocess) self._post_tensor_transform_context = CurrentFuncContext("post_tensor_transform", preprocess) - def forward(self, sample: Any): + def forward(self, sample: Any) -> Any: with self._current_stage_context: with self._pre_tensor_transform_context: sample = self.pre_tensor_transform(sample) @@ -129,7 +129,7 @@ def __init__( self._collate_context = CurrentFuncContext("collate", preprocess) self._per_batch_transform_context = CurrentFuncContext(f"per_batch_transform_{extension}", preprocess) - def forward(self, samples: Sequence[Any]): + def forward(self, samples: Sequence[Any]) -> Any: with self._current_stage_context: if self.apply_per_sample_transform: diff --git a/flash/data/data_pipeline.py b/flash/data/data_pipeline.py index f28d882a22..7cce7fc04a 100644 --- a/flash/data/data_pipeline.py +++ b/flash/data/data_pipeline.py @@ -14,7 +14,7 @@ import functools import inspect import weakref -from typing import Any, Callable, Iterable, Optional, Sequence, Tuple, Type, TYPE_CHECKING, Union +from typing import Any, Callable, Dict, Iterable, Optional, Sequence, Tuple, Type, TYPE_CHECKING, Union from pytorch_lightning.trainer.connectors.data_connector import _PatchDataLoader from pytorch_lightning.trainer.states import RunningStage @@ -101,31 +101,31 @@ def forward(self, samples: Sequence[Any]): General flow: - load_sample - │ - pre_tensor_transform - │ - to_tensor_transform - │ - post_tensor_transform - │ - ┌────────────────┴───────────────────┐ -(move list to main worker) --> │ │ - per_sample_transform_on_device collate - │ │ - collate per_batch_transform - │ │ <-- (move batch to main worker) - per_batch_transform_on_device per_batch_transform_on_device - │ │ - └─────────────────┬──────────────────┘ - │ - model.predict_step - │ - per_batch_transform - │ - uncollate - │ - per_sample_transform + load_sample + │ + pre_tensor_transform + │ + to_tensor_transform + │ + post_tensor_transform + │ + ┌────────────────┴───────────────────┐ +(move samples's sequence to main worker) --> │ │ + per_sample_transform_on_device collate + │ │ + collate per_batch_transform + │ │ <-- (move batch to main worker) + per_batch_transform_on_device per_batch_transform_on_device + │ │ + └─────────────────┬──────────────────┘ + │ + model.predict_step + │ + per_batch_transform + │ + uncollate + │ + per_sample_transform """ @@ -241,25 +241,25 @@ def _create_collate_preprocessors( if collate_fn is None: collate_fn = default_collate - preprocess = self._preprocess_pipeline + preprocess: Preprocess = self._preprocess_pipeline - func_names = { + func_names: Dict[str, str] = { k: self._resolve_function_hierarchy(k, preprocess, stage, Preprocess) for k in self.PREPROCESS_FUNCS } if self._is_overriden_recursive("collate", preprocess, Preprocess, prefix=_STAGES_PREFIX[stage]): - collate_fn = getattr(preprocess, func_names["collate"]) + collate_fn: Callable = getattr(preprocess, func_names["collate"]) - per_batch_transform_overriden = self._is_overriden_recursive( + per_batch_transform_overriden: bool = self._is_overriden_recursive( "per_batch_transform", preprocess, Preprocess, prefix=_STAGES_PREFIX[stage] ) - per_sample_transform_on_device_overriden = self._is_overriden_recursive( + per_sample_transform_on_device_overriden: bool = self._is_overriden_recursive( "per_sample_transform_on_device", preprocess, Preprocess, prefix=_STAGES_PREFIX[stage] ) - skip_mutual_check = getattr(preprocess, "skip_mutual_check", False) + skip_mutual_check: bool = getattr(preprocess, "skip_mutual_check", False) if (not skip_mutual_check and per_batch_transform_overriden and per_sample_transform_on_device_overriden): raise MisconfigurationException( @@ -562,7 +562,7 @@ def to_dataloader( return DataLoader(self._generate_auto_dataset(data), **loader_kwargs) def __str__(self) -> str: - preprocess = self._preprocess_pipeline + preprocess: Preprocess = self._preprocess_pipeline postprocess = self._postprocess_pipeline return f"{self.__class__.__name__}(preprocess={preprocess}, postprocess={postprocess})" diff --git a/flash/data/process.py b/flash/data/process.py index 4dc1fea4c0..8fd16e5359 100644 --- a/flash/data/process.py +++ b/flash/data/process.py @@ -125,16 +125,16 @@ def skip_mutual_check(self) -> bool: def skip_mutual_check(self, skip_mutual_check: bool) -> None: self._skip_mutual_check = skip_mutual_check - def _identify(self, x): + def _identify(self, x: Any) -> Any: return x - def _get_transform(self, transform: Dict[str, Callable]): + def _get_transform(self, transform: Dict[str, Callable]) -> Callable: if self.current_fn in transform: return transform[self.current_fn] return self._identify @property - def current_transform(self): + def current_transform(self) -> Callable: if self.training and self.train_transform: return self._get_transform(self.train_transform) elif self.validating and self.val_transform: diff --git a/flash/tabular/classification/data/data.py b/flash/tabular/classification/data/data.py index fa62d4e6ca..58f583e524 100644 --- a/flash/tabular/classification/data/data.py +++ b/flash/tabular/classification/data/data.py @@ -349,7 +349,7 @@ def from_df( is_regression, preprocess_state=preprocess_state ) - preprocess = preprocess_cls.from_state(preprocess_state) + preprocess: Preprocess = preprocess_cls.from_state(preprocess_state) return cls.from_load_data_inputs( train_load_data_input=train_df, diff --git a/flash/vision/classification/data.py b/flash/vision/classification/data.py index 6e7f99ba16..d66c9bb355 100644 --- a/flash/vision/classification/data.py +++ b/flash/vision/classification/data.py @@ -342,8 +342,7 @@ def instantiate_preprocess( ) preprocess_cls = preprocess_cls or cls.preprocess_cls - preprocess = preprocess_cls(train_transform, val_transform, test_transform, predict_transform) - # todo (tchaton) add check on mutually exclusive transforms + preprocess: Preprocess = preprocess_cls(train_transform, val_transform, test_transform, predict_transform) return preprocess @classmethod diff --git a/tests/data/test_data_pipeline.py b/tests/data/test_data_pipeline.py index 0154377160..ed7ebe60b9 100644 --- a/tests/data/test_data_pipeline.py +++ b/tests/data/test_data_pipeline.py @@ -130,25 +130,25 @@ def test_per_batch_transform_on_device(self, *_, **__): preprocess = CustomPreprocess() data_pipeline = DataPipeline(preprocess) - train_func_names = { + train_func_names: Dict[str, str] = { k: data_pipeline._resolve_function_hierarchy( k, data_pipeline._preprocess_pipeline, RunningStage.TRAINING, Preprocess ) for k in data_pipeline.PREPROCESS_FUNCS } - val_func_names = { + val_func_names: Dict[str, str] = { k: data_pipeline._resolve_function_hierarchy( k, data_pipeline._preprocess_pipeline, RunningStage.VALIDATING, Preprocess ) for k in data_pipeline.PREPROCESS_FUNCS } - test_func_names = { + test_func_names: Dict[str, str] = { k: data_pipeline._resolve_function_hierarchy( k, data_pipeline._preprocess_pipeline, RunningStage.TESTING, Preprocess ) for k in data_pipeline.PREPROCESS_FUNCS } - predict_func_names = { + predict_func_names: Dict[str, str] = { k: data_pipeline._resolve_function_hierarchy( k, data_pipeline._preprocess_pipeline, RunningStage.PREDICTING, Preprocess ) From 84eaa68c7a6c3d9f7c34dbb3553e1027d3507e7a Mon Sep 17 00:00:00 2001 From: tchaton Date: Thu, 1 Apr 2021 11:48:21 +0100 Subject: [PATCH 20/30] resolve bug --- flash/data/auto_dataset.py | 12 ++++-------- 1 file changed, 4 insertions(+), 8 deletions(-) diff --git a/flash/data/auto_dataset.py b/flash/data/auto_dataset.py index bc05f8c441..5652496c10 100644 --- a/flash/data/auto_dataset.py +++ b/flash/data/auto_dataset.py @@ -102,16 +102,12 @@ def _setup(self, stage: Optional[RunningStage]) -> None: if self._running_stage and self.data_pipeline and (not self.load_data or not self.load_sample) and stage: self.load_data = getattr( - self.data_pipeline.preprocess_pipeline, - self.data_pipeline._resolve_function_hierarchy( - 'load_data', self.data_pipeline.preprocess_pipeline, stage, Preprocess - ) + self.preprocess, + self.data_pipeline._resolve_function_hierarchy('load_data', self.preprocess, stage, Preprocess) ) self.load_sample = getattr( - self.data_pipeline.preprocess_pipeline, - self.data_pipeline._resolve_function_hierarchy( - 'load_sample', self.data_pipeline.preprocess_pipeline, stage, Preprocess - ) + self.preprocess, + self.data_pipeline._resolve_function_hierarchy('load_sample', self.preprocess, stage, Preprocess) ) if self.load_data and (previous_load_data != self.load_data.__code__ or not self._load_data_called): if previous_load_data: From fb25c048f2f06535c37bca456518c18cbc0449fb Mon Sep 17 00:00:00 2001 From: tchaton Date: Thu, 1 Apr 2021 12:14:13 +0100 Subject: [PATCH 21/30] update --- flash/data/batch.py | 6 +++--- flash/vision/classification/data.py | 1 + tests/data/test_data_viz.py | 4 +--- 3 files changed, 5 insertions(+), 6 deletions(-) diff --git a/flash/data/batch.py b/flash/data/batch.py index 9c7cce304e..e405f56b2b 100644 --- a/flash/data/batch.py +++ b/flash/data/batch.py @@ -123,11 +123,11 @@ def __init__( self.stage = stage self.on_device = on_device - extension = f"{'on_device' if self.on_device else ''}" + extension = f"{'_on_device' if self.on_device else ''}" self._current_stage_context = CurrentRunningStageContext(stage, preprocess) - self._per_sample_transform_context = CurrentFuncContext(f"per_sample_transform_{extension}", preprocess) + self._per_sample_transform_context = CurrentFuncContext(f"per_sample_transform{extension}", preprocess) self._collate_context = CurrentFuncContext("collate", preprocess) - self._per_batch_transform_context = CurrentFuncContext(f"per_batch_transform_{extension}", preprocess) + self._per_batch_transform_context = CurrentFuncContext(f"per_batch_transform{extension}", preprocess) def forward(self, samples: Sequence[Any]) -> Any: with self._current_stage_context: diff --git a/flash/vision/classification/data.py b/flash/vision/classification/data.py index bb0cf24914..eb71711aee 100644 --- a/flash/vision/classification/data.py +++ b/flash/vision/classification/data.py @@ -176,6 +176,7 @@ def to_tensor_transform(self, sample: Any) -> Any: def post_tensor_transform(self, sample: Any) -> Any: return self.common_step(sample) + # todo bug (tchaton) where to place the collate. Need an indication. def per_batch_transform(self, sample: Any) -> Any: return self.common_step(sample) diff --git a/tests/data/test_data_viz.py b/tests/data/test_data_viz.py index 74a9a04506..5727ec8d97 100644 --- a/tests/data/test_data_viz.py +++ b/tests/data/test_data_viz.py @@ -82,7 +82,7 @@ def test_base_viz_kornia(tmpdir): # K.augmentation.RandomResizedCrop((224, 224), align_corners=True), K.augmentation.Normalize(torch.tensor([0.485, 0.456, 0.406]), torch.tensor([0.229, 0.224, 0.225])), ), - "per_batch_transform_on_device": nn.Sequential( + "per_batch_transform": nn.Sequential( K.augmentation.RandomAffine(360., p=0.5), K.augmentation.ColorJitter(0.2, 0.3, 0.2, 0.3, p=0.5) ) } @@ -96,8 +96,6 @@ def test_base_viz_kornia(tmpdir): ) img_data.show_train_batch() - import pdb - pdb.set_trace() assert img_data.viz.batches["train"]["load_sample"] is not None assert img_data.viz.batches["train"]["to_tensor_transform"] is not None assert img_data.viz.batches["train"]["collate"] is not None From d3932c9ab593f421dc84c08d17d447ada581fad3 Mon Sep 17 00:00:00 2001 From: tchaton Date: Thu, 1 Apr 2021 12:43:01 +0100 Subject: [PATCH 22/30] update --- flash/data/base_viz.py | 10 ++++++---- tests/data/test_data_viz.py | 22 ++++++++++++++++------ 2 files changed, 22 insertions(+), 10 deletions(-) diff --git a/flash/data/base_viz.py b/flash/data/base_viz.py index 40e341196e..d071948f5f 100644 --- a/flash/data/base_viz.py +++ b/flash/data/base_viz.py @@ -6,14 +6,16 @@ from flash.data.data_pipeline import DataPipeline from flash.data.process import Preprocess +from flash.data.utils import _STAGES_PREFIX class BaseViz(Callback): def __init__(self, enabled: bool = False): - self.batches = {"train": {}, "val": {}, "test": {}, "predict": {}} + self.batches = {k: {} for k in _STAGES_PREFIX.values()} self.enabled = enabled self._datamodule = None + self._preprocess = None def attach_to_preprocess(self, preprocess: Preprocess) -> None: self._wrap_functions_per_stage(RunningStage.TRAINING, preprocess) @@ -25,14 +27,13 @@ def attach_to_datamodule(self, datamodule) -> None: def _wrap_fn( self, fn: Callable, - running_stage: RunningStage, ) -> Callable: @functools.wraps(fn) def wrapper(*args) -> Any: data = fn(*args) if self.enabled: - batches = self.batches[running_stage.value] + batches = self.batches[_STAGES_PREFIX[self._preprocess.running_stage]] if fn.__name__ not in batches: batches[fn.__name__] = [] batches[fn.__name__].append(data) @@ -41,10 +42,11 @@ def wrapper(*args) -> Any: return wrapper def _wrap_functions_per_stage(self, running_stage: RunningStage, preprocess: Preprocess): + self._preprocess = preprocess fn_names = { k: DataPipeline._resolve_function_hierarchy(k, preprocess, running_stage, Preprocess) for k in DataPipeline.PREPROCESS_FUNCS } for fn_name in fn_names: fn = getattr(preprocess, fn_name) - setattr(preprocess, fn_name, self._wrap_fn(fn, running_stage)) + setattr(preprocess, fn_name, self._wrap_fn(fn)) diff --git a/tests/data/test_data_viz.py b/tests/data/test_data_viz.py index 5727ec8d97..a97e54521c 100644 --- a/tests/data/test_data_viz.py +++ b/tests/data/test_data_viz.py @@ -21,6 +21,7 @@ from PIL import Image from torch import nn +from flash.data.utils import _STAGES_PREFIX from flash.vision import ImageClassificationData @@ -30,9 +31,12 @@ def _rand_image(): class ImageClassificationDataViz(ImageClassificationData): - def show_train_batch(self): + def show_batch(self): self.viz.enabled = True _ = next(iter(self.train_dataloader())) + _ = next(iter(self.val_dataloader())) + _ = next(iter(self.test_dataloader())) + _ = next(iter(self.predict_dataloader())) self.viz.enabled = False @@ -50,15 +54,21 @@ def test_base_viz(tmpdir): img_data = ImageClassificationDataViz.from_filepaths( train_filepaths=[tmpdir / "a", tmpdir / "b"], train_labels=[0, 1], + val_filepaths=[tmpdir / "a", tmpdir / "b"], + val_labels=[0, 1], + test_filepaths=[tmpdir / "a", tmpdir / "b"], + test_labels=[0, 1], + predict_filepaths=[tmpdir / "a", tmpdir / "b"], batch_size=1, num_workers=0, ) - img_data.show_train_batch() - assert img_data.viz.batches["train"]["load_sample"] is not None - assert img_data.viz.batches["train"]["to_tensor_transform"] is not None - assert img_data.viz.batches["train"]["collate"] is not None - assert img_data.viz.batches["train"]["per_batch_transform"] is not None + img_data.show_batch() + for stage in _STAGES_PREFIX.values(): + assert img_data.viz.batches[stage]["load_sample"] is not None + assert img_data.viz.batches[stage]["to_tensor_transform"] is not None + assert img_data.viz.batches[stage]["collate"] is not None + assert img_data.viz.batches[stage]["per_batch_transform"] is not None def test_base_viz_kornia(tmpdir): From f6f33b86c1a4a83abb986ec8954841a4bd9a2d7d Mon Sep 17 00:00:00 2001 From: tchaton Date: Thu, 1 Apr 2021 13:23:19 +0100 Subject: [PATCH 23/30] add test --- flash/data/base_viz.py | 11 ++++++ tests/data/test_data_viz.py | 75 +++++++++++++------------------------ 2 files changed, 38 insertions(+), 48 deletions(-) diff --git a/flash/data/base_viz.py b/flash/data/base_viz.py index d071948f5f..dd9a45ef0d 100644 --- a/flash/data/base_viz.py +++ b/flash/data/base_viz.py @@ -1,4 +1,5 @@ import functools +from contextlib import contextmanager from typing import Any, Callable from pytorch_lightning.callbacks import Callback @@ -10,6 +11,10 @@ class BaseViz(Callback): + """ + This class is used to profile ``Preprocess`` hook outputs and visualize the data transformations. + It is disabled by default. + """ def __init__(self, enabled: bool = False): self.batches = {k: {} for k in _STAGES_PREFIX.values()} @@ -17,6 +22,12 @@ def __init__(self, enabled: bool = False): self._datamodule = None self._preprocess = None + @contextmanager + def enable(self): + self.enabled = True + yield + self.enabled = False + def attach_to_preprocess(self, preprocess: Preprocess) -> None: self._wrap_functions_per_stage(RunningStage.TRAINING, preprocess) diff --git a/tests/data/test_data_viz.py b/tests/data/test_data_viz.py index a97e54521c..78370a2679 100644 --- a/tests/data/test_data_viz.py +++ b/tests/data/test_data_viz.py @@ -19,7 +19,7 @@ import torch import torchvision.transforms as T from PIL import Image -from torch import nn +from pytorch_lightning import seed_everything from flash.data.utils import _STAGES_PREFIX from flash.vision import ImageClassificationData @@ -32,15 +32,16 @@ def _rand_image(): class ImageClassificationDataViz(ImageClassificationData): def show_batch(self): - self.viz.enabled = True - _ = next(iter(self.train_dataloader())) - _ = next(iter(self.val_dataloader())) - _ = next(iter(self.test_dataloader())) - _ = next(iter(self.predict_dataloader())) - self.viz.enabled = False + # viz needs to be enabled, so it doesn't store profile transforms during training + with self.viz.enable(): + _ = next(iter(self.train_dataloader())) + _ = next(iter(self.val_dataloader())) + _ = next(iter(self.test_dataloader())) + _ = next(iter(self.predict_dataloader())) def test_base_viz(tmpdir): + seed_everything(42) tmpdir = Path(tmpdir) (tmpdir / "a").mkdir() @@ -59,54 +60,32 @@ def test_base_viz(tmpdir): test_filepaths=[tmpdir / "a", tmpdir / "b"], test_labels=[0, 1], predict_filepaths=[tmpdir / "a", tmpdir / "b"], - batch_size=1, + batch_size=2, num_workers=0, ) img_data.show_batch() for stage in _STAGES_PREFIX.values(): - assert img_data.viz.batches[stage]["load_sample"] is not None - assert img_data.viz.batches[stage]["to_tensor_transform"] is not None - assert img_data.viz.batches[stage]["collate"] is not None - assert img_data.viz.batches[stage]["per_batch_transform"] is not None + is_predict = stage == "predict" + def extract_data(data): + if not is_predict: + return data[0][0] + return data[0] -def test_base_viz_kornia(tmpdir): - tmpdir = Path(tmpdir) + assert isinstance(extract_data(img_data.viz.batches[stage]["load_sample"]), Image.Image) + if not is_predict: + assert isinstance(img_data.viz.batches[stage]["load_sample"][0][1], int) - (tmpdir / "a").mkdir() - (tmpdir / "b").mkdir() - _rand_image().save(tmpdir / "a" / "a_1.png") - _rand_image().save(tmpdir / "a" / "a_2.png") + assert isinstance(extract_data(img_data.viz.batches[stage]["to_tensor_transform"]), torch.Tensor) + if not is_predict: + assert isinstance(img_data.viz.batches[stage]["to_tensor_transform"][0][1], int) - _rand_image().save(tmpdir / "b" / "a_1.png") - _rand_image().save(tmpdir / "b" / "a_2.png") - - # Define the augmentations pipeline - train_transforms = { - # can we just call this `preprocess` ? - # this is needed all the time in train, valid, etc - "pre_tensor_transform": T.Compose([T.RandomResizedCrop(224), T.ToTensor()]), - "post_tensor_transform": nn.Sequential( - # Kornia RandomResizeCrop has a bug - I'll debug with Jian. - # K.augmentation.RandomResizedCrop((224, 224), align_corners=True), - K.augmentation.Normalize(torch.tensor([0.485, 0.456, 0.406]), torch.tensor([0.229, 0.224, 0.225])), - ), - "per_batch_transform": nn.Sequential( - K.augmentation.RandomAffine(360., p=0.5), K.augmentation.ColorJitter(0.2, 0.3, 0.2, 0.3, p=0.5) - ) - } - img_data = ImageClassificationDataViz.from_filepaths( - train_filepaths=[tmpdir / "a", tmpdir / "b"], - train_labels=[0, 1], - batch_size=1, - num_workers=0, - train_transform=train_transforms, - valt_transform=train_transforms, - ) + assert extract_data(img_data.viz.batches[stage]["collate"]).shape == torch.Size([2, 3, 196, 196]) + if not is_predict: + assert img_data.viz.batches[stage]["collate"][0][1].shape == torch.Size([2]) - img_data.show_train_batch() - assert img_data.viz.batches["train"]["load_sample"] is not None - assert img_data.viz.batches["train"]["to_tensor_transform"] is not None - assert img_data.viz.batches["train"]["collate"] is not None - assert img_data.viz.batches["train"]["per_batch_transform"] is not None + generated = extract_data(img_data.viz.batches[stage]["per_batch_transform"]).shape + assert generated == torch.Size([2, 3, 196, 196]) + if not is_predict: + assert img_data.viz.batches[stage]["per_batch_transform"][0][1].shape == torch.Size([2]) From 2de0e15e169d884b31a01e3ffcc0f67b5488f55f Mon Sep 17 00:00:00 2001 From: tchaton Date: Thu, 1 Apr 2021 16:35:48 +0100 Subject: [PATCH 24/30] update --- flash/data/data_module.py | 50 +++++++++++++++++++++++++++++ flash/vision/classification/data.py | 1 + test.py | 7 ++++ tests/data/test_data_viz.py | 16 ++------- 4 files changed, 61 insertions(+), 13 deletions(-) create mode 100644 test.py diff --git a/flash/data/data_module.py b/flash/data/data_module.py index bb11595e85..cdc04b8cff 100644 --- a/flash/data/data_module.py +++ b/flash/data/data_module.py @@ -26,6 +26,7 @@ from flash.data.auto_dataset import AutoDataset from flash.data.base_viz import BaseViz from flash.data.data_pipeline import DataPipeline, Postprocess, Preprocess +from flash.data.utils import _STAGES_PREFIX class DataModule(pl.LightningDataModule): @@ -101,6 +102,55 @@ def viz(self, viz: BaseViz) -> None: def configure_vis(cls) -> BaseViz: return BaseViz() + def show_batch(self): + # viz needs to be enabled, so it doesn't store profile transforms during training + with self.viz.enable(): + _ = next(iter(self.train_dataloader())) + _ = next(iter(self.val_dataloader())) + _ = next(iter(self.test_dataloader())) + _ = next(iter(self.predict_dataloader())) + + def visualize(self, batch: Dict[str, Any], stage: RunningStage) -> None: + """ + This function is a hook for users to override with their visualization on a batch. + """ + pass + + def _show_batch(self, stage: RunningStage, reset: bool = True) -> None: + """ + This function is used to handle transforms profiling for batch visualization. + """ + iter_name = f"_{stage}_iter" + + def _reset_iterator(): + dataloader_fn = getattr(self, f"{stage}_dataloader") + setattr(self, iter_name, iter(dataloader_fn())) + + if not hasattr(self, iter_name): + _reset_iterator() + iter_dataloader = getattr(self, iter_name) + with self.viz.enable(): + try: + _ = next(iter_dataloader) + except StopIteration: + _reset_iterator() + _ = next(iter_dataloader) + self.visualize(self.viz.batches[stage], stage) + if reset: + self.viz.batches[stage] = {} + + def show_train_batch(self, reset: bool = True) -> None: + self._show_batch(_STAGES_PREFIX[RunningStage.TRAINING], reset=reset) + + def show_val_batch(self, reset: bool = True) -> None: + self._show_batch(_STAGES_PREFIX[RunningStage.VALIDATING], reset=reset) + + def show_test_batch(self, reset: bool = True) -> None: + self._show_batch(_STAGES_PREFIX[RunningStage.TESTING], reset=reset) + + def show_predict_batch(self, reset: bool = True) -> None: + self._show_batch(_STAGES_PREFIX[RunningStage.PREDICTING], reset=reset) + @staticmethod def get_dataset_attribute(dataset: torch.utils.data.Dataset, attr_name: str, default: Optional[Any] = None) -> Any: if isinstance(dataset, Subset): diff --git a/flash/vision/classification/data.py b/flash/vision/classification/data.py index eb71711aee..5f6c4cc0d6 100644 --- a/flash/vision/classification/data.py +++ b/flash/vision/classification/data.py @@ -221,6 +221,7 @@ def __init__( predict_dataset=predict_dataset, batch_size=batch_size, num_workers=num_workers, + **kwargs, ) self._num_classes = None diff --git a/test.py b/test.py new file mode 100644 index 0000000000..23256dca00 --- /dev/null +++ b/test.py @@ -0,0 +1,7 @@ +# %% +msg = "Hello World" +print(msg) + +# %% +msg = "Hello again" +print(msg) diff --git a/tests/data/test_data_viz.py b/tests/data/test_data_viz.py index 78370a2679..ab08da43aa 100644 --- a/tests/data/test_data_viz.py +++ b/tests/data/test_data_viz.py @@ -29,17 +29,6 @@ def _rand_image(): return Image.fromarray(np.random.randint(0, 255, (196, 196, 3), dtype="uint8")) -class ImageClassificationDataViz(ImageClassificationData): - - def show_batch(self): - # viz needs to be enabled, so it doesn't store profile transforms during training - with self.viz.enable(): - _ = next(iter(self.train_dataloader())) - _ = next(iter(self.val_dataloader())) - _ = next(iter(self.test_dataloader())) - _ = next(iter(self.predict_dataloader())) - - def test_base_viz(tmpdir): seed_everything(42) tmpdir = Path(tmpdir) @@ -52,7 +41,7 @@ def test_base_viz(tmpdir): _rand_image().save(tmpdir / "b" / "a_1.png") _rand_image().save(tmpdir / "b" / "a_2.png") - img_data = ImageClassificationDataViz.from_filepaths( + img_data = ImageClassificationData.from_filepaths( train_filepaths=[tmpdir / "a", tmpdir / "b"], train_labels=[0, 1], val_filepaths=[tmpdir / "a", tmpdir / "b"], @@ -64,8 +53,9 @@ def test_base_viz(tmpdir): num_workers=0, ) - img_data.show_batch() for stage in _STAGES_PREFIX.values(): + + getattr(img_data, f"show_{stage}_batch")(reset=False) is_predict = stage == "predict" def extract_data(data): From 631f06f9d9b5c703a997ee967f985cbb1c4c3a86 Mon Sep 17 00:00:00 2001 From: tchaton Date: Tue, 6 Apr 2021 13:16:54 +0100 Subject: [PATCH 25/30] resolve tests --- flash/core/model.py | 12 +- flash/core/utils.py | 14 +- flash/data/auto_dataset.py | 12 +- flash/data/base_viz.py | 111 ++++++++++----- flash/data/batch.py | 19 ++- flash/data/callback.py | 67 +++++++++ flash/data/data_module.py | 29 ++-- flash/data/data_pipeline.py | 14 +- flash/data/process.py | 16 ++- flash/data/utils.py | 12 ++ .../finetuning/image_classification.py | 1 + tests/data/test_auto_dataset.py | 1 + tests/data/test_data_pipeline.py | 130 +++++++++++++----- tests/data/test_data_viz.py | 79 +++++++++-- tests/vision/classification/test_data.py | 4 +- 15 files changed, 390 insertions(+), 131 deletions(-) create mode 100644 flash/data/callback.py diff --git a/flash/core/model.py b/flash/core/model.py index 78c907fc6c..5be48c3384 100644 --- a/flash/core/model.py +++ b/flash/core/model.py @@ -245,22 +245,14 @@ def on_fit_end(self) -> None: self.data_pipeline._detach_from_model(self) super().on_fit_end() - @staticmethod - def _sanetize_funcs(obj: Any) -> Any: - if hasattr(obj, "__dict__"): - for k, v in obj.__dict__.items(): - if isinstance(v, Callable): - obj.__dict__[k] = inspect.unwrap(v) - return obj - def on_save_checkpoint(self, checkpoint: Dict[str, Any]) -> None: # TODO: Is this the best way to do this? or should we also use some kind of hparams here? # This may be an issue since here we create the same problems with pickle as in # https://pytorch.org/docs/stable/notes/serialization.html if self.data_pipeline is not None and 'data_pipeline' not in checkpoint: - self._preprocess = self._sanetize_funcs(self._preprocess) + # todo (tchaton): TypeError: cannot pickle '_io.TextIOWrapper' object with BaseViz Callback + self.data_pipeline._preprocess_pipeline._callbacks = [] checkpoint['data_pipeline'] = self.data_pipeline - # todo (tchaton) re-wrap visualization super().on_save_checkpoint(checkpoint) def on_load_checkpoint(self, checkpoint: Dict[str, Any]) -> None: diff --git a/flash/core/utils.py b/flash/core/utils.py index 353124ef94..040d6e28d6 100644 --- a/flash/core/utils.py +++ b/flash/core/utils.py @@ -11,7 +11,7 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -from typing import Callable, Dict, Mapping, Sequence, Union +from typing import Any, Callable, Dict, Mapping, Sequence, Union def get_callable_name(fn_or_class: Union[Callable, object]) -> str: @@ -25,3 +25,15 @@ def get_callable_dict(fn: Union[Callable, Mapping, Sequence]) -> Union[Dict, Map return {get_callable_name(f): f for f in fn} elif callable(fn): return {get_callable_name(fn): fn} + + +def _is_overriden(method_name: str, process_obj, super_obj: Any) -> bool: + """ + Cropped Version of + https://github.com/PyTorchLightning/pytorch-lightning/blob/master/pytorch_lightning/utilities/model_helpers.py + """ + + if not hasattr(process_obj, method_name): + return False + + return getattr(process_obj, method_name).__code__ != getattr(super_obj, method_name).__code__ diff --git a/flash/data/auto_dataset.py b/flash/data/auto_dataset.py index 5652496c10..0f3c2cbad5 100644 --- a/flash/data/auto_dataset.py +++ b/flash/data/auto_dataset.py @@ -19,6 +19,7 @@ from pytorch_lightning.utilities.warning_utils import rank_zero_warn from torch.utils.data import Dataset +from flash.data.callback import ControlFlow from flash.data.process import Preprocess from flash.data.utils import _STAGES_PREFIX, _STAGES_PREFIX_VALUES, CurrentRunningStageFuncContext @@ -82,6 +83,12 @@ def preprocess(self) -> Optional[Preprocess]: if self.data_pipeline is not None: return self.data_pipeline._preprocess_pipeline + @property + def control_flow_callback(self) -> Optional[ControlFlow]: + preprocess = self.preprocess + if preprocess is not None: + return ControlFlow(preprocess.callbacks) + def _call_load_data(self, data: Any) -> Iterable: parameters = signature(self.load_data).parameters if len(parameters) > 1 and self.DATASET_KEY in parameters: @@ -124,7 +131,10 @@ def __getitem__(self, index: int) -> Any: raise RuntimeError("`__getitem__` for `load_sample` and `load_data` could not be inferred.") if self.load_sample: with self._load_sample_context: - return self._call_load_sample(self.preprocessed_data[index]) + data = self._call_load_sample(self.preprocessed_data[index]) + if self.control_flow_callback: + self.control_flow_callback.on_load_sample(data, self.running_stage) + return data return self.preprocessed_data[index] def __len__(self) -> int: diff --git a/flash/data/base_viz.py b/flash/data/base_viz.py index dd9a45ef0d..257dd88992 100644 --- a/flash/data/base_viz.py +++ b/flash/data/base_viz.py @@ -1,16 +1,16 @@ -import functools from contextlib import contextmanager -from typing import Any, Callable +from typing import Any, Dict, List, Sequence -from pytorch_lightning.callbacks import Callback from pytorch_lightning.trainer.states import RunningStage +from torch import Tensor -from flash.data.data_pipeline import DataPipeline +from flash.core.utils import _is_overriden +from flash.data.callback import FlashCallback from flash.data.process import Preprocess -from flash.data.utils import _STAGES_PREFIX +from flash.data.utils import _PREPROCESS_FUNCS, _STAGES_PREFIX -class BaseViz(Callback): +class BaseViz(FlashCallback): """ This class is used to profile ``Preprocess`` hook outputs and visualize the data transformations. It is disabled by default. @@ -22,42 +22,89 @@ def __init__(self, enabled: bool = False): self._datamodule = None self._preprocess = None + def on_load_sample(self, sample: Any, running_stage: RunningStage) -> None: + store = self.batches[_STAGES_PREFIX[running_stage]] + store.setdefault("load_sample", []) + store["load_sample"].append(sample) + + def on_pre_tensor_transform(self, sample: Any, running_stage: RunningStage) -> None: + store = self.batches[_STAGES_PREFIX[running_stage]] + store.setdefault("pre_tensor_transform", []) + store["pre_tensor_transform"].append(sample) + + def on_to_tensor_transform(self, sample: Any, running_stage: RunningStage) -> None: + store = self.batches[_STAGES_PREFIX[running_stage]] + store.setdefault("to_tensor_transform", []) + store["to_tensor_transform"].append(sample) + + def on_post_tensor_transform(self, sample: Tensor, running_stage: RunningStage) -> None: + store = self.batches[_STAGES_PREFIX[running_stage]] + store.setdefault("post_tensor_transform", []) + store["post_tensor_transform"].append(sample) + + def on_per_batch_transform(self, batch: Any, running_stage: RunningStage) -> None: + store = self.batches[_STAGES_PREFIX[running_stage]] + store.setdefault("per_batch_transform", []) + store["per_batch_transform"].append(batch) + + def on_collate(self, batch: Sequence, running_stage: RunningStage) -> None: + store = self.batches[_STAGES_PREFIX[running_stage]] + store.setdefault("collate", []) + store["collate"].append(batch) + + def on_per_sample_transform_on_device(self, samples: Sequence, running_stage: RunningStage) -> None: + store = self.batches[_STAGES_PREFIX[running_stage]] + store.setdefault("per_sample_transform_on_device", []) + store["per_sample_transform_on_device"].append(samples) + + def on_per_batch_transform_on_device(self, batch: Any, running_stage: RunningStage) -> None: + store = self.batches[_STAGES_PREFIX[running_stage]] + store.setdefault("per_batch_transform_on_device", []) + store["per_batch_transform_on_device"].append(batch) + @contextmanager def enable(self): self.enabled = True yield self.enabled = False - def attach_to_preprocess(self, preprocess: Preprocess) -> None: - self._wrap_functions_per_stage(RunningStage.TRAINING, preprocess) - def attach_to_datamodule(self, datamodule) -> None: self._datamodule = datamodule datamodule.viz = self - def _wrap_fn( - self, - fn: Callable, - ) -> Callable: + def attach_to_preprocess(self, preprocess: Preprocess) -> None: + preprocess.callbacks = [self] + self._preprocess = preprocess - @functools.wraps(fn) - def wrapper(*args) -> Any: - data = fn(*args) - if self.enabled: - batches = self.batches[_STAGES_PREFIX[self._preprocess.running_stage]] - if fn.__name__ not in batches: - batches[fn.__name__] = [] - batches[fn.__name__].append(data) - return data + def show(self, batch: Dict[str, Any], running_stage: RunningStage) -> None: + """ + This function is a hook for users to override with their visualization on a batch. + """ + for func_name in _PREPROCESS_FUNCS: + hook_name = f"show_{func_name}" + if _is_overriden(hook_name, self, BaseViz): + getattr(self, hook_name)(batch[func_name], running_stage) - return wrapper + def show_load_sample(self, samples: List[Any], running_stage: RunningStage): + pass - def _wrap_functions_per_stage(self, running_stage: RunningStage, preprocess: Preprocess): - self._preprocess = preprocess - fn_names = { - k: DataPipeline._resolve_function_hierarchy(k, preprocess, running_stage, Preprocess) - for k in DataPipeline.PREPROCESS_FUNCS - } - for fn_name in fn_names: - fn = getattr(preprocess, fn_name) - setattr(preprocess, fn_name, self._wrap_fn(fn)) + def show_pre_tensor_transform(self, samples: List[Any], running_stage: RunningStage): + pass + + def show_to_tensor_transform(self, samples: List[Any], running_stage: RunningStage): + pass + + def show_post_tensor_transform(self, samples: List[Any], running_stage: RunningStage): + pass + + def show_collate(self, batch: Sequence, running_stage: RunningStage) -> None: + pass + + def show_per_batch_transform(self, batch: Any, running_stage: RunningStage) -> None: + pass + + def show_per_sample_transform_on_device(self, samples: Sequence, running_stage: RunningStage) -> None: + pass + + def show_per_batch_transform_on_device(self, batch: Any, running_stage: RunningStage) -> None: + pass diff --git a/flash/data/batch.py b/flash/data/batch.py index e405f56b2b..a0435f9332 100644 --- a/flash/data/batch.py +++ b/flash/data/batch.py @@ -18,6 +18,7 @@ from pytorch_lightning.utilities.exceptions import MisconfigurationException from torch import Tensor +from flash.data.callback import ControlFlow from flash.data.utils import _contains_any_tensor, convert_to_modules, CurrentFuncContext, CurrentRunningStageContext if TYPE_CHECKING: @@ -43,6 +44,7 @@ def __init__( ): super().__init__() self.preprocess = preprocess + self.callback = ControlFlow(self.preprocess.callbacks) self.pre_tensor_transform = convert_to_modules(pre_tensor_transform) self.to_tensor_transform = convert_to_modules(to_tensor_transform) self.post_tensor_transform = convert_to_modules(post_tensor_transform) @@ -58,9 +60,11 @@ def forward(self, sample: Any) -> Any: with self._current_stage_context: with self._pre_tensor_transform_context: sample = self.pre_tensor_transform(sample) + self.callback.on_pre_tensor_transform(sample, self.stage) with self._to_tensor_transform_context: sample = self.to_tensor_transform(sample) + self.callback.on_to_tensor_transform(sample, self.stage) if self.assert_contains_tensor: if not _contains_any_tensor(sample): @@ -71,6 +75,7 @@ def forward(self, sample: Any) -> Any: with self._post_tensor_transform_context: sample = self.post_tensor_transform(sample) + self.callback.on_post_tensor_transform(sample, self.stage) return sample @@ -112,10 +117,11 @@ def __init__( per_batch_transform: Callable, stage: RunningStage, apply_per_sample_transform: bool = True, - on_device: bool = False + on_device: bool = False, ): super().__init__() self.preprocess = preprocess + self.callback = ControlFlow(self.preprocess.callbacks) self.collate_fn = convert_to_modules(collate_fn) self.per_sample_transform = convert_to_modules(per_sample_transform) self.per_batch_transform = convert_to_modules(per_batch_transform) @@ -134,14 +140,21 @@ def forward(self, samples: Sequence[Any]) -> Any: if self.apply_per_sample_transform: with self._per_sample_transform_context: - samples = [self.per_sample_transform(sample) for sample in samples] - samples = type(samples)(samples) + _samples = [] + for sample in samples: + sample = self.per_sample_transform(sample) + self.callback.on_pre_tensor_transform(sample, self.stage) + _samples.append(sample) + + samples = type(_samples)(_samples) with self._collate_context: samples = self.collate_fn(samples) + self.callback.on_collate(samples, self.stage) with self._per_batch_transform_context: samples = self.per_batch_transform(samples) + self.callback.on_per_batch_transform(samples, self.stage) return samples def __str__(self) -> str: diff --git a/flash/data/callback.py b/flash/data/callback.py new file mode 100644 index 0000000000..97a8c03549 --- /dev/null +++ b/flash/data/callback.py @@ -0,0 +1,67 @@ +from typing import Any, List, Sequence + +from pytorch_lightning.callbacks import Callback +from pytorch_lightning.trainer.states import RunningStage +from torch import Tensor + + +class FlashCallback(Callback): + + def on_load_sample(self, sample: Any, running_stage: RunningStage) -> None: + """Called once a sample has been loaded.""" + + def on_pre_tensor_transform(self, sample: Any, running_stage: RunningStage) -> None: + """Called once an object has been transformed.""" + + def on_to_tensor_transform(self, sample: Any, running_stage: RunningStage) -> None: + """Called once an object has been transformed to a tensor.""" + + def on_post_tensor_transform(self, sample: Tensor, running_stage: RunningStage) -> None: + """Called after `post_tensor_transform` """ + + def on_per_batch_transform(self, batch: Any, running_stage: RunningStage) -> None: + """Called after `per_batch_transform` """ + + def on_collate(self, batch: Sequence, running_stage: RunningStage) -> None: + """Called after `collate` """ + + def on_per_sample_transform_on_device(self, samples: Sequence, running_stage: RunningStage) -> None: + """Called after `per_sample_transform_on_device` """ + + def on_per_batch_transform_on_device(self, batch: Any, running_stage: RunningStage) -> None: + """Called after `per_batch_transform_on_device` """ + + +class ControlFlow(FlashCallback): + + def __init__(self, callbacks: List[FlashCallback]): + self._callbacks = callbacks + + def run_for_all_callbacks(self, *args, method_name: str, **kwargs): + if self._callbacks: + for cb in self._callbacks: + getattr(cb, method_name)(*args, **kwargs) + + def on_load_sample(self, sample: Any, running_stage: RunningStage) -> None: + self.run_for_all_callbacks(sample, running_stage, method_name="on_load_sample") + + def on_pre_tensor_transform(self, sample: Any, running_stage: RunningStage) -> None: + self.run_for_all_callbacks(sample, running_stage, method_name="on_pre_tensor_transform") + + def on_to_tensor_transform(self, sample: Any, running_stage: RunningStage) -> None: + self.run_for_all_callbacks(sample, running_stage, method_name="on_to_tensor_transform") + + def on_post_tensor_transform(self, sample: Tensor, running_stage: RunningStage) -> None: + self.run_for_all_callbacks(sample, running_stage, method_name="on_post_tensor_transform") + + def on_per_batch_transform(self, batch: Any, running_stage: RunningStage) -> None: + self.run_for_all_callbacks(batch, running_stage, method_name="on_per_batch_transform") + + def on_collate(self, batch: Sequence, running_stage: RunningStage) -> None: + self.run_for_all_callbacks(batch, running_stage, method_name="on_collate") + + def on_per_sample_transform_on_device(self, samples: Sequence, running_stage: RunningStage) -> None: + self.run_for_all_callbacks(samples, running_stage, method_name="per_sample_transform_on_device") + + def on_per_batch_transform_on_device(self, batch: Any, running_stage: RunningStage) -> None: + self.run_for_all_callbacks(batch, running_stage, method_name="per_batch_transform_on_device") diff --git a/flash/data/data_module.py b/flash/data/data_module.py index cdc04b8cff..4b9608cbda 100644 --- a/flash/data/data_module.py +++ b/flash/data/data_module.py @@ -13,7 +13,7 @@ # limitations under the License. import os import platform -from typing import Any, Callable, Dict, Optional, Tuple, Union +from typing import Any, Callable, Dict, Generator, Iterable, Optional, Tuple, Union import pytorch_lightning as pl import torch @@ -98,23 +98,15 @@ def viz(self) -> BaseViz: def viz(self, viz: BaseViz) -> None: self._viz = viz - @classmethod - def configure_vis(cls) -> BaseViz: + @staticmethod + def configure_vis(*args, **kwargs) -> BaseViz: return BaseViz() - def show_batch(self): - # viz needs to be enabled, so it doesn't store profile transforms during training - with self.viz.enable(): - _ = next(iter(self.train_dataloader())) - _ = next(iter(self.val_dataloader())) - _ = next(iter(self.test_dataloader())) - _ = next(iter(self.predict_dataloader())) - - def visualize(self, batch: Dict[str, Any], stage: RunningStage) -> None: + def show(self, batch: Dict[str, Any], stage: RunningStage) -> None: """ This function is a hook for users to override with their visualization on a batch. """ - pass + self.viz.show(batch, stage) def _show_batch(self, stage: RunningStage, reset: bool = True) -> None: """ @@ -122,20 +114,23 @@ def _show_batch(self, stage: RunningStage, reset: bool = True) -> None: """ iter_name = f"_{stage}_iter" - def _reset_iterator(): + def _reset_iterator() -> Iterable[Any]: dataloader_fn = getattr(self, f"{stage}_dataloader") - setattr(self, iter_name, iter(dataloader_fn())) + iterator = iter(dataloader_fn()) + setattr(self, iter_name, iterator) + return iterator if not hasattr(self, iter_name): _reset_iterator() + iter_dataloader = getattr(self, iter_name) with self.viz.enable(): try: _ = next(iter_dataloader) except StopIteration: - _reset_iterator() + iter_dataloader = _reset_iterator() _ = next(iter_dataloader) - self.visualize(self.viz.batches[stage], stage) + self.show(self.viz.batches[stage], stage) if reset: self.viz.batches[stage] = {} diff --git a/flash/data/data_pipeline.py b/flash/data/data_pipeline.py index 7cce7fc04a..b6e400e0c9 100644 --- a/flash/data/data_pipeline.py +++ b/flash/data/data_pipeline.py @@ -26,7 +26,7 @@ from flash.data.auto_dataset import AutoDataset from flash.data.batch import _PostProcessor, _PreProcessor, _Sequential from flash.data.process import Postprocess, Preprocess -from flash.data.utils import _STAGES_PREFIX +from flash.data.utils import _PREPROCESS_FUNCS, _STAGES_PREFIX if TYPE_CHECKING: from flash.core.model import Task @@ -129,17 +129,7 @@ def forward(self, samples: Sequence[Any]): """ - PREPROCESS_FUNCS = { - "load_data", - "load_sample", - "pre_tensor_transform", - "to_tensor_transform", - "post_tensor_transform", - "per_batch_transform", - "per_sample_transform_on_device", - "per_batch_transform_on_device", - "collate", - } + PREPROCESS_FUNCS = _PREPROCESS_FUNCS def __init__(self, preprocess: Optional[Preprocess] = None, postprocess: Optional[Postprocess] = None) -> None: self._preprocess_pipeline = preprocess or Preprocess() diff --git a/flash/data/process.py b/flash/data/process.py index 8fd16e5359..33283fb701 100644 --- a/flash/data/process.py +++ b/flash/data/process.py @@ -13,7 +13,7 @@ # limitations under the License. import os from dataclasses import dataclass -from typing import Any, Callable, Dict, Optional, Sequence, Union +from typing import Any, Callable, Dict, List, Optional, Sequence, Union import torch from pytorch_lightning.trainer.states import RunningStage @@ -22,6 +22,7 @@ from torch.utils.data._utils.collate import default_collate from flash.data.batch import default_uncollate +from flash.data.callback import FlashCallback from flash.data.utils import convert_to_modules @@ -117,6 +118,8 @@ def __init__( if not hasattr(self, "_skip_mutual_check"): self._skip_mutual_check = False + self._callbacks = [] + @property def skip_mutual_check(self) -> bool: return self._skip_mutual_check @@ -150,6 +153,16 @@ def current_transform(self) -> Callable: def from_state(cls, state: PreprocessState) -> 'Preprocess': return cls(**vars(state)) + @property + def callbacks(self): + if not hasattr(self, "_callbacks"): + self._callbacks = [] + return self._callbacks + + @callbacks.setter + def callbacks(self, callbacks: List['FlashCallback']): + self._callbacks.extend(callbacks) + @classmethod def load_data(cls, data: Any, dataset: Optional[Any] = None) -> Any: """Loads entire data from Dataset""" @@ -201,7 +214,6 @@ def per_batch_transform_on_device(self, batch: Any) -> Any: return batch -@dataclass(unsafe_hash=True) class Postprocess(Properties, torch.nn.Module): def __init__(self, save_path: Optional[str] = None): diff --git a/flash/data/utils.py b/flash/data/utils.py index f3928861ba..8436001d5a 100644 --- a/flash/data/utils.py +++ b/flash/data/utils.py @@ -32,6 +32,18 @@ } _STAGES_PREFIX_VALUES = {"train", "test", "val", "predict"} +_PREPROCESS_FUNCS = { + "load_data", + "load_sample", + "pre_tensor_transform", + "to_tensor_transform", + "post_tensor_transform", + "per_batch_transform", + "per_sample_transform_on_device", + "per_batch_transform_on_device", + "collate", +} + class CurrentRunningStageContext: diff --git a/flash_examples/finetuning/image_classification.py b/flash_examples/finetuning/image_classification.py index 4a4032ac5d..b85ecfe3e5 100644 --- a/flash_examples/finetuning/image_classification.py +++ b/flash_examples/finetuning/image_classification.py @@ -26,6 +26,7 @@ val_folder="data/hymenoptera_data/val/", test_folder="data/hymenoptera_data/test/", ) + # 3. Build the model model = ImageClassifier(num_classes=datamodule.num_classes) diff --git a/tests/data/test_auto_dataset.py b/tests/data/test_auto_dataset.py index 273b5aa870..c3a0b6b79a 100644 --- a/tests/data/test_auto_dataset.py +++ b/tests/data/test_auto_dataset.py @@ -23,6 +23,7 @@ class _AutoDatasetTestPreprocess(Preprocess): def __init__(self, with_dset: bool): + self._callbacks = [] self.load_data_count = 0 self.load_sample_count = 0 self.load_sample_with_dataset_count = 0 diff --git a/tests/data/test_data_pipeline.py b/tests/data/test_data_pipeline.py index ed7ebe60b9..a482e854d3 100644 --- a/tests/data/test_data_pipeline.py +++ b/tests/data/test_data_pipeline.py @@ -11,7 +11,7 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. - +from functools import partial from typing import Any, Callable, Dict, List, Optional, Tuple from unittest import mock @@ -44,16 +44,6 @@ def __len__(self) -> int: return 5 -class CustomModel(Task): - - def __init__(self, postprocess: Optional[Postprocess] = None): - super().__init__(model=torch.nn.Linear(1, 1), loss_fn=torch.nn.MSELoss()) - self._postprocess = postprocess - - def train_dataloader(self) -> Any: - return DataLoader(DummyDataset()) - - class CustomDataModule(DataModule): def __init__(self): @@ -69,6 +59,15 @@ def __init__(self): @pytest.mark.parametrize("use_postprocess", [False, True]) def test_data_pipeline_init_and_assignement(use_preprocess, use_postprocess, tmpdir): + class CustomModel(Task): + + def __init__(self, postprocess: Optional[Postprocess] = None): + super().__init__(model=torch.nn.Linear(1, 1), loss_fn=torch.nn.MSELoss()) + self._postprocess = postprocess + + def train_dataloader(self) -> Any: + return DataLoader(DummyDataset()) + class SubPreprocess(Preprocess): pass @@ -82,7 +81,7 @@ class SubPostprocess(Postprocess): assert isinstance(data_pipeline._preprocess_pipeline, SubPreprocess if use_preprocess else Preprocess) assert isinstance(data_pipeline._postprocess_pipeline, SubPostprocess if use_postprocess else Postprocess) - model = CustomModel(Postprocess()) + model = CustomModel(postprocess=Postprocess()) model.data_pipeline = data_pipeline assert isinstance(model._preprocess, SubPreprocess if use_preprocess else Preprocess) assert isinstance(model._postprocess, SubPostprocess if use_postprocess else Postprocess) @@ -287,6 +286,15 @@ def test_data_pipeline_predict_worker_preprocessor_and_device_preprocessor(): def test_detach_preprocessing_from_model(tmpdir): + class CustomModel(Task): + + def __init__(self, postprocess: Optional[Postprocess] = None): + super().__init__(model=torch.nn.Linear(1, 1), loss_fn=torch.nn.MSELoss()) + self._postprocess = postprocess + + def train_dataloader(self) -> Any: + return DataLoader(DummyDataset()) + preprocess = CustomPreprocess() data_pipeline = DataPipeline(preprocess) model = CustomModel() @@ -334,9 +342,37 @@ def predict_per_batch_transform_on_device(self, *_, **__): def test_attaching_datapipeline_to_model(tmpdir): - preprocess = TestPreprocess() + preprocess = Preprocess() data_pipeline = DataPipeline(preprocess) + class CustomModel(Task): + + _postprocess = Postprocess() + + def __init__(self): + super().__init__(model=torch.nn.Linear(1, 1), loss_fn=torch.nn.MSELoss()) + + def training_step(self, batch: Any, batch_idx: int) -> Any: + pass + + def validation_step(self, batch: Any, batch_idx: int) -> Any: + pass + + def test_step(self, batch: Any, batch_idx: int) -> Any: + pass + + def train_dataloader(self) -> Any: + return DataLoader(DummyDataset()) + + def val_dataloader(self) -> Any: + return DataLoader(DummyDataset()) + + def test_dataloader(self) -> Any: + return DataLoader(DummyDataset()) + + def predict_dataloader(self) -> Any: + return DataLoader(DummyDataset()) + class TestModel(CustomModel): stages = [RunningStage.TRAINING, RunningStage.VALIDATING, RunningStage.TESTING, RunningStage.PREDICTING] @@ -430,11 +466,10 @@ def on_fit_end(self) -> None: assert not isinstance(self.transfer_batch_to_device, _StageOrchestrator) assert self.predict_step == self._saved_predict_step - datamodule = CustomDataModule() - datamodule._data_pipeline = data_pipeline model = TestModel() + model.data_pipeline = data_pipeline trainer = Trainer(fast_dev_run=True) - trainer.fit(model, datamodule=datamodule) + trainer.fit(model) trainer.test(model) trainer.predict(model) @@ -497,11 +532,20 @@ def __init__(self): self.test_post_tensor_transform_called = False self.predict_load_data_called = False + @staticmethod + def fn_train_load_data() -> Tuple: + return ( + 0, + 1, + 2, + 3, + ) + def train_load_data(self, sample) -> LamdaDummyDataset: assert self.training assert self.current_fn == "load_data" self.train_load_data_called = True - return LamdaDummyDataset(lambda: (0, 1, 2, 3)) + return LamdaDummyDataset(self.fn_train_load_data) def train_pre_tensor_transform(self, sample: Any) -> Any: assert self.training @@ -557,11 +601,15 @@ def val_per_batch_transform_on_device(self, batch: Any) -> Any: assert torch.equal(batch["b"], tensor([1, 2])) return [False] + @staticmethod + def fn_test_load_data() -> List[torch.Tensor]: + return [torch.rand(1), torch.rand(1)] + def test_load_data(self, sample) -> LamdaDummyDataset: assert self.testing assert self.current_fn == "load_data" self.test_load_data_called = True - return LamdaDummyDataset(lambda: [torch.rand(1), torch.rand(1)]) + return LamdaDummyDataset(self.fn_test_load_data) def test_to_tensor_transform(self, sample: Any) -> Tensor: assert self.testing @@ -575,11 +623,15 @@ def test_post_tensor_transform(self, sample: Tensor) -> Tensor: self.test_post_tensor_transform_called = True return sample + @staticmethod + def fn_predict_load_data() -> List[str]: + return (["a", "b"]) + def predict_load_data(self, sample) -> LamdaDummyDataset: assert self.predicting assert self.current_fn == "load_data" self.predict_load_data_called = True - return LamdaDummyDataset(lambda: (["a", "b"])) + return LamdaDummyDataset(self.fn_predict_load_data) class TestPreprocessTransformations2(TestPreprocessTransformations): @@ -589,33 +641,35 @@ def val_to_tensor_transform(self, sample: Any) -> Tensor: return {"a": tensor(sample["a"]), "b": tensor(sample["b"])} -def test_datapipeline_transformations(tmpdir): +class CustomModel(Task): - class CustomModel(Task): + def __init__(self): + super().__init__(model=torch.nn.Linear(1, 1), loss_fn=torch.nn.MSELoss()) - def __init__(self): - super().__init__(model=torch.nn.Linear(1, 1), loss_fn=torch.nn.MSELoss()) + def training_step(self, batch, batch_idx): + assert batch is None - def training_step(self, batch, batch_idx): - assert batch is None + def validation_step(self, batch, batch_idx): + assert batch is False - def validation_step(self, batch, batch_idx): - assert batch is False + def test_step(self, batch, batch_idx): + assert len(batch) == 2 + assert batch[0].shape == torch.Size([2, 1]) - def test_step(self, batch, batch_idx): - assert len(batch) == 2 - assert batch[0].shape == torch.Size([2, 1]) + def predict_step(self, batch, batch_idx, dataloader_idx): + assert batch[0][0] == 'a' + assert batch[0][1] == 'a' + assert batch[1][0] == 'b' + assert batch[1][1] == 'b' + return tensor([0, 0, 0]) - def predict_step(self, batch, batch_idx, dataloader_idx): - assert batch[0][0] == 'a' - assert batch[0][1] == 'a' - assert batch[1][0] == 'b' - assert batch[1][1] == 'b' - return tensor([0, 0, 0]) - class CustomDataModule(DataModule): +class CustomDataModule(DataModule): - preprocess_cls = TestPreprocessTransformations + preprocess_cls = TestPreprocessTransformations + + +def test_datapipeline_transformations(tmpdir): datamodule = CustomDataModule.from_load_data_inputs(1, 1, 1, 1, batch_size=2, num_workers=0) diff --git a/tests/data/test_data_viz.py b/tests/data/test_data_viz.py index ab08da43aa..621173d6ca 100644 --- a/tests/data/test_data_viz.py +++ b/tests/data/test_data_viz.py @@ -13,14 +13,15 @@ # limitations under the License. from pathlib import Path +from typing import Any, List, Sequence -import kornia as K import numpy as np import torch -import torchvision.transforms as T from PIL import Image from pytorch_lightning import seed_everything +from pytorch_lightning.trainer.states import RunningStage +from flash.data.base_viz import BaseViz from flash.data.utils import _STAGES_PREFIX from flash.vision import ImageClassificationData @@ -30,6 +31,7 @@ def _rand_image(): def test_base_viz(tmpdir): + seed_everything(42) tmpdir = Path(tmpdir) @@ -41,7 +43,48 @@ def test_base_viz(tmpdir): _rand_image().save(tmpdir / "b" / "a_1.png") _rand_image().save(tmpdir / "b" / "a_2.png") - img_data = ImageClassificationData.from_filepaths( + class CustomBaseViz(BaseViz): + + show_load_sample_called = False + show_pre_tensor_transform_called = False + show_to_tensor_transform_called = False + show_post_tensor_transform_called = False + show_collate_called = False + per_batch_transform_called = False + + def show_load_sample(self, samples: List[Any], running_stage: RunningStage): + self.show_load_sample_called = True + + def show_pre_tensor_transform(self, samples: List[Any], running_stage: RunningStage): + self.show_pre_tensor_transform_called = True + + def show_to_tensor_transform(self, samples: List[Any], running_stage: RunningStage): + self.show_to_tensor_transform_called = True + + def show_post_tensor_transform(self, samples: List[Any], running_stage: RunningStage): + self.show_post_tensor_transform_called = True + + def show_collate(self, batch: Sequence, running_stage: RunningStage) -> None: + self.show_collate_called = True + + def show_per_batch_transform(self, batch: Sequence, running_stage: RunningStage) -> None: + self.per_batch_transform_called = True + + def reset(self): + self.show_load_sample_called = False + self.show_pre_tensor_transform_called = False + self.show_to_tensor_transform_called = False + self.show_post_tensor_transform_called = False + self.show_collate_called = False + self.per_batch_transform_called = False + + class CustomImageClassificationData(ImageClassificationData): + + @staticmethod + def configure_vis(*args, **kwargs) -> CustomBaseViz: + return CustomBaseViz(*args, **kwargs) + + dm = CustomImageClassificationData.from_filepaths( train_filepaths=[tmpdir / "a", tmpdir / "b"], train_labels=[0, 1], val_filepaths=[tmpdir / "a", tmpdir / "b"], @@ -53,9 +96,13 @@ def test_base_viz(tmpdir): num_workers=0, ) + dm.show_val_batch() + for stage in _STAGES_PREFIX.values(): - getattr(img_data, f"show_{stage}_batch")(reset=False) + for _ in range(10): + getattr(dm, f"show_{stage}_batch")(reset=False) + is_predict = stage == "predict" def extract_data(data): @@ -63,19 +110,27 @@ def extract_data(data): return data[0][0] return data[0] - assert isinstance(extract_data(img_data.viz.batches[stage]["load_sample"]), Image.Image) + assert isinstance(extract_data(dm.viz.batches[stage]["load_sample"]), Image.Image) if not is_predict: - assert isinstance(img_data.viz.batches[stage]["load_sample"][0][1], int) + assert isinstance(dm.viz.batches[stage]["load_sample"][0][1], int) - assert isinstance(extract_data(img_data.viz.batches[stage]["to_tensor_transform"]), torch.Tensor) + assert isinstance(extract_data(dm.viz.batches[stage]["to_tensor_transform"]), torch.Tensor) if not is_predict: - assert isinstance(img_data.viz.batches[stage]["to_tensor_transform"][0][1], int) + assert isinstance(dm.viz.batches[stage]["to_tensor_transform"][0][1], int) - assert extract_data(img_data.viz.batches[stage]["collate"]).shape == torch.Size([2, 3, 196, 196]) + assert extract_data(dm.viz.batches[stage]["collate"]).shape == torch.Size([2, 3, 196, 196]) if not is_predict: - assert img_data.viz.batches[stage]["collate"][0][1].shape == torch.Size([2]) + assert dm.viz.batches[stage]["collate"][0][1].shape == torch.Size([2]) - generated = extract_data(img_data.viz.batches[stage]["per_batch_transform"]).shape + generated = extract_data(dm.viz.batches[stage]["per_batch_transform"]).shape assert generated == torch.Size([2, 3, 196, 196]) if not is_predict: - assert img_data.viz.batches[stage]["per_batch_transform"][0][1].shape == torch.Size([2]) + assert dm.viz.batches[stage]["per_batch_transform"][0][1].shape == torch.Size([2]) + + assert dm.viz.show_load_sample_called + assert dm.viz.show_pre_tensor_transform_called + assert dm.viz.show_to_tensor_transform_called + assert dm.viz.show_post_tensor_transform_called + assert dm.viz.show_collate_called + assert dm.viz.per_batch_transform_called + dm.viz.reset() diff --git a/tests/vision/classification/test_data.py b/tests/vision/classification/test_data.py index e90a5f838b..ab5cd66c8c 100644 --- a/tests/vision/classification/test_data.py +++ b/tests/vision/classification/test_data.py @@ -178,9 +178,7 @@ def test_from_folders(tmpdir): _rand_image().save(train_dir / "b" / "1.png") _rand_image().save(train_dir / "b" / "2.png") - img_data = ImageClassificationData.from_folders( - train_dir, train_transform=None, loader=_dummy_image_loader, batch_size=1 - ) + img_data = ImageClassificationData.from_folders(train_dir, train_transform=None, batch_size=1) data = next(iter(img_data.train_dataloader())) imgs, labels = data assert imgs.shape == (1, 3, 196, 196) From bda5ff2a34c1679231c30fde7c5ff12960edf4a4 Mon Sep 17 00:00:00 2001 From: tchaton Date: Tue, 6 Apr 2021 13:18:46 +0100 Subject: [PATCH 26/30] resolve flake8 --- tests/data/test_data_pipeline.py | 11 ----------- 1 file changed, 11 deletions(-) diff --git a/tests/data/test_data_pipeline.py b/tests/data/test_data_pipeline.py index a482e854d3..85e769208e 100644 --- a/tests/data/test_data_pipeline.py +++ b/tests/data/test_data_pipeline.py @@ -44,17 +44,6 @@ def __len__(self) -> int: return 5 -class CustomDataModule(DataModule): - - def __init__(self): - super().__init__( - train_dataset=DummyDataset(), - val_dataset=DummyDataset(), - test_dataset=DummyDataset(), - predict_dataset=DummyDataset(), - ) - - @pytest.mark.parametrize("use_preprocess", [False, True]) @pytest.mark.parametrize("use_postprocess", [False, True]) def test_data_pipeline_init_and_assignement(use_preprocess, use_postprocess, tmpdir): From 0e7416701f4137e3597b687b7d401d021d7bd30d Mon Sep 17 00:00:00 2001 From: tchaton Date: Tue, 6 Apr 2021 13:25:33 +0100 Subject: [PATCH 27/30] update --- flash/data/batch.py | 8 ++++++-- flash/data/callback.py | 4 ++-- 2 files changed, 8 insertions(+), 4 deletions(-) diff --git a/flash/data/batch.py b/flash/data/batch.py index a0435f9332..3758d78a66 100644 --- a/flash/data/batch.py +++ b/flash/data/batch.py @@ -143,7 +143,8 @@ def forward(self, samples: Sequence[Any]) -> Any: _samples = [] for sample in samples: sample = self.per_sample_transform(sample) - self.callback.on_pre_tensor_transform(sample, self.stage) + if self.on_device: + self.callback.on_per_sample_transform_on_device(sample, self.stage) _samples.append(sample) samples = type(_samples)(_samples) @@ -154,7 +155,10 @@ def forward(self, samples: Sequence[Any]) -> Any: with self._per_batch_transform_context: samples = self.per_batch_transform(samples) - self.callback.on_per_batch_transform(samples, self.stage) + if self.on_device: + self.callback.on_per_batch_transform_on_device(samples, self.stage) + else: + self.callback.on_per_batch_transform(samples, self.stage) return samples def __str__(self) -> str: diff --git a/flash/data/callback.py b/flash/data/callback.py index 97a8c03549..7b3b39b4ab 100644 --- a/flash/data/callback.py +++ b/flash/data/callback.py @@ -61,7 +61,7 @@ def on_collate(self, batch: Sequence, running_stage: RunningStage) -> None: self.run_for_all_callbacks(batch, running_stage, method_name="on_collate") def on_per_sample_transform_on_device(self, samples: Sequence, running_stage: RunningStage) -> None: - self.run_for_all_callbacks(samples, running_stage, method_name="per_sample_transform_on_device") + self.run_for_all_callbacks(samples, running_stage, method_name="on_per_sample_transform_on_device") def on_per_batch_transform_on_device(self, batch: Any, running_stage: RunningStage) -> None: - self.run_for_all_callbacks(batch, running_stage, method_name="per_batch_transform_on_device") + self.run_for_all_callbacks(batch, running_stage, method_name="on_per_batch_transform_on_device") From d0fb78d22b2515a37be69c32cf9897092ca36aad Mon Sep 17 00:00:00 2001 From: tchaton Date: Tue, 6 Apr 2021 13:32:24 +0100 Subject: [PATCH 28/30] update --- flash/data/base_viz.py | 3 +++ 1 file changed, 3 insertions(+) diff --git a/flash/data/base_viz.py b/flash/data/base_viz.py index 257dd88992..b2732bf60a 100644 --- a/flash/data/base_viz.py +++ b/flash/data/base_viz.py @@ -14,6 +14,9 @@ class BaseViz(FlashCallback): """ This class is used to profile ``Preprocess`` hook outputs and visualize the data transformations. It is disabled by default. + + batches: Dict = {"train": {"to_tensor_transform": [], ...}, ...} + """ def __init__(self, enabled: bool = False): From 098d7abbdcf8096a5201806d8aab6665def0652b Mon Sep 17 00:00:00 2001 From: tchaton Date: Tue, 6 Apr 2021 15:58:50 +0100 Subject: [PATCH 29/30] update --- flash/core/model.py | 3 - flash/core/utils.py | 8 +-- flash/data/auto_dataset.py | 2 +- flash/data/base_viz.py | 2 - flash/data/callback.py | 22 +++--- flash/data/process.py | 9 +-- flash/vision/__init__.py | 2 +- flash/vision/classification/__init__.py | 2 +- flash/vision/classification/data.py | 53 +------------- test.py | 7 -- tests/data/test_auto_dataset.py | 6 +- tests/data/test_callback.py | 94 +++++++++++++++++++++++++ tests/data/test_data_pipeline.py | 1 - tests/data/test_data_viz.py | 2 - 14 files changed, 124 insertions(+), 89 deletions(-) delete mode 100644 test.py create mode 100644 tests/data/test_callback.py diff --git a/flash/core/model.py b/flash/core/model.py index 5be48c3384..d2cc3c2bef 100644 --- a/flash/core/model.py +++ b/flash/core/model.py @@ -246,12 +246,9 @@ def on_fit_end(self) -> None: super().on_fit_end() def on_save_checkpoint(self, checkpoint: Dict[str, Any]) -> None: - # TODO: Is this the best way to do this? or should we also use some kind of hparams here? # This may be an issue since here we create the same problems with pickle as in # https://pytorch.org/docs/stable/notes/serialization.html if self.data_pipeline is not None and 'data_pipeline' not in checkpoint: - # todo (tchaton): TypeError: cannot pickle '_io.TextIOWrapper' object with BaseViz Callback - self.data_pipeline._preprocess_pipeline._callbacks = [] checkpoint['data_pipeline'] = self.data_pipeline super().on_save_checkpoint(checkpoint) diff --git a/flash/core/utils.py b/flash/core/utils.py index 040d6e28d6..7676218c71 100644 --- a/flash/core/utils.py +++ b/flash/core/utils.py @@ -11,7 +11,7 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -from typing import Any, Callable, Dict, Mapping, Sequence, Union +from typing import Any, Callable, Dict, Mapping, Sequence, Type, Union def get_callable_name(fn_or_class: Union[Callable, object]) -> str: @@ -27,13 +27,13 @@ def get_callable_dict(fn: Union[Callable, Mapping, Sequence]) -> Union[Dict, Map return {get_callable_name(fn): fn} -def _is_overriden(method_name: str, process_obj, super_obj: Any) -> bool: +def _is_overriden(method_name: str, instance: object, parent: Type[object]) -> bool: """ Cropped Version of https://github.com/PyTorchLightning/pytorch-lightning/blob/master/pytorch_lightning/utilities/model_helpers.py """ - if not hasattr(process_obj, method_name): + if not hasattr(instance, method_name): return False - return getattr(process_obj, method_name).__code__ != getattr(super_obj, method_name).__code__ + return getattr(instance, method_name).__code__ != getattr(parent, method_name).__code__ diff --git a/flash/data/auto_dataset.py b/flash/data/auto_dataset.py index 0f3c2cbad5..498b67a33d 100644 --- a/flash/data/auto_dataset.py +++ b/flash/data/auto_dataset.py @@ -131,7 +131,7 @@ def __getitem__(self, index: int) -> Any: raise RuntimeError("`__getitem__` for `load_sample` and `load_data` could not be inferred.") if self.load_sample: with self._load_sample_context: - data = self._call_load_sample(self.preprocessed_data[index]) + data: Any = self._call_load_sample(self.preprocessed_data[index]) if self.control_flow_callback: self.control_flow_callback.on_load_sample(data, self.running_stage) return data diff --git a/flash/data/base_viz.py b/flash/data/base_viz.py index b2732bf60a..2dcd95fc27 100644 --- a/flash/data/base_viz.py +++ b/flash/data/base_viz.py @@ -22,7 +22,6 @@ class BaseViz(FlashCallback): def __init__(self, enabled: bool = False): self.batches = {k: {} for k in _STAGES_PREFIX.values()} self.enabled = enabled - self._datamodule = None self._preprocess = None def on_load_sample(self, sample: Any, running_stage: RunningStage) -> None: @@ -72,7 +71,6 @@ def enable(self): self.enabled = False def attach_to_datamodule(self, datamodule) -> None: - self._datamodule = datamodule datamodule.viz = self def attach_to_preprocess(self, preprocess: Preprocess) -> None: diff --git a/flash/data/callback.py b/flash/data/callback.py index 7b3b39b4ab..b253b77d91 100644 --- a/flash/data/callback.py +++ b/flash/data/callback.py @@ -8,28 +8,28 @@ class FlashCallback(Callback): def on_load_sample(self, sample: Any, running_stage: RunningStage) -> None: - """Called once a sample has been loaded.""" + """Called once a sample has been loaded using ``load_sample``.""" def on_pre_tensor_transform(self, sample: Any, running_stage: RunningStage) -> None: - """Called once an object has been transformed.""" + """Called once ``pre_tensor_transform`` have been applied to a sample.""" def on_to_tensor_transform(self, sample: Any, running_stage: RunningStage) -> None: - """Called once an object has been transformed to a tensor.""" + """Called once ``to_tensor_transform`` have been applied to a sample.""" def on_post_tensor_transform(self, sample: Tensor, running_stage: RunningStage) -> None: - """Called after `post_tensor_transform` """ + """Called once ``post_tensor_transform`` have been applied to a sample.""" def on_per_batch_transform(self, batch: Any, running_stage: RunningStage) -> None: - """Called after `per_batch_transform` """ + """Called once ``per_batch_transform`` have been applied to a batch.""" def on_collate(self, batch: Sequence, running_stage: RunningStage) -> None: - """Called after `collate` """ + """Called once ``collate`` have been applied to a sequence of samples.""" - def on_per_sample_transform_on_device(self, samples: Sequence, running_stage: RunningStage) -> None: - """Called after `per_sample_transform_on_device` """ + def on_per_sample_transform_on_device(self, sample: Any, running_stage: RunningStage) -> None: + """Called once ``per_sample_transform_on_device`` have been applied to a sample.""" def on_per_batch_transform_on_device(self, batch: Any, running_stage: RunningStage) -> None: - """Called after `per_batch_transform_on_device` """ + """Called once ``per_batch_transform_on_device`` have been applied to a sample.""" class ControlFlow(FlashCallback): @@ -60,8 +60,8 @@ def on_per_batch_transform(self, batch: Any, running_stage: RunningStage) -> Non def on_collate(self, batch: Sequence, running_stage: RunningStage) -> None: self.run_for_all_callbacks(batch, running_stage, method_name="on_collate") - def on_per_sample_transform_on_device(self, samples: Sequence, running_stage: RunningStage) -> None: - self.run_for_all_callbacks(samples, running_stage, method_name="on_per_sample_transform_on_device") + def on_per_sample_transform_on_device(self, sample: Any, running_stage: RunningStage) -> None: + self.run_for_all_callbacks(sample, running_stage, method_name="on_per_sample_transform_on_device") def on_per_batch_transform_on_device(self, batch: Any, running_stage: RunningStage) -> None: self.run_for_all_callbacks(batch, running_stage, method_name="on_per_batch_transform_on_device") diff --git a/flash/data/process.py b/flash/data/process.py index 33283fb701..1512ec1d28 100644 --- a/flash/data/process.py +++ b/flash/data/process.py @@ -118,7 +118,7 @@ def __init__( if not hasattr(self, "_skip_mutual_check"): self._skip_mutual_check = False - self._callbacks = [] + self._callbacks: List[FlashCallback] = [] @property def skip_mutual_check(self) -> bool: @@ -154,14 +154,15 @@ def from_state(cls, state: PreprocessState) -> 'Preprocess': return cls(**vars(state)) @property - def callbacks(self): + def callbacks(self) -> List['FlashCallback']: if not hasattr(self, "_callbacks"): - self._callbacks = [] + self._callbacks: List[FlashCallback] = [] return self._callbacks @callbacks.setter def callbacks(self, callbacks: List['FlashCallback']): - self._callbacks.extend(callbacks) + _callbacks = [c for c in callbacks if c not in self._callbacks] + self._callbacks.extend(_callbacks) @classmethod def load_data(cls, data: Any, dataset: Optional[Any] = None) -> Any: diff --git a/flash/vision/__init__.py b/flash/vision/__init__.py index 42e0ac34aa..2023605d7b 100644 --- a/flash/vision/__init__.py +++ b/flash/vision/__init__.py @@ -1,3 +1,3 @@ -from flash.vision.classification import ImageClassificationData, ImageClassifier, ImageClassificationDataViz +from flash.vision.classification import ImageClassificationData, ImageClassifier from flash.vision.detection import ObjectDetectionData, ObjectDetector from flash.vision.embedding import ImageEmbedder diff --git a/flash/vision/classification/__init__.py b/flash/vision/classification/__init__.py index c8f37d1f76..eaeab26233 100644 --- a/flash/vision/classification/__init__.py +++ b/flash/vision/classification/__init__.py @@ -1,2 +1,2 @@ -from flash.vision.classification.data import ImageClassificationData, ImageClassificationDataViz +from flash.vision.classification.data import ImageClassificationData from flash.vision.classification.model import ImageClassifier diff --git a/flash/vision/classification/data.py b/flash/vision/classification/data.py index 5f6c4cc0d6..8e6ad5c8c7 100644 --- a/flash/vision/classification/data.py +++ b/flash/vision/classification/data.py @@ -176,7 +176,9 @@ def to_tensor_transform(self, sample: Any) -> Any: def post_tensor_transform(self, sample: Any) -> Any: return self.common_step(sample) - # todo bug (tchaton) where to place the collate. Need an indication. + # todo: (tchaton) `per_batch_transform` and `per_sample_transform_on_device` are mutually exclusive + # `skip_mutual_check` is used to skip the checks as the information are provided from the transforms directly + # Need to properly set the `collate` depending on user provided transforms def per_batch_transform(self, sample: Any) -> Any: return self.common_step(sample) @@ -528,52 +530,3 @@ def from_filepaths( seed=seed, **kwargs ) - - -class ImageClassificationDataViz(ImageClassificationData): - - def show_train_batch(self): - self.viz.enabled = True - # fetch batch and cache data - _ = next(iter(self.train_dataloader())) - self.viz.enabled = False - - from typing import List - - import kornia as K - import matplotlib.pyplot as plt - import numpy as np - import torchvision as tv - from PIL import Image - - # plot row data - rows: int = 4 # chenge later - data_raw: List[Image] = self.viz.batches['train']['load_sample'] - for num, x_data in enumerate(data_raw): - img, label = x_data - plt.subplot(rows, rows, num + 1) - plt.title(label) - plt.axis('off') - plt.imshow(np.array(img)) - plt.title('load_sample') - plt.show(block=False) - - mean = torch.tensor([0.485, 0.456, 0.406]) - std = torch.tensor([0.229, 0.224, 0.225]) - - # plot pre-process and after augmentations - data1, labels1 = self.viz.batches['train']['collate'][0] # this is before random transforms - data2, labels2 = self.viz.batches['train']['per_batch_transform'][0] # this should be after random transforms - - data1 = K.enhance.denormalize(data1, mean, std) - data2 = K.enhance.denormalize(data2, mean, std) - - # cast and prepare data for viualisation - data1_vis = K.tensor_to_image(tv.utils.make_grid(data1)) - data2_vis = K.tensor_to_image(tv.utils.make_grid(data2)) - - # plot using matplotlib - fig, (ax1, ax2) = plt.subplots(2) - ax1.imshow(data1_vis) - ax2.imshow(data2_vis) - plt.show() diff --git a/test.py b/test.py deleted file mode 100644 index 23256dca00..0000000000 --- a/test.py +++ /dev/null @@ -1,7 +0,0 @@ -# %% -msg = "Hello World" -print(msg) - -# %% -msg = "Hello again" -print(msg) diff --git a/tests/data/test_auto_dataset.py b/tests/data/test_auto_dataset.py index c3a0b6b79a..4030e90b35 100644 --- a/tests/data/test_auto_dataset.py +++ b/tests/data/test_auto_dataset.py @@ -11,19 +11,21 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. +from typing import List import pytest from pytorch_lightning.trainer.states import RunningStage from flash.data.auto_dataset import AutoDataset +from flash.data.callback import FlashCallback from flash.data.data_pipeline import DataPipeline -from flash.data.process import Postprocess, Preprocess +from flash.data.process import Preprocess class _AutoDatasetTestPreprocess(Preprocess): def __init__(self, with_dset: bool): - self._callbacks = [] + self._callbacks: List[FlashCallback] = [] self.load_data_count = 0 self.load_sample_count = 0 self.load_sample_with_dataset_count = 0 diff --git a/tests/data/test_callback.py b/tests/data/test_callback.py new file mode 100644 index 0000000000..fd22cfe8e2 --- /dev/null +++ b/tests/data/test_callback.py @@ -0,0 +1,94 @@ +# Copyright The PyTorch Lightning team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from typing import Sequence, Tuple +from unittest import mock +from unittest.mock import ANY, call, MagicMock, Mock + +import torch +from pytorch_lightning.callbacks.base import Callback +from pytorch_lightning.trainer.states import RunningStage +from torch import Tensor + +from flash.core.model import Task +from flash.core.trainer import Trainer +from flash.data.data_module import DataModule +from flash.data.process import Preprocess + + +@mock.patch("torch.save") # need to mock torch.save or we get pickle error +def test_flash_callback(tmpdir): + """Test the callback hook system for fit.""" + + callback_mock = MagicMock() + + inputs = [[torch.rand(1), torch.rand(1)]] + dm = DataModule.from_load_data_inputs(inputs, inputs, inputs, None, num_workers=0) + dm.preprocess.callbacks += [callback_mock] + + _ = next(iter(dm.train_dataloader())) + + assert callback_mock.method_calls == [ + call.on_load_sample(ANY, RunningStage.TRAINING), + call.on_pre_tensor_transform(ANY, RunningStage.TRAINING), + call.on_to_tensor_transform(ANY, RunningStage.TRAINING), + call.on_post_tensor_transform(ANY, RunningStage.TRAINING), + call.on_collate(ANY, RunningStage.TRAINING), + call.on_per_batch_transform(ANY, RunningStage.TRAINING), + ] + + class CustomModel(Task): + + def __init__(self): + super().__init__(model=torch.nn.Linear(1, 1), loss_fn=torch.nn.MSELoss()) + + trainer = Trainer( + default_root_dir=tmpdir, + max_epochs=1, + limit_val_batches=1, + limit_train_batches=1, + progress_bar_refresh_rate=0, + ) + dm = DataModule.from_load_data_inputs(inputs, inputs, inputs, None, num_workers=0) + dm.preprocess.callbacks += [callback_mock] + trainer.fit(CustomModel(), datamodule=dm) + + assert callback_mock.method_calls == [ + call.on_load_sample(ANY, RunningStage.TRAINING), + call.on_pre_tensor_transform(ANY, RunningStage.TRAINING), + call.on_to_tensor_transform(ANY, RunningStage.TRAINING), + call.on_post_tensor_transform(ANY, RunningStage.TRAINING), + call.on_collate(ANY, RunningStage.TRAINING), + call.on_per_batch_transform(ANY, RunningStage.TRAINING), + call.on_load_sample(ANY, RunningStage.VALIDATING), + call.on_pre_tensor_transform(ANY, RunningStage.VALIDATING), + call.on_to_tensor_transform(ANY, RunningStage.VALIDATING), + call.on_post_tensor_transform(ANY, RunningStage.VALIDATING), + call.on_collate(ANY, RunningStage.VALIDATING), + call.on_per_batch_transform(ANY, RunningStage.VALIDATING), + call.on_per_batch_transform_on_device(ANY, RunningStage.VALIDATING), + call.on_load_sample(ANY, RunningStage.TRAINING), + call.on_pre_tensor_transform(ANY, RunningStage.TRAINING), + call.on_to_tensor_transform(ANY, RunningStage.TRAINING), + call.on_post_tensor_transform(ANY, RunningStage.TRAINING), + call.on_collate(ANY, RunningStage.TRAINING), + call.on_per_batch_transform(ANY, RunningStage.TRAINING), + call.on_per_batch_transform_on_device(ANY, RunningStage.TRAINING), + call.on_load_sample(ANY, RunningStage.VALIDATING), + call.on_pre_tensor_transform(ANY, RunningStage.VALIDATING), + call.on_to_tensor_transform(ANY, RunningStage.VALIDATING), + call.on_post_tensor_transform(ANY, RunningStage.VALIDATING), + call.on_collate(ANY, RunningStage.VALIDATING), + call.on_per_batch_transform(ANY, RunningStage.VALIDATING), + call.on_per_batch_transform_on_device(ANY, RunningStage.VALIDATING) + ] diff --git a/tests/data/test_data_pipeline.py b/tests/data/test_data_pipeline.py index 85e769208e..b0556be3dd 100644 --- a/tests/data/test_data_pipeline.py +++ b/tests/data/test_data_pipeline.py @@ -11,7 +11,6 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -from functools import partial from typing import Any, Callable, Dict, List, Optional, Tuple from unittest import mock diff --git a/tests/data/test_data_viz.py b/tests/data/test_data_viz.py index 621173d6ca..e23009c386 100644 --- a/tests/data/test_data_viz.py +++ b/tests/data/test_data_viz.py @@ -96,8 +96,6 @@ def configure_vis(*args, **kwargs) -> CustomBaseViz: num_workers=0, ) - dm.show_val_batch() - for stage in _STAGES_PREFIX.values(): for _ in range(10): From 9bdd1791479ad3172bc7acd2035c14b64830ce35 Mon Sep 17 00:00:00 2001 From: tchaton Date: Tue, 6 Apr 2021 19:23:39 +0100 Subject: [PATCH 30/30] resolve test --- tests/data/test_callback.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/data/test_callback.py b/tests/data/test_callback.py index fd22cfe8e2..0bc47a91cd 100644 --- a/tests/data/test_callback.py +++ b/tests/data/test_callback.py @@ -27,7 +27,7 @@ @mock.patch("torch.save") # need to mock torch.save or we get pickle error -def test_flash_callback(tmpdir): +def test_flash_callback(_, tmpdir): """Test the callback hook system for fit.""" callback_mock = MagicMock()