Skip to content

Commit

Permalink
add adamW and amsgrad
Browse files Browse the repository at this point in the history
  • Loading branch information
iProzd committed Oct 9, 2024
1 parent e915e99 commit 85afcc1
Show file tree
Hide file tree
Showing 2 changed files with 25 additions and 7 deletions.
15 changes: 9 additions & 6 deletions deepmd/pt/train/training.py
Original file line number Diff line number Diff line change
Expand Up @@ -179,6 +179,7 @@ def __init__(
def get_opt_param(params):
opt_type = params.get("opt_type", "Adam")
opt_param = {
"adam_amsgrad": params.get("adam_amsgrad", False),
"kf_blocksize": params.get("kf_blocksize", 5120),
"kf_start_pref_e": params.get("kf_start_pref_e", 1),
"kf_limit_pref_e": params.get("kf_limit_pref_e", 1),
Expand Down Expand Up @@ -625,14 +626,16 @@ def warm_up_linear(step, warmup_steps):

# TODO add optimizers for multitask
# author: iProzd
if self.opt_type == "Adam":
if self.opt_type in ["Adam", "AdamW"]:
adam_amsgrad = self.opt_param["adam_amsgrad"]
adam = torch.optim.Adam if self.opt_type == "Adam" else torch.optim.AdamW
if not self.use_auto_reduce:
self.optimizer = torch.optim.Adam(
self.wrapper.parameters(), lr=self.lr_exp.start_lr
self.optimizer = adam(
self.wrapper.parameters(), lr=self.lr_exp.start_lr, amsgrad=adam_amsgrad,
)
else:
self.optimizer = torch.optim.Adam(
self.wrapper.parameters(), lr=self.lr_start
self.optimizer = adam(
self.wrapper.parameters(), lr=self.lr_start, amsgrad=adam_amsgrad,
)
if optimizer_state_dict is not None and self.restart_training:
self.optimizer.load_state_dict(optimizer_state_dict)
Expand Down Expand Up @@ -738,7 +741,7 @@ def step(_step_id, task_key="Default"):
print_str = f"Step {_step_id}: sample system{log_dict['sid']} frame{log_dict['fid']}\n"
fout1.write(print_str)
fout1.flush()
if self.opt_type == "Adam":
if self.opt_type in ["Adam", "AdamW"]:
if not self.use_auto_reduce:
cur_lr = self.scheduler.get_last_lr()[0]
else:
Expand Down
17 changes: 16 additions & 1 deletion deepmd/utils/argcheck.py
Original file line number Diff line number Diff line change
Expand Up @@ -2833,7 +2833,22 @@ def training_args(
Variant(
"opt_type",
choices=[
Argument("Adam", dict, [], [], optional=True),
Argument("Adam", dict, [
Argument(
"adam_amsgrad",
bool,
optional=True,
default=False,
),
], [], optional=True),
Argument("AdamW", dict, [
Argument(
"adam_amsgrad",
bool,
optional=True,
default=False,
),
], [], optional=True),
Argument(
"LKF",
dict,
Expand Down

0 comments on commit 85afcc1

Please sign in to comment.