Skip to content

Commit

Permalink
fix(ml): allow semantic loss weight control
Browse files Browse the repository at this point in the history
  • Loading branch information
beniz committed Apr 27, 2021
1 parent 8d5ca5e commit 8e552be
Showing 1 changed file with 6 additions and 2 deletions.
8 changes: 6 additions & 2 deletions models/cycle_gan_semantic_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,6 +46,7 @@ def modify_commandline_options(parser, is_train=True):
parser.add_argument('--cls_pretrained', action='store_true', help='whether to use a pretrained model, available for non "basic" model only')
parser.add_argument('--lr_f_s', type=float, default=0.0002, help='f_s learning rate')
parser.add_argument('--regression', action='store_true', help='if true cls will be a regressor and not a classifier')
parser.add_argument('--lambda_sem', type=float, default=1.0, help='weight for semantic loss')
parser.add_argument('--lambda_CLS', type=float, default=1.0, help='weight for CLS loss')
parser.add_argument('--l1_regression', action='store_true', help='if true l1 loss will be used to compute regressor loss')

Expand Down Expand Up @@ -291,7 +292,7 @@ def backward_G(self):
self.loss_sem_AB = self.criterionCLS(self.pred_fake_B, self.input_A_label)
else:
self.loss_sem_AB = self.criterionCLS(self.pred_fake_B.squeeze(1), self.input_A_label)

#self.loss_sem_AB = self.criterionCLS(self.pred_fake_B, self.gt_pred_A)
# semantic loss BA
if hasattr(self,'input_B_label'):
Expand All @@ -312,7 +313,10 @@ def backward_G(self):
if not hasattr(self, 'loss_CLS') or self.loss_CLS.detach().item() > self.opt.semantic_threshold:
self.loss_sem_AB = 0 * self.loss_sem_AB
self.loss_sem_BA = 0 * self.loss_sem_BA


self.loss_sem_AB *= self.opt.lambda_sem
self.loss_sem_BA *= self.opt.lambda_sem

self.loss_G += self.loss_sem_BA + self.loss_sem_AB
(self.loss_G/self.opt.iter_size).backward()

Expand Down

0 comments on commit 8e552be

Please sign in to comment.