Skip to content

Commit

Permalink
Merge branch 'dev' of github.com:micromind-toolkit/micromind into ref…
Browse files Browse the repository at this point in the history
…actor_yolo
  • Loading branch information
fpaissan committed Nov 22, 2023
2 parents 06daa3f + 1e31955 commit 445c202
Showing 1 changed file with 6 additions and 4 deletions.
10 changes: 6 additions & 4 deletions micromind/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -493,7 +493,10 @@ def validate(self) -> Dict:
return val_metrics

@torch.no_grad()
def test(self, datasets: Dict = {}) -> None:
def test(
self,
datasets: Dict = {},
metrics: List[Metric] = []) -> None:
"""Runs the test steps."""
assert "test" in datasets, "Test dataloader was not specified."
self.modules.eval()
Expand All @@ -511,11 +514,10 @@ def test(self, datasets: Dict = {}) -> None:
for idx, batch in enumerate(pbar):
if isinstance(batch, list):
batch = [b.to(self.device) for b in batch]
self.opt.zero_grad()

model_out = self(batch)
loss = self.compute_loss(model_out, batch)
for m in self.metrics:
for m in metrics:
m(model_out, batch, Stage.test, self.device)

loss_epoch += loss.item()
Expand All @@ -524,7 +526,7 @@ def test(self, datasets: Dict = {}) -> None:
pbar.close()

test_metrics = {
"test_" + m.name: m.reduce(Stage.test, True) for m in self.metrics
"test_" + m.name: m.reduce(Stage.test, True) for m in metrics
}
test_metrics.update({"test_loss": loss_epoch / (idx + 1)})
s_out = (
Expand Down

0 comments on commit 445c202

Please sign in to comment.