Skip to content

Commit

Permalink
Backports v0.14.3 (#3077)
Browse files Browse the repository at this point in the history
* Fix Rotbaum serialization and deserialization (#3068)

* Fix Rotbaum to handle short series (#3073)

---------

Co-authored-by: Anurag Pant <anuragpant@cs.ucla.edu>
  • Loading branch information
lostella and pantanurag555 authored Dec 7, 2023
1 parent 3c434d8 commit ed8d813
Show file tree
Hide file tree
Showing 5 changed files with 133 additions and 10 deletions.
9 changes: 8 additions & 1 deletion src/gluonts/ext/rotbaum/_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@
import gc
from collections import defaultdict

from gluonts.core.component import validated
from gluonts.core.component import equals, validated


class QRF:
Expand Down Expand Up @@ -121,6 +121,13 @@ def _create_xgboost_model(model_params: Optional[dict] = None):
}
return xgboost.sklearn.XGBModel(**model_params)

def __eq__(self, that):
"""
Two QRX instances are considered equal if they have the same
constructor arguments.
"""
return equals(self, that)

def fit(
self,
x_train: Union[pd.DataFrame, List],
Expand Down
27 changes: 27 additions & 0 deletions src/gluonts/ext/rotbaum/_predictor.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,12 +13,14 @@

import concurrent.futures
import logging
import pickle
from itertools import chain
from typing import Iterator, List, Optional, Any, Dict
from toolz import first

import numpy as np
import pandas as pd
from pathlib import Path
from itertools import compress

from gluonts.core.component import validated
Expand Down Expand Up @@ -340,6 +342,31 @@ def predict( # type: ignore
item_id=ts.get("item_id"),
)

def serialize(self, path: Path) -> None:
"""
This function calls parent class serialize() in order to serialize
the class name, version information and constuctor arguments. It
persists the tree predictor by pickling the model list that is
generated when pickling the TreePredictor.
"""
super().serialize(path)
with (path / "predictor.pkl").open("wb") as f:
pickle.dump(self.model_list, f)

@classmethod
def deserialize(cls, path: Path, **kwargs: Any) -> "TreePredictor":
"""
This function loads and returns the serialized model. It loads
the predictor class with the serialized arguments. It then loads
the trained model list by reading the pickle file.
"""

predictor = super().deserialize(path)
assert isinstance(predictor, cls)
with (path / "predictor.pkl").open("rb") as f:
predictor.model_list = pickle.load(f)
return predictor

def explain(
self, importance_type: str = "gain", percentage: bool = True
) -> ExplanationResult:
Expand Down
14 changes: 10 additions & 4 deletions src/gluonts/ext/rotbaum/_preprocess.py
Original file line number Diff line number Diff line change
Expand Up @@ -452,9 +452,12 @@ def make_features(self, time_series: Dict, starting_index: int) -> List:
end_index = starting_index + self.context_window_size
if starting_index < 0:
prefix = [None] * abs(starting_index)
time_series_window = time_series["target"]
else:
prefix = []
time_series_window = time_series["target"][starting_index:end_index]
time_series_window = time_series["target"][
starting_index:end_index
]
only_lag_features, transform_dict = self._pre_transform(
time_series_window, self.subtract_mean, self.count_nans
)
Expand All @@ -464,7 +467,10 @@ def make_features(self, time_series: Dict, starting_index: int) -> List:
if self.use_feat_static_real
else []
)
if self.cardinality:
if (
self.cardinality
and time_series.get("feat_static_cat", None) is not None
):
feat_static_cat = (
self.encode_one_hot_all(time_series["feat_static_cat"])
if self.one_hot_encode
Expand All @@ -477,10 +483,10 @@ def make_features(self, time_series: Dict, starting_index: int) -> List:
list(
chain(
*[
list(ent[0]) + list(ent[1].values())
prefix + list(ent[0]) + list(ent[1].values())
for ent in [
self._pre_transform(
ts[starting_index:end_index],
ts if prefix else ts[starting_index:end_index],
self.subtract_mean,
self.count_nans,
)
Expand Down
24 changes: 20 additions & 4 deletions test/ext/rotbaum/test_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,10 +11,11 @@
# express or implied. See the License for the specific language governing
# permissions and limitations under the License.


from pathlib import Path
import pytest
import tempfile

from gluonts.ext.rotbaum import TreeEstimator
from gluonts.ext.rotbaum import TreeEstimator, TreePredictor


@pytest.fixture()
Expand All @@ -33,5 +34,20 @@ def test_accuracy(accuracy_test, hyperparameters, quantiles):
accuracy_test(TreeEstimator, hyperparameters, accuracy=0.20)


def test_serialize(serialize_test, hyperparameters):
serialize_test(TreeEstimator, hyperparameters)
def test_serialize(serialize_test, hyperparameters, dsinfo):
forecaster = TreeEstimator.from_hyperparameters(
freq=dsinfo.freq,
**{
"prediction_length": dsinfo.prediction_length,
"num_parallel_samples": dsinfo.num_parallel_samples,
},
**hyperparameters,
)

predictor_act = forecaster.train(dsinfo.train_ds)

with tempfile.TemporaryDirectory() as temp_dir:
predictor_act.serialize(Path(temp_dir))
predictor_exp = TreePredictor.deserialize(Path(temp_dir))
assert predictor_act == predictor_exp
assert predictor_act.model_list == predictor_exp.model_list
69 changes: 68 additions & 1 deletion test/ext/rotbaum/test_rotbaum_smoke.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,10 +12,12 @@
# permissions and limitations under the License.

import pytest
import numpy as np

from gluonts.ext.rotbaum import TreeEstimator
from gluonts.ext.rotbaum import TreeEstimator, TreePredictor

from gluonts.testutil.dummy_datasets import make_dummy_datasets_with_features
from gluonts.dataset.common import ListDataset

# TODO: Add support for categorical and dynamic features.

Expand Down Expand Up @@ -59,3 +61,68 @@ def test_rotbaum_smoke(datasets):
predictor = estimator.train(dataset_train)
forecasts = list(predictor.predict(dataset_test))
assert len(forecasts) == len(dataset_test)


def test_short_history_item_pred():
prediction_length = 7
freq = "D"

dataset = ListDataset(
data_iter=[
{
"start": "2017-10-11",
"item_id": "item_1",
"target": np.array(
[
1.0,
9.0,
2.0,
0.0,
0.0,
1.0,
5.0,
3.0,
4.0,
2.0,
0.0,
0.0,
1.0,
6.0,
]
),
"feat_static_cat": np.array([0.0, 0.0], dtype=float),
"past_feat_dynamic_real": np.array(
[
[1.0222e06 for i in range(14)],
[750.0 for i in range(14)],
]
),
},
{
"start": "2017-10-11",
"item_id": "item_2",
"target": np.array([7.0, 0.0, 0.0, 23.0, 13.0]),
"feat_static_cat": np.array([0.0, 1.0], dtype=float),
"past_feat_dynamic_real": np.array(
[[0 for i in range(5)], [750.0 for i in range(5)]]
),
},
],
freq=freq,
)

predictor = TreePredictor(
freq=freq,
prediction_length=prediction_length,
quantiles=[0.1, 0.5, 0.9],
max_n_datapts=50000,
method="QuantileRegression",
use_past_feat_dynamic_real=True,
use_feat_dynamic_real=False,
use_feat_dynamic_cat=False,
use_feat_static_real=False,
cardinality="auto",
)
predictor = predictor.train(dataset)
forecasts = list(predictor.predict(dataset))
assert forecasts[1].quantile(0.5).shape[0] == prediction_length

0 comments on commit ed8d813

Please sign in to comment.