Skip to content

Commit

Permalink
fix: D_global optimization
Browse files Browse the repository at this point in the history
  • Loading branch information
beniz committed Dec 16, 2021
1 parent bb366b7 commit 284d8ae
Showing 1 changed file with 4 additions and 1 deletion.
5 changes: 4 additions & 1 deletion models/cut_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -147,7 +147,10 @@ def __init__(self, opt,rank):
self.group_G = NetworkGroup(networks_to_optimize=["G","F"],forward_functions=["forward"],backward_functions=["compute_G_loss"],loss_names_list=["loss_names_G"],optimizer=["optimizer_G"],loss_backward=["loss_G"])
self.networks_groups.append(self.group_G)

self.group_D = NetworkGroup(networks_to_optimize=["D"],forward_functions=None,backward_functions=["compute_D_loss"],loss_names_list=["loss_names_D"],optimizer=["optimizer_D"],loss_backward=["loss_D_tot"])
D_to_optimize = ["D"]
if opt.netD_global != "none":
D_to_optimize.append("D_global")
self.group_D = NetworkGroup(networks_to_optimize=D_to_optimize,forward_functions=None,backward_functions=["compute_D_loss"],loss_names_list=["loss_names_D"],optimizer=["optimizer_D"],loss_backward=["loss_D_tot"])
self.networks_groups.append(self.group_D)


Expand Down

0 comments on commit 284d8ae

Please sign in to comment.