Skip to content

Commit

Permalink
fix imports
Browse files Browse the repository at this point in the history
  • Loading branch information
justusschock committed Jan 31, 2021
1 parent 94e0b28 commit de8fe1b
Show file tree
Hide file tree
Showing 3 changed files with 12 additions and 9 deletions.
9 changes: 5 additions & 4 deletions pytorch_lightning/accelerators/accelerator.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,13 +17,14 @@
from torch.optim import Optimizer

from pytorch_lightning.core import LightningModule
from pytorch_lightning.plugins.training_type import TrainingTypePlugin, HorovodPlugin
from pytorch_lightning.plugins.precision import (
PrecisionPlugin,
MixedPrecisionPlugin,
ApexMixedPrecisionPlugin,
MixedPrecisionPlugin,
NativeMixedPrecisionPlugin,
PrecisionPlugin,
)
from pytorch_lightning.plugins.training_type import TrainingTypePlugin
from pytorch_lightning.plugins.training_type.horovod import HorovodPlugin
from pytorch_lightning.utilities.apply_func import move_data_to_device
from pytorch_lightning.utilities.enums import AMPType, LightningEnum

Expand Down Expand Up @@ -374,4 +375,4 @@ def optimizer_state(self, optimizer: Optimizer) -> dict:
return optimizer.state_dict()

def on_save(self, checkpoint):
return checkpoint
return checkpoint
10 changes: 6 additions & 4 deletions pytorch_lightning/plugins/training_type/training_type_plugin.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,14 +13,16 @@
# limitations under the License.
import os
from abc import ABC, abstractmethod
from typing import Any, Optional, Sequence, Union
from typing import Any, Optional, Sequence, TYPE_CHECKING, Union

import torch

from pytorch_lightning import _logger as log
from pytorch_lightning.core.lightning import LightningModule
from pytorch_lightning.plugins.base_plugin import Plugin
from pytorch_lightning.trainer import Trainer

if TYPE_CHECKING:
from pytorch_lightning.trainer.trainer import Trainer


class TrainingTypePlugin(Plugin, ABC):
Expand Down Expand Up @@ -105,10 +107,10 @@ def results(self) -> Any:
def rpc_enabled(self) -> bool:
return False

def start_training(self, trainer: Trainer) -> None:
def start_training(self, trainer: 'Trainer') -> None:
# double dispatch to initiate the training loop
self._results = trainer.train()

def start_testing(self, trainer: Trainer) -> None:
def start_testing(self, trainer: 'Trainer') -> None:
# double dispatch to initiate the test loop
self._results = trainer.run_test()
2 changes: 1 addition & 1 deletion setup.cfg
Original file line number Diff line number Diff line change
Expand Up @@ -142,7 +142,7 @@ ignore_errors = True
ignore_errors = True

# todo: add proper typing to this module...
[mypy-pytorch_lightning.accelerators.legacy.*]
[mypy-pytorch_lightning.accelerators.*]
ignore_errors = True

# todo: add proper typing to this module...
Expand Down

0 comments on commit de8fe1b

Please sign in to comment.