diff --git a/CHANGELOG.md b/CHANGELOG.md index b13080eec75db..1a6b3daa5ebe7 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -12,6 +12,9 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/). - Added casting to python types for numpy scalars when logging hparams ([#4647](https://github.com/PyTorchLightning/pytorch-lightning/pull/4647)) +- Added warning when progress bar refresh rate is less than 20 on Google Colab to prevent crashing ([#4654](https://github.com/PyTorchLightning/pytorch-lightning/pull/4654)) + + - Added `F1` class metric ([#4656](https://github.com/PyTorchLightning/pytorch-lightning/pull/4656)) diff --git a/pytorch_lightning/trainer/connectors/callback_connector.py b/pytorch_lightning/trainer/connectors/callback_connector.py index c9ef4ae32be77..d581497d3fd87 100644 --- a/pytorch_lightning/trainer/connectors/callback_connector.py +++ b/pytorch_lightning/trainer/connectors/callback_connector.py @@ -12,10 +12,9 @@ # See the License for the specific language governing permissions and # limitations under the License. import os +from typing import Optional, Union -from typing import Union, Optional - -from pytorch_lightning.callbacks import ModelCheckpoint, EarlyStopping, ProgressBarBase, ProgressBar +from pytorch_lightning.callbacks import EarlyStopping, ModelCheckpoint, ProgressBar, ProgressBarBase from pytorch_lightning.utilities import rank_zero_warn from pytorch_lightning.utilities.exceptions import MisconfigurationException @@ -74,6 +73,14 @@ def configure_checkpoint_callbacks(self, checkpoint_callback: Union[ModelCheckpo self.trainer.callbacks.append(ModelCheckpoint(dirpath=None, filename=None)) def configure_progress_bar(self, refresh_rate=1, process_position=0): + # smaller refresh rate on colab causes crashes, warn user about this + if os.getenv('COLAB_GPU') and refresh_rate < 20: + rank_zero_warn( + "You have set progress_bar_refresh_rate < 20 on Google Colab. This" + " may crash. Consider using progress_bar_refresh_rate >= 20 in Trainer.", + UserWarning + ) + progress_bars = [c for c in self.trainer.callbacks if isinstance(c, ProgressBarBase)] if len(progress_bars) > 1: raise MisconfigurationException( diff --git a/pytorch_lightning/utilities/__init__.py b/pytorch_lightning/utilities/__init__.py index 7075c790c60d9..b87aec0bf3cb5 100644 --- a/pytorch_lightning/utilities/__init__.py +++ b/pytorch_lightning/utilities/__init__.py @@ -18,7 +18,7 @@ import torch from pytorch_lightning.utilities.apply_func import move_data_to_device -from pytorch_lightning.utilities.distributed import rank_zero_only, rank_zero_warn, rank_zero_info +from pytorch_lightning.utilities.distributed import rank_zero_info, rank_zero_only, rank_zero_warn from pytorch_lightning.utilities.parsing import AttributeDict, flatten_dict, is_picklable try: diff --git a/tests/callbacks/test_progress_bar.py b/tests/callbacks/test_progress_bar.py index 221844244ad75..aca43c581d4c9 100644 --- a/tests/callbacks/test_progress_bar.py +++ b/tests/callbacks/test_progress_bar.py @@ -11,7 +11,9 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. +import os import pytest +from unittest import mock from pytorch_lightning import Trainer from pytorch_lightning.callbacks import ProgressBarBase, ProgressBar, ModelCheckpoint @@ -239,3 +241,14 @@ def on_validation_epoch_end(self, trainer, pl_module): ) trainer.fit(model) assert trainer.progress_bar_callback.val_progress_bar_total == expected + + +@mock.patch.dict(os.environ, {'COLAB_GPU': '1'}) +def test_progress_bar_warning_on_colab(tmpdir): + with pytest.warns(UserWarning, match='on Google Colab. This may crash.'): + trainer = Trainer( + default_root_dir=tmpdir, + progress_bar_refresh_rate=19, + ) + + assert trainer.progress_bar_callback.refresh_rate == 19