-
Notifications
You must be signed in to change notification settings - Fork 3.4k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[RFC] Depreceate the move_metrics_to_cpu Trainer argument. #10595
Comments
The memory consideration you outlined with O(num_batches_in_epoch) vs O(1) seems convincing enough to me to go with Option 2. |
I don't see how the difference in memory between versions justifies either of the options. I understand the flag was more "useful" before because there was "more memory to move", but the flag still has its use today after the reduced memory requirements. What the flag does today is:
edit: fixed with #10631 b. Moves logged results after training_step (for both automatic and manual optimization) https://github.com/PyTorchLightning/pytorch-lightning/blob/261ea90822e2bf1cfa5d56171ab1f95a81d5c571/pytorch_lightning/loops/optimization/optimizer_loop.py#L448-L451 https://github.com/PyTorchLightning/pytorch-lightning/blob/261ea90822e2bf1cfa5d56171ab1f95a81d5c571/pytorch_lightning/loops/optimization/manual_loop.py#L118-L122 Now, (b) is a different story. We could remove it but it leaves the user with no alternative if they still want to move the saved results to CPU. They could still want this regardless of the memory usage. Their only equivalent solution would be: def training_step():
...
self.log(...)
self.trainer._results.cpu()
return loss which is not great UX. Now coming back to your proposal, option (2) is the easy-but-lazy option which also removes a feature. It's worth considering because doing it well is hard (c.f. your cons) and the current value for the Personally, I don't know what's the best option, it depends on the value this flag provides to the community. |
We don't have OOM related to Metrics as we used to.
I believe it is an abuse of this argument and wasn't properly designed at the time.
I don't think it is hard, just engineering heavy. This could make the Result code more complex. I guess the best way forward is to implement it, benchmark its impact and make a decision then. Best, |
Hey @awaelchli @ananthsub @Quintulius. After investigation with @carmocca, here are the current challenges. The code below would fail on master with class TestModel(BoringModel):
def __init__(self):
super().__init__()
# the metric is automatically moved to GPU at the beginning
self.metric = DummyMetric()
def training_step(self, batch, batch_idx):
# first time this runs, the metric and the batch are on GPU. All good.
# the second time, the metric is on CPU but the batch on GPU. Error
self.metric(batch.sum().long())
self.log("train_loss_2", self.metric, on_epoch=True)
return super().training_step(batch, batch_idx)
# will call `trainer._results.cpu()` right after, this includes logged `Metric`s This raises: def update(self, x):
> self.x += x
E RuntimeError: Expected all tensors to be on the same device, but found at least two devices, cuda:0 and CPU! However, if we were to move the inputs to the metric to the device self.metric(batch.sum().long().cpu()) It would fail later on while trying to sync the metric across processes. pytorch_lightning/trainer/connectors/logger_connector/result.py:285: in wrapped_func
self._computed = compute(*args, **kwargs)
pytorch_lightning/trainer/connectors/logger_connector/result.py:253: in compute
return self.value.compute()
.venv/lib/python3.8/site-packages/torchmetrics/metric.py:362: in wrapped_func
with self.sync_context(
../.pyenv/versions/3.8.5/lib/python3.8/contextlib.py:113: in __enter__
return next(self.gen)
.venv/lib/python3.8/site-packages/torchmetrics/metric.py:333: in sync_context
self.sync(
.venv/lib/python3.8/site-packages/torchmetrics/metric.py:285: in sync
self._sync_dist(dist_sync_fn, process_group=process_group)
.venv/lib/python3.8/site-packages/torchmetrics/metric.py:224: in _sync_dist
output_dict = apply_to_collection(
.venv/lib/python3.8/site-packages/torchmetrics/utilities/data.py:191: in apply_to_collection
return elem_type({k: apply_to_collection(v, dtype, function, *args, **kwargs) for k, v in data.items()})
.venv/lib/python3.8/site-packages/torchmetrics/utilities/data.py:191: in <dictcomp>
return elem_type({k: apply_to_collection(v, dtype, function, *args, **kwargs) for k, v in data.items()})
.venv/lib/python3.8/site-packages/torchmetrics/utilities/data.py:187: in apply_to_collection
return function(data, *args, **kwargs)
.venv/lib/python3.8/site-packages/torchmetrics/utilities/distributed.py:120: in gather_all_tensors
return _simple_gather_all_tensors(result, group, world_size)
.venv/lib/python3.8/site-packages/torchmetrics/utilities/distributed.py:92: in _simple_gather_all_tensors
torch.distributed.all_gather(gathered_result, result, group)
RuntimeError: Tensors must be CUDA and dense This seems to be quite an unsolvable problem for TorchMetric objects unless we start patching the Metric objects which I think isn't a good option. This could be left to advanced users with a more concrete example within the documentation, but the user would have to manually move the metric (breaking down the Lightning philosophy around device abstraction). However, it is possible to support this for tensors, but Options are:
self.log(..., store_on_cpu=True)
# note: it would raise a warning if passing a `Metric` and setting it.
Any other options or ideas are welcome :) |
@tchaton Thanks for the detailed explanation ! On the metrics side:
seems reasonnable if this feature is rarely used. |
📚 Documentation
After investigating this community issue #10379, it seems the parameter
move_metrics_to_cpu
isn't working as expected withddp
.Bug reason: The ResultCollection object storing the metrics is moved on the CPU here and on epoch end when performing compute, the tensors are being reduced and this would raise
RuntimeError: Tensors must be CUDA and dense
.Note, this should also fail with
sync_dist
argument within the self.log method.Background:
Before: Logging prior to the ResultCollection object, tensor metrics used to have an o(num_batches_in_epoch) memory footprint as there were stored within a list and reduced on epoch ended. When the epoch had a very large number of batches, this would raise an OOM.
Now: The ResultCollection has a memory space of o(1) for tensor metrics therefore, the
move_metrics_to_cpu
argument isn't as impactful as before.As believe there is 2 option forward:
Option 1 🚀:
Make the ResultCollection of
move_metrics_to_cpu
. The ResultCollection would be responsible to move back and forth the ResultMetric on device before distributed collection and on CPU right after.Pros:
Cons:
Here is the pseudo-code for such solution. This would be spread across several code parts.
Option 2 😋:
Depreciate and remove the section of code moving the ResultCollection to CPU.
If you enjoy Lightning, check out our other projects! ⚡
Metrics: Machine learning metrics for distributed, scalable PyTorch applications.
Lite: enables pure PyTorch users to scale their existing code on any kind of device while retaining full control over their own loops and optimization logic.
Flash: The fastest way to get a Lightning baseline! A collection of tasks for fast prototyping, baselining, fine-tuning, and solving problems with deep learning.
Bolts: Pretrained SOTA Deep Learning models, callbacks, and more for research and production with PyTorch Lightning and PyTorch.
Lightning Transformers: Flexible interface for high-performance research using SOTA Transformers leveraging Pytorch Lightning, Transformers, and Hydra.
cc @carmocca @edward-io @tchaton
The text was updated successfully, but these errors were encountered: