Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

'NoneType' object has no attribute 'test_step' when DDP #577

Closed
jgsch opened this issue Dec 3, 2019 · 9 comments · Fixed by #1017
Closed

'NoneType' object has no attribute 'test_step' when DDP #577

jgsch opened this issue Dec 3, 2019 · 9 comments · Fixed by #1017
Assignees
Labels
bug Something isn't working help wanted Open to be worked on

Comments

@jgsch
Copy link

jgsch commented Dec 3, 2019

Describe the bug

When I activate the DDP, the test_step function is replaced by None. No problem when I run on one GPU.

To Reproduce

import torch
import numpy as np
import torch.nn as nn
import torch.optim as optim
import pytorch_lightning as pl
import torch.utils.data as data
from pytorch_lightning import Trainer
import torchvision
import torchvision.transforms as transforms

num_workers = 0
classes = (
    'plane', 'car', 'bird', 'cat', 'deer', 'dog',
    'frog', 'horse', 'ship', 'truck'
)
n_classes = len(classes)
ddp = True


class PlModule(pl.LightningModule):
    def __init__(self):
        super(PlModule, self).__init__()
        model = torchvision.models.squeezenet1_1(True)
        model.n_classes = n_classes

        final_conv = nn.Conv2d(512, n_classes, kernel_size=1)
        model.classifier = nn.Sequential(
            nn.Dropout(p=0.5),
            final_conv,
            nn.ReLU(inplace=True),
            nn.AdaptiveAvgPool2d((1, 1))
        )

        self.model = model
        self.criterion = nn.CrossEntropyLoss()

    def forward(self, x):
        return self.model(x)

    def training_step(self, batch, batch_nb):
        inputs, targets = batch
        outputs = self.forward(inputs)
        loss = self.criterion(outputs, targets)
        return {"loss": loss}

    def test_step(self, batch, batch_nb):
        inputs, targets = batch
        outputs = self.forward(inputs)
        loss = self.criterion(outputs, targets)
        return {'test_loss': loss}

    def test_end(self, outputs):
        metric = [o["test_loss"] for o in outputs]
        val_loss = np.sum(metric) / len(outputs)
        tqdm_dict = {"test_loss": val_loss}
        return {
            'test_loss': tqdm_dict["test_loss"],
            'progress_bar': tqdm_dict,
        }

    def configure_optimizers(self):
        return optim.AdamW(self.parameters(), lr=0.001)

    @pl.data_loader
    def train_dataloader(self):
        transform = transforms.Compose([
            transforms.ToTensor(),
            transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
        ])
        trainset = torchvision.datasets.CIFAR10(
            root='./data', train=True,
            download=True, transform=transform)

        t_sampler = None
        if ddp:
            t_sampler = data.distributed.DistributedSampler(trainset)
        return torch.utils.data.DataLoader(
            trainset,
            sampler=t_sampler,
            batch_size=400,
            num_workers=num_workers,
            shuffle=False,
            pin_memory=True,
        )

    @pl.data_loader
    def test_dataloader(self):
        transform = transforms.Compose([
            transforms.ToTensor(),
            transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
        ])
        testset = torchvision.datasets.CIFAR10(
            root='./data', train=False,
            download=True, transform=transform)

        test_sampler = None
        if ddp:
            test_sampler = data.distributed.DistributedSampler(testset)

        return torch.utils.data.DataLoader(
            testset,
            sampler=test_sampler,
            batch_size=400,
            num_workers=num_workers,
            shuffle=False,
            pin_memory=True,
        )


if __name__ == "__main__":

    distributed = {
        "gpus": 2 if ddp else 1,
        "distributed_backend": 'ddp' if ddp else None
    }

    trainer = Trainer(
        logger=False,
        checkpoint_callback=False,
        early_stop_callback=False,
        max_nb_epochs=1,
        nb_sanity_val_steps=1,
        **distributed
    )
    model = PlModule()

    trainer.fit(model)
    trainer.test()

Give the following error:

Traceback (most recent call last):
  File "test.py", line 128, in <module>
    trainer.test()
  File "/home/j/miniconda3/envs/alp36/lib/python3.6/site-packages/pytorch_lightning/trainer/trainer.py", line 478, in test
    self.run_evaluation(test=True)
  File "/home/j/miniconda3/envs/alp36/lib/python3.6/site-packages/pytorch_lightning/trainer/evaluation_loop_mixin.py", line 88, in run_evaluation
    can_run_test_step = self.is_overriden('test_step') and self.is_overriden('test_end')
  File "/home/j/miniconda3/envs/alp36/lib/python3.6/site-packages/pytorch_lightning/trainer/model_hooks_mixin.py", line 16, in is_overriden
    is_overriden = getattr(model, f_name).__code__ is not getattr(super_object, f_name).__code__
AttributeError: 'NoneType' object has no attribute 'test_step'

Change ddp = True by ddp = False and no error.

Version:

  • pytorch-lightning: 0.5.3.2
  • pytorch: 1.3.1
  • torchvision: 0.4.2
@jgsch jgsch added the bug Something isn't working label Dec 3, 2019
@jgsch jgsch changed the title NoneType' object has no attribute 'test_step' when DDP 'NoneType' object has no attribute 'test_step' when DDP Dec 3, 2019
@williamFalcon
Copy link
Contributor

good catch. Mind submitting a PR?

@jgsch
Copy link
Author

jgsch commented Dec 6, 2019

@williamFalcon > I investigated and I think the problem comes from this line: mp.spawn(self.ddp_train, nprocs=self.num_gpus, args=(model,))

After exiting function this line, self.model is still None. However, in the function self.ddp_train, the model is well defined. There must be a problem when the processes are joined together.

@sneiman
Copy link
Contributor

sneiman commented Jan 21, 2020

Is there a workaround available in the meantime?

@williamFalcon
Copy link
Contributor

williamFalcon commented Jan 21, 2020

Option A:
call fit(model) ?

option B:
submit a PR to track self.model before doing the ddp call?

@sneiman
Copy link
Contributor

sneiman commented Jan 21, 2020

Sorry - I assumed that this had already led to a pr that included the fact that both trainer.test() and trainer.test(model) with various errors when using ddp. Ill do some work to narrow it and submt a pr.

@Borda Borda added the help wanted Open to be worked on label Jan 24, 2020
@pableeto
Copy link

pableeto commented Mar 2, 2020

Hi, I am using the latest-version (pip install https://github.com/PytorchLightning/pytorch-lightning/archive/master.zip --upgrade) and I have met the same error when calling trainer.test(model).
Is there any update or some workaround on this? Thank you very much.

@Borda
Copy link
Member

Borda commented Mar 2, 2020

We are working on it now... this appears only for multi-GPU or have you observed it elsewhere?
Could you also check #979

@williamFalcon
Copy link
Contributor

@pableeto can you put code here?

@sneiman
Copy link
Contributor

sneiman commented Mar 2, 2020 via email

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
bug Something isn't working help wanted Open to be worked on
Projects
None yet
Development

Successfully merging a pull request may close this issue.

5 participants