diff --git a/mmgen/models/losses/__init__.py b/mmgen/models/losses/__init__.py index d39cb4d22..140ef8d07 100644 --- a/mmgen/models/losses/__init__.py +++ b/mmgen/models/losses/__init__.py @@ -6,7 +6,7 @@ r1_gradient_penalty_loss) from .gan_loss import GANLoss from .gen_auxiliary_loss import (CLIPLoss, FaceIdLoss, - GeneratorPathRegularizer, + GeneratorPathRegularizer, PerceptualLoss, gen_path_regularizer) from .pixelwise_loss import (DiscretizedGaussianLogLikelihoodLoss, GaussianKLDLoss, L1Loss, MSELoss, @@ -18,5 +18,5 @@ 'GeneratorPathRegularizer', 'gen_path_regularizer', 'MSELoss', 'L1Loss', 'gaussian_kld', 'GaussianKLDLoss', 'DiscretizedGaussianLogLikelihoodLoss', 'DDPMVLBLoss', 'discretized_gaussian_log_likelihood', 'FaceIdLoss', - 'CLIPLoss' + 'CLIPLoss', 'PerceptualLoss' ] diff --git a/mmgen/models/losses/gen_auxiliary_loss.py b/mmgen/models/losses/gen_auxiliary_loss.py index 890a42cab..42a5492c4 100644 --- a/mmgen/models/losses/gen_auxiliary_loss.py +++ b/mmgen/models/losses/gen_auxiliary_loss.py @@ -4,8 +4,12 @@ import torch.autograd as autograd import torch.distributed as dist import torch.nn as nn +import torchvision.models.vgg as vgg +from mmcv.runner import load_checkpoint from mmgen.models.builder import MODULES, build_module +from mmgen.utils import get_root_logger +from .pixelwise_loss import l1_loss, mse_loss def gen_path_regularizer(generator, @@ -532,3 +536,332 @@ def loss_name(): str: The name of this loss item. """ return 'clip_loss' + + +class PerceptualVGG(nn.Module): + """VGG network used in calculating perceptual loss. + + In this implementation, we allow users to choose whether use normalization + in the input feature and the type of vgg network. Note that the pretrained + path must fit the vgg type. + + Args: + layer_name_list (list[str]): According to the name in this list, + forward function will return the corresponding features. This + list contains the name each layer in `vgg.feature`. An example + of this list is ['4', '10']. + vgg_type (str): Set the type of vgg network. Default: 'vgg19'. + use_input_norm (bool): If True, normalize the input image. + Importantly, the input feature must in the range [0, 1]. + Default: True. + pretrained (str): Path for pretrained weights. Default: + 'torchvision://vgg19' + """ + + def __init__(self, + layer_name_list, + vgg_type='vgg19', + use_input_norm=True, + pretrained='torchvision://vgg19'): + super().__init__() + if pretrained.startswith('torchvision://'): + assert vgg_type in pretrained + self.layer_name_list = layer_name_list + self.use_input_norm = use_input_norm + + # get vgg model and load pretrained vgg weight + # remove _vgg from attributes to avoid `find_unused_parameters` bug + _vgg = getattr(vgg, vgg_type)() + self.init_weights(_vgg, pretrained) + num_layers = max(map(int, layer_name_list)) + 1 + assert len(_vgg.features) >= num_layers + # only borrow layers that will be used from _vgg to avoid unused params + self.vgg_layers = _vgg.features[:num_layers] + + if self.use_input_norm: + # the mean is for image with range [0, 1] + self.register_buffer( + 'mean', + torch.Tensor([0.485, 0.456, 0.406]).view(1, 3, 1, 1)) + # the std is for image with range [-1, 1] + self.register_buffer( + 'std', + torch.Tensor([0.229, 0.224, 0.225]).view(1, 3, 1, 1)) + + for v in self.vgg_layers.parameters(): + v.requires_grad = False + + def forward(self, x): + """Forward function. + + Args: + x (Tensor): Input tensor with shape (n, c, h, w). + + Returns: + Tensor: Forward results. + """ + + if self.use_input_norm: + x = (x - self.mean) / self.std + output = {} + + for name, module in self.vgg_layers.named_children(): + x = module(x) + if name in self.layer_name_list: + output[name] = x.clone() + return output + + def init_weights(self, model, pretrained): + """Init weights. + + Args: + model (nn.Module): Models to be inited. + pretrained (str): Path for pretrained weights. + """ + logger = get_root_logger() + load_checkpoint(model, pretrained, logger=logger) + + +@MODULES.register_module() +class PerceptualLoss(nn.Module): + """Perceptual loss with commonly used style loss. + + .. code-block:: python + :caption: Code from StaticUnconditionalGAN, train_step + :linenos: + + data_dict_ = dict( + gen=self.generator, + disc=self.discriminator, + disc_pred_fake=disc_pred_fake, + disc_pred_real=disc_pred_real, + fake_imgs=fake_imgs, + real_imgs=real_imgs, + iteration=curr_iter, + batch_size=batch_size) + + But in this loss, we may need to provide ``pred`` and ``target`` as input. + Thus, an example of the ``data_info`` is: + + .. code-block:: python + :linenos: + + data_info = dict( + pred='fake_imgs', + target='real_imgs', + layer_weights={ + '4': 1., + '9': 1., + '18': 1.}, + ) + + Then, the module will automatically construct this mapping from the input + data dictionary. + + Args: + data_info (dict, optional): Dictionary contains the mapping between + loss input args and data dictionary. If ``None``, this module will + directly pass the input data to the loss function. + Defaults to None. + loss_name (str, optional): Name of the loss item. If you want this loss + item to be included into the backward graph, `loss_` must be the + prefix of the name. Defaults to 'loss_mse'. + layers_weights (dict): The weight for each layer of vgg feature for + perceptual loss. Here is an example: {'4': 1., '9': 1., '18': 1.}, + which means the 5th, 10th and 18th feature layer will be + extracted with weight 1.0 in calculating losses. Defaults to + '{'4': 1., '9': 1., '18': 1.}'. + layers_weights_style (dict): The weight for each layer of vgg feature + for style loss. If set to 'None', the weights are set equal to + the weights for perceptual loss. Default: None. + vgg_type (str): The type of vgg network used as feature extractor. + Default: 'vgg19'. + use_input_norm (bool): If True, normalize the input image in vgg. + Default: True. + perceptual_weight (float): If `perceptual_weight > 0`, the perceptual + loss will be calculated and the loss will multiplied by the + weight. Default: 1.0. + style_weight (float): If `style_weight > 0`, the style loss will be + calculated and the loss will multiplied by the weight. + Default: 1.0. + norm_img (bool): If True, the image will be normed to [0, 1]. Note that + this is different from the `use_input_norm` which norm the input in + in forward function of vgg according to the statistics of dataset. + Importantly, the input image must be in range [-1, 1]. + pretrained (str): Path for pretrained weights. Default: + 'torchvision://vgg19'. + criterion (str): Criterion type. Options are 'l1' and 'mse'. + Default: 'l1'. + split_style_loss (bool): Whether return a separate style loss item. + Options are True and False. Default: False + """ + + def __init__(self, + data_info=None, + loss_name='loss_perceptual', + layer_weights={ + '4': 1., + '9': 1., + '18': 1. + }, + layer_weights_style=None, + vgg_type='vgg19', + use_input_norm=True, + perceptual_weight=1.0, + style_weight=1.0, + norm_img=True, + pretrained='torchvision://vgg19', + criterion='l1', + split_style_loss=False): + super().__init__() + self.data_info = data_info + self._loss_name = loss_name + self.norm_img = norm_img + self.perceptual_weight = perceptual_weight + self.style_weight = style_weight + self.layer_weights = layer_weights + self.layer_weights_style = layer_weights_style + self.split_style_loss = split_style_loss + + self.vgg = PerceptualVGG( + layer_name_list=list(self.layer_weights.keys()), + vgg_type=vgg_type, + use_input_norm=use_input_norm, + pretrained=pretrained) + + if self.layer_weights_style is not None and \ + self.layer_weights_style != self.layer_weights: + self.vgg_style = PerceptualVGG( + layer_name_list=list(self.layer_weights_style.keys()), + vgg_type=vgg_type, + use_input_norm=use_input_norm, + pretrained=pretrained) + else: + self.layer_weights_style = self.layer_weights + self.vgg_style = None + + criterion = criterion.lower() + if criterion == 'l1': + self.criterion = l1_loss + elif criterion == 'mse': + self.criterion = mse_loss + else: + raise NotImplementedError( + f'{criterion} criterion has not been supported in' + ' this version.') + + def forward(self, *args, **kwargs): + """Forward function. If ``self.data_info`` is not ``None``, a + dictionary containing all of the data and necessary modules should be + passed into this function. If this dictionary is given as a non-keyword + argument, it should be offered as the first argument. If you are using + keyword argument, please name it as `outputs_dict`. + + If ``self.data_info`` is ``None``, the input argument or key-word + argument will be directly passed to loss function, ``mse_loss``. + + Args: + pred (Tensor): Input tensor with shape (n, c, h, w). + target (Tensor): Ground-truth tensor with shape (n, c, h, w). + + Returns: + Tensor: Forward results. + """ + # use data_info to build computational path + if self.data_info is not None: + # parse the args and kwargs + if len(args) == 1: + assert isinstance(args[0], dict), ( + 'You should offer a dictionary containing network outputs ' + 'for building up computational graph of this loss module.') + outputs_dict = args[0] + elif 'outputs_dict' in kwargs: + assert len(args) == 0, ( + 'If the outputs dict is given in keyworded arguments, no' + ' further non-keyworded arguments should be offered.') + outputs_dict = kwargs.pop('outputs_dict') + else: + raise NotImplementedError( + 'Cannot parsing your arguments passed to this loss module.' + ' Please check the usage of this module') + # link the outputs with loss input args according to self.data_info + loss_input_dict = { + k: outputs_dict[v] + for k, v in self.data_info.items() + } + kwargs.update(loss_input_dict) + return self.perceptual_loss(**kwargs) + else: + # if you have not define how to build computational graph, this + # module will just directly return the loss as usual. + return self.perceptual_loss(*args, **kwargs) + + def perceptual_loss(self, pred, target): + if self.norm_img: + pred = (pred + 1.) * 0.5 + target = (target + 1.) * 0.5 + # extract vgg features + pred_features = self.vgg(pred) + target_features = self.vgg(target.detach()) + + # calculate perceptual loss + if self.perceptual_weight > 0: + percep_loss = 0 + for k in pred_features.keys(): + percep_loss += self.criterion( + pred_features[k], + target_features[k], + weight=self.layer_weights[k]) + percep_loss *= self.perceptual_weight + else: + percep_loss = 0. + + # calculate style loss + if self.style_weight > 0: + if self.vgg_style is not None: + pred_features = self.vgg_style(pred) + target_features = self.vgg_style(target.detach()) + + style_loss = 0 + for k in pred_features.keys(): + style_loss += self.criterion( + self._gram_mat(pred_features[k]), + self._gram_mat( + target_features[k])) * self.layer_weights_style[k] + style_loss *= self.style_weight + else: + style_loss = 0. + + if self.split_style_loss: + return percep_loss, style_loss + else: + return percep_loss + style_loss + + def _gram_mat(self, x): + """Calculate Gram matrix. + + Args: + x (torch.Tensor): Tensor with shape of (n, c, h, w). + + Returns: + torch.Tensor: Gram matrix. + """ + (n, c, h, w) = x.size() + features = x.view(n, c, w * h) + features_t = features.transpose(1, 2) + gram = features.bmm(features_t) / (c * h * w) + return gram + + def loss_name(self): + """Loss Name. + + This function must be implemented and will return the name of this + loss function. This name will be used to combine different loss items + by simple sum operation. In addition, if you want this loss item to be + included into the backward graph, `loss_` must be the prefix of the + name. + + Returns: + str: The name of this loss item. + """ + return self._loss_name diff --git a/tests/test_losses/test_gen_auxiliary_loss.py b/tests/test_losses/test_gen_auxiliary_loss.py index 9c9aa219d..91eed2343 100644 --- a/tests/test_losses/test_gen_auxiliary_loss.py +++ b/tests/test_losses/test_gen_auxiliary_loss.py @@ -2,8 +2,10 @@ import pytest import torch +from mmgen.models.architectures.pix2pix import UnetGenerator from mmgen.models.architectures.stylegan import StyleGANv2Generator -from mmgen.models.losses import GeneratorPathRegularizer +from mmgen.models.losses import GeneratorPathRegularizer, PerceptualLoss +from mmgen.models.losses.pixelwise_loss import l1_loss, mse_loss class TestPathRegularizer: @@ -61,3 +63,138 @@ def test_path_regularizer_cuda(self): with pytest.raises(AssertionError): _ = pl(1., 2, outputs_dict=output_dict) + + +class TestPerceptualLoss: + + @classmethod + def setup_class(cls): + cls.data_info = dict(pred='fake_imgs', target='real_imgs') + cls.gen = UnetGenerator(3, 3) + + def test_perceptual_loss_cpu(self): + unknown_h, unknown_w = (32, 32) + weight = torch.zeros(1, 1, 64, 64) + weight[0, 0, :unknown_h, :unknown_w] = 1 + pred = weight.clone() + target = weight.clone() * 2 + + perceptual_loss = PerceptualLoss(data_info=self.data_info) + loss_perceptual = perceptual_loss( + outputs_dict=dict(fake_imgs=pred, real_imgs=target)) + assert loss_perceptual.shape == () + assert id(perceptual_loss.criterion) == id(l1_loss) + + def test_only_perceptual_loss(self): + unknown_h, unknown_w = (32, 32) + weight = torch.zeros(1, 1, 64, 64) + weight[0, 0, :unknown_h, :unknown_w] = 1 + pred = weight.clone() + target = weight.clone() * 2 + + perceptual_loss = PerceptualLoss( + data_info=self.data_info, style_weight=0) + loss_percep = perceptual_loss(dict(fake_imgs=pred, real_imgs=target)) + assert loss_percep.shape == () + assert perceptual_loss.style_weight == 0 + + def test_only_style_loss(self): + unknown_h, unknown_w = (32, 32) + weight = torch.zeros(1, 1, 64, 64) + weight[0, 0, :unknown_h, :unknown_w] = 1 + pred = weight.clone() + target = weight.clone() * 2 + + perceptual_loss = PerceptualLoss( + data_info=self.data_info, perceptual_weight=0) + loss_style = perceptual_loss(dict(fake_imgs=pred, real_imgs=target)) + assert loss_style.shape == () + assert perceptual_loss.perceptual_weight == 0 + + def test_with_different_layer_weights(self): + unknown_h, unknown_w = (32, 32) + weight = torch.zeros(1, 1, 64, 64) + weight[0, 0, :unknown_h, :unknown_w] = 1 + pred = weight.clone() + target = weight.clone() * 2 + + layer_weights = {'1': 1., '2': 2., '3': 3.} + perceptual_loss = PerceptualLoss( + data_info=self.data_info, layer_weights=layer_weights) + loss_perceptual = perceptual_loss( + dict(fake_imgs=pred, real_imgs=target)) + assert loss_perceptual.shape == () + assert perceptual_loss.layer_weights == layer_weights and \ + perceptual_loss.layer_weights_style == layer_weights + + def test_with_different_perceptual_and_style_layers(self): + unknown_h, unknown_w = (32, 32) + weight = torch.zeros(1, 1, 64, 64) + weight[0, 0, :unknown_h, :unknown_w] = 1 + pred = weight.clone() + target = weight.clone() * 2 + + layer_weights = {'1': 1., '2': 2., '3': 3.} + layer_weights_style = {'4': 4., '5': 5., '6': 6.} + perceptual_loss = PerceptualLoss( + data_info=self.data_info, + layer_weights=layer_weights, + layer_weights_style=layer_weights_style) + loss_perceptual = perceptual_loss( + dict(fake_imgs=pred, real_imgs=target)) + assert loss_perceptual.shape == () + assert perceptual_loss.layer_weights == layer_weights and \ + perceptual_loss.layer_weights_style == layer_weights_style + + def test_MSE_critierion(self): + unknown_h, unknown_w = (32, 32) + weight = torch.zeros(1, 1, 64, 64) + weight[0, 0, :unknown_h, :unknown_w] = 1 + pred = weight.clone() + target = weight.clone() * 2 + + perceptual_loss = PerceptualLoss( + data_info=self.data_info, criterion='mse') + loss_perceptual = perceptual_loss( + outputs_dict=dict(fake_imgs=pred, real_imgs=target)) + assert loss_perceptual.shape == () + assert id(perceptual_loss.criterion) == id(mse_loss) + + def test_VGG_16(self): + unknown_h, unknown_w = (32, 32) + weight = torch.zeros(1, 1, 64, 64) + weight[0, 0, :unknown_h, :unknown_w] = 1 + pred = weight.clone() + target = weight.clone() * 2 + + perceptual_loss = PerceptualLoss( + data_info=self.data_info, + vgg_type='vgg16', + pretrained='torchvision://vgg16') + loss_perceptual = perceptual_loss( + outputs_dict=dict(fake_imgs=pred, real_imgs=target)) + assert loss_perceptual.shape == () + # TODO need to check whether vgg16 is loaded + # assert perceptual_loss.vgg + + def test_split_style_loss(self): + unknown_h, unknown_w = (32, 32) + weight = torch.zeros(1, 1, 64, 64) + weight[0, 0, :unknown_h, :unknown_w] = 1 + pred = weight.clone() + target = weight.clone() * 2 + + perceptual_loss = PerceptualLoss( + data_info=self.data_info, split_style_loss=True) + loss_percep, loss_style = perceptual_loss( + outputs_dict=dict(fake_imgs=pred, real_imgs=target)) + assert loss_percep.shape == () and loss_style.shape == () + + @pytest.mark.skipif(not torch.cuda.is_available(), reason='requires cuda') + def test_perceptual_loss_cuda(self): + pred = torch.rand([2, 3, 256, 256]).cuda() + target = torch.rand_like(pred).cuda() + perceptual_loss = PerceptualLoss(data_info=self.data_info).cuda() + loss_perceptual = perceptual_loss( + outputs_dict=dict(fake_imgs=pred, real_imgs=target)) + assert loss_perceptual.shape == ()