diff --git a/deepmd/pt/train/training.py b/deepmd/pt/train/training.py index f74c4769bf..af6e48191d 100644 --- a/deepmd/pt/train/training.py +++ b/deepmd/pt/train/training.py @@ -656,7 +656,6 @@ def step(_step_id, task_key="Default") -> None: # PyTorch Profiler if self.enable_profiler or self.profiling: prof.step() - self.wrapper.train() if isinstance(self.lr_exp, dict): _lr = self.lr_exp[task_key] else: @@ -682,12 +681,11 @@ def step(_step_id, task_key="Default") -> None: ) loss.backward() if self.gradient_max_norm > 0.0: - grad_norm = torch.nn.utils.clip_grad_norm_( - self.wrapper.parameters(), self.gradient_max_norm + torch.nn.utils.clip_grad_norm_( + self.wrapper.parameters(), + self.gradient_max_norm, + error_if_nonfinite=True, ) - if not torch.isfinite(grad_norm).all(): - # check local gradnorm single GPU case, trigger NanDetector - raise FloatingPointError("gradients are Nan/Inf") with torch.device("cpu"): self.optimizer.step() self.scheduler.step() @@ -766,7 +764,7 @@ def fake_model(): if self.display_in_training and ( display_step_id % self.disp_freq == 0 or display_step_id == 1 ): - self.wrapper.eval() + self.wrapper.eval() # Will set to train mode before fininshing validation def log_loss_train(_loss, _more_loss, _task_key="Default"): results = {} @@ -872,6 +870,7 @@ def log_loss_valid(_task_key="Default"): learning_rate=None, ) ) + self.wrapper.train() current_time = time.time() train_time = current_time - self.t0 @@ -927,6 +926,7 @@ def log_loss_valid(_task_key="Default"): f"{task_key}/{item}", more_loss[item], display_step_id ) + self.wrapper.train() self.t0 = time.time() self.total_train_time = 0.0 for step_id in range(self.num_steps):