diff --git a/.github/workflows/ci-testing.yml b/.github/workflows/ci-testing.yml index 70e4465cbe..179a00a051 100644 --- a/.github/workflows/ci-testing.yml +++ b/.github/workflows/ci-testing.yml @@ -37,6 +37,10 @@ jobs: python-version: 3.8 requires: 'latest' topic: 'text' + - os: ubuntu-20.04 + python-version: 3.8 + requires: 'latest' + topic: 'image_style_transfer' # Timeout: https://stackoverflow.com/a/59076067/4521646 timeout-minutes: 35 diff --git a/CHANGELOG.md b/CHANGELOG.md index 7f383676bf..588523b910 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -22,6 +22,7 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/). - Refactor preprocess_cls to preprocess, add Serializer, add DataPipelineState ([#229](https://github.com/PyTorchLightning/lightning-flash/pull/229)) - Added Semantic Segmentation task ([#239](https://github.com/PyTorchLightning/lightning-flash/pull/239) [#287](https://github.com/PyTorchLightning/lightning-flash/pull/287) [#290](https://github.com/PyTorchLightning/lightning-flash/pull/290)) - Added Object detection prediction example ([#283](https://github.com/PyTorchLightning/lightning-flash/pull/283)) +- Added Style Transfer task and accompanying finetuning and prediction examples ([#262](https://github.com/PyTorchLightning/lightning-flash/pull/262)) ### Changed diff --git a/docs/source/index.rst b/docs/source/index.rst index 49de343ea3..abdd2b6f8d 100644 --- a/docs/source/index.rst +++ b/docs/source/index.rst @@ -30,6 +30,7 @@ Lightning Flash reference/object_detection reference/video_classification reference/semantic_segmentation + reference/style_transfer .. toctree:: :maxdepth: 1 diff --git a/docs/source/reference/style_transfer.rst b/docs/source/reference/style_transfer.rst new file mode 100644 index 0000000000..87495070cf --- /dev/null +++ b/docs/source/reference/style_transfer.rst @@ -0,0 +1,82 @@ +############## +Style Transfer +############## + +******** +The task +******** + +The Neural Style Transfer Task is an optimization method which extract the style from an image and apply it another image while preserving its content. +The goal is that the output image looks like the content image, but “painted” in the style of the style reference image. + +.. image:: https://raw.githubusercontent.com/pystiche/pystiche/master/docs/source/graphics/banner/banner.jpg + :alt: style_transfer_example + +Lightning Flash :class:`~flash.image.style_transfer.StyleTransfer` and +:class:`~flash.image.style_transfer.StyleTransferData` internally rely on `pystiche `_ as +backend. + +------ + +*** +Fit +*** + +First, you would have to import the :class:`~flash.image.style_transfer.StyleTransfer` +and :class:`~flash.image.style_transfer.StyleTransferData` from Flash. + +.. testcode:: style_transfer + + import flash + from flash.core.data.utils import download_data + from flash.image.style_transfer import StyleTransfer, StyleTransferData + import pystiche + + +Then, download some content images and create a :class:`~flash.image.style_transfer.StyleTransferData` DataModule. + +.. testcode:: style_transfer + + download_data("https://github.com/zhiqwang/yolov5-rt-stack/releases/download/v0.3.0/coco128.zip", "data/") + + data_module = StyleTransferData.from_folders(train_folder="data/coco128/images", batch_size=4) + + +Select a style image and pass it to the `StyleTransfer` task. + +.. testcode:: style_transfer + + style_image = pystiche.demo.images()["paint"].read(size=256) + + model = StyleTransfer(style_image) + +Finally, create a Flash :class:`flash.core.trainer.Trainer` and pass it the model and datamodule. + +.. testcode:: style_transfer + + trainer = flash.Trainer(max_epochs=2) + trainer.fit(model, data_module) + +.. testoutput:: + :hide: + + ... + + +------ + +************* +API reference +************* + +StyleTransfer +------------- + +.. autoclass:: flash.image.StyleTransfer + :members: + :exclude-members: forward + +StyleTransferData +----------------- + +.. autoclass:: flash.image.StyleTransferData diff --git a/flash/core/data/data_pipeline.py b/flash/core/data/data_pipeline.py index d768050c5d..0446ba308a 100644 --- a/flash/core/data/data_pipeline.py +++ b/flash/core/data/data_pipeline.py @@ -180,7 +180,8 @@ def _resolve_function_hierarchy( if object_type is None: object_type = Preprocess - prefixes = [''] + prefixes = [] + if stage in (RunningStage.TRAINING, RunningStage.TUNING): prefixes += ['train', 'fit'] elif stage == RunningStage.VALIDATING: @@ -190,9 +191,11 @@ def _resolve_function_hierarchy( elif stage == RunningStage.PREDICTING: prefixes += ['predict'] + prefixes += [None] + for prefix in prefixes: if cls._is_overriden(function_name, process_obj, object_type, prefix=prefix): - return f'{prefix}_{function_name}' + return function_name if prefix is None else f'{prefix}_{function_name}' return function_name diff --git a/flash/core/data/data_source.py b/flash/core/data/data_source.py index 1392bcae02..0858e9a26a 100644 --- a/flash/core/data/data_source.py +++ b/flash/core/data/data_source.py @@ -298,14 +298,10 @@ def generate_dataset( mock_dataset = typing.cast(AutoDataset, MockDataset()) with CurrentRunningStageFuncContext(running_stage, "load_data", self): - load_data: Callable[[DATA_TYPE, Optional[Any]], Any] = getattr( - self, DataPipeline._resolve_function_hierarchy( - "load_data", - self, - running_stage, - DataSource, - ) + resolved_func_name = DataPipeline._resolve_function_hierarchy( + "load_data", self, running_stage, DataSource ) + load_data: Callable[[DATA_TYPE, Optional[Any]], Any] = getattr(self, resolved_func_name) parameters = signature(load_data).parameters if len(parameters) > 1 and "dataset" in parameters: # TODO: This was DATASET_KEY before data = load_data(data, mock_dataset) diff --git a/flash/core/data/transforms.py b/flash/core/data/transforms.py index d80ec15e69..cb48f495e7 100644 --- a/flash/core/data/transforms.py +++ b/flash/core/data/transforms.py @@ -53,6 +53,11 @@ def forward(self, x: Mapping[str, Any]) -> Mapping[str, Any]: return result return x + def __repr__(self): + keys = self.keys[0] if len(self.keys) == 1 else self.keys + transform = [c for c in self.children()] + return f"{self.__class__.__name__}(keys={keys}, transform={transform})" + class KorniaParallelTransforms(nn.Sequential): """The ``KorniaParallelTransforms`` class is an ``nn.Sequential`` which will apply the given transforms to each diff --git a/flash/core/utilities/imports.py b/flash/core/utilities/imports.py index b798b891a9..e5999ac778 100644 --- a/flash/core/utilities/imports.py +++ b/flash/core/utilities/imports.py @@ -74,10 +74,13 @@ def _compare_version(package: str, op, version) -> bool: _PYTORCHVIDEO_AVAILABLE = _module_available("pytorchvideo") _MATPLOTLIB_AVAILABLE = _module_available("matplotlib") _TRANSFORMERS_AVAILABLE = _module_available("transformers") +_PYSTICHE_AVAILABLE = _module_available("pystiche") if Version: _TORCHVISION_GREATER_EQUAL_0_9 = _compare_version("torchvision", operator.ge, "0.9.0") + _PYSTICHE_GREATER_EQUAL_0_7_2 = _compare_version("pystiche", operator.ge, "0.7.2") +_IMAGE_STLYE_TRANSFER = _PYSTICHE_AVAILABLE _TEXT_AVAILABLE = _TRANSFORMERS_AVAILABLE _TABULAR_AVAILABLE = _TABNET_AVAILABLE and _PANDAS_AVAILABLE _VIDEO_AVAILABLE = _PYTORCHVIDEO_AVAILABLE diff --git a/flash/image/__init__.py b/flash/image/__init__.py index 1ebc555b8f..87bd06cd10 100644 --- a/flash/image/__init__.py +++ b/flash/image/__init__.py @@ -3,3 +3,4 @@ from flash.image.detection import ObjectDetectionData, ObjectDetector from flash.image.embedding import ImageEmbedder from flash.image.segmentation import SemanticSegmentation, SemanticSegmentationData, SemanticSegmentationPreprocess +from flash.image.style_transfer import StyleTransfer, StyleTransferData, StyleTransferPreprocess diff --git a/flash/image/style_transfer/__init__.py b/flash/image/style_transfer/__init__.py new file mode 100644 index 0000000000..a1b7e4ca80 --- /dev/null +++ b/flash/image/style_transfer/__init__.py @@ -0,0 +1,3 @@ +from flash.image.style_transfer.backbone import STYLE_TRANSFER_BACKBONES +from flash.image.style_transfer.data import StyleTransferData, StyleTransferPreprocess +from flash.image.style_transfer.model import StyleTransfer diff --git a/flash/image/style_transfer/backbone.py b/flash/image/style_transfer/backbone.py new file mode 100644 index 0000000000..021c4a3ec7 --- /dev/null +++ b/flash/image/style_transfer/backbone.py @@ -0,0 +1,28 @@ +import re + +from flash.core.registry import FlashRegistry +from flash.core.utilities.imports import _PYSTICHE_AVAILABLE + +STYLE_TRANSFER_BACKBONES = FlashRegistry("backbones") + +__all__ = ["STYLE_TRANSFER_BACKBONES"] + +if _PYSTICHE_AVAILABLE: + + from pystiche import enc + + MLE_FN_PATTERN = re.compile(r"^(?P\w+?)_multi_layer_encoder$") + + STYLE_TRANSFER_BACKBONES = FlashRegistry("backbones") + + for mle_fn in dir(enc): + match = MLE_FN_PATTERN.match(mle_fn) + if not match: + continue + + STYLE_TRANSFER_BACKBONES( + fn=lambda: (getattr(enc, mle_fn)(), None), + name=match.group("name"), + namespace="image/style_transfer", + package="pystiche", + ) diff --git a/flash/image/style_transfer/data.py b/flash/image/style_transfer/data.py new file mode 100644 index 0000000000..2b50f9cb9a --- /dev/null +++ b/flash/image/style_transfer/data.py @@ -0,0 +1,127 @@ +import functools +import pathlib +from typing import Any, Callable, Dict, Optional, Sequence, Tuple, Union + +from torch import nn + +from flash.core.data.data_source import DefaultDataKeys, DefaultDataSources +from flash.core.data.process import Preprocess +from flash.core.data.transforms import ApplyToKeys +from flash.core.utilities.imports import _TORCHVISION_AVAILABLE +from flash.image.classification import ImageClassificationData +from flash.image.data import ImageNumpyDataSource, ImagePathsDataSource, ImageTensorDataSource +from flash.image.style_transfer.utils import raise_not_supported + +if _TORCHVISION_AVAILABLE: + from torchvision import transforms as T + +__all__ = ["StyleTransferPreprocess", "StyleTransferData"] + + +def _apply_to_input(default_transforms_fn, keys: Union[Sequence[DefaultDataKeys], + DefaultDataKeys]) -> Callable[..., Dict[str, ApplyToKeys]]: + + @functools.wraps(default_transforms_fn) + def wrapper(*args: Any, **kwargs: Any) -> Optional[Dict[str, ApplyToKeys]]: + default_transforms = default_transforms_fn(*args, **kwargs) + if not default_transforms: + return default_transforms + + return {hook: ApplyToKeys(keys, transform) for hook, transform in default_transforms.items()} + + return wrapper + + +class StyleTransferPreprocess(Preprocess): + + def __init__( + self, + train_transform: Optional[Union[Dict[str, Callable]]] = None, + val_transform: Optional[Union[Dict[str, Callable]]] = None, + test_transform: Optional[Union[Dict[str, Callable]]] = None, + predict_transform: Optional[Union[Dict[str, Callable]]] = None, + image_size: int = 256, + ): + if val_transform: + raise_not_supported("validation") + if test_transform: + raise_not_supported("test") + + if isinstance(image_size, int): + image_size = (image_size, image_size) + + self.image_size = image_size + + super().__init__( + train_transform=train_transform, + val_transform=val_transform, + test_transform=test_transform, + predict_transform=predict_transform, + data_sources={ + DefaultDataSources.FILES: ImagePathsDataSource(), + DefaultDataSources.FOLDERS: ImagePathsDataSource(), + DefaultDataSources.NUMPY: ImageNumpyDataSource(), + DefaultDataSources.TENSORS: ImageTensorDataSource(), + DefaultDataSources.TENSORS: ImageTensorDataSource(), + }, + default_data_source=DefaultDataSources.FILES, + ) + + def get_state_dict(self) -> Dict[str, Any]: + return {**self.transforms, "image_size": self.image_size} + + @classmethod + def load_state_dict(cls, state_dict: Dict[str, Any], strict: bool = False): + return cls(**state_dict) + + @functools.partial(_apply_to_input, keys=DefaultDataKeys.INPUT) + def default_transforms(self) -> Optional[Dict[str, Callable]]: + if self.training: + return dict( + to_tensor_transform=T.ToTensor(), + per_sample_transform_on_device=nn.Sequential( + T.Resize(self.image_size), + T.CenterCrop(self.image_size), + ), + ) + elif self.predicting: + return dict( + pre_tensor_transform=T.Resize(self.image_size), + to_tensor_transform=T.ToTensor(), + ) + # Style transfer doesn't support a validation or test phase, so we return nothing here + return None + + +class StyleTransferData(ImageClassificationData): + preprocess_cls = StyleTransferPreprocess + + @classmethod + def from_folders( + cls, + train_folder: Optional[Union[str, pathlib.Path]] = None, + predict_folder: Optional[Union[str, pathlib.Path]] = None, + train_transform: Optional[Union[str, Dict]] = None, + predict_transform: Optional[Union[str, Dict]] = None, + preprocess: Optional[Preprocess] = None, + **kwargs: Any, + ) -> "StyleTransferData": + + if any(param in kwargs for param in ("val_folder", "val_transform")): + raise_not_supported("validation") + + if any(param in kwargs for param in ("test_folder", "test_transform")): + raise_not_supported("test") + + preprocess = preprocess or cls.preprocess_cls( + train_transform=train_transform, + predict_transform=predict_transform, + ) + + return cls.from_data_source( + DefaultDataSources.FOLDERS, + train_data=train_folder, + predict_data=predict_folder, + preprocess=preprocess, + **kwargs, + ) diff --git a/flash/image/style_transfer/model.py b/flash/image/style_transfer/model.py new file mode 100644 index 0000000000..e689725cce --- /dev/null +++ b/flash/image/style_transfer/model.py @@ -0,0 +1,174 @@ +from typing import Any, cast, Dict, List, Mapping, NoReturn, Optional, Sequence, Tuple, Type, Union + +import torch +from torch import nn +from torch.optim.lr_scheduler import _LRScheduler + +from flash.core.data.data_source import DefaultDataKeys +from flash.core.data.process import Serializer +from flash.core.model import Task +from flash.core.registry import FlashRegistry +from flash.core.utilities.imports import _IMAGE_STLYE_TRANSFER +from flash.image.style_transfer import STYLE_TRANSFER_BACKBONES + +if _IMAGE_STLYE_TRANSFER: + import pystiche.demo + from pystiche import enc, loss, ops + from pystiche.image import read_image +else: + + class enc: + Encoder = None + MultiLayerEncoder = None + + class ops: + EncodingComparisonOperator = None + FeatureReconstructionOperator = None + MultiLayerEncodingOperator = None + + class loss: + + class PerceptualLoss: + pass + + +from flash.image.style_transfer.utils import raise_not_supported + +__all__ = ["StyleTransfer"] + + +class StyleTransfer(Task): + """Task that transfer the style from an image onto another. + + Example:: + + from flash.image.style_transfer import StyleTransfer + + model = StyleTransfer(image_style) + + Args: + style_image: Image or path to an image to derive the style from. + model: The model by the style transfer task. + backbone: A string or model to use to compute the style loss from. + content_layer: Which layer from the backbone to extract the content loss from. + content_weight: The weight associated with the content loss. A lower value will lose content over style. + style_layers: Layers from the backbone to derive the style loss from. + optimizer: Optimizer to use for training the model. + optimizer_kwargs: Optimizer keywords arguments. + scheduler: Scheduler to use for training the model. + scheduler_kwargs: Scheduler keywords arguments. + learning_rate: Learning rate to use for training, defaults to ``1e-3``. + serializer: The :class:`~flash.core.data.process.Serializer` to use when serializing prediction outputs. + """ + + backbones: FlashRegistry = STYLE_TRANSFER_BACKBONES + + def __init__( + self, + style_image: Optional[Union[str, torch.Tensor]] = None, + model: Optional[nn.Module] = None, + backbone: str = "vgg16", + content_layer: str = "relu2_2", + content_weight: float = 1e5, + style_layers: Union[Sequence[str], str] = ("relu1_2", "relu2_2", "relu3_3", "relu4_3"), + style_weight: float = 1e10, + optimizer: Union[Type[torch.optim.Optimizer], torch.optim.Optimizer] = torch.optim.Adam, + optimizer_kwargs: Optional[Dict[str, Any]] = None, + scheduler: Optional[Union[Type[_LRScheduler], str, _LRScheduler]] = None, + scheduler_kwargs: Optional[Dict[str, Any]] = None, + learning_rate: float = 1e-3, + serializer: Optional[Union[Serializer, Mapping[str, Serializer]]] = None, + ): + + if not _IMAGE_STLYE_TRANSFER: + raise ModuleNotFoundError("Please, pip install -e '.[image_style_transfer]'") + + self.save_hyperparameters(ignore="style_image") + + if style_image is None: + style_image = self.default_style_image() + elif isinstance(style_image, str): + style_image = read_image(style_image) + + if model is None: + model = pystiche.demo.transformer() + + if not isinstance(style_layers, (List, Tuple)): + style_layers = (style_layers, ) + + perceptual_loss = self._get_perceptual_loss( + backbone=backbone, + content_layer=content_layer, + content_weight=content_weight, + style_layers=style_layers, + style_weight=style_weight, + ) + perceptual_loss.set_style_image(style_image) + + super().__init__( + model=model, + loss_fn=perceptual_loss, + optimizer=optimizer, + optimizer_kwargs=optimizer_kwargs, + scheduler=scheduler, + scheduler_kwargs=scheduler_kwargs, + learning_rate=learning_rate, + serializer=serializer, + ) + + self.perceptual_loss = perceptual_loss + + def default_style_image(self) -> torch.Tensor: + return pystiche.demo.images()["paint"].read(size=256) + + @staticmethod + def _modified_gram_loss(encoder: enc.Encoder, *, score_weight: float) -> ops.EncodingComparisonOperator: + # The official PyTorch examples as well as the reference implementation of the original author contain an + # oversight: they normalize the representation twice by the number of channels. To be compatible with them, we + # do the same here. + class GramOperator(ops.GramOperator): + + def enc_to_repr(self, enc: torch.Tensor) -> torch.Tensor: + repr = super().enc_to_repr(enc) + num_channels = repr.size()[1] + return repr / num_channels + + return GramOperator(encoder, score_weight=score_weight) + + def _get_perceptual_loss( + self, + *, + backbone: str, + content_layer: str, + content_weight: float, + style_layers: Sequence[str], + style_weight: float, + ) -> loss.PerceptualLoss: + mle, _ = cast(enc.MultiLayerEncoder, self.backbones.get(backbone)()) + content_loss = ops.FeatureReconstructionOperator( + mle.extract_encoder(content_layer), score_weight=content_weight + ) + style_loss = ops.MultiLayerEncodingOperator( + mle, + style_layers, + lambda encoder, layer_weight: self._modified_gram_loss(encoder, score_weight=layer_weight), + layer_weights="sum", + score_weight=style_weight, + ) + return loss.PerceptualLoss(content_loss, style_loss) + + def training_step(self, batch: Any, batch_idx: int) -> Any: + input_image = batch[DefaultDataKeys.INPUT] + self.perceptual_loss.set_content_image(input_image) + output_image = self(input_image) + return self.perceptual_loss(output_image).total() + + def validation_step(self, batch: Any, batch_idx: int) -> NoReturn: + raise_not_supported("validation") + + def test_step(self, batch: Any, batch_idx: int) -> NoReturn: + raise_not_supported("test") + + def predict_step(self, batch: Any, batch_idx: int, dataloader_idx: int) -> Any: + input_image = batch[DefaultDataKeys.INPUT] + return self(input_image) diff --git a/flash/image/style_transfer/utils.py b/flash/image/style_transfer/utils.py new file mode 100644 index 0000000000..dc34ba6d4f --- /dev/null +++ b/flash/image/style_transfer/utils.py @@ -0,0 +1,10 @@ +from typing import NoReturn + +__all__ = ["raise_not_supported"] + + +def raise_not_supported(phase: str) -> NoReturn: + raise RuntimeError( + f"Style transfer does not support a {phase} phase, " + f"since there is no metric to objectively determine the quality of a stylization." + ) diff --git a/flash_examples/finetuning/style_transfer.py b/flash_examples/finetuning/style_transfer.py new file mode 100644 index 0000000000..31b49ad6b2 --- /dev/null +++ b/flash_examples/finetuning/style_transfer.py @@ -0,0 +1,26 @@ +import sys + +import flash +from flash.core.data.utils import download_data +from flash.core.utilities.imports import _PYSTICHE_AVAILABLE + +if _PYSTICHE_AVAILABLE: + import pystiche.demo + + from flash.image.style_transfer import StyleTransfer, StyleTransferData +else: + print("Please, run `pip install pystiche`") + sys.exit(1) + +download_data("https://github.com/zhiqwang/yolov5-rt-stack/releases/download/v0.3.0/coco128.zip", "data/") + +data_module = StyleTransferData.from_folders(train_folder="data/coco128/images", batch_size=4) + +style_image = pystiche.demo.images()["paint"].read(size=256) + +model = StyleTransfer(style_image) + +trainer = flash.Trainer(max_epochs=2) +trainer.fit(model, data_module) + +trainer.save_checkpoint("style_transfer_model.pt") diff --git a/flash_examples/predict/semantic_segmentation.py b/flash_examples/predict/semantic_segmentation.py index 41bd89654a..3938ffe17a 100644 --- a/flash_examples/predict/semantic_segmentation.py +++ b/flash_examples/predict/semantic_segmentation.py @@ -24,7 +24,9 @@ ) # 2. Load the model from a checkpoint -model = SemanticSegmentation.load_from_checkpoint("semantic_segmentation_model.pt") +model = SemanticSegmentation.load_from_checkpoint( + "https://flash-weights.s3.amazonaws.com/semantic_segmentation_model.pt" +) model.serializer = SegmentationLabels(visualize=True) # 3. Predict what's on a few images and visualize! diff --git a/flash_examples/predict/style_transfer.py b/flash_examples/predict/style_transfer.py new file mode 100644 index 0000000000..bd8eb6041f --- /dev/null +++ b/flash_examples/predict/style_transfer.py @@ -0,0 +1,49 @@ +import sys + +import numpy as np +import torch +from pytorch_lightning.callbacks.prediction_writer import BasePredictionWriter + +import flash +from flash.core.data.utils import download_data +from flash.core.utilities.imports import _MATPLOTLIB_AVAILABLE, _PYSTICHE_AVAILABLE +from flash.image.style_transfer import StyleTransfer, StyleTransferData + +if not _PYSTICHE_AVAILABLE: + print("Please, run `pip install pystiche`") + sys.exit(1) + + +class StyleTransferWriter(BasePredictionWriter): + + def __init__(self) -> None: + super().__init__("batch") + + def write_on_batch_end( + self, trainer, pl_module, prediction, batch_indices, batch, batch_idx, dataloader_idx + ) -> None: + """ + Implement the logic to save a given batch of predictions. + torch.save({"preds": prediction, "batch_indices": batch_indices}, "prediction_{batch_idx}.pt") + """ + + +download_data("https://github.com/zhiqwang/yolov5-rt-stack/releases/download/v0.3.0/coco128.zip", "data/") + +model = StyleTransfer.load_from_checkpoint("https://flash-weights.s3.amazonaws.com/style_transfer_model.pt") + +datamodule = StyleTransferData.from_folders(predict_folder="data/coco128/images/train2017", batch_size=4) + +trainer = flash.Trainer(max_epochs=2, callbacks=StyleTransferWriter(), limit_predict_batches=1) +predictions = trainer.predict(model, datamodule=datamodule) + +# display the first stylized image. +image_prediction = torch.stack(predictions[0])[0].numpy() + +if _MATPLOTLIB_AVAILABLE and not flash._IS_TESTING: + import matplotlib.pyplot as plt + image = np.moveaxis(image_prediction, 0, 2) + image -= image.min() + image /= image.max() + plt.imshow(image) + plt.show() diff --git a/requirements/datatype_image_style_transfer.txt b/requirements/datatype_image_style_transfer.txt new file mode 100644 index 0000000000..e536cae01f --- /dev/null +++ b/requirements/datatype_image_style_transfer.txt @@ -0,0 +1 @@ +pystiche>=0.7.2 diff --git a/setup.py b/setup.py index df687a01f5..a7e53d8826 100644 --- a/setup.py +++ b/setup.py @@ -32,11 +32,14 @@ def _load_py_module(fname, pkg="flash"): "text": setup_tools._load_requirements(path_dir=_PATH_REQUIRE, file_name="datatype_text.txt"), "tabular": setup_tools._load_requirements(path_dir=_PATH_REQUIRE, file_name="datatype_tabular.txt"), "image": setup_tools._load_requirements(path_dir=_PATH_REQUIRE, file_name="datatype_image.txt"), + "image_style_transfer": setup_tools._load_requirements( + path_dir=_PATH_REQUIRE, file_name="datatype_image_style_transfer.txt" + ), "video": setup_tools._load_requirements(path_dir=_PATH_REQUIRE, file_name="datatype_video.txt"), } # remove possible duplicate. -extras["vision"] = list(set(extras["image"] + extras["video"])) +extras["vision"] = list(set(extras["image"] + extras["video"] + extras["image_style_transfer"])) extras["dev"] = list(set(extras["vision"] + extras["tabular"] + extras["text"] + extras["image"])) extras["dev-test"] = list(set(extras["test"] + extras["dev"])) extras["all"] = list(set(extras["dev"] + extras["docs"])) diff --git a/tests/examples/test_scripts.py b/tests/examples/test_scripts.py index 7564a98941..901b874c85 100644 --- a/tests/examples/test_scripts.py +++ b/tests/examples/test_scripts.py @@ -22,6 +22,7 @@ from flash.core.utilities.imports import ( _IMAGE_AVAILABLE, + _PYSTICHE_GREATER_EQUAL_0_7_2, _TABULAR_AVAILABLE, _TEXT_AVAILABLE, _TORCHVISION_GREATER_EQUAL_0_9, @@ -94,6 +95,11 @@ def run_test(filepath): "translation.py", marks=pytest.mark.skipif(not _TEXT_AVAILABLE, reason="text libraries aren't installed") ), + pytest.param( + "finetuning", + "style_transfer.py", + marks=pytest.mark.skipif(not _PYSTICHE_GREATER_EQUAL_0_7_2, reason="pystiche is not installed") + ), pytest.param( "predict", "image_classification.py", diff --git a/tests/vision/__init__.py b/tests/image/__init__.py similarity index 100% rename from tests/vision/__init__.py rename to tests/image/__init__.py diff --git a/tests/vision/classification/__init__.py b/tests/image/classification/__init__.py similarity index 100% rename from tests/vision/classification/__init__.py rename to tests/image/classification/__init__.py diff --git a/tests/vision/classification/test_data.py b/tests/image/classification/test_data.py similarity index 100% rename from tests/vision/classification/test_data.py rename to tests/image/classification/test_data.py diff --git a/tests/vision/classification/test_data_model_integration.py b/tests/image/classification/test_data_model_integration.py similarity index 100% rename from tests/vision/classification/test_data_model_integration.py rename to tests/image/classification/test_data_model_integration.py diff --git a/tests/vision/classification/test_model.py b/tests/image/classification/test_model.py similarity index 100% rename from tests/vision/classification/test_model.py rename to tests/image/classification/test_model.py diff --git a/tests/vision/detection/__init__.py b/tests/image/detection/__init__.py similarity index 100% rename from tests/vision/detection/__init__.py rename to tests/image/detection/__init__.py diff --git a/tests/vision/detection/test_data.py b/tests/image/detection/test_data.py similarity index 100% rename from tests/vision/detection/test_data.py rename to tests/image/detection/test_data.py diff --git a/tests/vision/detection/test_data_model_integration.py b/tests/image/detection/test_data_model_integration.py similarity index 96% rename from tests/vision/detection/test_data_model_integration.py rename to tests/image/detection/test_data_model_integration.py index 2e195c1b9f..428a053b75 100644 --- a/tests/vision/detection/test_data_model_integration.py +++ b/tests/image/detection/test_data_model_integration.py @@ -26,7 +26,7 @@ Image = None if _COCO_AVAILABLE: - from tests.vision.detection.test_data import _create_synth_coco_dataset + from tests.image.detection.test_data import _create_synth_coco_dataset @pytest.mark.skipif(not _IMAGE_AVAILABLE, reason="pycocotools is not installed for testing") diff --git a/tests/vision/detection/test_model.py b/tests/image/detection/test_model.py similarity index 100% rename from tests/vision/detection/test_model.py rename to tests/image/detection/test_model.py diff --git a/tests/vision/segmentation/__init__.py b/tests/image/segmentation/__init__.py similarity index 100% rename from tests/vision/segmentation/__init__.py rename to tests/image/segmentation/__init__.py diff --git a/tests/vision/segmentation/test_data.py b/tests/image/segmentation/test_data.py similarity index 100% rename from tests/vision/segmentation/test_data.py rename to tests/image/segmentation/test_data.py diff --git a/tests/vision/segmentation/test_model.py b/tests/image/segmentation/test_model.py similarity index 100% rename from tests/vision/segmentation/test_model.py rename to tests/image/segmentation/test_model.py diff --git a/tests/vision/segmentation/test_serialization.py b/tests/image/segmentation/test_serialization.py similarity index 100% rename from tests/vision/segmentation/test_serialization.py rename to tests/image/segmentation/test_serialization.py diff --git a/tests/image/style_transfer/test_model.py b/tests/image/style_transfer/test_model.py new file mode 100644 index 0000000000..fbcdd6c7ad --- /dev/null +++ b/tests/image/style_transfer/test_model.py @@ -0,0 +1,22 @@ +import pytest + +from flash.core.utilities.imports import _IMAGE_STLYE_TRANSFER, _PYSTICHE_GREATER_EQUAL_0_7_2 +from flash.image.style_transfer import StyleTransfer + + +@pytest.mark.skipif(not _PYSTICHE_GREATER_EQUAL_0_7_2, reason="image style transfer libraries aren't installed.") +def test_style_transfer_task(): + + model = StyleTransfer( + backbone="vgg11", content_layer="relu1_2", content_weight=10, style_layers="relu1_2", style_weight=11 + ) + assert model.perceptual_loss.content_loss.encoder.layer == "relu1_2" + assert model.perceptual_loss.content_loss.score_weight == 10 + assert "relu1_2" in [n for n, m in model.perceptual_loss.style_loss.named_modules()] + assert model.perceptual_loss.style_loss.score_weight == 11 + + +@pytest.mark.skipif(_IMAGE_STLYE_TRANSFER, reason="image style transfer libraries are installed.") +def test_style_transfer_task_import(): + with pytest.raises(ModuleNotFoundError, match="[image_style_transfer]"): + StyleTransfer() diff --git a/tests/vision/test_backbones.py b/tests/image/test_backbones.py similarity index 100% rename from tests/vision/test_backbones.py rename to tests/image/test_backbones.py