Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix attention concat #902

Merged
merged 5 commits into from
Mar 23, 2022
Merged

Fix attention concat #902

merged 5 commits into from
Mar 23, 2022

Conversation

jdb78
Copy link
Collaborator

@jdb78 jdb78 commented Mar 23, 2022

Description

This PR fixes an issue with attention calculation accross many samples and predict(mode="raw") failing.

Checklist

  • Linked issues (if existing)
  • Amended changelog for large changes (and added myself there as contributor)
  • Added/modified tests
  • Used pre-commit hooks when committing to ensure that code is compliant with hooks. Install hooks with pre-commit install.
    To run hooks independent of commit, execute pre-commit run --all-files

Make sure to have fun coding!

@jdb78
Copy link
Collaborator Author

jdb78 commented Mar 23, 2022

Fixes #689

@codecov-commenter
Copy link

codecov-commenter commented Mar 23, 2022

Codecov Report

Merging #902 (eeec569) into master (eea8c16) will increase coverage by 0.20%.
The diff coverage is 98.00%.

❗ Current head eeec569 differs from pull request most recent head 8b68e7c. Consider uploading reports for the commit 8b68e7c to get more accurate results

@@            Coverage Diff             @@
##           master     #902      +/-   ##
==========================================
+ Coverage   89.57%   89.78%   +0.20%     
==========================================
  Files          26       26              
  Lines        4173     4200      +27     
==========================================
+ Hits         3738     3771      +33     
+ Misses        435      429       -6     
Flag Coverage Δ
cpu 89.78% <98.00%> (+0.20%) ⬆️
pytest 89.78% <98.00%> (+0.20%) ⬆️

Flags with carried forward coverage won't be shown. Click here to find out more.

Impacted Files Coverage Δ
pytorch_forecasting/utils.py 84.33% <91.66%> (+0.57%) ⬆️
pytorch_forecasting/models/base_model.py 88.82% <100.00%> (+0.38%) ⬆️
...ing/models/temporal_fusion_transformer/__init__.py 99.56% <100.00%> (+2.24%) ⬆️
pytorch_forecasting/models/nn/rnn.py 87.34% <0.00%> (-5.07%) ⬇️
pytorch_forecasting/metrics.py 91.96% <0.00%> (+0.21%) ⬆️
pytorch_forecasting/data/timeseries.py 93.49% <0.00%> (+0.39%) ⬆️

Continue to review full report at Codecov.

Legend - Click here to learn more
Δ = absolute <relative> (impact), ø = not affected, ? = missing data
Powered by Codecov. Last update eea8c16...8b68e7c. Read the comment docs.

@jdb78 jdb78 merged commit af9e1d3 into master Mar 23, 2022
@rohanthavarajah
Copy link

Hi Jan, Thank you for the great library! As of today (03/24/2022), I started encountering the multitarget bug described in #689 for the first time (in colab notebooks that were previously successfully forecasting multiple targets with TFT as recently as 03/23/2022).

@jdb78
Copy link
Collaborator Author

jdb78 commented Mar 24, 2022

#908

@nicocheh
Copy link

nicocheh commented Mar 27, 2022

Hi, @jdb78 . Awesome fix, Since this fix I am getting this error when trying to use the interpret_output function. Do you know what the problem could be? It seems like a small bug when trying to get a dimension of the vector that is a list...

interpretation = best_tft.interpret_output(raw_predictions, reduction="sum")
best_tft.plot_interpretation(interpretation)


AttributeError Traceback (most recent call last)
Input In [21], in
----> 1 interpretation = best_tft.interpret_output(raw_predictions, reduction="sum")
2 best_tft.plot_interpretation(interpretation)

File ~/.local/share/virtualenvs/SPF-25pQYGzl/lib/python3.8/site-packages/pytorch_forecasting/models/temporal_fusion_transformer/init.py:597, in TemporalFusionTransformer.interpret_output(self, out, reduction, attention_prediction_horizon)
595 # roll encoder attention (so start last encoder value is on the right)
596 encoder_attention = out["encoder_attention"]
--> 597 shifts = encoder_attention.size(3) - out["encoder_lengths"]
598 new_index = (
599 torch.arange(encoder_attention.size(3), device=encoder_attention.device)[None, None, None].expand_as(
600 encoder_attention
601 )
602 - shifts[:, None, None, None]
603 ) % encoder_attention.size(3)
604 encoder_attention = torch.gather(encoder_attention, dim=3, index=new_index)

AttributeError: 'list' object has no attribute 'size'

@jdb78
Copy link
Collaborator Author

jdb78 commented Mar 27, 2022

Hm, a bit strnage. Do you have a minimal working example?

@nicocheh
Copy link

nicocheh commented Mar 27, 2022

Yes, it is strange... When i do it with the tuorial works well. I am using a dataset of my own, the parameters of the time series dataset are

min_encoder_length=max_encoder_length//2, 
max_encoder_length=max_encoder_length,
min_prediction_length=max_prediction_length,
max_prediction_length=max_prediction_length,
allow_missing_timesteps=True,
target_normalizer=GroupNormalizer(groups=['region', 'item_variation_id'], center=False),
add_relative_time_idx=True,
add_encoder_length=True, 
add_target_scales=True

The model is created with:

  TemporalFusionTransformer.from_dataset(
      training,
      learning_rate=0.0992,
      hidden_size=100,  
      attention_head_size=4,
      dropout=0.283,  
      hidden_continuous_size=89, 
      output_size=7,  
      loss=QuantileLoss(),
      log_interval=-1,
      reduce_on_plateau_patience=5,
  )

I create the dataloader exactly as in the tutorial (same parameters and using from_dataset and the to_dataloader. Raw predictions are calculated as in the tutorial as well, with

raw_predictions, x, indexx = best_tft.predict(val_dataloader, mode="raw", return_x=True,return_index=True)

I don't know where the problem might come from... I think it is that when i get raw_predictions["encoder_attention"] i get a list, maybe its needed in the code an "if statement" that gets when this is a list (in the code this is done only with raw_predictions["decoder_attention"] but not with encoder and thats where the error comes from)? what do you think @jdb78 ?

Thank you very much!

@nicocheh
Copy link

Hi @jdb78 , in my data, if i made a raw_predictions["decoder_attention"].shape i get torch.Size([5147487, 5, 4, 5]). The raw_predictions["encoder_attention"] is a list of tensors indeed, of len()= 5147487, and if i do raw_predictions["encoder_attention"][0].shape i get torch.Size([5, 4, 21])

I am using 4 attention heads and the max_prediction length is 5 and max encoder_ length 21.

I think the problem is just that the interpret_output function assumes that raw_predictions["encoder_attention"] is always a tensor, what do you think?

Hope this helps to fix the issue, thank you very much!

@nicocheh
Copy link

@jdb78 I think its just adding the if isinstance(out[“encoder_attention”], (list, tuple)) in the line i marked the error is and changing tensor size for the corresponding with accesing list first...

Did you have time to get an eye on this? thank you!

@jdb78 jdb78 deleted the feature/fix-attention-concat branch May 23, 2022 11:29
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

4 participants