From 1e12d41c28b1fb9793f17646b018071b508864d7 Mon Sep 17 00:00:00 2001 From: Matthias Fey Date: Fri, 29 Sep 2023 11:57:55 +0200 Subject: [PATCH] Add `module_headers` property to `nn.Sequential` models (#8093) Fixes https://github.com/pyg-team/pytorch_geometric/issues/8082. --- CHANGELOG.md | 1 + test/nn/test_sequential.py | 16 +++++++++++----- torch_geometric/nn/sequential.jinja | 9 +++++++-- torch_geometric/nn/sequential.py | 10 +++++++++- 4 files changed, 28 insertions(+), 8 deletions(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index c54a41c005f2..190cceba0cc5 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -7,6 +7,7 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/). ### Added +- Added `module_headers` property to `nn.Sequential` models ([#8093](https://github.com/pyg-team/pytorch_geometric/pull/8093) - Added `OnDiskDataset` interface with data loader support ([#8066](https://github.com/pyg-team/pytorch_geometric/pull/8066), [#8088](https://github.com/pyg-team/pytorch_geometric/pull/8088), [#8092](https://github.com/pyg-team/pytorch_geometric/pull/8092)) - Added a tutorial for `Node2Vec` and `MetaPath2Vec` usage ([#7938](https://github.com/pyg-team/pytorch_geometric/pull/7938) - Added a tutorial for multi-GPU training with pure PyTorch ([#7894](https://github.com/pyg-team/pytorch_geometric/pull/7894) diff --git a/test/nn/test_sequential.py b/test/nn/test_sequential.py index 43adf1b27a8a..6f61fab47a36 100644 --- a/test/nn/test_sequential.py +++ b/test/nn/test_sequential.py @@ -34,11 +34,11 @@ def test_sequential(): assert len(model) == 5 assert str(model) == ( 'Sequential(\n' - ' (0): GCNConv(16, 64)\n' - ' (1): ReLU(inplace=True)\n' - ' (2): GCNConv(64, 64)\n' - ' (3): ReLU(inplace=True)\n' - ' (4): Linear(in_features=64, out_features=7, bias=True)\n' + ' (0) - GCNConv(16, 64): x, edge_index -> x\n' + ' (1) - ReLU(inplace=True): x -> x\n' + ' (2) - GCNConv(64, 64): x, edge_index -> x\n' + ' (3) - ReLU(inplace=True): x -> x\n' + ' (4) - Linear(in_features=64, out_features=7, bias=True): x -> x\n' ')') assert isinstance(model[0], GCNConv) @@ -47,6 +47,12 @@ def test_sequential(): assert isinstance(model[3], ReLU) assert isinstance(model[4], Linear) + assert model.module_headers[0] == (['x', 'edge_index'], ['x']) + assert model.module_headers[1] == (['x'], ['x']) + assert model.module_headers[2] == (['x', 'edge_index'], ['x']) + assert model.module_headers[3] == (['x'], ['x']) + assert model.module_headers[4] == (['x'], ['x']) + out = model(x, edge_index) assert out.size() == (4, 7) diff --git a/torch_geometric/nn/sequential.jinja b/torch_geometric/nn/sequential.jinja index 013d6f46cb05..16bbfaea60de 100644 --- a/torch_geometric/nn/sequential.jinja +++ b/torch_geometric/nn/sequential.jinja @@ -24,5 +24,10 @@ class {{cls_name}}(torch.nn.Module): return {{calls|length}} def __repr__(self) -> str: - return 'Sequential(\n{}\n)'.format('\n'.join( - [f' ({idx}): ' + str(self[idx]) for idx in range(len(self))])) + module_reprs = [ + (f" ({i}) - {self[i]}: {', '.join(self.module_headers[i].args)} " + f"-> {', '.join(self.module_headers[i].output)}") + for i in range(len(self)) + ] + + return 'Sequential(\n{}\n)'.format('\n'.join(module_reprs)) diff --git a/torch_geometric/nn/sequential.py b/torch_geometric/nn/sequential.py index af4d4a88f028..3a2ea7e013b6 100644 --- a/torch_geometric/nn/sequential.py +++ b/torch_geometric/nn/sequential.py @@ -1,6 +1,6 @@ import os import os.path as osp -from typing import Callable, List, Tuple, Union +from typing import Callable, List, NamedTuple, Tuple, Union from uuid import uuid1 import torch @@ -8,6 +8,11 @@ from torch_geometric.nn.conv.utils.jit import class_from_module_repr +class HeaderDesc(NamedTuple): + args: List[str] + output: List[str] + + def Sequential( input_args: str, modules: List[Union[Tuple[Callable, str], Callable]], @@ -110,6 +115,9 @@ def Sequential( # Instantiate a class from the rendered module representation. module = class_from_module_repr(cls_name, module_repr)() + module.module_headers = [ + HeaderDesc(in_desc, out_desc) for _, _, in_desc, out_desc in calls + ] module._names = list(modules.keys()) for name, submodule, _, _ in calls: setattr(module, name, submodule)