Skip to content

Commit

Permalink
make on_colab_kaggle utility func
Browse files Browse the repository at this point in the history
  • Loading branch information
justusschock committed Jan 31, 2021
1 parent cd25a6a commit a950f48
Show file tree
Hide file tree
Showing 3 changed files with 9 additions and 10 deletions.
7 changes: 2 additions & 5 deletions pytorch_lightning/plugins/training_type/single_tpu.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
import torch

from pytorch_lightning.plugins.training_type.single_device import SingleDevicePlugin
from pytorch_lightning.plugins.training_type.utils import on_colab_kaggle
from pytorch_lightning.utilities import _TPU_AVAILABLE, rank_zero_warn

if _TPU_AVAILABLE:
Expand Down Expand Up @@ -32,10 +33,6 @@ def pre_training(self) -> None:
def post_training(self) -> None:
model = self.lightning_module

if self.on_colab_kaggle:
if on_colab_kaggle():
rank_zero_warn("cleaning up... please do not interrupt")
self.save_spawn_weights(model)

@property
def on_colab_kaggle(self) -> bool:
return bool(os.getenv("COLAB_GPU") or os.getenv("KAGGLE_URL_BASE"))
7 changes: 2 additions & 5 deletions pytorch_lightning/plugins/training_type/tpu_spawn.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@

from pytorch_lightning.core.lightning import LightningModule
from pytorch_lightning.plugins.training_type.ddp_spawn import DDPSpawnPlugin
from pytorch_lightning.plugins.training_type.utils import on_colab_kaggle
from pytorch_lightning.trainer import Trainer
from pytorch_lightning.utilities import _TPU_AVAILABLE, rank_zero_warn
from pytorch_lightning.utilities.apply_func import move_data_to_device
Expand Down Expand Up @@ -74,7 +75,7 @@ def new_process(self, process_idx: int, trainer: Trainer) ->None:

def __save_end_of_training_weights(self, model: LightningModule, trainer: Trainer) -> None:
# when training ends on these platforms dump weights to get out of the main process
if self.on_colab_kaggle:
if on_colab_kaggle():
rank_zero_warn("cleaning up... please do not interrupt")
self.save_spawn_weights(model)

Expand All @@ -92,10 +93,6 @@ def on_save(self, checkpoint: dict) -> dict:
"""
return move_data_to_device(checkpoint, torch.device("cpu"))

@property
def on_colab_kaggle(self) -> bool:
return bool(os.getenv('COLAB_GPU') or os.getenv('KAGGLE_URL_BASE'))

def broadcast(self, obj: object, src:int=0)->object:
buffer = io.BytesIO()
torch.save(obj, buffer)
Expand Down
5 changes: 5 additions & 0 deletions pytorch_lightning/plugins/training_type/utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
import os


def on_colab_kaggle() -> bool:
return bool(os.getenv("COLAB_GPU") or os.getenv("KAGGLE_URL_BASE"))

0 comments on commit a950f48

Please sign in to comment.