Skip to content

Commit

Permalink
Merge pull request #1359 from bendavidsteel/weight-size-bug-fix
Browse files Browse the repository at this point in the history
  • Loading branch information
jdb78 authored Oct 2, 2023
2 parents 6b9133a + 6ccbc59 commit 8128061
Show file tree
Hide file tree
Showing 5 changed files with 84 additions and 15 deletions.
6 changes: 5 additions & 1 deletion pytorch_forecasting/models/base_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -393,6 +393,7 @@ def forward(self, x):

def __init__(
self,
dataset_parameters: Dict[str, Any] = None,
log_interval: Union[int, float] = -1,
log_val_interval: Union[int, float] = None,
learning_rate: Union[float, List[float]] = 1e-3,
Expand Down Expand Up @@ -467,6 +468,8 @@ def __init__(
self.output_transformer = output_transformer
if not hasattr(self, "optimizer"): # callables are removed from hyperparameters, so better to save them
self.optimizer = self.hparams.optimizer
if not hasattr(self, "dataset_parameters"):
self.dataset_parameters = dataset_parameters

# delete everything from hparams that cannot be serialized with yaml.dump
# which is particularly important for tensorboard logging
Expand Down Expand Up @@ -1235,8 +1238,9 @@ def from_dataset(cls, dataset: TimeSeriesDataSet, **kwargs) -> LightningModule:
"""
if "output_transformer" not in kwargs:
kwargs["output_transformer"] = dataset.target_normalizer
if "dataset_parameters" not in kwargs:
kwargs["dataset_parameters"] = dataset.get_parameters()
net = cls(**kwargs)
net.dataset_parameters = dataset.get_parameters()
if dataset.multi_target:
assert isinstance(
net.loss, MultiLoss
Expand Down
27 changes: 21 additions & 6 deletions pytorch_forecasting/models/nhits/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -179,7 +179,8 @@ def __init__(
prediction_length=self.hparams.prediction_length,
output_size=to_list(output_size),
static_size=self.static_size,
covariate_size=self.covariate_size,
encoder_covariate_size=self.encoder_covariate_size,
decoder_covariate_size=self.decoder_covariate_size,
static_hidden_size=self.hparams.static_hidden_size,
n_blocks=self.hparams.n_blocks,
n_layers=self.hparams.n_layers,
Expand All @@ -197,13 +198,24 @@ def __init__(
)

@property
def covariate_size(self) -> int:
"""Covariate size.
def decoder_covariate_size(self) -> int:
"""Decoder covariates size.
Returns:
int: size of time-dependent covariates
int: size of time-dependent covariates used by the decoder
"""
return len(set(self.hparams.time_varying_reals_decoder) - set(self.target_names)) + sum(
self.embeddings.output_size[name] for name in self.hparams.time_varying_categoricals_decoder
)

@property
def encoder_covariate_size(self) -> int:
"""Encoder covariate size.
Returns:
int: size of time-dependent covariates used by the encoder
"""
return len(set(self.hparams.time_varying_reals_encoder) - set(self.target_names)) + sum(
self.embeddings.output_size[name] for name in self.hparams.time_varying_categoricals_encoder
)

Expand Down Expand Up @@ -239,16 +251,19 @@ def forward(self, x: Dict[str, torch.Tensor]) -> Dict[str, torch.Tensor]:
Dict[str, torch.Tensor]: output of model
"""
# covariates
if self.covariate_size > 0:
if self.encoder_covariate_size > 0:
encoder_features = self.extract_features(x, self.embeddings, period="encoder")
encoder_x_t = torch.concat(
[encoder_features[name] for name in self.encoder_variables if name not in self.target_names],
dim=2,
)
else:
encoder_x_t = None

if self.decoder_covariate_size > 0:
decoder_features = self.extract_features(x, self.embeddings, period="decoder")
decoder_x_t = torch.concat([decoder_features[name] for name in self.decoder_variables], dim=2)
else:
encoder_x_t = None
decoder_x_t = None

# statics
Expand Down
31 changes: 23 additions & 8 deletions pytorch_forecasting/models/nhits/sub_modules.py
Original file line number Diff line number Diff line change
Expand Up @@ -91,7 +91,8 @@ def __init__(
context_length: int,
prediction_length: int,
output_size: int,
covariate_size: int,
encoder_covariate_size: int,
decoder_covariate_size: int,
static_size: int,
static_hidden_size: int,
n_theta: int,
Expand Down Expand Up @@ -119,14 +120,16 @@ def __init__(
self.prediction_length = prediction_length
self.static_size = static_size
self.static_hidden_size = static_hidden_size
self.covariate_size = covariate_size
self.encoder_covariate_size = encoder_covariate_size
self.decoder_covariate_size = decoder_covariate_size
self.pooling_sizes = pooling_sizes
self.batch_normalization = batch_normalization
self.dropout = dropout

self.hidden_size = [
self.context_length_pooled * len(self.output_size)
+ (self.context_length + self.prediction_length) * self.covariate_size
+ self.context_length * self.encoder_covariate_size
+ self.prediction_length * self.decoder_covariate_size
+ self.static_hidden_size
] + hidden_size

Expand Down Expand Up @@ -173,11 +176,19 @@ def forward(
encoder_y = self.pooling_layer(encoder_y)
encoder_y = encoder_y.transpose(1, 2).reshape(batch_size, -1)

if self.covariate_size > 0:
if self.encoder_covariate_size > 0:
encoder_y = torch.cat(
(
encoder_y,
encoder_x_t.reshape(batch_size, -1),
),
1,
)

if self.decoder_covariate_size > 0:
encoder_y = torch.cat(
(
encoder_y,
decoder_x_t.reshape(batch_size, -1),
),
1,
Expand Down Expand Up @@ -210,7 +221,8 @@ def __init__(
prediction_length,
output_size: int,
static_size,
covariate_size,
encoder_covariate_size,
decoder_covariate_size,
static_hidden_size,
n_blocks: list,
n_layers: list,
Expand Down Expand Up @@ -238,7 +250,8 @@ def __init__(
context_length=context_length,
prediction_length=prediction_length,
output_size=output_size,
covariate_size=covariate_size,
encoder_covariate_size=encoder_covariate_size,
decoder_covariate_size=decoder_covariate_size,
static_size=static_size,
static_hidden_size=static_hidden_size,
n_layers=n_layers,
Expand All @@ -261,7 +274,8 @@ def create_stack(
context_length,
prediction_length,
output_size,
covariate_size,
encoder_covariate_size,
decoder_covariate_size,
static_size,
static_hidden_size,
n_layers,
Expand Down Expand Up @@ -300,7 +314,8 @@ def create_stack(
context_length=context_length,
prediction_length=prediction_length,
output_size=output_size,
covariate_size=covariate_size,
encoder_covariate_size=encoder_covariate_size,
decoder_covariate_size=decoder_covariate_size,
static_size=static_size,
static_hidden_size=static_hidden_size,
n_theta=n_theta,
Expand Down
31 changes: 31 additions & 0 deletions tests/test_models/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -131,6 +131,37 @@ def multiple_dataloaders_with_covariates(data_with_covariates, request):
return make_dataloaders(data_with_covariates, **request.param)


@pytest.fixture(scope="session")
def dataloaders_with_different_encoder_decoder_length(data_with_covariates):
return make_dataloaders(
data_with_covariates.copy(),
target="target",
time_varying_known_categoricals=["special_days", "month"],
variable_groups=dict(
special_days=[
"easter_day",
"good_friday",
"new_year",
"christmas",
"labor_day",
"independence_day",
"revolution_day_memorial",
"regional_games",
"fifa_u_17_world_cup",
"football_gold_cup",
"beer_capital",
"music_fest",
]
),
time_varying_known_reals=["time_idx", "price_regular", "price_actual", "discount", "discount_in_percent"],
time_varying_unknown_categoricals=[],
time_varying_unknown_reals=["target", "volume", "log_volume", "industry_volume", "soda_volume", "avg_max_temp"],
static_categoricals=["agency"],
add_relative_time_idx=False,
target_normalizer=GroupNormalizer(groups=["agency", "sku"], center=False),
)


@pytest.fixture(scope="session")
def dataloaders_with_covariates(data_with_covariates):
return make_dataloaders(
Expand Down
4 changes: 4 additions & 0 deletions tests/test_models/test_nhits.py
Original file line number Diff line number Diff line change
Expand Up @@ -77,6 +77,7 @@ def _integration(dataloader, tmp_path, trainer_kwargs=None, **kwargs):
"dataloader",
[
"with_covariates",
"different_encoder_decoder_size",
"fixed_window_without_covariates",
"multi_target",
"quantiles",
Expand All @@ -86,6 +87,7 @@ def _integration(dataloader, tmp_path, trainer_kwargs=None, **kwargs):
)
def test_integration(
dataloaders_with_covariates,
dataloaders_with_different_encoder_decoder_length,
dataloaders_fixed_window_without_covariates,
dataloaders_multi_target,
tmp_path,
Expand All @@ -95,6 +97,8 @@ def test_integration(
if dataloader == "with_covariates":
dataloader = dataloaders_with_covariates
kwargs["backcast_loss_ratio"] = 0.5
elif dataloader == "different_encoder_decoder_size":
dataloader = dataloaders_with_different_encoder_decoder_length
elif dataloader == "fixed_window_without_covariates":
dataloader = dataloaders_fixed_window_without_covariates
elif dataloader == "multi_target":
Expand Down

0 comments on commit 8128061

Please sign in to comment.