From ccbbbdd7c725a6c2c0e602c3fc99eea5f55af09f Mon Sep 17 00:00:00 2001 From: Akihiro Nitta Date: Tue, 14 Nov 2023 16:36:44 +0900 Subject: [PATCH] Avoid a graph break in `ModuleDict` and `ParameterDict` (#8363) We have a graph break at `hasattr` call in `ModuleDict.to_internal_key`. Repro: ```python import torch from torch_geometric.nn.module_dict import ModuleDict edge_type = ("a", "to", "b") class SomeModel(torch.nn.Module): def __init__(self): super().__init__() self.module_dict = ModuleDict({ edge_type: torch.nn.Linear(1, 1), }) def forward(self, x): # need to convert tuple to string in advance to avoid a graph break # due to https://github.com/pytorch/pytorch/issues/111551 key = ModuleDict.to_internal_key(edge_type) x = self.module_dict[key](x) return x from torch._dynamo.utils import CompileProfiler model = SomeModel() with CompileProfiler() as prof: model = torch.compile(model) model(torch.randn(1, 1)) print(prof.report()) ``` --------- Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> Co-authored-by: rusty1s --- CHANGELOG.md | 1 + test/nn/conv/test_sage_conv.py | 1 - torch_geometric/nn/module_dict.py | 11 ++++++----- torch_geometric/nn/parameter_dict.py | 11 ++++++----- 4 files changed, 13 insertions(+), 11 deletions(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index 485eeb5472b3..e29d3aae55a8 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -7,6 +7,7 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/). ### Added +- Added support for `torch.compile` in `ModuleDict` and `ParameterDict` ([#8363](https://github.com/pyg-team/pytorch_geometric/pull/8363)) - Added `force_reload` option to `Dataset` and `InMemoryDataset` to reload datasets ([#8352](https://github.com/pyg-team/pytorch_geometric/pull/8352), [#8357](https://github.com/pyg-team/pytorch_geometric/pull/8357)) - Added support for `torch.compile` in `MultiAggregation` ([#8345](https://github.com/pyg-team/pytorch_geometric/pull/8345)) - Added support for `torch.compile` in `HeteroConv` ([#8344](https://github.com/pyg-team/pytorch_geometric/pull/8344)) diff --git a/test/nn/conv/test_sage_conv.py b/test/nn/conv/test_sage_conv.py index dc54044456ef..1a3fd23337f8 100644 --- a/test/nn/conv/test_sage_conv.py +++ b/test/nn/conv/test_sage_conv.py @@ -138,7 +138,6 @@ def test_multi_aggr_sage_conv(aggr_kwargs): def test_compile_multi_aggr_sage_conv(device): import torch._dynamo as dynamo - device = None x = torch.randn(4, 8, device=device) edge_index = torch.tensor([[0, 1, 2, 3], [0, 0, 1, 1]], device=device) diff --git a/torch_geometric/nn/module_dict.py b/torch_geometric/nn/module_dict.py index a41775061733..2f69e5eb7c53 100644 --- a/torch_geometric/nn/module_dict.py +++ b/torch_geometric/nn/module_dict.py @@ -1,4 +1,4 @@ -from typing import Iterable, Mapping, Optional, Tuple, Union +from typing import Final, Iterable, Mapping, Optional, Set, Tuple, Union import torch from torch.nn import Module @@ -11,6 +11,8 @@ # internal representation and converts it back to `.` in the external # representation. It also allows passing tuples as keys. class ModuleDict(torch.nn.ModuleDict): + CLASS_ATTRS: Final[Set[str]] = set(dir(torch.nn.ModuleDict)) + def __init__( self, modules: Optional[Mapping[Union[str, Tuple[str, ...]], Module]] = None, @@ -24,14 +26,13 @@ def __init__( @classmethod def to_internal_key(cls, key: Key) -> str: - # ModuleDict cannot handle tuples as keys: - if isinstance(key, tuple): + if isinstance(key, tuple): # ModuleDict can't handle tuples as keys assert len(key) > 1 key = f"<{'___'.join(key)}>" assert isinstance(key, str) # ModuleDict cannot handle keys that exists as class attributes: - if hasattr(cls, key): + if key in cls.CLASS_ATTRS: key = f'<{key}>' # ModuleDict cannot handle dots in keys: @@ -41,7 +42,7 @@ def to_internal_key(cls, key: Key) -> str: def to_external_key(cls, key: str) -> Key: key = key.replace('#', '.') - if key[0] == '<' and key[-1] == '>' and hasattr(cls, key[1:-1]): + if key[0] == '<' and key[-1] == '>' and key[1:-1] in cls.CLASS_ATTRS: key = key[1:-1] if key[0] == '<' and key[-1] == '>' and '___' in key: diff --git a/torch_geometric/nn/parameter_dict.py b/torch_geometric/nn/parameter_dict.py index e492f3e9bb01..a2dafc27c390 100644 --- a/torch_geometric/nn/parameter_dict.py +++ b/torch_geometric/nn/parameter_dict.py @@ -1,4 +1,4 @@ -from typing import Iterable, Mapping, Optional, Tuple, Union +from typing import Final, Iterable, Mapping, Optional, Set, Tuple, Union import torch from torch.nn import Parameter @@ -11,6 +11,8 @@ # internal representation and converts it back to `.` in the external # representation. It also allows passing tuples as keys. class ParameterDict(torch.nn.ParameterDict): + CLASS_ATTRS: Final[Set[str]] = set(dir(torch.nn.ParameterDict)) + def __init__( self, parameters: Optional[Mapping[Key, Parameter]] = None, @@ -25,14 +27,13 @@ def __init__( @classmethod def to_internal_key(cls, key: Key) -> str: - # ParameterDict cannot handle tuples as keys: - if isinstance(key, tuple): + if isinstance(key, tuple): # ParameterDict can't handle tuples as keys assert len(key) > 1 key = f"<{'___'.join(key)}>" assert isinstance(key, str) # ParameterDict cannot handle keys that exists as class attributes: - if hasattr(cls, key): + if key in cls.CLASS_ATTRS: key = f'<{key}>' # ParameterDict cannot handle dots in keys: @@ -42,7 +43,7 @@ def to_internal_key(cls, key: Key) -> str: def to_external_key(cls, key: str) -> Key: key = key.replace('#', '.') - if key[0] == '<' and key[-1] == '>' and hasattr(cls, key[1:-1]): + if key[0] == '<' and key[-1] == '>' and key[1:-1] in cls.CLASS_ATTRS: key = key[1:-1] if key[0] == '<' and key[-1] == '>' and '___' in key: