Skip to content

Commit

Permalink
Added return_semantic_attention_weights parameter to HANConv (#5787)
Browse files Browse the repository at this point in the history
Added return_semantic_attention_weights parameter to forward method of
HANConv, in the same way GATConv returns its attention_weights.

return_semantic_attention_weights (bool, optional): If set to
:obj:`True`, will additionally return the tensor
:obj:`semantic_attention_weights`, holding the computed attention
weights for each edge type at semantic-level attention. (default:
:obj:`None`).

Further work could be dedicated to return also node-level attention
weights as dictionary of (edge_type, attention_weights) items.

Co-authored-by: Matthias Fey <matthias.fey@tu-dortmund.de>
  • Loading branch information
manuel-dileo and rusty1s authored Oct 20, 2022
1 parent 9f86132 commit 2ffd0b7
Show file tree
Hide file tree
Showing 2 changed files with 28 additions and 20 deletions.
1 change: 1 addition & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).

## [2.2.0] - 2022-MM-DD
### Added
- Added a `return_semantic_attention_weights` argument `HANConv` ([#5787](https://github.com/pyg-team/pytorch_geometric/pull/5787))
- Added `disjoint` argument to `NeighborLoader` and `LinkNeighborLoader` ([#5775](https://github.com/pyg-team/pytorch_geometric/pull/5775))
- Added support for `input_time` in `NeighborLoader` ([#5763](https://github.com/pyg-team/pytorch_geometric/pull/5763))
- Added `disjoint` mode for temporal `LinkNeighborLoader` ([#5717](https://github.com/pyg-team/pytorch_geometric/pull/5717))
Expand Down
47 changes: 27 additions & 20 deletions torch_geometric/nn/conv/han_conv.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from typing import Dict, List, Optional, Union
from typing import Dict, List, Optional, Tuple, Union

import torch
import torch.nn.functional as F
Expand All @@ -7,23 +7,27 @@
from torch_geometric.nn.conv import MessagePassing
from torch_geometric.nn.dense import Linear
from torch_geometric.nn.inits import glorot, reset
from torch_geometric.typing import Adj, EdgeType, Metadata, NodeType
from torch_geometric.typing import Adj, EdgeType, Metadata, NodeType, OptTensor
from torch_geometric.utils import softmax


def group(xs: List[Tensor], q: nn.Parameter,
k_lin: nn.Module) -> Optional[Tensor]:
def group(
xs: List[Tensor],
q: nn.Parameter,
k_lin: nn.Module,
) -> Tuple[OptTensor, OptTensor]:

if len(xs) == 0:
return None
return None, None
else:
num_edge_types = len(xs)
out = torch.stack(xs)
if out.numel() == 0:
return out.view(0, out.size(-1))
return out.view(0, out.size(-1)), None
attn_score = (q * torch.tanh(k_lin(out)).mean(1)).sum(-1)
attn = F.softmax(attn_score, dim=0)
out = torch.sum(attn.view(num_edge_types, 1, -1) * out, dim=0)
return out
return out, attn


class HANConv(MessagePassing):
Expand Down Expand Up @@ -104,9 +108,12 @@ def reset_parameters(self):
glorot(self.q)

def forward(
self, x_dict: Dict[NodeType, Tensor],
edge_index_dict: Dict[EdgeType,
Adj]) -> Dict[NodeType, Optional[Tensor]]:
self,
x_dict: Dict[NodeType, Tensor],
edge_index_dict: Dict[EdgeType, Adj],
return_semantic_attention_weights: bool = False,
) -> Union[Dict[NodeType, OptTensor], Tuple[Dict[NodeType, OptTensor],
Dict[NodeType, OptTensor]]]:
r"""
Args:
x_dict (Dict[str, Tensor]): A dictionary holding input node
Expand All @@ -116,11 +123,10 @@ def forward(
individual edge type, either as a :obj:`torch.LongTensor` of
shape :obj:`[2, num_edges]` or a
:obj:`torch_sparse.SparseTensor`.
:rtype: :obj:`Dict[str, Optional[Tensor]]` - The output node embeddings
for each node type.
In case a node type does not receive any message, its output will
be set to :obj:`None`.
return_semantic_attention_weights (bool, optional): If set to
:obj:`True`, will additionally return the semantic-level
attention weights for each destination node type.
(default: :obj:`False`)
"""
H, D = self.heads, self.out_channels // self.heads
x_node_dict, out_dict = {}, {}
Expand Down Expand Up @@ -148,13 +154,14 @@ def forward(
out_dict[dst_type].append(out)

# iterate over node types:
semantic_attn_dict = {}
for node_type, outs in out_dict.items():
out = group(outs, self.q, self.k_lin)

if out is None:
out_dict[node_type] = None
continue
out, attn = group(outs, self.q, self.k_lin)
out_dict[node_type] = out
semantic_attn_dict[node_type] = attn

if return_semantic_attention_weights:
return out_dict, semantic_attn_dict

return out_dict

Expand Down

0 comments on commit 2ffd0b7

Please sign in to comment.