Skip to content

Commit

Permalink
feat(pt): support disp_training and time_training in pt (deepmodeling…
Browse files Browse the repository at this point in the history
…#3775)

<!-- This is an auto-generated comment: release notes by coderabbit.ai
-->

## Summary by CodeRabbit

- **New Features**
- Introduced options to display training progress and log training
times, enhancing user visibility and tracking capabilities during model
training sessions.


<!-- end of auto-generated comment: release notes by coderabbit.ai -->
  • Loading branch information
iProzd authored and Mathieu Taillefumier committed Sep 18, 2024
1 parent 4469977 commit dbe979d
Showing 1 changed file with 28 additions and 2 deletions.
30 changes: 28 additions & 2 deletions deepmd/pt/train/training.py
Original file line number Diff line number Diff line change
Expand Up @@ -139,6 +139,8 @@ def __init__(
self.save_ckpt = training_params.get("save_ckpt", "model.ckpt")
self.save_freq = training_params.get("save_freq", 1000)
self.max_ckpt_keep = training_params.get("max_ckpt_keep", 5)
self.display_in_training = training_params.get("disp_training", True)
self.timing_in_training = training_params.get("time_training", True)
self.lcurve_should_print_header = True

def get_opt_param(params):
Expand Down Expand Up @@ -811,7 +813,7 @@ def fake_model():
raise ValueError(f"Not supported optimizer type '{self.opt_type}'")

# Log and persist
if _step_id % self.disp_freq == 0:
if self.display_in_training and _step_id % self.disp_freq == 0:
self.wrapper.eval()

def log_loss_train(_loss, _more_loss, _task_key="Default"):
Expand Down Expand Up @@ -922,13 +924,18 @@ def log_loss_valid(_task_key="Default"):
current_time = time.time()
train_time = current_time - self.t0
self.t0 = current_time
if self.rank == 0:
if self.rank == 0 and self.timing_in_training:
log.info(
format_training_message(
batch=_step_id,
wall_time=train_time,
)
)
# the first training time is not accurate
if (
_step_id + 1
) > self.disp_freq or self.num_steps < 2 * self.disp_freq:
self.total_train_time += train_time

if fout:
if self.lcurve_should_print_header:
Expand Down Expand Up @@ -964,6 +971,7 @@ def log_loss_valid(_task_key="Default"):
writer.add_scalar(f"{task_key}/{item}", more_loss[item], _step_id)

self.t0 = time.time()
self.total_train_time = 0.0
for step_id in range(self.num_steps):
if step_id < self.start_step:
continue
Expand Down Expand Up @@ -995,6 +1003,24 @@ def log_loss_valid(_task_key="Default"):
with open("checkpoint", "w") as f:
f.write(str(self.latest_model))

if self.timing_in_training and self.num_steps // self.disp_freq > 0:
if self.num_steps >= 2 * self.disp_freq:
log.info(
"average training time: %.4f s/batch (exclude first %d batches)",
self.total_train_time
/ (
self.num_steps // self.disp_freq * self.disp_freq
- self.disp_freq
),
self.disp_freq,
)
else:
log.info(
"average training time: %.4f s/batch",
self.total_train_time
/ (self.num_steps // self.disp_freq * self.disp_freq),
)

if JIT:
pth_model_path = (
"frozen_model.pth" # We use .pth to denote the frozen model
Expand Down

0 comments on commit dbe979d

Please sign in to comment.