Skip to content

Commit

Permalink
feat: backward while computing MSE criterion loss
Browse files Browse the repository at this point in the history
  • Loading branch information
pnsuau authored and beniz committed Sep 15, 2022
1 parent 543f280 commit 1b87906
Show file tree
Hide file tree
Showing 2 changed files with 32 additions and 10 deletions.
37 changes: 27 additions & 10 deletions models/base_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -234,12 +234,18 @@ def __init__(self, opt, rank):

self.visual_names.append(temp_visual_names_attn)

if self.opt.train_temporal_criterion or "temporal" in opt.D_netDs:
self.use_temporal = True
else:
self.use_temporal = False
if (
any("temporal" in D_name for D_name in self.opt.D_netDs)
and self.opt.train_temporal_criterion
):
raise ValueError(
"Temporal discriminator and MSE temporal criterion cannot be used at the same time, please choose one and restart the training."
)

if self.use_temporal:
if (
any("temporal" in D_name for D_name in self.opt.D_netDs)
or self.opt.train_temporal_criterion
):
visual_names_temporal_real_A = []
visual_names_temporal_real_B = []
visual_names_temporal_fake_B = []
Expand Down Expand Up @@ -680,7 +686,7 @@ def forward_GAN(self):
self.mask_context, size=self.real_A.shape[2:]
)[:, 0]

if self.use_temporal:
if any("temporal" in D_name for D_name in self.opt.D_netDs):
self.compute_temporal_fake(objective_domain="B")

if hasattr(self, "netG_B"):
Expand Down Expand Up @@ -1637,17 +1643,28 @@ def one_hot(self, tensor):
return one_hot.scatter_(1, tensor.unsqueeze(1), 1.0)

def compute_temporal_criterion_loss_generic(self, domain):
origin_domain = "B" if domain == "A" else "A"
netG = getattr(self, "netG_" + origin_domain)

loss_value = torch.zeros([], device=self.device)
previous_fake = getattr(self, "temporal_fake_" + domain)[:, 0].clone().detach()
previous_fake = netG(
getattr(self, "temporal_real_" + origin_domain)[:, 0]
).detach()

setattr(self, "temporal_fake_%s_0" % domain, previous_fake)

for i in range(1, self.opt.D_temporal_number_frames):
next_fake = getattr(self, "temporal_fake_" + domain)[:, i]
loss_value += self.criterionTemporal(previous_fake, next_fake)
next_fake = netG(getattr(self, "temporal_real_" + origin_domain)[:, i])
cur_loss = self.criterionTemporal(previous_fake, next_fake)
cur_loss.backward()
loss_value += cur_loss.detach()
previous_fake = next_fake.clone().detach()
setattr(self, "temporal_fake_%s_%d" % (domain, i), previous_fake)

return loss_value.mean()
return loss_value.sum()

def compute_temporal_criterion_loss(self):

self.loss_G_temporal_criterion_B = (
self.compute_temporal_criterion_loss_generic(domain="B")
* self.opt.train_temporal_criterion_lambda
Expand Down
5 changes: 5 additions & 0 deletions tests/test_run_semantic_mask_online.py
Original file line number Diff line number Diff line change
Expand Up @@ -70,3 +70,8 @@ def test_semantic_mask(dataroot):
json_like_dict_c["f_s_net"] = f_s_type
opt = TrainOptions().parse_json(json_like_dict_c)
train.launch_training(opt)

json_like_dict_c = json_like_dict.copy()
json_like_dict_c["D_netDs"].remove("temporal")
json_like_dict_c["D_temporal_number_frames"] = 5
json_like_dict_c["train_temporal_criterion"] = True

0 comments on commit 1b87906

Please sign in to comment.