Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

WandbLogger _sanitize_callable_params throws AttributeError if param does not have __name__ #4380

Closed
awaelchli opened this issue Oct 26, 2020 · 2 comments · Fixed by #4422
Closed
Assignees
Labels
bug Something isn't working help wanted Open to be worked on logger Related to the Loggers
Milestone

Comments

@awaelchli
Copy link
Contributor

🐛 Bug

Using WandB logger throws an error:

During handling of the above exception, another exception occurred:

Traceback (most recent call last):
  File "/home/adrian/.conda/envs/lightning/lib/python3.7/site-packages/torch/multiprocessing/spawn.py", line 20, in _wrap  
    fn(i, *args)
  File "/home/adrian/repositories/pytorch-lightning/pytorch_lightning/accelerators/ddp_spawn_accelerator.py", line 145, in ddp_train
    self.trainer.train_loop.setup_training(model)
  File "/home/adrian/repositories/pytorch-lightning/pytorch_lightning/trainer/training_loop.py", line 135, in setup_training
    self.trainer.logger.log_hyperparams(ref_model.hparams_initial)
  File "/home/adrian/repositories/pytorch-lightning/pytorch_lightning/utilities/distributed.py", line 35, in wrapped_fn
    return fn(*args, **kwargs)
  File "/home/adrian/repositories/pytorch-lightning/pytorch_lightning/loggers/wandb.py", line 138, in log_hyperparams
    params = self._sanitize_callable_params(params)
  File "/home/adrian/repositories/pytorch-lightning/pytorch_lightning/loggers/base.py", line 194, in _sanitize_callable_params
    return {key: _sanitize_callable(val) for key, val in params.items()}
  File "/home/adrian/repositories/pytorch-lightning/pytorch_lightning/loggers/base.py", line 194, in <dictcomp>
    return {key: _sanitize_callable(val) for key, val in params.items()}
  File "/home/adrian/repositories/pytorch-lightning/pytorch_lightning/loggers/base.py", line 191, in _sanitize_callable
    return val.__name__
  File "/home/adrian/.conda/envs/lightning/lib/python3.7/site-packages/torch/nn/modules/module.py", line 772, in __getattr__
    type(self).__name__, name))
torch.nn.modules.module.ModuleAttributeError: 'Backbone' object has no attribute '__name__'

To Reproduce

from argparse import ArgumentParser

import torch
import pytorch_lightning as pl
from torch.nn import functional as F
from torch.utils.data import DataLoader, random_split

from pytorch_lightning.loggers import WandbLogger

try:
    from torchvision.datasets.mnist import MNIST
    from torchvision import transforms
except Exception as e:
    from tests.base.datasets import MNIST


class Backbone(torch.nn.Module):
    def __init__(self, hidden_dim=128):
        super().__init__()
        self.l1 = torch.nn.Linear(28 * 28, hidden_dim)
        self.l2 = torch.nn.Linear(hidden_dim, 10)

    def forward(self, x):
        x = x.view(x.size(0), -1)
        x = torch.relu(self.l1(x))
        x = torch.relu(self.l2(x))
        return x


class LitClassifier(pl.LightningModule):
    def __init__(self, backbone, learning_rate=1e-3):
        super().__init__()
        self.save_hyperparameters()
        self.backbone = backbone

    def forward(self, x):
        # use forward for inference/predictions
        embedding = self.backbone(x)
        return embedding

    def training_step(self, batch, batch_idx):
        x, y = batch
        y_hat = self.backbone(x)
        loss = F.cross_entropy(y_hat, y)
        self.log('train_loss', loss)
        return loss

    def validation_step(self, batch, batch_idx):
        x, y = batch
        y_hat = self.backbone(x)
        loss = F.cross_entropy(y_hat, y)
        self.log('valid_loss', loss)

    def test_step(self, batch, batch_idx):
        x, y = batch
        y_hat = self.backbone(x)
        loss = F.cross_entropy(y_hat, y)
        self.log('test_loss', loss)

    def configure_optimizers(self):
        # self.hparams available because we called self.save_hyperparameters()
        return torch.optim.Adam(self.parameters(), lr=self.hparams.learning_rate)

    @staticmethod
    def add_model_specific_args(parent_parser):
        parser = ArgumentParser(parents=[parent_parser], add_help=False)
        parser.add_argument('--learning_rate', type=float, default=0.0001)
        return parser


def cli_main():
    pl.seed_everything(1234)

    # ------------
    # args
    # ------------
    parser = ArgumentParser()
    parser.add_argument('--batch_size', default=32, type=int)
    parser.add_argument('--hidden_dim', type=int, default=128)
    parser = pl.Trainer.add_argparse_args(parser)
    parser = LitClassifier.add_model_specific_args(parser)
    args = parser.parse_args()

    # ------------
    # data
    # ------------
    dataset = MNIST('', train=True, download=True, transform=transforms.ToTensor())
    mnist_test = MNIST('', train=False, download=True, transform=transforms.ToTensor())
    mnist_train, mnist_val = random_split(dataset, [55000, 5000])

    train_loader = DataLoader(mnist_train, batch_size=args.batch_size)
    val_loader = DataLoader(mnist_val, batch_size=args.batch_size)
    test_loader = DataLoader(mnist_test, batch_size=args.batch_size)

    # ------------
    # model
    # ------------
    model = LitClassifier(Backbone(hidden_dim=args.hidden_dim), args.learning_rate)

    logger = WandbLogger(project="test", name="test")

    # ------------
    # training
    # ------------
    trainer = pl.Trainer.from_argparse_args(args, max_steps=1, limit_train_batches=2, logger=logger)
    trainer.fit(model, train_loader, val_loader)

    # ------------
    # testing
    # ------------
    result = trainer.test(model, test_dataloaders=test_loader)
    print(result)


if __name__ == '__main__':
    cli_main()

Expected behavior

Environment

environment does not matter

Additional context

A recent PR #4320 introduced the function that throws the error.

cc @tchaton

@awaelchli awaelchli added bug Something isn't working help wanted Open to be worked on logger Related to the Loggers labels Oct 26, 2020
@awaelchli awaelchli added this to the 1.0.x milestone Oct 26, 2020
@awaelchli awaelchli changed the title WandbLogger throws AttributeError in _sanitize_callable_params function WandbLogger _sanitize_callable_params throws AttributeError if param does not have __name__ Oct 26, 2020
@awaelchli
Copy link
Contributor Author

Looks like we need to be careful with objects that don't have __name__

@tchaton
Copy link
Contributor

tchaton commented Oct 29, 2020

def test_wandb_sanitize_callable_params_with_model(tmpdir):
    """
    Callback function are not serializiable. Therefore, we get them a chance to return
    something and if the returned type is not accepted, return None.
    """
    class ExtendedModel(BoringModel):
        @staticmethod
        def add_model_specific_args(parser):
            parser = ArgumentParser(parents=[parser], add_help=False)
            parser.add_argument('--learning_rate', type=float, default=0.0001)
            return parser        

    parser = ArgumentParser()
    parser.add_argument('--batch_size', default=32, type=int)
    parser.add_argument('--hidden_dim', type=int, default=128)
    parser = Trainer.add_argparse_args(parser)
    parser = ExtendedModel.add_model_specific_args(parser)
    args = parser.parse_args('--max_steps 1'.split(' '))

    model = ExtendedModel()
    logger = WandbLogger(project="test", name="test", offline=True)
    # ------------
    trainer = Trainer.from_argparse_args(args, max_steps=1, limit_train_batches=2, logger=logger)
    trainer.fit(model)

I created this code to reproduce the bug, but I couldn't. I have removed the name by default which None

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
bug Something isn't working help wanted Open to be worked on logger Related to the Loggers
Projects
None yet
Development

Successfully merging a pull request may close this issue.

2 participants