Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Bug Fix for #1071: NHiTs encoder/decoder covariate size difference bug #1359

Merged
merged 8 commits into from
Oct 2, 2023
6 changes: 5 additions & 1 deletion pytorch_forecasting/models/base_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -393,6 +393,7 @@ def forward(self, x):

def __init__(
self,
dataset_parameters: Dict[str, Any] = None,
log_interval: Union[int, float] = -1,
log_val_interval: Union[int, float] = None,
learning_rate: Union[float, List[float]] = 1e-3,
Expand Down Expand Up @@ -467,6 +468,8 @@ def __init__(
self.output_transformer = output_transformer
if not hasattr(self, "optimizer"): # callables are removed from hyperparameters, so better to save them
self.optimizer = self.hparams.optimizer
if not hasattr(self, "dataset_parameters"):
self.dataset_parameters = dataset_parameters

# delete everything from hparams that cannot be serialized with yaml.dump
# which is particularly important for tensorboard logging
Expand Down Expand Up @@ -1235,8 +1238,9 @@ def from_dataset(cls, dataset: TimeSeriesDataSet, **kwargs) -> LightningModule:
"""
if "output_transformer" not in kwargs:
kwargs["output_transformer"] = dataset.target_normalizer
if "dataset_parameters" not in kwargs:
kwargs["dataset_parameters"] = dataset.get_parameters()
net = cls(**kwargs)
net.dataset_parameters = dataset.get_parameters()
if dataset.multi_target:
assert isinstance(
net.loss, MultiLoss
Expand Down
27 changes: 21 additions & 6 deletions pytorch_forecasting/models/nhits/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -179,7 +179,8 @@ def __init__(
prediction_length=self.hparams.prediction_length,
output_size=to_list(output_size),
static_size=self.static_size,
covariate_size=self.covariate_size,
encoder_covariate_size=self.encoder_covariate_size,
decoder_covariate_size=self.decoder_covariate_size,
static_hidden_size=self.hparams.static_hidden_size,
n_blocks=self.hparams.n_blocks,
n_layers=self.hparams.n_layers,
Expand All @@ -197,13 +198,24 @@ def __init__(
)

@property
def covariate_size(self) -> int:
"""Covariate size.
def decoder_covariate_size(self) -> int:
"""Decoder covariates size.

Returns:
int: size of time-dependent covariates
int: size of time-dependent covariates used by the decoder
"""
return len(set(self.hparams.time_varying_reals_decoder) - set(self.target_names)) + sum(
self.embeddings.output_size[name] for name in self.hparams.time_varying_categoricals_decoder
)

@property
def encoder_covariate_size(self) -> int:
"""Encoder covariate size.

Returns:
int: size of time-dependent covariates used by the encoder
"""
return len(set(self.hparams.time_varying_reals_encoder) - set(self.target_names)) + sum(
self.embeddings.output_size[name] for name in self.hparams.time_varying_categoricals_encoder
)

Expand Down Expand Up @@ -239,16 +251,19 @@ def forward(self, x: Dict[str, torch.Tensor]) -> Dict[str, torch.Tensor]:
Dict[str, torch.Tensor]: output of model
"""
# covariates
if self.covariate_size > 0:
if self.encoder_covariate_size > 0:
encoder_features = self.extract_features(x, self.embeddings, period="encoder")
encoder_x_t = torch.concat(
[encoder_features[name] for name in self.encoder_variables if name not in self.target_names],
dim=2,
)
else:
encoder_x_t = None

if self.decoder_covariate_size > 0:
decoder_features = self.extract_features(x, self.embeddings, period="decoder")
decoder_x_t = torch.concat([decoder_features[name] for name in self.decoder_variables], dim=2)
else:
encoder_x_t = None
decoder_x_t = None

# statics
Expand Down
31 changes: 23 additions & 8 deletions pytorch_forecasting/models/nhits/sub_modules.py
Original file line number Diff line number Diff line change
Expand Up @@ -91,7 +91,8 @@ def __init__(
context_length: int,
prediction_length: int,
output_size: int,
covariate_size: int,
encoder_covariate_size: int,
decoder_covariate_size: int,
static_size: int,
static_hidden_size: int,
n_theta: int,
Expand Down Expand Up @@ -119,14 +120,16 @@ def __init__(
self.prediction_length = prediction_length
self.static_size = static_size
self.static_hidden_size = static_hidden_size
self.covariate_size = covariate_size
self.encoder_covariate_size = encoder_covariate_size
self.decoder_covariate_size = decoder_covariate_size
self.pooling_sizes = pooling_sizes
self.batch_normalization = batch_normalization
self.dropout = dropout

self.hidden_size = [
self.context_length_pooled * len(self.output_size)
+ (self.context_length + self.prediction_length) * self.covariate_size
+ self.context_length * self.encoder_covariate_size
+ self.prediction_length * self.decoder_covariate_size
+ self.static_hidden_size
] + hidden_size

Expand Down Expand Up @@ -173,11 +176,19 @@ def forward(
encoder_y = self.pooling_layer(encoder_y)
encoder_y = encoder_y.transpose(1, 2).reshape(batch_size, -1)

if self.covariate_size > 0:
if self.encoder_covariate_size > 0:
encoder_y = torch.cat(
(
encoder_y,
encoder_x_t.reshape(batch_size, -1),
),
1,
)

if self.decoder_covariate_size > 0:
encoder_y = torch.cat(
(
encoder_y,
decoder_x_t.reshape(batch_size, -1),
),
1,
Expand Down Expand Up @@ -210,7 +221,8 @@ def __init__(
prediction_length,
output_size: int,
static_size,
covariate_size,
encoder_covariate_size,
decoder_covariate_size,
static_hidden_size,
n_blocks: list,
n_layers: list,
Expand Down Expand Up @@ -238,7 +250,8 @@ def __init__(
context_length=context_length,
prediction_length=prediction_length,
output_size=output_size,
covariate_size=covariate_size,
encoder_covariate_size=encoder_covariate_size,
decoder_covariate_size=decoder_covariate_size,
static_size=static_size,
static_hidden_size=static_hidden_size,
n_layers=n_layers,
Expand All @@ -261,7 +274,8 @@ def create_stack(
context_length,
prediction_length,
output_size,
covariate_size,
encoder_covariate_size,
decoder_covariate_size,
static_size,
static_hidden_size,
n_layers,
Expand Down Expand Up @@ -300,7 +314,8 @@ def create_stack(
context_length=context_length,
prediction_length=prediction_length,
output_size=output_size,
covariate_size=covariate_size,
encoder_covariate_size=encoder_covariate_size,
decoder_covariate_size=decoder_covariate_size,
static_size=static_size,
static_hidden_size=static_hidden_size,
n_theta=n_theta,
Expand Down
29 changes: 29 additions & 0 deletions tests/test_models/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -130,6 +130,35 @@ def make_dataloaders(data_with_covariates, **kwargs):
def multiple_dataloaders_with_covariates(data_with_covariates, request):
return make_dataloaders(data_with_covariates, **request.param)

@pytest.fixture(scope="session")
def dataloaders_with_different_encoder_decoder_length(data_with_covariates):
return make_dataloaders(
data_with_covariates.copy(),
target="target",
time_varying_known_categoricals=["special_days", "month"],
variable_groups=dict(
special_days=[
"easter_day",
"good_friday",
"new_year",
"christmas",
"labor_day",
"independence_day",
"revolution_day_memorial",
"regional_games",
"fifa_u_17_world_cup",
"football_gold_cup",
"beer_capital",
"music_fest",
]
),
time_varying_known_reals=["time_idx", "price_regular", "price_actual", "discount", "discount_in_percent"],
time_varying_unknown_categoricals=[],
time_varying_unknown_reals=["target", "volume", "log_volume", "industry_volume", "soda_volume", "avg_max_temp"],
static_categoricals=["agency"],
add_relative_time_idx=False,
target_normalizer=GroupNormalizer(groups=["agency", "sku"], center=False),
)

@pytest.fixture(scope="session")
def dataloaders_with_covariates(data_with_covariates):
Expand Down
4 changes: 4 additions & 0 deletions tests/test_models/test_nhits.py
Original file line number Diff line number Diff line change
Expand Up @@ -77,6 +77,7 @@ def _integration(dataloader, tmp_path, trainer_kwargs=None, **kwargs):
"dataloader",
[
"with_covariates",
"different_encoder_decoder_size",
"fixed_window_without_covariates",
"multi_target",
"quantiles",
Expand All @@ -86,6 +87,7 @@ def _integration(dataloader, tmp_path, trainer_kwargs=None, **kwargs):
)
def test_integration(
dataloaders_with_covariates,
dataloaders_with_different_encoder_decoder_length,
dataloaders_fixed_window_without_covariates,
dataloaders_multi_target,
tmp_path,
Expand All @@ -95,6 +97,8 @@ def test_integration(
if dataloader == "with_covariates":
dataloader = dataloaders_with_covariates
kwargs["backcast_loss_ratio"] = 0.5
elif dataloader == "different_encoder_decoder_size":
dataloader = dataloaders_with_different_encoder_decoder_length
elif dataloader == "fixed_window_without_covariates":
dataloader = dataloaders_fixed_window_without_covariates
elif dataloader == "multi_target":
Expand Down
Loading