Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

TemporalFusionTransformer predicting with mode "raw" results in RuntimeError: Sizes of tensors must match except in dimension 0 #449

Open
TomSteenbergen opened this issue Apr 19, 2021 · 25 comments
Labels
bug Something isn't working

Comments

@TomSteenbergen
Copy link

TomSteenbergen commented Apr 19, 2021

  • PyTorch-Forecasting version: 0.8.4
  • PyTorch version: 1.8.1
  • Python version: 3.8
  • Operating System: Ubuntu 20.04.2 LTS

Expected behavior

In order to generate the interpretation plots of a Temporal Fusion Transformer model, I first try to generate the raw predictions using

  # Generate TimeSeriesDataSets and DataLoaders
  train_ts_dataset = TimeSeriesDataSet(
      train_dataset,
      time_idx=feature_config["time_index"],
      target=feature_config["target"],
      group_ids=feature_config["groups"],
      min_encoder_length=dataset_config["min_encoder_length"],  # 7
      max_encoder_length=dataset_config["max_encoder_length"],  # 28
      min_prediction_idx=1,
      min_prediction_length=dataset_config["min_prediction_length"],  # 1
      max_prediction_length=dataset_config["max_prediction_length"],  # 14
      static_categoricals=feature_config["categorical"]["static"],
      static_reals=feature_config["real"]["static"],
      time_varying_known_categoricals=feature_config["categorical"]["dynamic"][
          "known"
      ],
      time_varying_unknown_categoricals=feature_config["categorical"]["dynamic"][
          "unknown"
      ],
      time_varying_known_reals=feature_config["real"]["dynamic"]["known"],
      time_varying_unknown_reals=feature_config["real"]["dynamic"]["unknown"],
      target_normalizer=GroupNormalizer(groups=[], transformation="softplus"),
      allow_missings=True,
      categorical_encoders={
          col: NaNLabelEncoder(add_nan=True, warn=False)
          for col in set(
              feature_config["categorical"]["static"] +
              feature_config["categorical"]["dynamic"]["known"] +
              feature_config["categorical"]["dynamic"]["unknown"] +
              feature_config["groups"]
          )
      },
  )
test_ts_dataset = TimeSeriesDataSet.from_dataset(
        train_ts_dataset, test_dataset, stop_randomization=True, predict_mode=True
 )

test_dataloader = test_ts_dataset.to_dataloader(
        train=False, batch_size=batch_size * 10, num_workers=num_workers
 )
tft = TemporalFusionTransformer.from_dataset(
        train_ts_dataset,
        loss=QuantileLoss(quantiles=model_config["quantiles"]),
        output_size=len(model_config["quantiles"]),
        **model_config["hyperparameters"],
 )

...  # Train TFT on training set etc...

# Generate raw predictions. 
raw_predictions = tft.predict(test_dataloader, mode="raw")

Actual behavior

The predict method however raises an error when using mode="raw":

...
    raw_predictions = tft.predict(dataloader, mode="raw")
  File "/home/ubuntu/.local/share/virtualenvs/tmp-XVr6zr33/lib/python3.8/site-packages/pytorch_forecasting/models/base_model.py", line 982, in predict
    output = _concatenate_output(output)
  File "/home/ubuntu/.local/share/virtualenvs/tmp-XVr6zr33/lib/python3.8/site-packages/pytorch_forecasting/models/base_model.py", line 84, in _concatenate_output
    output_cat[name] = _torch_cat_na([out[name] for out in output])
  File "/home/ubuntu/.local/share/virtualenvs/tmp-XVr6zr33/lib/python3.8/site-packages/pytorch_forecasting/models/base_model.py", line 62, in _torch_cat_na
    return torch.cat(x, dim=0)
RuntimeError: Sizes of tensors must match except in dimension 0. Got 40 and 35 in dimension 3 (The offending index is 1)

Note that I have used the same dataloader earlier in a predict call without using mode="raw" and that works perfectly fine.

The same error was raised someplace else as documented here #85 and that was fixed using https://github.com/jdb78/pytorch-forecasting/pull/108/files. Could it be this padded_stack function should also be used someplace else, or is something else going on? Perhaps somewhere in _concatenate_output(), which is called in this line: https://github.com/jdb78/pytorch-forecasting/blob/master/pytorch_forecasting/models/base_model.py#L987 ?

Please let me know whether it is indeed a bug, or if I am doing something wrong 🙏 Thank you!

@Sharad24
Copy link

Facing a similar issue with another dataset. Any help would be appreciated!

@jdb78
Copy link
Collaborator

jdb78 commented Apr 29, 2021

Could you try with the newest release 0.8.5?

@jdb78 jdb78 added the bug Something isn't working label Apr 29, 2021
@TomSteenbergen
Copy link
Author

TomSteenbergen commented Apr 30, 2021

@jdb78 I tried 0.8.5. Had to update pytorch-lightning as well in order for my Pipfile to lock, which I did not expect since it was just a patch version update. Updated pytorch-lightning from 1.1.8 to 1.2.10 now.

Unfortunately I run into a new issue. The following get's triggered as soon as you start training a TFT model

Traceback (most recent call last):
  File "/home/ubuntu/.local/lib/python3.8/site-packages/torch/multiprocessing/spawn.py", line 59, in _wrap
    fn(i, *args)
  File "/home/ubuntu/.local/share/virtualenvs/tmp-XVr6zr33/lib/python3.8/site-packages/pytorch_lightning/plugins/training_type/ddp_spawn.py", line 160, in new_process
    results = trainer.train_or_test_or_predict()
  File "/home/ubuntu/.local/share/virtualenvs/tmp-XVr6zr33/lib/python3.8/site-packages/pytorch_lightning/trainer/trainer.py", line 556, in train_or_test_or_predict
    results = self.run_train()
  File "/home/ubuntu/.local/share/virtualenvs/tmp-XVr6zr33/lib/python3.8/site-packages/pytorch_lightning/trainer/trainer.py", line 637, in run_train
    self.train_loop.run_training_epoch()
  File "/home/ubuntu/.local/share/virtualenvs/tmp-XVr6zr33/lib/python3.8/site-packages/pytorch_lightning/trainer/training_loop.py", line 492, in run_training_epoch
    batch_output = self.run_training_batch(batch, batch_idx, dataloader_idx)
  File "/home/ubuntu/.local/share/virtualenvs/tmp-XVr6zr33/lib/python3.8/site-packages/pytorch_lightning/trainer/training_loop.py", line 654, in run_training_batch
    self.optimizer_step(optimizer, opt_idx, batch_idx, train_step_and_backward_closure)
  File "/home/ubuntu/.local/share/virtualenvs/tmp-XVr6zr33/lib/python3.8/site-packages/pytorch_lightning/trainer/training_loop.py", line 425, in optimizer_step
    model_ref.optimizer_step(
  File "/home/ubuntu/.local/share/virtualenvs/tmp-XVr6zr33/lib/python3.8/site-packages/pytorch_lightning/core/lightning.py", line 1390, in optimizer_step
    optimizer.step(closure=optimizer_closure)
  File "/home/ubuntu/.local/share/virtualenvs/tmp-XVr6zr33/lib/python3.8/site-packages/pytorch_lightning/core/optimizer.py", line 214, in step
    self.__optimizer_step(*args, closure=closure, profiler_name=profiler_name, **kwargs)
  File "/home/ubuntu/.local/share/virtualenvs/tmp-XVr6zr33/lib/python3.8/site-packages/pytorch_lightning/core/optimizer.py", line 134, in __optimizer_step
    trainer.accelerator.optimizer_step(optimizer, self._optimizer_idx, lambda_closure=closure, **kwargs)
  File "/home/ubuntu/.local/share/virtualenvs/tmp-XVr6zr33/lib/python3.8/site-packages/pytorch_lightning/accelerators/accelerator.py", line 277, in optimizer_step
    self.run_optimizer_step(optimizer, opt_idx, lambda_closure, **kwargs)
  File "/home/ubuntu/.local/share/virtualenvs/tmp-XVr6zr33/lib/python3.8/site-packages/pytorch_lightning/accelerators/accelerator.py", line 282, in run_optimizer_step
    self.training_type_plugin.optimizer_step(optimizer, lambda_closure=lambda_closure, **kwargs)
  File "/home/ubuntu/.local/share/virtualenvs/tmp-XVr6zr33/lib/python3.8/site-packages/pytorch_lightning/plugins/training_type/training_type_plugin.py", line 163, in optimizer_step
    optimizer.step(closure=lambda_closure, **kwargs)
  File "/home/ubuntu/.local/lib/python3.8/site-packages/torch/optim/optimizer.py", line 89, in wrapper
    return func(*args, **kwargs)
  File "/home/ubuntu/.local/share/virtualenvs/tmp-XVr6zr33/lib/python3.8/site-packages/pytorch_forecasting/optim.py", line 182, in step
    buffered = self.radam_buffer[int(state["step"] % 10)]
AttributeError: 'Ranger' object has no attribute 'radam_buffer'

@jdb78
Copy link
Collaborator

jdb78 commented Apr 30, 2021

This is very strange because it is defined in the __init__ of the class. Tests also all pass. How are you executing the code? Could you try from a script?

@TomSteenbergen
Copy link
Author

Yes, that's what I saw as well. Very strange indeed.
I'm executing it as a script already. Moving to a Jupyter Notebook is not an option for me.

I'm running this on an EC2 instance with 4 GPUs. Could it be that some process tries to perform an optimizer step before it gets properly initialized?

@jdb78
Copy link
Collaborator

jdb78 commented Apr 30, 2021

Ok, from a script it is a bit surprising. Notebooks sometimes introduce weird issues - particularly with multi-processing. I assume, it executes well on a single GPU? Lightning wraps the optimizer, so maybe that is the issue. Can you check through the debugger that the wrapped optimizer has an _optimizer attribute and this one the radam_buffer and alpha attributes?

@TomSteenbergen
Copy link
Author

TomSteenbergen commented May 6, 2021

Sorry for the late reply, only had time to look into this just now.

Yes, using a single GPU I'm not getting this error.

When I'm using 4 GPUs, it breaks though.
With a breakpoint on the trainer.fit line, inspecting the optimizer of the TemporalFusionTransformer class instance after it start training in debug mode returns the following:

>>> testopt = tft.optimizers()
>>> testopt
LightningRanger(groups=[{'N_sma_threshhold': 5, 'alpha': 0.5, 'betas': (0.95, 0.999), 'eps': 1e-05, 'k': 6, 'lr': 0.03, 'step_counter': 0, 'weight_decay': 0.0}])
>>> testopt._optimizer
Ranger (
Parameter Group 0
    N_sma_threshhold: 5
    alpha: 0.5
    betas: (0.95, 0.999)
    eps: 1e-05
    k: 6
    lr: 0.03
    step_counter: 0
    weight_decay: 0.0
)
>>> testopt._optimizer.radam_buffer
[[None, None, None], [None, None, None], [None, None, None], [None, None, None], [None, None, None], [None, None, None], [None, None, None], [None, None, None], [None, None, None], [None, None, None]]
>>> testopt._optimizer.alpha
0.5

Everything here looks as you'd expect, I think. However, inspecting tft.optimizers() before the breakpoint (i.e. before calling trainer.fit) returns:

>>> tft.optimizers()
Traceback (most recent call last):
  File "/home/ubuntu/.pycharm_helpers/pydev/_pydevd_bundle/pydevd_exec2.py", line 3, in Exec
    exec(exp, global_vars, local_vars)
  File "<input>", line 1, in <module>
  File "/home/ubuntu/.local/share/virtualenvs/tmp-XVr6zr33/lib/python3.8/site-packages/pytorch_lightning/core/lightning.py", line 114, in optimizers
    opts = list(self.trainer.lightning_optimizers.values())
AttributeError: 'NoneType' object has no attribute 'lightning_optimizers'

I'm not sure if this is expected?

Anyway, even though everything at the breakpoint looks as expected, continuing the script from there still results in the error:

AttributeError: 'Ranger' object has no attribute 'radam_buffer'

@jdb78
Copy link
Collaborator

jdb78 commented May 9, 2021

Just as a test, could you try another optimizer such as "Adam". Simply pass `optimizer="adam" to the network. This is really strange as others did not experience the issue. I wonder if something has gone wrong in the installation process.

@vishalmhjn
Copy link

I faced the same problem. Sizes of tensors must match except in dimension
It seems to be related to predict=True since this is the only difference between the training_loader and testing_loader. Any help will be appreciated.

TimeSeriesDataSet.from_dataset(training, X_formatted_test, predict=True, stop_randomization=True)

@TomSteenbergen
Copy link
Author

TomSteenbergen commented Aug 26, 2021

Just as a test, could you try another optimizer such as "Adam". Simply pass `optimizer="adam" to the network. This is really strange as others did not experience the issue. I wonder if something has gone wrong in the installation process.

@jdb78 A different optimizer works fine. Installed a clean environment on several VMs now, and I run into the same above error every time using the default ranger optimizer.

@owoshch
Copy link

owoshch commented Oct 22, 2021

Hi @TomSteenbergen and @jdb78!

I'm facing the same issue while training TFT model and then running raw_predictions, x = best_tft.predict(val_dataloader, mode="raw", return_x=True)

The error message is: RuntimeError: Sizes of tensors must match except in dimension 0. Got 110 and 109 in dimension 3 (The offending index is 1)

Notably, if I call raw_predictions, x = best_tft.predict(val_dataloader, return_x=True) it works fine.

I tried several setups:
pytorch_forecasting 0.9.0 pytorch_lightning 1.4.2 pytorch 1.9.0 python 3.7.11 linux 18.04.5

pytorch_forecasting 0.9.1 pytorch_lightning 1.4.9 pytorch 1.8.0 python 3.8.12 linux 18.04.5

Did you find the setup (versions of pytorch-forecasting, sklearn, etc) that allows solving this issue?
Thank you!

@adamwuyu
Copy link

adamwuyu commented Nov 4, 2021

Hi @TomSteenbergen and @jdb78!

I'm facing the same issue while training TFT model and then running raw_predictions, x = best_tft.predict(val_dataloader, mode="raw", return_x=True)

The error message is: RuntimeError: Sizes of tensors must match except in dimension 0. Got 110 and 109 in dimension 3 (The offending index is 1)

Notably, if I call raw_predictions, x = best_tft.predict(val_dataloader, return_x=True) it works fine.

I tried several setups: pytorch_forecasting 0.9.0 pytorch_lightning 1.4.2 pytorch 1.9.0 python 3.7.11 linux 18.04.5

pytorch_forecasting 0.9.1 pytorch_lightning 1.4.9 pytorch 1.8.0 python 3.8.12 linux 18.04.5

Did you find the setup (versions of pytorch-forecasting, sklearn, etc) that allows solving this issue? Thank you!

I have the exactly same issue with you. But common out mode="raw" is not an option. Otherwise best_tft.plot_prediction(x, raw_prediction, idx=0) will show error too many indices for tensor of dimension 2

So any help will be appreciated!

@leonardobarberi
Copy link

Hi, I also get the same errors as above. Has any of this been solved yet?

@nicocheh
Copy link

nicocheh commented Feb 3, 2022

Any solution on this? I get the same thing...

@jdb78
Copy link
Collaborator

jdb78 commented Mar 22, 2022

It seems to be related to attention (n_batches x n_decoder_steps (that attend) x n_attention_heads x n_timesteps (to which is attended)) concatenation. Seems like there is an issue with the concatenation logic. Will work on a fix

@jdb78
Copy link
Collaborator

jdb78 commented Mar 23, 2022

#902

@JSpenced
Copy link
Contributor

JSpenced commented Aug 10, 2022

@jdb78 I am still getting this issue when retrying a failed run that got killed. So I am loading in the state_dict and it can't find that variable so throws a key error. Any ideas on why that would be happening or how to debug?

@luisecastro
Copy link

luisecastro commented Nov 21, 2022

Same issue here running the https://pytorch-forecasting.readthedocs.io/en/stable/tutorials/stallion.html tutorial,

# find optimal learning rate
res = trainer.tuner.lr_find(
    tft,
    train_dataloaders=train_dataloader,
    val_dataloaders=val_dataloader,
    max_lr=10.0,
    min_lr=1e-6,
)

print(f"suggested learning rate: {res.suggestion()}")
fig = res.plot(show=True, suggest=True)
fig.show()

File ~\anaconda3\envs\bayes\lib\site-packages\pytorch_forecasting\optim.py:133, in Ranger.__setstate__(self, state)
    131 def __setstate__(self, state: dict) -> None:
    132     super().__setstate__(state)
--> 133     self.radam_buffer = state["radam_buffer"]
    134     self.alpha = state["alpha"]
    135     self.k = state["k"]

KeyError: 'radam_buffer'

@akshat-chandna
Copy link

Same issue here running the https://pytorch-forecasting.readthedocs.io/en/stable/tutorials/stallion.html tutorial,

# find optimal learning rate
res = trainer.tuner.lr_find(
    tft,
    train_dataloaders=train_dataloader,
    val_dataloaders=val_dataloader,
    max_lr=10.0,
    min_lr=1e-6,
)

print(f"suggested learning rate: {res.suggestion()}")
fig = res.plot(show=True, suggest=True)
fig.show()
File ~\anaconda3\envs\bayes\lib\site-packages\pytorch_forecasting\optim.py:133, in Ranger.__setstate__(self, state)
    131 def __setstate__(self, state: dict) -> None:
    132     super().__setstate__(state)
--> 133     self.radam_buffer = state["radam_buffer"]
    134     self.alpha = state["alpha"]
    135     self.k = state["k"]

KeyError: 'radam_buffer'

Pass optimizer='adam' while creating the TFT model. It works with that.

@ghosts8
Copy link

ghosts8 commented Nov 25, 2022

@akshat-chandna -> Where exactly do you pass the optimizer in the Stallion code?

@akshat-chandna
Copy link

akshat-chandna commented Nov 25, 2022

@akshat-chandna -> Where exactly do you pass the optimizer in the Stallion code?

tft = TemporalFusionTransformer.from_dataset(
training,
learning_rate=0.001,
hidden_size=16,
attention_head_size=1,
dropout=0.1,
hidden_continuous_size=8,
output_size=7,
loss=QuantileLoss(),
reduce_on_plateau_patience=4,
optimizer='adam'
)

@Kaszanas
Copy link

Kaszanas commented Jan 8, 2023

Same issue here running the https://pytorch-forecasting.readthedocs.io/en/stable/tutorials/stallion.html tutorial,

# find optimal learning rate
res = trainer.tuner.lr_find(
    tft,
    train_dataloaders=train_dataloader,
    val_dataloaders=val_dataloader,
    max_lr=10.0,
    min_lr=1e-6,
)

print(f"suggested learning rate: {res.suggestion()}")
fig = res.plot(show=True, suggest=True)
fig.show()
File ~\anaconda3\envs\bayes\lib\site-packages\pytorch_forecasting\optim.py:133, in Ranger.__setstate__(self, state)
    131 def __setstate__(self, state: dict) -> None:
    132     super().__setstate__(state)
--> 133     self.radam_buffer = state["radam_buffer"]
    134     self.alpha = state["alpha"]
    135     self.k = state["k"]

KeyError: 'radam_buffer'

I have the exact same issue when running the stallion tutorial code.

@sysublackstar
Copy link

Same issue here running the https://pytorch-forecasting.readthedocs.io/en/stable/tutorials/stallion.html tutorial,

# find optimal learning rate
res = trainer.tuner.lr_find(
    tft,
    train_dataloaders=train_dataloader,
    val_dataloaders=val_dataloader,
    max_lr=10.0,
    min_lr=1e-6,
)

print(f"suggested learning rate: {res.suggestion()}")
fig = res.plot(show=True, suggest=True)
fig.show()
File ~\anaconda3\envs\bayes\lib\site-packages\pytorch_forecasting\optim.py:133, in Ranger.__setstate__(self, state)
    131 def __setstate__(self, state: dict) -> None:
    132     super().__setstate__(state)
--> 133     self.radam_buffer = state["radam_buffer"]
    134     self.alpha = state["alpha"]
    135     self.k = state["k"]

KeyError: 'radam_buffer'

I have the exact same issue when running the stallion tutorial code.

I solve the problem by what akshat-chandna suggested by adding

optimizer='adam'

into the from_dataset function, you can try it.

@Jackson906E
Copy link

---> 1 res = trainer.tuner.lr_find(
2 tft,
3 train_dataloaders=train_dataloader,
4 val_dataloaders=val_dataloader,
5 max_lr=10.0,
6 min_lr=1e-6,
7 )
9 print(f"suggested learning rate: {res.suggestion()}")
10 fig = res.plot(show=True, suggest=True)

File /opt/homebrew/Caskroom/miniforge/base/envs/aaa/lib/python3.9/site-packages/pytorch_lightning/tuner/tuning.py:267, in Tuner.lr_find(self, model, train_dataloaders, val_dataloaders, dataloaders, datamodule, method, min_lr, max_lr, num_training, mode, early_stop_threshold, update_attr)
264 lr_finder_callback._early_exit = True
265 self.trainer.callbacks = [lr_finder_callback] + self.trainer.callbacks
--> 267 self.trainer.fit(model, train_dataloaders, val_dataloaders, datamodule)
269 self.trainer.callbacks = [cb for cb in self.trainer.callbacks if cb is not lr_finder_callback]
271 self.trainer.auto_lr_find = False

File /opt/homebrew/Caskroom/miniforge/base/envs/aaa/lib/python3.9/site-packages/pytorch_lightning/trainer/trainer.py:608, in Trainer.fit(self, model, train_dataloaders, val_dataloaders, datamodule, ckpt_path)
606 model = self._maybe_unwrap_optimized(model)
607 self.strategy._lightning_module = model
--> 608 call._call_and_handle_interrupt(
609 self, self._fit_impl, model, train_dataloaders, val_dataloaders, datamodule, ckpt_path
...
--> 200 Variable.execution_engine.run_backward( # Calls into the C++ engine to run the backward pass
201 tensors, grad_tensors
, retain_graph, create_graph, inputs,
202 allow_unreachable=True, accumulate_grad=True)

RuntimeError: element 0 of tensors does not require grad and does not have a grad_fn

@abudis
Copy link

abudis commented Jul 25, 2023

Not sure if it's the same issue, but we're facing the same problem when running the latest version of the package. What we noticed is that if you set return_y = False, then the predictions work even in raw mode.
After some debugging it seems like the ys are formatted in minibatch size, which fails if the dataset size cannot be divided into minibatches without a remainder.
You don't get the ys in the output, but you can easily format them yourself, as they are literally the targets passed during training with the dimensions (MINIBATCH_SIZE, DECODER_LENGTH * N_MINIBATCHES).

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
bug Something isn't working
Projects
Status: Needs triage & validation
Development

No branches or pull requests