diff --git a/nemo/lightning/pytorch/callbacks/memory_profiler.py b/nemo/lightning/pytorch/callbacks/memory_profiler.py index bf783871cfe7..10adaa496dfa 100644 --- a/nemo/lightning/pytorch/callbacks/memory_profiler.py +++ b/nemo/lightning/pytorch/callbacks/memory_profiler.py @@ -11,7 +11,7 @@ class MemoryProfileCallback(Callback, io.IOMixin): """ - This callback enables recording a timeline of memory allocations during training. + This callback enables recording a timeline of memory allocations during training. The generated .pickle profiles can be analyzed at https://pytorch.org/memory_viz More info about the profiles can be found [here](https://pytorch.org/blog/understanding-gpu-memory-1/). @@ -30,7 +30,6 @@ def __init__(self, dir: str = "/mem_profile"): os.makedirs(self.dir, exist_ok=True) logging.info(f"Torch memory profiles will be written to: {self.dir},") - def setup(self, trainer, pl_module, stage) -> None: """PyTorch Lightning hook: https://pytorch-lightning.readthedocs.io/en/stable/common/lightning_module.html#on-train-end @@ -40,14 +39,15 @@ def setup(self, trainer, pl_module, stage) -> None: if torch.distributed.is_initialized(): torch.cuda.memory._record_memory_history(max_entries=100000) - def on_train_end(self, trainer, pl_module) -> None: """PyTorch Lightning hook: https://pytorch-lightning.readthedocs.io/en/stable/common/lightning_module.html#on-train-end We use it here to finish memory profiling and write the snapshot. """ - logging.info(f"on_train_batch_end rank: {torch.distributed.get_rank()} mem: {torch.cuda.memory_allocated()/1024/1024/1024} / {torch.cuda.max_memory_reserved()/1024/1024/1024}") + logging.info( + f"on_train_batch_end rank: {torch.distributed.get_rank()} mem: {torch.cuda.memory_allocated()/1024/1024/1024} / {torch.cuda.max_memory_reserved()/1024/1024/1024}" + ) if torch.distributed.is_initialized(): rank = torch.distributed.get_rank() @@ -55,4 +55,4 @@ def on_train_end(self, trainer, pl_module) -> None: logging.info(f"Writing memory profile snapshot to {_snapshot_path}") torch.cuda.memory._dump_snapshot(f"{_snapshot_path}") torch.cuda.memory._record_memory_history(enabled=None) - logging.info(f"Finished writing memory profile snapshot: {_snapshot_path}") \ No newline at end of file + logging.info(f"Finished writing memory profile snapshot: {_snapshot_path}")