Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add MemoryProfileCallback #10166

Merged
merged 7 commits into from
Aug 23, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions nemo/lightning/pytorch/callbacks/__init__.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
from nemo.lightning.pytorch.callbacks.ddp_parity_checker import DdpParityChecker
from nemo.lightning.pytorch.callbacks.memory_profiler import MemoryProfileCallback
from nemo.lightning.pytorch.callbacks.model_checkpoint import ModelCheckpoint
from nemo.lightning.pytorch.callbacks.model_transform import ModelTransform
from nemo.lightning.pytorch.callbacks.nsys import NsysCallback
Expand All @@ -8,6 +9,7 @@
from nemo.lightning.pytorch.callbacks.progress_printer import ProgressPrinter

__all__ = [
"MemoryProfileCallback",
"ModelCheckpoint",
"ModelTransform",
"PEFT",
Expand Down
78 changes: 78 additions & 0 deletions nemo/lightning/pytorch/callbacks/memory_profiler.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,78 @@
import os

import torch
from pytorch_lightning.callbacks.callback import Callback
from torch.utils.viz._cycles import warn_tensor_cycles
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can we move this import inside the callback to be safe?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Isn't it better to fail early if there is an issue with this import than fail inside the callback?


from nemo.lightning import io
from nemo.utils import logging
from nemo.utils.get_rank import get_rank
Fixed Show fixed Hide fixed


class MemoryProfileCallback(Callback, io.IOMixin):
"""
This callback enables recording a timeline of memory allocations during training.
The generated .pickle profiles can be analyzed at https://pytorch.org/memory_viz
More info about the profiles can be found [here](https://pytorch.org/blog/understanding-gpu-memory-1/).
Args:
dir (Optional[str]): Directory to store the memory profile dump
warn_cycles (Optional[bool]): Whether to enable [reference cycle detection](https://pytorch.org/blog/understanding-gpu-memory-2/)
rank (Optional[list[int]]): List of ranks to collect snapshot on, defaults to all if list is empty
Example:
>>> callback = MemoryProfileCallback(dir="/mem_profile", ranks=[0])
>>> trainer = Trainer(callbacks=[callback])
"""

def __init__(self, dir: str = "/mem_profile", warn_cycles=True, ranks=[]):

self.dir = dir
self.ranks = ranks

os.makedirs(self.dir, exist_ok=True)
logging.info(f"Torch memory profiles will be written to: {self.dir}")

if warn_cycles:
logging.info("Enabling reference cycle detector")
warn_tensor_cycles()

def enable_on_rank(self) -> bool:
if not self.ranks:
return True
return get_rank() in self.ranks

def setup(self, trainer, pl_module, stage) -> None:
"""PyTorch Lightning hook:
https://pytorch-lightning.readthedocs.io/en/stable/common/lightning_module.html#on-train-end
We use it here to start recording the memory profiler.
"""

if trainer.max_steps > 1000:
logging.warning(
f"Memory profiling creates snapshots during the entire training process, \
where every iteration increases the size of the snapshot. \
Try reducing trainer.max_steps to avoid running into issues"
)

if torch.distributed.is_initialized() and self.enable_on_rank():
torch.cuda.memory._record_memory_history(max_entries=100000)

def on_train_end(self, trainer, pl_module) -> None:
"""PyTorch Lightning hook:
https://pytorch-lightning.readthedocs.io/en/stable/common/lightning_module.html#on-train-end
We use it here to finish memory profiling and write the snapshot.
"""

logging.info(
f"on_train_batch_end rank: {get_rank()} mem: {torch.cuda.memory_allocated()/1024/1024/1024} / {torch.cuda.max_memory_reserved()/1024/1024/1024}"
)

if torch.distributed.is_initialized() and self.enable_on_rank():
rank = get_rank()
_snapshot_path = f"{self.dir}/memory_snapshot-rank{rank}.pickle"
logging.info(f"Writing memory profile snapshot to {_snapshot_path}")
torch.cuda.memory._dump_snapshot(f"{_snapshot_path}")
torch.cuda.memory._record_memory_history(enabled=None)
logging.info(f"Finished writing memory profile snapshot: {_snapshot_path}")
Loading