Skip to content

Commit

Permalink
Avoid a graph break in ModuleDict and ParameterDict (#8363)
Browse files Browse the repository at this point in the history
We have a graph break at `hasattr` call in `ModuleDict.to_internal_key`.

Repro:

```python
import torch
from torch_geometric.nn.module_dict import ModuleDict

edge_type = ("a", "to", "b")

class SomeModel(torch.nn.Module):
    def __init__(self):
        super().__init__()
        self.module_dict = ModuleDict({
            edge_type: torch.nn.Linear(1, 1),
        })

    def forward(self, x):
        # need to convert tuple to string in advance to avoid a graph break
        # due to pytorch/pytorch#111551
        key = ModuleDict.to_internal_key(edge_type)
        x = self.module_dict[key](x)
        return x

from torch._dynamo.utils import CompileProfiler
model = SomeModel()
with CompileProfiler() as prof:
    model = torch.compile(model)
    model(torch.randn(1, 1))
    print(prof.report())
```

---------

Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
Co-authored-by: rusty1s <matthias.fey@tu-dortmund.de>
  • Loading branch information
3 people authored Nov 14, 2023
1 parent b1f8535 commit ccbbbdd
Show file tree
Hide file tree
Showing 4 changed files with 13 additions and 11 deletions.
1 change: 1 addition & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).

### Added

- Added support for `torch.compile` in `ModuleDict` and `ParameterDict` ([#8363](https://github.com/pyg-team/pytorch_geometric/pull/8363))
- Added `force_reload` option to `Dataset` and `InMemoryDataset` to reload datasets ([#8352](https://github.com/pyg-team/pytorch_geometric/pull/8352), [#8357](https://github.com/pyg-team/pytorch_geometric/pull/8357))
- Added support for `torch.compile` in `MultiAggregation` ([#8345](https://github.com/pyg-team/pytorch_geometric/pull/8345))
- Added support for `torch.compile` in `HeteroConv` ([#8344](https://github.com/pyg-team/pytorch_geometric/pull/8344))
Expand Down
1 change: 0 additions & 1 deletion test/nn/conv/test_sage_conv.py
Original file line number Diff line number Diff line change
Expand Up @@ -138,7 +138,6 @@ def test_multi_aggr_sage_conv(aggr_kwargs):
def test_compile_multi_aggr_sage_conv(device):
import torch._dynamo as dynamo

device = None
x = torch.randn(4, 8, device=device)
edge_index = torch.tensor([[0, 1, 2, 3], [0, 0, 1, 1]], device=device)

Expand Down
11 changes: 6 additions & 5 deletions torch_geometric/nn/module_dict.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from typing import Iterable, Mapping, Optional, Tuple, Union
from typing import Final, Iterable, Mapping, Optional, Set, Tuple, Union

import torch
from torch.nn import Module
Expand All @@ -11,6 +11,8 @@
# internal representation and converts it back to `.` in the external
# representation. It also allows passing tuples as keys.
class ModuleDict(torch.nn.ModuleDict):
CLASS_ATTRS: Final[Set[str]] = set(dir(torch.nn.ModuleDict))

def __init__(
self,
modules: Optional[Mapping[Union[str, Tuple[str, ...]], Module]] = None,
Expand All @@ -24,14 +26,13 @@ def __init__(

@classmethod
def to_internal_key(cls, key: Key) -> str:
# ModuleDict cannot handle tuples as keys:
if isinstance(key, tuple):
if isinstance(key, tuple): # ModuleDict can't handle tuples as keys
assert len(key) > 1
key = f"<{'___'.join(key)}>"
assert isinstance(key, str)

# ModuleDict cannot handle keys that exists as class attributes:
if hasattr(cls, key):
if key in cls.CLASS_ATTRS:
key = f'<{key}>'

# ModuleDict cannot handle dots in keys:
Expand All @@ -41,7 +42,7 @@ def to_internal_key(cls, key: Key) -> str:
def to_external_key(cls, key: str) -> Key:
key = key.replace('#', '.')

if key[0] == '<' and key[-1] == '>' and hasattr(cls, key[1:-1]):
if key[0] == '<' and key[-1] == '>' and key[1:-1] in cls.CLASS_ATTRS:
key = key[1:-1]

if key[0] == '<' and key[-1] == '>' and '___' in key:
Expand Down
11 changes: 6 additions & 5 deletions torch_geometric/nn/parameter_dict.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from typing import Iterable, Mapping, Optional, Tuple, Union
from typing import Final, Iterable, Mapping, Optional, Set, Tuple, Union

import torch
from torch.nn import Parameter
Expand All @@ -11,6 +11,8 @@
# internal representation and converts it back to `.` in the external
# representation. It also allows passing tuples as keys.
class ParameterDict(torch.nn.ParameterDict):
CLASS_ATTRS: Final[Set[str]] = set(dir(torch.nn.ParameterDict))

def __init__(
self,
parameters: Optional[Mapping[Key, Parameter]] = None,
Expand All @@ -25,14 +27,13 @@ def __init__(

@classmethod
def to_internal_key(cls, key: Key) -> str:
# ParameterDict cannot handle tuples as keys:
if isinstance(key, tuple):
if isinstance(key, tuple): # ParameterDict can't handle tuples as keys
assert len(key) > 1
key = f"<{'___'.join(key)}>"
assert isinstance(key, str)

# ParameterDict cannot handle keys that exists as class attributes:
if hasattr(cls, key):
if key in cls.CLASS_ATTRS:
key = f'<{key}>'

# ParameterDict cannot handle dots in keys:
Expand All @@ -42,7 +43,7 @@ def to_internal_key(cls, key: Key) -> str:
def to_external_key(cls, key: str) -> Key:
key = key.replace('#', '.')

if key[0] == '<' and key[-1] == '>' and hasattr(cls, key[1:-1]):
if key[0] == '<' and key[-1] == '>' and key[1:-1] in cls.CLASS_ATTRS:
key = key[1:-1]

if key[0] == '<' and key[-1] == '>' and '___' in key:
Expand Down

0 comments on commit ccbbbdd

Please sign in to comment.