Skip to content

Commit

Permalink
update
Browse files Browse the repository at this point in the history
  • Loading branch information
rusty1s committed Dec 17, 2022
1 parent 0d0ac96 commit f490606
Show file tree
Hide file tree
Showing 3 changed files with 28 additions and 15 deletions.
13 changes: 8 additions & 5 deletions test/datasets/motif_generator/test_cycle_motif.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,11 +2,14 @@


def test_cycle_motif():
n = 5
motif_generator = CycleMotif(n)
assert str(motif_generator) == 'CycleMotif()'
motif_generator = CycleMotif(5)
assert str(motif_generator) == 'CycleMotif(5)'

motif = motif_generator()
assert len(motif) == 2
assert motif.num_nodes == n
assert motif.num_edges == n
assert motif.num_nodes == 5
assert motif.num_edges == 10
assert motif.edge_index.tolist() == [
[0, 0, 1, 1, 2, 2, 3, 3, 4, 4],
[1, 4, 0, 2, 1, 3, 2, 4, 0, 3],
]
5 changes: 4 additions & 1 deletion torch_geometric/datasets/motif_generator/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,5 +4,8 @@
from .cycle import CycleMotif

__all__ = classes = [
'MotifGenerator', 'CustomMotif', 'HouseMotif', 'CycleMotif'
'MotifGenerator',
'CustomMotif',
'HouseMotif',
'CycleMotif',
]
25 changes: 16 additions & 9 deletions torch_geometric/datasets/motif_generator/cycle.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,18 +7,25 @@
class CycleMotif(CustomMotif):
r"""Generates the cycle motif from the `"GNNExplainer:
Generating Explanations for Graph Neural Networks"
<https://arxiv.org/pdf/1903.03894.pdf>`_ paper, containing n nodes and n
undirected edges.
<https://arxiv.org/pdf/1903.03894.pdf>`_ paper.
Args:
n (int): Number of nodes (or edges) in the cycle.
num_nodes (int): The number of nodes in the cycle.
"""
def __init__(self, n: int = 6):
# construct edge_index based on n
def __init__(self, num_nodes: int):
self.num_nodes = num_nodes

row = torch.arange(num_nodes).view(-1, 1).repeat(1, 2).view(-1)
col1 = torch.arange(-1, num_nodes - 1) % num_nodes
col2 = torch.arange(1, num_nodes + 1) % num_nodes
col = torch.stack([col1, col2], dim=1).sort(dim=-1)[0].view(-1)

structure = Data(
num_nodes=n,
edge_index=torch.Tensor([[x for x in range(n)],
[y for y in range(1, n)] + [0]
]).type(torch.int32),
num_nodes=num_nodes,
edge_index=torch.stack([row, col], dim=0),
)

super().__init__(structure)

def __repr__(self) -> str:
return f'{self.__class__.__name__}({self.num_nodes})'

0 comments on commit f490606

Please sign in to comment.