diff --git a/flash/image/style_transfer/data.py b/flash/image/style_transfer/data.py index c9a5f361b3..8bd20dabad 100644 --- a/flash/image/style_transfer/data.py +++ b/flash/image/style_transfer/data.py @@ -11,93 +11,23 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -import functools -from typing import Any, Callable, Collection, Dict, Optional, Sequence, Union +from typing import Any, Collection, Dict, Optional, Sequence, Type import numpy as np import torch -from torch import nn -from flash.core.data.data_module import DataModule -from flash.core.data.io.input import DataKeys, InputFormat -from flash.core.data.io.input_transform import InputTransform -from flash.core.data.transforms import ApplyToKeys -from flash.core.utilities.imports import _TORCHVISION_AVAILABLE +from flash.core.data.data_pipeline import DataPipelineState +from flash.core.data.io.input_base import Input +from flash.core.data.new_data_module import DataModule from flash.core.utilities.stages import RunningStage +from flash.core.utilities.types import INPUT_TRANSFORM_TYPE from flash.image.classification.input import ImageClassificationFilesInput, ImageClassificationFolderInput -from flash.image.data import ImageFilesInput, ImageNumpyInput, ImageTensorInput - -if _TORCHVISION_AVAILABLE: - from torchvision import transforms as T +from flash.image.data import ImageNumpyInput, ImageTensorInput +from flash.image.style_transfer.input_transform import StyleTransferInputTransform __all__ = ["StyleTransferInputTransform", "StyleTransferData"] -def _apply_to_input( - default_transforms_fn, keys: Union[Sequence[DataKeys], DataKeys] -) -> Callable[..., Dict[str, ApplyToKeys]]: - @functools.wraps(default_transforms_fn) - def wrapper(*args: Any, **kwargs: Any) -> Optional[Dict[str, ApplyToKeys]]: - default_transforms = default_transforms_fn(*args, **kwargs) - if not default_transforms: - return default_transforms - - return {hook: ApplyToKeys(keys, transform) for hook, transform in default_transforms.items()} - - return wrapper - - -class StyleTransferInputTransform(InputTransform): - def __init__( - self, - train_transform: Optional[Dict[str, Callable]] = None, - val_transform: Optional[Dict[str, Callable]] = None, - test_transform: Optional[Dict[str, Callable]] = None, - predict_transform: Optional[Dict[str, Callable]] = None, - image_size: int = 256, - ): - if isinstance(image_size, int): - image_size = (image_size, image_size) - - self.image_size = image_size - - super().__init__( - train_transform=train_transform, - val_transform=val_transform, - test_transform=test_transform, - predict_transform=predict_transform, - inputs={ - InputFormat.FILES: ImageFilesInput, - InputFormat.FOLDERS: ImageClassificationFolderInput, - InputFormat.NUMPY: ImageNumpyInput, - InputFormat.TENSORS: ImageTensorInput, - }, - default_input=InputFormat.FILES, - ) - - def get_state_dict(self) -> Dict[str, Any]: - return {**self.transforms, "image_size": self.image_size} - - @classmethod - def load_state_dict(cls, state_dict: Dict[str, Any], strict: bool = False): - return cls(**state_dict) - - @functools.partial(_apply_to_input, keys=DataKeys.INPUT) - def default_transforms(self) -> Optional[Dict[str, Callable]]: - if self.training: - return dict( - per_sample_transform=T.ToTensor(), - per_sample_transform_on_device=nn.Sequential( - T.Resize(self.image_size), - T.CenterCrop(self.image_size), - ), - ) - if self.predicting: - return dict(per_sample_transform=T.Compose([T.Resize(self.image_size), T.ToTensor()])) - # Style transfer doesn't support a validation or test phase, so we return nothing here - return None - - class StyleTransferData(DataModule): input_transform_cls = StyleTransferInputTransform @@ -106,19 +36,22 @@ def from_files( cls, train_files: Optional[Sequence[str]] = None, predict_files: Optional[Sequence[str]] = None, - train_transform: Optional[Dict[str, Callable]] = None, - predict_transform: Optional[Dict[str, Callable]] = None, - image_size: int = 256, - **data_module_kwargs: Any, + train_transform: INPUT_TRANSFORM_TYPE = StyleTransferInputTransform, + predict_transform: INPUT_TRANSFORM_TYPE = StyleTransferInputTransform, + input_cls: Type[Input] = ImageClassificationFilesInput, + transform_kwargs: Optional[Dict] = None, + **data_module_kwargs: Any ) -> "StyleTransferData": + + ds_kw = dict( + data_pipeline_state=DataPipelineState(), + transform_kwargs=transform_kwargs, + input_transforms_registry=cls.input_transforms_registry, + ) + return cls( - ImageFilesInput(RunningStage.TRAINING, train_files), - predict_dataset=ImageClassificationFilesInput(RunningStage.PREDICTING, predict_files), - input_transform=cls.input_transform_cls( - train_transform, - predict_transform=predict_transform, - image_size=image_size, - ), + input_cls(RunningStage.TRAINING, train_files, transform=train_transform, **ds_kw), + predict_input=input_cls(RunningStage.PREDICTING, predict_files, transform=predict_transform, **ds_kw), **data_module_kwargs, ) @@ -127,19 +60,22 @@ def from_folders( cls, train_folder: Optional[str] = None, predict_folder: Optional[str] = None, - train_transform: Optional[Dict[str, Callable]] = None, - predict_transform: Optional[Dict[str, Callable]] = None, - image_size: int = 256, - **data_module_kwargs: Any, + train_transform: INPUT_TRANSFORM_TYPE = StyleTransferInputTransform, + predict_transform: INPUT_TRANSFORM_TYPE = StyleTransferInputTransform, + input_cls: Type[Input] = ImageClassificationFolderInput, + transform_kwargs: Optional[Dict] = None, + **data_module_kwargs: Any ) -> "StyleTransferData": + + ds_kw = dict( + data_pipeline_state=DataPipelineState(), + transform_kwargs=transform_kwargs, + input_transforms_registry=cls.input_transforms_registry, + ) + return cls( - ImageClassificationFolderInput(RunningStage.TRAINING, train_folder), - predict_dataset=ImageClassificationFolderInput(RunningStage.PREDICTING, predict_folder), - input_transform=cls.input_transform_cls( - train_transform, - predict_transform=predict_transform, - image_size=image_size, - ), + input_cls(RunningStage.TRAINING, train_folder, transform=train_transform, **ds_kw), + predict_input=input_cls(RunningStage.PREDICTING, predict_folder, transform=predict_transform, **ds_kw), **data_module_kwargs, ) @@ -148,19 +84,22 @@ def from_numpy( cls, train_data: Optional[Collection[np.ndarray]] = None, predict_data: Optional[Collection[np.ndarray]] = None, - train_transform: Optional[Dict[str, Callable]] = None, - predict_transform: Optional[Dict[str, Callable]] = None, - image_size: int = 256, - **data_module_kwargs: Any, + train_transform: INPUT_TRANSFORM_TYPE = StyleTransferInputTransform, + predict_transform: INPUT_TRANSFORM_TYPE = StyleTransferInputTransform, + input_cls: Type[Input] = ImageNumpyInput, + transform_kwargs: Optional[Dict] = None, + **data_module_kwargs: Any ) -> "StyleTransferData": + + ds_kw = dict( + data_pipeline_state=DataPipelineState(), + transform_kwargs=transform_kwargs, + input_transforms_registry=cls.input_transforms_registry, + ) + return cls( - ImageNumpyInput(RunningStage.TRAINING, train_data), - predict_dataset=ImageNumpyInput(RunningStage.PREDICTING, predict_data), - input_transform=cls.input_transform_cls( - train_transform, - predict_transform=predict_transform, - image_size=image_size, - ), + input_cls(RunningStage.TRAINING, train_data, transform=train_transform, **ds_kw), + predict_input=input_cls(RunningStage.PREDICTING, predict_data, transform=predict_transform, **ds_kw), **data_module_kwargs, ) @@ -169,18 +108,20 @@ def from_tensors( cls, train_data: Optional[Collection[torch.Tensor]] = None, predict_data: Optional[Collection[torch.Tensor]] = None, - train_transform: Optional[Dict[str, Callable]] = None, - predict_transform: Optional[Dict[str, Callable]] = None, - image_size: int = 256, - **data_module_kwargs: Any, + train_transform: INPUT_TRANSFORM_TYPE = StyleTransferInputTransform, + predict_transform: INPUT_TRANSFORM_TYPE = StyleTransferInputTransform, + input_cls: Type[Input] = ImageTensorInput, + transform_kwargs: Optional[Dict] = None, + **data_module_kwargs: Any ) -> "StyleTransferData": + ds_kw = dict( + data_pipeline_state=DataPipelineState(), + transform_kwargs=transform_kwargs, + input_transforms_registry=cls.input_transforms_registry, + ) + return cls( - ImageTensorInput(RunningStage.TRAINING, train_data), - predict_dataset=ImageTensorInput(RunningStage.PREDICTING, predict_data), - input_transform=cls.input_transform_cls( - train_transform, - predict_transform=predict_transform, - image_size=image_size, - ), + input_cls(RunningStage.TRAINING, train_data, transform=train_transform, **ds_kw), + predict_input=input_cls(RunningStage.PREDICTING, predict_data, transform=predict_transform, **ds_kw), **data_module_kwargs, ) diff --git a/flash/image/style_transfer/input_transform.py b/flash/image/style_transfer/input_transform.py new file mode 100644 index 0000000000..b80fb605a0 --- /dev/null +++ b/flash/image/style_transfer/input_transform.py @@ -0,0 +1,30 @@ +# Copyright The PyTorch Lightning team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from dataclasses import dataclass +from typing import Callable + +from flash.core.data.input_transform import InputTransform +from flash.core.utilities.imports import _TORCHVISION_AVAILABLE + +if _TORCHVISION_AVAILABLE: + from torchvision import transforms as T + + +@dataclass +class StyleTransferInputTransform(InputTransform): + + image_size: int = 256 + + def input_per_sample_transform(self) -> Callable: + return T.Compose([T.ToTensor(), T.Resize(self.image_size), T.CenterCrop(self.image_size)]) diff --git a/flash_examples/style_transfer.py b/flash_examples/style_transfer.py index 047b044e71..a36d287f1b 100644 --- a/flash_examples/style_transfer.py +++ b/flash_examples/style_transfer.py @@ -22,7 +22,7 @@ # 1. Create the DataModule download_data("https://github.com/zhiqwang/yolov5-rt-stack/releases/download/v0.3.0/coco128.zip", "./data") -datamodule = StyleTransferData.from_folders(train_folder="data/coco128/images/train2017") +datamodule = StyleTransferData.from_folders(train_folder="data/coco128/images/train2017", batch_size=1) # 2. Build the task model = StyleTransfer(os.path.join(flash.ASSETS_ROOT, "starry_night.jpg")) @@ -37,7 +37,8 @@ "data/coco128/images/train2017/000000000625.jpg", "data/coco128/images/train2017/000000000626.jpg", "data/coco128/images/train2017/000000000629.jpg", - ] + ], + batch_size=3, ) predictions = trainer.predict(model, datamodule=datamodule) print(predictions)