Skip to content
This repository has been archived by the owner on Oct 9, 2023. It is now read-only.

Commit

Permalink
Data Pipeline V2: Move Style Transfer to new DataModule (#1043)
Browse files Browse the repository at this point in the history
  • Loading branch information
tchaton authored Dec 9, 2021
1 parent f33b6b8 commit 1ddd556
Show file tree
Hide file tree
Showing 3 changed files with 95 additions and 123 deletions.
183 changes: 62 additions & 121 deletions flash/image/style_transfer/data.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,93 +11,23 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import functools
from typing import Any, Callable, Collection, Dict, Optional, Sequence, Union
from typing import Any, Collection, Dict, Optional, Sequence, Type

import numpy as np
import torch
from torch import nn

from flash.core.data.data_module import DataModule
from flash.core.data.io.input import DataKeys, InputFormat
from flash.core.data.io.input_transform import InputTransform
from flash.core.data.transforms import ApplyToKeys
from flash.core.utilities.imports import _TORCHVISION_AVAILABLE
from flash.core.data.data_pipeline import DataPipelineState
from flash.core.data.io.input_base import Input
from flash.core.data.new_data_module import DataModule
from flash.core.utilities.stages import RunningStage
from flash.core.utilities.types import INPUT_TRANSFORM_TYPE
from flash.image.classification.input import ImageClassificationFilesInput, ImageClassificationFolderInput
from flash.image.data import ImageFilesInput, ImageNumpyInput, ImageTensorInput

if _TORCHVISION_AVAILABLE:
from torchvision import transforms as T
from flash.image.data import ImageNumpyInput, ImageTensorInput
from flash.image.style_transfer.input_transform import StyleTransferInputTransform

__all__ = ["StyleTransferInputTransform", "StyleTransferData"]


def _apply_to_input(
default_transforms_fn, keys: Union[Sequence[DataKeys], DataKeys]
) -> Callable[..., Dict[str, ApplyToKeys]]:
@functools.wraps(default_transforms_fn)
def wrapper(*args: Any, **kwargs: Any) -> Optional[Dict[str, ApplyToKeys]]:
default_transforms = default_transforms_fn(*args, **kwargs)
if not default_transforms:
return default_transforms

return {hook: ApplyToKeys(keys, transform) for hook, transform in default_transforms.items()}

return wrapper


class StyleTransferInputTransform(InputTransform):
def __init__(
self,
train_transform: Optional[Dict[str, Callable]] = None,
val_transform: Optional[Dict[str, Callable]] = None,
test_transform: Optional[Dict[str, Callable]] = None,
predict_transform: Optional[Dict[str, Callable]] = None,
image_size: int = 256,
):
if isinstance(image_size, int):
image_size = (image_size, image_size)

self.image_size = image_size

super().__init__(
train_transform=train_transform,
val_transform=val_transform,
test_transform=test_transform,
predict_transform=predict_transform,
inputs={
InputFormat.FILES: ImageFilesInput,
InputFormat.FOLDERS: ImageClassificationFolderInput,
InputFormat.NUMPY: ImageNumpyInput,
InputFormat.TENSORS: ImageTensorInput,
},
default_input=InputFormat.FILES,
)

def get_state_dict(self) -> Dict[str, Any]:
return {**self.transforms, "image_size": self.image_size}

@classmethod
def load_state_dict(cls, state_dict: Dict[str, Any], strict: bool = False):
return cls(**state_dict)

@functools.partial(_apply_to_input, keys=DataKeys.INPUT)
def default_transforms(self) -> Optional[Dict[str, Callable]]:
if self.training:
return dict(
per_sample_transform=T.ToTensor(),
per_sample_transform_on_device=nn.Sequential(
T.Resize(self.image_size),
T.CenterCrop(self.image_size),
),
)
if self.predicting:
return dict(per_sample_transform=T.Compose([T.Resize(self.image_size), T.ToTensor()]))
# Style transfer doesn't support a validation or test phase, so we return nothing here
return None


class StyleTransferData(DataModule):
input_transform_cls = StyleTransferInputTransform

Expand All @@ -106,19 +36,22 @@ def from_files(
cls,
train_files: Optional[Sequence[str]] = None,
predict_files: Optional[Sequence[str]] = None,
train_transform: Optional[Dict[str, Callable]] = None,
predict_transform: Optional[Dict[str, Callable]] = None,
image_size: int = 256,
**data_module_kwargs: Any,
train_transform: INPUT_TRANSFORM_TYPE = StyleTransferInputTransform,
predict_transform: INPUT_TRANSFORM_TYPE = StyleTransferInputTransform,
input_cls: Type[Input] = ImageClassificationFilesInput,
transform_kwargs: Optional[Dict] = None,
**data_module_kwargs: Any
) -> "StyleTransferData":

ds_kw = dict(
data_pipeline_state=DataPipelineState(),
transform_kwargs=transform_kwargs,
input_transforms_registry=cls.input_transforms_registry,
)

return cls(
ImageFilesInput(RunningStage.TRAINING, train_files),
predict_dataset=ImageClassificationFilesInput(RunningStage.PREDICTING, predict_files),
input_transform=cls.input_transform_cls(
train_transform,
predict_transform=predict_transform,
image_size=image_size,
),
input_cls(RunningStage.TRAINING, train_files, transform=train_transform, **ds_kw),
predict_input=input_cls(RunningStage.PREDICTING, predict_files, transform=predict_transform, **ds_kw),
**data_module_kwargs,
)

Expand All @@ -127,19 +60,22 @@ def from_folders(
cls,
train_folder: Optional[str] = None,
predict_folder: Optional[str] = None,
train_transform: Optional[Dict[str, Callable]] = None,
predict_transform: Optional[Dict[str, Callable]] = None,
image_size: int = 256,
**data_module_kwargs: Any,
train_transform: INPUT_TRANSFORM_TYPE = StyleTransferInputTransform,
predict_transform: INPUT_TRANSFORM_TYPE = StyleTransferInputTransform,
input_cls: Type[Input] = ImageClassificationFolderInput,
transform_kwargs: Optional[Dict] = None,
**data_module_kwargs: Any
) -> "StyleTransferData":

ds_kw = dict(
data_pipeline_state=DataPipelineState(),
transform_kwargs=transform_kwargs,
input_transforms_registry=cls.input_transforms_registry,
)

return cls(
ImageClassificationFolderInput(RunningStage.TRAINING, train_folder),
predict_dataset=ImageClassificationFolderInput(RunningStage.PREDICTING, predict_folder),
input_transform=cls.input_transform_cls(
train_transform,
predict_transform=predict_transform,
image_size=image_size,
),
input_cls(RunningStage.TRAINING, train_folder, transform=train_transform, **ds_kw),
predict_input=input_cls(RunningStage.PREDICTING, predict_folder, transform=predict_transform, **ds_kw),
**data_module_kwargs,
)

Expand All @@ -148,19 +84,22 @@ def from_numpy(
cls,
train_data: Optional[Collection[np.ndarray]] = None,
predict_data: Optional[Collection[np.ndarray]] = None,
train_transform: Optional[Dict[str, Callable]] = None,
predict_transform: Optional[Dict[str, Callable]] = None,
image_size: int = 256,
**data_module_kwargs: Any,
train_transform: INPUT_TRANSFORM_TYPE = StyleTransferInputTransform,
predict_transform: INPUT_TRANSFORM_TYPE = StyleTransferInputTransform,
input_cls: Type[Input] = ImageNumpyInput,
transform_kwargs: Optional[Dict] = None,
**data_module_kwargs: Any
) -> "StyleTransferData":

ds_kw = dict(
data_pipeline_state=DataPipelineState(),
transform_kwargs=transform_kwargs,
input_transforms_registry=cls.input_transforms_registry,
)

return cls(
ImageNumpyInput(RunningStage.TRAINING, train_data),
predict_dataset=ImageNumpyInput(RunningStage.PREDICTING, predict_data),
input_transform=cls.input_transform_cls(
train_transform,
predict_transform=predict_transform,
image_size=image_size,
),
input_cls(RunningStage.TRAINING, train_data, transform=train_transform, **ds_kw),
predict_input=input_cls(RunningStage.PREDICTING, predict_data, transform=predict_transform, **ds_kw),
**data_module_kwargs,
)

Expand All @@ -169,18 +108,20 @@ def from_tensors(
cls,
train_data: Optional[Collection[torch.Tensor]] = None,
predict_data: Optional[Collection[torch.Tensor]] = None,
train_transform: Optional[Dict[str, Callable]] = None,
predict_transform: Optional[Dict[str, Callable]] = None,
image_size: int = 256,
**data_module_kwargs: Any,
train_transform: INPUT_TRANSFORM_TYPE = StyleTransferInputTransform,
predict_transform: INPUT_TRANSFORM_TYPE = StyleTransferInputTransform,
input_cls: Type[Input] = ImageTensorInput,
transform_kwargs: Optional[Dict] = None,
**data_module_kwargs: Any
) -> "StyleTransferData":
ds_kw = dict(
data_pipeline_state=DataPipelineState(),
transform_kwargs=transform_kwargs,
input_transforms_registry=cls.input_transforms_registry,
)

return cls(
ImageTensorInput(RunningStage.TRAINING, train_data),
predict_dataset=ImageTensorInput(RunningStage.PREDICTING, predict_data),
input_transform=cls.input_transform_cls(
train_transform,
predict_transform=predict_transform,
image_size=image_size,
),
input_cls(RunningStage.TRAINING, train_data, transform=train_transform, **ds_kw),
predict_input=input_cls(RunningStage.PREDICTING, predict_data, transform=predict_transform, **ds_kw),
**data_module_kwargs,
)
30 changes: 30 additions & 0 deletions flash/image/style_transfer/input_transform.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,30 @@
# Copyright The PyTorch Lightning team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from dataclasses import dataclass
from typing import Callable

from flash.core.data.input_transform import InputTransform
from flash.core.utilities.imports import _TORCHVISION_AVAILABLE

if _TORCHVISION_AVAILABLE:
from torchvision import transforms as T


@dataclass
class StyleTransferInputTransform(InputTransform):

image_size: int = 256

def input_per_sample_transform(self) -> Callable:
return T.Compose([T.ToTensor(), T.Resize(self.image_size), T.CenterCrop(self.image_size)])
5 changes: 3 additions & 2 deletions flash_examples/style_transfer.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@
# 1. Create the DataModule
download_data("https://github.com/zhiqwang/yolov5-rt-stack/releases/download/v0.3.0/coco128.zip", "./data")

datamodule = StyleTransferData.from_folders(train_folder="data/coco128/images/train2017")
datamodule = StyleTransferData.from_folders(train_folder="data/coco128/images/train2017", batch_size=1)

# 2. Build the task
model = StyleTransfer(os.path.join(flash.ASSETS_ROOT, "starry_night.jpg"))
Expand All @@ -37,7 +37,8 @@
"data/coco128/images/train2017/000000000625.jpg",
"data/coco128/images/train2017/000000000626.jpg",
"data/coco128/images/train2017/000000000629.jpg",
]
],
batch_size=3,
)
predictions = trainer.predict(model, datamodule=datamodule)
print(predictions)
Expand Down

0 comments on commit 1ddd556

Please sign in to comment.