From 4ab64363b972258f76dabca1b8f4bc2f13b99dc2 Mon Sep 17 00:00:00 2001 From: Duo <50307526+iProzd@users.noreply.github.com> Date: Mon, 13 May 2024 10:45:17 +0800 Subject: [PATCH] feat(pt): support disp_training and time_training in pt (#3775) ## Summary by CodeRabbit - **New Features** - Introduced options to display training progress and log training times, enhancing user visibility and tracking capabilities during model training sessions. --- deepmd/pt/train/training.py | 30 ++++++++++++++++++++++++++++-- 1 file changed, 28 insertions(+), 2 deletions(-) diff --git a/deepmd/pt/train/training.py b/deepmd/pt/train/training.py index 655b729f8c..4056b30d87 100644 --- a/deepmd/pt/train/training.py +++ b/deepmd/pt/train/training.py @@ -139,6 +139,8 @@ def __init__( self.save_ckpt = training_params.get("save_ckpt", "model.ckpt") self.save_freq = training_params.get("save_freq", 1000) self.max_ckpt_keep = training_params.get("max_ckpt_keep", 5) + self.display_in_training = training_params.get("disp_training", True) + self.timing_in_training = training_params.get("time_training", True) self.lcurve_should_print_header = True def get_opt_param(params): @@ -811,7 +813,7 @@ def fake_model(): raise ValueError(f"Not supported optimizer type '{self.opt_type}'") # Log and persist - if _step_id % self.disp_freq == 0: + if self.display_in_training and _step_id % self.disp_freq == 0: self.wrapper.eval() def log_loss_train(_loss, _more_loss, _task_key="Default"): @@ -922,13 +924,18 @@ def log_loss_valid(_task_key="Default"): current_time = time.time() train_time = current_time - self.t0 self.t0 = current_time - if self.rank == 0: + if self.rank == 0 and self.timing_in_training: log.info( format_training_message( batch=_step_id, wall_time=train_time, ) ) + # the first training time is not accurate + if ( + _step_id + 1 + ) > self.disp_freq or self.num_steps < 2 * self.disp_freq: + self.total_train_time += train_time if fout: if self.lcurve_should_print_header: @@ -964,6 +971,7 @@ def log_loss_valid(_task_key="Default"): writer.add_scalar(f"{task_key}/{item}", more_loss[item], _step_id) self.t0 = time.time() + self.total_train_time = 0.0 for step_id in range(self.num_steps): if step_id < self.start_step: continue @@ -995,6 +1003,24 @@ def log_loss_valid(_task_key="Default"): with open("checkpoint", "w") as f: f.write(str(self.latest_model)) + if self.timing_in_training and self.num_steps // self.disp_freq > 0: + if self.num_steps >= 2 * self.disp_freq: + log.info( + "average training time: %.4f s/batch (exclude first %d batches)", + self.total_train_time + / ( + self.num_steps // self.disp_freq * self.disp_freq + - self.disp_freq + ), + self.disp_freq, + ) + else: + log.info( + "average training time: %.4f s/batch", + self.total_train_time + / (self.num_steps // self.disp_freq * self.disp_freq), + ) + if JIT: pth_model_path = ( "frozen_model.pth" # We use .pth to denote the frozen model