Skip to content

Commit

Permalink
feat: display augmented images
Browse files Browse the repository at this point in the history
  • Loading branch information
pnsuau committed Sep 29, 2021
1 parent b544b13 commit 2126253
Show file tree
Hide file tree
Showing 3 changed files with 34 additions and 0 deletions.
21 changes: 21 additions & 0 deletions models/base_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -426,6 +426,16 @@ def compute_D_loss_generic(self,netD,domain_img,loss,real_name=None,fake_name=No
real = diff_augment(real,self.diff_aug_policy)
fake = diff_augment(fake,self.diff_aug_policy)

if fake_name is None:
setattr(self,"fake_"+domain_img+"_aug",fake)
else:
setattr(self,fake_name,fake)

if real_name is None:
setattr(self,"real_"+domain_img+"_aug",real)
else:
setattr(self,real_name,real)

loss = loss.compute_loss_D(netD, real, fake)
return loss

Expand All @@ -441,6 +451,17 @@ def compute_G_loss_GAN_generic(self,netD,domain_img,loss,real_name=None,fake_nam

real = diff_augment(real,self.diff_aug_policy)
fake = diff_augment(fake,self.diff_aug_policy)

if fake_name is None:
setattr(self,"fake_"+domain_img+"_aug",fake)
else:
setattr(self,fake_name,fake)

if real_name is None:
setattr(self,"real_"+domain_img+"_aug",real)
else:
setattr(self,real_name,real)


loss = loss.compute_loss_G(netD, real, fake)
return loss
Expand Down
8 changes: 8 additions & 0 deletions models/cut_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -78,11 +78,19 @@ def __init__(self, opt,rank):

visual_names_A = ['real_A', 'fake_B']
visual_names_B = ['real_B']



self.nce_layers = [int(i) for i in self.opt.nce_layers.split(',')]

if opt.nce_idt and self.isTrain:
visual_names_B += ['idt_B']
self.visual_names = [visual_names_A,visual_names_B]

if self.opt.diff_aug_policy != '':
self.visual_names.append(['fake_B_aug'])
self.visual_names.append(['real_B_aug'])

if self.isTrain:
self.model_names = ['G', 'F', 'D']
if opt.netD_global != "none":
Expand Down
5 changes: 5 additions & 0 deletions models/cycle_gan_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -82,6 +82,11 @@ def __init__(self, opt,rank):

self.visual_names = [visual_names_A , visual_names_B] # combine visualizations for A and B
# specify the models you want to save to the disk. The training/test scripts will call <BaseModel.save_networks> and <BaseModel.load_networks>.

if self.opt.diff_aug_policy != '':
self.visual_names.append(['real_A_aug','fake_B_aug'])
self.visual_names.append(['real_B_aug','fake_A_aug'])

if self.isTrain:
self.model_names = ['G_A', 'G_B', 'D_A', 'D_B']
if opt.netD_global != "none":
Expand Down

0 comments on commit 2126253

Please sign in to comment.