Skip to content

Commit

Permalink
Use high progress_bar_refresh_rate on Google Colab (#4654)
Browse files Browse the repository at this point in the history
* Use high refresh rate on Google Colab (#3786)

Automatically override progress_bar_refresh_rate when on Google
Colab. Also added a constant IS_COLAB in utilities to check
whether it is being run in colab or not.
(#3786)

* Show a warning instead of overriding when rate is low on colab

* Change warning to suggestion and move it

Moved warning to configure_progress_bar instead of on_trainer_init

* Apply suggestions from code review

Co-authored-by: Rohit Gupta <rohitgr1998@gmail.com>

* add a mock test

Co-authored-by: chaton <thomas@grid.ai>
Co-authored-by: Jirka Borovec <Borda@users.noreply.github.com>
Co-authored-by: Rohit Gupta <rohitgr1998@gmail.com>

(cherry picked from commit ccf38ce)
  • Loading branch information
Samyak2 authored and Borda committed Nov 24, 2020
1 parent 278b9a9 commit c8e83a1
Show file tree
Hide file tree
Showing 4 changed files with 27 additions and 4 deletions.
3 changes: 3 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,9 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
- Added casting to python types for numpy scalars when logging hparams ([#4647](https://github.com/PyTorchLightning/pytorch-lightning/pull/4647))


- Added warning when progress bar refresh rate is less than 20 on Google Colab to prevent crashing ([#4654](https://github.com/PyTorchLightning/pytorch-lightning/pull/4654))


- Added `F1` class metric ([#4656](https://github.com/PyTorchLightning/pytorch-lightning/pull/4656))


Expand Down
13 changes: 10 additions & 3 deletions pytorch_lightning/trainer/connectors/callback_connector.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,10 +12,9 @@
# See the License for the specific language governing permissions and
# limitations under the License.
import os
from typing import Optional, Union

from typing import Union, Optional

from pytorch_lightning.callbacks import ModelCheckpoint, EarlyStopping, ProgressBarBase, ProgressBar
from pytorch_lightning.callbacks import EarlyStopping, ModelCheckpoint, ProgressBar, ProgressBarBase
from pytorch_lightning.utilities import rank_zero_warn
from pytorch_lightning.utilities.exceptions import MisconfigurationException

Expand Down Expand Up @@ -74,6 +73,14 @@ def configure_checkpoint_callbacks(self, checkpoint_callback: Union[ModelCheckpo
self.trainer.callbacks.append(ModelCheckpoint(dirpath=None, filename=None))

def configure_progress_bar(self, refresh_rate=1, process_position=0):
# smaller refresh rate on colab causes crashes, warn user about this
if os.getenv('COLAB_GPU') and refresh_rate < 20:
rank_zero_warn(
"You have set progress_bar_refresh_rate < 20 on Google Colab. This"
" may crash. Consider using progress_bar_refresh_rate >= 20 in Trainer.",
UserWarning
)

progress_bars = [c for c in self.trainer.callbacks if isinstance(c, ProgressBarBase)]
if len(progress_bars) > 1:
raise MisconfigurationException(
Expand Down
2 changes: 1 addition & 1 deletion pytorch_lightning/utilities/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@
import torch

from pytorch_lightning.utilities.apply_func import move_data_to_device
from pytorch_lightning.utilities.distributed import rank_zero_only, rank_zero_warn, rank_zero_info
from pytorch_lightning.utilities.distributed import rank_zero_info, rank_zero_only, rank_zero_warn
from pytorch_lightning.utilities.parsing import AttributeDict, flatten_dict, is_picklable

try:
Expand Down
13 changes: 13 additions & 0 deletions tests/callbacks/test_progress_bar.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,9 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import os
import pytest
from unittest import mock

from pytorch_lightning import Trainer
from pytorch_lightning.callbacks import ProgressBarBase, ProgressBar, ModelCheckpoint
Expand Down Expand Up @@ -239,3 +241,14 @@ def on_validation_epoch_end(self, trainer, pl_module):
)
trainer.fit(model)
assert trainer.progress_bar_callback.val_progress_bar_total == expected


@mock.patch.dict(os.environ, {'COLAB_GPU': '1'})
def test_progress_bar_warning_on_colab(tmpdir):
with pytest.warns(UserWarning, match='on Google Colab. This may crash.'):
trainer = Trainer(
default_root_dir=tmpdir,
progress_bar_refresh_rate=19,
)

assert trainer.progress_bar_callback.refresh_rate == 19

0 comments on commit c8e83a1

Please sign in to comment.