From ddfcc97500fc2b63cff2f6f2e3b321c9e957bc3c Mon Sep 17 00:00:00 2001 From: pnsuau Date: Mon, 14 Nov 2022 12:18:36 +0000 Subject: [PATCH] feat: add SRC and hDCE losses --- models/cut_model.py | 91 ++++++++++++++-- models/modules/NCE/SRC.py | 100 ++++++++++++++++++ .../{patchnce.py => modules/NCE/base_NCE.py} | 65 +++++++----- models/modules/NCE/hDCE.py | 38 +++++++ models/modules/NCE/monce.py | 33 ++++++ models/modules/NCE/patchnce.py | 10 ++ models/{ => modules/NCE}/sinkhorn.py | 0 models/monce.py | 76 ------------- 8 files changed, 302 insertions(+), 111 deletions(-) create mode 100644 models/modules/NCE/SRC.py rename models/{patchnce.py => modules/NCE/base_NCE.py} (68%) create mode 100644 models/modules/NCE/hDCE.py create mode 100644 models/modules/NCE/monce.py create mode 100644 models/modules/NCE/patchnce.py rename models/{ => modules/NCE}/sinkhorn.py (100%) delete mode 100644 models/monce.py diff --git a/models/cut_model.py b/models/cut_model.py index da81071d6..70cfa5746 100644 --- a/models/cut_model.py +++ b/models/cut_model.py @@ -6,8 +6,10 @@ from . import gan_networks from .modules import loss -from .patchnce import PatchNCELoss -from .monce import MoNCELoss +from .modules.NCE.patchnce import PatchNCELoss +from .modules.NCE.monce import MoNCELoss +from .modules.NCE.hDCE import PatchHDCELoss +from .modules.NCE.SRC import SRC_Loss from util.iter_calculator import IterCalculator from util.network_group import NetworkGroup @@ -37,6 +39,22 @@ def modify_commandline_options(parser, is_train=True): default=1.0, help="weight for NCE loss: NCE(G(X), X)", ) + parser.add_argument( + "--alg_cut_lambda_SRC", + type=float, + default=0.0, + help="weight for SRC (semantic relation consistency) loss: NCE(G(X), X)", + ) + parser.add_argument( + "--alg_cut_HDCE_gamma", + type=float, + default=1.0, + ) + parser.add_argument( + "--alg_cut_HDCE_gamma_min", + type=float, + default=1.0, + ) parser.add_argument( "--alg_cut_nce_idt", type=util.str2bool, @@ -77,7 +95,7 @@ def modify_commandline_options(parser, is_train=True): "--alg_cut_nce_loss", type=str, default="monce", - choices=["patchnce", "monce"], + choices=["patchnce", "monce", "SRC_hDCE"], help="CUT contrastice loss", ) parser.add_argument( @@ -209,6 +227,13 @@ def __init__(self, opt, rank): self.criterionNCE.append(PatchNCELoss(opt).to(self.device)) elif opt.alg_cut_nce_loss == "monce": self.criterionNCE.append(MoNCELoss(opt).to(self.device)) + elif opt.alg_cut_nce_loss == "SRC_hDCE": + self.criterionNCE.append(PatchHDCELoss(opt).to(self.device)) + + if opt.alg_cut_nce_loss == "SRC_hDCE": + self.criterionR = [] + for nce_layer in self.nce_layers: + self.criterionR.append(SRC_Loss(opt).to(self.device)) if self.opt.alg_cut_MSE_idt: self.criterionIdt = torch.nn.L1Loss() @@ -490,13 +515,32 @@ def forward_E(self): def compute_G_loss_cut(self): """Calculate NCE loss for the generator""" + # Fake losses + feat_q_pool, feat_k_pool = self.calculate_feats(self.real_A, self.fake_B) + + if self.opt.alg_cut_lambda_SRC > 0.0 or self.opt.alg_cut_nce_loss == "SRC_hDCE": + self.loss_G_SRC, weight = self.calculate_R_loss(feat_q_pool, feat_k_pool) + else: + self.loss_G_SRC = 0.0 + weight = None + if self.opt.alg_cut_lambda_NCE > 0.0: - self.loss_G_NCE = self.calculate_NCE_loss(self.real_A, self.fake_B) + self.loss_G_NCE = self.calculate_NCE_loss(feat_q_pool, feat_k_pool, weight) else: - self.loss_G_NCE, self.loss_NCE_bd = 0.0, 0.0 + self.loss_G_NCE = 0.0 + + # Identity losses + feat_q_pool, feat_k_pool = self.calculate_feats(self.real_B, self.idt_B) + if self.opt.alg_cut_lambda_SRC > 0.0 or self.opt.alg_cut_nce_loss == "SRC_hDCE": + self.loss_G_SRC_Y, weight = self.calculate_R_loss(feat_q_pool, feat_k_pool) + else: + self.loss_G_SRC = 0.0 + weight = None if self.opt.alg_cut_nce_idt and self.opt.alg_cut_lambda_NCE > 0.0: - self.loss_G_NCE_Y = self.calculate_NCE_loss(self.real_B, self.idt_B) + self.loss_G_NCE_Y = self.calculate_NCE_loss( + feat_q_pool, feat_k_pool, weight + ) loss_NCE_both = (self.loss_G_NCE + self.loss_G_NCE_Y) * 0.5 else: loss_NCE_both = self.loss_G_NCE @@ -521,8 +565,7 @@ def compute_E_loss(self): else: self.loss_G_z = 0 - def calculate_NCE_loss(self, src, tgt): - n_layers = len(self.nce_layers) + def calculate_feats(self, src, tgt): if hasattr(self.netG_A, "module"): netG_A = self.netG_A.module else: @@ -561,13 +604,39 @@ def calculate_NCE_loss(self, src, tgt): ) feat_q_pool, _ = self.netF(feat_q, self.opt.alg_cut_num_patches, sample_ids) + return feat_q_pool, feat_k_pool + + def calculate_NCE_loss(self, feat_q_pool, feat_k_pool, weights): + if weights is None: + weights = [None for k in range(len(feat_q_pool))] + n_layers = len(self.nce_layers) total_nce_loss = 0.0 - for f_q, f_k, crit, nce_layer in zip( - feat_q_pool, feat_k_pool, self.criterionNCE, self.nce_layers + for f_q, f_k, crit, nce_layer, weight in zip( + feat_q_pool, feat_k_pool, self.criterionNCE, self.nce_layers, weights ): loss = ( - crit(f_q, f_k, current_batch=src.shape[0]) * self.opt.alg_cut_lambda_NCE + crit( + feat_q=f_q, + feat_k=f_k, + current_batch=self.get_current_batch_size(), + weight=weight, + ) + * self.opt.alg_cut_lambda_NCE ) + total_nce_loss += loss.mean() return total_nce_loss / n_layers + + def calculate_R_loss(self, feat_q_pool, feat_k_pool, only_weight=False, epoch=None): + n_layers = len(self.nce_layers) + + total_SRC_loss = 0.0 + weights = [] + for f_q, f_k, crit, nce_layer in zip( + feat_q_pool, feat_k_pool, self.criterionR, self.nce_layers + ): + loss_SRC, weight = crit(f_q, f_k, only_weight, epoch) + total_SRC_loss += loss_SRC * self.opt.alg_cut_lambda_SRC + weights.append(weight) + return total_SRC_loss / n_layers, weights diff --git a/models/modules/NCE/SRC.py b/models/modules/NCE/SRC.py new file mode 100644 index 000000000..a8d80cee6 --- /dev/null +++ b/models/modules/NCE/SRC.py @@ -0,0 +1,100 @@ +from packaging import version +import torch +from torch import nn + + +class Normalize(nn.Module): + def __init__(self, power=2): + super(Normalize, self).__init__() + self.power = power + + def forward(self, x): + norm = x.pow(self.power).sum(1, keepdim=True).pow(1.0 / self.power) + out = x.div(norm + 1e-7) + return out + + +class SRC_Loss(nn.Module): + def __init__(self, opt): + super().__init__() + self.opt = opt + self.mask_dtype = ( + torch.uint8 + if version.parse(torch.__version__) < version.parse("1.2.0") + else torch.bool + ) + + self.opt.use_curriculum = False + + def forward(self, feat_q, feat_k, only_weight=False, epoch=None): + """ + :param feat_q: target + :param feat_k: source + :return: SRC loss, weights for hDCE + """ + + batchSize = feat_q.shape[0] + dim = feat_q.shape[1] + feat_k = feat_k.detach() + + if self.opt.alg_cut_nce_includes_all_negatives_from_minibatch: + # reshape features as if they are all negatives of minibatch of size 1. + batch_dim_for_bmm = 1 + else: + batch_dim_for_bmm = self.opt.train_batch_size + + feat_k = Normalize()(feat_k) + feat_q = Normalize()(feat_q) + + ## SRC + feat_q_v = feat_q.view(batch_dim_for_bmm, -1, dim) + feat_k_v = feat_k.view(batch_dim_for_bmm, -1, dim) + + num_patches = feat_q.size(1) + + spatial_q = torch.bmm(feat_q_v, feat_q_v.transpose(2, 1)) + spatial_k = torch.bmm(feat_k_v, feat_k_v.transpose(2, 1)) + + weight_seed = spatial_k.clone().detach() + diagonal = torch.eye( + num_patches, device=feat_k_v.device, dtype=self.mask_dtype + )[None, :, :] + + HDCE_gamma = self.opt.alg_cut_HDCE_gamma + if self.opt.use_curriculum: + HDCE_gamma = HDCE_gamma + (self.opt.alg_cut_HDCE_gamma_min - HDCE_gamma) * ( + epoch + ) / (self.opt.n_epochs + self.opt.n_epochs_decay) + if (self.opt.step_gamma) & (epoch > self.opt.step_gamma_epoch): + HDCE_gamma = 1 + + ## weights by semantic relation + weight_seed.masked_fill_(diagonal, -10.0) + weight_out = nn.Softmax(dim=2)(weight_seed.clone() / HDCE_gamma).detach() + wmax_out, _ = torch.max(weight_out, dim=2, keepdim=True) + weight_out /= wmax_out + + if only_weight: + return 0, weight_out + + spatial_q = nn.Softmax(dim=1)(spatial_q) + spatial_k = nn.Softmax(dim=1)(spatial_k).detach() + + loss_src = self.get_jsd(spatial_q, spatial_k) + + return loss_src, weight_out + + def get_jsd(self, p1, p2): + """ + :param p1: n X C + :param p2: n X C + :return: n X 1 + """ + m = 0.5 * (p1 + p2) + out = 0.5 * ( + nn.KLDivLoss(reduction="sum", log_target=True)(torch.log(m), torch.log(p1)) + + nn.KLDivLoss(reduction="sum", log_target=True)( + torch.log(m), torch.log(p2) + ) + ) + return out diff --git a/models/patchnce.py b/models/modules/NCE/base_NCE.py similarity index 68% rename from models/patchnce.py rename to models/modules/NCE/base_NCE.py index 36e3b5031..fc23a47ba 100644 --- a/models/patchnce.py +++ b/models/modules/NCE/base_NCE.py @@ -1,9 +1,9 @@ -from packaging import version import torch from torch import nn +from packaging import version -class PatchNCELoss(nn.Module): +class BaseNCELoss(nn.Module): def __init__(self, opt): super().__init__() self.opt = opt @@ -14,16 +14,8 @@ def __init__(self, opt): else torch.bool ) - def forward(self, feat_q, feat_k, current_batch): - batchSize = feat_q.shape[0] - dim = feat_q.shape[1] - feat_k = feat_k.detach() - - # pos logit - l_pos = torch.bmm(feat_q.view(batchSize, 1, -1), feat_k.view(batchSize, -1, 1)) - l_pos = l_pos.view(batchSize, 1) - - # neg logit + def forward(self, feat_q, feat_k, current_batch, **unused_args): + self.dim = feat_q.shape[1] # Should the negatives from the other samples of a minibatch be utilized? # In CUT and FastCUT, we found that it's best to only include negatives @@ -34,15 +26,47 @@ def forward(self, feat_q, feat_k, current_batch): # Therefore, we will include the negatives from the entire minibatch. if self.opt.alg_cut_nce_includes_all_negatives_from_minibatch: # reshape features as if they are all negatives of minibatch of size 1. - batch_dim_for_bmm = 1 + self.batch_dim_for_bmm = 1 else: - batch_dim_for_bmm = current_batch + self.batch_dim_for_bmm = current_batch + + # Positive logits + l_pos = self.compute_pos_logit(feat_q, feat_k) + + # neg logit + l_neg = self.compute_neg_logit(feat_q, feat_k) + + loss = self.compute_loss(l_pos, l_neg) + + return loss + + def compute_loss(self, l_pos, l_neg): + out = torch.cat((l_pos, l_neg), dim=1) / self.opt.alg_cut_nce_T + + loss = self.cross_entropy_loss( + out, torch.zeros(out.size(0), dtype=torch.long, device=out.device) + ) + return loss + + def compute_pos_logit(self, feat_q, feat_k): + batchSize = feat_q.shape[0] + feat_k = feat_k.detach() + l_pos = torch.bmm(feat_q.view(batchSize, 1, -1), feat_k.view(batchSize, -1, 1)) + l_pos = l_pos.view(batchSize, 1) + return l_pos + + def compute_l_neg_curbatch(self, feat_q, feat_k): + """Returns negative examples""" + dim = feat_q.shape[1] # reshape features to batch size - feat_q = feat_q.view(batch_dim_for_bmm, -1, dim) - feat_k = feat_k.view(batch_dim_for_bmm, -1, dim) + feat_q = feat_q.view(self.batch_dim_for_bmm, -1, dim) + feat_k = feat_k.view(self.batch_dim_for_bmm, -1, dim) npatches = feat_q.size(1) l_neg_curbatch = torch.bmm(feat_q, feat_k.transpose(2, 1).contiguous()) + return l_neg_curbatch, npatches + def compute_neg_logit(self, feat_q, feat_k): + l_neg_curbatch, npatches = self.compute_l_neg_curbatch(feat_q, feat_k) # diagonal entries are similarity between same features, and hence meaningless. # just fill the diagonal with very small number, which is exp(-10) and almost zero diagonal = torch.eye(npatches, device=feat_q.device, dtype=self.mask_dtype)[ @@ -50,11 +74,4 @@ def forward(self, feat_q, feat_k, current_batch): ] l_neg_curbatch.masked_fill_(diagonal, -10.0) l_neg = l_neg_curbatch.view(-1, npatches) - - out = torch.cat((l_pos, l_neg), dim=1) / self.opt.alg_cut_nce_T - - loss = self.cross_entropy_loss( - out, torch.zeros(out.size(0), dtype=torch.long, device=feat_q.device) - ) - - return loss + return l_neg diff --git a/models/modules/NCE/hDCE.py b/models/modules/NCE/hDCE.py new file mode 100644 index 000000000..95b69ca13 --- /dev/null +++ b/models/modules/NCE/hDCE.py @@ -0,0 +1,38 @@ +from packaging import version +import torch +from torch import nn + +from .base_NCE import BaseNCELoss + + +class PatchHDCELoss(BaseNCELoss): + def __init__(self, opt): + super().__init__(opt) + + def forward(self, feat_q, feat_k, current_batch, weight): + self.weight = weight + return super().forward(feat_q, feat_k, current_batch) + + def compute_l_neg_curbatch(self, feat_q, feat_k): + l_neg_curbatch, npatches = super().compute_l_neg_curbatch(feat_q, feat_k) + # weighted by semantic relation + if self.weight is not None: + l_neg_curbatch *= self.weight + return l_neg_curbatch, npatches + + def compute_loss(self, l_pos, l_neg): + logits = (l_neg - l_pos) / self.opt.alg_cut_nce_T + v = torch.logsumexp(logits, dim=1) + loss_vec = torch.exp(v - v.detach()) + + # for monitoring + out_dummy = torch.cat((l_pos, l_neg), dim=1) / self.opt.alg_cut_nce_T + + CELoss_dummy = self.cross_entropy_loss( + out_dummy, + torch.zeros(out_dummy.size(0), dtype=torch.long, device=out_dummy.device), + ) + + loss = loss_vec - 1 + CELoss_dummy.detach() + + return loss diff --git a/models/modules/NCE/monce.py b/models/modules/NCE/monce.py new file mode 100644 index 000000000..585b3b976 --- /dev/null +++ b/models/modules/NCE/monce.py @@ -0,0 +1,33 @@ +from packaging import version +import torch +from torch import nn +import math +from .sinkhorn import OT +import numpy as np +import torch.nn.functional as F + +from .base_NCE import BaseNCELoss + + +class MoNCELoss(BaseNCELoss): + def __init__(self, opt): + super().__init__(opt) + + def compute_l_neg_curbatch(self, feat_q, feat_k): + eps = 1.0 # default + cost_type = "hard" # default for cut + neg_term_weight = 1.0 # default + + ot_q = feat_q.view(self.batch_dim_for_bmm, -1, self.dim) + ot_k = feat_k.view(self.batch_dim_for_bmm, -1, self.dim).detach() + f = OT(ot_q, ot_k, eps=eps, max_iter=50, cost_type=cost_type) + f = ( + f.permute(0, 2, 1) * (self.opt.alg_cut_num_patches - 1) * neg_term_weight + + 1e-8 + ) + + l_neg_curbatch, npatches = super().compute_l_neg_curbatch(feat_q, feat_k) + + l_neg_curbatch = l_neg_curbatch + torch.log(f) * self.opt.alg_cut_nce_T + + return l_neg_curbatch, npatches diff --git a/models/modules/NCE/patchnce.py b/models/modules/NCE/patchnce.py new file mode 100644 index 000000000..75d0b90fc --- /dev/null +++ b/models/modules/NCE/patchnce.py @@ -0,0 +1,10 @@ +from packaging import version +import torch +from torch import nn + +from .base_NCE import BaseNCELoss + + +class PatchNCELoss(BaseNCELoss): + def __init__(self, opt): + super().__init__(opt) diff --git a/models/sinkhorn.py b/models/modules/NCE/sinkhorn.py similarity index 100% rename from models/sinkhorn.py rename to models/modules/NCE/sinkhorn.py diff --git a/models/monce.py b/models/monce.py deleted file mode 100644 index 4423921ca..000000000 --- a/models/monce.py +++ /dev/null @@ -1,76 +0,0 @@ -from packaging import version -import torch -from torch import nn -import math -from .sinkhorn import OT -import numpy as np -import torch.nn.functional as F - - -class Normalize(nn.Module): - def __init__(self, power=2): - super(Normalize, self).__init__() - self.power = power - - def forward(self, x): - norm = x.pow(self.power).sum(1, keepdim=True).pow(1.0 / self.power) - out = x.div(norm + 1e-7) - return out - - -class MoNCELoss(nn.Module): - def __init__(self, opt): - super().__init__() - self.opt = opt - self.cross_entropy_loss = torch.nn.CrossEntropyLoss(reduction="none") - self.mask_dtype = ( - torch.uint8 - if version.parse(torch.__version__) < version.parse("1.2.0") - else torch.bool - ) - self.l2_norm = Normalize(2) - - def forward(self, feat_q, feat_k, current_batch): - eps = 1.0 # default - cost_type = "hard" # default for cut - neg_term_weight = 1.0 # default - - batchSize = feat_q.shape[0] - dim = feat_q.shape[1] - # Therefore, we will include the negatives from the entire minibatch. - if self.opt.alg_cut_nce_includes_all_negatives_from_minibatch: - batch_dim_for_bmm = 1 - else: - batch_dim_for_bmm = current_batch - - # if self.loss_type == 'MoNCE': - ot_q = feat_q.view(batch_dim_for_bmm, -1, dim) - ot_k = feat_k.view(batch_dim_for_bmm, -1, dim).detach() - f = OT(ot_q, ot_k, eps=eps, max_iter=50, cost_type=cost_type) - f = ( - f.permute(0, 2, 1) * (self.opt.alg_cut_num_patches - 1) * neg_term_weight - + 1e-8 - ) - - feat_k = feat_k.detach() - l_pos = torch.bmm(feat_q.view(batchSize, 1, -1), feat_k.view(batchSize, -1, 1)) - l_pos = l_pos.view(batchSize, 1) - - feat_q = feat_q.view(batch_dim_for_bmm, -1, dim) - feat_k = feat_k.view(batch_dim_for_bmm, -1, dim) - npatches = feat_q.size(1) - l_neg_curbatch = torch.bmm(feat_q, feat_k.transpose(2, 1)) - # if self.loss_type == 'MoNCE': - l_neg_curbatch = l_neg_curbatch + torch.log(f) * self.opt.alg_cut_nce_T - - diagonal = torch.eye(npatches, device=feat_q.device, dtype=self.mask_dtype)[ - None, :, : - ] - l_neg_curbatch.masked_fill_(diagonal, -10.0) - l_neg = l_neg_curbatch.view(-1, npatches) - - out = torch.cat((l_pos, l_neg), dim=1) / self.opt.alg_cut_nce_T - loss = self.cross_entropy_loss( - out, torch.zeros(out.size(0), dtype=torch.long, device=feat_q.device) - ) - return loss