Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix incorrect input routing for models #3186

Merged
merged 3 commits into from
May 31, 2024
Merged

Conversation

shchur
Copy link
Contributor

@shchur shchur commented May 31, 2024

Fixes #3185

Description of changes:

There is currently a bug where the model inputs may be routed incorrect by the forecast generator. This effectively results in past_feat_dynamic_real and past_feat_dynamic_cat being ignored by the TFT model.

MWE:

from unittest import mock
import numpy as np
import pandas as pd
from gluonts.torch.model.tft import TemporalFusionTransformerEstimator

freq = "D"
N = 50
data = [
    {"target": np.arange(N), "past_feat_dynamic_real": np.random.rand(1, N).astype("float32"), "start": pd.Period("2020-01-01", freq=freq)}
]

predictor = TemporalFusionTransformerEstimator(prediction_length=1, freq=freq, past_dynamic_dims=[1], trainer_kwargs={"max_epochs": 1}).train(data)

with mock.patch("gluonts.torch.model.tft.module.TemporalFusionTransformerModel._preprocess") as mock_fwd:
    try:
        fcst = list(predictor.predict(data))
    except:
        pass   
    call_kwargs = mock_fwd.call_args[1]

call_kwargs["feat_dynamic_cat"]  
# tensor([[[0.8073]]])
call_kwargs["past_feat_dynamic_real"]  
# None

The bug occurs because model inputs are passed as positional arguments instead of keyword arguments.

By submitting this pull request, I confirm that you can use, modify, copy, and redistribute this contribution, under the terms of your choice.

Please tag this pr with at least one of these labels to make our release process faster: BREAKING, new feature, bug fix, other change, dev setup

@shchur shchur requested a review from lostella May 31, 2024 09:51
@shchur shchur added bug Something isn't working bug fix (one of pr required labels) and removed bug Something isn't working labels May 31, 2024
@@ -82,6 +82,14 @@ def make_distribution_forecast(distr, *args, **kwargs) -> Forecast:
raise NotImplementedError


def make_predictions(prediction_net, inputs: dict):
# MXNet predictors only support positional arguments
if prediction_net.__class__.__module__.startswith("gluonts.mx"):
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I couldn't find a more elegant way to use different logic for MXNet and PyTorch models :/

I tried @singledispatch, but that doesn't work for subclasses (i.e., we'd need to define it for all subclasses of pl.LightningModule in GluonTS, and same for MXNet).

@lostella lostella added pending v0.15.x backport This contains a fix to be backported to the v0.15.x branch pending v0.14.x backport This contains a fix to be backported to the v0.14.x branch labels May 31, 2024
@lostella lostella merged commit 5e30960 into awslabs:dev May 31, 2024
20 checks passed
lostella pushed a commit to lostella/gluonts that referenced this pull request May 31, 2024
Fixes awslabs#3185

*Description of changes:*

There is currently a bug where the model inputs may be routed incorrect
by the forecast generator. This effectively results in
`past_feat_dynamic_real` and `past_feat_dynamic_cat` being ignored by
the TFT model.

MWE:
```python
from unittest import mock
import numpy as np
import pandas as pd
from gluonts.torch.model.tft import TemporalFusionTransformerEstimator

freq = "D"
N = 50
data = [
    {"target": np.arange(N), "past_feat_dynamic_real": np.random.rand(1, N).astype("float32"), "start": pd.Period("2020-01-01", freq=freq)}
]

predictor = TemporalFusionTransformerEstimator(prediction_length=1, freq=freq, past_dynamic_dims=[1], trainer_kwargs={"max_epochs": 1}).train(data)

with mock.patch("gluonts.torch.model.tft.module.TemporalFusionTransformerModel._preprocess") as mock_fwd:
    try:
        fcst = list(predictor.predict(data))
    except:
        pass   
    call_kwargs = mock_fwd.call_args[1]

call_kwargs["feat_dynamic_cat"]  
# tensor([[[0.8073]]])
call_kwargs["past_feat_dynamic_real"]  
# None
```

The bug occurs because model inputs are passed as positional arguments
instead of keyword arguments.


By submitting this pull request, I confirm that you can use, modify,
copy, and redistribute this contribution, under the terms of your
choice.


**Please tag this pr with at least one of these labels to make our
release process faster:** BREAKING, new feature, bug fix, other change,
dev setup
@lostella lostella mentioned this pull request May 31, 2024
lostella added a commit that referenced this pull request May 31, 2024
*Description of changes:* backporting fixes
- #3186 


By submitting this pull request, I confirm that you can use, modify,
copy, and redistribute this contribution, under the terms of your
choice.


**Please tag this pr with at least one of these labels to make our
release process faster:** BREAKING, new feature, bug fix, other change,
dev setup

Co-authored-by: Oleksandr Shchur <shchuro@amazon.com>
@lostella lostella removed the pending v0.15.x backport This contains a fix to be backported to the v0.15.x branch label May 31, 2024
kashif pushed a commit to kashif/gluon-ts that referenced this pull request Jun 15, 2024
Fixes awslabs#3185

*Description of changes:*

There is currently a bug where the model inputs may be routed incorrect
by the forecast generator. This effectively results in
`past_feat_dynamic_real` and `past_feat_dynamic_cat` being ignored by
the TFT model.

MWE:
```python
from unittest import mock
import numpy as np
import pandas as pd
from gluonts.torch.model.tft import TemporalFusionTransformerEstimator

freq = "D"
N = 50
data = [
    {"target": np.arange(N), "past_feat_dynamic_real": np.random.rand(1, N).astype("float32"), "start": pd.Period("2020-01-01", freq=freq)}
]

predictor = TemporalFusionTransformerEstimator(prediction_length=1, freq=freq, past_dynamic_dims=[1], trainer_kwargs={"max_epochs": 1}).train(data)

with mock.patch("gluonts.torch.model.tft.module.TemporalFusionTransformerModel._preprocess") as mock_fwd:
    try:
        fcst = list(predictor.predict(data))
    except:
        pass   
    call_kwargs = mock_fwd.call_args[1]

call_kwargs["feat_dynamic_cat"]  
# tensor([[[0.8073]]])
call_kwargs["past_feat_dynamic_real"]  
# None
```

The bug occurs because model inputs are passed as positional arguments
instead of keyword arguments.


By submitting this pull request, I confirm that you can use, modify,
copy, and redistribute this contribution, under the terms of your
choice.


**Please tag this pr with at least one of these labels to make our
release process faster:** BREAKING, new feature, bug fix, other change,
dev setup
@lostella lostella removed the pending v0.14.x backport This contains a fix to be backported to the v0.14.x branch label Nov 7, 2024
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
bug fix (one of pr required labels)
Projects
None yet
Development

Successfully merging this pull request may close these issues.

Error calling make_evaluation_predictions with TFT using past_feat_dynamic_real after update to 0.15.0
2 participants