Skip to content

Commit

Permalink
fix test for torch 2.3
Browse files Browse the repository at this point in the history
  • Loading branch information
boeddeker committed Jun 14, 2024
1 parent 1429669 commit 5c1fab9
Showing 1 changed file with 45 additions and 16 deletions.
61 changes: 45 additions & 16 deletions tests/test_train/test_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -561,7 +561,11 @@ def test_released_tensors():
dt_dataset = dt_dataset[:2]

class ReleaseTestHook(pt.train.hooks.Hook):
def get_all_tensors(self):
def __init__(self, global_tensors):
self.global_tensors = global_tensors

@staticmethod
def get_all_tensors():
import gc
tensors = []
for obj in gc.get_objects():
Expand Down Expand Up @@ -607,16 +611,22 @@ def show_referrers_type(cls, obj, depth, ignore=list()):
ignore=ignore + [referrers, o, obj]
):
l.append(textwrap.indent(s, ' '*4))
else:
l.append('... cycle ...')
class c:
magenta = '\033[35m'
reset = '\033[0m'
cyan = '\033[36m'

if inspect.isframe(obj):
frame_info = inspect.getframeinfo(obj, context=1)
if frame_info.function == 'show_referrers_type':
pass
else:
info = f' {frame_info.function}, {frame_info.filename}:{frame_info.lineno}'
info = f' {frame_info.function}, {c.magenta}{frame_info.filename}{c.reset}:{c.magenta}{frame_info.lineno}{c.reset}'
l.append(f'Frame: {type(obj)} {info}')
else:
l.append(str(type(obj)) + str(obj)[:80].replace('\n', ' '))
l.append(str(type(obj)) + str(obj)[:160].replace('\n', ' '))
return l

def pre_step(self, trainer: 'pt.Trainer'):
Expand All @@ -643,19 +653,20 @@ def pre_step(self, trainer: 'pt.Trainer'):
]

import textwrap
print(len(all_tensors), len(parameters), len(optimizer_tensors))
- print(len(all_tensors), len(parameters), len(optimizer_tensors))

def format_(name, tensors):
s = textwrap.indent("\n".join(map(str, all_tensors)), " "*8)
return f'{name}: {len(tensors)}\n{s}\n'

assert len(all_tensors) == len(parameters) + len(optimizer_tensors) + len(grads), (
assert len(all_tensors) == len(parameters) + len(optimizer_tensors) + len(grads) + len(self.global_tensors), (
f'pre_step\n'
f'{summary}\n'
f'all_tensors: {len(all_tensors)}\n'
+ textwrap.indent("\n".join(map(str, all_tensors)), " "*8) + f'\n'
f'parameters: {len(parameters)}\n'
+ textwrap.indent("\n".join(map(str, parameters)), " "*8) + f'\n'
f'parameters: {len(grads)}\n'
+ textwrap.indent("\n".join(map(str, grads)), " "*8) + f'\n'
f'optimizer_tensors: {len(optimizer_tensors)}\n'
+ textwrap.indent("\n".join(map(str, optimizer_tensors)), " "*8) + f'\n'
+ format_('all_tensors', all_tensors)
+ format_('parameters', parameters)
+ format_('optimizer_tensors', optimizer_tensors)
+ format_('grads', grads)
+ format_('global_tensors', self.global_tensors)
)

def post_step(
Expand All @@ -665,12 +676,30 @@ def post_step(
parameters = list(trainer.model.parameters())
assert len(all_tensors) > len(parameters), ('post_step', all_tensors, parameters)


print('pre TemporaryDirectory', ReleaseTestHook.get_all_tensors())

try:
# Between Torch 2.1.2 and 2.3.1 someone created _nt_view_dummy,
# which is the only Tensor in torch, that is created with an import
# of torch code.
# For some unknown reason the Adam optimizer triggers this import
# with the __init__ call.
# Do it here manually to be able to find all "global" tensors.
from torch.nested._internal.nested_tensor import _nt_view_dummy
except Exception:
pass

global_tensors = ReleaseTestHook.get_all_tensors()

with tempfile.TemporaryDirectory() as tmp_dir:
tmp_dir = Path(tmp_dir)

model = Model()
optimizer = pt.optimizer.Adam()
t = pt.Trainer(
Model(),
optimizer=pt.optimizer.Adam(),
model=model,
optimizer=optimizer,
storage_dir=str(tmp_dir),
stop_trigger=(1, 'epoch'),
summary_trigger=(1, 'epoch'),
Expand All @@ -679,7 +708,7 @@ def post_step(
t.register_validation_hook(
validation_iterator=dt_dataset, max_checkpoints=None
)
t.register_hook(ReleaseTestHook()) # This hook will do the tests
t.register_hook(ReleaseTestHook(global_tensors)) # This hook will do the tests
t.train(tr_dataset)


Expand Down

0 comments on commit 5c1fab9

Please sign in to comment.