Skip to content

Commit

Permalink
adding tests + removing type annotation to avoid circular dep
Browse files Browse the repository at this point in the history
  • Loading branch information
hatemhelal committed Sep 14, 2022
1 parent 2e18bee commit a0a6086
Show file tree
Hide file tree
Showing 3 changed files with 21 additions and 4 deletions.
18 changes: 18 additions & 0 deletions test/data/test_dataset.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import pytest
import torch
from torch_sparse import SparseTensor

Expand Down Expand Up @@ -42,6 +43,20 @@ def test_in_memory_dataset():
assert dataset[1].test_int == 2
assert dataset[1].test_str == '2'

summary = dataset.summary()
num_nodes = torch.tensor([5., 10.])
assert summary.num_graphs == 2
assert summary.mean_num_nodes == pytest.approx(float(num_nodes.mean()))
assert summary.std_num_nodes == pytest.approx(float(num_nodes.std()))
assert summary.min_num_nodes == int(num_nodes.min())
assert summary.max_num_nodes == int(num_nodes.max())

num_edges = torch.tensor([4., 4.])
assert summary.mean_num_edges == pytest.approx(float(num_edges.mean()))
assert summary.std_num_edges == pytest.approx(float(num_edges.std()))
assert summary.min_num_edges == int(num_edges.min())
assert summary.max_num_edges == int(num_edges.max())


def test_in_memory_sparse_tensor_dataset():
x = torch.randn(11, 16)
Expand Down Expand Up @@ -118,6 +133,9 @@ def test_hetero_in_memory_dataset():
assert (dataset[1]['paper', 'paper'].edge_index.tolist() == data2[
'paper', 'paper'].edge_index.tolist())

summary = dataset.summary()
assert summary.num_graphs == 2


def test_override_behaviour():
class DS(Dataset):
Expand Down
3 changes: 2 additions & 1 deletion torch_geometric/data/dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,8 +10,9 @@
import torch.utils.data
from torch import Tensor

from torch_geometric.data import Data, Summary
from torch_geometric.data import Data
from torch_geometric.data.makedirs import makedirs
from torch_geometric.data.summary import Summary

IndexType = Union[slice, Tensor, np.ndarray, Sequence]

Expand Down
4 changes: 1 addition & 3 deletions torch_geometric/data/summary.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,16 +2,14 @@
from tabulate import tabulate
from tqdm import tqdm

from torch_geometric.data import Dataset


class Summary:
r"""Summary of graph datasets
Args:
dataset (Dataset): :obj:`torch_geometric.data.Dataset`
"""
def __init__(self, dataset: Dataset):
def __init__(self, dataset):
self.dataset_str = repr(dataset)

def map_data(data):
Expand Down

0 comments on commit a0a6086

Please sign in to comment.