Skip to content

Commit

Permalink
update
Browse files Browse the repository at this point in the history
  • Loading branch information
rusty1s committed Mar 20, 2023
1 parent 57711d0 commit 7b3b6fc
Show file tree
Hide file tree
Showing 4 changed files with 16 additions and 9 deletions.
1 change: 1 addition & 0 deletions test/nn/models/test_basic_gnn.py
Original file line number Diff line number Diff line change
Expand Up @@ -161,6 +161,7 @@ def test_packaging():
path = osp.join(torch.hub._get_torch_home(), 'pyg_test_package.pt')
with torch.package.PackageExporter(path) as pe:
pe.extern('torch_geometric.nn.**')
pe.extern('torch_geometric.utils.trim_to_layer')
pe.extern('_operator')
pe.save_pickle('models', 'model.pkl', model)

Expand Down
14 changes: 5 additions & 9 deletions torch_geometric/nn/models/basic_gnn.py
Original file line number Diff line number Diff line change
Expand Up @@ -141,7 +141,9 @@ def __init__(
in_channels = hidden_channels
self.lin = Linear(in_channels, self.out_channels)

self.trim = TrimToLayer()
# We define `trim_to_layer` functionality as a module such that we can
# still use `to_hetero` on-top.
self._trim = TrimToLayer()

def init_conv(self, in_channels: Union[int, Tuple[int, int]],
out_channels: int, **kwargs) -> MessagePassing:
Expand Down Expand Up @@ -187,12 +189,6 @@ def forward(
scenarios to only operate on minimal-sized representations.
(default: :obj:`None`)
"""
if (num_sampled_edges_per_hop is not None
and num_sampled_nodes_per_hop is None):
raise ValueError("'num_sampled_nodes_per_hop' needs to be given")
if (num_sampled_nodes_per_hop is not None
and num_sampled_edges_per_hop is None):
raise ValueError("'num_sampled_edges_per_hop' needs to be given")
if (num_sampled_nodes_per_hop is not None
and isinstance(edge_weight, Tensor)
and isinstance(edge_attr, Tensor)):
Expand All @@ -202,8 +198,8 @@ def forward(

xs: List[Tensor] = []
for i in range(self.num_layers):
if num_sampled_nodes_per_hop is not None:
x, edge_index, value = self.trim(
if isinstance(num_sampled_nodes_per_hop, (list, tuple)):
x, edge_index, value = self._trim(
i,
num_sampled_nodes_per_hop,
num_sampled_edges_per_hop,
Expand Down
3 changes: 3 additions & 0 deletions torch_geometric/nn/summary.py
Original file line number Diff line number Diff line change
Expand Up @@ -77,6 +77,9 @@ def hook(module, inputs, output):
name, module, depth = stack.pop()
module_id = id(module)

if name.startswith('(_'): # Do not summarize private modules.
continue

if module_id in hooks: # Avoid duplicated hooks.
hooks[module_id].remove()

Expand Down
7 changes: 7 additions & 0 deletions torch_geometric/utils/trim_to_layer.py
Original file line number Diff line number Diff line change
Expand Up @@ -103,6 +103,13 @@ def forward(
edge_attr: Optional[Tensor] = None,
) -> Tuple[Tensor, Tensor, Optional[Tensor]]:

if (num_sampled_nodes_per_hop is None
and num_sampled_edges_per_hop is not None):
raise ValueError("'num_sampled_nodes_per_hop' needs to be given")
if (num_sampled_edges_per_hop is None
and num_sampled_nodes_per_hop is not None):
raise ValueError("'num_sampled_edges_per_hop' needs to be given")

if num_sampled_nodes_per_hop is None:
return x, edge_index, edge_attr
if num_sampled_edges_per_hop is None:
Expand Down

0 comments on commit 7b3b6fc

Please sign in to comment.