Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Backports v0.15.1 #3187

Merged
merged 1 commit into from
May 31, 2024
Merged

Backports v0.15.1 #3187

merged 1 commit into from
May 31, 2024

Conversation

lostella
Copy link
Contributor

Description of changes: backporting fixes

By submitting this pull request, I confirm that you can use, modify, copy, and redistribute this contribution, under the terms of your choice.

Please tag this pr with at least one of these labels to make our release process faster: BREAKING, new feature, bug fix, other change, dev setup

Fixes awslabs#3185

*Description of changes:*

There is currently a bug where the model inputs may be routed incorrect
by the forecast generator. This effectively results in
`past_feat_dynamic_real` and `past_feat_dynamic_cat` being ignored by
the TFT model.

MWE:
```python
from unittest import mock
import numpy as np
import pandas as pd
from gluonts.torch.model.tft import TemporalFusionTransformerEstimator

freq = "D"
N = 50
data = [
    {"target": np.arange(N), "past_feat_dynamic_real": np.random.rand(1, N).astype("float32"), "start": pd.Period("2020-01-01", freq=freq)}
]

predictor = TemporalFusionTransformerEstimator(prediction_length=1, freq=freq, past_dynamic_dims=[1], trainer_kwargs={"max_epochs": 1}).train(data)

with mock.patch("gluonts.torch.model.tft.module.TemporalFusionTransformerModel._preprocess") as mock_fwd:
    try:
        fcst = list(predictor.predict(data))
    except:
        pass   
    call_kwargs = mock_fwd.call_args[1]

call_kwargs["feat_dynamic_cat"]  
# tensor([[[0.8073]]])
call_kwargs["past_feat_dynamic_real"]  
# None
```

The bug occurs because model inputs are passed as positional arguments
instead of keyword arguments.


By submitting this pull request, I confirm that you can use, modify,
copy, and redistribute this contribution, under the terms of your
choice.


**Please tag this pr with at least one of these labels to make our
release process faster:** BREAKING, new feature, bug fix, other change,
dev setup
@lostella lostella changed the base branch from dev to v0.15.x May 31, 2024 12:55
@lostella lostella requested a review from shchur May 31, 2024 12:55
@lostella lostella merged commit 0cb0808 into awslabs:v0.15.x May 31, 2024
19 of 20 checks passed
@lostella lostella deleted the backports-v0.15.1 branch May 31, 2024 12:58
@lostella lostella added the backport This PR backports changes to old releases label May 31, 2024
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
backport This PR backports changes to old releases
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants