From 284d8ae30efd290ba082a014a93df4e6fbe22a8e Mon Sep 17 00:00:00 2001 From: Emmanuel Benazera Date: Thu, 16 Dec 2021 19:27:12 +0000 Subject: [PATCH] fix: D_global optimization --- models/cut_model.py | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/models/cut_model.py b/models/cut_model.py index 22b2f5858..cc05b3743 100644 --- a/models/cut_model.py +++ b/models/cut_model.py @@ -147,7 +147,10 @@ def __init__(self, opt,rank): self.group_G = NetworkGroup(networks_to_optimize=["G","F"],forward_functions=["forward"],backward_functions=["compute_G_loss"],loss_names_list=["loss_names_G"],optimizer=["optimizer_G"],loss_backward=["loss_G"]) self.networks_groups.append(self.group_G) - self.group_D = NetworkGroup(networks_to_optimize=["D"],forward_functions=None,backward_functions=["compute_D_loss"],loss_names_list=["loss_names_D"],optimizer=["optimizer_D"],loss_backward=["loss_D_tot"]) + D_to_optimize = ["D"] + if opt.netD_global != "none": + D_to_optimize.append("D_global") + self.group_D = NetworkGroup(networks_to_optimize=D_to_optimize,forward_functions=None,backward_functions=["compute_D_loss"],loss_names_list=["loss_names_D"],optimizer=["optimizer_D"],loss_backward=["loss_D_tot"]) self.networks_groups.append(self.group_D)