Skip to content

Commit

Permalink
🧹 Linted CoraDataLoader using ruff
Browse files Browse the repository at this point in the history
Now will be following ruff for linting the project
  • Loading branch information
nithinmanoj10 committed Jan 20, 2024
1 parent fe06396 commit 62353ab
Show file tree
Hide file tree
Showing 7 changed files with 74 additions and 47 deletions.
4 changes: 4 additions & 0 deletions ruff.toml
Original file line number Diff line number Diff line change
@@ -0,0 +1,4 @@
[lint]
select = ["ALL"]

ignore = ["FBT002", "FBT001"]
4 changes: 2 additions & 2 deletions stgraph/dataset/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,8 +2,8 @@

from stgraph.dataset.static.cora_dataloader import CoraDataLoader

from stgraph.dataset.temporal.HungaryCPDataLoader import HungaryCPDataLoader
from stgraph.dataset.temporal.METRLADataLoader import METRLADataLoader
from stgraph.dataset.temporal.hungarycp_dataloader import HungaryCPDataLoader
from stgraph.dataset.temporal.metrla_dataloader import METRLADataLoader
from stgraph.dataset.temporal.MontevideoBusDataLoader import MontevideoBusDataLoader
from stgraph.dataset.temporal.PedalMeDataLoader import PedalMeDataLoader
from stgraph.dataset.temporal.WikiMathDataLoader import WikiMathDataLoader
Expand Down
1 change: 1 addition & 0 deletions stgraph/dataset/static/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
"""Collection of dataset loaders for Static real-world datasets."""
53 changes: 30 additions & 23 deletions stgraph/dataset/static/cora_dataloader.py
Original file line number Diff line number Diff line change
@@ -1,28 +1,36 @@
"""Citation network consisting of scientific publications"""
"""Citation network consisting of scientific publications."""


from typing import Optional

import numpy as np
from rich.console import Console

from stgraph.dataset.static.STGraphStaticDataset import STGraphStaticDataset

from stgraph.dataset.static.stgraph_static_dataset import STGraphStaticDataset

console = Console()


class CoraDataLoader(STGraphStaticDataset):
"""Dataloader provided for Citation network consisting of scientific publications"""

def __init__(self, verbose=False, url=None, redownload=False) -> None:
r"""Citation network consisting of scientific publications
The Cora dataset consists of 2708 scientific publications classified into one of seven classes.
The citation network consists of 5429 links. Each publication in the dataset is described by a 0/1-valued
word vector indicating the absence/presence of the corresponding word from the dictionary.
"""Citation network consisting of scientific publications."""

def __init__(
self: "CoraDataLoader",
verbose: bool = False,
url: Optional[str] = None,
redownload: bool = False,
) -> None:
r"""Citation network consisting of scientific publications.
The Cora dataset consists of 2708 scientific publications classified into
one of seven classes. The citation network consists of 5429 links. Each
publication in the dataset is described by a 0/1-valued word vector
indicating the absence/presence of the corresponding word from the dictionary.
The dictionary consists of 1433 unique words.
This class provides functionality for loading, processing, and accessing the Cora dataset
for use in deep learning tasks such as graph-based node classification.
This class provides functionality for loading, processing, and accessing the
Cora dataset for use in deep learning tasks such as graph-based
node classification.
.. list-table:: gdata
:widths: 25 25 25 25
Expand Down Expand Up @@ -50,7 +58,6 @@ def __init__(self, verbose=False, url=None, redownload=False) -> None:
Parameters
----------
verbose : bool, optional
Flag to control whether to display verbose info (default is False)
url : str, optional
Expand Down Expand Up @@ -96,7 +103,7 @@ def __init__(self, verbose=False, url=None, redownload=False) -> None:

self._process_dataset()

def _process_dataset(self) -> None:
def _process_dataset(self: "CoraDataLoader") -> None:
r"""Process the Cora dataset.
Calls private methods to extract edge list, node features, target classes
Expand All @@ -106,8 +113,8 @@ def _process_dataset(self) -> None:
self._set_targets_and_features()
self._set_graph_attributes()

def _set_edge_info(self) -> None:
r"""Extract edge information from the dataset"""
def _set_edge_info(self: "CoraDataLoader") -> None:
r"""Extract edge information from the dataset."""
edges = np.array(self._dataset["edges"])
edge_list = []

Expand All @@ -116,13 +123,13 @@ def _set_edge_info(self) -> None:

self._edge_list = edge_list

def _set_targets_and_features(self):
def _set_targets_and_features(self: "CoraDataLoader") -> None:
r"""Extract targets and features from the dataset."""
self._all_features = np.array(self._dataset["features"])
self._all_targets = np.array(self._dataset["labels"]).T

def _set_graph_attributes(self):
r"""Calculates and stores graph meta data inside ``gdata``"""
def _set_graph_attributes(self: "CoraDataLoader") -> None:
r"""Calculate and stores graph meta data inside ``gdata``."""
node_set = set()
for edge in self._edge_list:
node_set.add(edge[0])
Expand All @@ -133,14 +140,14 @@ def _set_graph_attributes(self):
self.gdata["num_feats"] = len(self._all_features[0])
self.gdata["num_classes"] = len(set(self._all_targets))

def get_edges(self) -> list:
def get_edges(self: "CoraDataLoader") -> list:
r"""Get the edge list."""
return self._edge_list

def get_all_features(self) -> np.ndarray:
def get_all_features(self: "CoraDataLoader") -> np.ndarray:
r"""Get all features."""
return self._all_features

def get_all_targets(self) -> np.ndarray:
def get_all_targets(self: "CoraDataLoader") -> np.ndarray:
r"""Get all targets."""
return self._all_targets
Original file line number Diff line number Diff line change
@@ -1,9 +1,13 @@
"""Temporal dataset for County level chicken pox cases in Hungary"""

import numpy as np

from stgraph.dataset.temporal.STGraphTemporalDataset import STGraphTemporalDataset


class HungaryCPDataLoader(STGraphTemporalDataset):
"""Temporal dataset provided for County level chicken pox cases in Hungary"""

def __init__(
self,
verbose: bool = False,
Expand Down Expand Up @@ -78,20 +82,23 @@ def __init__(

super().__init__()

if type(lags) != int:
if not isinstance(lags, int):
raise TypeError("lags must be of type int")
if lags < 0:
raise ValueError("lags must be a positive integer")

if cutoff_time != None and type(cutoff_time) != int:
if cutoff_time is not None and not isinstance(cutoff_time, int):
raise TypeError("cutoff_time must be of type int")
if cutoff_time != None and cutoff_time < 0:
if cutoff_time is not None and cutoff_time < 0:
raise ValueError("cutoff_time must be a positive integer")

self.name = "Hungary_Chickenpox"
self._verbose = verbose
self._lags = lags
self._cutoff_time = cutoff_time
self._edge_list = None
self._edge_weights = None
self._all_targets = None

if not url:
self._url = "https://raw.githubusercontent.com/bfGraph/STGraph-Datasets/main/HungaryCP.json"
Expand Down Expand Up @@ -125,7 +132,7 @@ def _set_total_timestamps(self) -> None:
choosen by the user and the total time periods present in the
original dataset.
"""
if self._cutoff_time != None:
if self._cutoff_time is not None:
self.gdata["total_timestamps"] = min(
len(self._dataset["FX"]), self._cutoff_time
)
Expand Down
Original file line number Diff line number Diff line change
@@ -1,10 +1,14 @@
"""Temporal dataset for traffic forecasting dataset based on Los Angeles Metropolitan traffic conditions"""

import torch
import numpy as np

from stgraph.dataset.temporal.STGraphTemporalDataset import STGraphTemporalDataset


class METRLADataLoader(STGraphTemporalDataset):
"""Temporal dataset for traffic forecasting dataset based on Los Angeles Metropolitan traffic conditions"""

def __init__(
self,
verbose: bool = True,
Expand Down Expand Up @@ -92,26 +96,30 @@ def __init__(

super().__init__()

if type(num_timesteps_in) != int:
if not isinstance(num_timesteps_in, int):
raise TypeError("num_timesteps_in must be of type int")
if num_timesteps_in < 0:
raise ValueError("num_timesteps_in must be a positive integer")

if type(num_timesteps_out) != int:
if not isinstance(num_timesteps_out, int):
raise TypeError("num_timesteps_out must be of type int")
if num_timesteps_out < 0:
raise ValueError("num_timesteps_out must be a positive integer")

if cutoff_time != None and type(cutoff_time) != int:
if cutoff_time is not None and not isinstance(cutoff_time, int):
raise TypeError("cutoff_time must be of type int")
if cutoff_time != None and cutoff_time < 0:
if cutoff_time is not None and cutoff_time < 0:
raise ValueError("cutoff_time must be a positive integer")

self.name = "METRLA"
self._verbose = verbose
self._num_timesteps_in = num_timesteps_in
self._num_timesteps_out = num_timesteps_out
self._cutoff_time = cutoff_time
self._edge_list = None
self._edge_weights = None
self._all_features = None
self._all_targets = None

if not url:
self._url = "https://raw.githubusercontent.com/bfGraph/STGraph-Datasets/main/METRLA.json"
Expand Down Expand Up @@ -145,7 +153,7 @@ def _set_total_timestamps(self) -> None:
choosen by the user and the total time periods present in the
original dataset.
"""
if self._cutoff_time != None:
if self._cutoff_time is not None:
self.gdata["total_timestamps"] = min(
self._dataset["time_periods"], self._cutoff_time
)
Expand Down Expand Up @@ -184,35 +192,35 @@ def _set_edge_weights(self):

def _set_targets_and_features(self):
r"""Calculates and sets the target and feature attributes"""
X = []
x = []

for timestamp in range(self.gdata["total_timestamps"]):
X.append(self._dataset[str(timestamp)])
x.append(self._dataset[str(timestamp)])

X = np.array(X)
X = X.transpose((1, 2, 0))
X = X.astype(np.float32)
x = np.array(x).transpose(1, 2, 0).astype(np.float32)
# x = x.transpose((1, 2, 0))
# x = x.astype(np.float32)

# Normalise as in DCRNN paper (via Z-Score Method)
means = np.mean(X, axis=(0, 2))
X = X - means.reshape(1, -1, 1)
stds = np.std(X, axis=(0, 2))
X = X / stds.reshape(1, -1, 1)
means = np.mean(x, axis=(0, 2))
x = x - means.reshape(1, -1, 1)
stds = np.std(x, axis=(0, 2))
x = x / stds.reshape(1, -1, 1)

X = torch.from_numpy(X)
x = torch.from_numpy(x)

indices = [
(i, i + (self._num_timesteps_in + self._num_timesteps_out))
for i in range(
X.shape[2] - (self._num_timesteps_in + self._num_timesteps_out) + 1
x.shape[2] - (self._num_timesteps_in + self._num_timesteps_out) + 1
)
]

# Generate observations
features, target = [], []
for i, j in indices:
features.append((X[:, :, i : i + self._num_timesteps_in]).numpy())
target.append((X[:, 0, i + self._num_timesteps_in : j]).numpy())
features.append((x[:, :, i : i + self._num_timesteps_in]).numpy())
target.append((x[:, 0, i + self._num_timesteps_in : j]).numpy())

self._all_features = np.array(features)
self._all_targets = np.array(target)
Expand Down

0 comments on commit 62353ab

Please sign in to comment.