From 00310235bd511f369b9ea22d583dd041eb80ceb2 Mon Sep 17 00:00:00 2001 From: Manuel Dileo Date: Thu, 20 Oct 2022 13:17:13 +0000 Subject: [PATCH 01/12] Added return_semantic_attention_weights parameter to HANConv --- torch_geometric/nn/conv/han_conv.py | 41 ++++++++++++++++++++--------- 1 file changed, 28 insertions(+), 13 deletions(-) diff --git a/torch_geometric/nn/conv/han_conv.py b/torch_geometric/nn/conv/han_conv.py index a7e8624d2a23..3301650f6ead 100644 --- a/torch_geometric/nn/conv/han_conv.py +++ b/torch_geometric/nn/conv/han_conv.py @@ -11,8 +11,8 @@ from torch_geometric.utils import softmax -def group(xs: List[Tensor], q: nn.Parameter, - k_lin: nn.Module) -> Optional[Tensor]: +def group(xs: List[Tensor], q: nn.Parameter, k_lin: nn.Module, + return_semantic_attention_weights=None): if len(xs) == 0: return None else: @@ -23,6 +23,8 @@ def group(xs: List[Tensor], q: nn.Parameter, attn_score = (q * torch.tanh(k_lin(out)).mean(1)).sum(-1) attn = F.softmax(attn_score, dim=0) out = torch.sum(attn.view(num_edge_types, 1, -1) * out, dim=0) + if isinstance(return_semantic_attention_weights, bool): + return out, attn return out @@ -103,10 +105,9 @@ def reset_parameters(self): self.k_lin.reset_parameters() glorot(self.q) - def forward( - self, x_dict: Dict[NodeType, Tensor], - edge_index_dict: Dict[EdgeType, - Adj]) -> Dict[NodeType, Optional[Tensor]]: + def forward(self, x_dict: Dict[NodeType, Tensor], + edge_index_dict: Dict[EdgeType, Adj], + return_semantic_attention_weights=None): r""" Args: x_dict (Dict[str, Tensor]): A dictionary holding input node @@ -116,12 +117,20 @@ def forward( individual edge type, either as a :obj:`torch.LongTensor` of shape :obj:`[2, num_edges]` or a :obj:`torch_sparse.SparseTensor`. - - :rtype: :obj:`Dict[str, Optional[Tensor]]` - The output node embeddings - for each node type. - In case a node type does not receive any message, its output will - be set to :obj:`None`. + return_semantic_attention_weights (bool, optional): + If set to :obj:`True`, + will additionally return the tensor + :obj:`semantic_attention_weights`, holding the computed + attention weights for each edge type + at semantic-level attention. (default: :obj:`None`) """ + # NOTE: semantic_attention weights will be returned whenever + # `return_attention_weights` is set to a value, regardless of its + # actual value (might be `True` or `False`). This is a current somewhat + # hacky workaround to allow for TorchScript support via the + # `torch.jit._overload` decorator, as we can only change the output + # arguments conditioned on type (`None` or `bool`), not based on its + # actual value. H, D = self.heads, self.out_channels // self.heads x_node_dict, out_dict = {}, {} @@ -149,13 +158,19 @@ def forward( # iterate over node types: for node_type, outs in out_dict.items(): - out = group(outs, self.q, self.k_lin) + if isinstance(return_semantic_attention_weights, bool): + out, semantic_attention_weights = group( + outs, self.q, self.k_lin, + return_semantic_attention_weights) + else: + out = group(outs, self.q, self.k_lin) if out is None: out_dict[node_type] = None continue out_dict[node_type] = out - + if isinstance(return_semantic_attention_weights, bool): + return out_dict, semantic_attention_weights return out_dict def message(self, x_j: Tensor, alpha_i: Tensor, alpha_j: Tensor, From 27e4680d2404f0f7d394b1c181478e4fc8e23e0c Mon Sep 17 00:00:00 2001 From: Manuel Dileo Date: Thu, 20 Oct 2022 13:32:17 +0000 Subject: [PATCH 02/12] changelog --- CHANGELOG.md | 1 + 1 file changed, 1 insertion(+) diff --git a/CHANGELOG.md b/CHANGELOG.md index d0d2b1fc395b..e1ddac68b228 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -5,6 +5,7 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/). ## [2.2.0] - 2022-MM-DD ### Added +- Added `return_semantic_attention_weights` argument to `forward` method of `HANConv` ([#5787](https://github.com/pyg-team/pytorch_geometric/pull/5787)) - Added `disjoint` argument to `NeighborLoader` and `LinkNeighborLoader` ([#5775](https://github.com/pyg-team/pytorch_geometric/pull/5775)) - Added support for `input_time` in `NeighborLoader` ([#5763](https://github.com/pyg-team/pytorch_geometric/pull/5763)) - Added `disjoint` mode for temporal `LinkNeighborLoader` ([#5717](https://github.com/pyg-team/pytorch_geometric/pull/5717)) From 6677198e193f9b30addb8646d3751d2889ef80a6 Mon Sep 17 00:00:00 2001 From: Manuel Dileo Date: Thu, 20 Oct 2022 20:39:48 +0200 Subject: [PATCH 03/12] Update torch_geometric/nn/conv/han_conv.py Co-authored-by: Matthias Fey --- torch_geometric/nn/conv/han_conv.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/torch_geometric/nn/conv/han_conv.py b/torch_geometric/nn/conv/han_conv.py index 3301650f6ead..c32a4b379c2d 100644 --- a/torch_geometric/nn/conv/han_conv.py +++ b/torch_geometric/nn/conv/han_conv.py @@ -12,7 +12,8 @@ def group(xs: List[Tensor], q: nn.Parameter, k_lin: nn.Module, - return_semantic_attention_weights=None): + return_semantic_attention_weights: bool = False) -> Union[OptTensor, Tuple[Tensor, Tensor]]: + if len(xs) == 0: return None else: From e46a8568b37e0b84b74dcea3a95b8e14620eb0cc Mon Sep 17 00:00:00 2001 From: Manuel Dileo Date: Thu, 20 Oct 2022 20:40:23 +0200 Subject: [PATCH 04/12] Update torch_geometric/nn/conv/han_conv.py Co-authored-by: Matthias Fey --- torch_geometric/nn/conv/han_conv.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/torch_geometric/nn/conv/han_conv.py b/torch_geometric/nn/conv/han_conv.py index c32a4b379c2d..c107e193281d 100644 --- a/torch_geometric/nn/conv/han_conv.py +++ b/torch_geometric/nn/conv/han_conv.py @@ -24,7 +24,8 @@ def group(xs: List[Tensor], q: nn.Parameter, k_lin: nn.Module, attn_score = (q * torch.tanh(k_lin(out)).mean(1)).sum(-1) attn = F.softmax(attn_score, dim=0) out = torch.sum(attn.view(num_edge_types, 1, -1) * out, dim=0) - if isinstance(return_semantic_attention_weights, bool): + if return_semantic_attention_weights: + return out, attn return out From 039bd0dcd090793b763560719439751b8b0528aa Mon Sep 17 00:00:00 2001 From: Manuel Dileo Date: Thu, 20 Oct 2022 20:40:45 +0200 Subject: [PATCH 05/12] Update torch_geometric/nn/conv/han_conv.py Co-authored-by: Matthias Fey --- torch_geometric/nn/conv/han_conv.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/torch_geometric/nn/conv/han_conv.py b/torch_geometric/nn/conv/han_conv.py index c107e193281d..f48643e42695 100644 --- a/torch_geometric/nn/conv/han_conv.py +++ b/torch_geometric/nn/conv/han_conv.py @@ -109,7 +109,8 @@ def reset_parameters(self): def forward(self, x_dict: Dict[NodeType, Tensor], edge_index_dict: Dict[EdgeType, Adj], - return_semantic_attention_weights=None): + return_semantic_attention_weights: bool = False) -> Union[Dict[NodeType, OptTensor], TupleDict[NodeType, OptTensor], Tensor]]: + r""" Args: x_dict (Dict[str, Tensor]): A dictionary holding input node From 60a6a00b8e6f8699954705c54b063e1136f20062 Mon Sep 17 00:00:00 2001 From: Manuel Dileo Date: Thu, 20 Oct 2022 20:40:56 +0200 Subject: [PATCH 06/12] Update torch_geometric/nn/conv/han_conv.py Co-authored-by: Matthias Fey --- torch_geometric/nn/conv/han_conv.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/torch_geometric/nn/conv/han_conv.py b/torch_geometric/nn/conv/han_conv.py index f48643e42695..87685805ec7e 100644 --- a/torch_geometric/nn/conv/han_conv.py +++ b/torch_geometric/nn/conv/han_conv.py @@ -125,7 +125,8 @@ def forward(self, x_dict: Dict[NodeType, Tensor], will additionally return the tensor :obj:`semantic_attention_weights`, holding the computed attention weights for each edge type - at semantic-level attention. (default: :obj:`None`) + at semantic-level attention. (default: :obj:`False`) + """ # NOTE: semantic_attention weights will be returned whenever # `return_attention_weights` is set to a value, regardless of its From bd4dcda469e67156b940a40d2db2c14ddc34f6ed Mon Sep 17 00:00:00 2001 From: Manuel Dileo Date: Thu, 20 Oct 2022 20:41:21 +0200 Subject: [PATCH 07/12] Update torch_geometric/nn/conv/han_conv.py Co-authored-by: Matthias Fey --- torch_geometric/nn/conv/han_conv.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/torch_geometric/nn/conv/han_conv.py b/torch_geometric/nn/conv/han_conv.py index 87685805ec7e..de288ac22a7d 100644 --- a/torch_geometric/nn/conv/han_conv.py +++ b/torch_geometric/nn/conv/han_conv.py @@ -162,7 +162,8 @@ def forward(self, x_dict: Dict[NodeType, Tensor], # iterate over node types: for node_type, outs in out_dict.items(): - if isinstance(return_semantic_attention_weights, bool): + if return_semantic_attention_weights: + out, semantic_attention_weights = group( outs, self.q, self.k_lin, return_semantic_attention_weights) From 42d65ff56af0ef0b00f2e1307c0649af11134a49 Mon Sep 17 00:00:00 2001 From: Manuel Dileo Date: Thu, 20 Oct 2022 20:41:51 +0200 Subject: [PATCH 08/12] Update torch_geometric/nn/conv/han_conv.py Co-authored-by: Matthias Fey --- torch_geometric/nn/conv/han_conv.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/torch_geometric/nn/conv/han_conv.py b/torch_geometric/nn/conv/han_conv.py index de288ac22a7d..33a15fa0aa72 100644 --- a/torch_geometric/nn/conv/han_conv.py +++ b/torch_geometric/nn/conv/han_conv.py @@ -174,7 +174,8 @@ def forward(self, x_dict: Dict[NodeType, Tensor], out_dict[node_type] = None continue out_dict[node_type] = out - if isinstance(return_semantic_attention_weights, bool): + if return_semantic_attention_weights: + return out_dict, semantic_attention_weights return out_dict From cc4a33e77db1982503ac36434e316b4f029f7b29 Mon Sep 17 00:00:00 2001 From: Manuel Dileo Date: Thu, 20 Oct 2022 20:45:56 +0200 Subject: [PATCH 09/12] Update han_conv.py --- torch_geometric/nn/conv/han_conv.py | 7 ------- 1 file changed, 7 deletions(-) diff --git a/torch_geometric/nn/conv/han_conv.py b/torch_geometric/nn/conv/han_conv.py index 33a15fa0aa72..4437d3cbe9af 100644 --- a/torch_geometric/nn/conv/han_conv.py +++ b/torch_geometric/nn/conv/han_conv.py @@ -128,13 +128,6 @@ def forward(self, x_dict: Dict[NodeType, Tensor], at semantic-level attention. (default: :obj:`False`) """ - # NOTE: semantic_attention weights will be returned whenever - # `return_attention_weights` is set to a value, regardless of its - # actual value (might be `True` or `False`). This is a current somewhat - # hacky workaround to allow for TorchScript support via the - # `torch.jit._overload` decorator, as we can only change the output - # arguments conditioned on type (`None` or `bool`), not based on its - # actual value. H, D = self.heads, self.out_channels // self.heads x_node_dict, out_dict = {}, {} From 6ff16e061c14bbcc87d5a516656bccd491b59526 Mon Sep 17 00:00:00 2001 From: rusty1s Date: Thu, 20 Oct 2022 20:54:38 +0200 Subject: [PATCH 10/12] update --- CHANGELOG.md | 2 +- torch_geometric/nn/conv/han_conv.py | 57 +++++++++++++---------------- 2 files changed, 26 insertions(+), 33 deletions(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index e1ddac68b228..25e501a8bc81 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -5,7 +5,7 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/). ## [2.2.0] - 2022-MM-DD ### Added -- Added `return_semantic_attention_weights` argument to `forward` method of `HANConv` ([#5787](https://github.com/pyg-team/pytorch_geometric/pull/5787)) +- Added a `return_semantic_attention_weights` argument `HANConv` ([#5787](https://github.com/pyg-team/pytorch_geometric/pull/5787)) - Added `disjoint` argument to `NeighborLoader` and `LinkNeighborLoader` ([#5775](https://github.com/pyg-team/pytorch_geometric/pull/5775)) - Added support for `input_time` in `NeighborLoader` ([#5763](https://github.com/pyg-team/pytorch_geometric/pull/5763)) - Added `disjoint` mode for temporal `LinkNeighborLoader` ([#5717](https://github.com/pyg-team/pytorch_geometric/pull/5717)) diff --git a/torch_geometric/nn/conv/han_conv.py b/torch_geometric/nn/conv/han_conv.py index 4437d3cbe9af..0943d55b3005 100644 --- a/torch_geometric/nn/conv/han_conv.py +++ b/torch_geometric/nn/conv/han_conv.py @@ -1,4 +1,4 @@ -from typing import Dict, List, Optional, Union +from typing import Dict, List, Optional, Tuple, Union import torch import torch.nn.functional as F @@ -7,15 +7,18 @@ from torch_geometric.nn.conv import MessagePassing from torch_geometric.nn.dense import Linear from torch_geometric.nn.inits import glorot, reset -from torch_geometric.typing import Adj, EdgeType, Metadata, NodeType +from torch_geometric.typing import Adj, EdgeType, Metadata, NodeType, OptTensor from torch_geometric.utils import softmax -def group(xs: List[Tensor], q: nn.Parameter, k_lin: nn.Module, - return_semantic_attention_weights: bool = False) -> Union[OptTensor, Tuple[Tensor, Tensor]]: +def group( + xs: List[Tensor], + q: nn.Parameter, + k_lin: nn.Module, +) -> Tuple[OptTensor, OptTensor]: if len(xs) == 0: - return None + return None, None else: num_edge_types = len(xs) out = torch.stack(xs) @@ -24,10 +27,7 @@ def group(xs: List[Tensor], q: nn.Parameter, k_lin: nn.Module, attn_score = (q * torch.tanh(k_lin(out)).mean(1)).sum(-1) attn = F.softmax(attn_score, dim=0) out = torch.sum(attn.view(num_edge_types, 1, -1) * out, dim=0) - if return_semantic_attention_weights: - - return out, attn - return out + return out, attn class HANConv(MessagePassing): @@ -107,10 +107,13 @@ def reset_parameters(self): self.k_lin.reset_parameters() glorot(self.q) - def forward(self, x_dict: Dict[NodeType, Tensor], - edge_index_dict: Dict[EdgeType, Adj], - return_semantic_attention_weights: bool = False) -> Union[Dict[NodeType, OptTensor], TupleDict[NodeType, OptTensor], Tensor]]: - + def forward( + self, + x_dict: Dict[NodeType, Tensor], + edge_index_dict: Dict[EdgeType, Adj], + return_semantic_attention_weights: bool = False, + ) -> Union[Dict[NodeType, OptTensor], Tuple[Dict[NodeType, OptTensor], + Dict[NodeType, Tensor]]]: r""" Args: x_dict (Dict[str, Tensor]): A dictionary holding input node @@ -120,13 +123,10 @@ def forward(self, x_dict: Dict[NodeType, Tensor], individual edge type, either as a :obj:`torch.LongTensor` of shape :obj:`[2, num_edges]` or a :obj:`torch_sparse.SparseTensor`. - return_semantic_attention_weights (bool, optional): - If set to :obj:`True`, - will additionally return the tensor - :obj:`semantic_attention_weights`, holding the computed - attention weights for each edge type - at semantic-level attention. (default: :obj:`False`) - + return_semantic_attention_weights (bool, optional): If set to + :obj:`True`, will additionally return the semantic-level + attention weights for each destination node type. + (default: :obj:`False`) """ H, D = self.heads, self.out_channels // self.heads x_node_dict, out_dict = {}, {} @@ -154,22 +154,15 @@ def forward(self, x_dict: Dict[NodeType, Tensor], out_dict[dst_type].append(out) # iterate over node types: + semantic_attn_dict = {} for node_type, outs in out_dict.items(): - if return_semantic_attention_weights: - - out, semantic_attention_weights = group( - outs, self.q, self.k_lin, - return_semantic_attention_weights) - else: - out = group(outs, self.q, self.k_lin) - - if out is None: - out_dict[node_type] = None - continue + out, attn = group(outs, self.q, self.k_lin) out_dict[node_type] = out + semantic_attn_dict[node_type] = attn + if return_semantic_attention_weights: + return out_dict, semantic_attn_dict - return out_dict, semantic_attention_weights return out_dict def message(self, x_j: Tensor, alpha_i: Tensor, alpha_j: Tensor, From 429788099f0748e799339b6395c27157507b1c23 Mon Sep 17 00:00:00 2001 From: rusty1s Date: Thu, 20 Oct 2022 20:55:20 +0200 Subject: [PATCH 11/12] update --- torch_geometric/nn/conv/han_conv.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/torch_geometric/nn/conv/han_conv.py b/torch_geometric/nn/conv/han_conv.py index 0943d55b3005..e1e0128c7443 100644 --- a/torch_geometric/nn/conv/han_conv.py +++ b/torch_geometric/nn/conv/han_conv.py @@ -113,7 +113,7 @@ def forward( edge_index_dict: Dict[EdgeType, Adj], return_semantic_attention_weights: bool = False, ) -> Union[Dict[NodeType, OptTensor], Tuple[Dict[NodeType, OptTensor], - Dict[NodeType, Tensor]]]: + Dict[NodeType, OptTensor]]]: r""" Args: x_dict (Dict[str, Tensor]): A dictionary holding input node From 357be3bd8b3c67abb20b4dc03b95a84c36e4966c Mon Sep 17 00:00:00 2001 From: rusty1s Date: Thu, 20 Oct 2022 21:00:57 +0200 Subject: [PATCH 12/12] fix test --- torch_geometric/nn/conv/han_conv.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/torch_geometric/nn/conv/han_conv.py b/torch_geometric/nn/conv/han_conv.py index e1e0128c7443..d406147e89f2 100644 --- a/torch_geometric/nn/conv/han_conv.py +++ b/torch_geometric/nn/conv/han_conv.py @@ -23,7 +23,7 @@ def group( num_edge_types = len(xs) out = torch.stack(xs) if out.numel() == 0: - return out.view(0, out.size(-1)) + return out.view(0, out.size(-1)), None attn_score = (q * torch.tanh(k_lin(out)).mean(1)).sum(-1) attn = F.softmax(attn_score, dim=0) out = torch.sum(attn.view(num_edge_types, 1, -1) * out, dim=0)