Skip to content

Commit

Permalink
address coverage gaps
Browse files Browse the repository at this point in the history
  • Loading branch information
hatemhelal committed Sep 15, 2022
1 parent 68244ad commit 8148e8c
Show file tree
Hide file tree
Showing 2 changed files with 34 additions and 22 deletions.
48 changes: 29 additions & 19 deletions test/data/test_summary.py
Original file line number Diff line number Diff line change
@@ -1,29 +1,47 @@
import torch
from pytest import approx
from pytest import approx, raises

from torch_geometric import seed_everything
from torch_geometric.data import Summary
from torch_geometric.datasets import FakeDataset, FakeHeteroDataset
from torch_geometric.testing import withPackage


@withPackage('pandas')
@withPackage('tabulate')
def test_summary():
seed_everything(0)
dataset = FakeDataset(num_graphs=10)
summary = dataset.summary()
assert summary.num_graphs == 10
num_nodes = torch.tensor([d.num_nodes for d in dataset]).float()
def check_summary(summary, num_nodes, num_edges):
assert summary.mean_num_nodes == approx(float(num_nodes.mean()))
assert summary.std_num_nodes == approx(float(num_nodes.std()))
assert summary.min_num_nodes == int(num_nodes.min())
assert summary.max_num_nodes == int(num_nodes.max())
assert summary.median_num_nodes == int(num_nodes.quantile(q=0.5))

num_edges = torch.tensor([d.num_edges for d in dataset]).float()
assert summary.mean_num_edges == approx(float(num_edges.mean()))
assert summary.std_num_edges == approx(float(num_edges.std()))
assert summary.min_num_edges == int(num_edges.min())
assert summary.max_num_edges == int(num_edges.max())
assert summary.median_num_edges == int(num_edges.quantile(q=0.5))


@withPackage('pandas')
@withPackage('tabulate')
def test_summary():
seed_everything(0)
dataset = FakeDataset(num_graphs=10)
summary = dataset.summary()
assert summary.num_graphs == 10
num_nodes = torch.tensor([d.num_nodes for d in dataset]).float()
num_edges = torch.tensor([d.num_edges for d in dataset]).float()
check_summary(summary, num_nodes, num_edges)

Summary.progressbar_threshold(0)
summary = dataset.summary()
assert summary.num_graphs == 10
check_summary(summary, num_nodes, num_edges)

with raises(ValueError, match="threshold must be a positive integer"):
Summary.progressbar_threshold(-1)

with raises(ValueError, match="threshold must be a positive integer"):
Summary.progressbar_threshold(10.0)


@withPackage('pandas')
Expand All @@ -34,13 +52,5 @@ def test_hetero_summary():
summary = dataset.summary()
assert summary.num_graphs == 10
num_nodes = torch.tensor([d.num_nodes for d in dataset]).float()
assert summary.mean_num_nodes == approx(float(num_nodes.mean()))
assert summary.std_num_nodes == approx(float(num_nodes.std()))
assert summary.min_num_nodes == int(num_nodes.min())
assert summary.max_num_nodes == int(num_nodes.max())

num_edges = torch.tensor([d.num_edges for d in dataset]).float()
assert summary.mean_num_edges == approx(float(num_edges.mean()))
assert summary.std_num_edges == approx(float(num_edges.std()))
assert summary.min_num_edges == int(num_edges.min())
assert summary.max_num_edges == int(num_edges.max())
check_summary(summary, num_nodes, num_edges)
8 changes: 5 additions & 3 deletions torch_geometric/data/summary.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,7 +43,7 @@ def max_num_nodes(self) -> int:
@property
def median_num_nodes(self) -> int:
r"""The median number of nodes"""
return int(self._desc['nodes']['median'])
return int(self._desc['nodes']['50%'])

@property
def mean_num_nodes(self) -> float:
Expand All @@ -68,7 +68,7 @@ def max_num_edges(self) -> int:
@property
def median_num_edges(self) -> int:
r"""The median number of edges"""
return int(self._desc['edges']['median'])
return int(self._desc['edges']['50%'])

@property
def mean_num_edges(self) -> float:
Expand Down Expand Up @@ -104,5 +104,7 @@ def progressbar_threshold(value: Optional[int] = None):
Summary._threshold = Summary._default_threshold
return

assert isinstance(value, int) and not value < 0
if not isinstance(value, int) or value < 0:
raise ValueError("threshold must be a positive integer.")

Summary._threshold = value

0 comments on commit 8148e8c

Please sign in to comment.