Skip to content

Commit

Permalink
feat: added optimizers and options
Browse files Browse the repository at this point in the history
  • Loading branch information
beniz committed Apr 19, 2022
1 parent dc2a00f commit 505cac2
Show file tree
Hide file tree
Showing 12 changed files with 54 additions and 30 deletions.
12 changes: 8 additions & 4 deletions models/cut_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -159,19 +159,22 @@ def __init__(self, opt, rank):
for nce_layer in self.nce_layers:
self.criterionNCE.append(PatchNCELoss(opt).to(self.device))

self.optimizer_G = torch.optim.Adam(
self.optimizer_G = opt.optim(
opt,
self.netG.parameters(),
lr=opt.train_G_lr,
betas=(opt.train_beta1, opt.train_beta2),
)
if opt.D_netD_global == "none":
self.optimizer_D = torch.optim.Adam(
self.optimizer_D = opt.optim(
opt,
self.netD.parameters(),
lr=opt.train_D_lr,
betas=(opt.train_beta1, opt.train_beta2),
)
else:
self.optimizer_D = torch.optim.Adam(
self.optimizer_D = opt.optim(
opt,
itertools.chain(
self.netD.parameters(), self.netD_global.parameters()
),
Expand Down Expand Up @@ -257,7 +260,8 @@ def data_dependent_initialize(self, data):
self.opt.alg_cut_lambda_NCE > 0.0
and not self.opt.alg_cut_netF == "sample"
):
self.optimizer_F = torch.optim.Adam(
self.optimizer_F = self.opt.optim(
self.opt,
self.netF.parameters(),
lr=self.opt.train_G_lr,
betas=(self.opt.train_beta1, self.opt.train_beta2),
Expand Down
6 changes: 4 additions & 2 deletions models/cut_semantic_mask_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -75,15 +75,17 @@ def __init__(self, opt, rank):
)

if self.opt.train_mask_disjoint_f_s:
self.optimizer_f_s = torch.optim.Adam(
self.optimizer_f_s = opt.optim(
opt,
itertools.chain(
self.netf_s_A.parameters(), self.netf_s_B.parameters()
),
lr=opt.train_sem_lr_f_s,
betas=(opt.train_beta1, opt.train_beta2),
)
else:
self.optimizer_f_s = torch.optim.Adam(
self.optimizer_f_s = opt.optim(
opt,
self.netf_s.parameters(),
lr=opt.train_sem_lr_f_s,
betas=(opt.train_beta1, opt.train_beta2),
Expand Down
3 changes: 2 additions & 1 deletion models/cut_semantic_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,7 +50,8 @@ def __init__(self, opt, rank):
# define loss functions
self.criterionCLS = torch.nn.modules.CrossEntropyLoss()

self.optimizer_CLS = torch.optim.Adam(
self.optimizer_CLS = opt.optim(
opt,
self.netCLS.parameters(),
lr=opt.train_sem_lr_f_s,
betas=(opt.train_beta1, opt.train_beta2),
Expand Down
9 changes: 6 additions & 3 deletions models/cycle_gan_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -155,19 +155,22 @@ def __init__(self, opt, rank):
self.criterionIdt = torch.nn.L1Loss()

# initialize optimizers; schedulers will be automatically created by function <BaseModel.setup>.
self.optimizer_G = torch.optim.Adam(
self.optimizer_G = opt.optim(
opt,
itertools.chain(self.netG_A.parameters(), self.netG_B.parameters()),
lr=opt.train_G_lr,
betas=(opt.train_beta1, opt.train_beta2),
)
if opt.D_netD_global == "none":
self.optimizer_D = torch.optim.Adam(
self.optimizer_D = opt.optim(
opt,
itertools.chain(self.netD_A.parameters(), self.netD_B.parameters()),
lr=opt.train_D_lr,
betas=(opt.train_beta1, opt.train_beta2),
)
else:
self.optimizer_D = torch.optim.Adam(
self.optimizer_D = opt.optim(
opt,
itertools.chain(
self.netD_A.parameters(),
self.netD_B.parameters(),
Expand Down
6 changes: 4 additions & 2 deletions models/cycle_gan_semantic_mask_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -83,15 +83,17 @@ def __init__(self, opt, rank):

# initialize optimizers
if self.opt.train_mask_disjoint_f_s:
self.optimizer_f_s = torch.optim.Adam(
self.optimizer_f_s = opt.optim(
opt,
itertools.chain(
self.netf_s_A.parameters(), self.netf_s_B.parameters()
),
lr=opt.train_sem_lr_f_s,
betas=(opt.train_beta1, opt.train_beta2),
)
else:
self.optimizer_f_s = torch.optim.Adam(
self.optimizer_f_s = opt.optim(
opt,
self.netf_s.parameters(),
lr=opt.train_sem_lr_f_s,
betas=(opt.train_beta1, opt.train_beta2),
Expand Down
3 changes: 2 additions & 1 deletion models/cycle_gan_semantic_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,7 +49,8 @@ def __init__(self, opt, rank):
self.criterionCLS = torch.nn.modules.CrossEntropyLoss()

# initialize optimizers
self.optimizer_CLS = torch.optim.Adam(
self.optimizer_CLS = opt.optim(
opt,
self.netCLS.parameters(),
lr=opt.train_sem_lr_f_s,
betas=(opt.train_beta1, opt.train_beta2),
Expand Down
6 changes: 4 additions & 2 deletions models/re_cut_semantic_mask_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -59,7 +59,8 @@ def __init__(self, opt):
)
self.model_names += ["P_B"]

self.optimizer_P = torch.optim.Adam(
self.optimizer_P = opt.optim(
opt,
itertools.chain(self.netP_B.parameters()),
lr=opt.alg_re_P_lr,
betas=(opt.train_beta1, opt.train_beta2),
Expand Down Expand Up @@ -130,7 +131,8 @@ def data_dependent_initialize(self, data):
self.loss_G.backward() # calculate gradients for G
self.loss_P.backward() # calculate gradients for P
if self.opt.alg_cut_lambda_NCE > 0.0:
self.optimizer_F = torch.optim.Adam(
self.optimizer_F = opt.optim(
opt,
self.netF.parameters(),
lr=self.opt.train_G_lr,
betas=(self.opt.train_beta1, self.opt.train_beta2),
Expand Down
3 changes: 2 additions & 1 deletion models/re_cycle_gan_semantic_mask_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -73,7 +73,8 @@ def __init__(self, opt):
)
self.model_names += ["P_A", "P_B"]

self.optimizer_P = torch.optim.Adam(
self.optimizer_P = opt.optim(
opt,
itertools.chain(self.netP_A.parameters(), self.netP_B.parameters()),
lr=opt.alg_re_P_lr,
betas=(opt.train_beta1, opt.train_beta2),
Expand Down
4 changes: 2 additions & 2 deletions models/segmentation_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,8 +50,8 @@ def __init__(self, opt):
self.criterionf_s = torch.nn.modules.NLLLoss()
self.criterionf_s = torch.nn.modules.CrossEntropyLoss()
# initialize optimizers
self.optimizer_f_s = torch.optim.Adam(
self.netf_s.parameters(), lr=opt.lr, betas=(opt.beta1, 0.999)
self.optimizer_f_s = opt.optim(
opt, self.netf_s.parameters(), lr=opt.lr, betas=(opt.beta1, 0.999)
)
print("f defined")
self.optimizers = []
Expand Down
4 changes: 2 additions & 2 deletions models/template_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -74,8 +74,8 @@ def __init__(self, opt):
self.criterionLoss = torch.nn.L1Loss()
# define and initialize optimizers. You can define one optimizer for each network.
# If two networks are updated at the same time, you can use itertools.chain to group them. See cycle_gan_model.py for an example.
self.optimizer = torch.optim.Adam(
self.netG.parameters(), lr=opt.lr, betas=(opt.beta1, 0.999)
self.optimizer = opt.optim(
opt, self.netG.parameters(), lr=opt.lr, betas=(opt.beta1, 0.999)
)
self.optimizers = [self.optimizer]

Expand Down
6 changes: 6 additions & 0 deletions options/train_options.py
Original file line number Diff line number Diff line change
Expand Up @@ -130,6 +130,12 @@ def initialize(self, parser):
default="latest",
help="which epoch to load? set to latest to use latest cached model",
)
parser.add_argument(
"--train_optim",
default="adam",
choices=["adam", "radam", "adamw"],
help="optimizer (adam, radam, adamw, ...)",
)
parser.add_argument(
"--train_load_iter",
type=int,
Expand Down
22 changes: 12 additions & 10 deletions train.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,16 @@ def setup(rank, world_size, port):
dist.init_process_group("nccl", rank=rank, world_size=world_size)


def optim(opt, params, lr, betas):
print("Using ", opt.train_optim, " as optimizer")
if opt.train_optim == "adam":
return torch.optim.Adam(params, lr, betas)
elif opt.train_optim == "radam":
return torch.optim.RAdam(params, lr, betas)
elif opt.train_optim == "adamw":
return torch.optim.AdamW(params, lr, betas)


def signal_handler(sig, frame):
dist.destroy_process_group()

Expand All @@ -53,6 +63,7 @@ def train_gpu(rank, world_size, opt, dataset):
opt, rank, dataset
) # create a dataset given opt.dataset_mode and other options
dataset_size = len(dataset) # get the number of images in the dataset.
opt.optim = optim # set optimizer
model = create_model(opt, rank) # create a model given opt.model and other options

if hasattr(model, "data_dependent_initialize"):
Expand Down Expand Up @@ -232,16 +243,7 @@ def launch_training(opt=None):
dataset = create_dataset(opt)
print("The number of training images = %d" % len(dataset))

mp.spawn(
train_gpu,
args=(
world_size,
opt,
dataset,
),
nprocs=world_size,
join=True,
)
mp.spawn(train_gpu, args=(world_size, opt, dataset), nprocs=world_size, join=True)


if __name__ == "__main__":
Expand Down

0 comments on commit 505cac2

Please sign in to comment.