Skip to content

Commit

Permalink
fix metrics and some polishing
Browse files Browse the repository at this point in the history
  • Loading branch information
fpaissan committed Nov 22, 2023
1 parent 445c202 commit 31a8d1e
Show file tree
Hide file tree
Showing 4 changed files with 37 additions and 20 deletions.
47 changes: 31 additions & 16 deletions micromind/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
Authors:
- Francesco Paissan, 2023
"""
from typing import Dict, Union, Tuple, Callable, List
from typing import Dict, Union, Tuple, Callable, List, Optional
from abc import ABC, abstractmethod
from dataclasses import dataclass
from argparse import Namespace
Expand Down Expand Up @@ -82,15 +82,23 @@ class Metric:
0.5
"""

def __init__(self, name: str, fn: Callable, reduction="mean"):
def __init__(
self,
name: str,
fn: Callable,
reduction: Optional[str] = "mean",
eval_only: Optional[bool] = False,
eval_period: Optional[int] = 1,
):
self.name = name
self.fn = fn
self.reduction = reduction
self.eval_only = eval_only
self.eval_period = eval_period

self.history = {s: [] for s in [Stage.train, Stage.val, Stage.test]}

def __call__(self, pred, batch, stage, device="cpu"):
# if pred.device != device:
# pred = pred.to(device)
dat = self.fn(pred, batch)
if dat.ndim == 0:
dat = dat.unsqueeze(0)
Expand Down Expand Up @@ -385,6 +393,7 @@ def train(
)
with self.accelerator.autocast():
for e in range(self.start_epoch, epochs):
self.current_epoch = e
pbar = tqdm(
self.datasets["train"],
unit="batches",
Expand All @@ -409,10 +418,13 @@ def train(
self.opt.step()

for m in self.metrics:
m(model_out, batch, Stage.train, self.device)
if not m.eval_only and (e + 1) % m.eval_period == 0:
m(model_out, batch, Stage.train, self.device)

running_train = {
"train_" + m.name: m.reduce(Stage.train) for m in self.metrics
"train_" + m.name: m.reduce(Stage.train)
for m in self.metrics
if not m.eval_only
}

running_train.update({"train_loss": loss_epoch / (idx + 1)})
Expand All @@ -425,7 +437,9 @@ def train(
pbar.close()

train_metrics = {
"train_" + m.name: m.reduce(Stage.train, True) for m in self.metrics
"train_" + m.name: m.reduce(Stage.train, True)
for m in self.metrics
if not m.eval_only
}
train_metrics.update({"train_loss": loss_epoch / (idx + 1)})

Expand Down Expand Up @@ -477,26 +491,29 @@ def validate(self) -> Dict:
model_out = self(batch)
loss = self.compute_loss(model_out, batch)
for m in self.metrics:
m(model_out, batch, Stage.val, self.device)
if (self.current_epoch + 1) % m.eval_period == 0:
m(model_out, batch, Stage.val, self.device)

loss_epoch += loss.item()
pbar.set_postfix(loss=loss_epoch / (idx + 1))

if self.debug and idx > 10:
break

val_metrics = {"val_" + m.name: m.reduce(Stage.val, True) for m in self.metrics}
if (self.current_epoch + 1) % m.eval_period == 0:
val_metrics = {
"val_" + m.name: m.reduce(Stage.val, True) for m in self.metrics
}
else:
val_metrics = {}
val_metrics.update({"val_loss": loss_epoch / (idx + 1)})

pbar.close()

return val_metrics

@torch.no_grad()
def test(
self,
datasets: Dict = {},
metrics: List[Metric] = []) -> None:
def test(self, datasets: Dict = {}, metrics: List[Metric] = []) -> None:
"""Runs the test steps."""
assert "test" in datasets, "Test dataloader was not specified."
self.modules.eval()
Expand Down Expand Up @@ -525,9 +542,7 @@ def test(

pbar.close()

test_metrics = {
"test_" + m.name: m.reduce(Stage.test, True) for m in metrics
}
test_metrics = {"test_" + m.name: m.reduce(Stage.test, True) for m in metrics}
test_metrics.update({"test_loss": loss_epoch / (idx + 1)})
s_out = (
"Testing "
Expand Down
2 changes: 1 addition & 1 deletion micromind/utils/yolo_helpers.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,7 +57,7 @@ def load_config(file_path):
"path": path,
"train": train.as_posix(),
"val": val.as_posix(),
"test": test.as_posix(),
"test": test,
"names": config["names"],
"download": config.get("download"),
"yaml_file": file_path,
Expand Down
2 changes: 2 additions & 0 deletions recipes/objection_detection/extra_requirements.txt
Original file line number Diff line number Diff line change
@@ -0,0 +1,2 @@
opencv-python
ultralytics==8.0.215
6 changes: 3 additions & 3 deletions recipes/objection_detection/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -255,12 +255,12 @@ def mAP(self, pred, batch):
if __name__ == "__main__":
batch_size = 8

m_cfg, data_cfg = load_config("cfg/coco.yaml")
m_cfg, data_cfg = load_config("cfg/coco8.yaml")

mode = "train"
coco8_dataset = build_yolo_dataset(
m_cfg,
"datasets/coco/images/train2017",
"datasets/coco8/images/train",
batch_size,
data_cfg,
mode=mode,
Expand All @@ -278,7 +278,7 @@ def mAP(self, pred, batch):
mode = "val"
coco8_dataset = build_yolo_dataset(
m_cfg,
"datasets/coco/images/val2017",
"datasets/coco8/images/val",
batch_size,
data_cfg,
mode=mode,
Expand Down

0 comments on commit 31a8d1e

Please sign in to comment.