Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Added return_semantic_attention_weights parameter to HANConv #5787

Merged
merged 13 commits into from
Oct 20, 2022
1 change: 1 addition & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).

## [2.2.0] - 2022-MM-DD
### Added
- Added a `return_semantic_attention_weights` argument `HANConv` ([#5787](https://github.com/pyg-team/pytorch_geometric/pull/5787))
- Added `disjoint` argument to `NeighborLoader` and `LinkNeighborLoader` ([#5775](https://github.com/pyg-team/pytorch_geometric/pull/5775))
- Added support for `input_time` in `NeighborLoader` ([#5763](https://github.com/pyg-team/pytorch_geometric/pull/5763))
- Added `disjoint` mode for temporal `LinkNeighborLoader` ([#5717](https://github.com/pyg-team/pytorch_geometric/pull/5717))
Expand Down
47 changes: 27 additions & 20 deletions torch_geometric/nn/conv/han_conv.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from typing import Dict, List, Optional, Union
from typing import Dict, List, Optional, Tuple, Union

import torch
import torch.nn.functional as F
Expand All @@ -7,23 +7,27 @@
from torch_geometric.nn.conv import MessagePassing
from torch_geometric.nn.dense import Linear
from torch_geometric.nn.inits import glorot, reset
from torch_geometric.typing import Adj, EdgeType, Metadata, NodeType
from torch_geometric.typing import Adj, EdgeType, Metadata, NodeType, OptTensor
from torch_geometric.utils import softmax


def group(xs: List[Tensor], q: nn.Parameter,
k_lin: nn.Module) -> Optional[Tensor]:
def group(
xs: List[Tensor],
q: nn.Parameter,
k_lin: nn.Module,
) -> Tuple[OptTensor, OptTensor]:

if len(xs) == 0:
return None
return None, None
else:
num_edge_types = len(xs)
out = torch.stack(xs)
if out.numel() == 0:
return out.view(0, out.size(-1))
return out.view(0, out.size(-1)), None
attn_score = (q * torch.tanh(k_lin(out)).mean(1)).sum(-1)
attn = F.softmax(attn_score, dim=0)
out = torch.sum(attn.view(num_edge_types, 1, -1) * out, dim=0)
return out
return out, attn


class HANConv(MessagePassing):
Expand Down Expand Up @@ -104,9 +108,12 @@ def reset_parameters(self):
glorot(self.q)

def forward(
self, x_dict: Dict[NodeType, Tensor],
edge_index_dict: Dict[EdgeType,
Adj]) -> Dict[NodeType, Optional[Tensor]]:
self,
x_dict: Dict[NodeType, Tensor],
edge_index_dict: Dict[EdgeType, Adj],
return_semantic_attention_weights: bool = False,
) -> Union[Dict[NodeType, OptTensor], Tuple[Dict[NodeType, OptTensor],
Dict[NodeType, OptTensor]]]:
r"""
Args:
x_dict (Dict[str, Tensor]): A dictionary holding input node
Expand All @@ -116,11 +123,10 @@ def forward(
individual edge type, either as a :obj:`torch.LongTensor` of
shape :obj:`[2, num_edges]` or a
:obj:`torch_sparse.SparseTensor`.

:rtype: :obj:`Dict[str, Optional[Tensor]]` - The output node embeddings
for each node type.
In case a node type does not receive any message, its output will
be set to :obj:`None`.
return_semantic_attention_weights (bool, optional): If set to
:obj:`True`, will additionally return the semantic-level
attention weights for each destination node type.
(default: :obj:`False`)
"""
H, D = self.heads, self.out_channels // self.heads
x_node_dict, out_dict = {}, {}
Expand Down Expand Up @@ -148,13 +154,14 @@ def forward(
out_dict[dst_type].append(out)

# iterate over node types:
semantic_attn_dict = {}
for node_type, outs in out_dict.items():
out = group(outs, self.q, self.k_lin)

if out is None:
out_dict[node_type] = None
continue
out, attn = group(outs, self.q, self.k_lin)
out_dict[node_type] = out
semantic_attn_dict[node_type] = attn

if return_semantic_attention_weights:
return out_dict, semantic_attn_dict

return out_dict

Expand Down