Skip to content

Commit

Permalink
feat: optimization eps value control
Browse files Browse the repository at this point in the history
  • Loading branch information
beniz committed Oct 11, 2023
1 parent 124ac0f commit 0556987
Show file tree
Hide file tree
Showing 6 changed files with 20 additions and 4 deletions.
3 changes: 3 additions & 0 deletions models/base_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -221,6 +221,7 @@ def init_semantic_cls(self, opt):
lr=opt.train_sem_lr_f_s,
betas=(opt.train_beta1, opt.train_beta2),
weight_decay=opt.train_optim_weight_decay,
eps=opt.train_optim_eps,
)

if opt.train_cls_regression:
Expand Down Expand Up @@ -323,6 +324,7 @@ def init_semantic_mask(self, opt):
lr=opt.train_sem_lr_f_s,
betas=(opt.train_beta1, opt.train_beta2),
weight_decay=opt.train_optim_weight_decay,
eps=opt.train_optim_eps,
)
else:
self.optimizer_f_s = opt.optim(
Expand All @@ -331,6 +333,7 @@ def init_semantic_mask(self, opt):
lr=opt.train_sem_lr_f_s,
betas=(opt.train_beta1, opt.train_beta2),
weight_decay=opt.train_optim_weight_decay,
eps=opt.train_optim_eps,
)

self.optimizers.append(self.optimizer_f_s)
Expand Down
4 changes: 4 additions & 0 deletions models/cut_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -254,6 +254,7 @@ def __init__(self, opt, rank):
lr=opt.train_G_lr,
betas=(opt.train_beta1, opt.train_beta2),
weight_decay=opt.train_optim_weight_decay,
eps=opt.train_optim_eps,
)
if self.opt.model_multimodal:
self.criterionZ = torch.nn.L1Loss()
Expand All @@ -263,6 +264,7 @@ def __init__(self, opt, rank):
lr=opt.train_G_lr,
betas=(opt.train_beta1, opt.train_beta2),
weight_decay=opt.train_optim_weight_decay,
eps=opt.train_optim_eps,
)

if len(self.discriminators_names) > 0:
Expand All @@ -283,6 +285,7 @@ def __init__(self, opt, rank):
lr=opt.train_D_lr,
betas=(opt.train_beta1, opt.train_beta2),
weight_decay=opt.train_optim_weight_decay,
eps=opt.train_optim_eps,
)

self.optimizers.append(self.optimizer_G)
Expand Down Expand Up @@ -426,6 +429,7 @@ def data_dependent_initialize(self, data):
lr=self.opt.train_G_lr,
betas=(self.opt.train_beta1, self.opt.train_beta2),
weight_decay=self.opt.train_optim_weight_decay,
eps=self.opt.train_optim_eps,
)
self.optimizers.append(self.optimizer_F)

Expand Down
2 changes: 2 additions & 0 deletions models/cycle_gan_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -136,6 +136,7 @@ def __init__(self, opt, rank):
lr=opt.train_G_lr,
betas=(opt.train_beta1, opt.train_beta2),
weight_decay=opt.train_optim_weight_decay,
eps=opt.train_optim_eps,
)

D_parameters = itertools.chain(
Expand All @@ -151,6 +152,7 @@ def __init__(self, opt, rank):
lr=opt.train_D_lr,
betas=(opt.train_beta1, opt.train_beta2),
weight_decay=opt.train_optim_weight_decay,
eps=opt.train_optim_eps,
)

self.optimizers.append(self.optimizer_G)
Expand Down
1 change: 1 addition & 0 deletions models/palette_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -344,6 +344,7 @@ def __init__(self, opt, rank):
lr=opt.train_G_lr,
betas=(opt.train_beta1, opt.train_beta2),
weight_decay=opt.train_optim_weight_decay,
eps=opt.train_optim_eps,
)
self.optimizers.append(self.optimizer_G)

Expand Down
6 changes: 6 additions & 0 deletions options/train_options.py
Original file line number Diff line number Diff line change
Expand Up @@ -179,6 +179,12 @@ def initialize(self, parser):
default=0.0,
help="weight decay for optimizer",
)
parser.add_argument(
"--train_optim_eps",
type=float,
default=1e-8,
help="epsilon for optimizer",
)
parser.add_argument(
"--train_load_iter",
type=int,
Expand Down
8 changes: 4 additions & 4 deletions train.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,16 +51,16 @@ def setup(rank, world_size, port):
dist.init_process_group("nccl", rank=rank, world_size=world_size)


def optim(opt, params, lr, betas, weight_decay):
def optim(opt, params, lr, betas, weight_decay, eps):
print("Using ", opt.train_optim, " as optimizer")
if opt.train_optim == "adam":
return torch.optim.Adam(params, lr, betas, weight_decay=weight_decay)
return torch.optim.Adam(params, lr, betas, weight_decay=weight_decay, eps=eps)
elif opt.train_optim == "radam":
return torch.optim.RAdam(params, lr, betas, weight_decay=weight_decay)
return torch.optim.RAdam(params, lr, betas, weight_decay=weight_decay, eps=eps)
elif opt.train_optim == "adamw":
if weight_decay == 0.0:
weight_decay = 0.01 # default value
return torch.optim.AdamW(params, lr, betas, weight_decay=weight_decay)
return torch.optim.AdamW(params, lr, betas, weight_decay=weight_decay, eps=eps)
elif opt.train_optim == "lion":
return Lion(params, lr, betas, weight_decay)

Expand Down

0 comments on commit 0556987

Please sign in to comment.