Skip to content

Commit

Permalink
Fixed #2143 and many more :) (#3855)
Browse files Browse the repository at this point in the history
  • Loading branch information
williamFalcon authored and awaelchli committed Oct 5, 2020
1 parent 31e97be commit 02c66ca
Show file tree
Hide file tree
Showing 2 changed files with 36 additions and 0 deletions.
2 changes: 2 additions & 0 deletions pytorch_lightning/core/step_result.py
Original file line number Diff line number Diff line change
Expand Up @@ -894,6 +894,8 @@ def write_dict(self, predictions_dict, filename='predictions.pt'):


def weighted_mean(result, weights):
if not isinstance(result, torch.Tensor):
result = torch.tensor(result)
weights = weights.to(result.device)[:result.size(0)]
numerator = torch.dot(result.float(), weights.transpose(-1, 0).float())
result = numerator / weights.sum().float()
Expand Down
34 changes: 34 additions & 0 deletions tests/trainer/logging/test_eval_loop_logging_1_0.py
Original file line number Diff line number Diff line change
Expand Up @@ -196,6 +196,40 @@ def validation_epoch_end(self, outputs):
assert len(trainer.dev_debugger.logged_metrics) == max_epochs


def test_eval_float_logging(tmpdir):
"""
Tests that only training_step can be used
"""
os.environ['PL_DEV_DEBUG'] = '1'

class TestModel(BoringModel):

def validation_step(self, batch, batch_idx):
output = self.layer(batch)
loss = self.loss(batch, output)
self.log('a', 12.0)
return {"x": loss}

model = TestModel()

trainer = Trainer(
default_root_dir=tmpdir,
limit_train_batches=2,
limit_val_batches=2,
max_epochs=1,
row_log_interval=1,
weights_summary=None,
)
trainer.fit(model)

# make sure all the metrics are available for callbacks
logged_metrics = set(trainer.logged_metrics.keys())
expected_logged_metrics = {
'a',
}
assert logged_metrics == expected_logged_metrics


def test_monitor_val_epoch_end(tmpdir):
epoch_min_loss_override = 0
model = SimpleModule()
Expand Down

0 comments on commit 02c66ca

Please sign in to comment.