Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Concatenation of non-tensor values when using DP accelerator do not get aggregated in lists #10155

Closed
jzazo opened this issue Oct 26, 2021 · 4 comments
Assignees
Labels
bug Something isn't working help wanted Open to be worked on priority: 1 Medium priority task strategy: dp (removed in pl) DataParallel

Comments

@jzazo
Copy link

jzazo commented Oct 26, 2021

🐛 Bug

When using DP mode, tensors are concatenated/stacked at the input of training_step_end, etc. Other non-tensor objects are not concatenated in lists, but rather displayed as a string object, e.g., '<map object at 0x7fcd1d1e5700>'.

To Reproduce

import os

import torch
from pytorch_lightning import LightningModule, Trainer
from torch.utils.data import DataLoader, Dataset


class RandomDataset(Dataset):
    def __init__(self, size, length):
        self.len = length
        self.data = torch.randn(length, size)

    def __getitem__(self, index):
        return self.data[index]

    def __len__(self):
        return self.len


class BoringModel(LightningModule):
    def __init__(self):
        super().__init__()
        self.layer = torch.nn.Linear(32, 2)

    def forward(self, x):
        return self.layer(x)

    def training_step(self, batch, batch_idx):
        loss = self(batch).sum()
        self.log("train_loss", loss)
        return {"loss": loss, "data": batch_idx, "example": "something"}

    def training_step_end(self, batch):
        mean_loss = batch["loss"].mean()
        outputs = {**batch, "loss": mean_loss}
        print("outputs on step end", outputs)
        return outputs

    def training_epoch_end(self, outputs):
        print("outputs on epoch end", outputs)

    def configure_optimizers(self):
        return torch.optim.SGD(self.layer.parameters(), lr=0.1)

def run():
    train_data = DataLoader(RandomDataset(32, 64), batch_size=2)

    model = BoringModel()
    trainer = Trainer(
        default_root_dir=os.getcwd(),
        limit_train_batches=2,
        max_epochs=1,
        accelerator="dp",
        gpus=2,
    )
    trainer.fit(model, train_dataloaders=train_data)

if __name__ == "__main__":
    run()

Expected behavior

print("outputs on step end", outputs) is something like

outputs on step end {'loss': tensor(-0.5886, device='cuda:0', grad_fn=<MeanBackward0>), 'data': tensor([0, 0], device='cuda:0'), 'example': '<map object at 0x7f169de23340>'}

but should be something like

outputs on step end {'loss': tensor(-0.5886, device='cuda:0', grad_fn=<MeanBackward0>), 'data': tensor([0, 0], device='cuda:0'), 'example': ['something', 'something']}

Environment

  • PyTorch Lightning Version (e.g., 1.3.0): 1.4.9
  • PyTorch Version (e.g., 1.8): 1.10.0
  • Python version: 3.8.5
  • OS (e.g., Linux): Ubuntu 20.04
  • How you installed PyTorch (conda, pip, source): pipenv

Additional context

Discussion following this message and discussion with @awaelchli.

@jzazo jzazo added bug Something isn't working help wanted Open to be worked on labels Oct 26, 2021
@awaelchli
Copy link
Contributor

awaelchli commented Oct 26, 2021

Hello @jzazo

The behavior you are seeing is because Lightning just relies on the default behavior of torch.nn.DataParallel. Here is a demonstration with pure torch data parallel:

import torch


class BoringModel(torch.nn.Module):
    def __init__(self):
        super().__init__()
        self.layer = torch.nn.Linear(32, 2)

    def forward(self, x):
        out = self.layer(x)
        return {"out": out, "example": "something"}


def run():
    device = torch.device("cuda", 0)
    model = BoringModel().to(device)
    model = torch.nn.DataParallel(model, device_ids=[0, 1])
    output = model(torch.rand(1, 32).to(device))
    print(output)

    # {'out': tensor([[-0.1890,  0.3555]], device='cuda:0', grad_fn=<GatherBackward>), 'example': '<map object at 0x7f8eea4f8640>'}

if __name__ == "__main__":
    run()

As you can see, the same map object as you get 😅

I will do some research why DP has it implemented like this and what data types it supports.

@awaelchli awaelchli added strategy: dp (removed in pl) DataParallel priority: 2 Low priority task priority: 1 Medium priority task and removed priority: 2 Low priority task labels Oct 26, 2021
@awaelchli awaelchli self-assigned this Oct 26, 2021
@jzazo
Copy link
Author

jzazo commented Oct 28, 2021

This behavior is being tracked in this issue: pytorch/pytorch#62466

Regarding splitting data types, does lightning process them to aid in the splitting? For example, I pass a list of strings to a BERT-like model that converts the strings into tensors (via alphabet object in charge of the conversion). Is there a way in which the list of strings could be split to GPUs properly? Seems that in DP lists are shallowed copied to all GPUs.

Or is there an accelerator where we would get independent batches per GPU rather than splitting one in two? DDP has a high toll in memory because it spawns new processes and I my dataset is hold in memory. Any other viable?

@awaelchli
Copy link
Contributor

Regarding splitting data types, does lightning process them to aid in the splitting? For example, I pass a list of strings to a BERT-like model that converts the strings into tensors (via alphabet object in charge of the conversion). Is there a way in which the list of strings could be split to GPUs properly?

No, Lightning does not provide this. There is unfortunately no good/safe way for us to intercept this because it happens inside the torch.nn.DataParallel forward and scatter function.

All multi-device plugins other than DP spawn processes (and get independent batches per GPU). So ddp is a strong recommendation here. For the in-memory dataset, perhaps this could help.

@jzazo
Copy link
Author

jzazo commented Oct 29, 2021

Your link looks very helpful, I will try it next week. I will close this ticket for now as it seems to be a pytorch issue. Thanks for all the help.

@jzazo jzazo closed this as completed Oct 29, 2021
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
bug Something isn't working help wanted Open to be worked on priority: 1 Medium priority task strategy: dp (removed in pl) DataParallel
Projects
None yet
Development

No branches or pull requests

2 participants