Skip to content

Commit

Permalink
fix test - reduce metric
Browse files Browse the repository at this point in the history
  • Loading branch information
Borda committed Dec 28, 2020
1 parent f15dca2 commit abbe7ec
Showing 1 changed file with 7 additions and 8 deletions.
15 changes: 7 additions & 8 deletions tests/trainer/logging_tests/test_train_loop_logging_1_0.py
Original file line number Diff line number Diff line change
Expand Up @@ -862,19 +862,18 @@ def test_metric_are_properly_reduced(tmpdir):
class TestingModel(BoringModel):
def __init__(self, *args, **kwargs):
super().__init__()
self.train_acc = pl.metrics.Accuracy()
self.val_acc = pl.metrics.Accuracy()

def training_step(self, batch, batch_idx):
self.train_acc(torch.rand(1, 3, device=self.device), torch.randint(0, 2, (1,), device=self.device))
self.log('train_acc', self.train_acc, on_step=True, on_epoch=True)
return super().training_step(batch, batch_idx)
output = super().training_step(batch, batch_idx)
self.log("train_loss", output["loss"])
return output

def validation_step(self, batch, batch_idx):
preds = torch.tensor(0, device=self.device)
targets = torch.tensor(1, device=self.device)
preds = torch.tensor([[0.9, 0.1]], device=self.device)
targets = torch.tensor([1], device=self.device)
if batch_idx < 8:
targets = preds
preds = torch.tensor([[0.1, 0.9]], device=self.device)
self.val_acc(preds, targets)
self.log('val_acc', self.val_acc, on_step=True, on_epoch=True)
return super().validation_step(batch, batch_idx)
Expand All @@ -899,4 +898,4 @@ def validation_step(self, batch, batch_idx):
trainer.fit(model)

assert trainer.callback_metrics["val_acc"] == 8 / 32.
assert "train_acc" in trainer.callback_metrics
assert "train_loss" in trainer.callback_metrics

0 comments on commit abbe7ec

Please sign in to comment.