Skip to content

Commit

Permalink
feat(ml): added semantic threshold option
Browse files Browse the repository at this point in the history
  • Loading branch information
beniz committed Apr 26, 2021
1 parent d7cade1 commit 03f33a2
Show file tree
Hide file tree
Showing 3 changed files with 3 additions and 2 deletions.
2 changes: 1 addition & 1 deletion models/cycle_gan_semantic_mask_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -412,7 +412,7 @@ def backward_G(self):

# only use semantic loss when classifier has reasonably low loss
#if True:
if not hasattr(self, 'loss_f_s') or self.loss_f_s.detach().item() > 1.0:
if not hasattr(self, 'loss_f_s') or self.loss_f_s.detach().item() > self.opt.semantic_threshold:
self.loss_sem_AB = 0 * self.loss_sem_AB
self.loss_sem_BA = 0 * self.loss_sem_BA
self.loss_G += self.loss_sem_BA + self.loss_sem_AB
Expand Down
2 changes: 1 addition & 1 deletion models/cycle_gan_semantic_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -273,7 +273,7 @@ def backward_G(self):

# only use semantic loss when classifier has reasonably low loss
#if True:
if not hasattr(self, 'loss_CLS') or self.loss_CLS.detach().item() > 1.0:
if not hasattr(self, 'loss_CLS') or self.loss_CLS.detach().item() > self.opt.semantic_threshold:
self.loss_sem_AB = 0 * self.loss_sem_AB
self.loss_sem_BA = 0 * self.loss_sem_BA

Expand Down
1 change: 1 addition & 0 deletions options/base_options.py
Original file line number Diff line number Diff line change
Expand Up @@ -64,6 +64,7 @@ def initialize(self, parser):
parser.add_argument('--verbose', action='store_true', help='if specified, print more debugging information')
parser.add_argument('--suffix', default='', type=str, help='customized suffix: opt.name = opt.name + suffix: e.g., {model}_{netG}_size{load_size}')
parser.add_argument('--semantic_nclasses',default=10,type=int,help='number of classes of the semantic loss classifier')
parser.add_argument('--semantic_threshold',default=1.0,type=float,help='threshold of the semantic classifier loss below with semantic loss is applied')
parser.add_argument('--display_networks', action='store_true',help='Set True if you want to display networks on port 8000')
parser.add_argument('--compute_fid', action='store_true')
parser.add_argument('--fid_every', type=int, default=1000)
Expand Down

0 comments on commit 03f33a2

Please sign in to comment.