Skip to content

Commit

Permalink
Fix item order in GluonTS models predictions (open-mmlab#2092)
Browse files Browse the repository at this point in the history
  • Loading branch information
shchur authored Sep 2, 2022
1 parent 94231ac commit 686510e
Show file tree
Hide file tree
Showing 8 changed files with 42 additions and 38 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,8 @@ def __init__(

def _fit(self, predictions, labels, time_limit=None, sample_weight=None):
self.dummy_pred = copy.deepcopy(predictions[0])
# This should never happen; sanity check to make sure that all predictions have the same index
assert all(self.dummy_pred.index.equals(pred.index) for pred in predictions)
super()._fit(
predictions=[d.values for d in predictions],
labels=labels,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -48,10 +48,10 @@ def freq(self):
return self.time_series_df.freq

def __len__(self):
return len(self.time_series_df.index.levels[0]) # noqa
return len(self.time_series_df.item_ids) # noqa

def __iter__(self) -> Iterator[Dict[str, Any]]:
for j in self.time_series_df.index.levels[0]: # noqa
for j in self.time_series_df.item_ids: # noqa
df = self.time_series_df.loc[j]
yield {
"item_id": j,
Expand Down Expand Up @@ -247,7 +247,8 @@ def predict(self, data: TimeSeriesDataFrame, quantile_levels: List[float] = None
inplace=True,
)

return df
# Make sure the item_ids are sorted in the same order as in data
return df.loc[data.item_ids]

def _predict_gluonts_forecasts(self, data: TimeSeriesDataFrame, **kwargs) -> List[Forecast]:
gts_data = self._to_gluonts_dataset(data)
Expand Down
2 changes: 1 addition & 1 deletion timeseries/tests/unittests/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,7 +51,7 @@ def get_data_frame_with_item_index(
)


DUMMY_TS_DATAFRAME = get_data_frame_with_item_index(["A", "B", "C", "D"])
DUMMY_TS_DATAFRAME = get_data_frame_with_item_index(["10", "A", "2", "1"])


def get_data_frame_with_variable_lengths(item_id_to_length: Dict[str, int]):
Expand Down
15 changes: 8 additions & 7 deletions timeseries/tests/unittests/models/test_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,9 +20,10 @@
from ..common import DUMMY_TS_DATAFRAME, dict_equal_primitive, get_data_frame_with_item_index
from .test_gluonts import TESTABLE_MODELS as GLUONTS_TESTABLE_MODELS
from .test_sktime import TESTABLE_MODELS as SKTIME_TESTABLE_MODELS
from .test_statsmodels import TESTABLE_MODELS as STATSMODELS_TESTABLE_MODELS

AVAILABLE_METRICS = TimeSeriesEvaluator.AVAILABLE_METRICS
TESTABLE_MODELS = GLUONTS_TESTABLE_MODELS + SKTIME_TESTABLE_MODELS
TESTABLE_MODELS = GLUONTS_TESTABLE_MODELS + SKTIME_TESTABLE_MODELS + STATSMODELS_TESTABLE_MODELS
TESTABLE_PREDICTION_LENGTHS = [1, 5]


Expand Down Expand Up @@ -177,8 +178,8 @@ def test_when_fit_called_then_models_train_and_returned_predictor_inference_has_

assert isinstance(predictions, TimeSeriesDataFrame)

predicted_item_index = predictions.index.levels[0]
assert all(predicted_item_index == DUMMY_TS_DATAFRAME.index.levels[0]) # noqa
predicted_item_index = predictions.item_ids
assert all(predicted_item_index == DUMMY_TS_DATAFRAME.item_ids) # noqa
assert all(k in predictions.columns for k in ["mean"] + [str(q) for q in quantile_levels])


Expand All @@ -196,8 +197,8 @@ def test_when_fit_called_then_models_train_and_returned_predictor_inference_corr

assert isinstance(predictions, TimeSeriesDataFrame)

predicted_item_index = predictions.index.levels[0]
assert all(predicted_item_index == train_data.index.levels[0]) # noqa
predicted_item_index = predictions.item_ids
assert all(predicted_item_index == train_data.item_ids)
assert all(len(predictions.loc[i]) == prediction_length for i in predicted_item_index)
assert all(predictions.loc[i].index[0].hour > 0 for i in predicted_item_index)

Expand Down Expand Up @@ -272,7 +273,7 @@ def test_when_predict_called_with_test_data_then_predictor_inference_correct(
assert isinstance(predictions, TimeSeriesDataFrame)
assert len(predictions) == test_data.num_items * prediction_length

predicted_item_index = predictions.index.levels[0]
assert all(predicted_item_index == test_data.index.levels[0]) # noqa
predicted_item_index = predictions.item_ids
assert all(predicted_item_index == test_data.item_ids) # noqa
assert all(len(predictions.loc[i]) == prediction_length for i in predicted_item_index)
assert all(predictions.loc[i].index[0].hour > 0 for i in predicted_item_index)
8 changes: 4 additions & 4 deletions timeseries/tests/unittests/test_learner.py
Original file line number Diff line number Diff line change
Expand Up @@ -90,8 +90,8 @@ def test_given_hyperparameters_when_learner_called_then_model_can_predict(

assert isinstance(predictions, TimeSeriesDataFrame)

predicted_item_index = predictions.index.levels[0]
assert all(predicted_item_index == DUMMY_TS_DATAFRAME.index.levels[0]) # noqa
predicted_item_index = predictions.item_ids
assert all(predicted_item_index == DUMMY_TS_DATAFRAME.item_ids) # noqa
assert all(len(predictions.loc[i]) == 3 for i in predicted_item_index)
assert not np.any(np.isnan(predictions))

Expand Down Expand Up @@ -179,8 +179,8 @@ def test_given_hyperparameters_when_learner_called_and_loaded_back_then_all_mode

assert isinstance(predictions, TimeSeriesDataFrame)

predicted_item_index = predictions.index.levels[0]
assert all(predicted_item_index == DUMMY_TS_DATAFRAME.index.levels[0]) # noqa
predicted_item_index = predictions.item_ids
assert all(predicted_item_index == DUMMY_TS_DATAFRAME.item_ids)
assert all(len(predictions.loc[i]) == 2 for i in predicted_item_index)
assert not np.any(np.isnan(predictions))

Expand Down
26 changes: 13 additions & 13 deletions timeseries/tests/unittests/test_predictor.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,8 +50,8 @@ def test_given_hyperparameters_when_predictor_called_then_model_can_predict(temp

assert isinstance(predictions, TimeSeriesDataFrame)

predicted_item_index = predictions.index.levels[0]
assert all(predicted_item_index == DUMMY_TS_DATAFRAME.index.levels[0]) # noqa
predicted_item_index = predictions.item_ids
assert all(predicted_item_index == DUMMY_TS_DATAFRAME.item_ids) # noqa
assert all(len(predictions.loc[i]) == 3 for i in predicted_item_index)
assert not np.any(np.isnan(predictions))

Expand All @@ -75,8 +75,8 @@ def test_given_different_target_name_when_predictor_called_then_model_can_predic

assert isinstance(predictions, TimeSeriesDataFrame)

predicted_item_index = predictions.index.levels[0]
assert all(predicted_item_index == DUMMY_TS_DATAFRAME.index.levels[0]) # noqa
predicted_item_index = predictions.item_ids
assert all(predicted_item_index == DUMMY_TS_DATAFRAME.item_ids) # noqa
assert all(len(predictions.loc[i]) == 3 for i in predicted_item_index)
assert not np.any(np.isnan(predictions))

Expand All @@ -92,8 +92,8 @@ def test_given_no_tuning_data_when_predictor_called_then_model_can_predict(temp_

assert isinstance(predictions, TimeSeriesDataFrame)

predicted_item_index = predictions.index.levels[0]
assert all(predicted_item_index == DUMMY_TS_DATAFRAME.index.levels[0]) # noqa
predicted_item_index = predictions.item_ids
assert all(predicted_item_index == DUMMY_TS_DATAFRAME.item_ids) # noqa
assert all(len(predictions.loc[i]) == 3 for i in predicted_item_index)
assert not np.any(np.isnan(predictions))

Expand Down Expand Up @@ -169,8 +169,8 @@ def test_given_hyperparameters_when_predictor_called_and_loaded_back_then_all_mo

assert isinstance(predictions, TimeSeriesDataFrame)

predicted_item_index = predictions.index.levels[0]
assert all(predicted_item_index == DUMMY_TS_DATAFRAME.index.levels[0]) # noqa
predicted_item_index = predictions.item_ids
assert all(predicted_item_index == DUMMY_TS_DATAFRAME.item_ids) # noqa
assert all(len(predictions.loc[i]) == 2 for i in predicted_item_index)
assert not np.any(np.isnan(predictions))

Expand Down Expand Up @@ -220,8 +220,8 @@ def test_given_hp_spaces_and_custom_target_when_predictor_called_predictor_can_p

assert isinstance(predictions, TimeSeriesDataFrame)

predicted_item_index = predictions.index.levels[0]
assert all(predicted_item_index == df.index.levels[0]) # noqa
predicted_item_index = predictions.item_ids
assert all(predicted_item_index == df.item_ids) # noqa
assert all(len(predictions.loc[i]) == 2 for i in predicted_item_index)
assert not np.any(np.isnan(predictions))

Expand All @@ -245,8 +245,8 @@ def test_given_hyperparameters_when_predictor_called_and_loaded_back_then_loaded

assert isinstance(predictions, TimeSeriesDataFrame)

predicted_item_index = predictions.index.levels[0]
assert all(predicted_item_index == DUMMY_TS_DATAFRAME.index.levels[0]) # noqa
predicted_item_index = predictions.item_ids
assert all(predicted_item_index == DUMMY_TS_DATAFRAME.item_ids) # noqa
assert all(len(predictions.loc[i]) == 2 for i in predicted_item_index)
assert not np.any(np.isnan(predictions))

Expand Down Expand Up @@ -347,7 +347,7 @@ def test_given_irregular_time_series_when_predictor_called_with_ignore_then_pred
assert isinstance(predictions, TimeSeriesDataFrame)

predicted_item_index = predictions.item_ids
assert all(predicted_item_index == df._item_index) # noqa
assert all(predicted_item_index == df.item_ids) # noqa
assert all(len(predictions.loc[i]) == 1 for i in predicted_item_index)
assert not np.any(np.isnan(predictions))

Expand Down
8 changes: 4 additions & 4 deletions timeseries/tests/unittests/test_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -119,8 +119,8 @@ def test_given_hyperparameters_when_trainer_called_then_model_can_predict(

assert isinstance(predictions, TimeSeriesDataFrame)

predicted_item_index = predictions.index.levels[0]
assert all(predicted_item_index == DUMMY_TS_DATAFRAME.index.levels[0]) # noqa
predicted_item_index = predictions.item_ids
assert all(predicted_item_index == DUMMY_TS_DATAFRAME.item_ids) # noqa
assert all(len(predictions.loc[i]) == 3 for i in predicted_item_index)
assert not np.any(np.isnan(predictions))

Expand Down Expand Up @@ -529,7 +529,7 @@ def test_when_trainer_fit_and_deleted_models_load_back_correctly_and_can_predict

assert isinstance(predictions, TimeSeriesDataFrame)

predicted_item_index = predictions.index.levels[0]
assert all(predicted_item_index == DUMMY_TS_DATAFRAME.index.levels[0]) # noqa
predicted_item_index = predictions.item_ids
assert all(predicted_item_index == DUMMY_TS_DATAFRAME.item_ids) # noqa
assert all(len(predictions.loc[i]) == 2 for i in predicted_item_index)
assert not np.any(np.isnan(predictions))
12 changes: 6 additions & 6 deletions timeseries/tests/unittests/test_ts_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -449,11 +449,11 @@ def test_when_dataset_sliced_by_step_then_output_times_and_values_correct(
)
def test_when_dataset_sliced_by_step_then_order_of_item_index_is_preserved(input_iterable, input_slice):
df = TimeSeriesDataFrame.from_iterable_dataset(input_iterable)
new_idx = df._item_index[::-1]
new_idx = df.item_ids[::-1]
df.index = df.index.set_levels(new_idx, level=ITEMID)
dfv = df.slice_by_timestep(input_slice)

assert dfv._item_index.equals(new_idx)
assert dfv.item_ids.equals(new_idx)


@pytest.mark.parametrize("input_df", [SAMPLE_TS_DATAFRAME, SAMPLE_TS_DATAFRAME_EMPTY])
Expand Down Expand Up @@ -591,7 +591,7 @@ def test_when_dataset_sliced_by_step_then_static_features_are_correct():
dfv = df.slice_by_timestep(slice(-2, None))

assert isinstance(dfv, TimeSeriesDataFrame)
assert len(dfv) == 2 * len(dfv.index.levels[0])
assert len(dfv) == 2 * len(dfv.item_ids)

assert dfv.static_features.equals(df.static_features)

Expand All @@ -601,16 +601,16 @@ def test_when_dataset_subsequenced_then_static_features_are_correct():
dfv = df.subsequence(START_TIMESTAMP, START_TIMESTAMP + datetime.timedelta(days=1))

assert isinstance(dfv, TimeSeriesDataFrame)
assert len(dfv) == 1 * len(dfv.index.levels[0])
assert len(dfv) == 1 * len(dfv.item_ids)

assert dfv.static_features.equals(df.static_features)


def test_when_dataset_split_by_time_then_static_features_are_correct():
left, right = SAMPLE_TS_DATAFRAME_STATIC.split_by_time(START_TIMESTAMP + datetime.timedelta(days=1))

assert len(left) == 1 * len(SAMPLE_TS_DATAFRAME_STATIC.index.levels[0])
assert len(right) == 2 * len(SAMPLE_TS_DATAFRAME_STATIC.index.levels[0])
assert len(left) == 1 * len(SAMPLE_TS_DATAFRAME_STATIC.item_ids)
assert len(right) == 2 * len(SAMPLE_TS_DATAFRAME_STATIC.item_ids)

assert left.static_features.equals(SAMPLE_TS_DATAFRAME_STATIC.static_features)
assert right.static_features.equals(SAMPLE_TS_DATAFRAME_STATIC.static_features)
Expand Down

0 comments on commit 686510e

Please sign in to comment.