Skip to content

Commit

Permalink
fixing tests
Browse files Browse the repository at this point in the history
  • Loading branch information
wwymak committed Jun 3, 2023
1 parent 352df4b commit 2b576bb
Show file tree
Hide file tree
Showing 2 changed files with 7 additions and 6 deletions.
5 changes: 3 additions & 2 deletions test/nn/test_encoding.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,11 @@
import torch

from torch_geometric.loader import NeighborLoader
from torch_geometric.nn import (
NodeEncoding,
PositionalEncoding,
TemporalEncoding,
)
from torch_geometric.testing import withCUDA
from torch_geometric.testing import onlyNeighborSampler, withCUDA


@withCUDA
Expand All @@ -28,7 +27,9 @@ def test_temporal_encoding(device):


@withCUDA
@onlyNeighborSampler
def test_node_encoding(get_dataset, device):
from torch_geometric.loader import NeighborLoader
dataset = get_dataset(name='Cora')
data = dataset[0]
loader = NeighborLoader(
Expand Down
8 changes: 4 additions & 4 deletions torch_geometric/nn/encoding.py
Original file line number Diff line number Diff line change
Expand Up @@ -103,17 +103,17 @@ class NodeEncoding(torch.nn.Module):
Model Architectures for Temporal Networks?"
<https://openreview.net/forum?id=ayPPc0SyLv1>`_ paper.
:class:`NodeEncoding` captures the node identity and node feature
information via neighbor mean-pooling
information via neighbor mean-pooling
"""
def forward(self, data):
def forward(self, data) -> Tensor:
"""
Args:
data: batch of nodes from NeighborLoader
Returns:
torch.Tensor of the root nodes updated with mean
pooling of neighbor features
"""
x, edge_index, batch_size = data.x, data.edge_index, data.batch_size
root_nodes = x[:, batch_size]
Expand Down

0 comments on commit 2b576bb

Please sign in to comment.