Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

"Wrong" code example in doc auto-scaling-of-batch-size section #5967

Closed
ifsheldon opened this issue Feb 14, 2021 · 9 comments · Fixed by #5968
Closed

"Wrong" code example in doc auto-scaling-of-batch-size section #5967

ifsheldon opened this issue Feb 14, 2021 · 9 comments · Fixed by #5968
Labels
bug Something isn't working help wanted Open to be worked on priority: 1 Medium priority task

Comments

@ifsheldon
Copy link
Contributor

ifsheldon commented Feb 14, 2021

🐛 Bug

In the doc auto-scaling-of-batch-size section, a code example is

# Use default in trainer construction
trainer = Trainer()
tuner = Tuner(trainer)

# Invoke method
new_batch_size = tuner.scale_batch_size(model, *extra_parameters_here)

# Override old batch size
model.hparams.batch_size = new_batch_size

# Fit as normal
trainer.fit(model)

However, this will not work as expected in the case where a LightningModule contains an attribute self.datamodule. Following the code will give MisconfigurationException: Field batch_size not found in both model and model.hparams.

To Reproduce

See my one-page code

import torch
import torchvision
from torchvision import transforms
from torchvision import models
import utils
import pytorch_lightning as pl
from pytorch_lightning.callbacks import ModelCheckpoint

class ImageNetDataModule(pl.LightningDataModule):
    def __init__(self, imagenet_root, batch_size = 128, num_workers = 32):
        super().__init__()
        self.batch_size = batch_size
        self.imagenet_root = imagenet_root
        self.num_workers = num_workers
        
    def setup(self, stage):
        train_transform = transforms.Compose([transforms.Resize(256), 
                                        transforms.CenterCrop(224), 
                                        transforms.RandomGrayscale(),
                                        transforms.RandomHorizontalFlip(),
                                        transforms.ToTensor(), 
                                        transforms.Normalize(mean=[0.485, 0.456, 0.406],std=[0.229, 0.224, 0.225])])
        self.imagenet_train = torchvision.datasets.ImageNet(root = self.imagenet_root+"train/", split="train", transform = train_transform)
        val_transform = transforms.Compose([transforms.Resize(256), 
                                        transforms.CenterCrop(224), 
                                        transforms.ToTensor(), 
                                        transforms.Normalize(mean=[0.485, 0.456, 0.406],std=[0.229, 0.224, 0.225])])
        self.imagenet_val = torchvision.datasets.ImageNet(root= self.imagenet_root+"val/", split="val", transform = val_transform)
    
    def train_dataloader(self):
        return torch.utils.data.DataLoader(self.imagenet_train, 
                                           batch_size=self.batch_size,
                                           shuffle=True,
                                           num_workers = self.num_workers,
                                           persistent_workers=True
                                          )
    
    def val_dataloader(self):
        return torch.utils.data.DataLoader(self.imagenet_val, 
                                           batch_size=self.batch_size,
                                           shuffle=False,
                                           num_workers = self.num_workers,
                                           persistent_workers=True
                                          )
    
class NetWrapper(pl.LightningModule):
    def __init__(self, model,datamodule, criterion = torch.nn.CrossEntropyLoss()):
        super().__init__()
        self.model = model
        self.criterion = criterion
        self.lr = 1e-3
        self.datamodule = datamodule
    
    def forward(self, x):
        raw_prob = self.model(x) #(batch, 1000)
        return raw_prob
    
    def training_step(self, batch, batch_idx):
        x, y = batch
        preds = self(x) #(batch, word_emb_dim)
        loss = self.criterion(preds, y)
        self.log("cross_entropy_loss_training", loss)
        return loss
    
    def validation_step(self, batch, batch_idx):
        x, y = batch
        preds = self(x) #(batch, word_emb_dim)
        loss = self.criterion(preds, y)
        self.log("cross_entropy_loss_val", loss)
        
        return loss
    
    def validation_epoch_end(self, validation_step_outputs):
        all_outputs = torch.tensor(validation_step_outputs)
        std, mean = torch.std_mean(all_outputs)
        self.log("validation_mean", mean)
        self.log("validation_std", std)
        
    
    def configure_optimizers(self):
        optimizer = torch.optim.Adam(self.parameters(), lr=self.lr)
        return optimizer

imagenet_dm = ImageNetDataModule("../../datasets/ImageNet/", batch_size = 1024)
resnet50 = models.resnet50(pretrained=False)
resnet50_pl_module = NetWrapper(resnet50, imagenet_dm)
trainer = pl.Trainer(gpus=1, 
                     accelerator='dp', 
                     auto_scale_batch_size='binsearch')

tuner = pl.tuner.tuning.Tuner(trainer)
new_batch_size = tuner.scale_batch_size(resnet50_pl_module, mode="binsearch", init_val=128)
# the below line works fine
trainer.tune(resnet50_pl_module)

Expected behavior

Tuner should find the attibute batch_size in model.datamodule in the method Tuner.scale_batch_size().

Environment

This issue should be independent of environments.

Additional context

I took a look at the source code and found out that if we call Trainer.tune() directly, the invoke chain is
trainer.tune() -> tuner.tune()->tuner.scale_batch_size()->batch_size_scaling.scale_batch_size()->lightning_hasattr(model, attribute)->...
while the invoke chain of calling Tuner.scale_batch_size() is
tuner.scale_batch_size()->batch_size_scaling.scale_batch_size()->lightning_hasattr(model, attribute)->....
The problem is that lightning_hasattr(model, attribute) cannot find the attribute model.datamodule.batch_size if we skip the registration steps in trainer.tune().

@awaelchli
Copy link
Contributor

Something like this #5968 will probably fix this issue but it's not good.
I need to find a better way.

@awaelchli
Copy link
Contributor

from argparse import ArgumentParser

import torch
from pytorch_lightning import Trainer
from pytorch_lightning.tuner.tuning import Tuner
from torch.nn import functional as F

import pytorch_lightning as pl
from pl_examples.basic_examples.mnist_datamodule import MNISTDataModule


class LitClassifier(pl.LightningModule):

    def __init__(self, hidden_dim=128, learning_rate=1e-3):
        super().__init__()
        self.save_hyperparameters()

        self.l1 = torch.nn.Linear(28 * 28, self.hparams.hidden_dim)
        self.l2 = torch.nn.Linear(self.hparams.hidden_dim, 10)

    def forward(self, x):
        x = x.view(x.size(0), -1)
        x = torch.relu(self.l1(x))
        x = torch.relu(self.l2(x))
        return x

    def training_step(self, batch, batch_idx):
        x, y = batch
        y_hat = self(x)
        loss = F.cross_entropy(y_hat, y)
        return loss

    def validation_step(self, batch, batch_idx):
        x, y = batch
        y_hat = self(x)
        loss = F.cross_entropy(y_hat, y)
        self.log('valid_loss', loss)

    def test_step(self, batch, batch_idx):
        x, y = batch
        y_hat = self(x)
        loss = F.cross_entropy(y_hat, y)
        self.log('test_loss', loss)

    def configure_optimizers(self):
        return torch.optim.Adam(self.parameters(), lr=self.hparams.learning_rate)

    @staticmethod
    def add_model_specific_args(parent_parser):
        parser = ArgumentParser(parents=[parent_parser], add_help=False)
        parser.add_argument('--hidden_dim', type=int, default=128)
        parser.add_argument('--learning_rate', type=float, default=0.0001)
        return parser


if __name__ == '__main__':
    pl.seed_everything(1234)
    parser = ArgumentParser()
    parser = Trainer.add_argparse_args(parser)
    parser = LitClassifier.add_model_specific_args(parser)
    parser = MNISTDataModule.add_argparse_args(parser)
    args = parser.parse_args()

    dm = MNISTDataModule.from_argparse_args(args)
    model = LitClassifier(args.hidden_dim, args.learning_rate)
    trainer = Trainer(
        gpus=1,
        accelerator='dp',
        auto_scale_batch_size='binsearch'
    )

    tuner = Tuner(trainer)
    new_batch_size = tuner.scale_batch_size(model, mode="binsearch", init_val=128, max_trials=3, datamodule=dm)
    model.hparams.batch_size = new_batch_size

    trainer.fit(model, datamodule=dm)

minimal reproducible example

@VCasecnikovs
Copy link

I have the same issue.
Found out where is the problem.
The problem is that model doesn't have an attribute trainer (model.trainer), so it cannot find model.trainer.datamodule when using lightning_hasattr(model, batch_arg_name).

If you first call fit, the problem does not accure, because fit calls self.model_connector.copy_trainer_model_properties(model) where trainer is added to model m.trainer = proxy(self.trainer).
To fix an issue we need to add trainer to model calling ModelConnector.copy_trainer_model_properties(model) function in tune.

So it would be
self.trainer.model_connector.copy_trainer_model_properties(model) in tune function in Tuner.

    def tune(self, model, train_dataloader, val_dataloaders, datamodule):
        # setup data, etc...
        self.trainer.train_loop.setup_fit(model, train_dataloader, val_dataloaders, datamodule)
        self.trainer.model_connector.copy_trainer_model_properties(model)

        # hook
        self.trainer.data_connector.prepare_data(model)
        # Run auto batch size scaling
        if self.trainer.auto_scale_batch_size:
            if isinstance(self.trainer.auto_scale_batch_size, bool):
                self.trainer.auto_scale_batch_size = 'power'
            self.scale_batch_size(
                model,
                mode=self.trainer.auto_scale_batch_size,
                train_dataloader=train_dataloader,
                val_dataloaders=val_dataloaders,
                datamodule=datamodule,
            )

        # Run learning rate finder:
        if self.trainer.auto_lr_find:
            self.lr_find(model, update_attr=True)

Another workaround would be to add this line of code in TrainLoop.setup_fit with deleting it from trainer.fit.

I do not know which solution is more idiomatic to PyTorch lightning, but I will prepare pool request of first solution.

@awaelchli
Copy link
Contributor

See the linked PR I am already working on

VCasecnikovs added a commit to VCasecnikovs/pytorch-lightning that referenced this issue Mar 5, 2021
@VCasecnikovs
Copy link

@awaelchli, sorry, didn't manage to find it. Looks nice and allows to use autoscalebatch and lrfinder without calling tune. Hope for faster approval.

@awaelchli
Copy link
Contributor

awaelchli commented Mar 5, 2021

Yes, working on adding tests right now, hopefully done soon! Thanks for your help.

@VCasecnikovs
Copy link

Could You explain why do we have to add this also to lr_finder?

@awaelchli
Copy link
Contributor

awaelchli commented Mar 5, 2021

Because for LR finder it will be a similar bug, just not with batch_size but with the lr attribute.
But I will remove these lines and do that in a different PR. This was just for testing. It is work in progress.

@VCasecnikovs
Copy link

Will it? Because we can not write LR in datamodule, but only in the PLmodule. I can not reproduce this bug with learning rate.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
bug Something isn't working help wanted Open to be worked on priority: 1 Medium priority task
Projects
None yet
3 participants