Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Using multiple dataloaders in the training_step? #2457

Closed
christofer-f opened this issue Jul 1, 2020 · 10 comments · Fixed by #1959
Closed

Using multiple dataloaders in the training_step? #2457

christofer-f opened this issue Jul 1, 2020 · 10 comments · Fixed by #1959
Labels
data handling Generic data-related topic feature Is an improvement or enhancement question Further information is requested won't fix This will not be worked on

Comments

@christofer-f
Copy link
Contributor

christofer-f commented Jul 1, 2020

Hi!

In the pseudo-code below I have two models that I want to fit with two different datasets.
I have tried to figure out if this is possible by reading test_dataloaders.py with no success...

In the documentation, it states that:
Multiple training dataloaders
For training, the best way to use multiple-dataloaders is to create a Dataloader class which wraps both your dataloaders. (This of course also works for testing and validation dataloaders).

But that doesn't really help me...

I guess that this already has been discussed in: #1089
And that I should study: https://gist.github.com/Dref360/2524e524244569ed47428f19c487f264

But it would be nice with a dataloader_idx like just like the optimizer_idx parameter...

Or perhaps a batch could have a dictionary-like structure where you sample data into different "baskets"
so that I could write something like:

        if optimizer_idx == 0:
            # REQUIRED
            x, y = batch[0]

class FashionMNIST_and_MNISTModel(pl.LightningModule):

    def __init__(self):
        super(FashionMNIST_and_MNISTModel, self).__init__()

       # l1 should be fit to MNIST dataset
        self.l1 = torch.nn.Linear(28 * 28, 10)

       # l2 should be fit to FashionMNIST dataset
        self.l2 = torch.nn.Linear(28 * 28, 10)  

    def training_step(self, batch, batch_nb, optimizer_idx):
        if optimizer_idx == 0:
            # REQUIRED
            x, y = batch
            y_hat = torch.relu(self.l1(x.view(x.size(0), -1)))
            loss_l1 = F.cross_entropy(y_hat, y)
            tensorboard_logs = {'train_loss': loss_l1}
            return {'loss': loss_l1, 'log': tensorboard_logs}
        if optimizer_idx == 1:            
            # REQUIRED
            x, y = batch
            y_hat = torch.relu(self.l2(x.view(x.size(0), -1)))
            loss_l2 = F.cross_entropy(y_hat, y)
            tensorboard_logs = {'train_loss': loss_l2}
            return {'loss': loss_l2, 'log': tensorboard_logs}

    def train_dataloader(self):
        # REQUIRED
        return [
            DataLoader(MNIST(os.getcwd(), train=True, download=True, transform=transforms.ToTensor()), batch_size=32),
            DataLoader(FashionMNIST(os.getcwd(), train=True, download=True, transform=transforms.ToTensor()), batch_size=32)
        ]

//Christofer

@christofer-f christofer-f added the question Further information is requested label Jul 1, 2020
@github-actions
Copy link
Contributor

github-actions bot commented Jul 1, 2020

Hi! thanks for your contribution!, great first issue!

@justusschock
Copy link
Member

justusschock commented Jul 2, 2020

Hi @christofer-f, I've actually prototyped this feature already (in #1959). If your dataset has a length, this is already working (the failing tests are due to the case, where your dataset does not have a defined length).
Any feedback is highly appreciated :)

@christofer-f
Copy link
Contributor Author

Hi! This looks very promising.

In my case, I have several processes that create datasets on the fly.
So the requirement with a minimum length is not a problem. At least not for me.

I will try to apply your code to the toy example above.

Br,
Christofer

@christofer-f
Copy link
Contributor Author

This is exactly what I wanted!!! Great job.

pip install -e git://github.com/PyTorchLightning/pytorch-lightning.git@train_loaders#egg=pytorch-lightning_train_loaders

import os

import torch
from torch.nn import functional as F
from torch.utils.data import DataLoader
from torchvision.datasets import MNIST, FashionMNIST
from torchvision import transforms
import pytorch_lightning as pl

class FashionMNIST_and_MNISTModel(pl.LightningModule):

    def __init__(self):
        super(FashionMNIST_and_MNISTModel, self).__init__()

        self.l_mnist = torch.nn.Linear(28 * 28, 10)
        self.l_fashion_mnist = torch.nn.Linear(28 * 28, 10)        

    def forward(self, x):
        # called with self(x)
        return torch.relu(self.l_mnist(x.view(x.size(0), -1)))

    def training_step(self, batch, batch_idx, optimizer_idx):
        if optimizer_idx == 0:
            x, y = batch['mnist']
            y_hat = torch.relu(self.l_mnist(x.view(x.size(0), -1)))
            loss_mnist = F.cross_entropy(y_hat, y)
            tensorboard_logs = {'train_loss': loss_mnist}
            return {'loss': loss_mnist, 'log': tensorboard_logs}
        if optimizer_idx == 1:
            x, y = batch['fashion_mnist']
            y_hat = torch.relu(self.l_fashion_mnist(x.view(x.size(0), -1)))
            loss_fashion_mnist = F.cross_entropy(y_hat, y)
            tensorboard_logs = {'train_loss': loss_fashion_mnist}
            return {'loss': loss_fashion_mnist, 'log': tensorboard_logs}

    def configure_optimizers(self):
        opt_mnist = torch.optim.Adam(self.l_mnist.parameters(), lr=0.02)
        opt_fashion_mnist = torch.optim.Adam(self.l_fashion_mnist.parameters(), lr=0.02)
        return [opt_mnist, opt_fashion_mnist], []

    def train_dataloader(self):
        loader_mnist = DataLoader(MNIST(os.getcwd(), train=True, download=True, transform=transforms.ToTensor()), batch_size=32)
        loader_fashion_mnist = DataLoader(FashionMNIST(os.getcwd(), train=True, download=True, transform=transforms.ToTensor()), batch_size=32)       
        loaders = {"mnist": loader_mnist, "fashion_mnist": loader_fashion_mnist}
        return loaders


def main():
    mnist_model = FashionMNIST_and_MNISTModel()

    trainer = pl.Trainer(gpus=1, fast_dev_run=True)    
    trainer.fit(mnist_model)   

if __name__ == "__main__":
    main()

@justusschock
Copy link
Member

great to hear that! So the feature should almost be ready to merge to master. There is also a trainer flag, that controls, how to deal with datasets of different lengths

@christofer-f
Copy link
Contributor Author

I close this. It is a very good and useful feature. I need to tinker a bit with this before I understand how it really works...

@omiita
Copy link

omiita commented Jul 7, 2020

Thank your for this wonderful work @justusschock !
I play around this new feature a bit, and I encounter some problem.
I used following code (from @christofer-f ) to check how many steps in one epoch.

pip install -e git://github.com/PyTorchLightning/pytorch-lightning.git@train_loaders#egg=pytorch-lightning_train_loaders
import os

import torch
from torch.nn import functional as F
from torch.utils.data import DataLoader
from torchvision.datasets import MNIST, FashionMNIST
from torchvision import transforms
import pytorch_lightning as pl

class FashionMNIST_and_MNISTModel(pl.LightningModule):
    def __init__(self):
        super(FashionMNIST_and_MNISTModel, self).__init__()

        self.l_mnist = torch.nn.Linear(28 * 28, 10)
        self.l_fashion_mnist = torch.nn.Linear(28 * 28, 10)        

    def forward(self, x):
        return torch.relu(self.l_mnist(x.view(x.size(0), -1)))

    def training_step(self, batch, batch_idx, optimizer_idx):
        if optimizer_idx == 0:
            print('mnist')
            x, y = batch['mnist']
            y_hat = self(x)
            loss_mnist = F.cross_entropy(y_hat, y)
            tensorboard_logs = {'train_loss': loss_mnist}
            return {'loss': loss_mnist, 'log': tensorboard_logs}
        if optimizer_idx == 1:
            print('fashion')
            x, y = batch['fashion_mnist']
            y_hat = torch.relu(self.l_fashion_mnist(x.view(x.size(0), -1)))
            loss_fashion_mnist = F.cross_entropy(y_hat, y)
            tensorboard_logs = {'train_loss': loss_fashion_mnist}
            return {'loss': loss_fashion_mnist, 'log': tensorboard_logs}

    def configure_optimizers(self):
        opt_mnist = torch.optim.Adam(self.l_mnist.parameters(), lr=0.02)
        opt_fashion_mnist = torch.optim.Adam(self.l_fashion_mnist.parameters(), lr=0.02)
        return [opt_mnist, opt_fashion_mnist], []

    def train_dataloader(self):
        loader_mnist = DataLoader(MNIST(os.getcwd(), train=True, download=True, transform=transforms.ToTensor()), batch_size=32)
        loader_fashion_mnist = DataLoader(FashionMNIST(os.getcwd(), train=True, download=True, transform=transforms.ToTensor()), batch_size=32)       
        loaders = {"mnist": loader_mnist, "fashion_mnist": loader_fashion_mnist}
        return loaders

mnist_model = FashionMNIST_and_MNISTModel()

trainer = pl.Trainer(gpus=0, fast_dev_run=False, max_epochs=1)    
trainer.fit(mnist_model)

And, I got the following output.

GPU available: False, used: False
TPU available: False, using: 0 TPU cores

  | Name            | Type   | Params
-------------------------------------------
0 | l_mnist         | Linear | 7 K   
1 | l_fashion_mnist | Linear | 7 K   

Epoch 1: 100%
2/2 [00:00<00:00, 30.03it/s, loss=2.272, v_num=16]

mnist
fashion
mnist
fashion

1

It seems there are only two steps in one epoch according to the print messages.
There should be
(#samples) / (batch_size) = (50,000) / (32) => 1563 steps
Am I missing something? If so, please correct me!

Thanks in advance!

@christofer-f christofer-f reopened this Jul 8, 2020
@christofer-f
Copy link
Contributor Author

christofer-f commented Jul 8, 2020

Hi,

I think @omiita is right.
Comparing the training loops when only running the mnist dataset, the difference is obvious.
For details:
#1959 (comment)

@christofer-f
Copy link
Contributor Author

I close this again... the problem has been identified...

@edenlightning edenlightning reopened this Jul 27, 2020
@edenlightning edenlightning linked a pull request Jul 27, 2020 that will close this issue
5 tasks
@stale stale bot added the won't fix This will not be worked on label Oct 22, 2020
@Lightning-AI Lightning-AI deleted a comment from stale bot Oct 22, 2020
@stale stale bot removed the won't fix This will not be worked on label Oct 22, 2020
@stale stale bot added the won't fix This will not be worked on label Nov 21, 2020
@Lightning-AI Lightning-AI deleted a comment from stale bot Nov 23, 2020
@stale stale bot removed the won't fix This will not be worked on label Nov 23, 2020
@justusschock justusschock added data handling Generic data-related topic feature Is an improvement or enhancement labels Nov 23, 2020
@stale
Copy link

stale bot commented Dec 23, 2020

This issue has been automatically marked as stale because it hasn't had any recent activity. This issue will be closed in 7 days if no further activity occurs. Thank you for your contributions, Pytorch Lightning Team!

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
data handling Generic data-related topic feature Is an improvement or enhancement question Further information is requested won't fix This will not be worked on
Projects
None yet
Development

Successfully merging a pull request may close this issue.

4 participants