From 7055331513fea09b6d745868f2b7d5bf0b7da6af Mon Sep 17 00:00:00 2001 From: Lorenzo Stella Date: Fri, 22 Mar 2024 16:44:44 +0100 Subject: [PATCH] Add ETT datasets (#3149) *Description of changes:* Add electricity transformer datasets from https://github.com/zhouhaoyi/ETDataset By submitting this pull request, I confirm that you can use, modify, copy, and redistribute this contribution, under the terms of your choice. **Please tag this pr with at least one of these labels to make our release process faster:** BREAKING, new feature, bug fix, other change, dev setup --- src/gluonts/dataset/repository/_ett_small.py | 69 ++++++++++++++++++++ src/gluonts/dataset/repository/datasets.py | 13 ++++ 2 files changed, 82 insertions(+) create mode 100644 src/gluonts/dataset/repository/_ett_small.py diff --git a/src/gluonts/dataset/repository/_ett_small.py b/src/gluonts/dataset/repository/_ett_small.py new file mode 100644 index 0000000000..bf021a9659 --- /dev/null +++ b/src/gluonts/dataset/repository/_ett_small.py @@ -0,0 +1,69 @@ +# Copyright 2018 Amazon.com, Inc. or its affiliates. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"). +# You may not use this file except in compliance with the License. +# A copy of the License is located at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# or in the "license" file accompanying this file. This file is distributed +# on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either +# express or implied. See the License for the specific language governing +# permissions and limitations under the License. + +from pathlib import Path + +import pandas as pd +from gluonts.dataset import DatasetWriter +from gluonts.dataset.common import MetaData, TrainDatasets + + +# Currently data from only two regions are made public. +NUM_REGIONS = 2 + + +def generate_ett_small_dataset( + dataset_path: Path, + dataset_writer: DatasetWriter, + base_file_name: str, + freq: str, + prediction_length: int, +): + dfs = [] + for i in range(NUM_REGIONS): + df = pd.read_csv( + f"https://raw.githubusercontent.com/zhouhaoyi/ETDataset" + f"/main/ETT-small/{base_file_name}{i+1}.csv" + ) + df["date"] = df["date"].astype("datetime64[ms]") + dfs.append(df) + + test = [] + for df in dfs: + start = pd.Period(df["date"][0], freq=freq) + for col in df.columns: + if col in ["date"]: + continue + test.append( + { + "start": start, + "target": df[col].values, + } + ) + + train = [] + for df in dfs: + start = pd.Period(df["date"][0], freq=freq) + for col in df.columns: + if col in ["date"]: + continue + train.append( + { + "start": start, + "target": df[col].values[:-prediction_length], + } + ) + + metadata = MetaData(freq=freq, prediction_length=prediction_length) + dataset = TrainDatasets(metadata=metadata, train=train, test=test) + dataset.save(str(dataset_path), writer=dataset_writer, overwrite=True) diff --git a/src/gluonts/dataset/repository/datasets.py b/src/gluonts/dataset/repository/datasets.py index 87f0369446..bfc71b5031 100644 --- a/src/gluonts/dataset/repository/datasets.py +++ b/src/gluonts/dataset/repository/datasets.py @@ -25,6 +25,7 @@ from ._artificial import generate_artificial_dataset from ._airpassengers import generate_airpassengers_dataset from ._ercot import generate_ercot_dataset +from ._ett_small import generate_ett_small_dataset from ._gp_copula_2019 import generate_gp_copula_dataset from ._lstnet import generate_lstnet_dataset from ._m3 import generate_m3_dataset @@ -243,6 +244,18 @@ def get_download_path() -> Path: dataset_name="vehicle_trips_without_missing", ), "ercot": partial(generate_ercot_dataset), + "ett_small_15min": partial( + generate_ett_small_dataset, + base_file_name="ETTm", + freq="15min", + prediction_length=24, + ), + "ett_small_1h": partial( + generate_ett_small_dataset, + base_file_name="ETTh", + freq="1h", + prediction_length=24, + ), }