-
Notifications
You must be signed in to change notification settings - Fork 212
[WIP] add style transfer task with pystiche #262
Conversation
Hello @pmeier! Thanks for updating this PR. There are currently no PEP 8 issues detected in this Pull Request. Cheers! 🍻 Comment last updated at 2021-05-17 19:48:40 UTC |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This is just the very preliminary state. I've hit several roadblocks that need to be resolved:
- Models for Neural Style Transfer are trained in an unsupervised manner. Thus, I simply need a dataset that can supply images without labels / annotations.
- Following from the point above, there is no train / val / test split. The process is done after the training. If there is something like a validation / test it is performed manually by trying a few examples. There is no objective way to put a number on the quality of the stylization.
- The models used as transformer are not named. I've seen that the model is usually loaded by their name, which is thus not possible. We could fall back to a author / year combination of the paper the architecture was published.
I'll fix the linting errors and update the documentation, tests, and the changelog when the main part is resolved.
Hey @pmeier, Awesome you started. You need to properly create a task. Here is the pseudo code to get you started. class StyleTransfer(Task):
models: FlashRegistry = STYLE_TRANSFER_MODELS
def __init__(
self,
content_image: Union[Image.PIL, str, np.ndarray],
style_loss: Optional[Callable] = None,
content_loss: Optional[Callable] = None,
perceptual_loss: Optional[Callable] = None,
model: Union[str, Tuple[nn.Module, int]] = "transformer",
model_kwargs: Optional[Dict] = None,
optimizer: Union[Type[torch.optim.Optimizer], torch.optim.Optimizer] = torch.optim.Adam,
optimizer_kwargs: Optional[Dict[str, Any]] = None,
scheduler: Optional[Union[Type[_LRScheduler], str, _LRScheduler]] = None,
scheduler_kwargs: Optional[Dict[str, Any]] = None,
metrics: Union[torchmetrics.Metric, Mapping, Sequence, None] = None,
learning_rate: float = 1e-3,
serializer: Optional[Union[Serializer, Mapping[str, Serializer]]] = None,
):
if perceptual_loss is None:
content_loss = content_loss or self.default_content_loss()
style_loss = style_loss or self.default_style_transfer()
perceptual_loss = loss.PerceptualLoss(content_loss, style_loss)
if content_image is not None:
perceptual_loss.set_content_image(content_image)
self.perceptual_loss = perceptual_loss
self.save_hyperparameters()
if isinstance(model, tuple):
model = model
else:
model = self.models.get(model)(pretrained=pretrained, **model_kwargs)
super().__init__(
model=model,
loss_fn=perceptual_loss,
optimizer=optimizer,
optimizer_kwargs=optimizer_kwargs,
scheduler=scheduler,
scheduler_kwargs=scheduler_kwargs,
metrics=metrics,
learning_rate=learning_rate,
serializer=serializer,
)
def default_content_loss(self):
multi_layer_encoder = enc.vgg16_multi_layer_encoder()
content_layer = "relu2_2"
content_encoder = multi_layer_encoder.extract_encoder(content_layer)
content_weight = 1e5
return = ops.FeatureReconstructionOperator(
content_encoder, score_weight=content_weight
)
def default_style_transfer(self):
class GramOperator(ops.GramOperator):
def enc_to_repr(self, enc: torch.Tensor) -> torch.Tensor:
repr = super().enc_to_repr(enc)
num_channels = repr.size()[1]
return repr / num_channels
style_layers = ("relu1_2", "relu2_2", "relu3_3", "relu4_3")
style_weight = 1e10
return ops.MultiLayerEncodingOperator(
multi_layer_encoder,
style_layers,
lambda encoder, layer_weight: GramOperator(encoder, score_weight=layer_weight),
layer_weights="sum",
score_weight=style_weight,
)
def forward(self, x):
# not sure about this part
self.model(x)
return self.perceptual_loss(x)
# in finetuning.
content_image = ...
dm = StyleDataModule.from_folder(...)
model = StyleTransfer(content_image=content_image)
trainer = Trainer(...)
trainer.fit(model, dm) |
Codecov Report
@@ Coverage Diff @@
## master #262 +/- ##
==========================================
- Coverage 87.54% 87.05% -0.50%
==========================================
Files 73 78 +5
Lines 3815 3970 +155
==========================================
+ Hits 3340 3456 +116
- Misses 475 514 +39
Flags with carried forward coverage won't be shown. Click here to find out more.
Continue to review full report at Codecov.
|
for more information, see https://pre-commit.ci
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Apart from my comments / questions below, I'm wondering whether this example should be in predict
or finetune
.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Hey @tchaton, thanks for the commits. I have some comments below. Additionally, it looks like you have added a lot of changes that seemingly have nothing to do with this PR. Was that intentional?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Awesome! Small comment
What does this PR do?
Add a style transfer task using
pystiche
as backend.Note: Change codeblock to test-code when 0.7.2 is out.
Before submitting
PR review
Anyone in the community is free to review the PR once the tests have passed.
If we didn't discuss your PR in Github issues there's a high chance it will not be merged.
Did you have fun?
Make sure you had fun coding 🙃