diff --git a/pytorch_lightning/accelerators/accelerator.py b/pytorch_lightning/accelerators/accelerator.py index 7087f6a261010..984f9a6842b4a 100644 --- a/pytorch_lightning/accelerators/accelerator.py +++ b/pytorch_lightning/accelerators/accelerator.py @@ -17,13 +17,14 @@ from torch.optim import Optimizer from pytorch_lightning.core import LightningModule -from pytorch_lightning.plugins.training_type import TrainingTypePlugin, HorovodPlugin from pytorch_lightning.plugins.precision import ( - PrecisionPlugin, - MixedPrecisionPlugin, ApexMixedPrecisionPlugin, + MixedPrecisionPlugin, NativeMixedPrecisionPlugin, + PrecisionPlugin, ) +from pytorch_lightning.plugins.training_type import TrainingTypePlugin +from pytorch_lightning.plugins.training_type.horovod import HorovodPlugin from pytorch_lightning.utilities.apply_func import move_data_to_device from pytorch_lightning.utilities.enums import AMPType, LightningEnum @@ -374,4 +375,4 @@ def optimizer_state(self, optimizer: Optimizer) -> dict: return optimizer.state_dict() def on_save(self, checkpoint): - return checkpoint \ No newline at end of file + return checkpoint diff --git a/pytorch_lightning/plugins/training_type/training_type_plugin.py b/pytorch_lightning/plugins/training_type/training_type_plugin.py index 349ed689254ad..5dbbf23881373 100644 --- a/pytorch_lightning/plugins/training_type/training_type_plugin.py +++ b/pytorch_lightning/plugins/training_type/training_type_plugin.py @@ -13,14 +13,16 @@ # limitations under the License. import os from abc import ABC, abstractmethod -from typing import Any, Optional, Sequence, Union +from typing import Any, Optional, Sequence, TYPE_CHECKING, Union import torch from pytorch_lightning import _logger as log from pytorch_lightning.core.lightning import LightningModule from pytorch_lightning.plugins.base_plugin import Plugin -from pytorch_lightning.trainer import Trainer + +if TYPE_CHECKING: + from pytorch_lightning.trainer.trainer import Trainer class TrainingTypePlugin(Plugin, ABC): @@ -105,10 +107,10 @@ def results(self) -> Any: def rpc_enabled(self) -> bool: return False - def start_training(self, trainer: Trainer) -> None: + def start_training(self, trainer: 'Trainer') -> None: # double dispatch to initiate the training loop self._results = trainer.train() - def start_testing(self, trainer: Trainer) -> None: + def start_testing(self, trainer: 'Trainer') -> None: # double dispatch to initiate the test loop self._results = trainer.run_test() diff --git a/setup.cfg b/setup.cfg index deccd35af8f98..ee23dd130de10 100644 --- a/setup.cfg +++ b/setup.cfg @@ -142,7 +142,7 @@ ignore_errors = True ignore_errors = True # todo: add proper typing to this module... -[mypy-pytorch_lightning.accelerators.legacy.*] +[mypy-pytorch_lightning.accelerators.*] ignore_errors = True # todo: add proper typing to this module...