forked from pyg-team/pytorch_geometric
-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
* parameter dict * changelog * update * update
- Loading branch information
1 parent
36394b7
commit 419232c
Showing
6 changed files
with
92 additions
and
18 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,33 @@ | ||
from typing import Mapping | ||
|
||
import torch | ||
from torch.nn import Parameter | ||
|
||
from torch_geometric.nn.parameter_dict import ParameterDict | ||
|
||
|
||
def test_internal_external_key_conversion(): | ||
assert ParameterDict.to_internal_key("a.b") == "a#b" | ||
assert ParameterDict.to_internal_key("ab") == "ab" | ||
assert ParameterDict.to_internal_key("a.b.c") == "a#b#c" | ||
|
||
assert ParameterDict.to_external_key("a#b") == "a.b" | ||
assert ParameterDict.to_external_key("a#b#c") == "a.b.c" | ||
|
||
|
||
def test_dot_syntax_keys(): | ||
parameters: Mapping[str, Parameter] = { | ||
"param1": Parameter(torch.Tensor(16, 16)), | ||
"model.param2": Parameter(torch.Tensor(8, 8)), | ||
"model.sub_model.param3": Parameter(torch.Tensor(4, 4)), | ||
} | ||
parameter_dict = ParameterDict(parameters) | ||
|
||
expected_keys = {"param1", "model.param2", "model.sub_model.param3"} | ||
assert set(parameter_dict.keys()) == expected_keys | ||
|
||
for key in expected_keys: | ||
assert key in parameter_dict | ||
|
||
del parameter_dict["model.param2"] | ||
assert "model.param2" not in parameter_dict |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,42 @@ | ||
from typing import Iterable, Mapping, Optional | ||
|
||
import torch | ||
from torch.nn import Parameter | ||
|
||
|
||
# `torch.nn.ParameterDict` doesn't allow `.` to be used in key names. | ||
# This `ParameterDict` will support it by converting the `.` to `#` in the | ||
# internal representation and converts it back to `.` in the external | ||
# representation. | ||
class ParameterDict(torch.nn.ParameterDict): | ||
def __init__(self, parameters: Optional[Mapping[str, Parameter]] = None): | ||
# Replace the keys in modules. | ||
if parameters: | ||
parameters = { | ||
self.to_internal_key(key): module | ||
for key, module in parameters.items() | ||
} | ||
super().__init__(parameters) | ||
|
||
@staticmethod | ||
def to_internal_key(key: str) -> str: | ||
return key.replace(".", "#") | ||
|
||
@staticmethod | ||
def to_external_key(key: str) -> str: | ||
return key.replace("#", ".") | ||
|
||
def __getitem__(self, key: str) -> Parameter: | ||
return super().__getitem__(self.to_internal_key(key)) | ||
|
||
def __setitem__(self, key: str, parameter: Parameter) -> None: | ||
return super().__setitem__(self.to_internal_key(key), parameter) | ||
|
||
def __delitem__(self, key: str) -> None: | ||
return super().__delitem__(self.to_internal_key(key)) | ||
|
||
def __contains__(self, key: str) -> bool: | ||
return super().__contains__(self.to_internal_key(key)) | ||
|
||
def keys(self) -> Iterable[str]: | ||
return [self.to_external_key(key) for key in super().keys()] |