Skip to content

Commit

Permalink
Allow . in ParameterDict key names (pyg-team#5494)
Browse files Browse the repository at this point in the history
* parameter dict

* changelog

* update

* update
  • Loading branch information
rusty1s authored and JakubPietrakIntel committed Nov 25, 2022
1 parent 36394b7 commit 419232c
Show file tree
Hide file tree
Showing 6 changed files with 92 additions and 18 deletions.
3 changes: 2 additions & 1 deletion CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
- Added `BaseStorage.get()` functionality ([#5240](https://github.com/pyg-team/pytorch_geometric/pull/5240))
- Added a test to confirm that `to_hetero` works with `SparseTensor` ([#5222](https://github.com/pyg-team/pytorch_geometric/pull/5222))
### Changed
- Allow `.` in `ParameterDict` key names ([#5494](https://github.com/pyg-team/pytorch_geometric/pull/5494))
- Renamed `drop_unconnected_nodes` to `drop_unconnected_node_types` and `drop_orig_edges` to `drop_orig_edge_types` in `AddMetapaths` ([#5490](https://github.com/pyg-team/pytorch_geometric/pull/5490))
- Improved `utils.scatter` performance by explicitly choosing better implementation for `add` and `mean` reduction ([#5399](https://github.com/pyg-team/pytorch_geometric/pull/5399))
- Fix `to_dense_adj` with empty `edge_index` ([#5476](https://github.com/pyg-team/pytorch_geometric/pull/5476))
Expand Down Expand Up @@ -54,7 +55,7 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).

## [2.1.0] - 2022-08-17
### Added
- Added `CustomModuleDict` as a replacement for `torch.nn.ModuleDict` ([#5227](https://github.com/pyg-team/pytorch_geometric/pull/5227))
- Allow `.` in `ModuleDict` key names ([#5227](https://github.com/pyg-team/pytorch_geometric/pull/5227))
- Added `edge_label_time` argument to `LinkNeighborLoader` ([#5137](https://github.com/pyg-team/pytorch_geometric/pull/5137), [#5173](https://github.com/pyg-team/pytorch_geometric/pull/5173))
- Let `ImbalancedSampler` accept `torch.Tensor` as input ([#5138](https://github.com/pyg-team/pytorch_geometric/pull/5138))
- Added `flow` argument to `gcn_norm` to correctly normalize the adjacency matrix in `GCNConv` ([#5149](https://github.com/pyg-team/pytorch_geometric/pull/5149))
Expand Down
8 changes: 2 additions & 6 deletions test/nn/test_module_dict.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,22 +16,18 @@ def test_internal_external_key_conversion():


def test_dot_syntax_keys():
# Create a sample modules dict.
modules: Mapping[str, Module] = {
"lin1": torch.nn.Linear(16, 16),
"model.lin2": torch.nn.Linear(8, 8),
"model.sub_model.lin3": torch.nn.Linear(4, 4),
}
module_dict = ModuleDict(modules)

expected_keys = ["lin1", "model.lin2", "model.sub_model.lin3"]
# Check the keys.
assert module_dict.keys() == expected_keys
expected_keys = {"lin1", "model.lin2", "model.sub_model.lin3"}
assert set(module_dict.keys()) == expected_keys

# Check for the existence of the keys.
for key in expected_keys:
assert key in module_dict

# Check deletion using keys.
del module_dict["model.lin2"]
assert "model.lin2" not in module_dict
33 changes: 33 additions & 0 deletions test/nn/test_parameter_dict.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,33 @@
from typing import Mapping

import torch
from torch.nn import Parameter

from torch_geometric.nn.parameter_dict import ParameterDict


def test_internal_external_key_conversion():
assert ParameterDict.to_internal_key("a.b") == "a#b"
assert ParameterDict.to_internal_key("ab") == "ab"
assert ParameterDict.to_internal_key("a.b.c") == "a#b#c"

assert ParameterDict.to_external_key("a#b") == "a.b"
assert ParameterDict.to_external_key("a#b#c") == "a.b.c"


def test_dot_syntax_keys():
parameters: Mapping[str, Parameter] = {
"param1": Parameter(torch.Tensor(16, 16)),
"model.param2": Parameter(torch.Tensor(8, 8)),
"model.sub_model.param3": Parameter(torch.Tensor(4, 4)),
}
parameter_dict = ParameterDict(parameters)

expected_keys = {"param1", "model.param2", "model.sub_model.param3"}
assert set(parameter_dict.keys()) == expected_keys

for key in expected_keys:
assert key in parameter_dict

del parameter_dict["model.param2"]
assert "model.param2" not in parameter_dict
18 changes: 10 additions & 8 deletions torch_geometric/nn/conv/hgt_conv.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,8 @@
from torch_geometric.nn.conv import MessagePassing
from torch_geometric.nn.dense import Linear
from torch_geometric.nn.inits import glorot, ones, reset
from torch_geometric.nn.module_dict import ModuleDict
from torch_geometric.nn.parameter_dict import ParameterDict
from torch_geometric.typing import EdgeType, Metadata, NodeType
from torch_geometric.utils import softmax

Expand Down Expand Up @@ -77,21 +79,21 @@ def __init__(
self.heads = heads
self.group = group

self.k_lin = torch.nn.ModuleDict()
self.q_lin = torch.nn.ModuleDict()
self.v_lin = torch.nn.ModuleDict()
self.a_lin = torch.nn.ModuleDict()
self.skip = torch.nn.ParameterDict()
self.k_lin = ModuleDict()
self.q_lin = ModuleDict()
self.v_lin = ModuleDict()
self.a_lin = ModuleDict()
self.skip = ParameterDict()
for node_type, in_channels in self.in_channels.items():
self.k_lin[node_type] = Linear(in_channels, out_channels)
self.q_lin[node_type] = Linear(in_channels, out_channels)
self.v_lin[node_type] = Linear(in_channels, out_channels)
self.a_lin[node_type] = Linear(out_channels, out_channels)
self.skip[node_type] = Parameter(torch.Tensor(1))

self.a_rel = torch.nn.ParameterDict()
self.m_rel = torch.nn.ParameterDict()
self.p_rel = torch.nn.ParameterDict()
self.a_rel = ParameterDict()
self.m_rel = ParameterDict()
self.p_rel = ParameterDict()
dim = out_channels // heads
for edge_type in metadata[1]:
edge_type = '__'.join(edge_type)
Expand Down
6 changes: 3 additions & 3 deletions torch_geometric/nn/module_dict.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,16 +4,16 @@
from torch.nn import Module


# `torch.nn.ModuleDict` doesn't allow `.` to be used in the keys.
# `torch.nn.ModuleDict` doesn't allow `.` to be used in key names.
# This `ModuleDict` will support it by converting the `.` to `#` in the
# internal representation and converts it back to `.` in the external
# representation.
class ModuleDict(torch.nn.ModuleDict):
def __init__(self, modules: Optional[Mapping[str, Module]] = None) -> None:
def __init__(self, modules: Optional[Mapping[str, Module]] = None):
# Replace the keys in modules.
if modules:
modules = {
ModuleDict.to_internal_key(key): module
self.to_internal_key(key): module
for key, module in modules.items()
}
super().__init__(modules)
Expand Down
42 changes: 42 additions & 0 deletions torch_geometric/nn/parameter_dict.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,42 @@
from typing import Iterable, Mapping, Optional

import torch
from torch.nn import Parameter


# `torch.nn.ParameterDict` doesn't allow `.` to be used in key names.
# This `ParameterDict` will support it by converting the `.` to `#` in the
# internal representation and converts it back to `.` in the external
# representation.
class ParameterDict(torch.nn.ParameterDict):
def __init__(self, parameters: Optional[Mapping[str, Parameter]] = None):
# Replace the keys in modules.
if parameters:
parameters = {
self.to_internal_key(key): module
for key, module in parameters.items()
}
super().__init__(parameters)

@staticmethod
def to_internal_key(key: str) -> str:
return key.replace(".", "#")

@staticmethod
def to_external_key(key: str) -> str:
return key.replace("#", ".")

def __getitem__(self, key: str) -> Parameter:
return super().__getitem__(self.to_internal_key(key))

def __setitem__(self, key: str, parameter: Parameter) -> None:
return super().__setitem__(self.to_internal_key(key), parameter)

def __delitem__(self, key: str) -> None:
return super().__delitem__(self.to_internal_key(key))

def __contains__(self, key: str) -> bool:
return super().__contains__(self.to_internal_key(key))

def keys(self) -> Iterable[str]:
return [self.to_external_key(key) for key in super().keys()]

0 comments on commit 419232c

Please sign in to comment.