-
Notifications
You must be signed in to change notification settings - Fork 3.4k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
PyTorch Lightning 1.4.1 crashes during training #8821
Comments
This issue happened to me, I thought it was related to how the dataloaders are handled, and I decided to postponed using multi-GPU on my project, so I didn't investigate it further. |
I have the same issue. And succeed to reproduce it with following code. import torch
import math
from torchvision.datasets import CIFAR100
from torchmetrics.functional import accuracy
import torchvision.transforms as transforms
from torch.utils.data import random_split, DataLoader
from torch import nn
from torchsummary import summary
import pytorch_lightning as pl
import wandb
from pytorch_lightning.loggers import CSVLogger, WandbLogger
from pytorch_lightning.plugins import DDPPlugin
from torchvision.models import resnet18 as res18
def weights_init(m):
if isinstance(m, nn.Conv2d):
if m.weight is not None: nn.init.xavier_normal_(m.weight.data)
if m.bias is not None: nn.init.xavier_normal_(m.bias.data)
class resnet(pl.LightningModule):
def __init__(self, nclasses=10, bs=256, lr=3e-4, epochs=100, workers=2):
super().__init__()
self.resnet18 = res18(pretrained=True)
self.loss = nn.CrossEntropyLoss()
self.lr = lr
self.bs = bs
self.eps = epochs
self.workers = workers
self.ngpus = torch.cuda.device_count()
self.save_hyperparameters()
def forward(self, x):
out = self.resnet18(x)
return out
def prepare_data(self):
CIFAR100(root='/tmp', train=True, download=True)
CIFAR100(root='/tmp', train=False, download=True)
def setup(self, stage):
cifar10_mean, cifar10_std = (0.4914, 0.4822, 0.4465), (0.247, 0.243, 0.261)
cifar100_mean, cifar100_std = (0.5070, 0.4865, 0.4409), (0.2673, 0.2564, 0.2761)
transforms_train = transforms.Compose([
transforms.Pad(4),
transforms.RandomHorizontalFlip(),
transforms.RandomCrop(32),
transforms.ToTensor(),
transforms.Normalize(cifar100_mean, cifar100_std)
])
transforms_test = transforms.Compose([
transforms.ToTensor(),
transforms.Normalize(cifar100_mean, cifar100_std)
])
trainset = CIFAR100(root='/tmp', train=True, download=False, transform = transforms_train)
self.trainset, self.valset = torch.utils.data.random_split(trainset, [45000, 5000])
self.testset = CIFAR100(root='/tmp', train=False, download=False, transform = transforms_test)
self.t_steps = math.ceil(len(self.trainset)/self.ngpus / self.bs)
def train_dataloader(self):
return DataLoader(self.trainset, batch_size=self.bs, shuffle=True, num_workers=self.workers)
def val_dataloader(self):
return DataLoader(self.valset, batch_size = self.bs, shuffle=False, num_workers=self.workers)
def test_dataloader(self):
return DataLoader(self.testset, batch_size = self.bs, shuffle=False, num_workers=self.workers)
def configure_optimizers(self):
optimizer = torch.optim.SGD(self.parameters(), lr=self.lr, weight_decay=1e-5)
schedulers = [
{
'scheduler': torch.optim.lr_scheduler.OneCycleLR(optimizer, cycle_momentum=True, pct_start=0.45,
anneal_strategy='cos', max_lr=self.lr, epochs=self.eps, steps_per_epoch=self.t_steps, three_phase=True),
'interval': 'step',
}
]
return [optimizer], schedulers
def training_step(self, batch, batch_idx):
x, y = batch
y_hat = self(x)
loss = self.loss(y_hat, y)
_, preds = torch.max(y_hat, 1)
acc = accuracy(y, preds)
self.log('train_loss', loss, prog_bar=True, on_step=False, on_epoch=True, logger=True)
self.log('train_acc', acc, prog_bar=True, on_step=False, on_epoch=True, logger=True)
return loss
def training_epoch_end(self, outputs):
loss = torch.tensor([x['loss'] for x in outputs]).mean()
def validation_step(self, batch, batch_idx):
x, y = batch
out = self(x)
loss = self.loss(out, y)
_, preds = torch.max(out, 1)
acc = accuracy(y, preds)
self.log('val_loss', loss, prog_bar=True, on_step=False, on_epoch=True, logger=True)
self.log('val_acc', acc, prog_bar=True, on_step=False, on_epoch=True, logger=True)
return loss
def validation_epoch_end(self, outputs):
loss = torch.tensor(outputs).mean()
def test_step(self, batch, batch_idx):
x, y = batch
y_hat = self(x)
_, preds = torch.max(y_hat, 1)
acc = accuracy(y, preds)
self.log('test_acc', acc, prog_bar = True, on_step=False, on_epoch = True, logger = True)
return acc
if __name__ == '__main__':
pl.seed_everything(20)
num_gpus = torch.cuda.device_count()
nepochs = 100
model = resnet(nclasses=100, bs=512, lr=1, epochs=nepochs, workers=8)
lr_monitor = pl.callbacks.LearningRateMonitor(logging_interval='step', log_momentum=True)
chkpt = pl.callbacks.ModelCheckpoint(monitor='val_loss')
es = pl.callbacks.EarlyStopping(monitor='val_loss', patience=10)
wandb_logger = WandbLogger(project='pruning-cifar100')
csv_logger = CSVLogger(save_dir = './csv_logs', name='resnet18-cifar100')
trainer = pl.Trainer(
gpus=num_gpus, accelerator='ddp', auto_select_gpus=True, precision=16, max_epochs=nepochs,
default_root_dir='./checkpoints/resnet18',#logger=[wandb_logger, csv_logger],
callbacks = [lr_monitor], plugins=DDPPlugin(find_unused_parameters=False))
trainer.fit(model)
trainer.test(model, ckpt_path='best') reproduced environment:
|
Maybe it is connected with this? |
@InCogNiTo124 |
I hope you meant, it is pytorch's bug ? |
yeah it runs successfully for a few epochs. For me it never crossed 10 epochs. Which is what is weird. |
For me in 1.4.1 it never crossed 5 hours of training, been trying it for 3-4 days continuously.. Then i downgraded it to 1.4.0 and it has successfully completed 15hours now. FYI. |
@stonelazy I think this is the issue of PL In my case, the Training phase has no problem and when the test phase starts, the error occurs. |
Same thing happened to me for multi-gpu ddp with multi-workers. Downgrading to 1.4.0 resolved the problem. |
Same thing happened to me when using multi-gpu ddp with multi-workers after I upgrade my PL to v1.4.2. But sadly even though I downgrading to 1.4.0, this bug still exits. :( |
This problem could be caused by
|
Is this related to #4471? |
I'm not sure, but considering their stack-traces, it seems to have some relationship. |
Faced exact same issues. Model crashed during validation. In addition to what @yoichi-yamakawa mentioned above, setting |
@tchaton we might want to bump the priority on this, seems like many users are experiencing this. |
I was trying the solution proposed here. |
I think I am following all instructions as per documentation, so not sure where its going wrong. I have my skeleton code in this discussion #9197 if anybody finds any obvious coding mistakes, let me know. |
Hey everyone, I can confirm I could reproduce the error and I will start investigating. Thanks for your patience and we apologise for the inconvenience. Best, |
Here's the minimal reproduction code: import torch
from torch.utils.data import DataLoader, Dataset
import pytorch_lightning as pl
from pytorch_lightning import LightningModule
class RandomDataset(Dataset):
def __init__(self, size, length):
self.len = length
self.data = torch.randn(length, size)
def __getitem__(self, index):
return self.data[index]
def __len__(self):
return self.len
class BoringModel(LightningModule):
def __init__(self):
super().__init__()
self.layer = torch.nn.Linear(32, 2)
def forward(self, x):
return self.layer(x)
def training_step(self, batch, batch_idx):
loss = self(batch).sum()
self.log('train_loss', torch.tensor(1), on_epoch=True)
return {"loss": loss}
def configure_optimizers(self):
return torch.optim.SGD(self.layer.parameters(), lr=0.1)
def train_dataloader(self):
return DataLoader(RandomDataset(32, 64), batch_size=2, num_workers=1)
if __name__ == '__main__':
model = BoringModel()
trainer = pl.Trainer(
gpus=1,
accelerator='ddp',
limit_train_batches=1,
max_epochs=100,
checkpoint_callback=False,
logger=False,
)
trainer.fit(model) $ CUDA_LAUNCH_BLOCKING=1 python bug.py
...
terminate called after throwing an instance of 'c10::CUDAError'
what(): CUDA error: initialization error
Exception raised from insert_events at /pytorch/c10/cuda/CUDACachingAllocator.cpp:1089 (most recent call first):
frame #0: c10::Error::Error(c10::SourceLocation, std::string) + 0x42 (0x7f726617fa22 in /home/carlos/venv/lib/python3.8/site-packages/torch/lib/libc10.so)
frame #1: <unknown function> + 0x10e9e (0x7f72663e0e9e in /home/carlos/venv/lib/python3.8/site-packages/torch/lib/libc10_cuda.so)
frame #2: c10::cuda::CUDACachingAllocator::raw_delete(void*) + 0x1a7 (0x7f72663e2147 in /home/carlos/venv/lib/python3.8/site-packages/torch/lib/libc10_cuda.so)
frame #3: c10::TensorImpl::release_resources() + 0x54 (0x7f72661695a4 in /home/carlos/venv/lib/python3.8/site-packages/torch/lib/libc10.so)
frame #4: <unknown function> + 0xa2822a (0x7f730af8722a in /home/carlos/venv/lib/python3.8/site-packages/torch/lib/libtorch_python.so)
frame #5: /home/carlos/venv/bin/python() [0x4ef828]
frame #6: /home/carlos/venv/bin/python() [0x5fb497]
frame #7: PyTraceBack_Here + 0x6db (0x54242b in /home/carlos/venv/bin/python)
frame #8: _PyEval_EvalFrameDefault + 0x3aec (0x56d32c in /home/carlos/venv/bin/python)
frame #9: /home/carlos/venv/bin/python() [0x50a23e]
frame #10: _PyEval_EvalFrameDefault + 0x5757 (0x56ef97 in /home/carlos/venv/bin/python)
frame #11: _PyFunction_Vectorcall + 0x1b6 (0x5f5e56 in /home/carlos/venv/bin/python)
frame #12: _PyEval_EvalFrameDefault + 0x5757 (0x56ef97 in /home/carlos/venv/bin/python)
frame #13: _PyEval_EvalCodeWithName + 0x26a (0x56822a in /home/carlos/venv/bin/python)
frame #14: _PyFunction_Vectorcall + 0x393 (0x5f6033 in /home/carlos/venv/bin/python)
frame #15: _PyObject_FastCallDict + 0x48 (0x5f5808 in /home/carlos/venv/bin/python)
frame #16: _PyObject_Call_Prepend + 0x61 (0x5f5a21 in /home/carlos/venv/bin/python)
frame #17: /home/carlos/venv/bin/python() [0x59b60b]
frame #18: _PyObject_MakeTpCall + 0x296 (0x5f3446 in /home/carlos/venv/bin/python)
frame #19: _PyEval_EvalFrameDefault + 0x598a (0x56f1ca in /home/carlos/venv/bin/python)
frame #20: _PyFunction_Vectorcall + 0x1b6 (0x5f5e56 in /home/carlos/venv/bin/python)
frame #21: _PyEval_EvalFrameDefault + 0x71e (0x569f5e in /home/carlos/venv/bin/python)
frame #22: _PyEval_EvalCodeWithName + 0x26a (0x56822a in /home/carlos/venv/bin/python)
frame #23: _PyFunction_Vectorcall + 0x393 (0x5f6033 in /home/carlos/venv/bin/python)
frame #24: _PyEval_EvalFrameDefault + 0x71e (0x569f5e in /home/carlos/venv/bin/python)
frame #25: _PyFunction_Vectorcall + 0x1b6 (0x5f5e56 in /home/carlos/venv/bin/python)
frame #26: _PyEval_EvalFrameDefault + 0x5757 (0x56ef97 in /home/carlos/venv/bin/python)
frame #27: _PyFunction_Vectorcall + 0x1b6 (0x5f5e56 in /home/carlos/venv/bin/python)
frame #28: /home/carlos/venv/bin/python() [0x50a33c]
frame #29: PyObject_Call + 0x1f7 (0x5f2b87 in /home/carlos/venv/bin/python)
frame #30: _PyEval_EvalFrameDefault + 0x1f70 (0x56b7b0 in /home/carlos/venv/bin/python)
frame #31: _PyFunction_Vectorcall + 0x1b6 (0x5f5e56 in /home/carlos/venv/bin/python)
frame #32: _PyEval_EvalFrameDefault + 0x8f6 (0x56a136 in /home/carlos/venv/bin/python)
frame #33: _PyFunction_Vectorcall + 0x1b6 (0x5f5e56 in /home/carlos/venv/bin/python)
frame #34: _PyEval_EvalFrameDefault + 0x8f6 (0x56a136 in /home/carlos/venv/bin/python)
frame #35: _PyFunction_Vectorcall + 0x1b6 (0x5f5e56 in /home/carlos/venv/bin/python)
frame #36: /home/carlos/venv/bin/python() [0x50a33c]
frame #37: PyObject_Call + 0x1f7 (0x5f2b87 in /home/carlos/venv/bin/python)
frame #38: /home/carlos/venv/bin/python() [0x654fbc]
frame #39: /home/carlos/venv/bin/python() [0x674aa8]
frame #40: <unknown function> + 0x9609 (0x7f730d495609 in /lib/x86_64-linux-gnu/libpthread.so.0)
frame #41: clone + 0x43 (0x7f730d5d1293 in /lib/x86_64-linux-gnu/libc.so.6)
Traceback (most recent call last):
File "/home/carlos/venv/lib/python3.8/site-packages/torch/utils/data/dataloader.py", line 990, in _try_get_data
data = self._data_queue.get(timeout=timeout)
File "/usr/lib/python3.8/multiprocessing/queues.py", line 107, in get
if not self._poll(timeout):
File "/usr/lib/python3.8/multiprocessing/connection.py", line 257, in poll
return self._poll(timeout)
File "/usr/lib/python3.8/multiprocessing/connection.py", line 424, in _poll
r = wait([self], timeout)
File "/usr/lib/python3.8/multiprocessing/connection.py", line 931, in wait
ready = selector.select(timeout)
File "/usr/lib/python3.8/selectors.py", line 415, in select
fd_event_list = self._selector.poll(timeout)
File "/home/carlos/venv/lib/python3.8/site-packages/torch/utils/data/_utils/signal_handling.py", line 66, in handler
_error_if_any_worker_fails()
RuntimeError: DataLoader worker (pid 3397436) is killed by signal: Aborted.
The above exception was the direct cause of the following exception:
Traceback (most recent call last):
File "/home/carlos/pytorch-lightning/kk.py", line 49, in <module>
trainer.fit(model)
File "/home/carlos/pytorch-lightning/pytorch_lightning/trainer/trainer.py", line 547, in fit
self._call_and_handle_interrupt(self._fit_impl, model, train_dataloaders, val_dataloaders, datamodule)
File "/home/carlos/pytorch-lightning/pytorch_lightning/trainer/trainer.py", line 502, in _call_and_handle_interrupt
return trainer_fn(*args, **kwargs)
File "/home/carlos/pytorch-lightning/pytorch_lightning/trainer/trainer.py", line 577, in _fit_impl
self._run(model)
File "/home/carlos/pytorch-lightning/pytorch_lightning/trainer/trainer.py", line 1001, in _run
self._dispatch()
File "/home/carlos/pytorch-lightning/pytorch_lightning/trainer/trainer.py", line 1072, in _dispatch
self.accelerator.start_training(self)
File "/home/carlos/pytorch-lightning/pytorch_lightning/accelerators/accelerator.py", line 91, in start_training
self.training_type_plugin.start_training(trainer)
File "/home/carlos/pytorch-lightning/pytorch_lightning/plugins/training_type/training_type_plugin.py", line 170, in start_training
self._results = trainer.run_stage()
File "/home/carlos/pytorch-lightning/pytorch_lightning/trainer/trainer.py", line 1082, in run_stage
return self._run_train()
File "/home/carlos/pytorch-lightning/pytorch_lightning/trainer/trainer.py", line 1123, in _run_train
self.fit_loop.run()
File "/home/carlos/pytorch-lightning/pytorch_lightning/loops/base.py", line 111, in run
self.advance(*args, **kwargs)
File "/home/carlos/pytorch-lightning/pytorch_lightning/loops/fit_loop.py", line 206, in advance
epoch_output = self.epoch_loop.run(data_fetcher)
File "/home/carlos/pytorch-lightning/pytorch_lightning/loops/base.py", line 106, in run
self.on_run_start(*args, **kwargs)
File "/home/carlos/pytorch-lightning/pytorch_lightning/loops/epoch/training_epoch_loop.py", line 107, in on_run_start
self.dataloader_iter = _prepare_dataloader_iter(dataloader_iter, self.batch_idx + 1)
File "/home/carlos/pytorch-lightning/pytorch_lightning/loops/utilities.py", line 169, in _prepare_dataloader_iter
dataloader_iter = enumerate(data_fetcher, batch_idx)
File "/home/carlos/pytorch-lightning/pytorch_lightning/utilities/fetching.py", line 200, in __iter__
self.prefetching(self.prefetch_batches)
File "/home/carlos/pytorch-lightning/pytorch_lightning/utilities/fetching.py", line 256, in prefetching
self._fetch_next_batch()
File "/home/carlos/pytorch-lightning/pytorch_lightning/utilities/fetching.py", line 298, in _fetch_next_batch
batch = next(self.dataloader_iter)
File "/home/carlos/pytorch-lightning/pytorch_lightning/trainer/supporters.py", line 569, in __next__
return self.request_next_batch(self.loader_iters)
File "/home/carlos/pytorch-lightning/pytorch_lightning/trainer/supporters.py", line 597, in request_next_batch
return apply_to_collection(loader_iters, Iterator, next_fn)
File "/home/carlos/pytorch-lightning/pytorch_lightning/utilities/apply_func.py", line 93, in apply_to_collection
return function(data, *args, **kwargs)
File "/home/carlos/pytorch-lightning/pytorch_lightning/trainer/supporters.py", line 584, in next_fn
batch = next(iterator)
File "/home/carlos/venv/lib/python3.8/site-packages/torch/utils/data/dataloader.py", line 521, in __next__
data = self._next_data()
File "/home/carlos/venv/lib/python3.8/site-packages/torch/utils/data/dataloader.py", line 1186, in _next_data
idx, data = self._get_data()
File "/home/carlos/venv/lib/python3.8/site-packages/torch/utils/data/dataloader.py", line 1152, in _get_data
success, data = self._try_get_data()
File "/home/carlos/venv/lib/python3.8/site-packages/torch/utils/data/dataloader.py", line 1003, in _try_get_data
raise RuntimeError('DataLoader worker (pid(s) {}) exited unexpectedly'.format(pids_str)) from e
RuntimeError: DataLoader worker (pid(s) 3397436) exited unexpectedly Running on master with the deadlock detection removed. Current findings:
|
Hey everyone, After a long day of debugging with @carmocca, we finally found the source of the problem. Should be fixed on master and next weekly release. Best, |
Good job! Can't wait to try the fix as soon as possible |
Can’t wait to see what was the annoying problem! 😂😭 |
copy vs deepcopy, somehow |
Hey @InCogNiTo124, Mistery to me. But my guess is that PyTorch is doing quite a lot of work on deepcopy compared to copy: https://github.com/pytorch/pytorch/blob/83e28a7d281c91a6d1a12b86bd5fb212dd424a85/torch/_tensor.py#L80 And copy might not as strict as deepcopy. Best, |
Works well with |
This fix that was introduced in PL 1.4.5 crashes my DDP trainings completely, so that it doesn't run any training step. However, I also experienced the original bug reported in the 1.4.x versions. @tchaton File "/home/meeso/miniconda3/envs/calvin/lib/python3.8/site-packages/pytorch_lightning/trainer/trainer.py", line 552, in fit
self._run(model)
File "/home/meeso/miniconda3/envs/calvin/lib/python3.8/site-packages/pytorch_lightning/trainer/trainer.py", line 917, in _run
self._dispatch()
File "/home/meeso/miniconda3/envs/calvin/lib/python3.8/site-packages/pytorch_lightning/trainer/trainer.py", line 985, in _dispatch
self.accelerator.start_training(self)
File "/home/meeso/miniconda3/envs/calvin/lib/python3.8/site-packages/pytorch_lightning/accelerators/accelerator.py", line 92, in start_training
self.training_type_plugin.start_training(trainer)
File "/home/meeso/miniconda3/envs/calvin/lib/python3.8/site-packages/pytorch_lightning/plugins/training_type/training_type_plugin.py", line 161, in start_training
self._results = trainer.run_stage()
File "/home/meeso/miniconda3/envs/calvin/lib/python3.8/site-packages/pytorch_lightning/trainer/trainer.py", line 995, in run_stage
return self._run_train()
File "/home/meeso/miniconda3/envs/calvin/lib/python3.8/site-packages/pytorch_lightning/trainer/trainer.py", line 1044, in _run_train
self.fit_loop.run()
File "/home/meeso/miniconda3/envs/calvin/lib/python3.8/site-packages/pytorch_lightning/loops/base.py", line 111, in run
self.advance(*args, **kwargs)
File "/home/meeso/miniconda3/envs/calvin/lib/python3.8/site-packages/pytorch_lightning/loops/fit_loop.py", line 200, in advance
epoch_output = self.epoch_loop.run(train_dataloader)
File "/home/meeso/miniconda3/envs/calvin/lib/python3.8/site-packages/pytorch_lightning/loops/base.py", line 111, in run
self.advance(*args, **kwargs)
File "/home/meeso/miniconda3/envs/calvin/lib/python3.8/site-packages/pytorch_lightning/loops/epoch/training_epoch_loop.py", line 130, in advance
batch_output = self.batch_loop.run(batch, self.iteration_count, self._dataloader_idx)
File "/home/meeso/miniconda3/envs/calvin/lib/python3.8/site-packages/pytorch_lightning/loops/batch/training_batch_loop.py", line 100, in run
super().run(batch, batch_idx, dataloader_idx)
File "/home/meeso/miniconda3/envs/calvin/lib/python3.8/site-packages/pytorch_lightning/loops/base.py", line 111, in run
self.advance(*args, **kwargs)
File "/home/meeso/miniconda3/envs/calvin/lib/python3.8/site-packages/pytorch_lightning/loops/batch/training_batch_loop.py", line 149, in advance
self.batch_outputs[opt_idx].append(deepcopy(result.training_step_output))
File "/home/meeso/miniconda3/envs/calvin/lib/python3.8/copy.py", line 172, in deepcopy
y = _reconstruct(x, memo, *rv)
File "/home/meeso/miniconda3/envs/calvin/lib/python3.8/copy.py", line 270, in _reconstruct
state = deepcopy(state, memo)
File "/home/meeso/miniconda3/envs/calvin/lib/python3.8/copy.py", line 146, in deepcopy
y = copier(x, memo)
File "/home/meeso/miniconda3/envs/calvin/lib/python3.8/copy.py", line 230, in _deepcopy_dict
y[deepcopy(key, memo)] = deepcopy(value, memo)
File "/home/meeso/miniconda3/envs/calvin/lib/python3.8/copy.py", line 146, in deepcopy
y = copier(x, memo)
File "/home/meeso/miniconda3/envs/calvin/lib/python3.8/copy.py", line 230, in _deepcopy_dict
y[deepcopy(key, memo)] = deepcopy(value, memo)
File "/home/meeso/miniconda3/envs/calvin/lib/python3.8/copy.py", line 146, in deepcopy
y = copier(x, memo)
File "/home/meeso/miniconda3/envs/calvin/lib/python3.8/copy.py", line 230, in _deepcopy_dict
y[deepcopy(key, memo)] = deepcopy(value, memo)
File "/home/meeso/miniconda3/envs/calvin/lib/python3.8/copy.py", line 146, in deepcopy
y = copier(x, memo)
File "/home/meeso/miniconda3/envs/calvin/lib/python3.8/copy.py", line 205, in _deepcopy_list
append(deepcopy(a, memo))
File "/home/meeso/miniconda3/envs/calvin/lib/python3.8/copy.py", line 172, in deepcopy
y = _reconstruct(x, memo, *rv)
File "/home/meeso/miniconda3/envs/calvin/lib/python3.8/copy.py", line 270, in _reconstruct
state = deepcopy(state, memo)
File "/home/meeso/miniconda3/envs/calvin/lib/python3.8/copy.py", line 146, in deepcopy
y = copier(x, memo)
File "/home/meeso/miniconda3/envs/calvin/lib/python3.8/copy.py", line 230, in _deepcopy_dict
y[deepcopy(key, memo)] = deepcopy(value, memo)
File "/home/meeso/miniconda3/envs/calvin/lib/python3.8/copy.py", line 172, in deepcopy
y = _reconstruct(x, memo, *rv)
File "/home/meeso/miniconda3/envs/calvin/lib/python3.8/copy.py", line 270, in _reconstruct
state = deepcopy(state, memo)
File "/home/meeso/miniconda3/envs/calvin/lib/python3.8/copy.py", line 146, in deepcopy
y = copier(x, memo)
File "/home/meeso/miniconda3/envs/calvin/lib/python3.8/copy.py", line 230, in _deepcopy_dict
y[deepcopy(key, memo)] = deepcopy(value, memo)
File "/home/meeso/miniconda3/envs/calvin/lib/python3.8/copy.py", line 153, in deepcopy
y = copier(memo)
File "/home/meeso/miniconda3/envs/calvin/lib/python3.8/site-packages/torch/_tensor.py", line 55, in __deepcopy__
raise RuntimeError("Only Tensors created explicitly by the user "
RuntimeError: Only Tensors created explicitly by the user (graph leaves) support the deepcopy protocol at the moment |
Related to this issue - is there already an integration test for DDP, or would it be possible to add one? |
We have DDP tests! However, testing this is not so easy as it requires |
So I have found why PL 1.4.5 breaks my training with the error " Only Tensors created explicitly by the user (graph leaves) support the deepcopy" @tchaton @carmocca . I was returning a dictionary for Basically, I had something like this def training_step(self, batch, batch_idx):
x, y, z = batch
out, mean, std = self.encoder(x)
loss = self.loss(out, x)
encoders_dict = {"one": Independent(Normal(mean, std)), 1)}
return {"loss": loss, "encoders_dict": encoders_dict} Returning only the loss tensor works as a workaround. |
@mees in that case, the encoders_dict = {"one": Independent(Normal(mean, std)), 1).detach()}
return {"loss": loss, "encoders_dict": encoders_dict} |
meets the same error with the same Traceback of @mees in pytorch-lightning 1.4.8, but this time only with
so the problem might not because of |
@popfido I can check it if you share a repro script, but this should be fixed in master. |
Hi @carmocca, This scene happends when I call With: Metric I used is AUROC |
🐛 Bug
When I start training on 2 opus using pytorch-lightning 1.4.1 the training crashes after a few epochs. Note that this happens only on 1.4.1
If I run my code using pytorch-lightning 1.4.0 everything works fine.
There are multiple versions of the same error with different versions. For brevity I'm attaching just one trace.
Here's the error trace:
To Reproduce
Here's my code.
It's a simple code which trains resnet18 on cifar using 2 gpus with DDP.
Expected behavior
It's supposed to train for 100 epochs and
Environment
Additional context
The error happens irrespective of whether I use DP or DDP
The text was updated successfully, but these errors were encountered: