From 9bb361faf5235ffeb6735828256d160df7b09e61 Mon Sep 17 00:00:00 2001 From: ftxj <932141413@qq.com> Date: Wed, 15 Feb 2023 12:32:58 +0000 Subject: [PATCH 1/6] Make linkx model jittable --- test/nn/models/test_linkx.py | 9 +++++ torch_geometric/nn/models/linkx.py | 55 ++++++++++++++++++++++++++++-- 2 files changed, 61 insertions(+), 3 deletions(-) diff --git a/test/nn/models/test_linkx.py b/test/nn/models/test_linkx.py index 5c592c3afcaa..f15f9cd4b145 100644 --- a/test/nn/models/test_linkx.py +++ b/test/nn/models/test_linkx.py @@ -3,6 +3,7 @@ from torch_sparse import SparseTensor from torch_geometric.nn import LINKX +from torch_geometric.testing import is_full_test @pytest.mark.parametrize('num_edge_layers', [1, 2]) @@ -21,6 +22,10 @@ def test_linkx(num_edge_layers): out = model(x, edge_index) assert out.size() == (4, 8) assert torch.allclose(out, model(x, adj1.t()), atol=1e-4) + if is_full_test(): + t = '(OptTensor, SparseTensor, OptTensor) -> Tensor' + jit = torch.jit.script(model.jittable(t)) + assert torch.allclose(jit(x, adj1.t()), out) out = model(None, edge_index) assert out.size() == (4, 8) @@ -29,6 +34,10 @@ def test_linkx(num_edge_layers): out = model(x, edge_index, edge_weight) assert out.size() == (4, 8) assert torch.allclose(out, model(x, adj2.t()), atol=1e-4) + if is_full_test(): + t = '(OptTensor, SparseTensor, OptTensor) -> Tensor' + jit = torch.jit.script(model.jittable(t)) + assert torch.allclose(jit(x, adj2.t()), out) out = model(None, edge_index, edge_weight) assert out.size() == (4, 8) diff --git a/torch_geometric/nn/models/linkx.py b/torch_geometric/nn/models/linkx.py index aa16d33e4502..d8871b815b48 100644 --- a/torch_geometric/nn/models/linkx.py +++ b/torch_geometric/nn/models/linkx.py @@ -30,6 +30,16 @@ def reset_parameters(self): a=math.sqrt(5)) inits.uniform(self.in_channels, self.bias) + @torch.jit._overload_method + def forward(self, edge_index: SparseTensor, edge_weight=None): + # type: (SparseTensor, OptTensor) -> Tensor + pass + + @torch.jit._overload_method + def forward(self, edge_index: Tensor, edge_weight=None): + # type: (Tensor, OptTensor) -> Tensor + pass + def forward(self, edge_index: Adj, edge_weight: OptTensor = None) -> Tensor: # propagate_type: (weight: Tensor, edge_weight: OptTensor) @@ -102,10 +112,12 @@ def __init__( self.num_edge_layers = num_edge_layers self.edge_lin = SparseLinear(num_nodes, hidden_channels) + # just make TorchScript happy + self.edge_norm = BatchNorm1d(hidden_channels) + channels = [hidden_channels] * 2 if self.num_edge_layers > 1: - self.edge_norm = BatchNorm1d(hidden_channels) - channels = [hidden_channels] * num_edge_layers - self.edge_mlp = MLP(channels, dropout=0., act_first=True) + channels = [hidden_channels] * self.num_edge_layers + self.edge_mlp = MLP(channels, dropout=0., act_first=True) channels = [in_channels] + [hidden_channels] * num_node_layers self.node_mlp = MLP(channels, dropout=0., act_first=True) @@ -129,6 +141,16 @@ def reset_parameters(self): self.cat_lin2.reset_parameters() self.final_mlp.reset_parameters() + @torch.jit._overload_method + def forward(self, x:OptTensor, edge_index: SparseTensor, edge_weight=None): + # type: (OptTensor, SparseTensor, OptTensor) -> Tensor + pass + + @torch.jit._overload_method + def forward(self, x:OptTensor, edge_index: Tensor, edge_weight=None): + # type: (OptTensor, Tensor, OptTensor) -> Tensor + pass + def forward(self, x: OptTensor, edge_index: Adj, edge_weight: OptTensor = None) -> Tensor: """""" @@ -147,6 +169,33 @@ def forward(self, x: OptTensor, edge_index: Adj, return self.final_mlp(out.relu_()) + def jittable(self, typing: str): + edge_index_type = typing.split(",")[1] + print(self.edge_lin) + class Jittable(torch.nn.Module): + def __init__(self, unjittable): + super().__init__() + self.sub_module = unjittable + def forward(self, x: Tensor, edge_index: Tensor, edge_weight: OptTensor = None ): + return self.sub_module(x, edge_index, edge_weight) + + class SparseJittable(torch.nn.Module): + def __init__(self, unjittable): + super().__init__() + self.sub_module = unjittable + def forward(self, x: Tensor, edge_index: SparseTensor, edge_weight: OptTensor = None): + return self.sub_module(x, edge_index, edge_weight) + + if self.edge_lin.jittable is not None: + edge_lin_typin = '(' + edge_index_type + ", OptTensor) -> Tensor" + self.edge_lin = self.edge_lin.jittable() + + if "SparseTensor" in edge_index_type: + jittable_module = SparseJittable(self) + elif "Tensor" in edge_index_type: + jittable_module = Jittable(self) + return jittable_module + def __repr__(self) -> str: return (f'{self.__class__.__name__}(num_nodes={self.num_nodes}, ' f'in_channels={self.in_channels}, ' From d53b158e4ffe46af741058f31a2294b0cc64fa4c Mon Sep 17 00:00:00 2001 From: ftxj <932141413@qq.com> Date: Wed, 15 Feb 2023 12:56:01 +0000 Subject: [PATCH 2/6] Format --- examples/jit/linkx.py | 52 ++++++++++++++++++++++++++++++ torch_geometric/nn/models/linkx.py | 19 ++++++----- torch_geometric/nn/models/mlp.py | 4 +-- 3 files changed, 65 insertions(+), 10 deletions(-) create mode 100644 examples/jit/linkx.py diff --git a/examples/jit/linkx.py b/examples/jit/linkx.py new file mode 100644 index 000000000000..05f7593a7458 --- /dev/null +++ b/examples/jit/linkx.py @@ -0,0 +1,52 @@ +import os.path as osp + +import torch +import torch.nn.functional as F + +from torch_geometric.datasets import LINKXDataset +from torch_geometric.nn import LINKX + +device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') + +path = osp.join(osp.dirname(osp.realpath(__file__)), '..', 'data', 'LINKX') +dataset = LINKXDataset(path, name='Penn94') +# torch._C._jit_set_nvfuser_enabled(True) + +data = dataset[0].to(device) + +t = '(OptTensor, Tensor, OptTensor) -> Tensor' +model = LINKX(data.num_nodes, data.num_features, hidden_channels=32, + out_channels=dataset.num_classes, num_layers=1, + num_edge_layers=1, num_node_layers=1, + dropout=0.5).jittable(t).to(device) +model = torch.jit.script(model) +optimizer = torch.optim.Adam(model.parameters(), lr=0.01, weight_decay=1e-3) + + +def train(): + model.train() + optimizer.zero_grad() + out = model(data.x, data.edge_index) + mask = data.train_mask[:, 0] # Use the first set of the five masks. + loss = F.cross_entropy(out[mask], data.y[mask]) + loss.backward() + optimizer.step() + return float(loss) + + +@torch.no_grad() +def test(): + accs = [] + model.eval() + pred = model(data.x, data.edge_index).argmax(dim=-1) + for _, mask in data('train_mask', 'val_mask', 'test_mask'): + mask = mask[:, 0] # Use the first set of the five masks. + accs.append(int((pred[mask] == data.y[mask]).sum()) / int(mask.sum())) + return accs + + +for epoch in range(1, 201): + loss = train() + train_acc, val_acc, test_acc = test() + print(f'Epoch: {epoch:03d}, Loss: {loss:.4f}, Train: {train_acc:.4f}, ' + f'Val: {val_acc:.4f}, Test: {test_acc:.4f}') diff --git a/torch_geometric/nn/models/linkx.py b/torch_geometric/nn/models/linkx.py index d8871b815b48..982741c3e86b 100644 --- a/torch_geometric/nn/models/linkx.py +++ b/torch_geometric/nn/models/linkx.py @@ -114,7 +114,7 @@ def __init__( self.edge_lin = SparseLinear(num_nodes, hidden_channels) # just make TorchScript happy self.edge_norm = BatchNorm1d(hidden_channels) - channels = [hidden_channels] * 2 + channels = [hidden_channels] * 2 if self.num_edge_layers > 1: channels = [hidden_channels] * self.num_edge_layers self.edge_mlp = MLP(channels, dropout=0., act_first=True) @@ -142,12 +142,13 @@ def reset_parameters(self): self.final_mlp.reset_parameters() @torch.jit._overload_method - def forward(self, x:OptTensor, edge_index: SparseTensor, edge_weight=None): + def forward(self, x: OptTensor, edge_index: SparseTensor, + edge_weight=None): # type: (OptTensor, SparseTensor, OptTensor) -> Tensor pass @torch.jit._overload_method - def forward(self, x:OptTensor, edge_index: Tensor, edge_weight=None): + def forward(self, x: OptTensor, edge_index: Tensor, edge_weight=None): # type: (OptTensor, Tensor, OptTensor) -> Tensor pass @@ -171,25 +172,27 @@ def forward(self, x: OptTensor, edge_index: Adj, def jittable(self, typing: str): edge_index_type = typing.split(",")[1] - print(self.edge_lin) + class Jittable(torch.nn.Module): def __init__(self, unjittable): super().__init__() self.sub_module = unjittable - def forward(self, x: Tensor, edge_index: Tensor, edge_weight: OptTensor = None ): + + def forward(self, x: Tensor, edge_index: Tensor, + edge_weight: OptTensor = None): return self.sub_module(x, edge_index, edge_weight) class SparseJittable(torch.nn.Module): def __init__(self, unjittable): super().__init__() self.sub_module = unjittable - def forward(self, x: Tensor, edge_index: SparseTensor, edge_weight: OptTensor = None): + + def forward(self, x: Tensor, edge_index: SparseTensor, + edge_weight: OptTensor = None): return self.sub_module(x, edge_index, edge_weight) if self.edge_lin.jittable is not None: - edge_lin_typin = '(' + edge_index_type + ", OptTensor) -> Tensor" self.edge_lin = self.edge_lin.jittable() - if "SparseTensor" in edge_index_type: jittable_module = SparseJittable(self) elif "Tensor" in edge_index_type: diff --git a/torch_geometric/nn/models/mlp.py b/torch_geometric/nn/models/mlp.py index 26356b09e341..2f0db75a70b8 100644 --- a/torch_geometric/nn/models/mlp.py +++ b/torch_geometric/nn/models/mlp.py @@ -1,5 +1,5 @@ import warnings -from typing import Any, Callable, Dict, List, Optional, Tuple, Union +from typing import Any, Callable, Dict, List, Optional, Union import torch import torch.nn.functional as F @@ -189,7 +189,7 @@ def forward( self, x: Tensor, return_emb: NoneType = None, - ) -> Union[Tensor, Tuple[Tensor, Tensor]]: + ): r""" Args: x (torch.Tensor): The source tensor. From 2de8c1984e5fcfbb4b39ef810623a9495d459dff Mon Sep 17 00:00:00 2001 From: ftxj <932141413@qq.com> Date: Wed, 15 Feb 2023 13:09:41 +0000 Subject: [PATCH 3/6] ChangeLog --- CHANGELOG.md | 1 + 1 file changed, 1 insertion(+) diff --git a/CHANGELOG.md b/CHANGELOG.md index 274db40ce143..a5fc9bde7d06 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -7,6 +7,7 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/). ### Added +- Make `LINKX` model to be jittable([#6712](https://github.com/pyg-team/pytorch_geometric/pull/6712)) - Added `torch.jit` examples for `example/film.py` and `example/gcn.py`([#6602](https://github.com/pyg-team/pytorch_geometric/pull/6692)) - Added `Pad` transform ([#5940](https://github.com/pyg-team/pytorch_geometric/pull/5940)) - Added full batch mode to the inference benchmark ([#6631](https://github.com/pyg-team/pytorch_geometric/pull/6631)) From 4efebfa5cb4681056936ad98d586478068859733 Mon Sep 17 00:00:00 2001 From: rusty1s Date: Wed, 15 Feb 2023 14:26:18 +0000 Subject: [PATCH 4/6] update --- CHANGELOG.md | 2 +- torch_geometric/nn/models/linkx.py | 82 +++++++++++++++++++----------- 2 files changed, 53 insertions(+), 31 deletions(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index a5fc9bde7d06..fb721efca15e 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -7,7 +7,7 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/). ### Added -- Make `LINKX` model to be jittable([#6712](https://github.com/pyg-team/pytorch_geometric/pull/6712)) +- Added TorchScript support to the `LINKX` model ([#6712](https://github.com/pyg-team/pytorch_geometric/pull/6712)) - Added `torch.jit` examples for `example/film.py` and `example/gcn.py`([#6602](https://github.com/pyg-team/pytorch_geometric/pull/6692)) - Added `Pad` transform ([#5940](https://github.com/pyg-team/pytorch_geometric/pull/5940)) - Added full batch mode to the inference benchmark ([#6631](https://github.com/pyg-team/pytorch_geometric/pull/6631)) diff --git a/torch_geometric/nn/models/linkx.py b/torch_geometric/nn/models/linkx.py index 982741c3e86b..dce1eaae24b3 100644 --- a/torch_geometric/nn/models/linkx.py +++ b/torch_geometric/nn/models/linkx.py @@ -31,22 +31,27 @@ def reset_parameters(self): inits.uniform(self.in_channels, self.bias) @torch.jit._overload_method - def forward(self, edge_index: SparseTensor, edge_weight=None): + def forward(self, edge_index, edge_weight=None): # type: (SparseTensor, OptTensor) -> Tensor pass @torch.jit._overload_method - def forward(self, edge_index: Tensor, edge_weight=None): + def forward(self, edge_index, edge_weight=None): # type: (Tensor, OptTensor) -> Tensor pass - def forward(self, edge_index: Adj, - edge_weight: OptTensor = None) -> Tensor: + def forward( + self, + edge_index: Adj, + edge_weight: OptTensor = None, + ) -> Tensor: # propagate_type: (weight: Tensor, edge_weight: OptTensor) out = self.propagate(edge_index, weight=self.weight, edge_weight=edge_weight, size=None) + if self.bias is not None: out = out + self.bias + return out def message(self, weight_j: Tensor, edge_weight: OptTensor) -> Tensor: @@ -112,12 +117,14 @@ def __init__( self.num_edge_layers = num_edge_layers self.edge_lin = SparseLinear(num_nodes, hidden_channels) - # just make TorchScript happy - self.edge_norm = BatchNorm1d(hidden_channels) - channels = [hidden_channels] * 2 + if self.num_edge_layers > 1: + self.edge_norm = BatchNorm1d(hidden_channels) channels = [hidden_channels] * self.num_edge_layers - self.edge_mlp = MLP(channels, dropout=0., act_first=True) + self.edge_mlp = MLP(channels, dropout=0., act_first=True) + else: + self.edge_norm = None + self.edge_mlp = None channels = [in_channels] + [hidden_channels] * num_node_layers self.node_mlp = MLP(channels, dropout=0., act_first=True) @@ -133,8 +140,9 @@ def __init__( def reset_parameters(self): r"""Resets all learnable parameters of the module.""" self.edge_lin.reset_parameters() - if self.num_edge_layers > 1: + if self.edge_norm is not None: self.edge_norm.reset_parameters() + if self.edge_mlp is not None: self.edge_mlp.reset_parameters() self.node_mlp.reset_parameters() self.cat_lin1.reset_parameters() @@ -142,21 +150,25 @@ def reset_parameters(self): self.final_mlp.reset_parameters() @torch.jit._overload_method - def forward(self, x: OptTensor, edge_index: SparseTensor, - edge_weight=None): + def forward(self, x, edge_index, edge_weight=None): # type: (OptTensor, SparseTensor, OptTensor) -> Tensor pass @torch.jit._overload_method - def forward(self, x: OptTensor, edge_index: Tensor, edge_weight=None): + def forward(self, x, edge_index, edge_weight=None): # type: (OptTensor, Tensor, OptTensor) -> Tensor pass - def forward(self, x: OptTensor, edge_index: Adj, - edge_weight: OptTensor = None) -> Tensor: + def forward( + self, + x: OptTensor, + edge_index: Adj, + edge_weight: OptTensor = None, + ) -> Tensor: """""" out = self.edge_lin(edge_index, edge_weight) - if self.num_edge_layers > 1: + + if self.edge_norm is not None and self.edge_mlp is not None: out = out.relu_() out = self.edge_norm(out) out = self.edge_mlp(out) @@ -170,33 +182,43 @@ def forward(self, x: OptTensor, edge_index: Adj, return self.final_mlp(out.relu_()) - def jittable(self, typing: str): - edge_index_type = typing.split(",")[1] + def jittable(self, typing: str) -> torch.nn.Module: + edge_index_type = typing.split(',')[1].strip() - class Jittable(torch.nn.Module): - def __init__(self, unjittable): + class EdgeIndexJittable(torch.nn.Module): + def __init__(self, child): super().__init__() - self.sub_module = unjittable + self.child = child + + def reset_parameters(self): + self.child.reset_parameters() def forward(self, x: Tensor, edge_index: Tensor, - edge_weight: OptTensor = None): - return self.sub_module(x, edge_index, edge_weight) + edge_weight: OptTensor = None) -> Tensor: + return self.child(x, edge_index, edge_weight) - class SparseJittable(torch.nn.Module): - def __init__(self, unjittable): + class SparseTensorJittable(torch.nn.Module): + def __init__(self, child): super().__init__() - self.sub_module = unjittable + self.child = child + + def reset_parameters(self): + self.child.reset_parameters() def forward(self, x: Tensor, edge_index: SparseTensor, edge_weight: OptTensor = None): - return self.sub_module(x, edge_index, edge_weight) + return self.child(x, edge_index, edge_weight) if self.edge_lin.jittable is not None: self.edge_lin = self.edge_lin.jittable() - if "SparseTensor" in edge_index_type: - jittable_module = SparseJittable(self) - elif "Tensor" in edge_index_type: - jittable_module = Jittable(self) + + if 'Tensor' == edge_index_type: + jittable_module = EdgeIndexJittable(self) + elif 'SparseTensor' == edge_index_type: + jittable_module = SparseTensorJittable(self) + else: + raise ValueError(f"Could not parse types '{typing}'") + return jittable_module def __repr__(self) -> str: From ad121a9ba5d6c3f7480412fb677a0e8135ac2c6e Mon Sep 17 00:00:00 2001 From: rusty1s Date: Wed, 15 Feb 2023 14:28:17 +0000 Subject: [PATCH 5/6] update --- test/nn/models/test_linkx.py | 9 +++++---- torch_geometric/nn/models/linkx.py | 2 +- 2 files changed, 6 insertions(+), 5 deletions(-) diff --git a/test/nn/models/test_linkx.py b/test/nn/models/test_linkx.py index f15f9cd4b145..d0ed377bc26a 100644 --- a/test/nn/models/test_linkx.py +++ b/test/nn/models/test_linkx.py @@ -22,7 +22,12 @@ def test_linkx(num_edge_layers): out = model(x, edge_index) assert out.size() == (4, 8) assert torch.allclose(out, model(x, adj1.t()), atol=1e-4) + if is_full_test(): + t = '(OptTensor, Tensor, OptTensor) -> Tensor' + jit = torch.jit.script(model.jittable(t)) + assert torch.allclose(jit(x, edge_index), out) + t = '(OptTensor, SparseTensor, OptTensor) -> Tensor' jit = torch.jit.script(model.jittable(t)) assert torch.allclose(jit(x, adj1.t()), out) @@ -34,10 +39,6 @@ def test_linkx(num_edge_layers): out = model(x, edge_index, edge_weight) assert out.size() == (4, 8) assert torch.allclose(out, model(x, adj2.t()), atol=1e-4) - if is_full_test(): - t = '(OptTensor, SparseTensor, OptTensor) -> Tensor' - jit = torch.jit.script(model.jittable(t)) - assert torch.allclose(jit(x, adj2.t()), out) out = model(None, edge_index, edge_weight) assert out.size() == (4, 8) diff --git a/torch_geometric/nn/models/linkx.py b/torch_geometric/nn/models/linkx.py index dce1eaae24b3..f85dabc76e5f 100644 --- a/torch_geometric/nn/models/linkx.py +++ b/torch_geometric/nn/models/linkx.py @@ -182,7 +182,7 @@ def forward( return self.final_mlp(out.relu_()) - def jittable(self, typing: str) -> torch.nn.Module: + def jittable(self, typing: str) -> torch.nn.Module: # pragma: no cover edge_index_type = typing.split(',')[1].strip() class EdgeIndexJittable(torch.nn.Module): From 6b4c2db04cb000bb63c60046fb12c96674ffd8f2 Mon Sep 17 00:00:00 2001 From: rusty1s Date: Wed, 15 Feb 2023 14:29:54 +0000 Subject: [PATCH 6/6] update --- examples/jit/linkx.py | 52 ------------------------------ torch_geometric/nn/models/linkx.py | 2 +- torch_geometric/nn/models/mlp.py | 2 +- 3 files changed, 2 insertions(+), 54 deletions(-) delete mode 100644 examples/jit/linkx.py diff --git a/examples/jit/linkx.py b/examples/jit/linkx.py deleted file mode 100644 index 05f7593a7458..000000000000 --- a/examples/jit/linkx.py +++ /dev/null @@ -1,52 +0,0 @@ -import os.path as osp - -import torch -import torch.nn.functional as F - -from torch_geometric.datasets import LINKXDataset -from torch_geometric.nn import LINKX - -device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') - -path = osp.join(osp.dirname(osp.realpath(__file__)), '..', 'data', 'LINKX') -dataset = LINKXDataset(path, name='Penn94') -# torch._C._jit_set_nvfuser_enabled(True) - -data = dataset[0].to(device) - -t = '(OptTensor, Tensor, OptTensor) -> Tensor' -model = LINKX(data.num_nodes, data.num_features, hidden_channels=32, - out_channels=dataset.num_classes, num_layers=1, - num_edge_layers=1, num_node_layers=1, - dropout=0.5).jittable(t).to(device) -model = torch.jit.script(model) -optimizer = torch.optim.Adam(model.parameters(), lr=0.01, weight_decay=1e-3) - - -def train(): - model.train() - optimizer.zero_grad() - out = model(data.x, data.edge_index) - mask = data.train_mask[:, 0] # Use the first set of the five masks. - loss = F.cross_entropy(out[mask], data.y[mask]) - loss.backward() - optimizer.step() - return float(loss) - - -@torch.no_grad() -def test(): - accs = [] - model.eval() - pred = model(data.x, data.edge_index).argmax(dim=-1) - for _, mask in data('train_mask', 'val_mask', 'test_mask'): - mask = mask[:, 0] # Use the first set of the five masks. - accs.append(int((pred[mask] == data.y[mask]).sum()) / int(mask.sum())) - return accs - - -for epoch in range(1, 201): - loss = train() - train_acc, val_acc, test_acc = test() - print(f'Epoch: {epoch:03d}, Loss: {loss:.4f}, Train: {train_acc:.4f}, ' - f'Val: {val_acc:.4f}, Test: {test_acc:.4f}') diff --git a/torch_geometric/nn/models/linkx.py b/torch_geometric/nn/models/linkx.py index f85dabc76e5f..5804f532c3ba 100644 --- a/torch_geometric/nn/models/linkx.py +++ b/torch_geometric/nn/models/linkx.py @@ -120,7 +120,7 @@ def __init__( if self.num_edge_layers > 1: self.edge_norm = BatchNorm1d(hidden_channels) - channels = [hidden_channels] * self.num_edge_layers + channels = [hidden_channels] * num_edge_layers self.edge_mlp = MLP(channels, dropout=0., act_first=True) else: self.edge_norm = None diff --git a/torch_geometric/nn/models/mlp.py b/torch_geometric/nn/models/mlp.py index 2f0db75a70b8..402a6c087363 100644 --- a/torch_geometric/nn/models/mlp.py +++ b/torch_geometric/nn/models/mlp.py @@ -189,7 +189,7 @@ def forward( self, x: Tensor, return_emb: NoneType = None, - ): + ) -> Tensor: r""" Args: x (torch.Tensor): The source tensor.