Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Enable logger connector re-design #7891

Merged
merged 22 commits into from
Jun 9, 2021
Merged

Conversation

carmocca
Copy link
Contributor

@carmocca carmocca commented Jun 9, 2021

What does this PR do?

Integrate the logger connector re-design with the loops.
Fix tests
Remove legacy tests

Part of #7631

pseudo-benchmark:

MASTER

ParityModuleMNIST:
time: 0:01:14.02
profiler: 19.160 seconds training_step

HeavyLoggingBoringModel
memory: 18.3 MiB
time: 0:00:08.62
profiler: 3.298 seconds training_step

Logging PoC

ParityModuleMNIST:
time: 0:01:12.17
profiler: 20.911 seconds training_step

HeavyLoggingBoringModel
memory: 70.3 KiB
time: 0:00:06.86
profiler: 7.130 seconds training_step

Code
import gc
import io
import pstats
import tracemalloc
from datetime import datetime

import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import DataLoader

from pytorch_lightning import Trainer, LightningModule
from pytorch_lightning.profiler import AdvancedProfiler
from tests import PATH_DATASETS
from tests.helpers import BoringModel, RandomDataset
from tests.helpers.datasets import MNIST


def collect_stats():
    gc.collect()
    snapshot = tracemalloc.take_snapshot()
    top_stats = snapshot.statistics("lineno")
    for stat in top_stats[:3]:
        print(stat)


class HeavyLoggingBoringModel(BoringModel):
    def __init__(self, memory=False):
        super().__init__()
        self.memory = memory

    def on_fit_start(self):
        if self.memory:
            tracemalloc.start(10)

    def on_fit_end(self):
        if self.memory:
            tracemalloc.stop()

    def training_step(self, batch, batch_idx):
        if self.memory and batch_idx % 50 == 49:
            collect_stats()

        loss = super().training_step(batch, batch_idx)["loss"]

        output_dict = {f"loss_{i}": loss for i in range(200)}
        self.log_dict(output_dict, on_step=True, on_epoch=True, prog_bar=True)

        return loss

    def train_dataloader(self):
        return DataLoader(RandomDataset(32, 500))


class ParityModuleMNIST(LightningModule):
    def __init__(self, memory=False):
        super().__init__()
        self.memory = memory
        self.c_d1 = nn.Linear(in_features=28 * 28, out_features=128)
        self.c_d1_bn = nn.BatchNorm1d(128)
        self.c_d1_drop = nn.Dropout(0.3)
        self.c_d2 = nn.Linear(in_features=128, out_features=10)

    def forward(self, x):
        x = x.view(x.size(0), -1)
        x = self.c_d1(x)
        x = torch.tanh(x)
        x = self.c_d1_bn(x)
        x = self.c_d1_drop(x)
        x = self.c_d2(x)
        return x

    def on_fit_start(self):
        if self.memory:
            tracemalloc.start(10)

    def on_fit_end(self):
        if self.memory:
            tracemalloc.stop()

    def training_step(self, batch, batch_idx):
        if self.memory and batch_idx % 50 == 49:
            collect_stats()

        x, y = batch
        y_hat = self(x)
        loss = F.cross_entropy(y_hat, y)
        self.log("loss", loss, on_step=True, on_epoch=True, prog_bar=True)
        return {"loss": loss}

    def configure_optimizers(self):
        return torch.optim.Adam(self.parameters(), lr=0.02)

    def train_dataloader(self):
        return DataLoader(MNIST(root=PATH_DATASETS, train=True, download=True), batch_size=2)


class MyAdvancedProfiler(AdvancedProfiler):
    def summary(self) -> str:
        recorded_stats = {}
        for action_name, pr in self.profiled_actions.items():
            if action_name != "training_step":
                continue
            s = io.StringIO()
            ps = pstats.Stats(pr, stream=s).strip_dirs().sort_stats("tottime")
            ps.print_stats(20)
            recorded_stats[action_name] = s.getvalue()
        return self._stats_to_str(recorded_stats)


def run(model_cls, test):
    print(f"{'=' * 30}\n {test} - {model_cls.__name__}\n{'=' * 30}")

    if test == "memory":
        trainer = Trainer(max_epochs=1, logger=False, progress_bar_refresh_rate=50, weights_summary=None)
        model = model_cls(memory=True)
        trainer.fit(model)

    elif test == "time":
        trainer = Trainer(max_epochs=1, logger=False, progress_bar_refresh_rate=50, weights_summary=None)
        model = model_cls()
        start = datetime.now()
        trainer.fit(model)
        end = datetime.now()
        print("Time: ", end - start)

    elif test == "profiler":
        trainer = Trainer(
            max_epochs=1,
            logger=False,
            profiler=MyAdvancedProfiler(),
            progress_bar_refresh_rate=50,
            weights_summary=None,
        )
        model = model_cls()
        trainer.fit(model)


if __name__ == "__main__":
    for model_cls in (ParityModuleMNIST, HeavyLoggingBoringModel):
        for test in ("memory", "time", "profiler"):
            if model_cls is ParityModuleMNIST and test == "memory":
                continue
            run(model_cls, test)

Recap:

  • Negligible speed difference when using a real (aka medium-large+) model
  • training_step takes longer as self.log does more now.
  • Fixed memory
  • Runtime is decreased (faster aggregation)

Before submitting

  • Was this discussed/approved via a GitHub issue? (not for typos and docs)
  • Did you read the contributor guideline, Pull Request section?
  • Did you make sure your PR does only one thing, instead of bundling different changes together?
  • Did you make sure to update the documentation with your changes? (if necessary)
  • Did you write any new necessary tests? (not for typos and docs)
  • Did you verify new and existing tests pass locally with your changes?
  • Did you update the CHANGELOG? (not for typos, docs, test updates, or internal minor changes/refactorings)

PR review

  • Is this pull request ready for review? (if not, please submit in draft mode)
  • Check that all items from Before submitting are resolved
  • Make sure the title is self-explanatory and the description concisely explains the PR
  • Add labels and milestones (and optionally projects) to the PR so it can be classified

@carmocca carmocca added the feature Is an improvement or enhancement label Jun 9, 2021
@carmocca carmocca added this to the v1.4 milestone Jun 9, 2021
@carmocca carmocca self-assigned this Jun 9, 2021
@pep8speaks
Copy link

pep8speaks commented Jun 9, 2021

Hello @carmocca! Thanks for updating this PR.

Line 212:17: W503 line break before binary operator

Comment last updated at 2021-06-09 12:31:58 UTC

@codecov
Copy link

codecov bot commented Jun 9, 2021

Codecov Report

Merging #7891 (d53fe03) into master (6fee926) will decrease coverage by 3%.
The diff coverage is 98%.

@@           Coverage Diff           @@
##           master   #7891    +/-   ##
=======================================
- Coverage      91%     88%    -3%     
=======================================
  Files         204     204            
  Lines       13669   13667     -2     
=======================================
- Hits        12445   12009   -436     
- Misses       1224    1658   +434     

Copy link
Contributor

@tchaton tchaton left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Amazing job !

pytorch_lightning/core/lightning.py Outdated Show resolved Hide resolved
f" of {list(self._metric_attributes.values())}"
)

value = apply_to_collection(value, numbers.Number, self.__to_float)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Should we conserve the logged type ?

Copy link
Contributor Author

@carmocca carmocca Jun 9, 2021

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

So not converting to float tensor but just wrapping it in tensor?

We can, but I don't think this matters as ResultMetric.update will convert it to float anyways

edit: changed to __to_tensor

pytorch_lightning/core/lightning.py Outdated Show resolved Hide resolved
pytorch_lightning/trainer/evaluation_loop.py Show resolved Hide resolved
@@ -126,6 +142,7 @@ def setup(self, max_batches: List[Union[int, float]], dataloaders: List[DataLoad
self.num_dataloaders = self._get_num_dataloaders(dataloaders)

def on_evaluation_epoch_start(self, *args: Any, **kwargs: Any) -> None:
self.trainer.logger_connector.on_epoch_start()
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
self.trainer.logger_connector.on_epoch_start()
# update ResultCollection.
self.trainer.logger_connector.on_epoch_start()

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This doesn't update the ResultCollection, just sets a flag

I don't think that comment is useful

pytorch_lightning/trainer/evaluation_loop.py Show resolved Hide resolved
pytorch_lightning/trainer/trainer.py Show resolved Hide resolved
pytorch_lightning/trainer/training_loop.py Outdated Show resolved Hide resolved
pytorch_lightning/trainer/training_loop.py Outdated Show resolved Hide resolved
pytorch_lightning/trainer/training_loop.py Show resolved Hide resolved
@mergify mergify bot added the has conflicts label Jun 9, 2021
pytorch_lightning/core/lightning.py Outdated Show resolved Hide resolved
def on_evaluation_batch_start(self, batch: Any, batch_idx: int, dataloader_idx: int) -> None:
self.trainer.logger_connector.on_batch_start()
# FIXME(@carmocca): missing hook?
# self.trainer.call_hook('on_batch_start')
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

it's missing on purpose, I thought we decided not to run on_batch_start/end regardless of train/val/predict. It's only for training.

cc @ananthsub

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Where did we discuss this? 😂
It's weird then because we do run on_epoch_{start,end} for them.

This is unrelated to the PR though, can remove the FIXME and address it again in #7738

Copy link
Contributor

@awaelchli awaelchli Jun 9, 2021

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I initially suggested it in this issue a long time ago: #1440
Then later there was in a slack message thread. @ananthsub had an argument against, which I don't remember. So we decided not to do it.
And also because it would be impossible to make backward compatible. Docs say it runs for training.

pytorch_lightning/trainer/evaluation_loop.py Outdated Show resolved Hide resolved
Copy link
Member

@ethanwharris ethanwharris left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM, small comment

pytorch_lightning/core/lightning.py Outdated Show resolved Hide resolved
pytorch_lightning/core/lightning.py Outdated Show resolved Hide resolved
Comment on lines 486 to 492
@property
def active_loop(self) -> Optional[Union[TrainLoop, EvaluationLoop]]:
if self.training:
return self.train_loop
elif self.sanity_checking or self.evaluating:
return self.evaluation_loop

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

why does this need to be exposed as a property? doesn't this leak the implementation detail? what if someone accesses the active_loop and then modifies properties on it?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

So we want to access the current ResultCollection object from the logger connector.
And each Loop has its own ResultCollection.
So we need this to get the current running loop.

I guess we can make this property protected to discourage external modifications.

cc: @awaelchli

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

do you want me to directly do it now in the new loops? should be no problem.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Did it already with ab28850

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

so this was not about the results in loops?

pytorch_lightning/trainer/training_loop.py Outdated Show resolved Hide resolved
@mergify mergify bot removed the has conflicts label Jun 9, 2021
@awaelchli awaelchli mentioned this pull request Jun 9, 2021
11 tasks
CHANGELOG.md Outdated Show resolved Hide resolved
@awaelchli awaelchli added the logging Related to the `LoggerConnector` and `log()` label Jun 9, 2021
@carmocca carmocca enabled auto-merge (squash) June 9, 2021 12:11
@mergify mergify bot added the has conflicts label Jun 9, 2021
@mergify mergify bot removed the has conflicts label Jun 9, 2021
@carmocca carmocca merged commit ec4f885 into master Jun 9, 2021
@carmocca carmocca deleted the refactor/use-new-logger-connector branch June 9, 2021 14:24
@Queuecumber
Copy link
Contributor

Could someone explain what the batch_size parameter does? I don't see it being used in the code anywhere and the docs don't explain it.

Comment on lines +309 to +310
batch_size: Current batch_size. This will be directly inferred from the loaded batch,
but some data structures might need to explicitly provide it.
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Here are the docs for batch size, is this what you mean?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeah but what I'm missing is why e.g. what is that parameter used for (inferred or otherwise)?

Copy link
Contributor

@awaelchli awaelchli Jun 10, 2021

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

To compute the correct average when we ask self.log to average the metric on epoch end.

It has to be weighted by the batch size because often the last batch does not have the same size as the others.
The dataset is not guaranteed to be divisible by the batch size and the drop_last in the PyTorch DataLoader is False by default.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

In the ResultMetric in result.py you will find the line:

self.cumulated_batch_size += batch_size
and the cumulated_batch_size is then used in the compute() method.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

ok that's what I thought it was for but I couldn't find in the code where it's actually doing that.

So does this mean that my _step should log a scalar which is the mean of the current batch and PL will correctly average (including across DDP processes) by multiplying with the batch size, summing, then dividing by the dataset size?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

result.track_batch_size(len(split_batch))

# track metrics without grads for epoch reduction
training_step_output_for_epoch_end = copy(result)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

the removal of this line could possibly be the cause of #8613

@mergify mergify bot added the ready PRs ready to be merged label Jul 29, 2021
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
feature Is an improvement or enhancement logging Related to the `LoggerConnector` and `log()` ready PRs ready to be merged
Projects
None yet
Development

Successfully merging this pull request may close these issues.

7 participants