You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
Hi, Thank you for your great library! As of 03/25/2022, my google colab notebooks which previously could successfully train multitarget DeepAR and TFT models are now failing when calling trainer.fit(). I have tried installing older versions of pytorch-forecasting (0.9.0, 0.9.2, 0.8.5) with no luck.
Actual behavior
I'm getting the same error message from #689 and which was commented on in the last commit #902: RuntimeError: element 0 of tensors does not require grad and does not have a grad_fn
Code to reproduce the problem
A minimal example to reproduce the error in google colab:
import torch
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
from pytorch_forecasting.data import (
TimeSeriesDataSet,
TorchNormalizer,
MultiNormalizer
)
import pytorch_lightning as pl
from pytorch_forecasting.metrics import QuantileLoss, MultiLoss
from pytorch_forecasting.models import TemporalFusionTransformer
# gen synthetic data
time_idx = np.tile(np.arange(20), 20)
group_column = np.repeat(np.arange(20), 20)
df = pd.DataFrame({'time_idx': time_idx, 'group_column': group_column})
df['target1'] = np.random.normal(0,1,df.shape[0])
df['target2'] = np.random.normal(0,1,df.shape[0])
training = TimeSeriesDataSet(
df,
time_idx="time_idx",
target=["target1", "target2"],
group_ids=["group_column"],
min_encoder_length=5,
max_encoder_length=10,
min_prediction_length=5,
max_prediction_length=5,
add_relative_time_idx=True,
randomize_length=[5, 2],
target_normalizer=MultiNormalizer([TorchNormalizer('identity'), TorchNormalizer('identity')]),
)
train_dataloader = training.to_dataloader(train=True, batch_size=32)
trainer = pl.Trainer(
auto_lr_find=False,
# accelerator="ddr",
gpus=0,
max_epochs=50,
)
tft = TemporalFusionTransformer.from_dataset(
training,
learning_rate=0.03,
hidden_size=16,
hidden_continuous_size=8,
output_size=[7, 7],
loss=MultiLoss([QuantileLoss(), QuantileLoss()]),
)
tft.size()
trainer.fit(
tft,
train_dataloaders=train_dataloader,
# val_dataloaders=training.to_dataloader(train=False, batch_size=32), # need to specify a val_dataloader for older versions
)
Trace
---------------------------------------------------------------------------
RuntimeError Traceback (most recent call last)
[<ipython-input-2-3dc75bc406c0>](https://localhost:8080/#) in <module>()
56 tft,
57 train_dataloaders=train_dataloader,
---> 58 val_dataloaders=training.to_dataloader(train=False, batch_size=32),
59 )
32 frames
[/usr/local/lib/python3.7/dist-packages/pytorch_lightning/trainer/trainer.py](https://localhost:8080/#) in fit(self, model, train_dataloaders, val_dataloaders, datamodule, train_dataloader, ckpt_path)
739 train_dataloaders = train_dataloader
740 self._call_and_handle_interrupt(
--> 741 self._fit_impl, model, train_dataloaders, val_dataloaders, datamodule, ckpt_path
742 )
743
[/usr/local/lib/python3.7/dist-packages/pytorch_lightning/trainer/trainer.py](https://localhost:8080/#) in _call_and_handle_interrupt(self, trainer_fn, *args, **kwargs)
683 """
684 try:
--> 685 return trainer_fn(*args, **kwargs)
686 # TODO: treat KeyboardInterrupt as BaseException (delete the code below) in v1.7
687 except KeyboardInterrupt as exception:
[/usr/local/lib/python3.7/dist-packages/pytorch_lightning/trainer/trainer.py](https://localhost:8080/#) in _fit_impl(self, model, train_dataloaders, val_dataloaders, datamodule, ckpt_path)
775 # TODO: ckpt_path only in v1.7
776 ckpt_path = ckpt_path or self.resume_from_checkpoint
--> 777 self._run(model, ckpt_path=ckpt_path)
778
779 assert self.state.stopped
[/usr/local/lib/python3.7/dist-packages/pytorch_lightning/trainer/trainer.py](https://localhost:8080/#) in _run(self, model, ckpt_path)
1197
1198 # dispatch `start_training` or `start_evaluating` or `start_predicting`
-> 1199 self._dispatch()
1200
1201 # plugin will finalized fitting (e.g. ddp_spawn will load trained model)
[/usr/local/lib/python3.7/dist-packages/pytorch_lightning/trainer/trainer.py](https://localhost:8080/#) in _dispatch(self)
1277 self.training_type_plugin.start_predicting(self)
1278 else:
-> 1279 self.training_type_plugin.start_training(self)
1280
1281 def run_stage(self):
[/usr/local/lib/python3.7/dist-packages/pytorch_lightning/plugins/training_type/training_type_plugin.py](https://localhost:8080/#) in start_training(self, trainer)
200 def start_training(self, trainer: "pl.Trainer") -> None:
201 # double dispatch to initiate the training loop
--> 202 self._results = trainer.run_stage()
203
204 def start_evaluating(self, trainer: "pl.Trainer") -> None:
[/usr/local/lib/python3.7/dist-packages/pytorch_lightning/trainer/trainer.py](https://localhost:8080/#) in run_stage(self)
1287 if self.predicting:
1288 return self._run_predict()
-> 1289 return self._run_train()
1290
1291 def _pre_training_routine(self):
[/usr/local/lib/python3.7/dist-packages/pytorch_lightning/trainer/trainer.py](https://localhost:8080/#) in _run_train(self)
1317 self.fit_loop.trainer = self
1318 with torch.autograd.set_detect_anomaly(self._detect_anomaly):
-> 1319 self.fit_loop.run()
1320
1321 def _run_evaluate(self) -> _EVALUATE_OUTPUT:
[/usr/local/lib/python3.7/dist-packages/pytorch_lightning/loops/base.py](https://localhost:8080/#) in run(self, *args, **kwargs)
143 try:
144 self.on_advance_start(*args, **kwargs)
--> 145 self.advance(*args, **kwargs)
146 self.on_advance_end()
147 self.restarting = False
[/usr/local/lib/python3.7/dist-packages/pytorch_lightning/loops/fit_loop.py](https://localhost:8080/#) in advance(self)
232
233 with self.trainer.profiler.profile("run_training_epoch"):
--> 234 self.epoch_loop.run(data_fetcher)
235
236 # the global step is manually decreased here due to backwards compatibility with existing loggers
[/usr/local/lib/python3.7/dist-packages/pytorch_lightning/loops/base.py](https://localhost:8080/#) in run(self, *args, **kwargs)
143 try:
144 self.on_advance_start(*args, **kwargs)
--> 145 self.advance(*args, **kwargs)
146 self.on_advance_end()
147 self.restarting = False
[/usr/local/lib/python3.7/dist-packages/pytorch_lightning/loops/epoch/training_epoch_loop.py](https://localhost:8080/#) in advance(self, *args, **kwargs)
191
192 with self.trainer.profiler.profile("run_training_batch"):
--> 193 batch_output = self.batch_loop.run(batch, batch_idx)
194
195 self.batch_progress.increment_processed()
[/usr/local/lib/python3.7/dist-packages/pytorch_lightning/loops/base.py](https://localhost:8080/#) in run(self, *args, **kwargs)
143 try:
144 self.on_advance_start(*args, **kwargs)
--> 145 self.advance(*args, **kwargs)
146 self.on_advance_end()
147 self.restarting = False
[/usr/local/lib/python3.7/dist-packages/pytorch_lightning/loops/batch/training_batch_loop.py](https://localhost:8080/#) in advance(self, batch, batch_idx)
86 if self.trainer.lightning_module.automatic_optimization:
87 optimizers = _get_active_optimizers(self.trainer.optimizers, self.trainer.optimizer_frequencies, batch_idx)
---> 88 outputs = self.optimizer_loop.run(split_batch, optimizers, batch_idx)
89 else:
90 outputs = self.manual_loop.run(split_batch, batch_idx)
[/usr/local/lib/python3.7/dist-packages/pytorch_lightning/loops/base.py](https://localhost:8080/#) in run(self, *args, **kwargs)
143 try:
144 self.on_advance_start(*args, **kwargs)
--> 145 self.advance(*args, **kwargs)
146 self.on_advance_end()
147 self.restarting = False
[/usr/local/lib/python3.7/dist-packages/pytorch_lightning/loops/optimization/optimizer_loop.py](https://localhost:8080/#) in advance(self, batch, *args, **kwargs)
217 self._batch_idx,
218 self._optimizers[self.optim_progress.optimizer_position],
--> 219 self.optimizer_idx,
220 )
221 if result.loss is not None:
[/usr/local/lib/python3.7/dist-packages/pytorch_lightning/loops/optimization/optimizer_loop.py](https://localhost:8080/#) in _run_optimization(self, split_batch, batch_idx, optimizer, opt_idx)
264 # gradient update with accumulated gradients
265 else:
--> 266 self._optimizer_step(optimizer, opt_idx, batch_idx, closure)
267
268 result = closure.consume_result()
[/usr/local/lib/python3.7/dist-packages/pytorch_lightning/loops/optimization/optimizer_loop.py](https://localhost:8080/#) in _optimizer_step(self, optimizer, opt_idx, batch_idx, train_step_and_backward_closure)
384 on_tpu=(self.trainer._device_type == DeviceType.TPU and _TPU_AVAILABLE),
385 using_native_amp=(self.trainer.amp_backend is not None and self.trainer.amp_backend == AMPType.NATIVE),
--> 386 using_lbfgs=is_lbfgs,
387 )
388
[/usr/local/lib/python3.7/dist-packages/pytorch_lightning/core/lightning.py](https://localhost:8080/#) in optimizer_step(self, epoch, batch_idx, optimizer, optimizer_idx, optimizer_closure, on_tpu, using_native_amp, using_lbfgs)
1650
1651 """
-> 1652 optimizer.step(closure=optimizer_closure)
1653
1654 def optimizer_zero_grad(self, epoch: int, batch_idx: int, optimizer: Optimizer, optimizer_idx: int):
[/usr/local/lib/python3.7/dist-packages/pytorch_lightning/core/optimizer.py](https://localhost:8080/#) in step(self, closure, **kwargs)
162 assert trainer is not None
163 with trainer.profiler.profile(profiler_action):
--> 164 trainer.accelerator.optimizer_step(self._optimizer, self._optimizer_idx, closure, **kwargs)
[/usr/local/lib/python3.7/dist-packages/pytorch_lightning/accelerators/accelerator.py](https://localhost:8080/#) in optimizer_step(self, optimizer, opt_idx, closure, model, **kwargs)
337 """
338 model = model or self.lightning_module
--> 339 self.precision_plugin.optimizer_step(model, optimizer, opt_idx, closure, **kwargs)
340
341 def optimizer_zero_grad(self, current_epoch: int, batch_idx: int, optimizer: Optimizer, opt_idx: int) -> None:
[/usr/local/lib/python3.7/dist-packages/pytorch_lightning/plugins/precision/precision_plugin.py](https://localhost:8080/#) in optimizer_step(self, model, optimizer, optimizer_idx, closure, **kwargs)
161 if isinstance(model, pl.LightningModule):
162 closure = partial(self._wrap_closure, model, optimizer, optimizer_idx, closure)
--> 163 optimizer.step(closure=closure, **kwargs)
164
165 def _track_grad_norm(self, trainer: "pl.Trainer") -> None:
[/usr/local/lib/python3.7/dist-packages/torch/optim/optimizer.py](https://localhost:8080/#) in wrapper(*args, **kwargs)
86 profile_name = "Optimizer.step#{}.step".format(obj.__class__.__name__)
87 with torch.autograd.profiler.record_function(profile_name):
---> 88 return func(*args, **kwargs)
89 return wrapper
90
[/usr/local/lib/python3.7/dist-packages/pytorch_forecasting/optim.py](https://localhost:8080/#) in step(self, closure)
141 closure: A closure that reevaluates the model and returns the loss.
142 """
--> 143 _ = closure()
144 loss = None
145 # note - below is commented out b/c I have other work that passes back
[/usr/local/lib/python3.7/dist-packages/pytorch_lightning/plugins/precision/precision_plugin.py](https://localhost:8080/#) in _wrap_closure(self, model, optimizer, optimizer_idx, closure)
146 consistent with the ``PrecisionPlugin`` subclasses that cannot pass ``optimizer.step(closure)`` directly.
147 """
--> 148 closure_result = closure()
149 self._after_closure(model, optimizer, optimizer_idx)
150 return closure_result
[/usr/local/lib/python3.7/dist-packages/pytorch_lightning/loops/optimization/optimizer_loop.py](https://localhost:8080/#) in __call__(self, *args, **kwargs)
158
159 def __call__(self, *args: Any, **kwargs: Any) -> Optional[Tensor]:
--> 160 self._result = self.closure(*args, **kwargs)
161 return self._result.loss
162
[/usr/local/lib/python3.7/dist-packages/pytorch_lightning/loops/optimization/optimizer_loop.py](https://localhost:8080/#) in closure(self, *args, **kwargs)
153 if self._backward_fn is not None and step_output.closure_loss is not None:
154 with self._profiler.profile("backward"):
--> 155 self._backward_fn(step_output.closure_loss)
156
157 return step_output
[/usr/local/lib/python3.7/dist-packages/pytorch_lightning/loops/optimization/optimizer_loop.py](https://localhost:8080/#) in backward_fn(loss)
325
326 def backward_fn(loss: Tensor) -> None:
--> 327 self.trainer.accelerator.backward(loss, optimizer, opt_idx)
328
329 # check if model weights are nan
[/usr/local/lib/python3.7/dist-packages/pytorch_lightning/accelerators/accelerator.py](https://localhost:8080/#) in backward(self, closure_loss, *args, **kwargs)
312 closure_loss = self.precision_plugin.pre_backward(self.lightning_module, closure_loss)
313
--> 314 self.precision_plugin.backward(self.lightning_module, closure_loss, *args, **kwargs)
315
316 closure_loss = self.precision_plugin.post_backward(self.lightning_module, closure_loss)
[/usr/local/lib/python3.7/dist-packages/pytorch_lightning/plugins/precision/precision_plugin.py](https://localhost:8080/#) in backward(self, model, closure_loss, optimizer, *args, **kwargs)
89 # do backward pass
90 if model is not None and isinstance(model, pl.LightningModule):
---> 91 model.backward(closure_loss, optimizer, *args, **kwargs)
92 else:
93 self._run_backward(closure_loss, *args, **kwargs)
[/usr/local/lib/python3.7/dist-packages/pytorch_lightning/core/lightning.py](https://localhost:8080/#) in backward(self, loss, optimizer, optimizer_idx, *args, **kwargs)
1432 loss.backward()
1433 """
-> 1434 loss.backward(*args, **kwargs)
1435
1436 def toggle_optimizer(self, optimizer: Union[Optimizer, LightningOptimizer], optimizer_idx: int) -> None:
[/usr/local/lib/python3.7/dist-packages/torch/_tensor.py](https://localhost:8080/#) in backward(self, gradient, retain_graph, create_graph, inputs)
305 create_graph=create_graph,
306 inputs=inputs)
--> 307 torch.autograd.backward(self, gradient, retain_graph, create_graph, inputs=inputs)
308
309 def register_hook(self, hook):
[/usr/local/lib/python3.7/dist-packages/torch/autograd/__init__.py](https://localhost:8080/#) in backward(tensors, grad_tensors, retain_graph, create_graph, grad_variables, inputs)
154 Variable._execution_engine.run_backward(
155 tensors, grad_tensors_, retain_graph, create_graph, inputs,
--> 156 allow_unreachable=True, accumulate_grad=True) # allow_unreachable flag
157
158
RuntimeError: element 0 of tensors does not require grad and does not have a grad_fn
The text was updated successfully, but these errors were encountered:
rohanthavarajah
changed the title
Multitarget not working for TFT and DeepAR sincce 03/24/2022
Multitarget not working for TFT and DeepAR since 03/24/2022
Mar 24, 2022
Expected behavior
Hi, Thank you for your great library! As of 03/25/2022, my google colab notebooks which previously could successfully train multitarget DeepAR and TFT models are now failing when calling trainer.fit(). I have tried installing older versions of pytorch-forecasting (0.9.0, 0.9.2, 0.8.5) with no luck.
Actual behavior
I'm getting the same error message from #689 and which was commented on in the last commit #902:
RuntimeError: element 0 of tensors does not require grad and does not have a grad_fn
Code to reproduce the problem
A minimal example to reproduce the error in google colab:
Trace
The text was updated successfully, but these errors were encountered: