Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

max_encoder_length and log_prediction issue with TFT and TimeSeriesDataset #864

Closed
zhaodongw opened this issue Feb 7, 2022 · 5 comments
Closed

Comments

@zhaodongw
Copy link

zhaodongw commented Feb 7, 2022

Hi team, I found some strange issues with TimeSeriesDataset. I initialized it with the following code:

max_encoder_length = 24
training = TimeSeriesDataSet(
    df[lambda x: x.time_idx <= training_cutoff],
    time_idx="time_idx",
    ...
    min_encoder_length=max_encoder_length // 2,  # keep encoder length long (as it is in the validation set)
    max_encoder_length=max_encoder_length,
    min_prediction_length=1,
    max_prediction_length=max_prediction_length,
    ...
    add_relative_time_idx=True,
    add_target_scales=True,
    add_encoder_length=True,
    allow_missing_timesteps=True,
)

Basically, the max_encoder_length was set to 24. However, when I tried to check the values by:

x, y = next(iter(train_dataloader))
x["encoder_lengths"]

I got a tensor containing values greater than 24: e.g., tensor([24, 14, 24, ... , 27, 29, ..., 30, ..., 24]).

This caused some runtime errors in training as follows:

---------------------------------------------------------------------------
RuntimeError                              Traceback (most recent call last)
/tmp/ipykernel_21808/1997652008.py in <module>
      3     tft,
      4     train_dataloader=train_dataloader,
----> 5     val_dataloaders=val_dataloader,
      6 )

/apps/python3/lib/python3.7/site-packages/pytorch_lightning/trainer/trainer.py in fit(self, model, train_dataloaders, val_dataloaders, datamodule, train_dataloader, ckpt_path)
    739             train_dataloaders = train_dataloader
    740         self._call_and_handle_interrupt(
--> 741             self._fit_impl, model, train_dataloaders, val_dataloaders, datamodule, ckpt_path
    742         )
    743 

/apps/python3/lib/python3.7/site-packages/pytorch_lightning/trainer/trainer.py in _call_and_handle_interrupt(self, trainer_fn, *args, **kwargs)
    683         """
    684         try:
--> 685             return trainer_fn(*args, **kwargs)
    686         # TODO: treat KeyboardInterrupt as BaseException (delete the code below) in v1.7
    687         except KeyboardInterrupt as exception:

/apps/python3/lib/python3.7/site-packages/pytorch_lightning/trainer/trainer.py in _fit_impl(self, model, train_dataloaders, val_dataloaders, datamodule, ckpt_path)
    775         # TODO: ckpt_path only in v1.7
    776         ckpt_path = ckpt_path or self.resume_from_checkpoint
--> 777         self._run(model, ckpt_path=ckpt_path)
    778 
    779         assert self.state.stopped

/apps/python3/lib/python3.7/site-packages/pytorch_lightning/trainer/trainer.py in _run(self, model, ckpt_path)
   1197 
   1198         # dispatch `start_training` or `start_evaluating` or `start_predicting`
-> 1199         self._dispatch()
   1200 
   1201         # plugin will finalized fitting (e.g. ddp_spawn will load trained model)

/apps/python3/lib/python3.7/site-packages/pytorch_lightning/trainer/trainer.py in _dispatch(self)
   1277             self.training_type_plugin.start_predicting(self)
   1278         else:
-> 1279             self.training_type_plugin.start_training(self)
   1280 
   1281     def run_stage(self):

/apps/python3/lib/python3.7/site-packages/pytorch_lightning/plugins/training_type/training_type_plugin.py in start_training(self, trainer)
    200     def start_training(self, trainer: "pl.Trainer") -> None:
    201         # double dispatch to initiate the training loop
--> 202         self._results = trainer.run_stage()
    203 
    204     def start_evaluating(self, trainer: "pl.Trainer") -> None:

/apps/python3/lib/python3.7/site-packages/pytorch_lightning/trainer/trainer.py in run_stage(self)
   1287         if self.predicting:
   1288             return self._run_predict()
-> 1289         return self._run_train()
   1290 
   1291     def _pre_training_routine(self):

/apps/python3/lib/python3.7/site-packages/pytorch_lightning/trainer/trainer.py in _run_train(self)
   1309             self.progress_bar_callback.disable()
   1310 
-> 1311         self._run_sanity_check(self.lightning_module)
   1312 
   1313         # enable train mode

/apps/python3/lib/python3.7/site-packages/pytorch_lightning/trainer/trainer.py in _run_sanity_check(self, ref_model)
   1373             # run eval step
   1374             with torch.no_grad():
-> 1375                 self._evaluation_loop.run()
   1376 
   1377             self.call_hook("on_sanity_check_end")

/apps/python3/lib/python3.7/site-packages/pytorch_lightning/loops/base.py in run(self, *args, **kwargs)
    143             try:
    144                 self.on_advance_start(*args, **kwargs)
--> 145                 self.advance(*args, **kwargs)
    146                 self.on_advance_end()
    147                 self.restarting = False

/apps/python3/lib/python3.7/site-packages/pytorch_lightning/loops/dataloader/evaluation_loop.py in advance(self, *args, **kwargs)
    108         dl_max_batches = self._max_batches[dataloader_idx]
    109 
--> 110         dl_outputs = self.epoch_loop.run(dataloader, dataloader_idx, dl_max_batches, self.num_dataloaders)
    111 
    112         # store batch level output per dataloader

/apps/python3/lib/python3.7/site-packages/pytorch_lightning/loops/base.py in run(self, *args, **kwargs)
    143             try:
    144                 self.on_advance_start(*args, **kwargs)
--> 145                 self.advance(*args, **kwargs)
    146                 self.on_advance_end()
    147                 self.restarting = False

/apps/python3/lib/python3.7/site-packages/pytorch_lightning/loops/epoch/evaluation_epoch_loop.py in advance(self, data_fetcher, dataloader_idx, dl_max_batches, num_dataloaders)
    120         # lightning module methods
    121         with self.trainer.profiler.profile("evaluation_step_and_end"):
--> 122             output = self._evaluation_step(batch, batch_idx, dataloader_idx)
    123             output = self._evaluation_step_end(output)
    124 

/apps/python3/lib/python3.7/site-packages/pytorch_lightning/loops/epoch/evaluation_epoch_loop.py in _evaluation_step(self, batch, batch_idx, dataloader_idx)
    215             self.trainer.lightning_module._current_fx_name = "validation_step"
    216             with self.trainer.profiler.profile("validation_step"):
--> 217                 output = self.trainer.accelerator.validation_step(step_kwargs)
    218 
    219         return output

/apps/python3/lib/python3.7/site-packages/pytorch_lightning/accelerators/accelerator.py in validation_step(self, step_kwargs)
    234         """
    235         with self.precision_plugin.val_step_context():
--> 236             return self.training_type_plugin.validation_step(*step_kwargs.values())
    237 
    238     def test_step(self, step_kwargs: Dict[str, Union[Any, int]]) -> Optional[STEP_OUTPUT]:

/apps/python3/lib/python3.7/site-packages/pytorch_lightning/plugins/training_type/training_type_plugin.py in validation_step(self, *args, **kwargs)
    217 
    218     def validation_step(self, *args, **kwargs):
--> 219         return self.model.validation_step(*args, **kwargs)
    220 
    221     def test_step(self, *args, **kwargs):

/apps/python3/lib/python3.7/site-packages/pytorch_forecasting/models/base_model.py in validation_step(self, batch, batch_idx)
    387         x, y = batch
    388         log, out = self.step(x, y, batch_idx)
--> 389         log.update(self.create_log(x, y, out, batch_idx))
    390         return log
    391 

/apps/python3/lib/python3.7/site-packages/pytorch_forecasting/models/temporal_fusion_transformer/__init__.py in create_log(self, x, y, out, batch_idx, **kwargs)
    513 
    514     def create_log(self, x, y, out, batch_idx, **kwargs):
--> 515         log = super().create_log(x, y, out, batch_idx, **kwargs)
    516         if self.log_interval > 0:
    517             log["interpretation"] = self._log_interpretation(out)

/apps/python3/lib/python3.7/site-packages/pytorch_forecasting/models/base_model.py in create_log(self, x, y, out, batch_idx, prediction_kwargs, quantiles_kwargs)
    431         if self.log_interval > 0:
    432             self.log_prediction(
--> 433                 x, out, batch_idx, prediction_kwargs=prediction_kwargs, quantiles_kwargs=quantiles_kwargs
    434             )
    435         return {}

/apps/python3/lib/python3.7/site-packages/pytorch_forecasting/models/base_model.py in log_prediction(self, x, out, batch_idx, **kwargs)
    689                 log_indices = [0]
    690             for idx in log_indices:
--> 691                 fig = self.plot_prediction(x, out, idx=idx, add_loss_to_title=True, **kwargs)
    692                 tag = f"{self.current_stage} prediction"
    693                 if self.training:

/apps/python3/lib/python3.7/site-packages/pytorch_forecasting/models/temporal_fusion_transformer/__init__.py in plot_prediction(self, x, out, idx, plot_attention, add_loss_to_title, show_future_observed, ax, **kwargs)
    676         # add attention on secondary axis
    677         if plot_attention:
--> 678             interpretation = self.interpret_output(out)
    679             for f in to_list(fig):
    680                 ax = f.axes[0]

/apps/python3/lib/python3.7/site-packages/pytorch_forecasting/models/temporal_fusion_transformer/__init__.py in interpret_output(self, out, reduction, attention_prediction_horizon, attention_as_autocorrelation)
    557 
    558         # histogram of decode and encode lengths
--> 559         encoder_length_histogram = integer_histogram(out["encoder_lengths"], min=0, max=self.hparams.max_encoder_length)
    560         decoder_length_histogram = integer_histogram(
    561             out["decoder_lengths"], min=1, max=out["decoder_variables"].size(1)

/apps/python3/lib/python3.7/site-packages/pytorch_forecasting/utils.py in integer_histogram(data, min, max)
     32         max = uniques.max()
     33     hist = torch.zeros(max - min + 1, dtype=torch.long, device=data.device).scatter(
---> 34         dim=0, index=uniques - min, src=counts
     35     )
     36     return hist

RuntimeError: index 30 is out of bounds for dimension 0 with size 25

Basically, here since the "encoder_lengths" contains values greater than 24, then out of bounds error occurred in integer_histogram.

I am not sure how to fix this since I didn't find any checks in timeseries.py about limiting the lengths of encoders to be smaller than max_encoder_lengths.

@zhaodongw
Copy link
Author

zhaodongw commented Feb 7, 2022

After digging a bit deeper into the code, it seems like in this line, the sequence_length can be greater than index.sequence_length, which should be unexpected? I am not sure which part of the math was wrong but something must be problematic here, maybe some edge cases in the data are ignored.

@NazyS
Copy link

NazyS commented Feb 8, 2022

I faced the same issue. Strongly suspect that allow_missing_timesteps=True kwarg in TimeSeriesDataSet results in sequences longer than max_encoder_length for groups having missing timestamps. (#376 )

This does not cause any issues for training but fails during logging if you have nonzero log_interval in TemporalFusionTransformer since in self.interpret_output method multiple tensors are created with the size of max_encoder_length and then its values are filled from actually larger tensors.

It happens in the following lines:

As a temporary fix, you can turn off logging by supplying log_interval=0 or overwrite interpret_output method to create tensors of the max sequence length in the above-mentioned lines but this is not a sufficient solution.
Also you can fill missing timesteps by yourself and use allow_missing_timesteps=False.

Anyway, even if there is a different reason for such behavior, the presence of such inconsistent batches with sequence length larger than max_encoder_length in dataloader can cause other conflicts such as, for example, shape mismatch in predictions

tft.predict(test_dataloader, mode="raw")

just as #449 which I also suspect has the reason discussed here.

@zhaodongw
Copy link
Author

Thanks @NazyS, this is super helpful! Yea it seems like when allow_missing_timesteps=True they are not doing something as expected, not sure if it will affect the model quality. But definitely fill missing timestamps can fix this in a more predictive way.

@raisbecka
Copy link

Thank you SO SO much for creating this issue - I wasted over a week completely stumped on why I was getting CUDA errors relating to indexes.

At least I can train my model now. Setting log_interval=0 worked.

@mrzaizai2k
Copy link

"Also you can fill missing timesteps by yourself and use allow_missing_timesteps=False" works for me. Thanks

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

4 participants