diff --git a/models/base_gan_model.py b/models/base_gan_model.py index 82f4aa47e..8b69a7a7e 100644 --- a/models/base_gan_model.py +++ b/models/base_gan_model.py @@ -114,7 +114,7 @@ def __init__(self, opt, rank): else: self.use_depth = False - if "sam" in opt.D_netDs: + if "sam" in opt.D_netDs or opt.data_refined_mask: self.use_sam = True self.netfreeze_sam, self.predictor_sam = init_sam_net( opt.model_type_sam, self.opt.D_weight_sam, self.device diff --git a/models/base_model.py b/models/base_model.py index 291121e71..bd48f77f4 100644 --- a/models/base_model.py +++ b/models/base_model.py @@ -1084,8 +1084,11 @@ def one_hot(self, tensor): def compute_fake_real_masks(self): fake_mask = self.netf_s(self.real_A) fake_mask = F.gumbel_softmax(fake_mask, tau=1.0, hard=True, dim=1) - real_mask = self.netf_s(self.real_B) + real_mask = self.netf_s( + self.real_B + ) # f_s(B) is a good approximation of the real mask when task is easy real_mask = F.gumbel_softmax(real_mask, tau=1.0, hard=True, dim=1) + setattr(self, "fake_mask_B_inv", fake_mask.argmax(dim=1)) setattr(self, "real_mask_B_inv", real_mask.argmax(dim=1)) setattr(self, "fake_mask_B", fake_mask) @@ -1133,7 +1136,16 @@ def compute_f_s_loss(self): f_s = self.netf_s_B else: f_s = self.netf_s - label_B = self.input_B_label_mask + + if self.opt.data_refined_mask: + # get mask with sam instead of label from self.real_B and self.input_B_ref_bbox + self.label_sam_B = ( + predict_sam(self.real_B, self.predictor_sam, self.input_B_ref_bbox) + > 0.0 + ) + label_B = self.label_sam_B.long() + else: + label_B = self.input_B_label_mask pred_B = f_s(self.real_B) self.loss_f_s += self.criterionf_s(pred_B, label_B) # .squeeze(1)) diff --git a/models/cut_model.py b/models/cut_model.py index 0a4f87395..ef00b5fc2 100644 --- a/models/cut_model.py +++ b/models/cut_model.py @@ -453,6 +453,8 @@ def data_dependent_initialize_semantic_mask(self, data): if "mask" in self.opt.D_netDs: visual_names_seg_B += ["real_mask_B_inv", "fake_mask_B_inv"] + if self.opt.data_refined_mask: + visual_names_seg_B += ["label_sam_B"] self.visual_names += [visual_names_seg_A, visual_names_seg_B]