Skip to content

Commit

Permalink
update
Browse files Browse the repository at this point in the history
  • Loading branch information
rusty1s committed Nov 14, 2023
1 parent 171e494 commit c2f1378
Show file tree
Hide file tree
Showing 3 changed files with 10 additions and 11 deletions.
2 changes: 1 addition & 1 deletion CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -7,8 +7,8 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).

### Added

- Added support for `torch.compile` in `ModuleDict` and `ParameterDict` ([#8363](https://github.com/pyg-team/pytorch_geometric/pull/8363))
- Added `force_reload` option to `Dataset` and `InMemoryDataset` to reload datasets ([#8352](https://github.com/pyg-team/pytorch_geometric/pull/8352), [#8357](https://github.com/pyg-team/pytorch_geometric/pull/8357))
- Added support for `torch.compile` in `ModuleDict` when keys are `str` ([#8363](https://github.com/pyg-team/pytorch_geometric/pull/8363))
- Added support for `torch.compile` in `MultiAggregation` ([#8345](https://github.com/pyg-team/pytorch_geometric/pull/8345))
- Added support for `torch.compile` in `HeteroConv` ([#8344](https://github.com/pyg-team/pytorch_geometric/pull/8344))
- Added support for weighted `sparse_cross_entropy` ([#8340](https://github.com/pyg-team/pytorch_geometric/pull/8340))
Expand Down
8 changes: 3 additions & 5 deletions torch_geometric/nn/module_dict.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from typing import Iterable, Mapping, Optional, Tuple, Union
from typing import Final, Iterable, Mapping, Optional, Set, Tuple, Union

import torch
from torch.nn import Module
Expand All @@ -11,7 +11,7 @@
# internal representation and converts it back to `.` in the external
# representation. It also allows passing tuples as keys.
class ModuleDict(torch.nn.ModuleDict):
CLASS_ATTRS = dir(torch.nn.ModuleDict)
CLASS_ATTRS: Final[Set[str]] = set(dir(torch.nn.ModuleDict))

def __init__(
self,
Expand All @@ -26,11 +26,9 @@ def __init__(

@classmethod
def to_internal_key(cls, key: Key) -> str:
if isinstance(key, tuple):
# ModuleDict cannot handle tuples as keys:
if isinstance(key, tuple): # ModuleDict can't handle tuples as keys
assert len(key) > 1
key = f"<{'___'.join(key)}>"

assert isinstance(key, str)

# ModuleDict cannot handle keys that exists as class attributes:
Expand Down
11 changes: 6 additions & 5 deletions torch_geometric/nn/parameter_dict.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from typing import Iterable, Mapping, Optional, Tuple, Union
from typing import Final, Iterable, Mapping, Optional, Set, Tuple, Union

import torch
from torch.nn import Parameter
Expand All @@ -11,6 +11,8 @@
# internal representation and converts it back to `.` in the external
# representation. It also allows passing tuples as keys.
class ParameterDict(torch.nn.ParameterDict):
CLASS_ATTRS: Final[Set[str]] = set(dir(torch.nn.ParameterDict))

def __init__(
self,
parameters: Optional[Mapping[Key, Parameter]] = None,
Expand All @@ -25,14 +27,13 @@ def __init__(

@classmethod
def to_internal_key(cls, key: Key) -> str:
# ParameterDict cannot handle tuples as keys:
if isinstance(key, tuple):
if isinstance(key, tuple): # ParameterDict can't handle tuples as keys
assert len(key) > 1
key = f"<{'___'.join(key)}>"
assert isinstance(key, str)

# ParameterDict cannot handle keys that exists as class attributes:
if hasattr(cls, key):
if key in cls.CLASS_ATTRS:
key = f'<{key}>'

# ParameterDict cannot handle dots in keys:
Expand All @@ -42,7 +43,7 @@ def to_internal_key(cls, key: Key) -> str:
def to_external_key(cls, key: str) -> Key:
key = key.replace('#', '.')

if key[0] == '<' and key[-1] == '>' and hasattr(cls, key[1:-1]):
if key[0] == '<' and key[-1] == '>' and key[1:-1] in cls.CLASS_ATTRS:
key = key[1:-1]

if key[0] == '<' and key[-1] == '>' and '___' in key:
Expand Down

0 comments on commit c2f1378

Please sign in to comment.