Skip to content

Commit

Permalink
generalize dataloader
Browse files Browse the repository at this point in the history
  • Loading branch information
AlexanderVNikitin committed Jun 5, 2024
1 parent 129b99a commit d5a9751
Show file tree
Hide file tree
Showing 2 changed files with 28 additions and 52 deletions.
24 changes: 11 additions & 13 deletions tests/test_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -189,7 +189,7 @@ def test_mmd_3_test():
Y = np.random.normal(10, 100, 100)[:, None]
Z = np.random.normal(0, 1, 100)[:, None]

# Use custome kernels with this (TF-sklearn compatibility)
# Use custom kernels with this (TF-sklearn compatibility)
# sigma_XY = tsgm.utils.kernel_median_heuristic(X, Y);
# sigma_XZ = tsgm.utils.kernel_median_heuristic(X, Z);
# sigma = (sigma_XY + sigma_XZ) / 2
Expand All @@ -201,16 +201,14 @@ def test_mmd_3_test():


@pytest.mark.parametrize("dataset_name", [
"beef",
"coffee",
"ecg200",
"electric",
"freezer",
"gunpoint",
"insect",
"mixed_shapes",
"starlight",
"wafer"
"Beef",
"Coffee",
"ECG200",
"ElectricDevices",
"GunPoint",
"MixedShapesRegularTrain",
"StarLightCurves",
"Wafer"
])
def test_ucr_loadable(dataset_name):
ucr_data_manager = tsgm.utils.UCRDataManager(ds=dataset_name)
Expand All @@ -222,11 +220,11 @@ def test_ucr_loadable(dataset_name):
def test_ucr_raises():
with pytest.raises(ValueError) as excinfo:
ucr_data_manager = tsgm.utils.UCRDataManager(ds="does not exist")
assert "ds should be in" in str(excinfo.value)
assert "ds should be listed at UCR website" in str(excinfo.value)


def test_get_wafer():
dataset = "wafer"
dataset = "Wafer"
ucr_data_manager = tsgm.utils.UCRDataManager(ds=dataset)
assert ucr_data_manager.summary() is None
X_train, y_train, X_test, y_test = ucr_data_manager.get()
Expand Down
56 changes: 17 additions & 39 deletions tsgm/utils/datasets.py
Original file line number Diff line number Diff line change
Expand Up @@ -130,6 +130,16 @@ def gen_sine_vs_const_dataset(N: int, T: int, D: int, max_value: int = 10, const
class UCRDataManager:
"""
A manager for UCR collection of time series datasets.
If you find these datasets useful, please cite:
@misc{UCRArchive2018,
title = {The UCR Time Series Classification Archive},
author = {Dau, Hoang Anh and Keogh, Eamonn and Kamgar, Kaveh and Yeh, Chin-Chia Michael and Zhu, Yan
and Gharghabi, Shaghayegh and Ratanamahatana, Chotirat Ann and Yanping and Hu, Bing
and Begum, Nurjahan and Bagnall, Anthony and Mueen, Abdullah and Batista, Gustavo, and Hexagon-ML},
year = {2018},
month = {October},
note = {\\url{https://www.cs.ucr.edu/~eamonn/time_series_data_2018/}}
}
"""
mirrors = ["https://www.cs.ucr.edu/~eamonn/time_series_data_2018/"]
resources = [("UCRArchive_2018.zip", 0)]
Expand All @@ -140,7 +150,7 @@ def __init__(self, path: str = default_path, ds: str = "gunpoint") -> None:
"""
:param path: a relative path to the stored UCR dataset.
:type path: str
:param ds: Name of the dataset. Should be in (beef | coffee | ecg200 | freezer | gunpoint | insect | mixed_shapes | starlight).
:param ds: Name of the dataset. The list of names is available at https://www.cs.ucr.edu/~eamonn/time_series_data_2018/ (case sensitive!).
:type ds: str
:raises ValueError: When there is no stored UCR archive, or the name of the dataset is incorrect.
Expand All @@ -150,48 +160,16 @@ def __init__(self, path: str = default_path, ds: str = "gunpoint") -> None:

self.ds = ds.strip().lower()
self.y_all: T.Optional[T.Collection[T.Hashable]] = None
path = os.path.join(path, ds)
train_files = glob.glob(os.path.join(path, "*TRAIN.tsv"))

if ds == "beef":
self.regular_train_path = os.path.join(path, "Beef")
self.small_train_path = os.path.join(path, "Beef")
elif ds == "coffee":
self.regular_train_path = os.path.join(path, "Coffee")
self.small_train_path = os.path.join(path, "Coffee")
elif ds == "ecg200":
self.regular_train_path = os.path.join(path, "ECG200")
self.small_train_path = os.path.join(path, "ECG200")
elif ds == "electric":
self.regular_train_path = os.path.join(path, "ElectricDevices")
self.small_train_path = os.path.join(path, "ElectricDevices")
elif ds == "freezer":
self.regular_train_path = os.path.join(path, "FreezerRegularTrain")
self.small_train_path = os.path.join(path, "FreezerSmallTrain")
elif ds == "gunpoint":
self.regular_train_path = os.path.join(path, "GunPoint")
self.small_train_path = os.path.join(path, "GunPoint")
elif ds == "insect":
self.regular_train_path = os.path.join(path, "InsectEPGRegularTrain")
self.small_train_path = os.path.join(path, path, "InsectEPGSmallTrain")
elif ds == "mixed_shapes":
self.regular_train_path = os.path.join(path, path, "MixedShapesRegularTrain")
self.small_train_path = os.path.join(path, path, "MixedShapesSmallTrain")
elif ds == "starlight":
self.regular_train_path = os.path.join(path, path, "StarLightCurves")
self.small_train_path = os.path.join(path, path, "StarLightCurves")
elif ds == "wafer":
self.regular_train_path = os.path.join(path, path, "Wafer")
self.small_train_path = os.path.join(path, path, "Wafer")
else:
raise ValueError("ds should be in (beef | coffee | ecg200 | freezer | gunpoint | insect | mixed_shapes | starlight)")

self.small_train_df = pd.read_csv(
glob.glob(os.path.join(self.small_train_path, "*TRAIN.tsv"))[0],
sep='\t', header=None)
if len(train_files) == 0:
raise ValueError("ds should be listed at UCR website")
self.train_df = pd.read_csv(
glob.glob(os.path.join(self.regular_train_path, "*TRAIN.tsv"))[0],
glob.glob(os.path.join(path, "*TRAIN.tsv"))[0],
sep='\t', header=None)
self.test_df = pd.read_csv(
glob.glob(os.path.join(self.regular_train_path, "*TEST.tsv"))[0],
glob.glob(os.path.join(path, "*TEST.tsv"))[0],
sep='\t', header=None)

self.X_train, self.y_train = self.train_df[self.train_df.columns[1:]].to_numpy(), self.train_df[self.train_df.columns[0]].to_numpy()
Expand Down

0 comments on commit d5a9751

Please sign in to comment.