Skip to content

Commit

Permalink
feat(ml): GAN mask generator with sam refined target
Browse files Browse the repository at this point in the history
  • Loading branch information
beniz committed Jul 24, 2023
1 parent f580d88 commit 0cd1ee9
Show file tree
Hide file tree
Showing 3 changed files with 17 additions and 3 deletions.
2 changes: 1 addition & 1 deletion models/base_gan_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -114,7 +114,7 @@ def __init__(self, opt, rank):
else:
self.use_depth = False

if "sam" in opt.D_netDs:
if "sam" in opt.D_netDs or opt.data_refined_mask:
self.use_sam = True
self.netfreeze_sam, self.predictor_sam = init_sam_net(
opt.model_type_sam, self.opt.D_weight_sam, self.device
Expand Down
16 changes: 14 additions & 2 deletions models/base_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -1084,8 +1084,11 @@ def one_hot(self, tensor):
def compute_fake_real_masks(self):
fake_mask = self.netf_s(self.real_A)
fake_mask = F.gumbel_softmax(fake_mask, tau=1.0, hard=True, dim=1)
real_mask = self.netf_s(self.real_B)
real_mask = self.netf_s(
self.real_B
) # f_s(B) is a good approximation of the real mask when task is easy
real_mask = F.gumbel_softmax(real_mask, tau=1.0, hard=True, dim=1)

setattr(self, "fake_mask_B_inv", fake_mask.argmax(dim=1))
setattr(self, "real_mask_B_inv", real_mask.argmax(dim=1))
setattr(self, "fake_mask_B", fake_mask)
Expand Down Expand Up @@ -1133,7 +1136,16 @@ def compute_f_s_loss(self):
f_s = self.netf_s_B
else:
f_s = self.netf_s
label_B = self.input_B_label_mask

if self.opt.data_refined_mask:
# get mask with sam instead of label from self.real_B and self.input_B_ref_bbox
self.label_sam_B = (
predict_sam(self.real_B, self.predictor_sam, self.input_B_ref_bbox)
> 0.0
)
label_B = self.label_sam_B.long()
else:
label_B = self.input_B_label_mask
pred_B = f_s(self.real_B)
self.loss_f_s += self.criterionf_s(pred_B, label_B) # .squeeze(1))

Expand Down
2 changes: 2 additions & 0 deletions models/cut_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -453,6 +453,8 @@ def data_dependent_initialize_semantic_mask(self, data):

if "mask" in self.opt.D_netDs:
visual_names_seg_B += ["real_mask_B_inv", "fake_mask_B_inv"]
if self.opt.data_refined_mask:
visual_names_seg_B += ["label_sam_B"]

self.visual_names += [visual_names_seg_A, visual_names_seg_B]

Expand Down

0 comments on commit 0cd1ee9

Please sign in to comment.