Skip to content

Commit

Permalink
Set precision=16 when use_amp is passed as True (Lightning-AI#1145)
Browse files Browse the repository at this point in the history
* Set precision=16 when use_amp is passed as True

* Update CHANGELOG.md

* add use_amp to deprecated API

* Update trainer.py

* Update trainer.py

* move the use_amp attribute to deprecated API

* move use_amp deprecation back to Trainer's __init__

* drop unsed

* drop deprecated

* reorder imports

* typing

Co-authored-by: Jirka Borovec <Borda@users.noreply.github.com>
Co-authored-by: William Falcon <waf2107@columbia.edu>
Co-authored-by: J. Borovec <jirka.borovec@seznam.cz>
  • Loading branch information
4 people authored and tullie committed May 6, 2020
1 parent 5b94bda commit 73ebab5
Show file tree
Hide file tree
Showing 8 changed files with 34 additions and 18 deletions.
2 changes: 2 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -81,6 +81,8 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
- Fixed checkpointing interval ([#1272](https://github.com/PyTorchLightning/pytorch-lightning/pull/1272))
- Fixed validation and training loops run the partial dataset ([#1192](https://github.com/PyTorchLightning/pytorch-lightning/pull/1192))
- Fixed running `on_validation_end` only on main process in DDP ([#1125](https://github.com/PyTorchLightning/pytorch-lightning/pull/1125))
- Fixes `use_amp` issue ([#1145](https://github.com/PyTorchLightning/pytorch-lightning/pull/1145))
- Fixes using deprecated `use_amp` attribute ([#1145](https://github.com/PyTorchLightning/pytorch-lightning/pull/1145))

## [0.7.1] - 2020-03-07

Expand Down
2 changes: 1 addition & 1 deletion pl_examples/basic_examples/gpu_template.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,7 @@ def main(hparams):
max_epochs=hparams.epochs,
gpus=hparams.gpus,
distributed_backend=hparams.distributed_backend,
use_amp=hparams.use_16bit
precision=16 if hparams.use_16bit else 32,
)

# ------------------------
Expand Down
2 changes: 1 addition & 1 deletion pl_examples/domain_templates/imagenet.py
Original file line number Diff line number Diff line change
Expand Up @@ -226,7 +226,7 @@ def main(hparams):
gpus=hparams.gpus,
max_epochs=hparams.epochs,
distributed_backend=hparams.distributed_backend,
use_amp=hparams.use_16bit
precision=16 if hparams.use_16bit else 32,
)
if hparams.evaluate:
trainer.run_evaluation()
Expand Down
18 changes: 10 additions & 8 deletions pytorch_lightning/trainer/auto_mix_precision.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,19 +14,21 @@ class TrainerAMPMixin(ABC):

# this is just a summary on variables used in this abstract class,
# the proper values/initialisation should be done in child class
use_amp: bool
precision: int

def init_amp(self, use_amp):
self.use_amp = use_amp and APEX_AVAILABLE
if self.use_amp:
log.info('Using 16bit precision.')

if use_amp and not APEX_AVAILABLE: # pragma: no-cover
msg = """
raise ModuleNotFoundError("""
You set `use_amp=True` but do not have apex installed.
Install apex first using this guide and rerun with use_amp=True:
https://github.com/NVIDIA/apex#linux
this run will NOT use 16 bit precision
"""
raise ModuleNotFoundError(msg)
""")

if self.use_amp:
log.info('Using 16bit precision.')

@property
def use_amp(self) -> bool:
return self.precision == 16 and APEX_AVAILABLE
6 changes: 5 additions & 1 deletion pytorch_lightning/trainer/distrib_data_parallel.py
Original file line number Diff line number Diff line change
Expand Up @@ -141,7 +141,6 @@ class TrainerDDPMixin(ABC):
logger: Union[LightningLoggerBase, bool]
data_parallel_device_ids: ...
distributed_backend: str
use_amp: bool
amp_level: str
use_tpu: bool
default_save_path: str
Expand All @@ -151,6 +150,11 @@ class TrainerDDPMixin(ABC):
def num_gpus(self) -> int:
"""Warning: this is just empty shell for code implemented in other class."""

@property
@abstractmethod
def use_amp(self) -> bool:
"""Warning: this is just empty shell for code implemented in other class."""

@abstractmethod
def copy_trainer_model_properties(self, *args):
"""Warning: this is just empty shell for code implemented in other class."""
Expand Down
6 changes: 5 additions & 1 deletion pytorch_lightning/trainer/distrib_parts.py
Original file line number Diff line number Diff line change
Expand Up @@ -372,7 +372,6 @@ class TrainerDPMixin(ABC):
use_dp: bool
use_ddp2: bool
use_ddp: bool
use_amp: bool
testing: bool
single_gpu: bool
root_gpu: ...
Expand All @@ -385,6 +384,11 @@ class TrainerDPMixin(ABC):
use_tpu: bool
data_parallel_device_ids: ...

@property
@abstractmethod
def use_amp(self) -> bool:
"""Warning: this is just empty shell for code implemented in other class."""

@abstractmethod
def run_pretrain_routine(self, *args):
"""Warning: this is just empty shell for code implemented in other class."""
Expand Down
15 changes: 10 additions & 5 deletions pytorch_lightning/trainer/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,8 +21,7 @@
from pytorch_lightning.trainer.callback_config import TrainerCallbackConfigMixin
from pytorch_lightning.trainer.callback_hook import TrainerCallbackHookMixin
from pytorch_lightning.trainer.data_loading import TrainerDataLoadingMixin
from pytorch_lightning.trainer.deprecated_api import (TrainerDeprecatedAPITillVer0_8,
TrainerDeprecatedAPITillVer0_9)
from pytorch_lightning.trainer.deprecated_api import TrainerDeprecatedAPITillVer0_8, TrainerDeprecatedAPITillVer0_9
from pytorch_lightning.trainer.distrib_data_parallel import TrainerDDPMixin
from pytorch_lightning.trainer.distrib_parts import TrainerDPMixin, parse_gpu_ids, determine_root_gpu_device
from pytorch_lightning.trainer.evaluation_loop import TrainerEvaluationLoopMixin
Expand Down Expand Up @@ -55,13 +54,13 @@
class Trainer(
TrainerIOMixin,
TrainerOptimizersMixin,
TrainerAMPMixin,
TrainerDPMixin,
TrainerDDPMixin,
TrainerLoggingMixin,
TrainerModelHooksMixin,
TrainerTrainingTricksMixin,
TrainerDataLoadingMixin,
TrainerAMPMixin,
TrainerEvaluationLoopMixin,
TrainerTrainLoopMixin,
TrainerCallbackConfigMixin,
Expand All @@ -88,7 +87,6 @@ def __init__(
gpus: Optional[Union[List[int], str, int]] = None,
num_tpu_cores: Optional[int] = None,
log_gpu_memory: Optional[str] = None,
show_progress_bar=None, # backward compatible, todo: remove in v0.9.0
progress_bar_refresh_rate: int = 1,
overfit_pct: float = 0.0,
track_grad_norm: int = -1,
Expand Down Expand Up @@ -122,7 +120,8 @@ def __init__(
nb_gpu_nodes=None, # backward compatible, todo: remove in v0.8.0
max_nb_epochs=None, # backward compatible, todo: remove in v0.8.0
min_nb_epochs=None, # backward compatible, todo: remove in v0.8.0
use_amp=False, # backward compatible, todo: remove in v0.9.0
use_amp=None, # backward compatible, todo: remove in v0.9.0
show_progress_bar=None, # backward compatible, todo: remove in v0.9.0
nb_sanity_val_steps=None, # backward compatible, todo: remove in v0.8.0
**kwargs
):
Expand Down Expand Up @@ -446,6 +445,12 @@ def __init__(
self.amp_level = amp_level
self.precision = precision

# Backward compatibility, TODO: remove in v0.9.0
if use_amp is not None:
warnings.warn("`use_amp` has been replaced by `precision` since v0.7.0"
" and this argument will be removed in v0.9.0", DeprecationWarning)
self.precision = 16 if use_amp else 32

assert self.precision in (16, 32), 'only 32 or 16 bit precision supported'

if self.precision == 16 and self.num_tpu_cores is None:
Expand Down
1 change: 0 additions & 1 deletion pytorch_lightning/trainer/training_loop.py
Original file line number Diff line number Diff line change
Expand Up @@ -200,7 +200,6 @@ class TrainerTrainLoopMixin(ABC):
optimizers: ...
optimizer_frequencies: ...
accumulate_grad_batches: int
use_amp: bool
track_grad_norm: ...
model: LightningModule
interrupted: bool
Expand Down

0 comments on commit 73ebab5

Please sign in to comment.