From 7aff422f603ac3ee4c50abd043c8646670842f87 Mon Sep 17 00:00:00 2001 From: Shriya Palsamudram Date: Thu, 15 Aug 2024 14:49:02 -0700 Subject: [PATCH 1/6] Add MemoryProfileCallback Signed-off-by: Shriya Palsamudram --- nemo/lightning/pytorch/callbacks/__init__.py | 2 + .../pytorch/callbacks/memory_profiler.py | 58 +++++++++++++++++++ 2 files changed, 60 insertions(+) create mode 100644 nemo/lightning/pytorch/callbacks/memory_profiler.py diff --git a/nemo/lightning/pytorch/callbacks/__init__.py b/nemo/lightning/pytorch/callbacks/__init__.py index 00637c9d57d4..bdf73b70e74a 100644 --- a/nemo/lightning/pytorch/callbacks/__init__.py +++ b/nemo/lightning/pytorch/callbacks/__init__.py @@ -1,3 +1,4 @@ +from nemo.lightning.pytorch.callbacks.memory_profiler import MemoryProfileCallback from nemo.lightning.pytorch.callbacks.model_checkpoint import ModelCheckpoint from nemo.lightning.pytorch.callbacks.model_transform import ModelTransform from nemo.lightning.pytorch.callbacks.nsys import NsysCallback @@ -8,6 +9,7 @@ __all__ = [ + "MemoryProfileCallback", "ModelCheckpoint", "ModelTransform", "PEFT", diff --git a/nemo/lightning/pytorch/callbacks/memory_profiler.py b/nemo/lightning/pytorch/callbacks/memory_profiler.py new file mode 100644 index 000000000000..bf783871cfe7 --- /dev/null +++ b/nemo/lightning/pytorch/callbacks/memory_profiler.py @@ -0,0 +1,58 @@ +import os +from typing import List, Optional + +import torch +from pytorch_lightning.callbacks.callback import Callback +from nemo.lightning import io + +from nemo.utils import logging +from nemo.utils.get_rank import get_rank + + +class MemoryProfileCallback(Callback, io.IOMixin): + """ + This callback enables recording a timeline of memory allocations during training. + The generated .pickle profiles can be analyzed at https://pytorch.org/memory_viz + + More info about the profiles can be found [here](https://pytorch.org/blog/understanding-gpu-memory-1/). + + Args: + dir (Optional[str]): Directory to store the memory profile dump + + Example: + >>> callback = MemoryProfileCallback(dir="/mem_profile") + >>> trainer = Trainer(callbacks=[callback]) + """ + + def __init__(self, dir: str = "/mem_profile"): + + self.dir = dir + os.makedirs(self.dir, exist_ok=True) + logging.info(f"Torch memory profiles will be written to: {self.dir},") + + + def setup(self, trainer, pl_module, stage) -> None: + """PyTorch Lightning hook: + https://pytorch-lightning.readthedocs.io/en/stable/common/lightning_module.html#on-train-end + We use it here to start recording the memory profiler. + """ + + if torch.distributed.is_initialized(): + torch.cuda.memory._record_memory_history(max_entries=100000) + + + def on_train_end(self, trainer, pl_module) -> None: + """PyTorch Lightning hook: + https://pytorch-lightning.readthedocs.io/en/stable/common/lightning_module.html#on-train-end + We use it here to finish memory profiling and write the snapshot. + """ + + logging.info(f"on_train_batch_end rank: {torch.distributed.get_rank()} mem: {torch.cuda.memory_allocated()/1024/1024/1024} / {torch.cuda.max_memory_reserved()/1024/1024/1024}") + + if torch.distributed.is_initialized(): + rank = torch.distributed.get_rank() + _snapshot_path = f"{self.dir}/memory_snapshot-rank{rank}.pickle" + logging.info(f"Writing memory profile snapshot to {_snapshot_path}") + torch.cuda.memory._dump_snapshot(f"{_snapshot_path}") + torch.cuda.memory._record_memory_history(enabled=None) + logging.info(f"Finished writing memory profile snapshot: {_snapshot_path}") \ No newline at end of file From 34d7f52389964ec4b3a5a32922d371725619ec5f Mon Sep 17 00:00:00 2001 From: ShriyaPalsamudram Date: Thu, 15 Aug 2024 21:58:06 +0000 Subject: [PATCH 2/6] Apply isort and black reformatting Signed-off-by: ShriyaPalsamudram --- nemo/lightning/pytorch/callbacks/memory_profiler.py | 10 +++++----- 1 file changed, 5 insertions(+), 5 deletions(-) diff --git a/nemo/lightning/pytorch/callbacks/memory_profiler.py b/nemo/lightning/pytorch/callbacks/memory_profiler.py index bf783871cfe7..10adaa496dfa 100644 --- a/nemo/lightning/pytorch/callbacks/memory_profiler.py +++ b/nemo/lightning/pytorch/callbacks/memory_profiler.py @@ -11,7 +11,7 @@ class MemoryProfileCallback(Callback, io.IOMixin): """ - This callback enables recording a timeline of memory allocations during training. + This callback enables recording a timeline of memory allocations during training. The generated .pickle profiles can be analyzed at https://pytorch.org/memory_viz More info about the profiles can be found [here](https://pytorch.org/blog/understanding-gpu-memory-1/). @@ -30,7 +30,6 @@ def __init__(self, dir: str = "/mem_profile"): os.makedirs(self.dir, exist_ok=True) logging.info(f"Torch memory profiles will be written to: {self.dir},") - def setup(self, trainer, pl_module, stage) -> None: """PyTorch Lightning hook: https://pytorch-lightning.readthedocs.io/en/stable/common/lightning_module.html#on-train-end @@ -40,14 +39,15 @@ def setup(self, trainer, pl_module, stage) -> None: if torch.distributed.is_initialized(): torch.cuda.memory._record_memory_history(max_entries=100000) - def on_train_end(self, trainer, pl_module) -> None: """PyTorch Lightning hook: https://pytorch-lightning.readthedocs.io/en/stable/common/lightning_module.html#on-train-end We use it here to finish memory profiling and write the snapshot. """ - logging.info(f"on_train_batch_end rank: {torch.distributed.get_rank()} mem: {torch.cuda.memory_allocated()/1024/1024/1024} / {torch.cuda.max_memory_reserved()/1024/1024/1024}") + logging.info( + f"on_train_batch_end rank: {torch.distributed.get_rank()} mem: {torch.cuda.memory_allocated()/1024/1024/1024} / {torch.cuda.max_memory_reserved()/1024/1024/1024}" + ) if torch.distributed.is_initialized(): rank = torch.distributed.get_rank() @@ -55,4 +55,4 @@ def on_train_end(self, trainer, pl_module) -> None: logging.info(f"Writing memory profile snapshot to {_snapshot_path}") torch.cuda.memory._dump_snapshot(f"{_snapshot_path}") torch.cuda.memory._record_memory_history(enabled=None) - logging.info(f"Finished writing memory profile snapshot: {_snapshot_path}") \ No newline at end of file + logging.info(f"Finished writing memory profile snapshot: {_snapshot_path}") From 43f7ddb610d3485054fecd56f783e7d829e956d0 Mon Sep 17 00:00:00 2001 From: Shriya Palsamudram Date: Wed, 21 Aug 2024 13:23:36 -0700 Subject: [PATCH 3/6] Remove reference cycles, save snapshot on specific ranks Signed-off-by: Shriya Palsamudram --- .../pytorch/callbacks/memory_profiler.py | 32 +++++++++++++++---- 1 file changed, 26 insertions(+), 6 deletions(-) diff --git a/nemo/lightning/pytorch/callbacks/memory_profiler.py b/nemo/lightning/pytorch/callbacks/memory_profiler.py index 10adaa496dfa..ee6fcd5209ab 100644 --- a/nemo/lightning/pytorch/callbacks/memory_profiler.py +++ b/nemo/lightning/pytorch/callbacks/memory_profiler.py @@ -8,6 +8,7 @@ from nemo.utils import logging from nemo.utils.get_rank import get_rank +from torch.utils.viz._cycles import warn_tensor_cycles class MemoryProfileCallback(Callback, io.IOMixin): """ @@ -24,11 +25,24 @@ class MemoryProfileCallback(Callback, io.IOMixin): >>> trainer = Trainer(callbacks=[callback]) """ - def __init__(self, dir: str = "/mem_profile"): + def __init__(self, dir: str = "/mem_profile", warn_cycles=True, ranks=[]): self.dir = dir + self.ranks = ranks + os.makedirs(self.dir, exist_ok=True) - logging.info(f"Torch memory profiles will be written to: {self.dir},") + logging.info(f"Torch memory profiles will be written to: {self.dir}") + + if warn_cycles: + logging.info("Enabling reference cycle detector") + warn_tensor_cycles() + + + def enable_on_rank(self) -> bool: + if not self.ranks: + return True + return get_rank() in self.ranks + def setup(self, trainer, pl_module, stage) -> None: """PyTorch Lightning hook: @@ -36,9 +50,15 @@ def setup(self, trainer, pl_module, stage) -> None: We use it here to start recording the memory profiler. """ - if torch.distributed.is_initialized(): + if trainer.max_steps > 1000: + logging.warning(f"Memory profiling creates snapshots during the entire training process, \ + where every iteration increases the size of the snapshot. \ + Try reducing trainer.max_steps to avoid running into issues") + + if torch.distributed.is_initialized() and self.enable_on_rank(): torch.cuda.memory._record_memory_history(max_entries=100000) + def on_train_end(self, trainer, pl_module) -> None: """PyTorch Lightning hook: https://pytorch-lightning.readthedocs.io/en/stable/common/lightning_module.html#on-train-end @@ -46,11 +66,11 @@ def on_train_end(self, trainer, pl_module) -> None: """ logging.info( - f"on_train_batch_end rank: {torch.distributed.get_rank()} mem: {torch.cuda.memory_allocated()/1024/1024/1024} / {torch.cuda.max_memory_reserved()/1024/1024/1024}" + f"on_train_batch_end rank: {get_rank()} mem: {torch.cuda.memory_allocated()/1024/1024/1024} / {torch.cuda.max_memory_reserved()/1024/1024/1024}" ) - if torch.distributed.is_initialized(): - rank = torch.distributed.get_rank() + if torch.distributed.is_initialized() and self.enable_on_rank(): + rank = get_rank() _snapshot_path = f"{self.dir}/memory_snapshot-rank{rank}.pickle" logging.info(f"Writing memory profile snapshot to {_snapshot_path}") torch.cuda.memory._dump_snapshot(f"{_snapshot_path}") From cb2a83a1c941f128b21d9b0306408c84a4813d37 Mon Sep 17 00:00:00 2001 From: Shriya Palsamudram Date: Wed, 21 Aug 2024 13:26:12 -0700 Subject: [PATCH 4/6] Remove unnecessary imports Signed-off-by: Shriya Palsamudram --- nemo/lightning/pytorch/callbacks/memory_profiler.py | 1 - 1 file changed, 1 deletion(-) diff --git a/nemo/lightning/pytorch/callbacks/memory_profiler.py b/nemo/lightning/pytorch/callbacks/memory_profiler.py index ee6fcd5209ab..c208e4ebcf7b 100644 --- a/nemo/lightning/pytorch/callbacks/memory_profiler.py +++ b/nemo/lightning/pytorch/callbacks/memory_profiler.py @@ -1,5 +1,4 @@ import os -from typing import List, Optional import torch from pytorch_lightning.callbacks.callback import Callback From 84e3c250edbd86d5531137495f79531629942aeb Mon Sep 17 00:00:00 2001 From: ShriyaPalsamudram Date: Wed, 21 Aug 2024 20:28:17 +0000 Subject: [PATCH 5/6] Apply isort and black reformatting Signed-off-by: ShriyaPalsamudram --- nemo/lightning/pytorch/callbacks/memory_profiler.py | 13 ++++++------- 1 file changed, 6 insertions(+), 7 deletions(-) diff --git a/nemo/lightning/pytorch/callbacks/memory_profiler.py b/nemo/lightning/pytorch/callbacks/memory_profiler.py index c208e4ebcf7b..7f4af887b896 100644 --- a/nemo/lightning/pytorch/callbacks/memory_profiler.py +++ b/nemo/lightning/pytorch/callbacks/memory_profiler.py @@ -2,12 +2,12 @@ import torch from pytorch_lightning.callbacks.callback import Callback -from nemo.lightning import io +from torch.utils.viz._cycles import warn_tensor_cycles +from nemo.lightning import io from nemo.utils import logging from nemo.utils.get_rank import get_rank -from torch.utils.viz._cycles import warn_tensor_cycles class MemoryProfileCallback(Callback, io.IOMixin): """ @@ -36,13 +36,11 @@ def __init__(self, dir: str = "/mem_profile", warn_cycles=True, ranks=[]): logging.info("Enabling reference cycle detector") warn_tensor_cycles() - def enable_on_rank(self) -> bool: if not self.ranks: return True return get_rank() in self.ranks - def setup(self, trainer, pl_module, stage) -> None: """PyTorch Lightning hook: https://pytorch-lightning.readthedocs.io/en/stable/common/lightning_module.html#on-train-end @@ -50,14 +48,15 @@ def setup(self, trainer, pl_module, stage) -> None: """ if trainer.max_steps > 1000: - logging.warning(f"Memory profiling creates snapshots during the entire training process, \ + logging.warning( + f"Memory profiling creates snapshots during the entire training process, \ where every iteration increases the size of the snapshot. \ - Try reducing trainer.max_steps to avoid running into issues") + Try reducing trainer.max_steps to avoid running into issues" + ) if torch.distributed.is_initialized() and self.enable_on_rank(): torch.cuda.memory._record_memory_history(max_entries=100000) - def on_train_end(self, trainer, pl_module) -> None: """PyTorch Lightning hook: https://pytorch-lightning.readthedocs.io/en/stable/common/lightning_module.html#on-train-end From 73d045962456ff042a8adffa27568a85bc40dae2 Mon Sep 17 00:00:00 2001 From: Shriya Palsamudram Date: Wed, 21 Aug 2024 13:31:30 -0700 Subject: [PATCH 6/6] Update docstring Signed-off-by: Shriya Palsamudram --- nemo/lightning/pytorch/callbacks/memory_profiler.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/nemo/lightning/pytorch/callbacks/memory_profiler.py b/nemo/lightning/pytorch/callbacks/memory_profiler.py index 7f4af887b896..089479637f61 100644 --- a/nemo/lightning/pytorch/callbacks/memory_profiler.py +++ b/nemo/lightning/pytorch/callbacks/memory_profiler.py @@ -18,9 +18,11 @@ class MemoryProfileCallback(Callback, io.IOMixin): Args: dir (Optional[str]): Directory to store the memory profile dump + warn_cycles (Optional[bool]): Whether to enable [reference cycle detection](https://pytorch.org/blog/understanding-gpu-memory-2/) + rank (Optional[list[int]]): List of ranks to collect snapshot on, defaults to all if list is empty Example: - >>> callback = MemoryProfileCallback(dir="/mem_profile") + >>> callback = MemoryProfileCallback(dir="/mem_profile", ranks=[0]) >>> trainer = Trainer(callbacks=[callback]) """