Skip to content

Commit

Permalink
fix(pt): make PT training step idx consistent with TF (#4221)
Browse files Browse the repository at this point in the history
Fix #4206.

Currently, the training step index displayed in TF and PT has different
meanings:
- In TF, step 0 means no training; step 1 means a training step has been
performed. The maximum training step is equal to the number of steps.
- In PT, step 0 means a training step has been performed. The maximum
training step is the number of steps minus 1.

This PR corrects the definition of the step-index in PT and makes them
consistent.

There is still a difference after this PR: TF shows step 0, but PT shows
step 1. Showing the loss of step 0 in PT needs heavy refactoring and is
thus not included in this PR.

<!-- This is an auto-generated comment: release notes by coderabbit.ai
-->
## Summary by CodeRabbit

- **New Features**
- Improved logging for training progress, starting step count from 1 for
better clarity.
	- Enhanced TensorBoard logging for consistent step tracking.

- **Bug Fixes**
- Adjusted logging conditions to ensure the first step's results are
included in the output.
<!-- end of auto-generated comment: release notes by coderabbit.ai -->

---------

Signed-off-by: Jinzhe Zeng <jinzhe.zeng@rutgers.edu>
Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
  • Loading branch information
njzjz and pre-commit-ci[bot] authored Oct 16, 2024
1 parent 5050f61 commit d7d2210
Showing 1 changed file with 18 additions and 11 deletions.
29 changes: 18 additions & 11 deletions deepmd/pt/train/training.py
Original file line number Diff line number Diff line change
Expand Up @@ -769,7 +769,10 @@ def fake_model():
raise ValueError(f"Not supported optimizer type '{self.opt_type}'")

# Log and persist
if self.display_in_training and _step_id % self.disp_freq == 0:
display_step_id = _step_id + 1
if self.display_in_training and (
display_step_id % self.disp_freq == 0 or display_step_id == 1
):
self.wrapper.eval()

def log_loss_train(_loss, _more_loss, _task_key="Default"):
Expand Down Expand Up @@ -821,7 +824,7 @@ def log_loss_valid(_task_key="Default"):
if self.rank == 0:
log.info(
format_training_message_per_task(
batch=_step_id,
batch=display_step_id,
task_name="trn",
rmse=train_results,
learning_rate=cur_lr,
Expand All @@ -830,7 +833,7 @@ def log_loss_valid(_task_key="Default"):
if valid_results:
log.info(
format_training_message_per_task(
batch=_step_id,
batch=display_step_id,
task_name="val",
rmse=valid_results,
learning_rate=None,
Expand Down Expand Up @@ -861,7 +864,7 @@ def log_loss_valid(_task_key="Default"):
if self.rank == 0:
log.info(
format_training_message_per_task(
batch=_step_id,
batch=display_step_id,
task_name=_key + "_trn",
rmse=train_results[_key],
learning_rate=cur_lr,
Expand All @@ -870,7 +873,7 @@ def log_loss_valid(_task_key="Default"):
if valid_results[_key]:
log.info(
format_training_message_per_task(
batch=_step_id,
batch=display_step_id,
task_name=_key + "_val",
rmse=valid_results[_key],
learning_rate=None,
Expand All @@ -883,7 +886,7 @@ def log_loss_valid(_task_key="Default"):
if self.rank == 0 and self.timing_in_training:
log.info(
format_training_message(
batch=_step_id,
batch=display_step_id,
wall_time=train_time,
)
)
Expand All @@ -899,7 +902,7 @@ def log_loss_valid(_task_key="Default"):
self.print_header(fout, train_results, valid_results)
self.lcurve_should_print_header = False
self.print_on_training(
fout, _step_id, cur_lr, train_results, valid_results
fout, display_step_id, cur_lr, train_results, valid_results
)

if (
Expand All @@ -921,11 +924,15 @@ def log_loss_valid(_task_key="Default"):
f.write(str(self.latest_model))

# tensorboard
if self.enable_tensorboard and _step_id % self.tensorboard_freq == 0:
writer.add_scalar(f"{task_key}/lr", cur_lr, _step_id)
writer.add_scalar(f"{task_key}/loss", loss, _step_id)
if self.enable_tensorboard and (
display_step_id % self.tensorboard_freq == 0 or display_step_id == 1
):
writer.add_scalar(f"{task_key}/lr", cur_lr, display_step_id)
writer.add_scalar(f"{task_key}/loss", loss, display_step_id)
for item in more_loss:
writer.add_scalar(f"{task_key}/{item}", more_loss[item], _step_id)
writer.add_scalar(
f"{task_key}/{item}", more_loss[item], display_step_id
)

self.t0 = time.time()
self.total_train_time = 0.0
Expand Down

0 comments on commit d7d2210

Please sign in to comment.