-
Notifications
You must be signed in to change notification settings - Fork 2.4k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Add MemoryProfileCallback #10166
Merged
Merged
Add MemoryProfileCallback #10166
Changes from all commits
Commits
Show all changes
7 commits
Select commit
Hold shift + click to select a range
7aff422
Add MemoryProfileCallback
ShriyaPalsamudram 34d7f52
Apply isort and black reformatting
ShriyaPalsamudram 43f7ddb
Remove reference cycles, save snapshot on specific ranks
ShriyaPalsamudram cb2a83a
Remove unnecessary imports
ShriyaPalsamudram 84e3c25
Apply isort and black reformatting
ShriyaPalsamudram 73d0459
Update docstring
ShriyaPalsamudram 7b5741b
Merge branch 'main' into shriya/mem_profiler
ShriyaPalsamudram File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,78 @@ | ||
import os | ||
|
||
import torch | ||
from pytorch_lightning.callbacks.callback import Callback | ||
from torch.utils.viz._cycles import warn_tensor_cycles | ||
|
||
from nemo.lightning import io | ||
from nemo.utils import logging | ||
from nemo.utils.get_rank import get_rank | ||
|
||
|
||
|
||
class MemoryProfileCallback(Callback, io.IOMixin): | ||
""" | ||
This callback enables recording a timeline of memory allocations during training. | ||
The generated .pickle profiles can be analyzed at https://pytorch.org/memory_viz | ||
More info about the profiles can be found [here](https://pytorch.org/blog/understanding-gpu-memory-1/). | ||
Args: | ||
dir (Optional[str]): Directory to store the memory profile dump | ||
warn_cycles (Optional[bool]): Whether to enable [reference cycle detection](https://pytorch.org/blog/understanding-gpu-memory-2/) | ||
rank (Optional[list[int]]): List of ranks to collect snapshot on, defaults to all if list is empty | ||
Example: | ||
>>> callback = MemoryProfileCallback(dir="/mem_profile", ranks=[0]) | ||
>>> trainer = Trainer(callbacks=[callback]) | ||
""" | ||
|
||
def __init__(self, dir: str = "/mem_profile", warn_cycles=True, ranks=[]): | ||
|
||
self.dir = dir | ||
self.ranks = ranks | ||
|
||
os.makedirs(self.dir, exist_ok=True) | ||
logging.info(f"Torch memory profiles will be written to: {self.dir}") | ||
|
||
if warn_cycles: | ||
logging.info("Enabling reference cycle detector") | ||
warn_tensor_cycles() | ||
|
||
def enable_on_rank(self) -> bool: | ||
if not self.ranks: | ||
return True | ||
return get_rank() in self.ranks | ||
|
||
def setup(self, trainer, pl_module, stage) -> None: | ||
"""PyTorch Lightning hook: | ||
https://pytorch-lightning.readthedocs.io/en/stable/common/lightning_module.html#on-train-end | ||
We use it here to start recording the memory profiler. | ||
""" | ||
|
||
if trainer.max_steps > 1000: | ||
logging.warning( | ||
f"Memory profiling creates snapshots during the entire training process, \ | ||
where every iteration increases the size of the snapshot. \ | ||
Try reducing trainer.max_steps to avoid running into issues" | ||
) | ||
|
||
if torch.distributed.is_initialized() and self.enable_on_rank(): | ||
torch.cuda.memory._record_memory_history(max_entries=100000) | ||
|
||
def on_train_end(self, trainer, pl_module) -> None: | ||
"""PyTorch Lightning hook: | ||
https://pytorch-lightning.readthedocs.io/en/stable/common/lightning_module.html#on-train-end | ||
We use it here to finish memory profiling and write the snapshot. | ||
""" | ||
|
||
logging.info( | ||
f"on_train_batch_end rank: {get_rank()} mem: {torch.cuda.memory_allocated()/1024/1024/1024} / {torch.cuda.max_memory_reserved()/1024/1024/1024}" | ||
) | ||
|
||
if torch.distributed.is_initialized() and self.enable_on_rank(): | ||
rank = get_rank() | ||
_snapshot_path = f"{self.dir}/memory_snapshot-rank{rank}.pickle" | ||
logging.info(f"Writing memory profile snapshot to {_snapshot_path}") | ||
torch.cuda.memory._dump_snapshot(f"{_snapshot_path}") | ||
torch.cuda.memory._record_memory_history(enabled=None) | ||
logging.info(f"Finished writing memory profile snapshot: {_snapshot_path}") |
Oops, something went wrong.
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Can we move this import inside the callback to be safe?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Isn't it better to fail early if there is an issue with this import than fail inside the callback?