Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Logging non-tensor scalar with result breaks subsequent epoch aggregation #3276

Closed
s-rog opened this issue Aug 31, 2020 · 11 comments · Fixed by #4116
Closed

Logging non-tensor scalar with result breaks subsequent epoch aggregation #3276

s-rog opened this issue Aug 31, 2020 · 11 comments · Fixed by #4116
Labels
bug Something isn't working good first issue Good for newcomers help wanted Open to be worked on
Milestone

Comments

@s-rog
Copy link
Contributor

s-rog commented Aug 31, 2020

🐛 Bug

Logging non-tensor scalar with result breaks subsequent epoch/tbptt aggregation
(on both 0.9 and master)

-- Process 1 terminated with the following error:
Traceback (most recent call last):
  File "/opt/conda/lib/python3.6/site-packages/torch/multiprocessing/spawn.py", line 20, in _wrap
    fn(i, *args)
  File "/opt/conda/lib/python3.6/site-packages/pytorch_lightning/accelerators/ddp_spawn_backend.py", line 165, in ddp_train
    results = self.trainer.run_pretrain_routine(model)
  File "/opt/conda/lib/python3.6/site-packages/pytorch_lightning/trainer/trainer.py", line 1237, in run_pretrain_routine
    self.train()
  File "/opt/conda/lib/python3.6/site-packages/pytorch_lightning/trainer/training_loop.py", line 396, in train
    self.run_training_epoch()
  File "/opt/conda/lib/python3.6/site-packages/pytorch_lightning/trainer/training_loop.py", line 543, in run_training_epoch
    self.run_training_epoch_end(epoch_output, checkpoint_accumulator, early_stopping_accumulator, num_optimizers)
  File "/opt/conda/lib/python3.6/site-packages/pytorch_lightning/trainer/training_loop.py", line 672, in run_training_epoch_end
    epoch_log_metrics, epoch_progress_bar_metrics = self.__auto_reduce_results_on_epoch_end(epoch_output)
  File "/opt/conda/lib/python3.6/site-packages/pytorch_lightning/trainer/training_loop.py", line 696, in __auto_reduce_results_on_epoch_end
    tbptt_outs = tbptt_outs[0].__class__.reduce_across_time(tbptt_outs)
  File "/opt/conda/lib/python3.6/site-packages/pytorch_lightning/core/step_result.py", line 392, in reduce_across_time
    result[k] = tbptt_reduce_fx(value)
TypeError: mean(): argument 'input' (position 1) must be Tensor, not list

To Reproduce

    def training_step(self, batch, batch_idx):
        x, y = batch[0], batch[1]
        x = self.forward(x)
        loss = self.loss(x, y)
        result = pl.TrainResult(loss)
        result.log("non tensor scalar", 1.0)
        result.log("loss", loss, on_step=False, on_epoch=True)

To Fix

result.log("non tensor scalar", torch.tensor(1.0))

Expected behavior

In log() of result objects, value should accept non tensor values as value: Any and not cause issues with other metrics to be logged

Additional context

log() can be changed to only accept tensors, or have a built-in conversion, will update as I investigate further

@s-rog s-rog added bug Something isn't working help wanted Open to be worked on labels Aug 31, 2020
@edenlightning edenlightning added this to the 0.9.x milestone Sep 1, 2020
@edenlightning
Copy link
Contributor

@williamFalcon

@Borda Borda added the Result label Sep 4, 2020
@Borda Borda added the good first issue Good for newcomers label Sep 15, 2020
@Borda
Copy link
Member

Borda commented Sep 15, 2020

@williamFalcon is it still there? @s-rog mind check if it is still on master? or better add test for such case...

@s-rog
Copy link
Contributor Author

s-rog commented Sep 16, 2020

I'll take a look again when I get a chance, haven't probed much due to the refactors... Are they mostly done?

@s-rog
Copy link
Contributor Author

s-rog commented Sep 16, 2020

@Borda the bug is still here, the following template tests for this issue as well as #3278

For this problem the easiest fix would be to force type to tensors. Though that's probably just a bandaid solution, thoughts?

Test template for reference
#!/opt/conda/bin/python
"""
Runs a model on a single node across multiple gpus.
"""
import os
from argparse import ArgumentParser

import torch.nn.functional as F

import pytorch_lightning as pl
from pl_examples.models.lightning_template import LightningTemplateModel
from pytorch_lightning import Trainer, seed_everything

seed_everything(234)

class custom_template(LightningTemplateModel):
    def __init__(self, *args, **kwargs):
        super().__init__(*args, **kwargs)
    
    def on_epoch_start(self):
        print("on_epoch_start")

    def on_fit_start(self):
        print("on_fit_start")
        
    def training_step(self, batch, batch_idx):
        """
        Lightning calls this inside the training loop with the data from the training dataloader
        passed in as `batch`.
        """
        # forward pass
        x, y = batch
        y_hat = self(x)
        loss = F.cross_entropy(y_hat, y)
        result = pl.TrainResult(loss)
        result.log("non tensor scalar", 1.0)
        result.log("loss", loss, on_step=False, on_epoch=True)
        return result
        

def main(args):
    """ Main training routine specific for this project. """
    # ------------------------
    # 1 INIT LIGHTNING MODEL
    # ------------------------
    model = custom_template(**vars(args))

    # ------------------------
    # 2 INIT TRAINER
    # ------------------------
    trainer = Trainer.from_argparse_args(args)

    # ------------------------
    # 3 START TRAINING
    # ------------------------
    trainer.fit(model)


def run_cli():
    # ------------------------
    # TRAINING ARGUMENTS
    # ------------------------
    # these are project-wide arguments
    root_dir = os.path.dirname(os.path.realpath(__file__))
    parent_parser = ArgumentParser(add_help=False)

    # each LightningModule defines arguments relevant to it
    parser = LightningTemplateModel.add_model_specific_args(parent_parser, root_dir)
    parser = Trainer.add_argparse_args(parser)
    parser.set_defaults(gpus=1, distributed_backend=None)
    args = parser.parse_args()

    # ---------------------
    # RUN TRAINING
    # ---------------------
    main(args)


if __name__ == '__main__':
    run_cli()

@s-rog
Copy link
Contributor Author

s-rog commented Sep 17, 2020

I looked into it a bit and reduce_across_time is getting called on all metrics if one metric in results is logged with on_epoch=True which makes on_epoch=True only compatible with tensor scalars since the default tbptt fn is torch.mean

This is probably not intended behavior?

@edenlightning
Copy link
Contributor

@justusschock @SkafteNicki

@justusschock
Copy link
Member

@edenlightning this is not related to metrics.

@s-rog I guess a simple type check that converts scalars to scalars tensor should do the trick? If so, could you open a PR with this fix?

@Borda
Copy link
Member

Borda commented Oct 5, 2020

fixed by #3855

@Borda Borda closed this as completed Oct 5, 2020
@s-rog
Copy link
Contributor Author

s-rog commented Oct 13, 2020

Traceback (most recent call last):
  File "/opt/conda/lib/python3.6/site-packages/torch/multiprocessing/spawn.py", line 59, in _wrap
    fn(i, *args)
  File "/opt/conda/lib/python3.6/site-packages/pytorch_lightning/accelerators/ddp_spawn_accelerator.py", line 152, in ddp_train
    results = self.train_or_test()
  File "/opt/conda/lib/python3.6/site-packages/pytorch_lightning/accelerators/accelerator.py", line 53, in train_or_test
    results = self.trainer.train()
  File "/opt/conda/lib/python3.6/site-packages/pytorch_lightning/trainer/trainer.py", line 483, in train
    self.train_loop.run_training_epoch()
  File "/opt/conda/lib/python3.6/site-packages/pytorch_lightning/trainer/training_loop.py", line 598, in run_training_epoch
    self.num_optimizers
  File "/opt/conda/lib/python3.6/site-packages/pytorch_lightning/trainer/connectors/logger_connector.py", line 339, in log_train_epoch_end_metrics
    epoch_log_metrics, epoch_progress_bar_metrics = self.__auto_reduce_results_on_epoch_end(epoch_output)
  File "/opt/conda/lib/python3.6/site-packages/pytorch_lightning/trainer/connectors/logger_connector.py", line 449, in __auto_reduce_results_on_epoch_end
    tbptt_outs = tbptt_outs[0].__class__.reduce_across_time(tbptt_outs)
  File "/opt/conda/lib/python3.6/site-packages/pytorch_lightning/core/step_result.py", line 483, in reduce_across_time
    result[k] = tbptt_reduce_fx(value.float())
AttributeError: 'list' object has no attribute 'float'

@Borda I don't think the issue was fixed completely, this is on rc5 (using self.log)

@Borda
Copy link
Member

Borda commented Oct 13, 2020

@s-rog mind add a test for this case?

@Borda Borda reopened this Oct 13, 2020
williamFalcon added a commit that referenced this issue Oct 13, 2020
williamFalcon added a commit that referenced this issue Oct 13, 2020
@s-rog
Copy link
Contributor Author

s-rog commented Oct 13, 2020

william beat me to it :]

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
bug Something isn't working good first issue Good for newcomers help wanted Open to be worked on
Projects
None yet
Development

Successfully merging a pull request may close this issue.

6 participants