Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Weird logging to console behavior. #4621

Closed
vegovs opened this issue Nov 11, 2020 · 14 comments Β· Fixed by #5509 or #6275
Closed

Weird logging to console behavior. #4621

vegovs opened this issue Nov 11, 2020 · 14 comments Β· Fixed by #5509 or #6275
Assignees
Labels
bug Something isn't working help wanted Open to be worked on logging Related to the `LoggerConnector` and `log()`

Comments

@vegovs
Copy link

vegovs commented Nov 11, 2020

πŸ› Bug

Logging to console prints some stuff twice, and does not output my custom logging. Verbose EarlyStopping does also not output to console:

 |segmentation|base|py-3.8.5 Stanley in ~/Repos/segmentation
Β± |master U:1 ?:1 βœ—| β†’ python train.py 
GPU available: True, used: True
INFO:lightning:GPU available: True, used: True
TPU available: False, using: 0 TPU cores
INFO:lightning:TPU available: False, using: 0 TPU cores
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]
INFO:lightning:LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]
Using native 16bit precision.
INFO:lightning:Using native 16bit precision.
Missing logger folder: ./logs/11-11-2020-04-29-21_LR0_001_BS5_IS512
WARNING:lightning:Missing logger folder: ./logs/11-11-2020-04-29-21_LR0_001_BS5_IS512
initializing ddp: GLOBAL_RANK: 0, MEMBER: 1/1
INFO:lightning:initializing ddp: GLOBAL_RANK: 0, MEMBER: 1/1

   | Name        | Type              | Params | In sizes                                 | Out sizes         
-------------------------------------------------------------------------------------------------------------------
0  | criterion   | BCEWithLogitsLoss | 0      | ?                                        | ?                 
1  | in_conv     | DoubleConvolution | 38 K   | [5, 3, 512, 512]                         | [5, 64, 512, 512] 
2  | down_conv_1 | Down              | 221 K  | [5, 64, 512, 512]                        | [5, 128, 256, 256]
3  | down_conv_2 | Down              | 885 K  | [5, 128, 256, 256]                       | [5, 256, 128, 128]
4  | down_conv_3 | Down              | 3 M    | [5, 256, 128, 128]                       | [5, 512, 64, 64]  
5  | down_conv_4 | Down              | 4 M    | [5, 512, 64, 64]                         | [5, 512, 32, 32]  
6  | up_conv_1   | Up                | 5 M    | [[5, 512, 32, 32], [5, 512, 64, 64]]     | [5, 256, 64, 64]  
7  | up_conv_2   | Up                | 1 M    | [[5, 256, 64, 64], [5, 256, 128, 128]]   | [5, 128, 128, 128]
8  | up_conv_3   | Up                | 369 K  | [[5, 128, 128, 128], [5, 128, 256, 256]] | [5, 64, 256, 256] 
9  | up_conv_4   | Up                | 110 K  | [[5, 64, 256, 256], [5, 64, 512, 512]]   | [5, 64, 512, 512] 
10 | out_conv    | OutConvolution    | 65     | [5, 64, 512, 512]                        | [5, 1, 512, 512]  
INFO:lightning:
   | Name        | Type              | Params | In sizes                                 | Out sizes         
-------------------------------------------------------------------------------------------------------------------
0  | criterion   | BCEWithLogitsLoss | 0      | ?                                        | ?                 
1  | in_conv     | DoubleConvolution | 38 K   | [5, 3, 512, 512]                         | [5, 64, 512, 512] 
2  | down_conv_1 | Down              | 221 K  | [5, 64, 512, 512]                        | [5, 128, 256, 256]
3  | down_conv_2 | Down              | 885 K  | [5, 128, 256, 256]                       | [5, 256, 128, 128]
4  | down_conv_3 | Down              | 3 M    | [5, 256, 128, 128]                       | [5, 512, 64, 64]  
5  | down_conv_4 | Down              | 4 M    | [5, 512, 64, 64]                         | [5, 512, 32, 32]  
6  | up_conv_1   | Up                | 5 M    | [[5, 512, 32, 32], [5, 512, 64, 64]]     | [5, 256, 64, 64]  
7  | up_conv_2   | Up                | 1 M    | [[5, 256, 64, 64], [5, 256, 128, 128]]   | [5, 128, 128, 128]
8  | up_conv_3   | Up                | 369 K  | [[5, 128, 128, 128], [5, 128, 256, 256]] | [5, 64, 256, 256] 
9  | up_conv_4   | Up                | 110 K  | [[5, 64, 256, 256], [5, 64, 512, 512]]   | [5, 64, 512, 512] 
10 | out_conv    | OutConvolution    | 65     | [5, 64, 512, 512]                        | [5, 1, 512, 512]  
Epoch 3:  70%|β–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆ                                                       | 5173/7395 [33:13<14:16,  2.60it/s, loss=0.327, v_num=0]
Testing: 100%|β–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆ| 1752/1752 [43:15<00:00,  1.49s/it]--------------------------------------------------------------------------------
DATALOADER:0 TEST RESULTS
{'test_f1': tensor(0.9091, device='cuda:0'),
 'test_loss': tensor(0.2796, device='cuda:0'),
 'test_precision': tensor(0.9091, device='cuda:0'),
 'test_recall': tensor(0.9091, device='cuda:0'),
 'train_f1': tensor(0.9245, device='cuda:0'),
 'train_loss': tensor(0.2836, device='cuda:0'),
 'train_precision': tensor(0.9245, device='cuda:0'),
 'train_recall': tensor(0.9245, device='cuda:0'),
 'val_f1': tensor(0.9164, device='cuda:0'),
 'val_loss': tensor(0.2818, device='cuda:0'),
 'val_precision': tensor(0.9164, device='cuda:0'),
 'val_recall': tensor(0.9164, device='cuda:0')}
--------------------------------------------------------------------------------
Testing: 100%|β–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆ| 1752/1752 [43:16<00:00,  1.48s/it]

 |segmentation|base|py-3.8.5 Stanley in ~/Repos/segmentation
Β± |master U:1 ?:2 βœ—| β†’ 

To Reproduce

Here is my training code:

import logging
import os
import sys
from argparse import ArgumentParser
from datetime import datetime

from knockknock import discord_sender

import torch
from dotenv import load_dotenv
from pytorch_lightning import Trainer
from pytorch_lightning.callbacks import EarlyStopping
from pytorch_lightning.loggers import TensorBoardLogger

from torch.backends import cudnn

from unet.unet_model import UNet

load_dotenv(verbose=True)


@discord_sender(webhook_url=os.getenv("DISCORD_WH"))
def main():
    """
    Main training loop.
    """
    parser = ArgumentParser()

    parser = UNet.add_model_specific_args(parser)
    parser = Trainer.add_argparse_args(parser)

    args = parser.parse_args()

    prod = bool(os.getenv("PROD"))
    logging.getLogger("lightning").setLevel(logging.INFO)

    if prod:
        logging.info("Training i production mode, disabling all debugging APIs")
        torch.autograd.set_detect_anomaly(False)
        torch.autograd.profiler.profile(enabled=False)
        torch.autograd.profiler.emit_nvtx(enabled=False)
    else:
        logging.info("Training i development mode, debugging APIs active.")
        torch.autograd.set_detect_anomaly(True)
        torch.autograd.profiler.profile(
            enabled=True, use_cuda=True, record_shapes=True, profile_memory=True
        )
        torch.autograd.profiler.emit_nvtx(enabled=True, record_shapes=True)

    model = UNet(**vars(args))

    logging.info(
        f"Network:\n"
        f"\t{model.hparams.n_channels} input channels\n"
        f"\t{model.hparams.n_classes} output channels (classes)\n"
        f'\t{"Bilinear" if model.hparams.bilinear else "Transposed conv"} upscaling'
    )

    cudnn.benchmark = True  # cudnn Autotuner
    cudnn.enabled = True  # look for optimal algorithms

    early_stop_callback = EarlyStopping(
        monitor="val_loss",
        min_delta=0.00,
        mode="min",
        patience=3 if not os.getenv("EARLY_STOP") else int(os.getenv("EARLY_STOP")),
        verbose=True,
    )

    run_name = "{}_LR{}_BS{}_IS{}".format(
        datetime.now().strftime("%d-%m-%Y-%H-%M-%S"),
        args.lr,
        args.batch_size,
        args.image_size,
    ).replace(".", "_")

    log_folder = (
        "./logs" if not os.getenv("DIR_ROOT_DIR") else os.getenv("DIR_ROOT_DIR")
    )
    if not os.path.isdir(log_folder):
        os.mkdir(log_folder)
    logger = TensorBoardLogger(log_folder, name=run_name)

    try:
        trainer = Trainer.from_argparse_args(
            args,
            gpus=-1,
            precision=16,
            distributed_backend="ddp",
            logger=logger,
            callbacks=[early_stop_callback],
            accumulate_grad_batches=1.0
            if not os.getenv("ACC_GRAD")
            else int(os.getenv("ACC_GRAD")),
            gradient_clip_val=0.0
            if not os.getenv("GRAD_CLIP")
            else float(os.getenv("GRAD_CLIP")),
            max_epochs=100 if not os.getenv("EPOCHS") else int(os.getenv("EPOCHS")),
            val_check_interval=0.1
            if not os.getenv("VAL_INT_PER")
            else float(os.getenv("VAL_INT_PER")),
            default_root_dir=os.getcwd()
            if not os.getenv("DIR_ROOT_DIR")
            else os.getenv("DIR_ROOT_DIR"),
        )
        trainer.fit(model)
        trainer.test(model)
    except KeyboardInterrupt:
        torch.save(model.state_dict(), "INTERRUPTED.pth")
        logging.info("Saved interrupt")
        try:
            sys.exit(0)
        except SystemExit:
            os._exit(0)


if __name__ == "__main__":
    main()

Expected behavior

Environment

  • CUDA:
    • GPU:
      • GeForce RTX 2070 SUPER
    • available: True
    • version: 11.0
  • Packages:
    • numpy: 1.19.4
    • pyTorch_debug: True
    • pyTorch_version: 1.7.0+cu110
    • pytorch-lightning: 1.0.5
    • tqdm: 4.51.0
  • System:
    • OS: Linux
    • architecture:
      • 64bit
      • ELF
    • processor: x86_64
    • python: 3.8.5
    • version: Training accuracyΒ #57-Ubuntu SMP Thu Oct 15 10:57:00 UTC 2020
@vegovs vegovs added bug Something isn't working help wanted Open to be worked on labels Nov 11, 2020
@edenlightning edenlightning assigned tchaton and unassigned tchaton Nov 12, 2020
@edenlightning edenlightning added the logging Related to the `LoggerConnector` and `log()` label Nov 16, 2020
@edenlightning
Copy link
Contributor

@tchaton

@edenlightning
Copy link
Contributor

from @awaelchli : this might be a platform issue. Behaves differently on Linux vs windows.

@edenlightning
Copy link
Contributor

We couldn't find any culprit do far.

@Huizerd
Copy link

Huizerd commented Jan 11, 2021

Even more minimal reproduction:

import logging

import pytorch_lightning as pl
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F


class Model(pl.LightningModule):
    def __init__(self, in_size, hid_size, out_size):
        super().__init__()

        self.fc1 = nn.Linear(in_size, hid_size)
        self.fc2 = nn.Linear(hid_size, out_size)

    def forward(self, x):
        return self.fc2(self.fc1(x))

    def training_step(self, batch, batch_idx):
        x, y = batch
        logits = self(x)
        loss = F.cross_entropy(logits, y)
        return loss
    
    def configure_optimizers(self):
        return optim.Adam(self.parameters(), lr=1e-3)


def main(bug):
    logging.getLogger("lightning").setLevel(logging.INFO)

    if bug:
        logging.info("Hoi")

    model = Model(5, 10, 2)
    trainer = pl.Trainer()


if __name__ == "__main__":
    main(bug=False)

Setting bug to True leads to duplicate console output and no custom logging:

GPU available: True, used: False
INFO:lightning:GPU available: True, used: False
TPU available: None, using: 0 TPU cores
INFO:lightning:TPU available: None, using: 0 TPU cores

while leaving bug to False leads to correct outputs:

GPU available: True, used: False
TPU available: None, using: 0 TPU cores

Output of pip freeze:

absl-py==0.11.0
antlr4-python3-runtime==4.8
cachetools==4.2.0
certifi==2020.12.5
chardet==4.0.0
fsspec==0.8.5
future==0.18.2
google-auth==1.24.0
google-auth-oauthlib==0.4.2
grpcio==1.34.0
hydra-core==1.0.4
idna==2.10
importlib-resources==4.0.0
Markdown==3.3.3
numpy==1.19.4
oauthlib==3.1.0
omegaconf==2.0.5
pkg-resources==0.0.0
protobuf==3.14.0
pyasn1==0.4.8
pyasn1-modules==0.2.8
pytorch-lightning==1.1.2
PyYAML==5.3.1
requests==2.25.1
requests-oauthlib==1.3.0
rsa==4.6
six==1.15.0
tensorboard==2.4.0
tensorboard-plugin-wit==1.7.0
torch==1.7.1
tqdm==4.54.1
typing-extensions==3.7.4.3
urllib3==1.26.2
Werkzeug==1.0.1

Also leads to duplicate logging with Hydra

@toliz
Copy link

toliz commented Jan 12, 2021

Same issue on Ubuntu 20.04.

@toliz
Copy link

toliz commented Jan 12, 2021

Also I didn't find any way to suppress internal logging such as:

GPU available: True, used: False
TPU available: None, using: 0 TPU cores

@awaelchli
Copy link
Contributor

awaelchli commented Jan 12, 2021

@toliz I think you have to do

my_logger = logging.getLogger("lightning")
my_logger.setLevel(logging.INFO)
my_logger.info("Hoi")

I have spent some time on this problem of "duplicated logging" a few weeks ago but the problem is that it behaves differently on different platforms, which can drive a human crazy. I will try again with your sample and env and see if I can go any further this time. Thanks for reporting

@toliz
Copy link

toliz commented Jan 12, 2021

@toliz I think you have to do

my_logger = logging.getLogger("lightning")
my_logger.setLevel(logging.INFO)
my_logger.info("Hoi")

@awaelchli For this to suppress Trainers logging I have to set the logging level to logging.WARNING or higher, which means every information log will be suppressed.

What if you switch the following to logging.DEBUG level. I guess most of this info is redundant after a point.

GPU available: True, used: False
TPU available: None, using: 0 TPU cores

@awaelchli
Copy link
Contributor

Minimal example:

import pytorch_lightning as pl
import logging
logging.info("I'm not getting logged")
pl.seed_everything(1234)  # but this gets logged twice

# console output:
# Global seed set to 1234
# INFO:lightning:Global seed set to 1234

@awaelchli
Copy link
Contributor

awaelchli commented Jan 14, 2021

In pytorch_lightning/__init__.py I added: _logger.propagate = False
to avoid it propagating logs to the python root logger.
Does this resolve your issue?

@Huizerd
Copy link

Huizerd commented Jan 14, 2021

Yes! Works for me.

@awaelchli
Copy link
Contributor

Alright, I will finalize the PR. I need to see that this change doesn't affect any existing logging.

@toliz
Copy link

toliz commented Jan 14, 2021

@awaelchli sorry if it's a basic question but how can I test it? Do I just add this line in my local PL package?

@awaelchli
Copy link
Contributor

a few ways :)

a) you can modify the pytorch lightning source code directly as I did in the linked PR.
b) you can pip intall lightning from my branch: pip install --upgrade git+https://https://github.com/PyTorchLightning/pytorch-lightning@bugfix/duplicate-logs2
c) you can try to do this in your training script:

pl_logger = logging.getLogger("lightning")
pl_logger.propagate = False

before anything else runs

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
bug Something isn't working help wanted Open to be worked on logging Related to the `LoggerConnector` and `log()`
Projects
None yet
6 participants