Skip to content

Commit

Permalink
cosmetic
Browse files Browse the repository at this point in the history
  • Loading branch information
fpaissan committed Nov 22, 2023
1 parent 8f441ff commit 083d164
Showing 1 changed file with 11 additions and 8 deletions.
19 changes: 11 additions & 8 deletions micromind/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -350,8 +350,7 @@ def on_train_end(self):
shutil.rmtree(self.experiment_folder)

def eval(self):
for m self.modules:
self.modules[m].eval()
self.modules.eval()

def train(
self,
Expand Down Expand Up @@ -422,12 +421,16 @@ def train(
self.opt.step()

for m in self.metrics:
if (self.current_epoch + 1) % m.eval_period == 0 and not m.eval_only:
if (
self.current_epoch + 1
) % m.eval_period == 0 and not m.eval_only:
m(model_out, batch, Stage.train, self.device)

running_train = {}
for m in self.metrics:
if (self.current_epoch + 1) % m.eval_period == 0 and not m.eval_only:
if (
self.current_epoch + 1
) % m.eval_period == 0 and not m.eval_only:
running_train["train_" + m.name] = m.reduce(Stage.train)

running_train.update({"train_loss": loss_epoch / (idx + 1)})
Expand All @@ -441,7 +444,9 @@ def train(

train_metrics = {}
for m in self.metrics:
if (self.current_epoch + 1) % m.eval_period == 0 and not m.eval_only:
if (
self.current_epoch + 1
) % m.eval_period == 0 and not m.eval_only:
train_metrics["train_" + m.name] = m.reduce(Stage.train, True)

train_metrics.update({"train_loss": loss_epoch / (idx + 1)})
Expand Down Expand Up @@ -506,9 +511,7 @@ def validate(self) -> Dict:
val_metrics = {}
for m in self.metrics:
if (self.current_epoch + 1) % m.eval_period == 0:
val_metrics = {
"val_" + m.name: m.reduce(Stage.val, True)
}
val_metrics = {"val_" + m.name: m.reduce(Stage.val, True)}

val_metrics.update({"val_loss": loss_epoch / (idx + 1)})

Expand Down

0 comments on commit 083d164

Please sign in to comment.