Skip to content

Commit

Permalink
Merge branch 'master' into feature/logging-sub-dir
Browse files Browse the repository at this point in the history
  • Loading branch information
s-rog committed May 18, 2021
2 parents fbe724c + 20f6337 commit 3fb1bec
Show file tree
Hide file tree
Showing 26 changed files with 431 additions and 142 deletions.
9 changes: 9 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,10 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).

### Added

- Added `KubeflowEnvironment` for use with the `PyTorchJob` operator in Kubeflow

- Added LightningCLI support for config files on object stores ([#7521](https://github.com/PyTorchLightning/pytorch-lightning/pull/7521))


- Added argument `trainer.predict(ckpt_path)` ([#7430](https://github.com/PyTorchLightning/pytorch-lightning/pull/7430))

Expand Down Expand Up @@ -36,6 +40,8 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
- Refactored Loops
* Moved attributes `global_step`, `current_epoch`, `max/min_steps`, `max/min_epochs`, `batch_idx`, and `total_batch_idx` to TrainLoop ([#7437](https://github.com/PyTorchLightning/pytorch-lightning/pull/7025))
* Refactored result handling in training loop ([#7506](https://github.com/PyTorchLightning/pytorch-lightning/pull/7506))
* Moved attributes `hiddens` and `split_idx` to TrainLoop ([#7507](https://github.com/PyTorchLightning/pytorch-lightning/pull/7507))
* Refactored the logic around manual and automatic optimization inside the optimizer loop ([#7526](https://github.com/PyTorchLightning/pytorch-lightning/pull/7526))

- `DataModule`s now avoid duplicate `{setup,teardown,prepare_data}` calls for the same stage ([#7238](https://github.com/PyTorchLightning/pytorch-lightning/pull/7238))

Expand Down Expand Up @@ -371,6 +377,9 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
* Remove hardcoding of local rank in accelerator connector ([#6878](https://github.com/PyTorchLightning/pytorch-lightning/pull/6878))


- Fixed incorrect number of calls to LR scheduler when `check_val_every_n_epoch > 1` ([#7032](https://github.com/PyTorchLightning/pytorch-lightning/pull/7032))


## [1.2.7] - 2021-04-06

### Fixed
Expand Down
1 change: 1 addition & 0 deletions docs/source/api_references.rst
Original file line number Diff line number Diff line change
Expand Up @@ -125,6 +125,7 @@ Cluster Environments
ClusterEnvironment
LightningEnvironment
TorchElasticEnvironment
KubeflowEnvironment
SLURMEnvironment


Expand Down
1 change: 1 addition & 0 deletions docs/source/extensions/plugins.rst
Original file line number Diff line number Diff line change
Expand Up @@ -151,4 +151,5 @@ Cluster Environments
ClusterEnvironment
LightningEnvironment
TorchElasticEnvironment
KubeflowEnvironment
SLURMEnvironment
2 changes: 1 addition & 1 deletion pytorch_lightning/callbacks/finetuning.py
Original file line number Diff line number Diff line change
Expand Up @@ -282,7 +282,7 @@ def _store(

def on_train_epoch_start(self, trainer, pl_module):
"""Called when the epoch begins."""
for opt_idx, optimizer in trainer.train_loop.prepare_optimizers():
for opt_idx, optimizer in trainer.train_loop.get_active_optimizers():
num_param_groups = len(optimizer.param_groups)
self.finetune_function(pl_module, trainer.current_epoch, optimizer, opt_idx)
current_param_groups = optimizer.param_groups
Expand Down
12 changes: 8 additions & 4 deletions pytorch_lightning/callbacks/model_checkpoint.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,7 @@
from pytorch_lightning.utilities.exceptions import MisconfigurationException
from pytorch_lightning.utilities.types import _METRIC, STEP_OUTPUT
from pytorch_lightning.utilities.warnings import WarningCache
from pytorch_lightning.utilities.xla_device import tpu_training_and_local_rank_zero

log = logging.getLogger(__name__)
warning_cache = WarningCache()
Expand Down Expand Up @@ -453,7 +454,7 @@ def _do_save(self, trainer: 'pl.Trainer', filepath: str) -> None:
trainer.dev_debugger.track_checkpointing_history(filepath)

# make paths
if trainer.is_global_zero:
if trainer.is_global_zero or tpu_training_and_local_rank_zero(trainer):
self._fs.makedirs(os.path.dirname(filepath), exist_ok=True)

# delegate the saving to the trainer
Expand Down Expand Up @@ -591,7 +592,7 @@ def __resolve_ckpt_dir(self, trainer: 'pl.Trainer') -> None:

self.dirpath = ckpt_path

if not trainer.fast_dev_run and trainer.is_global_zero:
if (not trainer.fast_dev_run and (trainer.is_global_zero or tpu_training_and_local_rank_zero(trainer))):
self._fs.makedirs(self.dirpath, exist_ok=True)

def _add_backward_monitor_support(self, trainer: 'pl.Trainer') -> None:
Expand Down Expand Up @@ -654,7 +655,10 @@ def _save_last_checkpoint(self, trainer: 'pl.Trainer', monitor_candidates: Dict[

self._save_model(trainer, filepath)

if self.last_model_path and self.last_model_path != filepath and trainer.is_global_zero:
if (
self.last_model_path and self.last_model_path != filepath
and (trainer.is_global_zero or tpu_training_and_local_rank_zero(trainer))
):
self._del_model(self.last_model_path)

self.last_model_path = filepath
Expand All @@ -681,7 +685,7 @@ def _save_none_monitor_checkpoint(self, trainer: 'pl.Trainer', monitor_candidate

if (
self.save_top_k is None and self.best_model_path and self.best_model_path != filepath
and trainer.is_global_zero
and (trainer.is_global_zero or tpu_training_and_local_rank_zero(trainer))
):
self._del_model(self.best_model_path)

Expand Down
2 changes: 1 addition & 1 deletion pytorch_lightning/core/lightning.py
Original file line number Diff line number Diff line change
Expand Up @@ -1635,7 +1635,7 @@ def get_progress_bar_dict(self):
module_tbptt_enabled = self.truncated_bptt_steps > 0
trainer_tbptt_enabled = self.trainer.truncated_bptt_steps is not None and self.trainer.truncated_bptt_steps > 0
if module_tbptt_enabled or trainer_tbptt_enabled:
tqdm_dict["split_idx"] = self.trainer.split_idx
tqdm_dict["split_idx"] = self.trainer.train_loop.split_idx

if self.trainer.logger is not None and self.trainer.logger.version is not None:
version = self.trainer.logger.version
Expand Down
3 changes: 1 addition & 2 deletions pytorch_lightning/core/optimizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,6 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import types
from contextlib import contextmanager
from typing import Callable, Optional
from weakref import proxy
Expand Down Expand Up @@ -207,7 +206,7 @@ def closure_dis():
profiler_name = "closure_{self._optimizer_idx}"
closure = do_nothing_closure
else:
if not isinstance(closure, types.FunctionType):
if not callable(closure):
raise MisconfigurationException("When closure is provided, it should be a function")
profiler_name = f"optimizer_step_and_closure_{self._optimizer_idx}"

Expand Down
1 change: 1 addition & 0 deletions pytorch_lightning/plugins/environments/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.
from pytorch_lightning.plugins.environments.cluster_environment import ClusterEnvironment # noqa: F401
from pytorch_lightning.plugins.environments.kubeflow_environment import KubeflowEnvironment # noqa: F401
from pytorch_lightning.plugins.environments.lightning_environment import LightningEnvironment # noqa: F401
from pytorch_lightning.plugins.environments.slurm_environment import SLURMEnvironment # noqa: F401
from pytorch_lightning.plugins.environments.torchelastic_environment import TorchElasticEnvironment # noqa: F401
63 changes: 63 additions & 0 deletions pytorch_lightning/plugins/environments/kubeflow_environment.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,63 @@
# Copyright The PyTorch Lightning team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import logging
import os

from pytorch_lightning.plugins.environments.cluster_environment import ClusterEnvironment

log = logging.getLogger(__name__)


class KubeflowEnvironment(ClusterEnvironment):
"""
Environment for distributed training using the
`PyTorchJob <https://www.kubeflow.org/docs/components/training/pytorch/>`_
operator from `Kubeflow <https://www.kubeflow.org>`_
"""

@staticmethod
def is_using_kubeflow() -> bool:
""" Returns ``True`` if the current process was launched using Kubeflow PyTorchJob. """
required_env_vars = ("KUBERNETES_PORT", "MASTER_ADDR", "MASTER_PORT", "WORLD_SIZE", "RANK")
# torchelastic sets these. Make sure we're not in torchelastic
excluded_env_vars = ("GROUP_RANK", "LOCAL_RANK", "LOCAL_WORLD_SIZE")
return (all(v in os.environ for v in required_env_vars) and not any(v in os.environ for v in excluded_env_vars))

def creates_children(self) -> bool:
return True

def master_address(self) -> str:
return os.environ['MASTER_ADDR']

def master_port(self) -> int:
return int(os.environ['MASTER_PORT'])

def world_size(self) -> int:
return int(os.environ['WORLD_SIZE'])

def set_world_size(self, size: int) -> None:
log.debug("KubeflowEnvironment.set_world_size was called, but setting world size is not allowed. Ignored.")

def global_rank(self) -> int:
return int(os.environ["RANK"])

def set_global_rank(self, rank: int) -> None:
log.debug("KubeflowEnvironment.set_global_rank was called, but setting global rank is not allowed. Ignored.")

def local_rank(self) -> int:
return 0

def node_rank(self) -> int:
return self.global_rank()
8 changes: 4 additions & 4 deletions pytorch_lightning/plugins/training_type/tpu_spawn.py
Original file line number Diff line number Diff line change
Expand Up @@ -67,7 +67,7 @@ def local_rank(self) -> int:

@property
def world_size(self) -> int:
return self.num_processes
return xm.xrt_world_size()

@property
def root_device(self) -> torch.device:
Expand Down Expand Up @@ -168,7 +168,7 @@ def new_process(self, process_idx: int, trainer, mp_queue) -> None:
self.barrier("end-process")

# https://github.com/pytorch/xla/issues/2190#issuecomment-641665358
if self.global_rank == 0:
if self.local_rank == 0:
time.sleep(2)

def model_to_device(self) -> None:
Expand All @@ -185,7 +185,7 @@ def transfer_distrib_spawn_state_on_fit_end(self, results):
best_model_path = checkpoint_callback.best_model_path if checkpoint_callback else None

if self.mp_queue is not None:
rank_zero_warn("cleaning up ddp environment...")
rank_zero_warn("cleaning up tpu spawn environment...")

# save the last weights
last_path = None
Expand All @@ -196,7 +196,7 @@ def transfer_distrib_spawn_state_on_fit_end(self, results):
last_path = re.sub(".ckpt", ".tmp_end.ckpt", best_model_path)
self.save(self.lightning_module.state_dict(), last_path)

if self.global_rank == 0:
if self.local_rank == 0:
# todo, pass complete checkpoint as state dictionary
self.mp_queue.put(best_model_path)
self.mp_queue.put(last_path)
Expand Down
10 changes: 9 additions & 1 deletion pytorch_lightning/trainer/connectors/accelerator_connector.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,6 +46,7 @@
)
from pytorch_lightning.plugins.environments import (
ClusterEnvironment,
KubeflowEnvironment,
LightningEnvironment,
SLURMEnvironment,
TorchElasticEnvironment,
Expand Down Expand Up @@ -397,10 +398,12 @@ def select_training_type_plugin(self) -> TrainingTypePlugin:
elif self.use_ddp:
use_slurm_ddp = self.use_ddp and self.is_slurm_managing_tasks
use_torchelastic_ddp = self.use_ddp and TorchElasticEnvironment.is_using_torchelastic()
use_kubeflow_ddp = self.use_ddp and KubeflowEnvironment.is_using_kubeflow()
use_ddp_spawn = self._distrib_type == DistributedType.DDP_SPAWN
use_ddp_cpu_spawn = self.use_ddp and self.on_cpu
use_tpu_spawn = self.on_tpu and self._distrib_type == DistributedType.TPU_SPAWN
use_ddp_cpu_torch_elastic = use_ddp_cpu_spawn and TorchElasticEnvironment.is_using_torchelastic()
use_ddp_cpu_kubeflow = use_ddp_cpu_spawn and KubeflowEnvironment.is_using_kubeflow()
use_ddp_cpu_slurm = use_ddp_cpu_spawn and self.is_slurm_managing_tasks
use_ddp_sharded = self._distrib_type == DistributedType.DDP_SHARDED
use_ddp_sharded_spawn = self._distrib_type == DistributedType.DDP_SHARDED_SPAWN
Expand All @@ -416,7 +419,10 @@ def select_training_type_plugin(self) -> TrainingTypePlugin:
ddp_plugin_cls = DDPShardedPlugin
elif use_ddp_sharded_spawn:
ddp_plugin_cls = DDPSpawnShardedPlugin
elif use_ddp_cpu_slurm or use_slurm_ddp or use_ddp_cpu_torch_elastic or use_torchelastic_ddp:
elif (
use_ddp_cpu_slurm or use_slurm_ddp or use_ddp_cpu_torch_elastic or use_torchelastic_ddp
or use_kubeflow_ddp or use_ddp_cpu_kubeflow
):
ddp_plugin_cls = DDPPlugin
elif use_ddp_spawn or use_ddp_cpu_spawn:
ddp_plugin_cls = DDPSpawnPlugin
Expand Down Expand Up @@ -488,6 +494,8 @@ def select_cluster_environment(self) -> ClusterEnvironment:
env = SLURMEnvironment()
elif TorchElasticEnvironment.is_using_torchelastic():
env = TorchElasticEnvironment()
elif KubeflowEnvironment.is_using_kubeflow():
env = KubeflowEnvironment()
else:
env = LightningEnvironment()
return env
Expand Down
3 changes: 2 additions & 1 deletion pytorch_lightning/trainer/connectors/data_connector.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,8 +24,9 @@

class DataConnector(object):

def __init__(self, trainer):
def __init__(self, trainer: "pl.Trainer", multiple_trainloader_mode: str = "max_size_cycle"):
self.trainer = trainer
self.multiple_trainloader_mode = multiple_trainloader_mode

def on_trainer_init(
self, check_val_every_n_epoch: int, reload_dataloaders_every_epoch: bool, prepare_data_per_node: bool
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -126,7 +126,6 @@ def on_trainer_init(self, logger, flush_logs_every_n_steps: int, log_every_n_ste
self.trainer.flush_logs_every_n_steps = flush_logs_every_n_steps
self.trainer.log_every_n_steps = log_every_n_steps
self.trainer.move_metrics_to_cpu = move_metrics_to_cpu
self.trainer.split_idx = None

@property
def should_flush_logs(self):
Expand Down
2 changes: 1 addition & 1 deletion pytorch_lightning/trainer/data_loading.py
Original file line number Diff line number Diff line change
Expand Up @@ -259,7 +259,7 @@ def reset_train_dataloader(self, model: LightningModule) -> None:
apply_to_collection(self.train_dataloader, DataLoader, self.auto_add_worker_init_fn)

# wrap the sequence of train loaders to a CombinedLoader object for computing the num_training_batches
self.train_dataloader = CombinedLoader(self.train_dataloader, self._multiple_trainloader_mode)
self.train_dataloader = CombinedLoader(self.train_dataloader, self.data_connector.multiple_trainloader_mode)

self.num_training_batches = len(self.train_dataloader) if has_len(self.train_dataloader) else float('inf')

Expand Down
8 changes: 3 additions & 5 deletions pytorch_lightning/trainer/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -313,7 +313,7 @@ def __init__(
# init connectors
self.dev_debugger = InternalDebugger(self)
self.config_validator = ConfigValidator(self)
self.data_connector = DataConnector(self)
self.data_connector = DataConnector(self, multiple_trainloader_mode)
self.optimizer_connector = OptimizerConnector(self)

self.accelerator_connector = AcceleratorConnector(
Expand All @@ -329,9 +329,7 @@ def __init__(
self.checkpoint_connector = CheckpointConnector(self)
self.slurm_connector = SLURMConnector(self)
self.tuner = Tuner(self)
self.train_loop = TrainLoop(
self, multiple_trainloader_mode, max_epochs, min_epochs, max_steps, min_steps, num_sanity_val_steps
)
self.train_loop = TrainLoop(self, max_epochs, min_epochs, max_steps, min_steps, num_sanity_val_steps)
self.evaluation_loop = EvaluationLoop(self)
self.predict_loop = PredictLoop(self)

Expand Down Expand Up @@ -1000,7 +998,7 @@ def _run_evaluation(self, on_epoch: bool = False) -> _EVALUATE_OUTPUT:
self.optimizer_connector.update_learning_rates(
interval='epoch',
opt_indices=[
opt_idx for opt_idx, _ in self.train_loop.get_optimizers_iterable(
opt_idx for opt_idx, _ in self.train_loop.get_active_optimizers(
batch_idx=(self.train_loop.total_batch_idx - 1)
) # Select the optimizers which were used in the last batch of the epoch
],
Expand Down
Loading

0 comments on commit 3fb1bec

Please sign in to comment.