Skip to content

Commit

Permalink
add_inf_detect_mode (#12)
Browse files Browse the repository at this point in the history
  • Loading branch information
drisspg authored Jan 10, 2024
1 parent 7241c72 commit 9b600d0
Show file tree
Hide file tree
Showing 2 changed files with 88 additions and 1 deletion.
45 changes: 45 additions & 0 deletions test/test_utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,45 @@
import unittest.mock as mock

import pytest
import torch
from transformer_nuggets.utils.tracing import NanInfDetect


def test_nan():
a = torch.tensor(
[
0.0,
]
)
with pytest.raises(RuntimeError, match="returned a NaN"):
with NanInfDetect():
print(torch.div(a, a))


def test_inf():
a = torch.tensor(
[
1.0,
],
dtype=torch.float16,
)
with pytest.raises(RuntimeError, match="returned an Inf"):
with NanInfDetect():
print(torch.mul(a, 65537))


def test_breakpoint():
a = torch.tensor(
[
0.0,
]
)
with pytest.raises(RuntimeError, match="returned a NaN"):
with mock.patch("builtins.breakpoint") as mock_breakpoint:
with NanInfDetect(do_breakpoint=True):
print(torch.div(a, a))
mock_breakpoint.assert_called_once()


if __name__ == "__main__":
pytest.main([__file__])
44 changes: 43 additions & 1 deletion transformer_nuggets/utils/tracing.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
import torch
import torch.overrides
from torch.utils._python_dispatch import TorchDispatchMode
from torch.utils._pytree import tree_map
from torch.utils._pytree import tree_flatten, tree_map, tree_map_only
from torch.utils.weak import WeakIdRef

dtype_abbrs = {
Expand Down Expand Up @@ -93,3 +93,45 @@ def __torch_dispatch__(self, func, types, args=(), kwargs=None):
else:
print(log_msg)
return rs


def get_error_string(func, types, args, kwargs, output_has_nan, output_has_inf):
error_string = f"Function {func}(*{args}, **{kwargs}) returned "
if output_has_nan:
error_string += "a NaN"
if output_has_inf:
if output_has_nan:
error_string += " and an Inf"
else:
error_string += "an Inf"
return error_string


class NanInfDetect(TorchDispatchMode):
"""This mode can be helpful for debugging NaNs or Infs in your code.
Example usage:
```Python
>>> a = torch.tensor([0.,])
>>> with NanDetect():
>>> print(torch.div(a, a)
RuntimeError: Function aten.div.Tensor(*(tensor([0.]), tensor([0.])), **{}) returned a NaN
```
"""

def __init__(self, do_breakpoint: bool = False):
super().__init__()
self.do_breakpoint = do_breakpoint

def __torch_dispatch__(self, func, types, args, kwargs=None):
kwargs = kwargs or {}
res = func(*args, **kwargs)

output_has_nan = tree_map_only(torch.Tensor, lambda x: torch.isnan(x), res).any()
output_has_inf = tree_map_only(torch.Tensor, lambda x: torch.isinf(x), res).any()
if output_has_nan or output_has_inf:
if self.do_breakpoint:
breakpoint()
raise RuntimeError(
get_error_string(func, types, args, kwargs, output_has_nan, output_has_inf)
)
return res

0 comments on commit 9b600d0

Please sign in to comment.