From dfdb979acad523eb776db2b0c776320dd032d6a6 Mon Sep 17 00:00:00 2001 From: Shriya Palsamudram Date: Thu, 15 Aug 2024 14:49:02 -0700 Subject: [PATCH] Add MemoryProfileCallback Signed-off-by: Shriya Palsamudram --- nemo/lightning/pytorch/callbacks/__init__.py | 3 +- .../pytorch/callbacks/memory_profiler.py | 58 +++++++++++++++++++ 2 files changed, 60 insertions(+), 1 deletion(-) create mode 100644 nemo/lightning/pytorch/callbacks/memory_profiler.py diff --git a/nemo/lightning/pytorch/callbacks/__init__.py b/nemo/lightning/pytorch/callbacks/__init__.py index ee0e777d739e..eb32e6564116 100644 --- a/nemo/lightning/pytorch/callbacks/__init__.py +++ b/nemo/lightning/pytorch/callbacks/__init__.py @@ -1,3 +1,4 @@ +from nemo.lightning.pytorch.callbacks.memory_profiler import MemoryProfileCallback from nemo.lightning.pytorch.callbacks.model_checkpoint import ModelCheckpoint from nemo.lightning.pytorch.callbacks.model_transform import ModelTransform from nemo.lightning.pytorch.callbacks.nsys import NsysCallback @@ -6,4 +7,4 @@ from nemo.lightning.pytorch.callbacks.progress import MegatronProgressBar -__all__ = ["ModelCheckpoint", "ModelTransform", "PEFT", "NsysCallback", "MegatronProgressBar", "PreemptionCallback"] +__all__ = ["MemoryProfileCallback", "ModelCheckpoint", "ModelTransform", "PEFT", "NsysCallback", "MegatronProgressBar", "PreemptionCallback"] diff --git a/nemo/lightning/pytorch/callbacks/memory_profiler.py b/nemo/lightning/pytorch/callbacks/memory_profiler.py new file mode 100644 index 000000000000..bf783871cfe7 --- /dev/null +++ b/nemo/lightning/pytorch/callbacks/memory_profiler.py @@ -0,0 +1,58 @@ +import os +from typing import List, Optional + +import torch +from pytorch_lightning.callbacks.callback import Callback +from nemo.lightning import io + +from nemo.utils import logging +from nemo.utils.get_rank import get_rank + + +class MemoryProfileCallback(Callback, io.IOMixin): + """ + This callback enables recording a timeline of memory allocations during training. + The generated .pickle profiles can be analyzed at https://pytorch.org/memory_viz + + More info about the profiles can be found [here](https://pytorch.org/blog/understanding-gpu-memory-1/). + + Args: + dir (Optional[str]): Directory to store the memory profile dump + + Example: + >>> callback = MemoryProfileCallback(dir="/mem_profile") + >>> trainer = Trainer(callbacks=[callback]) + """ + + def __init__(self, dir: str = "/mem_profile"): + + self.dir = dir + os.makedirs(self.dir, exist_ok=True) + logging.info(f"Torch memory profiles will be written to: {self.dir},") + + + def setup(self, trainer, pl_module, stage) -> None: + """PyTorch Lightning hook: + https://pytorch-lightning.readthedocs.io/en/stable/common/lightning_module.html#on-train-end + We use it here to start recording the memory profiler. + """ + + if torch.distributed.is_initialized(): + torch.cuda.memory._record_memory_history(max_entries=100000) + + + def on_train_end(self, trainer, pl_module) -> None: + """PyTorch Lightning hook: + https://pytorch-lightning.readthedocs.io/en/stable/common/lightning_module.html#on-train-end + We use it here to finish memory profiling and write the snapshot. + """ + + logging.info(f"on_train_batch_end rank: {torch.distributed.get_rank()} mem: {torch.cuda.memory_allocated()/1024/1024/1024} / {torch.cuda.max_memory_reserved()/1024/1024/1024}") + + if torch.distributed.is_initialized(): + rank = torch.distributed.get_rank() + _snapshot_path = f"{self.dir}/memory_snapshot-rank{rank}.pickle" + logging.info(f"Writing memory profile snapshot to {_snapshot_path}") + torch.cuda.memory._dump_snapshot(f"{_snapshot_path}") + torch.cuda.memory._record_memory_history(enabled=None) + logging.info(f"Finished writing memory profile snapshot: {_snapshot_path}") \ No newline at end of file