From c8e83a138f1d466a6ce8912e3c72e1aa4a7ba76d Mon Sep 17 00:00:00 2001 From: Samyak S Sarnayak <34161949+samyak2@users.noreply.github.com> Date: Mon, 23 Nov 2020 21:43:33 +0100 Subject: [PATCH] Use high progress_bar_refresh_rate on Google Colab (#4654) * Use high refresh rate on Google Colab (#3786) Automatically override progress_bar_refresh_rate when on Google Colab. Also added a constant IS_COLAB in utilities to check whether it is being run in colab or not. (#3786) * Show a warning instead of overriding when rate is low on colab * Change warning to suggestion and move it Moved warning to configure_progress_bar instead of on_trainer_init * Apply suggestions from code review Co-authored-by: Rohit Gupta * add a mock test Co-authored-by: chaton Co-authored-by: Jirka Borovec Co-authored-by: Rohit Gupta (cherry picked from commit ccf38ced2e2e4191c7dfa5b139ce9c08a6f64c1f) --- CHANGELOG.md | 3 +++ .../trainer/connectors/callback_connector.py | 13 ++++++++++--- pytorch_lightning/utilities/__init__.py | 2 +- tests/callbacks/test_progress_bar.py | 13 +++++++++++++ 4 files changed, 27 insertions(+), 4 deletions(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index b13080eec75db..1a6b3daa5ebe7 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -12,6 +12,9 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/). - Added casting to python types for numpy scalars when logging hparams ([#4647](https://github.com/PyTorchLightning/pytorch-lightning/pull/4647)) +- Added warning when progress bar refresh rate is less than 20 on Google Colab to prevent crashing ([#4654](https://github.com/PyTorchLightning/pytorch-lightning/pull/4654)) + + - Added `F1` class metric ([#4656](https://github.com/PyTorchLightning/pytorch-lightning/pull/4656)) diff --git a/pytorch_lightning/trainer/connectors/callback_connector.py b/pytorch_lightning/trainer/connectors/callback_connector.py index c9ef4ae32be77..d581497d3fd87 100644 --- a/pytorch_lightning/trainer/connectors/callback_connector.py +++ b/pytorch_lightning/trainer/connectors/callback_connector.py @@ -12,10 +12,9 @@ # See the License for the specific language governing permissions and # limitations under the License. import os +from typing import Optional, Union -from typing import Union, Optional - -from pytorch_lightning.callbacks import ModelCheckpoint, EarlyStopping, ProgressBarBase, ProgressBar +from pytorch_lightning.callbacks import EarlyStopping, ModelCheckpoint, ProgressBar, ProgressBarBase from pytorch_lightning.utilities import rank_zero_warn from pytorch_lightning.utilities.exceptions import MisconfigurationException @@ -74,6 +73,14 @@ def configure_checkpoint_callbacks(self, checkpoint_callback: Union[ModelCheckpo self.trainer.callbacks.append(ModelCheckpoint(dirpath=None, filename=None)) def configure_progress_bar(self, refresh_rate=1, process_position=0): + # smaller refresh rate on colab causes crashes, warn user about this + if os.getenv('COLAB_GPU') and refresh_rate < 20: + rank_zero_warn( + "You have set progress_bar_refresh_rate < 20 on Google Colab. This" + " may crash. Consider using progress_bar_refresh_rate >= 20 in Trainer.", + UserWarning + ) + progress_bars = [c for c in self.trainer.callbacks if isinstance(c, ProgressBarBase)] if len(progress_bars) > 1: raise MisconfigurationException( diff --git a/pytorch_lightning/utilities/__init__.py b/pytorch_lightning/utilities/__init__.py index 7075c790c60d9..b87aec0bf3cb5 100644 --- a/pytorch_lightning/utilities/__init__.py +++ b/pytorch_lightning/utilities/__init__.py @@ -18,7 +18,7 @@ import torch from pytorch_lightning.utilities.apply_func import move_data_to_device -from pytorch_lightning.utilities.distributed import rank_zero_only, rank_zero_warn, rank_zero_info +from pytorch_lightning.utilities.distributed import rank_zero_info, rank_zero_only, rank_zero_warn from pytorch_lightning.utilities.parsing import AttributeDict, flatten_dict, is_picklable try: diff --git a/tests/callbacks/test_progress_bar.py b/tests/callbacks/test_progress_bar.py index 221844244ad75..aca43c581d4c9 100644 --- a/tests/callbacks/test_progress_bar.py +++ b/tests/callbacks/test_progress_bar.py @@ -11,7 +11,9 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. +import os import pytest +from unittest import mock from pytorch_lightning import Trainer from pytorch_lightning.callbacks import ProgressBarBase, ProgressBar, ModelCheckpoint @@ -239,3 +241,14 @@ def on_validation_epoch_end(self, trainer, pl_module): ) trainer.fit(model) assert trainer.progress_bar_callback.val_progress_bar_total == expected + + +@mock.patch.dict(os.environ, {'COLAB_GPU': '1'}) +def test_progress_bar_warning_on_colab(tmpdir): + with pytest.warns(UserWarning, match='on Google Colab. This may crash.'): + trainer = Trainer( + default_root_dir=tmpdir, + progress_bar_refresh_rate=19, + ) + + assert trainer.progress_bar_callback.refresh_rate == 19