Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Fix incorrect input routing for models (awslabs#3186)
Fixes awslabs#3185 *Description of changes:* There is currently a bug where the model inputs may be routed incorrect by the forecast generator. This effectively results in `past_feat_dynamic_real` and `past_feat_dynamic_cat` being ignored by the TFT model. MWE: ```python from unittest import mock import numpy as np import pandas as pd from gluonts.torch.model.tft import TemporalFusionTransformerEstimator freq = "D" N = 50 data = [ {"target": np.arange(N), "past_feat_dynamic_real": np.random.rand(1, N).astype("float32"), "start": pd.Period("2020-01-01", freq=freq)} ] predictor = TemporalFusionTransformerEstimator(prediction_length=1, freq=freq, past_dynamic_dims=[1], trainer_kwargs={"max_epochs": 1}).train(data) with mock.patch("gluonts.torch.model.tft.module.TemporalFusionTransformerModel._preprocess") as mock_fwd: try: fcst = list(predictor.predict(data)) except: pass call_kwargs = mock_fwd.call_args[1] call_kwargs["feat_dynamic_cat"] # tensor([[[0.8073]]]) call_kwargs["past_feat_dynamic_real"] # None ``` The bug occurs because model inputs are passed as positional arguments instead of keyword arguments. By submitting this pull request, I confirm that you can use, modify, copy, and redistribute this contribution, under the terms of your choice. **Please tag this pr with at least one of these labels to make our release process faster:** BREAKING, new feature, bug fix, other change, dev setup
- Loading branch information