From 037cf3f3add2a24ae7c5ecc412c8f1c7669a21ea Mon Sep 17 00:00:00 2001 From: Chun Cai Date: Thu, 28 Nov 2024 06:08:42 +0800 Subject: [PATCH] perf: optimize training loop (#4426) Improvements to the training process: * [`deepmd/pt/train/training.py`](diffhunk://#diff-a90c90dc0e6a17fbe2e930f91182805b83260484c9dc1cfac3331378ffa34935R659): Added a check to skip setting the model to training mode if it already is. The profiling result shows it takes some time to recursively set it to all models. * [`deepmd/pt/train/training.py`](diffhunk://#diff-a90c90dc0e6a17fbe2e930f91182805b83260484c9dc1cfac3331378ffa34935L686-L690): Modified the gradient clipping function to include the `error_if_nonfinite` parameter, and removed the manual check for non-finite gradients and the associated exception raising. ## Summary by CodeRabbit - **New Features** - Improved training loop with enhanced error handling and control flow. - Updated gradient clipping logic for better error detection. - Refined logging functionality for training and validation results. - **Bug Fixes** - Prevented redundant training calls by adding conditional checks. - **Documentation** - Clarified method logic in the `Trainer` class without changing method signatures. --------- Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> --- deepmd/pt/train/training.py | 14 +++++++------- 1 file changed, 7 insertions(+), 7 deletions(-) diff --git a/deepmd/pt/train/training.py b/deepmd/pt/train/training.py index f74c4769bf..af6e48191d 100644 --- a/deepmd/pt/train/training.py +++ b/deepmd/pt/train/training.py @@ -656,7 +656,6 @@ def step(_step_id, task_key="Default") -> None: # PyTorch Profiler if self.enable_profiler or self.profiling: prof.step() - self.wrapper.train() if isinstance(self.lr_exp, dict): _lr = self.lr_exp[task_key] else: @@ -682,12 +681,11 @@ def step(_step_id, task_key="Default") -> None: ) loss.backward() if self.gradient_max_norm > 0.0: - grad_norm = torch.nn.utils.clip_grad_norm_( - self.wrapper.parameters(), self.gradient_max_norm + torch.nn.utils.clip_grad_norm_( + self.wrapper.parameters(), + self.gradient_max_norm, + error_if_nonfinite=True, ) - if not torch.isfinite(grad_norm).all(): - # check local gradnorm single GPU case, trigger NanDetector - raise FloatingPointError("gradients are Nan/Inf") with torch.device("cpu"): self.optimizer.step() self.scheduler.step() @@ -766,7 +764,7 @@ def fake_model(): if self.display_in_training and ( display_step_id % self.disp_freq == 0 or display_step_id == 1 ): - self.wrapper.eval() + self.wrapper.eval() # Will set to train mode before fininshing validation def log_loss_train(_loss, _more_loss, _task_key="Default"): results = {} @@ -872,6 +870,7 @@ def log_loss_valid(_task_key="Default"): learning_rate=None, ) ) + self.wrapper.train() current_time = time.time() train_time = current_time - self.t0 @@ -927,6 +926,7 @@ def log_loss_valid(_task_key="Default"): f"{task_key}/{item}", more_loss[item], display_step_id ) + self.wrapper.train() self.t0 = time.time() self.total_train_time = 0.0 for step_id in range(self.num_steps):