Skip to content

Commit

Permalink
only set training mode on first entrance and exiting valid
Browse files Browse the repository at this point in the history
  • Loading branch information
caic99 authored Nov 27, 2024
1 parent 53bbbdb commit 078e848
Showing 1 changed file with 3 additions and 4 deletions.
7 changes: 3 additions & 4 deletions deepmd/pt/train/training.py
Original file line number Diff line number Diff line change
Expand Up @@ -656,8 +656,6 @@ def step(_step_id, task_key="Default") -> None:
# PyTorch Profiler
if self.enable_profiler or self.profiling:
prof.step()
if not self.wrapper.training:
self.wrapper.train()
if isinstance(self.lr_exp, dict):
_lr = self.lr_exp[task_key]
else:
Expand Down Expand Up @@ -766,7 +764,7 @@ def fake_model():
if self.display_in_training and (
display_step_id % self.disp_freq == 0 or display_step_id == 1
):
self.wrapper.eval()
self.wrapper.eval() # Will set to train mode before fininshing validation

def log_loss_train(_loss, _more_loss, _task_key="Default"):
results = {}
Expand Down Expand Up @@ -872,6 +870,7 @@ def log_loss_valid(_task_key="Default"):
learning_rate=None,
)
)
self.wrapper.train()

current_time = time.time()
train_time = current_time - self.t0
Expand Down Expand Up @@ -926,7 +925,7 @@ def log_loss_valid(_task_key="Default"):
writer.add_scalar(
f"{task_key}/{item}", more_loss[item], display_step_id
)

self.wrapper.train()
self.t0 = time.time()
self.total_train_time = 0.0
for step_id in range(self.num_steps):
Expand Down

0 comments on commit 078e848

Please sign in to comment.