Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Feature] Support Perceptual Loss #471

Merged
merged 6 commits into from
Nov 2, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions mmgen/models/losses/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
r1_gradient_penalty_loss)
from .gan_loss import GANLoss
from .gen_auxiliary_loss import (CLIPLoss, FaceIdLoss,
GeneratorPathRegularizer,
GeneratorPathRegularizer, PerceptualLoss,
gen_path_regularizer)
from .pixelwise_loss import (DiscretizedGaussianLogLikelihoodLoss,
GaussianKLDLoss, L1Loss, MSELoss,
Expand All @@ -18,5 +18,5 @@
'GeneratorPathRegularizer', 'gen_path_regularizer', 'MSELoss', 'L1Loss',
'gaussian_kld', 'GaussianKLDLoss', 'DiscretizedGaussianLogLikelihoodLoss',
'DDPMVLBLoss', 'discretized_gaussian_log_likelihood', 'FaceIdLoss',
'CLIPLoss'
'CLIPLoss', 'PerceptualLoss'
]
333 changes: 333 additions & 0 deletions mmgen/models/losses/gen_auxiliary_loss.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,8 +4,12 @@
import torch.autograd as autograd
import torch.distributed as dist
import torch.nn as nn
import torchvision.models.vgg as vgg
from mmcv.runner import load_checkpoint

from mmgen.models.builder import MODULES, build_module
from mmgen.utils import get_root_logger
from .pixelwise_loss import l1_loss, mse_loss


def gen_path_regularizer(generator,
Expand Down Expand Up @@ -532,3 +536,332 @@ def loss_name():
str: The name of this loss item.
"""
return 'clip_loss'


class PerceptualVGG(nn.Module):
"""VGG network used in calculating perceptual loss.

In this implementation, we allow users to choose whether use normalization
in the input feature and the type of vgg network. Note that the pretrained
path must fit the vgg type.

Args:
layer_name_list (list[str]): According to the name in this list,
forward function will return the corresponding features. This
list contains the name each layer in `vgg.feature`. An example
of this list is ['4', '10'].
vgg_type (str): Set the type of vgg network. Default: 'vgg19'.
use_input_norm (bool): If True, normalize the input image.
Importantly, the input feature must in the range [0, 1].
Default: True.
pretrained (str): Path for pretrained weights. Default:
'torchvision://vgg19'
"""

def __init__(self,
layer_name_list,
vgg_type='vgg19',
use_input_norm=True,
pretrained='torchvision://vgg19'):
super().__init__()
if pretrained.startswith('torchvision://'):
assert vgg_type in pretrained
self.layer_name_list = layer_name_list
self.use_input_norm = use_input_norm

# get vgg model and load pretrained vgg weight
# remove _vgg from attributes to avoid `find_unused_parameters` bug
_vgg = getattr(vgg, vgg_type)()
self.init_weights(_vgg, pretrained)
num_layers = max(map(int, layer_name_list)) + 1
assert len(_vgg.features) >= num_layers
# only borrow layers that will be used from _vgg to avoid unused params
self.vgg_layers = _vgg.features[:num_layers]

if self.use_input_norm:
# the mean is for image with range [0, 1]
self.register_buffer(
'mean',
torch.Tensor([0.485, 0.456, 0.406]).view(1, 3, 1, 1))
# the std is for image with range [-1, 1]
self.register_buffer(
'std',
torch.Tensor([0.229, 0.224, 0.225]).view(1, 3, 1, 1))

for v in self.vgg_layers.parameters():
v.requires_grad = False

def forward(self, x):
"""Forward function.

Args:
x (Tensor): Input tensor with shape (n, c, h, w).

Returns:
Tensor: Forward results.
"""

if self.use_input_norm:
x = (x - self.mean) / self.std
output = {}

for name, module in self.vgg_layers.named_children():
x = module(x)
if name in self.layer_name_list:
output[name] = x.clone()
return output

def init_weights(self, model, pretrained):
"""Init weights.

Args:
model (nn.Module): Models to be inited.
pretrained (str): Path for pretrained weights.
"""
logger = get_root_logger()
load_checkpoint(model, pretrained, logger=logger)


@MODULES.register_module()
class PerceptualLoss(nn.Module):
"""Perceptual loss with commonly used style loss.

.. code-block:: python
:caption: Code from StaticUnconditionalGAN, train_step
:linenos:

data_dict_ = dict(
gen=self.generator,
disc=self.discriminator,
disc_pred_fake=disc_pred_fake,
disc_pred_real=disc_pred_real,
fake_imgs=fake_imgs,
real_imgs=real_imgs,
iteration=curr_iter,
batch_size=batch_size)

But in this loss, we may need to provide ``pred`` and ``target`` as input.
Thus, an example of the ``data_info`` is:

.. code-block:: python
:linenos:

data_info = dict(
pred='fake_imgs',
target='real_imgs',
layer_weights={
'4': 1.,
'9': 1.,
'18': 1.},
)

Then, the module will automatically construct this mapping from the input
data dictionary.

Args:
data_info (dict, optional): Dictionary contains the mapping between
loss input args and data dictionary. If ``None``, this module will
directly pass the input data to the loss function.
Defaults to None.
loss_name (str, optional): Name of the loss item. If you want this loss
item to be included into the backward graph, `loss_` must be the
prefix of the name. Defaults to 'loss_mse'.
layers_weights (dict): The weight for each layer of vgg feature for
perceptual loss. Here is an example: {'4': 1., '9': 1., '18': 1.},
which means the 5th, 10th and 18th feature layer will be
extracted with weight 1.0 in calculating losses. Defaults to
'{'4': 1., '9': 1., '18': 1.}'.
layers_weights_style (dict): The weight for each layer of vgg feature
for style loss. If set to 'None', the weights are set equal to
the weights for perceptual loss. Default: None.
vgg_type (str): The type of vgg network used as feature extractor.
Default: 'vgg19'.
use_input_norm (bool): If True, normalize the input image in vgg.
Default: True.
perceptual_weight (float): If `perceptual_weight > 0`, the perceptual
loss will be calculated and the loss will multiplied by the
weight. Default: 1.0.
style_weight (float): If `style_weight > 0`, the style loss will be
calculated and the loss will multiplied by the weight.
Default: 1.0.
norm_img (bool): If True, the image will be normed to [0, 1]. Note that
this is different from the `use_input_norm` which norm the input in
in forward function of vgg according to the statistics of dataset.
Importantly, the input image must be in range [-1, 1].
pretrained (str): Path for pretrained weights. Default:
'torchvision://vgg19'.
criterion (str): Criterion type. Options are 'l1' and 'mse'.
Default: 'l1'.
split_style_loss (bool): Whether return a separate style loss item.
Options are True and False. Default: False
"""

def __init__(self,
data_info=None,
loss_name='loss_perceptual',
layer_weights={
'4': 1.,
'9': 1.,
'18': 1.
},
layer_weights_style=None,
vgg_type='vgg19',
use_input_norm=True,
perceptual_weight=1.0,
style_weight=1.0,
norm_img=True,
pretrained='torchvision://vgg19',
criterion='l1',
split_style_loss=False):
super().__init__()
self.data_info = data_info
self._loss_name = loss_name
self.norm_img = norm_img
self.perceptual_weight = perceptual_weight
self.style_weight = style_weight
self.layer_weights = layer_weights
self.layer_weights_style = layer_weights_style
self.split_style_loss = split_style_loss

self.vgg = PerceptualVGG(
layer_name_list=list(self.layer_weights.keys()),
vgg_type=vgg_type,
use_input_norm=use_input_norm,
pretrained=pretrained)

if self.layer_weights_style is not None and \
self.layer_weights_style != self.layer_weights:
self.vgg_style = PerceptualVGG(
layer_name_list=list(self.layer_weights_style.keys()),
vgg_type=vgg_type,
use_input_norm=use_input_norm,
pretrained=pretrained)
else:
self.layer_weights_style = self.layer_weights
self.vgg_style = None

criterion = criterion.lower()
if criterion == 'l1':
self.criterion = l1_loss
elif criterion == 'mse':
self.criterion = mse_loss
else:
raise NotImplementedError(
f'{criterion} criterion has not been supported in'
' this version.')

def forward(self, *args, **kwargs):
"""Forward function. If ``self.data_info`` is not ``None``, a
dictionary containing all of the data and necessary modules should be
passed into this function. If this dictionary is given as a non-keyword
argument, it should be offered as the first argument. If you are using
keyword argument, please name it as `outputs_dict`.

If ``self.data_info`` is ``None``, the input argument or key-word
argument will be directly passed to loss function, ``mse_loss``.

Args:
pred (Tensor): Input tensor with shape (n, c, h, w).
target (Tensor): Ground-truth tensor with shape (n, c, h, w).

Returns:
Tensor: Forward results.
"""
# use data_info to build computational path
if self.data_info is not None:
# parse the args and kwargs
if len(args) == 1:
assert isinstance(args[0], dict), (
'You should offer a dictionary containing network outputs '
'for building up computational graph of this loss module.')
outputs_dict = args[0]
elif 'outputs_dict' in kwargs:
assert len(args) == 0, (
'If the outputs dict is given in keyworded arguments, no'
' further non-keyworded arguments should be offered.')
outputs_dict = kwargs.pop('outputs_dict')
else:
raise NotImplementedError(
'Cannot parsing your arguments passed to this loss module.'
' Please check the usage of this module')
# link the outputs with loss input args according to self.data_info
loss_input_dict = {
k: outputs_dict[v]
for k, v in self.data_info.items()
}
kwargs.update(loss_input_dict)
return self.perceptual_loss(**kwargs)
else:
# if you have not define how to build computational graph, this
# module will just directly return the loss as usual.
return self.perceptual_loss(*args, **kwargs)

def perceptual_loss(self, pred, target):
if self.norm_img:
pred = (pred + 1.) * 0.5
target = (target + 1.) * 0.5
# extract vgg features
pred_features = self.vgg(pred)
target_features = self.vgg(target.detach())

# calculate perceptual loss
if self.perceptual_weight > 0:
percep_loss = 0
for k in pred_features.keys():
percep_loss += self.criterion(
pred_features[k],
target_features[k],
weight=self.layer_weights[k])
percep_loss *= self.perceptual_weight
else:
percep_loss = 0.

# calculate style loss
if self.style_weight > 0:
if self.vgg_style is not None:
pred_features = self.vgg_style(pred)
target_features = self.vgg_style(target.detach())

style_loss = 0
for k in pred_features.keys():
style_loss += self.criterion(
self._gram_mat(pred_features[k]),
self._gram_mat(
target_features[k])) * self.layer_weights_style[k]
style_loss *= self.style_weight
else:
style_loss = 0.

if self.split_style_loss:
return percep_loss, style_loss
else:
return percep_loss + style_loss

def _gram_mat(self, x):
"""Calculate Gram matrix.

Args:
x (torch.Tensor): Tensor with shape of (n, c, h, w).

Returns:
torch.Tensor: Gram matrix.
"""
(n, c, h, w) = x.size()
features = x.view(n, c, w * h)
features_t = features.transpose(1, 2)
gram = features.bmm(features_t) / (c * h * w)
return gram

def loss_name(self):
"""Loss Name.

This function must be implemented and will return the name of this
loss function. This name will be used to combine different loss items
by simple sum operation. In addition, if you want this loss item to be
included into the backward graph, `loss_` must be the prefix of the
name.

Returns:
str: The name of this loss item.
"""
return self._loss_name
Loading