Skip to content

Commit

Permalink
minor bug fixes
Browse files Browse the repository at this point in the history
  • Loading branch information
fpaissan committed Nov 22, 2023
1 parent 31a8d1e commit 8f441ff
Show file tree
Hide file tree
Showing 2 changed files with 23 additions and 19 deletions.
38 changes: 21 additions & 17 deletions micromind/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -349,6 +349,10 @@ def on_train_end(self):
logger.info(f"Removed temporary folder {self.experiment_folder}.")
shutil.rmtree(self.experiment_folder)

def eval(self):
for m self.modules:
self.modules[m].eval()

def train(
self,
epochs: int = 1,
Expand Down Expand Up @@ -418,14 +422,13 @@ def train(
self.opt.step()

for m in self.metrics:
if not m.eval_only and (e + 1) % m.eval_period == 0:
if (self.current_epoch + 1) % m.eval_period == 0 and not m.eval_only:
m(model_out, batch, Stage.train, self.device)

running_train = {
"train_" + m.name: m.reduce(Stage.train)
for m in self.metrics
if not m.eval_only
}
running_train = {}
for m in self.metrics:
if (self.current_epoch + 1) % m.eval_period == 0 and not m.eval_only:
running_train["train_" + m.name] = m.reduce(Stage.train)

running_train.update({"train_loss": loss_epoch / (idx + 1)})

Expand All @@ -436,11 +439,11 @@ def train(

pbar.close()

train_metrics = {
"train_" + m.name: m.reduce(Stage.train, True)
for m in self.metrics
if not m.eval_only
}
train_metrics = {}
for m in self.metrics:
if (self.current_epoch + 1) % m.eval_period == 0 and not m.eval_only:
train_metrics["train_" + m.name] = m.reduce(Stage.train, True)

train_metrics.update({"train_loss": loss_epoch / (idx + 1)})

if "val" in datasets:
Expand Down Expand Up @@ -500,12 +503,13 @@ def validate(self) -> Dict:
if self.debug and idx > 10:
break

if (self.current_epoch + 1) % m.eval_period == 0:
val_metrics = {
"val_" + m.name: m.reduce(Stage.val, True) for m in self.metrics
}
else:
val_metrics = {}
val_metrics = {}
for m in self.metrics:
if (self.current_epoch + 1) % m.eval_period == 0:
val_metrics = {
"val_" + m.name: m.reduce(Stage.val, True)
}

val_metrics.update({"val_loss": loss_epoch / (idx + 1)})

pbar.close()
Expand Down
4 changes: 2 additions & 2 deletions recipes/objection_detection/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -295,10 +295,10 @@ def mAP(self, pred, batch):

hparams = parse_arguments()
m = YOLO(m_cfg, hparams=hparams)
mAP = Metric("mAP", m.mAP)
mAP = Metric("mAP", m.mAP, eval_only=False, eval_period=2)

m.train(
epochs=50,
epochs=2,
datasets={"train": train_loader, "val": val_loader},
metrics=[mAP],
debug=hparams.debug,
Expand Down

0 comments on commit 8f441ff

Please sign in to comment.