diff --git a/CHANGELOG.md b/CHANGELOG.md index 07450cf95d57..20cc6100caff 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -27,6 +27,7 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/). - Added `BaseStorage.get()` functionality ([#5240](https://github.com/pyg-team/pytorch_geometric/pull/5240)) - Added a test to confirm that `to_hetero` works with `SparseTensor` ([#5222](https://github.com/pyg-team/pytorch_geometric/pull/5222)) ### Changed +- Allow `.` in `ParameterDict` key names ([#5494](https://github.com/pyg-team/pytorch_geometric/pull/5494)) - Renamed `drop_unconnected_nodes` to `drop_unconnected_node_types` and `drop_orig_edges` to `drop_orig_edge_types` in `AddMetapaths` ([#5490](https://github.com/pyg-team/pytorch_geometric/pull/5490)) - Improved `utils.scatter` performance by explicitly choosing better implementation for `add` and `mean` reduction ([#5399](https://github.com/pyg-team/pytorch_geometric/pull/5399)) - Fix `to_dense_adj` with empty `edge_index` ([#5476](https://github.com/pyg-team/pytorch_geometric/pull/5476)) @@ -54,7 +55,7 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/). ## [2.1.0] - 2022-08-17 ### Added -- Added `CustomModuleDict` as a replacement for `torch.nn.ModuleDict` ([#5227](https://github.com/pyg-team/pytorch_geometric/pull/5227)) +- Allow `.` in `ModuleDict` key names ([#5227](https://github.com/pyg-team/pytorch_geometric/pull/5227)) - Added `edge_label_time` argument to `LinkNeighborLoader` ([#5137](https://github.com/pyg-team/pytorch_geometric/pull/5137), [#5173](https://github.com/pyg-team/pytorch_geometric/pull/5173)) - Let `ImbalancedSampler` accept `torch.Tensor` as input ([#5138](https://github.com/pyg-team/pytorch_geometric/pull/5138)) - Added `flow` argument to `gcn_norm` to correctly normalize the adjacency matrix in `GCNConv` ([#5149](https://github.com/pyg-team/pytorch_geometric/pull/5149)) diff --git a/test/nn/test_module_dict.py b/test/nn/test_module_dict.py index 6485c7b5c380..8a4df545b861 100644 --- a/test/nn/test_module_dict.py +++ b/test/nn/test_module_dict.py @@ -16,7 +16,6 @@ def test_internal_external_key_conversion(): def test_dot_syntax_keys(): - # Create a sample modules dict. modules: Mapping[str, Module] = { "lin1": torch.nn.Linear(16, 16), "model.lin2": torch.nn.Linear(8, 8), @@ -24,14 +23,11 @@ def test_dot_syntax_keys(): } module_dict = ModuleDict(modules) - expected_keys = ["lin1", "model.lin2", "model.sub_model.lin3"] - # Check the keys. - assert module_dict.keys() == expected_keys + expected_keys = {"lin1", "model.lin2", "model.sub_model.lin3"} + assert set(module_dict.keys()) == expected_keys - # Check for the existence of the keys. for key in expected_keys: assert key in module_dict - # Check deletion using keys. del module_dict["model.lin2"] assert "model.lin2" not in module_dict diff --git a/test/nn/test_parameter_dict.py b/test/nn/test_parameter_dict.py new file mode 100644 index 000000000000..34cde8c5406b --- /dev/null +++ b/test/nn/test_parameter_dict.py @@ -0,0 +1,33 @@ +from typing import Mapping + +import torch +from torch.nn import Parameter + +from torch_geometric.nn.parameter_dict import ParameterDict + + +def test_internal_external_key_conversion(): + assert ParameterDict.to_internal_key("a.b") == "a#b" + assert ParameterDict.to_internal_key("ab") == "ab" + assert ParameterDict.to_internal_key("a.b.c") == "a#b#c" + + assert ParameterDict.to_external_key("a#b") == "a.b" + assert ParameterDict.to_external_key("a#b#c") == "a.b.c" + + +def test_dot_syntax_keys(): + parameters: Mapping[str, Parameter] = { + "param1": Parameter(torch.Tensor(16, 16)), + "model.param2": Parameter(torch.Tensor(8, 8)), + "model.sub_model.param3": Parameter(torch.Tensor(4, 4)), + } + parameter_dict = ParameterDict(parameters) + + expected_keys = {"param1", "model.param2", "model.sub_model.param3"} + assert set(parameter_dict.keys()) == expected_keys + + for key in expected_keys: + assert key in parameter_dict + + del parameter_dict["model.param2"] + assert "model.param2" not in parameter_dict diff --git a/torch_geometric/nn/conv/hgt_conv.py b/torch_geometric/nn/conv/hgt_conv.py index 39bde0c06266..5ed8e3c2c85c 100644 --- a/torch_geometric/nn/conv/hgt_conv.py +++ b/torch_geometric/nn/conv/hgt_conv.py @@ -10,6 +10,8 @@ from torch_geometric.nn.conv import MessagePassing from torch_geometric.nn.dense import Linear from torch_geometric.nn.inits import glorot, ones, reset +from torch_geometric.nn.module_dict import ModuleDict +from torch_geometric.nn.parameter_dict import ParameterDict from torch_geometric.typing import EdgeType, Metadata, NodeType from torch_geometric.utils import softmax @@ -77,11 +79,11 @@ def __init__( self.heads = heads self.group = group - self.k_lin = torch.nn.ModuleDict() - self.q_lin = torch.nn.ModuleDict() - self.v_lin = torch.nn.ModuleDict() - self.a_lin = torch.nn.ModuleDict() - self.skip = torch.nn.ParameterDict() + self.k_lin = ModuleDict() + self.q_lin = ModuleDict() + self.v_lin = ModuleDict() + self.a_lin = ModuleDict() + self.skip = ParameterDict() for node_type, in_channels in self.in_channels.items(): self.k_lin[node_type] = Linear(in_channels, out_channels) self.q_lin[node_type] = Linear(in_channels, out_channels) @@ -89,9 +91,9 @@ def __init__( self.a_lin[node_type] = Linear(out_channels, out_channels) self.skip[node_type] = Parameter(torch.Tensor(1)) - self.a_rel = torch.nn.ParameterDict() - self.m_rel = torch.nn.ParameterDict() - self.p_rel = torch.nn.ParameterDict() + self.a_rel = ParameterDict() + self.m_rel = ParameterDict() + self.p_rel = ParameterDict() dim = out_channels // heads for edge_type in metadata[1]: edge_type = '__'.join(edge_type) diff --git a/torch_geometric/nn/module_dict.py b/torch_geometric/nn/module_dict.py index 3d2fbfa82e35..3c0cbef849c3 100644 --- a/torch_geometric/nn/module_dict.py +++ b/torch_geometric/nn/module_dict.py @@ -4,16 +4,16 @@ from torch.nn import Module -# `torch.nn.ModuleDict` doesn't allow `.` to be used in the keys. +# `torch.nn.ModuleDict` doesn't allow `.` to be used in key names. # This `ModuleDict` will support it by converting the `.` to `#` in the # internal representation and converts it back to `.` in the external # representation. class ModuleDict(torch.nn.ModuleDict): - def __init__(self, modules: Optional[Mapping[str, Module]] = None) -> None: + def __init__(self, modules: Optional[Mapping[str, Module]] = None): # Replace the keys in modules. if modules: modules = { - ModuleDict.to_internal_key(key): module + self.to_internal_key(key): module for key, module in modules.items() } super().__init__(modules) diff --git a/torch_geometric/nn/parameter_dict.py b/torch_geometric/nn/parameter_dict.py new file mode 100644 index 000000000000..40fd1ee8ba45 --- /dev/null +++ b/torch_geometric/nn/parameter_dict.py @@ -0,0 +1,42 @@ +from typing import Iterable, Mapping, Optional + +import torch +from torch.nn import Parameter + + +# `torch.nn.ParameterDict` doesn't allow `.` to be used in key names. +# This `ParameterDict` will support it by converting the `.` to `#` in the +# internal representation and converts it back to `.` in the external +# representation. +class ParameterDict(torch.nn.ParameterDict): + def __init__(self, parameters: Optional[Mapping[str, Parameter]] = None): + # Replace the keys in modules. + if parameters: + parameters = { + self.to_internal_key(key): module + for key, module in parameters.items() + } + super().__init__(parameters) + + @staticmethod + def to_internal_key(key: str) -> str: + return key.replace(".", "#") + + @staticmethod + def to_external_key(key: str) -> str: + return key.replace("#", ".") + + def __getitem__(self, key: str) -> Parameter: + return super().__getitem__(self.to_internal_key(key)) + + def __setitem__(self, key: str, parameter: Parameter) -> None: + return super().__setitem__(self.to_internal_key(key), parameter) + + def __delitem__(self, key: str) -> None: + return super().__delitem__(self.to_internal_key(key)) + + def __contains__(self, key: str) -> bool: + return super().__contains__(self.to_internal_key(key)) + + def keys(self) -> Iterable[str]: + return [self.to_external_key(key) for key in super().keys()]