Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[RFC] Depreceate the move_metrics_to_cpu Trainer argument. #10595

Closed
tchaton opened this issue Nov 17, 2021 · 5 comments
Closed

[RFC] Depreceate the move_metrics_to_cpu Trainer argument. #10595

tchaton opened this issue Nov 17, 2021 · 5 comments
Labels
deprecation Includes a deprecation logging Related to the `LoggerConnector` and `log()`
Milestone

Comments

@tchaton
Copy link
Contributor

tchaton commented Nov 17, 2021

📚 Documentation

After investigating this community issue #10379, it seems the parameter move_metrics_to_cpu isn't working as expected with ddp.

Bug reason: The ResultCollection object storing the metrics is moved on the CPU here and on epoch end when performing compute, the tensors are being reduced and this would raise RuntimeError: Tensors must be CUDA and dense.
Note, this should also fail with sync_dist argument within the self.log method.

Background:

Before: Logging prior to the ResultCollection object, tensor metrics used to have an o(num_batches_in_epoch) memory footprint as there were stored within a list and reduced on epoch ended. When the epoch had a very large number of batches, this would raise an OOM.

Now: The ResultCollection has a memory space of o(1) for tensor metrics therefore, the move_metrics_to_cpu argument isn't as impactful as before.

As believe there is 2 option forward:

Option 1 🚀:

Make the ResultCollection of move_metrics_to_cpu. The ResultCollection would be responsible to move back and forth the ResultMetric on device before distributed collection and on CPU right after.

Pros:

  • This might be impactful for large Metric from TM but a user could alternatively do this manually within its LightningModule

Cons:

  • Engineering heavy
  • Perform drop.

Here is the pseudo-code for such solution. This would be spread across several code parts.

class ResultCollection

    def __init__(self, ..., move_metrics_to_cpu):
        self.move_metrics_to_cpu = move_metrics_to_cpu

    def metrics(self, on_step = False):
        if on_step
            for result_metric in self.result_metrics:
                # move the metric back the device
                if self.move_metrics_to_cpu:
                    result_metric.to(self.device)
                ... = result_metric.compute() # perform distributed reduction
                # move the metric back to cpu
                if self.move_metrics_to_cpu:
                    result_metric.to("cpu")

Option 2 😋:

Depreciate and remove the section of code moving the ResultCollection to CPU.


If you enjoy Lightning, check out our other projects! ⚡

  • Metrics: Machine learning metrics for distributed, scalable PyTorch applications.

  • Lite: enables pure PyTorch users to scale their existing code on any kind of device while retaining full control over their own loops and optimization logic.

  • Flash: The fastest way to get a Lightning baseline! A collection of tasks for fast prototyping, baselining, fine-tuning, and solving problems with deep learning.

  • Bolts: Pretrained SOTA Deep Learning models, callbacks, and more for research and production with PyTorch Lightning and PyTorch.

  • Lightning Transformers: Flexible interface for high-performance research using SOTA Transformers leveraging Pytorch Lightning, Transformers, and Hydra.

cc @carmocca @edward-io @tchaton

@tchaton tchaton added the deprecation Includes a deprecation label Nov 17, 2021
@tchaton tchaton added this to the 1.6 milestone Nov 17, 2021
@awaelchli
Copy link
Contributor

The memory consideration you outlined with O(num_batches_in_epoch) vs O(1) seems convincing enough to me to go with Option 2.

@carmocca carmocca added the logging Related to the `LoggerConnector` and `log()` label Nov 17, 2021
@carmocca
Copy link
Contributor

carmocca commented Nov 18, 2021

I don't see how the difference in memory between versions justifies either of the options. I understand the flag was more "useful" before because there was "more memory to move", but the flag still has its use today after the reduced memory requirements.

What the flag does today is:
a. Moves saved evaluation outputs (if necessary) on epoch end https://github.com/PyTorchLightning/pytorch-lightning/blob/261ea90822e2bf1cfa5d56171ab1f95a81d5c571/pytorch_lightning/loops/epoch/evaluation_epoch_loop.py#L137-L138

(a) Seems like a bug to me as the user might need them on the correct device for an operation in validation_epoch_end. Additionally, this move is not done for stored training outputs. This was added by you in #4592 where I guess it made more sense as these old Result objects were shared between logging and epoch-end outputs. I'd advocate for removing (a) because the user can easily move this to CPU if desired

{'loss': ..., 'something_else': something_else.cpu()}

edit: fixed with #10631

b. Moves logged results after training_step (for both automatic and manual optimization) https://github.com/PyTorchLightning/pytorch-lightning/blob/261ea90822e2bf1cfa5d56171ab1f95a81d5c571/pytorch_lightning/loops/optimization/optimizer_loop.py#L448-L451 https://github.com/PyTorchLightning/pytorch-lightning/blob/261ea90822e2bf1cfa5d56171ab1f95a81d5c571/pytorch_lightning/loops/optimization/manual_loop.py#L118-L122

Now, (b) is a different story. We could remove it but it leaves the user with no alternative if they still want to move the saved results to CPU. They could still want this regardless of the memory usage. Their only equivalent solution would be:

def training_step():
    ...
    self.log(...)
    self.trainer._results.cpu()
    return loss

which is not great UX.

Now coming back to your proposal, option (2) is the easy-but-lazy option which also removes a feature. It's worth considering because doing it well is hard (c.f. your cons) and the current value for the move_metrics_to_cpu is low, but doing so also goes against the "we do the hard work" ideal.

Personally, I don't know what's the best option, it depends on the value this flag provides to the community.

@tchaton
Copy link
Contributor Author

tchaton commented Nov 18, 2021

@carmocca

I don't see how the difference in memory between versions justifies either of the options

We don't have OOM related to Metrics as we used to.

a. Moves saved evaluation outputs (if necessary) on epoch end

I believe it is an abuse of this argument and wasn't properly designed at the time. move_outputs_to_cpu would have been more appropriate.

considering because doing it well is hard

I don't think it is hard, just engineering heavy. This could make the Result code more complex.
But I am more concerned about the performance drop of moving the ResultMetric back and forth to the device.
A user might activate this feature and share Lightning is slow without fully understanding the impact of such feature.

I guess the best way forward is to implement it, benchmark its impact and make a decision then.

Best,
T.C

@tchaton
Copy link
Contributor Author

tchaton commented Nov 24, 2021

Hey @awaelchli @ananthsub @Quintulius. After investigation with @carmocca, here are the current challenges.

The code below would fail on master with move_metrics_to_cpu=True, gpus=1 as the metric will be moved to CPU after the first training_step call:

class TestModel(BoringModel):
    def __init__(self):
        super().__init__()
        # the metric is automatically moved to GPU at the beginning
        self.metric = DummyMetric()

    def training_step(self, batch, batch_idx):
        # first time this runs, the metric and the batch are on GPU. All good.
        # the second time, the metric is on CPU but the batch on GPU. Error
        self.metric(batch.sum().long())
        self.log("train_loss_2", self.metric, on_epoch=True)
        return super().training_step(batch, batch_idx)
        # will call `trainer._results.cpu()` right after, this includes logged `Metric`s

This raises:

    def update(self, x):
>       self.x += x
E       RuntimeError: Expected all tensors to be on the same device, but found at least two devices, cuda:0 and CPU!

However, if we were to move the inputs to the metric to the device

self.metric(batch.sum().long().cpu())

It would fail later on while trying to sync the metric across processes.

pytorch_lightning/trainer/connectors/logger_connector/result.py:285: in wrapped_func
    self._computed = compute(*args, **kwargs)
pytorch_lightning/trainer/connectors/logger_connector/result.py:253: in compute
    return self.value.compute()
.venv/lib/python3.8/site-packages/torchmetrics/metric.py:362: in wrapped_func
    with self.sync_context(
../.pyenv/versions/3.8.5/lib/python3.8/contextlib.py:113: in __enter__
    return next(self.gen)
.venv/lib/python3.8/site-packages/torchmetrics/metric.py:333: in sync_context
    self.sync(
.venv/lib/python3.8/site-packages/torchmetrics/metric.py:285: in sync
    self._sync_dist(dist_sync_fn, process_group=process_group)
.venv/lib/python3.8/site-packages/torchmetrics/metric.py:224: in _sync_dist
    output_dict = apply_to_collection(
.venv/lib/python3.8/site-packages/torchmetrics/utilities/data.py:191: in apply_to_collection
    return elem_type({k: apply_to_collection(v, dtype, function, *args, **kwargs) for k, v in data.items()})
.venv/lib/python3.8/site-packages/torchmetrics/utilities/data.py:191: in <dictcomp>
    return elem_type({k: apply_to_collection(v, dtype, function, *args, **kwargs) for k, v in data.items()})
.venv/lib/python3.8/site-packages/torchmetrics/utilities/data.py:187: in apply_to_collection
    return function(data, *args, **kwargs)
.venv/lib/python3.8/site-packages/torchmetrics/utilities/distributed.py:120: in gather_all_tensors
    return _simple_gather_all_tensors(result, group, world_size)
.venv/lib/python3.8/site-packages/torchmetrics/utilities/distributed.py:92: in _simple_gather_all_tensors
    torch.distributed.all_gather(gathered_result, result, group)
RuntimeError: Tensors must be CUDA and dense

This seems to be quite an unsolvable problem for TorchMetric objects unless we start patching the Metric objects which I think isn't a good option. This could be left to advanced users with a more concrete example within the documentation, but the user would have to manually move the metric (breaking down the Lightning philosophy around device abstraction).

However, it is possible to support this for tensors, but move_metrics_to_cpu would be quite confusing.

Options are:

  1. Add support for moving only logged tensors to CPU with a new parameter in self.log while deprecating move_metrics_to_cpu from the Trainer. During the deprecation process, if the user sets Trainer(move_metrics_to_cpu=True), self.log will default to store_on_cpu=True if a tensor is passed.
self.log(..., store_on_cpu=True)
# note: it would raise a warning if passing a `Metric` and setting it.
  1. Remove this feature and provide an example of how to do this manually with TorchMetrics. Moving tensors to CPU would not be supported. This should not a big deal, but there might be people with very strict memory requirements where keeping the logged tensors on GPU is not an option.

Any other options or ideas are welcome :)

@Quintulius
Copy link

Quintulius commented Dec 6, 2021

@tchaton Thanks for the detailed explanation !

On the metrics side:

This could be left to advanced users with a more concrete example within the documentation, but the user would have to manually move the metric (breaking down the Lightning philosophy around device abstraction

seems reasonnable if this feature is rarely used.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
deprecation Includes a deprecation logging Related to the `LoggerConnector` and `log()`
Projects
No open projects
Status: Accepted
Development

No branches or pull requests

4 participants