Skip to content
This repository has been archived by the owner on Oct 9, 2023. It is now read-only.

[WIP] add style transfer task with pystiche #262

Merged
merged 72 commits into from
May 17, 2021
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
Show all changes
72 commits
Select commit Hold shift + click to select a range
413a530
add style transfer task with pystiche
pmeier May 5, 2021
be5a893
address review comments
pmeier May 10, 2021
398bbf3
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] May 10, 2021
61c9480
fix type hint
pmeier May 10, 2021
0a70150
allow passing style_image by path
pmeier May 10, 2021
f6b2fcc
add batch_size
pmeier May 10, 2021
edf0dff
add data_module based on image classification
pmeier May 11, 2021
4ffc734
add internal pre / post-processing
pmeier May 11, 2021
9bb45bc
bail out if val / test step is performed
pmeier May 11, 2021
e82a94c
update example
pmeier May 11, 2021
d939fad
move example from predict to finetuning
pmeier May 11, 2021
5b10dbb
remove metrics from task
pmeier May 11, 2021
2e6901a
flake8
pmeier May 11, 2021
3c16bc0
Merge branch 'master' into style-transfer
pmeier May 11, 2021
eeed004
remove unused imports
pmeier May 11, 2021
7d38a5b
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] May 11, 2021
e844db4
remove grayscale handling
pmeier May 11, 2021
a932f86
address review comments and small fixes
pmeier May 11, 2021
fb30c16
Merge branch 'master' into style-transfer
pmeier May 11, 2021
ba091cc
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] May 11, 2021
e399d50
streamline apply_to_input
pmeier May 12, 2021
dc80a21
fix hyper parameters saving
pmeier May 12, 2021
9f7fd41
implement custom step
pmeier May 12, 2021
eabf49b
cleanup
pmeier May 12, 2021
54ae632
Merge branch 'master' into style-transfer
pmeier May 12, 2021
361074b
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] May 12, 2021
464a26c
add explanation to not supported phases
pmeier May 14, 2021
d78989e
temporarily use unreleased pystiche version
pmeier May 14, 2021
1b2e6e3
add missing transforms in preprocess
pmeier May 14, 2021
0feaf7a
introduce multi layer encoders as backbones
pmeier May 14, 2021
c41e38c
refactor task
pmeier May 14, 2021
de62996
add explanation for modified gram operator
pmeier May 14, 2021
36ce2da
Merge branch 'master' into style-transfer
pmeier May 14, 2021
02a0c29
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] May 14, 2021
081cb48
Merge branch 'master' into style-transfer
pmeier May 16, 2021
21d0817
streamline default transforms
pmeier May 16, 2021
f868489
add disabled test for finetuning example
pmeier May 16, 2021
39baffd
add documentation skeleton
pmeier May 16, 2021
7660f51
update changelog
pmeier May 16, 2021
46ff9d0
update
tchaton May 17, 2021
e230a2c
update
tchaton May 17, 2021
40fd08b
update
tchaton May 17, 2021
278a874
update
tchaton May 17, 2021
740bb22
update
tchaton May 17, 2021
3e2ad57
update
tchaton May 17, 2021
753dfd7
update
tchaton May 17, 2021
a49474a
update
tchaton May 17, 2021
55888c3
update
tchaton May 17, 2021
71b63f4
update
tchaton May 17, 2021
4ab9aae
update
tchaton May 17, 2021
4b1e7b9
update
tchaton May 17, 2021
5c3e72a
update
tchaton May 17, 2021
b2b3132
update
tchaton May 17, 2021
8d23b95
Merge branch 'master' into style-transfer
tchaton May 17, 2021
fa7c304
change skipif
tchaton May 17, 2021
35d5702
Merge branch 'style-transfer' of https://github.com/pmeier/lightning-…
tchaton May 17, 2021
ed1574f
update
tchaton May 17, 2021
f173788
update
tchaton May 17, 2021
472cb92
update
tchaton May 17, 2021
7ac8234
Merge branch 'master' into style-transfer
tchaton May 17, 2021
d2ab928
fix image size for preprocess
pmeier May 17, 2021
1a7819c
fix style transfer requirements
pmeier May 17, 2021
a4c7f0a
update
tchaton May 17, 2021
82f2005
Merge branch 'style-transfer' of https://github.com/pmeier/lightning-…
tchaton May 17, 2021
559446b
update doc
tchaton May 17, 2021
22ee835
Merge remote-tracking branch 'pmeier/style-transfer' into style-transfer
pmeier May 17, 2021
a3b95d5
fix style transfer requirements
pmeier May 17, 2021
b6e459b
Merge remote-tracking branch 'pmeier/style-transfer' into style-transfer
pmeier May 17, 2021
b8d93be
update
tchaton May 17, 2021
77db374
Merge branch 'style-transfer' of https://github.com/pmeier/lightning-…
tchaton May 17, 2021
4fbf11c
add reference to pystiche
pmeier May 17, 2021
31efb53
remove unnecessary import
pmeier May 17, 2021
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions flash/utils/imports.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,3 +8,4 @@
_PYTORCHVIDEO_AVAILABLE = _module_available("pytorchvideo")
_MATPLOTLIB_AVAILABLE = _module_available("matplotlib")
_TRANSFORMERS_AVAILABLE = _module_available("transformers")
_PYSTICHE_AVAILABLE = _module_available("pystiche")
1 change: 1 addition & 0 deletions flash/vision/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,3 +2,4 @@
from flash.vision.classification import ImageClassificationData, ImageClassificationPreprocess, ImageClassifier
from flash.vision.detection import ObjectDetectionData, ObjectDetector
from flash.vision.embedding import ImageEmbedder
from .style_transfer import *
pmeier marked this conversation as resolved.
Show resolved Hide resolved
2 changes: 2 additions & 0 deletions flash/vision/style_transfer/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,2 @@
from .data import *
from .model import *
1 change: 1 addition & 0 deletions flash/vision/style_transfer/data.py
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
# TODO
pmeier marked this conversation as resolved.
Show resolved Hide resolved
90 changes: 90 additions & 0 deletions flash/vision/style_transfer/model.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,90 @@
from torch import nn
from torch.nn.functional import interpolate

__all__ = ["Transformer"]


class Interpolate(nn.Module):
pmeier marked this conversation as resolved.
Show resolved Hide resolved
def __init__(self, scale_factor=1.0, mode="nearest"):
super().__init__()
self.scale_factor = scale_factor
self.mode = mode

def forward(self, input):
return interpolate(input, scale_factor=self.scale_factor, mode=self.mode)

def extra_repr(self):
extras = []
if self.scale_factor:
extras.append(f"scale_factor={self.scale_factor}")
if self.mode != "nearest":
extras.append(f"mode={self.mode}")
return ", ".join(extras)


class Conv(nn.Module):
pmeier marked this conversation as resolved.
Show resolved Hide resolved
def __init__(
self,
in_channels,
out_channels,
kernel_size,
stride=1,
upsample=False,
norm=True,
activation=True,
):
super().__init__()
self.upsample = Interpolate(scale_factor=stride) if upsample else None
self.pad = nn.ReflectionPad2d(kernel_size // 2)
self.conv = nn.Conv2d(
in_channels, out_channels, kernel_size, stride=1 if upsample else stride
)
self.norm = nn.InstanceNorm2d(out_channels, affine=True) if norm else None
self.activation = nn.ReLU() if activation else None

def forward(self, input):
if self.upsample:
input = self.upsample(input)

output = self.conv(self.pad(input))

if self.norm:
output = self.norm(output)
if self.activation:
output = self.activation(output)

return output


class Residual(nn.Module):
def __init__(self, channels):
super().__init__()
self.conv1 = Conv(channels, channels, kernel_size=3)
self.conv2 = Conv(channels, channels, kernel_size=3, activation=False)

def forward(self, input):
output = self.conv2(self.conv1(input))
return output + input


class Transformer(nn.Module):
pmeier marked this conversation as resolved.
Show resolved Hide resolved
def __init__(self):
super().__init__()
self.encoder = nn.Sequential(
Conv(3, 32, kernel_size=9),
Conv(32, 64, kernel_size=3, stride=2),
Conv(64, 128, kernel_size=3, stride=2),
Residual(128),
Residual(128),
Residual(128),
Residual(128),
Residual(128),
)
self.decoder = nn.Sequential(
Conv(128, 64, kernel_size=3, stride=2, upsample=True),
Conv(64, 32, kernel_size=3, stride=2, upsample=True),
Conv(32, 3, kernel_size=9, norm=False, activation=False),
)

def forward(self, input):
return self.decoder(self.encoder(input))
47 changes: 47 additions & 0 deletions flash_examples/predict/style_transfer.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,47 @@
import sys

import torch

from flash.utils.imports import _PYSTICHE_AVAILABLE

if _PYSTICHE_AVAILABLE:
from pystiche import enc, loss, ops
else:
print("Please, run `pip install pystiche`")
sys.exit(0)

multi_layer_encoder = enc.vgg16_multi_layer_encoder()

content_layer = "relu2_2"
content_encoder = multi_layer_encoder.extract_encoder(content_layer)
content_weight = 1e5
content_loss = ops.FeatureReconstructionOperator(
content_encoder, score_weight=content_weight
)


class GramOperator(ops.GramOperator):
def enc_to_repr(self, enc: torch.Tensor) -> torch.Tensor:
repr = super().enc_to_repr(enc)
num_channels = repr.size()[1]
return repr / num_channels
pmeier marked this conversation as resolved.
Show resolved Hide resolved


style_layers = ("relu1_2", "relu2_2", "relu3_3", "relu4_3")
style_weight = 1e10
style_loss = ops.MultiLayerEncodingOperator(
pmeier marked this conversation as resolved.
Show resolved Hide resolved
multi_layer_encoder,
style_layers,
lambda encoder, layer_weight: GramOperator(encoder, score_weight=layer_weight),
layer_weights="sum",
score_weight=style_weight,
)

# TODO: this needs to be moved to the device to be trained on
# TODO: we need to register a style image here
perceptual_loss = loss.PerceptualLoss(content_loss, style_loss)
pmeier marked this conversation as resolved.
Show resolved Hide resolved


def loss_fn(image):
perceptual_loss.set_content_image(image)
return float(perceptual_loss(image))
1 change: 1 addition & 0 deletions requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -19,3 +19,4 @@ pycocotools>=2.0.2 ; python_version >= "3.7"
kornia==0.5.0
pytorchvideo
matplotlib # used by the visualisation callback
pystiche>=0.7.1