Skip to content

Commit

Permalink
🧹 Linting done for metrla_dataloader
Browse files Browse the repository at this point in the history
  • Loading branch information
nithinmanoj10 committed Jan 22, 2024
1 parent 0fc3a95 commit ce4000e
Show file tree
Hide file tree
Showing 2 changed files with 48 additions and 44 deletions.
2 changes: 1 addition & 1 deletion ruff.toml
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
[lint]
select = ["ALL"]

ignore = ["FBT002", "FBT001", "PLR0913", "TRY003", "EM101"]
ignore = ["FBT002", "FBT001", "PLR0913", "TRY003", "EM101", "ERA001"]
90 changes: 47 additions & 43 deletions stgraph/dataset/temporal/metrla_dataloader.py
Original file line number Diff line number Diff line change
@@ -1,32 +1,35 @@
"""Temporal dataset for traffic forecasting dataset based on Los Angeles Metropolitan traffic conditions"""
"""Temporal dataset for traffic forecasting based on Los Angeles city."""

from __future__ import annotations

import torch
import numpy as np
import torch

from stgraph.dataset.temporal.stgraph_temporal_dataset import STGraphTemporalDataset


class METRLADataLoader(STGraphTemporalDataset):
"""Temporal dataset for traffic forecasting dataset based on Los Angeles Metropolitan traffic conditions"""
"""Temporal dataset for traffic forecasting based on the Los Angeles city."""

def __init__(
self,
self: METRLADataLoader,
verbose: bool = True,
url: str = None,
url: str | None = None,
num_timesteps_in: int = 12,
num_timesteps_out: int = 12,
cutoff_time: int = None,
cutoff_time: int | None = None,
redownload: bool = False,
):
r"""A traffic forecasting dataset based on Los Angeles Metropolitan traffic conditions.
) -> None:
r"""Traffic forecasting dataset based on the Los Angeles city..
A dataset for predicting traffic patterns in the Los Angeles Metropolitan area,
comprising traffic data obtained from 207 loop detectors on highways in Los Angeles County.
The dataset includes aggregated 5-minute interval readings spanning a four-month
period from March 2012 to June 2012.
comprising traffic data obtained from 207 loop detectors on highways in Los
Angeles County. The dataset includes aggregated 5-minute interval readings
spanning a four-month period from March 2012 to June 2012.
This class provides functionality for loading, processing, and accessing the METRLA
dataset for use in deep learning tasks such as traffic forecasting.
This class provides functionality for loading, processing, and accessing
the METRLA dataset for use in deep learning tasks such as traffic
forecasting.
.. list-table:: gdata
:widths: 33 33 33
Expand Down Expand Up @@ -58,7 +61,6 @@ def __init__(
Parameters
----------
verbose : bool, optional
Flag to control whether to display verbose info (default is False)
url : str, optional
Expand Down Expand Up @@ -93,7 +95,6 @@ def __init__(
_all_targets : numpy.ndarray
Numpy array of the node target value
"""

super().__init__()

if not isinstance(num_timesteps_in, int):
Expand Down Expand Up @@ -137,16 +138,16 @@ def __init__(

self._process_dataset()

def _process_dataset(self) -> None:
def _process_dataset(self: METRLADataLoader) -> None:
self._set_total_timestamps()
self._set_num_nodes()
self._set_num_edges()
self._set_edges()
self._set_edge_weights()
self._set_targets_and_features()

def _set_total_timestamps(self) -> None:
r"""Sets the total timestamps present in the dataset
def _set_total_timestamps(self: METRLADataLoader) -> None:
r"""Set the total timestamps present in the dataset.
It sets the total timestamps present in the dataset into the
gdata attribute dictionary. It is the minimum of the cutoff time
Expand All @@ -155,33 +156,36 @@ def _set_total_timestamps(self) -> None:
"""
if self._cutoff_time is not None:
self.gdata["total_timestamps"] = min(
self._dataset["time_periods"], self._cutoff_time
self._dataset["time_periods"],
self._cutoff_time,
)
else:
self.gdata["total_timestamps"] = self._dataset["time_periods"]

def _set_num_nodes(self):
r"""Sets the total number of nodes present in the dataset"""
def _set_num_nodes(self: METRLADataLoader) -> None:
r"""Set the total number of nodes present in the dataset."""
node_set = set()
max_node_id = 0
for edge in self._dataset["edges"]:
node_set.add(edge[0])
node_set.add(edge[1])
max_node_id = max(max_node_id, edge[0], edge[1])

assert max_node_id == len(node_set) - 1, "Node ID labelling is not continuous"
if max_node_id == len(node_set) - 1:
raise RuntimeError("Node ID labelling is not continuous")

self.gdata["num_nodes"] = len(node_set)

def _set_num_edges(self):
r"""Sets the total number of edges present in the dataset"""
def _set_num_edges(self: METRLADataLoader) -> None:
r"""Set the total number of edges present in the dataset."""
self.gdata["num_edges"] = len(self._dataset["edges"])

def _set_edges(self):
r"""Sets the edge list of the dataset"""
def _set_edges(self: METRLADataLoader) -> None:
r"""Set the edge list of the dataset."""
self._edge_list = [(edge[0], edge[1]) for edge in self._dataset["edges"]]

def _set_edge_weights(self):
r"""Sets the edge weights of the dataset"""
def _set_edge_weights(self: METRLADataLoader) -> None:
r"""Set the edge weights of the dataset."""
edges = self._dataset["edges"]
edge_weights = self._dataset["weights"]
comb_edge_list = [
Expand All @@ -190,12 +194,12 @@ def _set_edge_weights(self):
comb_edge_list.sort(key=lambda x: (x[1], x[0]))
self._edge_weights = np.array([edge_det[2] for edge_det in comb_edge_list])

def _set_targets_and_features(self):
r"""Calculates and sets the target and feature attributes"""
x = []

for timestamp in range(self.gdata["total_timestamps"]):
x.append(self._dataset[str(timestamp)])
def _set_targets_and_features(self: METRLADataLoader) -> None:
r"""Calculate and set the target and feature attributes."""
x = [
self._dataset[str(timestamp)]
for timestamp in range(self.gdata["total_timestamps"])
]

x = np.array(x).transpose(1, 2, 0).astype(np.float32)
# x = x.transpose((1, 2, 0))
Expand All @@ -212,7 +216,7 @@ def _set_targets_and_features(self):
indices = [
(i, i + (self._num_timesteps_in + self._num_timesteps_out))
for i in range(
x.shape[2] - (self._num_timesteps_in + self._num_timesteps_out) + 1
x.shape[2] - (self._num_timesteps_in + self._num_timesteps_out) + 1,
)
]

Expand All @@ -225,18 +229,18 @@ def _set_targets_and_features(self):
self._all_features = np.array(features)
self._all_targets = np.array(target)

def get_edges(self):
r"""Returns the edge list"""
def get_edges(self: METRLADataLoader) -> list:
r"""Return the edge list."""
return self._edge_list

def get_edge_weights(self):
r"""Returns the edge weights"""
def get_edge_weights(self: METRLADataLoader) -> np.ndarray:
r"""Return the edge weights."""
return self._edge_weights

def get_all_targets(self):
r"""Returns the targets for each timestamp"""
def get_all_targets(self: METRLADataLoader) -> np.ndarray:
r"""Return the targets for each timestamp."""
return self._all_targets

def get_all_features(self):
r"""Returns the features for each timestamp"""
def get_all_features(self: METRLADataLoader) -> np.ndarray:
r"""Return the features for each timestamp."""
return self._all_features

0 comments on commit ce4000e

Please sign in to comment.