Skip to content

Commit

Permalink
fix: cls net was created with f_s options
Browse files Browse the repository at this point in the history
  • Loading branch information
pnsuau authored and beniz committed Jan 16, 2023
1 parent cc846a5 commit 70e9009
Show file tree
Hide file tree
Showing 4 changed files with 14 additions and 47 deletions.
5 changes: 3 additions & 2 deletions data/unaligned_labeled_cls_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,13 +40,14 @@ def __init__(self, opt):
self.dir_A, opt.data_max_dataset_size
) # load images from '/path/to/data/trainA' as well as labels
self.A_label = np.array(self.A_label)

else:

self.A_img_paths, self.A_label = make_labeled_path_dataset(
self.dir_A, "/paths.txt", opt.data_max_dataset_size
) # load images from '/path/to/data/trainA/paths.txt' as well as labels
self.A_label = np.array(self.A_label, dtype=np.float32)

# print('A_label',self.A_label)
if opt.train_sem_use_label_B:
if not os.path.isfile(self.dir_B + "/paths.txt"):
self.B_img_paths, self.B_label = make_labeled_dataset(
Expand All @@ -70,7 +71,7 @@ def __init__(self, opt):
self.transform_A = get_transform(self.opt, grayscale=(self.input_nc == 1))
self.transform_B = get_transform(self.opt, grayscale=(self.output_nc == 1))

self.semantic_nclasses = self.opt.f_s_semantic_nclasses
self.semantic_nclasses = self.opt.cls_semantic_nclasses

def get_img(
self,
Expand Down
38 changes: 0 additions & 38 deletions models/base_gan_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -702,44 +702,6 @@ def compute_G_loss_semantic_cls(self):
if hasattr(self, "fake_A"):
self.compute_G_loss_semantic_cls_generic(domain_fake="A")

def compute_G_loss_semantic_cls_old(self):
"""Calculate semantic class loss for G"""

# semantic class loss AB
if not self.opt.train_cls_regression:
self.loss_G_sem_cls_AB = self.criterionCLS(
self.pred_cls_fake_B, self.input_A_label_cls
)
else:
self.loss_G_sem_cls_AB = self.criterionCLS(
self.pred_cls_fake_B.squeeze(1), self.input_A_label_cls
)
if (
not hasattr(self, "loss_CLS")
or self.loss_CLS > self.opt.f_s_semantic_threshold
):
self.loss_G_sem_cls_AB = 0 * self.loss_G_sem_cls_AB
self.loss_G_sem_cls_AB *= self.opt.train_sem_cls_lambda
self.loss_G_tot += self.loss_G_sem_cls_AB

# semantic class loss BA
if hasattr(self, "fake_A"):
if not self.opt.train_cls_regression:
self.loss_G_sem_cls_BA = self.criterionCLS(
self.pred_cls_fake_A, self.input_B_label_cls
)
else:
self.loss_G_sem_cls_BA = self.criterionCLS(
self.pred_cls_fake_A.squeeze(1), self.input_B_label_cls
)
if (
not hasattr(self, "loss_CLS")
or self.loss_CLS > self.opt.f_s_semantic_threshold
):
self.loss_G_sem_cls_BA = 0 * self.loss_G_sem_cls_BA
self.loss_G_sem_cls_BA *= self.opt.train_sem_cls_lambda
self.loss_G_tot += self.loss_G_sem_cls_BA

def compute_G_loss_semantic_mask(self):
self.compute_G_loss_semantic_mask_generic(domain_fake="B")

Expand Down
8 changes: 6 additions & 2 deletions models/base_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -462,14 +462,18 @@ def set_input_semantic_mask(self, data):
def set_input_semantic_cls(self, data):
if "A_label_cls" in data:
if not self.opt.train_cls_regression:
self.input_A_label_cls = data["A_label_cls"].to(self.device)
self.input_A_label_cls = (
data["A_label_cls"].to(torch.long).to(self.device)
)
else:
self.input_A_label_cls = (
data["A_label_cls"].to(torch.float).to(device=self.device)
)
if "B_label_cls" in data:
if not self.opt.train_cls_regression:
self.input_B_label_cls = data["B_label_cls"].to(self.device)
self.input_B_label_cls = (
data["B_label_cls"].to(torch.long).to(self.device)
)
else:
self.input_B_label_cls = (
data["B_label_cls"].to(torch.float).to(device=self.device)
Expand Down
10 changes: 5 additions & 5 deletions models/semantic_networks.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,9 +14,9 @@

def define_C(
model_output_nc,
f_s_nf,
cls_nf,
data_crop_size,
f_s_semantic_nclasses,
cls_semantic_nclasses,
train_sem_cls_template,
model_init_type,
model_init_gain,
Expand All @@ -25,12 +25,12 @@ def define_C(
):
img_size = data_crop_size
if train_sem_cls_template == "basic":
netC = Classifier(model_output_nc, f_s_nf, f_s_semantic_nclasses, img_size)
netC = Classifier(model_output_nc, cls_nf, cls_semantic_nclasses, img_size)
else:
netC = torch_model(
model_output_nc,
f_s_nf,
f_s_semantic_nclasses,
cls_nf,
cls_semantic_nclasses,
img_size,
train_sem_cls_template,
train_sem_cls_pretrained,
Expand Down

0 comments on commit 70e9009

Please sign in to comment.