Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[PoC] AutoGluon TimeSeries Prototype #494

Merged
merged 5 commits into from
Oct 10, 2022
Merged
Show file tree
Hide file tree
Changes from 3 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@ venv/
.idea/
*.iml
*.swp
launch.json

# tmp files
.ipynb_checkpoints/
Expand Down
23 changes: 20 additions & 3 deletions amlb/benchmark.py
Original file line number Diff line number Diff line change
Expand Up @@ -489,7 +489,20 @@ def load_data(self):
# TODO
raise NotImplementedError("OpenML datasets without task_id are not supported yet.")
elif hasattr(self._task_def, 'dataset'):
self._dataset = Benchmark.data_loader.load(DataSourceType.file, dataset=self._task_def.dataset, fold=self.fold)
if self._task_def.dataset['type'] == 'timeseries' and self._task_def.dataset['timestamp_column'] is None:
log.warning("Warning: For timeseries task setting undefined timestamp column to `timestamp`.")
self._task_def.dataset['timestamp_column'] = "timestamp"
self._dataset = Benchmark.data_loader.load(DataSourceType.file, dataset=self._task_def.dataset, fold=self.fold, timestamp_column=self._task_def.dataset['timestamp_column'])
if self._dataset.type == DatasetType.timeseries:
if self._task_def.dataset['id_column'] is None:
log.warning("Warning: For timeseries task setting undefined itemid column to `item_id`.")
self._task_def.dataset['id_column'] = "item_id"
if self._task_def.dataset['prediction_length'] is None:
log.warning("Warning: For timeseries task setting undefined prediction length to `1`.")
self._task_def.dataset['prediction_length'] = "1"
self._dataset.timestamp_column=self._task_def.dataset['timestamp_column']
self._dataset.id_column=self._task_def.dataset['id_column']
self._dataset.prediction_length=self._task_def.dataset['prediction_length']
Copy link
Collaborator

@sebhrusen sebhrusen Oct 3, 2022

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

looks like most of this logic could reside in the loading logic itself as this is dealing with information available in self._task_def.dataset which is directly available to the file loader.
I'd move the logic to dataset/file.py for now to minimize scope of changes.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Okay, so you want me to extend the FileDataset or the CsvDataset?

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think you can extract this logic in a dedicated method in file.py for clarity (it's just mutating dataset after all), and if you just support CVS right now, then please apply it only there.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

so I added it to a dedicated method in file.py inside the FileLoader class.

else:
raise ValueError("Tasks should have one property among [openml_task_id, openml_dataset_id, dataset].")

Expand Down Expand Up @@ -522,7 +535,12 @@ def run(self):
predictions_dir=self.benchmark.output_dirs.predictions)
framework_def = self.benchmark.framework_def
task_config = copy(self.task_config)
task_config.type = 'regression' if self._dataset.type == DatasetType.regression else 'classification'
if self._dataset.type == DatasetType.regression:
task_config.type = 'regression'
elif self._dataset.type == DatasetType.timeseries:
task_config.type = 'timeseries'
else:
task_config.type = 'classification'
task_config.type_ = self._dataset.type.name
task_config.framework = self.benchmark.framework_name
task_config.framework_params = framework_def.params
Expand Down Expand Up @@ -552,4 +570,3 @@ def run(self):
finally:
self._dataset.release()
return results.compute_score(result=result, meta_result=meta_result)

1 change: 1 addition & 0 deletions amlb/data.py
Original file line number Diff line number Diff line change
Expand Up @@ -172,6 +172,7 @@ class DatasetType(Enum):
binary = 1
multiclass = 2
regression = 3
timeseries = 4


class Dataset(ABC):
Expand Down
20 changes: 11 additions & 9 deletions amlb/datasets/file.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@ def __init__(self, cache_dir=None):
self._cache_dir = cache_dir if cache_dir else tempfile.mkdtemp(prefix='amlb_cache')

@profile(logger=log)
def load(self, dataset, fold=0):
def load(self, dataset, fold=0, timestamp_column=None):
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

you obtained this new column using

timestamp_column=self._task_def.dataset['timestamp_column']

so you already have the information in the dataset object

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

true

dataset = dataset if isinstance(dataset, ns) else ns(path=dataset)
log.debug("Loading dataset %s", dataset)
paths = self._extract_train_test_paths(dataset.path if 'path' in dataset else dataset, fold=fold)
Expand All @@ -51,7 +51,7 @@ def load(self, dataset, fold=0):
if ext == '.arff':
return ArffDataset(train_path, test_path, target=target, features=features, type=type_)
elif ext == '.csv':
return CsvDataset(train_path, test_path, target=target, features=features, type=type_)
return CsvDataset(train_path, test_path, target=target, features=features, type=type_, timestamp_column=timestamp_column)
else:
raise ValueError(f"Unsupported file type: {ext}")

Expand Down Expand Up @@ -302,25 +302,26 @@ def release(self, properties=None):
class CsvDataset(FileDataset):

def __init__(self, train_path, test_path,
target=None, features=None, type=None):
target=None, features=None, type=None, timestamp_column=None):
# todo: handle auto-split (if test_path is None): requires loading the training set, split, save
super().__init__(None, None,
target=target, features=features, type=type)
self._train = CsvDatasplit(self, train_path)
self._test = CsvDatasplit(self, test_path)
self._train = CsvDatasplit(self, train_path, timestamp_column=timestamp_column)
self._test = CsvDatasplit(self, test_path, timestamp_column=timestamp_column)
self._dtypes = None


class CsvDatasplit(FileDatasplit):

def __init__(self, dataset, path):
def __init__(self, dataset, path, timestamp_column=None):
super().__init__(dataset, format='csv', path=path)
self._ds = None
self.timestamp_column = timestamp_column

def _ensure_loaded(self):
if self._ds is None:
if self.dataset._dtypes is None:
df = read_csv(self.path)
df = read_csv(self.path, timestamp_column=self.timestamp_column)
# df = df.convert_dtypes()
dt_conversions = {name: 'category'
for name, dtype in zip(df.dtypes.index, df.dtypes.values)
Expand All @@ -336,8 +337,9 @@ def _ensure_loaded(self):

self._ds = df
self.dataset._dtypes = self._ds.dtypes

else:
self._ds = read_csv(self.path, dtype=self.dataset._dtypes.to_dict())
self._ds = read_csv(self.path, dtype=self.dataset._dtypes.to_dict(), timestamp_column=self.timestamp_column)

@profile(logger=log)
def load_metadata(self):
Expand All @@ -348,7 +350,7 @@ def load_metadata(self):
else 'number' if pat.is_numeric_dtype(dt)
else 'category' if pat.is_categorical_dtype(dt)
else 'string' if pat.is_string_dtype(dt)
# else 'datetime' if pat.is_datetime64_dtype(dt)
else 'datetime' if pat.is_datetime64_dtype(dt)
else 'object')
features = [Feature(i, col, to_feature_type(dtypes[i]))
for i, col in enumerate(self._ds.columns)]
Expand Down
10 changes: 8 additions & 2 deletions amlb/datautils.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@
log = logging.getLogger(__name__)


def read_csv(path, nrows=None, header=True, index=False, as_data_frame=True, dtype=None):
def read_csv(path, nrows=None, header=True, index=False, as_data_frame=True, dtype=None, timestamp_column=None):
"""
read csv file to DataFrame.

Expand All @@ -37,13 +37,19 @@ def read_csv(path, nrows=None, header=True, index=False, as_data_frame=True, dty
:param header: if the columns header should be read.
:param as_data_frame: if the result should be returned as a data frame (default) or a numpy array.
:param dtype: data type for columns.
:param timestamp_column: column name for timestamp, to ensure dates are correctly parsed by pandas.
:return: a DataFrame
"""
if dtype is not None and timestamp_column is not None and timestamp_column in dtype:
dtype = dtype.copy() # to avoid outer context manipulation
del dtype[timestamp_column]

df = pd.read_csv(path,
nrows=nrows,
header=0 if header else None,
index_col=0 if index else None,
dtype=dtype)
dtype=dtype,
parse_dates=[timestamp_column] if timestamp_column is not None else None)
return df if as_data_frame else df.values


Expand Down
101 changes: 95 additions & 6 deletions amlb/results.py
Original file line number Diff line number Diff line change
Expand Up @@ -228,12 +228,16 @@ def load_predictions(predictions_file):
try:
df = read_csv(predictions_file, dtype=object)
log.debug("Predictions preview:\n %s\n", df.head(10).to_string())
if rconfig().test_mode:
TaskResult.validate_predictions(df)
if df.shape[1] > 2:
return ClassificationResult(df)
if 'y_past_period_error' in df.columns:
return TimeSeriesResult(df)
Comment on lines +235 to +236
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

please don't bypass test mode by adding your own test block: it should remain the first check and also be applied for time series. Not asking you to add the test dataset in our workflow right now, but we will need to add this soon after your changes.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

got it

else:
return RegressionResult(df)
if rconfig().test_mode:
TaskResult.validate_predictions(df)

if df.shape[1] > 2:
return ClassificationResult(df)
else:
return RegressionResult(df)
except Exception as e:
return ErrorResult(ResultError(e))
else:
Expand All @@ -255,7 +259,8 @@ def save_predictions(dataset: Dataset, output_file: str,
predictions: Union[A, DF, S] = None, truth: Union[A, DF, S] = None,
probabilities: Union[A, DF] = None, probabilities_labels: Union[list, A] = None,
target_is_encoded: bool = False,
preview: bool = True):
preview: bool = True,
quantiles: Union[A, DF] = None):
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nitpick: let's try to group the params functionally, makes it easier to read and understand params. Here quantiles has a function similar to probabilities.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

got it

""" Save class probabilities and predicted labels to file in csv format.

:param dataset:
Expand All @@ -266,6 +271,7 @@ def save_predictions(dataset: Dataset, output_file: str,
:param probabilities_labels:
:param target_is_encoded:
:param preview:
:param quantiles:
:return: None
"""
log.debug("Saving predictions to `%s`.", output_file)
Expand Down Expand Up @@ -308,6 +314,24 @@ def save_predictions(dataset: Dataset, output_file: str,

df = df.assign(predictions=preds)
df = df.assign(truth=truth)

if dataset.type == DatasetType.timeseries:
if quantiles is not None:
quantiles = quantiles.reset_index(drop=True)
df = pd.concat([df, quantiles], axis=1)

period_length = 1 # TODO: This period length could be adapted to the Dataset, but then we need to pass this information as well. As of now this works.

# we aim to calculate the mean period error from the past for each sequence: 1/N sum_{i=1}^N |x(t_i) - x(t_i - T)|
# 1. retrieve item_ids for each sequence/item
item_ids, inverse_item_ids = np.unique(dataset.test.X[dataset.id_column].squeeze().to_numpy(), return_index=False, return_inverse=True)
# 2. capture sequences in a list
y_past = [dataset.test.y.squeeze().to_numpy()[inverse_item_ids == i][:-dataset.prediction_length] for i in range(len(item_ids))]
# 3. calculate period error per sequence
y_past_period_error = [np.abs(y_past_item[period_length:] - y_past_item[:-period_length]).mean() for y_past_item in y_past]
# 4. repeat period error for each sequence, to save one for each element
y_past_period_error_rep = np.repeat(y_past_period_error, dataset.prediction_length)
df = df.assign(y_past_period_error=y_past_period_error_rep)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'd rather not have this here, this looks like a lot of calculations + assumptions (apparently you can't have time series without an id_column) for a method that just supposed to be save predictions into a standard format. Even more as this y_past_period_error seems to be useful only for the mase metric, therefore, either you can compute it with the metric or you compute it before (in AG framework integration).

For now, I'd move your computations to the __init__.py or exec.py file, and simply ensure that we can customize the result by adding optional columns (in this case, this includes both quantiles and your additional results).

Suggestion:
change signature to

 def save_predictions(dataset: Dataset, output_file: str,
                         predictions: Union[A, DF, S] = None, truth: Union[A, DF, S] = None,
                         probabilities: Union[A, DF] = None, probabilities_labels: Union[list, A] = None,
                         optional_columns: Union[A, DF] = None,
                         target_is_encoded: bool = False,
                         preview: bool = True):

and automatically concatenate the optional_columns to the predictions if provided. For now, you should be able to generate those in exec.py

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

got it

if preview:
log.info("Predictions preview:\n %s\n", df.head(20).to_string())
backup_file(output_file)
Expand Down Expand Up @@ -656,6 +680,71 @@ def r2(self):
"""R^2"""
return float(r2_score(self.truth, self.predictions))

class TimeSeriesResult(RegressionResult):

def __init__(self, predictions_df, info=None):
super().__init__(predictions_df, info)
self.truth = self.df['truth'].values if self.df is not None else None #.iloc[:, 1].values if self.df is not None else None
self.predictions = self.df['predictions'].values if self.df is not None else None #.iloc[:, -2].values if self.df is not None else None
self.y_past_period_error = self.df['y_past_period_error'].values
self.quantiles = self.df.iloc[:, 2:-1].values
self.quantiles_probs = np.array([float(q) for q in self.df.columns[2:-1]])
self.truth = self.truth.astype(float, copy=False)
self.predictions = self.predictions.astype(float, copy=False)
self.quantiles = self.quantiles.astype(float, copy=False)
self.y_past_period_error = self.y_past_period_error.astype(float, copy=False)

self.target = Feature(0, 'target', 'real', is_target=True)
self.type = DatasetType.timeseries

@metric(higher_is_better=False)
def mase(self):
"""Mean Absolute Scaled Error"""
return float(np.nanmean(np.abs(self.truth/self.y_past_period_error - self.predictions/self.y_past_period_error)))

@metric(higher_is_better=False)
def smape(self):
"""Symmetric Mean Absolute Percentage Error"""
num = np.abs(self.truth - self.predictions)
denom = (np.abs(self.truth) + np.abs(self.predictions)) / 2
# If the denominator is 0, we set it to float('inf') such that any division yields 0 (this
# might not be fully mathematically correct, but at least we don't get NaNs)
denom[denom == 0] = math.inf
return np.mean(num / denom)

@metric(higher_is_better=False)
def mape(self):
"""Symmetric Mean Absolute Percentage Error"""
num = np.abs(self.truth - self.predictions)
denom = np.abs(self.truth)
# If the denominator is 0, we set it to float('inf') such that any division yields 0 (this
# might not be fully mathematically correct, but at least we don't get NaNs)
denom[denom == 0] = math.inf
return np.mean(num / denom)

@metric(higher_is_better=False)
def nrmse(self):
"""Normalized Root Mean Square Error"""
return self.rmse() / np.mean(np.abs(self.truth))

@metric(higher_is_better=False)
def wape(self):
"""Weighted Average Percentage Error"""
return np.sum(np.abs(self.truth - self.predictions)) / np.sum(np.abs(self.truth))

@metric(higher_is_better=False)
def ncrps(self):
"""Normalized Continuous Ranked Probability Score"""
quantile_losses = 2 * np.sum(
np.abs(
(self.quantiles - self.truth[:, None])
* ((self.quantiles >= self.truth[:, None]) - self.quantiles_probs[None, :])
),
axis=0,
)
denom = np.sum(np.abs(self.truth)) # shape [num_time_series, num_quantiles]
weighted_losses = quantile_losses.sum(0) / denom # shape [num_quantiles]
return weighted_losses.mean()

_encode_predictions_and_truth_ = False

Expand Down
16 changes: 16 additions & 0 deletions frameworks/AutoGluonTS/README.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,16 @@
# AutoGluonTS

AutoGluonTS stands for autogluon.timeseries. This framework handles time series problems.

This code is currently a prototype, since time series support is not fully defined in AutoMLBenchmark yet.
Consider the code a proof of concept.

## Run Steps

To run AutoGluonTS in AutoMLBenchmark on the covid dataset from the AutoGluon tutorial, do the following:

1. Create a fresh Python environment
2. Follow automlbenchmark install instructions
3. Run the following command in terminal: ```python3 ../automlbenchmark/runbenchmark.py autogluonts ts test```

To run mainline AutoGluonTS instead of v0.5.2: ```python3 ../automlbenchmark/runbenchmark.py autogluonts:latest ts test```
38 changes: 38 additions & 0 deletions frameworks/AutoGluonTS/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,38 @@
from amlb.benchmark import TaskConfig
from amlb.data import Dataset, DatasetType
from amlb.utils import call_script_in_same_dir


def setup(*args, **kwargs):
call_script_in_same_dir(__file__, "setup.sh", *args, **kwargs)


def run(dataset: Dataset, config: TaskConfig):
from frameworks.shared.caller import run_in_venv

if hasattr(dataset, 'timestamp_column') is False:
dataset.timestamp_column = None
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

for this one and below

Suggested change
if hasattr(dataset, 'timestamp_column') is False:
dataset.timestamp_column = None
if not hasattr(dataset, 'timestamp_column'):
dataset.timestamp_column = None

if hasattr(dataset, 'id_column') is False:
dataset.id_column = None
if hasattr(dataset, 'prediction_length') is False:
raise AttributeError("Unspecified `prediction_length`.")
if dataset.type is not DatasetType.timeseries:
raise ValueError("AutoGluonTS only supports timeseries.")

data = dict(
# train=dict(path=dataset.train.data_path('parquet')),
# test=dict(path=dataset.test.data_path('parquet')),
train=dict(path=dataset.train.path),
test=dict(path=dataset.test.path),
target=dict(
name=dataset.target.name,
classes=dataset.target.values
),
problem_type=dataset.type.name, # AutoGluon problem_type is using same names as amlb.data.DatasetType
timestamp_column=dataset.timestamp_column,
id_column=dataset.id_column,
prediction_length=dataset.prediction_length
)

return run_in_venv(__file__, "exec.py",
input_data=data, dataset=dataset, config=config)
Loading