Skip to content

Commit

Permalink
Support tensor name alias for optimizer
Browse files Browse the repository at this point in the history
  • Loading branch information
Stonesjtu authored Jul 31, 2024
1 parent e3151f9 commit 791174a
Show file tree
Hide file tree
Showing 2 changed files with 36 additions and 1 deletion.
19 changes: 18 additions & 1 deletion pytorch_memlab/mem_reporter.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,14 +20,18 @@ class MemReporter():
Parameters:
- model: an extra nn.Module can be passed to infer the name
of Tensors
- pre_collect: do a garbage collection before getting remaining
Tensors, this gives cleaner outputs.
Caution: This is an intrusive change to your original code.
"""
def __init__(self, model: Optional[torch.nn.Module] = None):
def __init__(self, model: Optional[torch.nn.Module] = None, pre_collect: bool = False):
self.tensor_name = {}
self.device_mapping = defaultdict(list)
self.device_tensor_stat = {}
# to numbering the unknown tensors
self.name_idx = 0
self.pre_collect = pre_collect

tensor_names = defaultdict(list)
if model is not None:
Expand All @@ -51,6 +55,16 @@ def _get_tensor_name(self, tensor: torch.Tensor) -> str:
self.name_idx += 1
return name

def add_optimizer(self, optimizer: torch.optim.Optimizer):
optimizer_name = optimizer.__class__.__name__
for param, states in optimizer.state.items():
param_name = self.tensor_name[id(param)]
for name, tensor in states.items():
self.tensor_name[id(tensor)] = f'{optimizer_name}.{param_name}.{name}'
# self.tensor_name[id()]
# print(states)


def collect_tensor(self):
"""Collect all tensor objects tracked by python
Expand All @@ -61,6 +75,9 @@ def collect_tensor(self):
I don't know why.
"""
#FIXME: make the grad tensor collected by gc
# Do a pre-garbage collect to eliminate python garbage objects
if self.pre_collect:
gc.collect()
objects = gc.get_objects()
tensors = [obj for obj in objects if isinstance(obj, torch.Tensor)]
for t in tensors:
Expand Down
18 changes: 18 additions & 0 deletions test/test_mem_reporter.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import torch
import torch.optim
from pytorch_memlab import MemReporter

import pytest
Expand Down Expand Up @@ -57,6 +58,23 @@ def test_reporter_tie_weight():
reporter = MemReporter(container)
reporter.report()

def test_reporter_with_optimizer():
linear = torch.nn.Linear(1024, 1024)
inp = torch.Tensor(512, 1024)
optimizer = torch.optim.Adam(linear.parameters())
# reporter = MemReporter(linear)

out = linear(inp*(inp+3)*(inp+2)).mean()
reporter = MemReporter(linear)
reporter.report()
out.backward()
# reporter.report()
optimizer.step()

reporter.add_optimizer(optimizer)
reporter.report()


@pytest.mark.skipif(not torch.cuda.is_available(), reason='no CUDA')
@pytest.mark.skipif(concentrate_mode, reason='concentrate')
def test_reporter_LSTM():
Expand Down

0 comments on commit 791174a

Please sign in to comment.