Skip to content

Commit

Permalink
...
Browse files Browse the repository at this point in the history
  • Loading branch information
Borda committed Mar 23, 2021
1 parent f35dda8 commit 40d952c
Show file tree
Hide file tree
Showing 8 changed files with 19 additions and 24 deletions.
2 changes: 1 addition & 1 deletion pytorch_lightning/accelerators/accelerator.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import Any, Callable, Iterable, Optional, Union, Sequence
from typing import Any, Callable, Iterable, Optional, Sequence, Union

import torch
from torch.optim import Optimizer
Expand Down
2 changes: 1 addition & 1 deletion pytorch_lightning/accelerators/gpu.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
import logging
import os
from typing import TYPE_CHECKING, Any
from typing import Any, TYPE_CHECKING

import torch

Expand Down
2 changes: 1 addition & 1 deletion pytorch_lightning/trainer/callback_hook.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@
from abc import ABC
from copy import deepcopy
from inspect import signature
from typing import List, Dict, Any, Type, Callable
from typing import Any, Callable, Dict, List, Type

from pytorch_lightning.callbacks import Callback
from pytorch_lightning.core.lightning import LightningModule
Expand Down
7 changes: 1 addition & 6 deletions pytorch_lightning/trainer/connectors/callback_connector.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,12 +14,7 @@
import os
from typing import List, Union

from pytorch_lightning.callbacks import (
Callback,
ModelCheckpoint,
ProgressBar,
ProgressBarBase,
)
from pytorch_lightning.callbacks import Callback, ModelCheckpoint, ProgressBar, ProgressBarBase
from pytorch_lightning.core.lightning import LightningModule
from pytorch_lightning.utilities import rank_zero_info, rank_zero_warn
from pytorch_lightning.utilities.exceptions import MisconfigurationException
Expand Down
3 changes: 1 addition & 2 deletions tests/accelerators/test_cpu.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,8 +14,7 @@ def test_unsupported_precision_plugins():
trainer = Mock()
model = Mock()
accelerator = CPUAccelerator(
training_type_plugin=SingleDevicePlugin(torch.device("cpu")),
precision_plugin=MixedPrecisionPlugin()
training_type_plugin=SingleDevicePlugin(torch.device("cpu")), precision_plugin=MixedPrecisionPlugin()
)
with pytest.raises(MisconfigurationException, match=r"amp \+ cpu is not supported."):
accelerator.setup(trainer=trainer, model=model)
2 changes: 1 addition & 1 deletion tests/accelerators/test_dp.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,10 +23,10 @@
from pytorch_lightning.callbacks import EarlyStopping
from pytorch_lightning.core import memory
from pytorch_lightning.utilities.exceptions import MisconfigurationException
from tests.base import EvalModelTemplate
from tests.helpers import BoringModel, RandomDataset
from tests.helpers.datamodules import ClassifDataModule
from tests.helpers.simple_models import ClassificationModel
from tests.base import EvalModelTemplate


class CustomClassificationModelDP(ClassificationModel):
Expand Down
7 changes: 6 additions & 1 deletion tests/deprecated_api/test_remove_1-5.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@

import pytest

from pytorch_lightning import Trainer, Callback
from pytorch_lightning import Callback, Trainer
from pytorch_lightning.loggers import WandbLogger
from tests.helpers import BoringModel
from tests.helpers.utils import no_warning_call
Expand All @@ -30,7 +30,9 @@ def test_v1_5_0_wandb_unused_sync_step(tmpdir):


def test_v1_5_0_old_callback_on_save_checkpoint(tmpdir):

class OldSignature(Callback):

def on_save_checkpoint(self, trainer, pl_module): # noqa
...

Expand All @@ -49,14 +51,17 @@ def on_save_checkpoint(self, trainer, pl_module): # noqa
trainer.save_checkpoint(filepath)

class NewSignature(Callback):

def on_save_checkpoint(self, trainer, pl_module, checkpoint):
...

class ValidSignature1(Callback):

def on_save_checkpoint(self, trainer, *args):
...

class ValidSignature2(Callback):

def on_save_checkpoint(self, *args):
...

Expand Down
18 changes: 7 additions & 11 deletions tests/plugins/test_deepspeed_plugin.py
Original file line number Diff line number Diff line change
Expand Up @@ -285,16 +285,8 @@ def on_train_start(self) -> None:
raise SystemExit()

model = TestModel()
ds = DeepSpeedPlugin(
loss_scale=10, initial_scale_power=10, loss_scale_window=10, hysteresis=10, min_loss_scale=10
)
trainer = Trainer(
plugins=[
ds
],
precision=16,
gpus=1
)
ds = DeepSpeedPlugin(loss_scale=10, initial_scale_power=10, loss_scale_window=10, hysteresis=10, min_loss_scale=10)
trainer = Trainer(plugins=[ds], precision=16, gpus=1)
with pytest.raises(SystemExit):
trainer.fit(model)

Expand All @@ -318,7 +310,11 @@ def on_train_start(self) -> None:
raise SystemExit()

model = TestModel()
trainer = Trainer(plugins=[DeepSpeedPlugin(config=deepspeed_zero_config)], precision=16, gpus=1,)
trainer = Trainer(
plugins=[DeepSpeedPlugin(config=deepspeed_zero_config)],
precision=16,
gpus=1,
)
with pytest.raises(SystemExit):
trainer.fit(model)

Expand Down

0 comments on commit 40d952c

Please sign in to comment.