Skip to content

Commit

Permalink
feat: add SRC and hDCE losses
Browse files Browse the repository at this point in the history
  • Loading branch information
pnsuau authored and beniz committed Dec 6, 2022
1 parent b4c3cfd commit ddfcc97
Show file tree
Hide file tree
Showing 8 changed files with 302 additions and 111 deletions.
91 changes: 80 additions & 11 deletions models/cut_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,8 +6,10 @@
from . import gan_networks

from .modules import loss
from .patchnce import PatchNCELoss
from .monce import MoNCELoss
from .modules.NCE.patchnce import PatchNCELoss
from .modules.NCE.monce import MoNCELoss
from .modules.NCE.hDCE import PatchHDCELoss
from .modules.NCE.SRC import SRC_Loss

from util.iter_calculator import IterCalculator
from util.network_group import NetworkGroup
Expand Down Expand Up @@ -37,6 +39,22 @@ def modify_commandline_options(parser, is_train=True):
default=1.0,
help="weight for NCE loss: NCE(G(X), X)",
)
parser.add_argument(
"--alg_cut_lambda_SRC",
type=float,
default=0.0,
help="weight for SRC (semantic relation consistency) loss: NCE(G(X), X)",
)
parser.add_argument(
"--alg_cut_HDCE_gamma",
type=float,
default=1.0,
)
parser.add_argument(
"--alg_cut_HDCE_gamma_min",
type=float,
default=1.0,
)
parser.add_argument(
"--alg_cut_nce_idt",
type=util.str2bool,
Expand Down Expand Up @@ -77,7 +95,7 @@ def modify_commandline_options(parser, is_train=True):
"--alg_cut_nce_loss",
type=str,
default="monce",
choices=["patchnce", "monce"],
choices=["patchnce", "monce", "SRC_hDCE"],
help="CUT contrastice loss",
)
parser.add_argument(
Expand Down Expand Up @@ -209,6 +227,13 @@ def __init__(self, opt, rank):
self.criterionNCE.append(PatchNCELoss(opt).to(self.device))
elif opt.alg_cut_nce_loss == "monce":
self.criterionNCE.append(MoNCELoss(opt).to(self.device))
elif opt.alg_cut_nce_loss == "SRC_hDCE":
self.criterionNCE.append(PatchHDCELoss(opt).to(self.device))

if opt.alg_cut_nce_loss == "SRC_hDCE":
self.criterionR = []
for nce_layer in self.nce_layers:
self.criterionR.append(SRC_Loss(opt).to(self.device))

if self.opt.alg_cut_MSE_idt:
self.criterionIdt = torch.nn.L1Loss()
Expand Down Expand Up @@ -490,13 +515,32 @@ def forward_E(self):
def compute_G_loss_cut(self):
"""Calculate NCE loss for the generator"""

# Fake losses
feat_q_pool, feat_k_pool = self.calculate_feats(self.real_A, self.fake_B)

if self.opt.alg_cut_lambda_SRC > 0.0 or self.opt.alg_cut_nce_loss == "SRC_hDCE":
self.loss_G_SRC, weight = self.calculate_R_loss(feat_q_pool, feat_k_pool)
else:
self.loss_G_SRC = 0.0
weight = None

if self.opt.alg_cut_lambda_NCE > 0.0:
self.loss_G_NCE = self.calculate_NCE_loss(self.real_A, self.fake_B)
self.loss_G_NCE = self.calculate_NCE_loss(feat_q_pool, feat_k_pool, weight)
else:
self.loss_G_NCE, self.loss_NCE_bd = 0.0, 0.0
self.loss_G_NCE = 0.0

# Identity losses
feat_q_pool, feat_k_pool = self.calculate_feats(self.real_B, self.idt_B)
if self.opt.alg_cut_lambda_SRC > 0.0 or self.opt.alg_cut_nce_loss == "SRC_hDCE":
self.loss_G_SRC_Y, weight = self.calculate_R_loss(feat_q_pool, feat_k_pool)
else:
self.loss_G_SRC = 0.0
weight = None

if self.opt.alg_cut_nce_idt and self.opt.alg_cut_lambda_NCE > 0.0:
self.loss_G_NCE_Y = self.calculate_NCE_loss(self.real_B, self.idt_B)
self.loss_G_NCE_Y = self.calculate_NCE_loss(
feat_q_pool, feat_k_pool, weight
)
loss_NCE_both = (self.loss_G_NCE + self.loss_G_NCE_Y) * 0.5
else:
loss_NCE_both = self.loss_G_NCE
Expand All @@ -521,8 +565,7 @@ def compute_E_loss(self):
else:
self.loss_G_z = 0

def calculate_NCE_loss(self, src, tgt):
n_layers = len(self.nce_layers)
def calculate_feats(self, src, tgt):
if hasattr(self.netG_A, "module"):
netG_A = self.netG_A.module
else:
Expand Down Expand Up @@ -561,13 +604,39 @@ def calculate_NCE_loss(self, src, tgt):
)
feat_q_pool, _ = self.netF(feat_q, self.opt.alg_cut_num_patches, sample_ids)

return feat_q_pool, feat_k_pool

def calculate_NCE_loss(self, feat_q_pool, feat_k_pool, weights):
if weights is None:
weights = [None for k in range(len(feat_q_pool))]
n_layers = len(self.nce_layers)
total_nce_loss = 0.0
for f_q, f_k, crit, nce_layer in zip(
feat_q_pool, feat_k_pool, self.criterionNCE, self.nce_layers
for f_q, f_k, crit, nce_layer, weight in zip(
feat_q_pool, feat_k_pool, self.criterionNCE, self.nce_layers, weights
):
loss = (
crit(f_q, f_k, current_batch=src.shape[0]) * self.opt.alg_cut_lambda_NCE
crit(
feat_q=f_q,
feat_k=f_k,
current_batch=self.get_current_batch_size(),
weight=weight,
)
* self.opt.alg_cut_lambda_NCE
)

total_nce_loss += loss.mean()

return total_nce_loss / n_layers

def calculate_R_loss(self, feat_q_pool, feat_k_pool, only_weight=False, epoch=None):
n_layers = len(self.nce_layers)

total_SRC_loss = 0.0
weights = []
for f_q, f_k, crit, nce_layer in zip(
feat_q_pool, feat_k_pool, self.criterionR, self.nce_layers
):
loss_SRC, weight = crit(f_q, f_k, only_weight, epoch)
total_SRC_loss += loss_SRC * self.opt.alg_cut_lambda_SRC
weights.append(weight)
return total_SRC_loss / n_layers, weights
100 changes: 100 additions & 0 deletions models/modules/NCE/SRC.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,100 @@
from packaging import version
import torch
from torch import nn


class Normalize(nn.Module):
def __init__(self, power=2):
super(Normalize, self).__init__()
self.power = power

def forward(self, x):
norm = x.pow(self.power).sum(1, keepdim=True).pow(1.0 / self.power)
out = x.div(norm + 1e-7)
return out


class SRC_Loss(nn.Module):
def __init__(self, opt):
super().__init__()
self.opt = opt
self.mask_dtype = (
torch.uint8
if version.parse(torch.__version__) < version.parse("1.2.0")
else torch.bool
)

self.opt.use_curriculum = False

def forward(self, feat_q, feat_k, only_weight=False, epoch=None):
"""
:param feat_q: target
:param feat_k: source
:return: SRC loss, weights for hDCE
"""

batchSize = feat_q.shape[0]
dim = feat_q.shape[1]
feat_k = feat_k.detach()

if self.opt.alg_cut_nce_includes_all_negatives_from_minibatch:
# reshape features as if they are all negatives of minibatch of size 1.
batch_dim_for_bmm = 1
else:
batch_dim_for_bmm = self.opt.train_batch_size

feat_k = Normalize()(feat_k)
feat_q = Normalize()(feat_q)

## SRC
feat_q_v = feat_q.view(batch_dim_for_bmm, -1, dim)
feat_k_v = feat_k.view(batch_dim_for_bmm, -1, dim)

num_patches = feat_q.size(1)

spatial_q = torch.bmm(feat_q_v, feat_q_v.transpose(2, 1))
spatial_k = torch.bmm(feat_k_v, feat_k_v.transpose(2, 1))

weight_seed = spatial_k.clone().detach()
diagonal = torch.eye(
num_patches, device=feat_k_v.device, dtype=self.mask_dtype
)[None, :, :]

HDCE_gamma = self.opt.alg_cut_HDCE_gamma
if self.opt.use_curriculum:
HDCE_gamma = HDCE_gamma + (self.opt.alg_cut_HDCE_gamma_min - HDCE_gamma) * (
epoch
) / (self.opt.n_epochs + self.opt.n_epochs_decay)
if (self.opt.step_gamma) & (epoch > self.opt.step_gamma_epoch):
HDCE_gamma = 1

## weights by semantic relation
weight_seed.masked_fill_(diagonal, -10.0)
weight_out = nn.Softmax(dim=2)(weight_seed.clone() / HDCE_gamma).detach()
wmax_out, _ = torch.max(weight_out, dim=2, keepdim=True)
weight_out /= wmax_out

if only_weight:
return 0, weight_out

spatial_q = nn.Softmax(dim=1)(spatial_q)
spatial_k = nn.Softmax(dim=1)(spatial_k).detach()

loss_src = self.get_jsd(spatial_q, spatial_k)

return loss_src, weight_out

def get_jsd(self, p1, p2):
"""
:param p1: n X C
:param p2: n X C
:return: n X 1
"""
m = 0.5 * (p1 + p2)
out = 0.5 * (
nn.KLDivLoss(reduction="sum", log_target=True)(torch.log(m), torch.log(p1))
+ nn.KLDivLoss(reduction="sum", log_target=True)(
torch.log(m), torch.log(p2)
)
)
return out
65 changes: 41 additions & 24 deletions models/patchnce.py → models/modules/NCE/base_NCE.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,9 @@
from packaging import version
import torch
from torch import nn
from packaging import version


class PatchNCELoss(nn.Module):
class BaseNCELoss(nn.Module):
def __init__(self, opt):
super().__init__()
self.opt = opt
Expand All @@ -14,16 +14,8 @@ def __init__(self, opt):
else torch.bool
)

def forward(self, feat_q, feat_k, current_batch):
batchSize = feat_q.shape[0]
dim = feat_q.shape[1]
feat_k = feat_k.detach()

# pos logit
l_pos = torch.bmm(feat_q.view(batchSize, 1, -1), feat_k.view(batchSize, -1, 1))
l_pos = l_pos.view(batchSize, 1)

# neg logit
def forward(self, feat_q, feat_k, current_batch, **unused_args):
self.dim = feat_q.shape[1]

# Should the negatives from the other samples of a minibatch be utilized?
# In CUT and FastCUT, we found that it's best to only include negatives
Expand All @@ -34,27 +26,52 @@ def forward(self, feat_q, feat_k, current_batch):
# Therefore, we will include the negatives from the entire minibatch.
if self.opt.alg_cut_nce_includes_all_negatives_from_minibatch:
# reshape features as if they are all negatives of minibatch of size 1.
batch_dim_for_bmm = 1
self.batch_dim_for_bmm = 1
else:
batch_dim_for_bmm = current_batch
self.batch_dim_for_bmm = current_batch

# Positive logits
l_pos = self.compute_pos_logit(feat_q, feat_k)

# neg logit
l_neg = self.compute_neg_logit(feat_q, feat_k)

loss = self.compute_loss(l_pos, l_neg)

return loss

def compute_loss(self, l_pos, l_neg):
out = torch.cat((l_pos, l_neg), dim=1) / self.opt.alg_cut_nce_T

loss = self.cross_entropy_loss(
out, torch.zeros(out.size(0), dtype=torch.long, device=out.device)
)
return loss

def compute_pos_logit(self, feat_q, feat_k):
batchSize = feat_q.shape[0]
feat_k = feat_k.detach()
l_pos = torch.bmm(feat_q.view(batchSize, 1, -1), feat_k.view(batchSize, -1, 1))
l_pos = l_pos.view(batchSize, 1)
return l_pos

def compute_l_neg_curbatch(self, feat_q, feat_k):
"""Returns negative examples"""
dim = feat_q.shape[1]
# reshape features to batch size
feat_q = feat_q.view(batch_dim_for_bmm, -1, dim)
feat_k = feat_k.view(batch_dim_for_bmm, -1, dim)
feat_q = feat_q.view(self.batch_dim_for_bmm, -1, dim)
feat_k = feat_k.view(self.batch_dim_for_bmm, -1, dim)
npatches = feat_q.size(1)
l_neg_curbatch = torch.bmm(feat_q, feat_k.transpose(2, 1).contiguous())
return l_neg_curbatch, npatches

def compute_neg_logit(self, feat_q, feat_k):
l_neg_curbatch, npatches = self.compute_l_neg_curbatch(feat_q, feat_k)
# diagonal entries are similarity between same features, and hence meaningless.
# just fill the diagonal with very small number, which is exp(-10) and almost zero
diagonal = torch.eye(npatches, device=feat_q.device, dtype=self.mask_dtype)[
None, :, :
]
l_neg_curbatch.masked_fill_(diagonal, -10.0)
l_neg = l_neg_curbatch.view(-1, npatches)

out = torch.cat((l_pos, l_neg), dim=1) / self.opt.alg_cut_nce_T

loss = self.cross_entropy_loss(
out, torch.zeros(out.size(0), dtype=torch.long, device=feat_q.device)
)

return loss
return l_neg
38 changes: 38 additions & 0 deletions models/modules/NCE/hDCE.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,38 @@
from packaging import version
import torch
from torch import nn

from .base_NCE import BaseNCELoss


class PatchHDCELoss(BaseNCELoss):
def __init__(self, opt):
super().__init__(opt)

def forward(self, feat_q, feat_k, current_batch, weight):
self.weight = weight
return super().forward(feat_q, feat_k, current_batch)

def compute_l_neg_curbatch(self, feat_q, feat_k):
l_neg_curbatch, npatches = super().compute_l_neg_curbatch(feat_q, feat_k)
# weighted by semantic relation
if self.weight is not None:
l_neg_curbatch *= self.weight
return l_neg_curbatch, npatches

def compute_loss(self, l_pos, l_neg):
logits = (l_neg - l_pos) / self.opt.alg_cut_nce_T
v = torch.logsumexp(logits, dim=1)
loss_vec = torch.exp(v - v.detach())

# for monitoring
out_dummy = torch.cat((l_pos, l_neg), dim=1) / self.opt.alg_cut_nce_T

CELoss_dummy = self.cross_entropy_loss(
out_dummy,
torch.zeros(out_dummy.size(0), dtype=torch.long, device=out_dummy.device),
)

loss = loss_vec - 1 + CELoss_dummy.detach()

return loss
Loading

0 comments on commit ddfcc97

Please sign in to comment.