Skip to content

Commit

Permalink
[bug] Fix Pytorch profiler with emit_nvtx (#6260)
Browse files Browse the repository at this point in the history
* resolve bug

* update changelog

* Update tests/trainer/test_trainer.py

* Update pytorch_lightning/profiler/profilers.py

Co-authored-by: Jirka Borovec <Borda@users.noreply.github.com>

* resolve comments

* resolve flake8

Co-authored-by: Carlos Mocholí <carlossmocholi@gmail.com>
Co-authored-by: Jirka Borovec <Borda@users.noreply.github.com>
  • Loading branch information
3 people authored Mar 5, 2021
1 parent e848542 commit 2ec67a4
Show file tree
Hide file tree
Showing 5 changed files with 34 additions and 3 deletions.
6 changes: 6 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -83,6 +83,12 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
- Fixed DP reduction with collection ([#6324](https://github.com/PyTorchLightning/pytorch-lightning/pull/6324))


- Fixed PyTorch Profiler with `emit_nvtx` ([#6260](https://github.com/PyTorchLightning/pytorch-lightning/pull/6260))


- Fixed `trainer.test` from `best_path` hangs after calling `trainer.fit` ([#6272](https://github.com/PyTorchLightning/pytorch-lightning/pull/6272))


## [1.2.2] - 2021-03-02

### Added
Expand Down
1 change: 0 additions & 1 deletion pytorch_lightning/profiler/profilers.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,6 @@
# See the License for the specific language governing permissions and
# limitations under the License.
"""Profiler to check if there are any bottlenecks in your code."""

import cProfile
import io
import logging
Expand Down
13 changes: 11 additions & 2 deletions pytorch_lightning/profiler/pytorch.py
Original file line number Diff line number Diff line change
Expand Up @@ -205,7 +205,7 @@ def start(self, action_name: str) -> None:

def _start(self, action_name: str) -> None:
if self.emit_nvtx:
self._create_profiler(action_name, torch.cuda.profiler.profile, enter=False)
self._parent_profiler = self._create_profiler(action_name, torch.cuda.profiler.profile, enter=True)
self._create_profiler(action_name, torch.autograd.profiler.emit_nvtx)
else:
self._create_profiler(action_name, torch.autograd.profiler.profile)
Expand All @@ -215,15 +215,24 @@ def _create_profiler(self, action_name, profiler, enter=True):
profiler_args = {k: v for k, v in vars(self).items() if k in init_args}
pr = profiler(**profiler_args)
if enter:
pr = pr.__enter__()
out_pr = pr.__enter__()
if out_pr is not None:
pr = out_pr
self.profiler = pr
return self.profiler

def _stop(self, action_name: str) -> None:
if self.profiler is None:
return

self.profiler.__exit__(exc_type=None, exc_val=None, exc_tb=None)

if isinstance(self.profiler, torch.autograd.profiler.emit_nvtx):
# when running ``emit_nvtx``, PyTorch requires 2 context manager.
# The parent_profiler is being closed too.
self._parent_profiler.__exit__(None, None, None)
return

function_events = self.profiler.function_events
self.profiler = None
for name in self.running_stack:
Expand Down
1 change: 1 addition & 0 deletions tests/special_tests.sh
Original file line number Diff line number Diff line change
Expand Up @@ -32,3 +32,4 @@ python ${DEFAULTS} tests/trainer/test_trainer.py::test_pytorch_profiler_trainer_
python ${DEFAULTS} tests/models/test_hooks.py::test_transfer_batch_hook_ddp
python ${DEFAULTS} tests/trainer/test_data_loading.py::test_replace_distrubuted_sampler_custom_dataloader_custom_batch_sampler
python ${DEFAULTS} tests/trainer/optimization/test_manual_optimization.py::test_step_with_optimizer_closure_with_different_frequencies_ddp_with_toggle_model
nvprof --profile-from-start off -o trace_name.prof -- python ${DEFAULTS} tests/trainer/test_trainer.py::test_pytorch_profiler_nested_emit_nvtx
16 changes: 16 additions & 0 deletions tests/trainer/test_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -1554,6 +1554,22 @@ def test_pytorch_profiler_nested(tmpdir):
assert pa[n] == expected_[n]


@RunIf(min_gpus=1, special=True)
def test_pytorch_profiler_nested_emit_nvtx(tmpdir):
"""
This test check emit_nvtx is correctly supported
"""
profiler = PyTorchProfiler(use_cuda=True, emit_nvtx=True)

model = BoringModel()
trainer = Trainer(
fast_dev_run=True,
profiler=profiler,
gpus=1,
)
trainer.fit(model)


@pytest.mark.parametrize(
["limit_train_batches", "global_step", "num_training_batches", "current_epoch", "should_train"],
[(0.2, 0, 0, 0, False), (0.5, 10, 2, 4, True)],
Expand Down

0 comments on commit 2ec67a4

Please sign in to comment.