Skip to content

Commit

Permalink
Merge remote-tracking branch 'origin/master' into HEAD
Browse files Browse the repository at this point in the history
  • Loading branch information
rohitgr7 committed Sep 15, 2020
2 parents 0b8fbc1 + 4ed96b2 commit 874a2b4
Show file tree
Hide file tree
Showing 3 changed files with 32 additions and 5 deletions.
2 changes: 2 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -62,6 +62,8 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).

- Fixed overfit_batches which now correctly disables shuffling for the training loader. ([#3501](https://github.com/PyTorchLightning/pytorch-lightning/pull/3501))

- Fixed gradient norm tracking for `row_log_interval > 1` ([#3489](https://github.com/PyTorchLightning/pytorch-lightning/pull/3489))

## [0.9.0] - YYYY-MM-DD

### Added
Expand Down
9 changes: 4 additions & 5 deletions pytorch_lightning/trainer/training_loop.py
Original file line number Diff line number Diff line change
Expand Up @@ -376,13 +376,12 @@ def on_before_backward(self, batch_idx, optimizer):
return grad_norm_dic

def _track_gradient_norm(self, batch_idx):
grad_norm_dic = {}
if batch_idx % self.trainer.row_log_interval == 0:
grad_norm_dict = {}
if (batch_idx + 1) % self.trainer.row_log_interval == 0:
if float(self.trainer.track_grad_norm) > 0:
model = self.trainer.get_model()
grad_norm_dic = model.grad_norm(
self.trainer.track_grad_norm)
return grad_norm_dic
grad_norm_dict = model.grad_norm(self.trainer.track_grad_norm)
return grad_norm_dict

def log_training_step_metrics(self, opt_closure_result, batch_callback_metrics, batch_log_metrics):
# track callback metrics
Expand Down
26 changes: 26 additions & 0 deletions tests/models/test_grad_norm.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import os
from unittest.mock import patch

import numpy as np
import pytest
Expand Down Expand Up @@ -73,3 +74,28 @@ def test_grad_tracking(tmpdir, norm_type, rtol=5e-3):
log, mod = [log[k] for k in common], [mod[k] for k in common]

assert np.allclose(log, mod, rtol=rtol)


@pytest.mark.parametrize("row_log_interval", [1, 2, 3])
def test_grad_tracking_interval(tmpdir, row_log_interval):
""" Test that gradient norms get tracked in the right interval and that everytime the same keys get logged. """
trainer = Trainer(
default_root_dir=tmpdir,
track_grad_norm=2,
row_log_interval=row_log_interval,
max_steps=10,
)

with patch.object(trainer.logger, "log_metrics") as mocked:
model = EvalModelTemplate()
trainer.fit(model)
expected = trainer.global_step // row_log_interval
grad_norm_dicts = []
for _, kwargs in mocked.call_args_list:
metrics = kwargs.get("metrics", {})
grad_norm_dict = {k: v for k, v in metrics.items() if k.startswith("grad_")}
if grad_norm_dict:
grad_norm_dicts.append(grad_norm_dict)

assert len(grad_norm_dicts) == expected
assert all(grad_norm_dicts[0].keys() == g.keys() for g in grad_norm_dicts)

0 comments on commit 874a2b4

Please sign in to comment.