Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[DDP] Metrics .compute() / sync_ddp fails with move_metrics_to_cpu #10379

Closed
Quintulius opened this issue Nov 5, 2021 · 5 comments
Closed

[DDP] Metrics .compute() / sync_ddp fails with move_metrics_to_cpu #10379

Quintulius opened this issue Nov 5, 2021 · 5 comments
Labels
bug Something isn't working help wanted Open to be worked on priority: 1 Medium priority task

Comments

@Quintulius
Copy link

Quintulius commented Nov 5, 2021

🐛 Bug

Metrics compute() method fails when calling sync_ddp with DDP and move_metrics_to_cpu:

  File "pytorch1.10/lib/python3.8/site-packages/pytorch_lightning/plugins/training_type/ddp.py", line 385, in reduce
    tensor = sync_ddp_if_available(tensor, group, reduce_op=reduce_op)
  File "pytorch1.10/lib/python3.8/site-packages/pytorch_lightning/utilities/distributed.py", line 158, in sync_ddp_if_available
    return sync_ddp(result, group=group, reduce_op=reduce_op)
  File "pytorch1.10/lib/python3.8/site-packages/pytorch_lightning/utilities/distributed.py", line 193, in sync_ddp
    torch.distributed.all_reduce(result, op=op, group=group, async_op=False)
  File "pytorch1.10/lib/python3.8/site-packages/torch/distributed/distributed_c10d.py", line 1287, in all_reduce
    work = group.allreduce([tensor], opts)
RuntimeError: Tensors must be CUDA and dense

To Reproduce

import os

import torch
from pytorch_lightning.utilities.types import EPOCH_OUTPUT
from torch.utils.data import DataLoader, Dataset

from pytorch_lightning import LightningModule, Trainer
from torchmetrics import Accuracy


class RandomDataset(Dataset):
    def __init__(self, size, length):
        self.len = length
        self.data = torch.randn(length, size)
        self.labels = torch.randint(0, 2, (length, 2))

    def __getitem__(self, index):
        return self.data[index], self.labels[index]

    def __len__(self):
        return self.len


class BoringModel(LightningModule):
    def __init__(self):
        super().__init__()
        self.layer = torch.nn.Linear(32, 2)
        self.metric = Accuracy()

    def forward(self, x):
        return self.layer(x)

    def training_step(self, batch, batch_idx):
        x, y = batch
        preds = self(x)
        loss = self(x).sum()
        self.log("train_loss", loss)
        self.log("accuracy", self.metric(preds, y))
        return {"loss": loss}

    def training_epoch_end(self, outputs: EPOCH_OUTPUT) -> None:
        self.log("end_accuracy", self.metric.compute())

    def configure_optimizers(self):
        return torch.optim.SGD(self.layer.parameters(), lr=0.1)


def run():
    train_data = DataLoader(RandomDataset(32, 64), batch_size=2)

    model = BoringModel()
    trainer = Trainer(
        default_root_dir=os.getcwd(),
        limit_train_batches=1,
        num_sanity_val_steps=0,
        max_epochs=1,
        gpus=1,
        accelerator="ddp",
        move_metrics_to_cpu=True
    )
    trainer.fit(model, train_dataloaders=train_data)


if __name__ == "__main__":
    run()

Expected behavior

Metrics are computed without trouble !

Environment

PyTorch version: 1.10.0+cu113
Is debug build: False
CUDA used to build PyTorch: 11.3
ROCM used to build PyTorch: N/A

OS: Ubuntu 20.04.3 LTS (x86_64)
GCC version: (Ubuntu 7.5.0-6ubuntu2) 7.5.0
Clang version: Could not collect
CMake version: version 3.16.3
Libc version: glibc-2.31

Python version: 3.8.12 (default, Oct 12 2021, 13:49:34)  [GCC 7.5.0] (64-bit runtime)
Python platform: Linux-5.11.0-38-generic-x86_64-with-glibc2.17
Is CUDA available: True
CUDA runtime version: 10.1.243
GPU models and configuration: GPU 0: NVIDIA GeForce GTX 970
Nvidia driver version: 495.29.05
cuDNN version: Could not collect
HIP runtime version: N/A
MIOpen runtime version: N/A

Versions of relevant libraries:
[pip3] numpy==1.21.2
[pip3] pytorch-lightning==1.4.9
[pip3] torch==1.10.0+cu113
[pip3] torchaudio==0.10.0+cu113
[pip3] torchmetrics==0.5.1
[pip3] torchvision==0.11.1+cu113
[conda] blas                      1.0                         mkl  
[conda] cudatoolkit               11.3.1               h2bc3f7f_2  
[conda] ffmpeg                    4.3                  hf484d3e_0    pytorch
[conda] mkl                       2021.3.0           h06a4308_520  
[conda] mkl-service               2.4.0            py38h7f8727e_0  
[conda] mkl_fft                   1.3.1            py38hd3c417c_0  
[conda] mkl_random                1.2.2            py38h51133e4_0  
[conda] numpy                     1.21.2           py38h20f2e39_0  
[conda] numpy-base                1.21.2           py38h79a1101_0  
[conda] pytorch-lightning         1.4.9                    pypi_0    pypi
[conda] pytorch-mutex             1.0                        cuda    pytorch
[conda] torch                     1.10.0+cu113             pypi_0    pypi
[conda] torchaudio                0.10.0+cu113             pypi_0    pypi
[conda] torchmetrics              0.5.1                    pypi_0    pypi
[conda] torchvision               0.11.1+cu113             pypi_0    pypi

Additional context

cc @tchaton @rohitgr7

@Quintulius Quintulius added bug Something isn't working help wanted Open to be worked on labels Nov 5, 2021
@tchaton tchaton added priority: 1 Medium priority task priority: 0 High priority task and removed priority: 1 Medium priority task labels Nov 15, 2021
@tchaton tchaton self-assigned this Nov 15, 2021
@tangbinh
Copy link
Contributor

Has anyone been looking into this issue? If not, I can try to help.

@tchaton
Copy link
Contributor

tchaton commented Nov 17, 2021

Hey @tangbinh @Quintulius,

After some investigation, it seems the move_metrics_to_cpu is working as expected after the refactoring of the logging.

Before:

  • We used to store metrics in a list of the Result objects. As the epoch was quite long, it could result in OOM and the move_metrics_to_device was introduced for those extreme use case.

Now:

  • We track the metric and keep only a single value / Metric for each key within the Result Object. To properly support this feature, it would be quite engineering involved and I don't believe it would save much memory for performance drop.

I would suggest depreciating this parameter altogether.

Best,
T.C

@Quintulius
Copy link
Author

Quintulius commented Nov 18, 2021

@tchaton Thanks for the update. I use this parameter to deal with a custom metric (subclass of pytorch_lightining.metrics.Metric) which requires a lot of memory. What strategy should I use now ?

@tchaton
Copy link
Contributor

tchaton commented Nov 19, 2021

Hey @Quintulius,

This discussion has been moved there: #10595.

The decision to add to proper support for this and evaluate its performance impact.

I will keep you updated.

@Borda Borda assigned awaelchli and unassigned tchaton Aug 8, 2022
@carmocca
Copy link
Contributor

Closing in favor of #10595

@carmocca carmocca closed this as not planned Won't fix, can't repro, duplicate, stale Aug 22, 2022
@Borda Borda added priority: 1 Medium priority task and removed priority: 0 High priority task labels Aug 22, 2022
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
bug Something isn't working help wanted Open to be worked on priority: 1 Medium priority task
Projects
No open projects
Status: Done
Development

No branches or pull requests

6 participants