Skip to content

Commit

Permalink
Merge branch 'master' into bug/5459
Browse files Browse the repository at this point in the history
  • Loading branch information
carmocca authored Jan 18, 2021
2 parents b20bd4d + a56f745 commit 6ffc2b0
Show file tree
Hide file tree
Showing 10 changed files with 88 additions and 61 deletions.
7 changes: 7 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,12 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
- Fixed logging on_train_batch_end in a callback with multiple optimizers ([#5521](https://github.com/PyTorchLightning/pytorch-lightning/pull/5521))


- Fixed `reinit_scheduler_properties` with correct optimizer ([#5519](https://github.com/PyTorchLightning/pytorch-lightning/pull/5519))


- Fixed `val_check_interval` with `fast_dev_run` ([#5540](https://github.com/PyTorchLightning/pytorch-lightning/pull/5540))


## [1.1.4] - 2021-01-12

### Added
Expand All @@ -43,6 +49,7 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
- Logging only on `not should_accumulate()` during training ([#5417](https://github.com/PyTorchLightning/pytorch-lightning/pull/5417))
- Resolve interpolation bug with Hydra ([#5406](https://github.com/PyTorchLightning/pytorch-lightning/pull/5406))
- Check environ before selecting a seed to prevent warning message ([#4743](https://github.com/PyTorchLightning/pytorch-lightning/pull/4743))
- Fixed signature mismatch in `model_to_device` of `DDPCPUHPCAccelerator` ([#5505](https://github.com/PyTorchLightning/pytorch-lightning/pull/5505))


## [1.1.3] - 2021-01-05
Expand Down
4 changes: 2 additions & 2 deletions pytorch_lightning/accelerators/accelerator_connector.py
Original file line number Diff line number Diff line change
Expand Up @@ -342,8 +342,8 @@ def set_distributed_mode(self):
# throw error to force user ddp or ddp2 choice
if self.trainer.num_nodes > 1 and not (self.trainer.use_ddp2 or self.trainer.use_ddp):
raise MisconfigurationException(
'DataParallel does not support num_nodes > 1. Switching to DistributedDataParallel for you. '
'To silence this warning set `accelerator="ddp"` or `accelerator="ddp2"`'
'DataParallel does not support num_nodes > 1. '
'To avoid this exception, set `accelerator="ddp"` or `accelerator="ddp2"`'
)

rank_zero_info(f'GPU available: {torch.cuda.is_available()}, used: {self.trainer.on_gpu}')
Expand Down
2 changes: 1 addition & 1 deletion pytorch_lightning/accelerators/ddp_cpu_hpc_accelerator.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,7 @@ def __init__(self,
super().__init__(trainer, cluster_environment, ddp_plugin)
self.nickname = 'ddp_cpu'

def model_to_device(self, model, process_idx):
def model_to_device(self, model):
model.cpu()

def get_device_ids(self):
Expand Down
24 changes: 0 additions & 24 deletions pytorch_lightning/accelerators/dp_accelerator.py
Original file line number Diff line number Diff line change
Expand Up @@ -144,30 +144,6 @@ def test_step_end(self, output):
output = output.mean()
return output

def reinit_scheduler_properties(self, optimizers: list, schedulers: list):
"""
Reinitialize optimizer.step properties added by schedulers
"""
for scheduler in schedulers:
scheduler = scheduler['scheduler']

for optimizer in optimizers:
# check that we dont mix users optimizers and schedulers
if scheduler.optimizer == optimizer:
# Find the mro belonging to the base lr scheduler class
for i, mro in enumerate(scheduler.__class__.__mro__):
is_regular_scheduler = optim.lr_scheduler._LRScheduler
is_lr_reduce_on_plateau = optim.lr_scheduler.ReduceLROnPlateau
if is_regular_scheduler or is_lr_reduce_on_plateau:
idx = i
state = scheduler.state_dict()
else:
state = None

scheduler.__class__.__mro__[idx].__init__(scheduler, optimizer)
if state is not None:
scheduler.load_state_dict(state)

def get_reference_model(self, model) -> LightningModule:
if isinstance(model, LightningDataParallel):
return model.module
Expand Down
17 changes: 15 additions & 2 deletions pytorch_lightning/loggers/tensorboard.py
Original file line number Diff line number Diff line change
Expand Up @@ -144,8 +144,21 @@ def experiment(self) -> SummaryWriter:
return self._experiment

@rank_zero_only
def log_hyperparams(self, params: Union[Dict[str, Any], Namespace],
metrics: Optional[Dict[str, Any]] = None) -> None:
def log_hyperparams(
self,
params: Union[Dict[str, Any], Namespace],
metrics: Optional[Dict[str, Any]] = None,
) -> None:
"""
Record hyperparameters. TensorBoard logs with and without saved hyperparameters
are incompatible, the hyperparameters are then not displayed in the TensorBoard.
Please delete or move the previously saved logs to display the new ones with hyperparameters.
Args:
params: a dictionary-like container with the hyperparameters
metrics: Dictionary with metric names as keys and measured quantities as values
"""

params = self._convert_params(params)

# store params to output
Expand Down
2 changes: 0 additions & 2 deletions pytorch_lightning/metrics/classification/precision_recall.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,6 @@ class Precision(Metric):
Args:
num_classes: Number of classes in the dataset.
beta: Beta coefficient in the F measure.
threshold:
Threshold value for binary or multi-label logits. default: 0.5
Expand Down Expand Up @@ -135,7 +134,6 @@ class Recall(Metric):
Args:
num_classes: Number of classes in the dataset.
beta: Beta coefficient in the F measure.
threshold:
Threshold value for binary or multi-label logits. default: 0.5
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -59,7 +59,7 @@ def on_init_start(
self.trainer.max_steps = fast_dev_run
self.trainer.num_sanity_val_steps = 0
self.trainer.max_epochs = 1
self.trainer.val_check_interval = 1.0
val_check_interval = 1.0
self.trainer.check_val_every_n_epoch = 1
self.trainer.logger = DummyLogger()

Expand Down
10 changes: 5 additions & 5 deletions pytorch_lightning/trainer/optimizers.py
Original file line number Diff line number Diff line change
Expand Up @@ -145,21 +145,21 @@ def reinit_scheduler_properties(self, optimizers: list, schedulers: list):
# Reinitialize optimizer.step properties added by schedulers
for scheduler in schedulers:
scheduler = scheduler['scheduler']
state = None

for optimizer in optimizers:
# check that we dont mix users optimizers and schedulers
if scheduler.optimizer == optimizer:
# Find the mro belonging to the base lr scheduler class
for i, mro in enumerate(scheduler.__class__.__mro__):
if mro in (optim.lr_scheduler._LRScheduler, optim.lr_scheduler.ReduceLROnPlateau):
idx = i
state = scheduler.state_dict()
else:
state = None
scheduler.__class__.__mro__[i].__init__(scheduler, optimizer)
scheduler.load_state_dict(state)
break

scheduler.__class__.__mro__[idx].__init__(scheduler, optimizer)
if state is not None:
scheduler.load_state_dict(state)
break


class _MockOptimizer(Optimizer):
Expand Down
22 changes: 16 additions & 6 deletions tests/models/test_amp.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,15 +16,16 @@

import pytest
import torch
from torch import optim

import tests.base.develop_pipelines as tpipes
import tests.base.develop_utils as tutils
from pytorch_lightning import Trainer
from pytorch_lightning.loggers import WandbLogger
from pytorch_lightning.trainer.states import TrainerState
from pytorch_lightning.utilities import APEX_AVAILABLE
from pytorch_lightning.utilities.exceptions import MisconfigurationException
from tests.base import EvalModelTemplate
import tests.base.develop_pipelines as tpipes
import tests.base.develop_utils as tutils


@pytest.mark.skip(reason='dp + amp not supported currently') # TODO
Expand Down Expand Up @@ -189,9 +190,15 @@ def test_amp_without_apex(tmpdir):
@pytest.mark.skipif(not APEX_AVAILABLE, reason="test requires apex")
def test_amp_with_apex(tmpdir):
"""Check calling apex scaling in training."""

model = EvalModelTemplate()

class CustomModel(EvalModelTemplate):
def configure_optimizers(self):
optimizer1 = optim.Adam(self.parameters(), lr=self.learning_rate)
optimizer2 = optim.SGD(self.parameters(), lr=self.learning_rate)
lr_scheduler1 = optim.lr_scheduler.StepLR(optimizer1, 1, gamma=0.1)
lr_scheduler2 = optim.lr_scheduler.StepLR(optimizer2, 1, gamma=0.1)
return [optimizer1, optimizer2], [lr_scheduler1, lr_scheduler2]

model = CustomModel()
trainer = Trainer(
default_root_dir=tmpdir,
max_epochs=1,
Expand All @@ -202,4 +209,7 @@ def test_amp_with_apex(tmpdir):
assert str(trainer.amp_backend) == "AMPType.APEX"
trainer.fit(model)
assert trainer.state == TrainerState.FINISHED
assert trainer.dev_debugger.count_events('AMP') == 10
assert trainer.dev_debugger.count_events('AMP') == 20

assert isinstance(trainer.lr_schedulers[0]['scheduler'].optimizer, optim.Adam)
assert isinstance(trainer.lr_schedulers[1]['scheduler'].optimizer, optim.SGD)
59 changes: 41 additions & 18 deletions tests/trainer/flags/test_fast_dev_run.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,30 +36,59 @@ def test_callbacks_and_logger_not_called_with_fastdevrun(tmpdir, fast_dev_run):
class FastDevRunModel(BoringModel):
def __init__(self):
super().__init__()
self.training_step_called = False
self.validation_step_called = False
self.test_step_called = False
self.training_step_call_count = 0
self.training_epoch_end_call_count = 0
self.validation_step_call_count = 0
self.validation_epoch_end_call_count = 0
self.test_step_call_count = 0

def training_step(self, batch, batch_idx):
self.log('some_metric', torch.tensor(7.))
self.logger.experiment.dummy_log('some_distribution', torch.randn(7) + batch_idx)
self.training_step_called = True
self.training_step_call_count += 1
return super().training_step(batch, batch_idx)

def training_epoch_end(self, outputs):
self.training_epoch_end_call_count += 1
super().training_epoch_end(outputs)

def validation_step(self, batch, batch_idx):
self.validation_step_called = True
self.validation_step_call_count += 1
return super().validation_step(batch, batch_idx)

def validation_epoch_end(self, outputs):
self.validation_epoch_end_call_count += 1
super().validation_epoch_end(outputs)

def test_step(self, batch, batch_idx):
self.test_step_call_count += 1
return super().test_step(batch, batch_idx)

checkpoint_callback = ModelCheckpoint()
early_stopping_callback = EarlyStopping()
trainer_config = dict(
fast_dev_run=fast_dev_run,
val_check_interval=2,
logger=True,
log_every_n_steps=1,
callbacks=[checkpoint_callback, early_stopping_callback],
)

def _make_fast_dev_run_assertions(trainer):
def _make_fast_dev_run_assertions(trainer, model):
# check the call count for train/val/test step/epoch
assert model.training_step_call_count == fast_dev_run
assert model.training_epoch_end_call_count == 1
assert model.validation_step_call_count == 0 if model.validation_step is None else fast_dev_run
assert model.validation_epoch_end_call_count == 0 if model.validation_step is None else 1
assert model.test_step_call_count == fast_dev_run

# check trainer arguments
assert trainer.max_steps == fast_dev_run
assert trainer.num_sanity_val_steps == 0
assert trainer.max_epochs == 1
assert trainer.val_check_interval == 1.0
assert trainer.check_val_every_n_epoch == 1

# there should be no logger with fast_dev_run
assert isinstance(trainer.logger, DummyLogger)
assert len(trainer.dev_debugger.logged_metrics) == fast_dev_run
Expand All @@ -76,13 +105,10 @@ def _make_fast_dev_run_assertions(trainer):
train_val_step_model = FastDevRunModel()
trainer = Trainer(**trainer_config)
results = trainer.fit(train_val_step_model)
assert results
trainer.test(ckpt_path=None)

# make sure both training_step and validation_step were called
assert train_val_step_model.training_step_called
assert train_val_step_model.validation_step_called

_make_fast_dev_run_assertions(trainer)
assert results
_make_fast_dev_run_assertions(trainer, train_val_step_model)

# -----------------------
# also called once with no val step
Expand All @@ -92,10 +118,7 @@ def _make_fast_dev_run_assertions(trainer):

trainer = Trainer(**trainer_config)
results = trainer.fit(train_step_only_model)
assert results
trainer.test(ckpt_path=None)

# make sure only training_step was called
assert train_step_only_model.training_step_called
assert not train_step_only_model.validation_step_called

_make_fast_dev_run_assertions(trainer)
assert results
_make_fast_dev_run_assertions(trainer, train_step_only_model)

0 comments on commit 6ffc2b0

Please sign in to comment.