From e4af1e9951aa13a73487e8ee3ec38cf860f7afdc Mon Sep 17 00:00:00 2001 From: James Bishop Date: Sun, 5 May 2024 14:11:39 +0100 Subject: [PATCH 1/6] update reqs --- requirements/base.txt | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/requirements/base.txt b/requirements/base.txt index 3dd079d0b..74fcae877 100644 --- a/requirements/base.txt +++ b/requirements/base.txt @@ -1,5 +1,5 @@ numpy <1.26.0 -pytorch-lightning >1.7.0, <2.0.0 # strict +lightning >=2.0.0 # strict torchmetrics >=0.10.0, <0.12.0 lightning-utilities >0.3.1 # this is needed for PL 1.7 torchvision >=0.10.0 # todo: move to topic related extras From e4c3dd66cb9dc4e14f35d46c58095ac2a180fc19 Mon Sep 17 00:00:00 2001 From: James Bishop Date: Sun, 5 May 2024 14:11:48 +0100 Subject: [PATCH 2/6] refactor imports --- src/pl_bolts/callbacks/byol_updates.py | 2 +- src/pl_bolts/callbacks/data_monitor.py | 19 +++++----- src/pl_bolts/callbacks/knn_online.py | 7 ++-- src/pl_bolts/callbacks/printing.py | 7 ++-- src/pl_bolts/callbacks/sparseml.py | 11 +++--- src/pl_bolts/callbacks/ssl_online.py | 9 +++-- src/pl_bolts/callbacks/torch_ort.py | 5 ++- src/pl_bolts/callbacks/variational.py | 7 ++-- src/pl_bolts/callbacks/verification/base.py | 7 ++-- .../callbacks/verification/batch_gradient.py | 14 ++++---- .../callbacks/vision/confused_logit.py | 5 ++- .../callbacks/vision/image_generation.py | 3 +- .../callbacks/vision/sr_image_logger.py | 3 +- .../datamodules/cityscapes_datamodule.py | 5 ++- .../datamodules/imagenet_datamodule.py | 5 ++- src/pl_bolts/datamodules/kitti_datamodule.py | 7 ++-- .../datamodules/sklearn_datamodule.py | 5 ++- src/pl_bolts/datamodules/sr_datamodule.py | 5 ++- .../datamodules/ssl_imagenet_datamodule.py | 5 ++- src/pl_bolts/datamodules/stl10_datamodule.py | 5 ++- src/pl_bolts/datamodules/vision_datamodule.py | 2 +- .../datamodules/vocdetection_datamodule.py | 7 ++-- src/pl_bolts/datasets/array_dataset.py | 5 ++- .../losses/self_supervised_learning.py | 5 ++- .../autoencoders/basic_ae/basic_ae_module.py | 13 ++++--- .../basic_vae/basic_vae_module.py | 13 ++++--- .../faster_rcnn/faster_rcnn_module.py | 12 ++++--- .../detection/retinanet/retinanet_module.py | 15 ++++---- .../models/detection/yolo/darknet_network.py | 15 +++++--- .../models/detection/yolo/yolo_module.py | 6 ++-- .../models/gans/basic/basic_gan_module.py | 19 ++++++---- .../models/gans/dcgan/dcgan_module.py | 12 ++++--- .../models/gans/pix2pix/pix2pix_module.py | 5 ++- src/pl_bolts/models/mnist_module.py | 5 ++- .../models/regression/linear_regression.py | 5 ++- .../models/regression/logistic_regression.py | 5 ++- .../models/rl/advantage_actor_critic_model.py | 12 +++---- src/pl_bolts/models/rl/double_dqn_model.py | 6 ++-- src/pl_bolts/models/rl/dqn_model.py | 14 ++++---- src/pl_bolts/models/rl/dueling_dqn_model.py | 4 +-- src/pl_bolts/models/rl/noisy_dqn_model.py | 6 ++-- src/pl_bolts/models/rl/per_dqn_model.py | 8 ++--- src/pl_bolts/models/rl/ppo_model.py | 9 +++-- src/pl_bolts/models/rl/reinforce_model.py | 13 ++++--- src/pl_bolts/models/rl/sac_model.py | 14 ++++---- .../rl/vanilla_policy_gradient_model.py | 13 ++++--- .../self_supervised/amdim/amdim_module.py | 7 ++-- .../self_supervised/byol/byol_module.py | 20 +++++++---- .../self_supervised/cpc/cpc_finetuner.py | 3 +- .../models/self_supervised/cpc/cpc_module.py | 8 ++--- .../models/self_supervised/moco/callbacks.py | 3 +- .../self_supervised/moco/moco_module.py | 16 ++++++--- .../simclr/simclr_finetuner.py | 13 ++++--- .../self_supervised/simclr/simclr_module.py | 20 +++++++---- .../self_supervised/simsiam/simsiam_module.py | 16 ++++++--- .../models/self_supervised/ssl_finetuner.py | 5 ++- .../self_supervised/swav/swav_finetuner.py | 8 +++-- .../self_supervised/swav/swav_module.py | 19 ++++++---- src/pl_bolts/models/vision/image_gpt/gpt2.py | 5 ++- .../models/vision/image_gpt/igpt_module.py | 5 ++- src/pl_bolts/models/vision/segmentation.py | 5 ++- src/pl_bolts/utils/__init__.py | 5 +-- src/pl_bolts/utils/_dependency.py | 5 ++- src/pl_bolts/utils/arguments.py | 3 +- src/pl_bolts/utils/pretrained_weights.py | 3 +- src/pl_bolts/utils/stability.py | 2 +- tests/__init__.py | 2 +- tests/callbacks/test_data_monitor.py | 2 +- tests/callbacks/test_ort.py | 7 ++-- tests/callbacks/test_sparseml.py | 5 ++- tests/callbacks/test_variational_callbacks.py | 4 +-- tests/callbacks/verification/test_base.py | 5 ++- .../verification/test_batch_gradient.py | 13 ++++--- tests/conftest.py | 10 ++++-- tests/datamodules/test_sklearn_dataloaders.py | 2 +- tests/datasets/test_array_dataset.py | 2 +- tests/helpers/boring_model.py | 2 +- tests/models/gans/integration/test_gans.py | 4 +-- .../models/gans/unit/test_basic_components.py | 2 +- .../integration/test_actor_critic_models.py | 2 +- .../rl/integration/test_policy_models.py | 2 +- .../rl/integration/test_value_models.py | 2 +- tests/models/rl/unit/test_ppo.py | 2 +- tests/models/self_supervised/test_models.py | 35 ++++++++++++++----- tests/models/test_autoencoders.py | 10 ++++-- tests/models/test_classic_ml.py | 2 +- tests/models/test_detection.py | 7 ++-- tests/models/test_mnist_templates.py | 4 +-- tests/models/test_vision.py | 9 ++--- .../models/yolo/unit/test_darknet_network.py | 2 +- tests/models/yolo/unit/test_utils.py | 2 +- tests/optimizers/test_lr_scheduler.py | 2 +- tests/transforms/test_normalizations.py | 2 +- tests/transforms/test_transforms.py | 7 ++-- tests/utils/test_arguments.py | 2 +- 95 files changed, 388 insertions(+), 314 deletions(-) diff --git a/src/pl_bolts/callbacks/byol_updates.py b/src/pl_bolts/callbacks/byol_updates.py index c3e2bd237..cc2999499 100644 --- a/src/pl_bolts/callbacks/byol_updates.py +++ b/src/pl_bolts/callbacks/byol_updates.py @@ -2,7 +2,7 @@ from typing import Sequence, Union import torch.nn as nn -from pytorch_lightning import Callback, LightningModule, Trainer +from lightning import Callback, LightningModule, Trainer from torch import Tensor diff --git a/src/pl_bolts/callbacks/data_monitor.py b/src/pl_bolts/callbacks/data_monitor.py index 7a39a1e70..7a6466821 100644 --- a/src/pl_bolts/callbacks/data_monitor.py +++ b/src/pl_bolts/callbacks/data_monitor.py @@ -2,23 +2,22 @@ import numpy as np import torch -from pytorch_lightning import Callback, LightningModule, Trainer -from pytorch_lightning.loggers import TensorBoardLogger, WandbLogger -from pytorch_lightning.utilities import rank_zero_warn -from pytorch_lightning.utilities.apply_func import apply_to_collection -from torch import Tensor, nn -from torch.nn import Module -from torch.utils.hooks import RemovableHandle - +from lightning import Callback, LightningModule, Trainer +from lightning.fabric.utilities import rank_zero_warn +from lightning.fabric.utilities.apply_func import apply_to_collection +from lightning.pytorch.loggers import TensorBoardLogger, WandbLogger from pl_bolts.utils import _WANDB_AVAILABLE from pl_bolts.utils.stability import under_review from pl_bolts.utils.warnings import warn_missing_pkg +from torch import Tensor, nn +from torch.nn import Module +from torch.utils.hooks import RemovableHandle # Backward compatibility for Lightning Logger try: - from pytorch_lightning.loggers import Logger + from lightning.pytorch.loggers import Logger except ImportError: - from pytorch_lightning.loggers import LightningLoggerBase as Logger + from lightning.pytorch.loggers import LightningLoggerBase as Logger if _WANDB_AVAILABLE: import wandb diff --git a/src/pl_bolts/callbacks/knn_online.py b/src/pl_bolts/callbacks/knn_online.py index 3168799c2..0d0cb8896 100644 --- a/src/pl_bolts/callbacks/knn_online.py +++ b/src/pl_bolts/callbacks/knn_online.py @@ -1,13 +1,12 @@ from typing import Optional, Tuple, Union import torch -from pytorch_lightning import Callback, LightningModule, Trainer -from pytorch_lightning.accelerators import Accelerator +from lightning import Callback, LightningModule, Trainer +from lightning.accelerators import Accelerator +from pl_bolts.utils.stability import under_review from torch import Tensor from torch.nn import functional as F # noqa: N812 -from pl_bolts.utils.stability import under_review - @under_review() class KNNOnlineEvaluator(Callback): diff --git a/src/pl_bolts/callbacks/printing.py b/src/pl_bolts/callbacks/printing.py index 1ba298aa1..555250585 100644 --- a/src/pl_bolts/callbacks/printing.py +++ b/src/pl_bolts/callbacks/printing.py @@ -2,10 +2,9 @@ from itertools import zip_longest from typing import Any, Callable, Dict, List, Optional -from pytorch_lightning import LightningModule, Trainer -from pytorch_lightning.callbacks import Callback -from pytorch_lightning.utilities import rank_zero_info - +from lightning import LightningModule, Trainer +from lightning.fabric.utilities import rank_zero_info +from lightning.pytorch.callbacks import Callback from pl_bolts.utils.stability import under_review diff --git a/src/pl_bolts/callbacks/sparseml.py b/src/pl_bolts/callbacks/sparseml.py index 3a5d48b47..430b867c4 100644 --- a/src/pl_bolts/callbacks/sparseml.py +++ b/src/pl_bolts/callbacks/sparseml.py @@ -14,10 +14,13 @@ from typing import Any, Optional import torch -from pytorch_lightning import Callback, LightningModule, Trainer -from pytorch_lightning.utilities.exceptions import MisconfigurationException - -from pl_bolts.utils import _SPARSEML_AVAILABLE, _SPARSEML_TORCH_SATISFIED, _SPARSEML_TORCH_SATISFIED_ERROR +from lightning import Callback, LightningModule, Trainer +from lightning.fabric.utilities.exceptions import MisconfigurationException +from pl_bolts.utils import ( + _SPARSEML_AVAILABLE, + _SPARSEML_TORCH_SATISFIED, + _SPARSEML_TORCH_SATISFIED_ERROR, +) if _SPARSEML_TORCH_SATISFIED: from sparseml.pytorch.optim import ScheduledModifierManager diff --git a/src/pl_bolts/callbacks/ssl_online.py b/src/pl_bolts/callbacks/ssl_online.py index 943a2a28f..22bbd718f 100644 --- a/src/pl_bolts/callbacks/ssl_online.py +++ b/src/pl_bolts/callbacks/ssl_online.py @@ -2,16 +2,15 @@ from typing import Any, Dict, Optional, Sequence, Tuple, Union import torch -from pytorch_lightning import Callback, LightningModule, Trainer -from pytorch_lightning.utilities import rank_zero_warn +from lightning import Callback, LightningModule, Trainer +from lightning.fabric.utilities import rank_zero_warn +from pl_bolts.models.self_supervised.evaluator import SSLEvaluator +from pl_bolts.utils.stability import under_review from torch import Tensor, nn from torch.nn import functional as F # noqa: N812 from torch.optim import Optimizer from torchmetrics.functional import accuracy -from pl_bolts.models.self_supervised.evaluator import SSLEvaluator -from pl_bolts.utils.stability import under_review - @under_review() class SSLOnlineEvaluator(Callback): # pragma: no cover diff --git a/src/pl_bolts/callbacks/torch_ort.py b/src/pl_bolts/callbacks/torch_ort.py index 1ce9c0262..c9feef774 100644 --- a/src/pl_bolts/callbacks/torch_ort.py +++ b/src/pl_bolts/callbacks/torch_ort.py @@ -11,9 +11,8 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -from pytorch_lightning import Callback, LightningModule, Trainer -from pytorch_lightning.utilities.exceptions import MisconfigurationException - +from lightning import Callback, LightningModule, Trainer +from lightning.fabric.utilities.exceptions import MisconfigurationException from pl_bolts.utils import _TORCH_ORT_AVAILABLE if _TORCH_ORT_AVAILABLE: diff --git a/src/pl_bolts/callbacks/variational.py b/src/pl_bolts/callbacks/variational.py index 66f74871a..eac65a9d9 100644 --- a/src/pl_bolts/callbacks/variational.py +++ b/src/pl_bolts/callbacks/variational.py @@ -2,13 +2,12 @@ import numpy as np import torch -from pytorch_lightning import LightningModule, Trainer -from pytorch_lightning.callbacks import Callback -from torch import Tensor - +from lightning import LightningModule, Trainer +from lightning.pytorch.callbacks import Callback from pl_bolts.utils import _TORCHVISION_AVAILABLE from pl_bolts.utils.stability import under_review from pl_bolts.utils.warnings import warn_missing_pkg +from torch import Tensor if _TORCHVISION_AVAILABLE: import torchvision diff --git a/src/pl_bolts/callbacks/verification/base.py b/src/pl_bolts/callbacks/verification/base.py index 49e2e3593..1134831f6 100644 --- a/src/pl_bolts/callbacks/verification/base.py +++ b/src/pl_bolts/callbacks/verification/base.py @@ -4,10 +4,9 @@ from typing import Any, Optional import torch.nn as nn -from pytorch_lightning import Callback, LightningModule -from pytorch_lightning.utilities import move_data_to_device, rank_zero_warn -from pytorch_lightning.utilities.signature_utils import is_param_in_hook_signature - +from lightning import Callback, LightningModule +from lightning.fabric.utilities import move_data_to_device, rank_zero_warn +from lightning.pytorch.utilities.signature_utils import is_param_in_hook_signature from pl_bolts.utils.stability import under_review diff --git a/src/pl_bolts/callbacks/verification/batch_gradient.py b/src/pl_bolts/callbacks/verification/batch_gradient.py index 834184c15..91299337e 100644 --- a/src/pl_bolts/callbacks/verification/batch_gradient.py +++ b/src/pl_bolts/callbacks/verification/batch_gradient.py @@ -4,13 +4,15 @@ import torch import torch.nn as nn -from pytorch_lightning import LightningModule, Trainer -from pytorch_lightning.utilities.apply_func import apply_to_collection -from pytorch_lightning.utilities.exceptions import MisconfigurationException -from torch import Tensor - -from pl_bolts.callbacks.verification.base import VerificationBase, VerificationCallbackBase +from lightning import LightningModule, Trainer +from lightning.fabric.utilities.apply_func import apply_to_collection +from lightning.fabric.utilities.exceptions import MisconfigurationException +from pl_bolts.callbacks.verification.base import ( + VerificationBase, + VerificationCallbackBase, +) from pl_bolts.utils.stability import under_review +from torch import Tensor @under_review() diff --git a/src/pl_bolts/callbacks/vision/confused_logit.py b/src/pl_bolts/callbacks/vision/confused_logit.py index e66a088a3..1c23c423f 100644 --- a/src/pl_bolts/callbacks/vision/confused_logit.py +++ b/src/pl_bolts/callbacks/vision/confused_logit.py @@ -1,12 +1,11 @@ from typing import Sequence import torch -from pytorch_lightning import Callback, LightningModule, Trainer -from torch import Tensor, nn - +from lightning import Callback, LightningModule, Trainer from pl_bolts.utils import _MATPLOTLIB_AVAILABLE from pl_bolts.utils.stability import under_review from pl_bolts.utils.warnings import warn_missing_pkg +from torch import Tensor, nn if _MATPLOTLIB_AVAILABLE: from matplotlib import pyplot as plt diff --git a/src/pl_bolts/callbacks/vision/image_generation.py b/src/pl_bolts/callbacks/vision/image_generation.py index a30e78972..3a7cc3df5 100644 --- a/src/pl_bolts/callbacks/vision/image_generation.py +++ b/src/pl_bolts/callbacks/vision/image_generation.py @@ -1,8 +1,7 @@ from typing import Optional, Tuple import torch -from pytorch_lightning import Callback, LightningModule, Trainer - +from lightning import Callback, LightningModule, Trainer from pl_bolts.utils import _TORCHVISION_AVAILABLE from pl_bolts.utils.stability import under_review from pl_bolts.utils.warnings import warn_missing_pkg diff --git a/src/pl_bolts/callbacks/vision/sr_image_logger.py b/src/pl_bolts/callbacks/vision/sr_image_logger.py index f27bd294e..d4b178a8e 100644 --- a/src/pl_bolts/callbacks/vision/sr_image_logger.py +++ b/src/pl_bolts/callbacks/vision/sr_image_logger.py @@ -3,8 +3,7 @@ import pytorch_lightning as pl import torch import torch.nn.functional as F # noqa: N812 -from pytorch_lightning import Callback - +from lightning import Callback from pl_bolts.utils import _TORCHVISION_AVAILABLE from pl_bolts.utils.stability import under_review from pl_bolts.utils.warnings import warn_missing_pkg diff --git a/src/pl_bolts/datamodules/cityscapes_datamodule.py b/src/pl_bolts/datamodules/cityscapes_datamodule.py index 351314d83..a1933eb7d 100644 --- a/src/pl_bolts/datamodules/cityscapes_datamodule.py +++ b/src/pl_bolts/datamodules/cityscapes_datamodule.py @@ -1,11 +1,10 @@ from typing import Any, Callable, Optional -from pytorch_lightning import LightningDataModule -from torch.utils.data import DataLoader - +from lightning import LightningDataModule from pl_bolts.utils import _TORCHVISION_AVAILABLE from pl_bolts.utils.stability import under_review from pl_bolts.utils.warnings import warn_missing_pkg +from torch.utils.data import DataLoader if _TORCHVISION_AVAILABLE: from torchvision import transforms as transform_lib diff --git a/src/pl_bolts/datamodules/imagenet_datamodule.py b/src/pl_bolts/datamodules/imagenet_datamodule.py index 90f2bb641..601186619 100644 --- a/src/pl_bolts/datamodules/imagenet_datamodule.py +++ b/src/pl_bolts/datamodules/imagenet_datamodule.py @@ -2,14 +2,13 @@ from argparse import ArgumentParser from typing import Any, Callable, Optional -from pytorch_lightning import LightningDataModule -from torch.utils.data import DataLoader - +from lightning import LightningDataModule from pl_bolts.datasets import UnlabeledImagenet from pl_bolts.transforms.dataset_normalizations import imagenet_normalization from pl_bolts.utils import _TORCHVISION_AVAILABLE from pl_bolts.utils.stability import under_review from pl_bolts.utils.warnings import warn_missing_pkg +from torch.utils.data import DataLoader if _TORCHVISION_AVAILABLE: from torchvision import transforms as transform_lib diff --git a/src/pl_bolts/datamodules/kitti_datamodule.py b/src/pl_bolts/datamodules/kitti_datamodule.py index f30f99689..e7416bebb 100644 --- a/src/pl_bolts/datamodules/kitti_datamodule.py +++ b/src/pl_bolts/datamodules/kitti_datamodule.py @@ -2,14 +2,13 @@ from typing import Any, Callable, Optional import torch -from pytorch_lightning import LightningDataModule -from torch.utils.data import DataLoader -from torch.utils.data.dataset import random_split - +from lightning import LightningDataModule from pl_bolts.datasets import KittiDataset from pl_bolts.utils import _TORCHVISION_AVAILABLE from pl_bolts.utils.stability import under_review from pl_bolts.utils.warnings import warn_missing_pkg +from torch.utils.data import DataLoader +from torch.utils.data.dataset import random_split if _TORCHVISION_AVAILABLE: from torchvision import transforms diff --git a/src/pl_bolts/datamodules/sklearn_datamodule.py b/src/pl_bolts/datamodules/sklearn_datamodule.py index 1c9361369..3c179d5b3 100644 --- a/src/pl_bolts/datamodules/sklearn_datamodule.py +++ b/src/pl_bolts/datamodules/sklearn_datamodule.py @@ -2,12 +2,11 @@ from typing import Any, Tuple import numpy as np -from pytorch_lightning import LightningDataModule -from torch.utils.data import DataLoader, Dataset - +from lightning import LightningDataModule from pl_bolts.utils import _SKLEARN_AVAILABLE from pl_bolts.utils.stability import under_review from pl_bolts.utils.warnings import warn_missing_pkg +from torch.utils.data import DataLoader, Dataset if _SKLEARN_AVAILABLE: from sklearn.utils import shuffle as sk_shuffle diff --git a/src/pl_bolts/datamodules/sr_datamodule.py b/src/pl_bolts/datamodules/sr_datamodule.py index b6d30b3c8..657e5d320 100644 --- a/src/pl_bolts/datamodules/sr_datamodule.py +++ b/src/pl_bolts/datamodules/sr_datamodule.py @@ -1,9 +1,8 @@ from typing import Any -from pytorch_lightning import LightningDataModule -from torch.utils.data import DataLoader, Dataset - +from lightning import LightningDataModule from pl_bolts.utils.stability import under_review +from torch.utils.data import DataLoader, Dataset @under_review() diff --git a/src/pl_bolts/datamodules/ssl_imagenet_datamodule.py b/src/pl_bolts/datamodules/ssl_imagenet_datamodule.py index 0c811406d..5a489d850 100644 --- a/src/pl_bolts/datamodules/ssl_imagenet_datamodule.py +++ b/src/pl_bolts/datamodules/ssl_imagenet_datamodule.py @@ -1,14 +1,13 @@ import os from typing import Any, Callable, Optional -from pytorch_lightning import LightningDataModule -from torch.utils.data import DataLoader - +from lightning import LightningDataModule from pl_bolts.datasets import UnlabeledImagenet from pl_bolts.transforms.dataset_normalizations import imagenet_normalization from pl_bolts.utils import _TORCHVISION_AVAILABLE from pl_bolts.utils.stability import under_review from pl_bolts.utils.warnings import warn_missing_pkg +from torch.utils.data import DataLoader if _TORCHVISION_AVAILABLE: from torchvision import transforms as transform_lib diff --git a/src/pl_bolts/datamodules/stl10_datamodule.py b/src/pl_bolts/datamodules/stl10_datamodule.py index 158baee0b..fa77682be 100644 --- a/src/pl_bolts/datamodules/stl10_datamodule.py +++ b/src/pl_bolts/datamodules/stl10_datamodule.py @@ -3,14 +3,13 @@ from typing import Any, Callable, Optional import torch -from pytorch_lightning import LightningDataModule -from torch.utils.data import DataLoader, random_split - +from lightning import LightningDataModule from pl_bolts.datasets import ConcatDataset from pl_bolts.transforms.dataset_normalizations import stl10_normalization from pl_bolts.utils import _TORCHVISION_AVAILABLE from pl_bolts.utils.stability import under_review from pl_bolts.utils.warnings import warn_missing_pkg +from torch.utils.data import DataLoader, random_split if _TORCHVISION_AVAILABLE: from torchvision import transforms as transform_lib diff --git a/src/pl_bolts/datamodules/vision_datamodule.py b/src/pl_bolts/datamodules/vision_datamodule.py index 204ebf3ae..1c0791d52 100644 --- a/src/pl_bolts/datamodules/vision_datamodule.py +++ b/src/pl_bolts/datamodules/vision_datamodule.py @@ -3,7 +3,7 @@ from typing import Any, Callable, List, Optional, Union import torch -from pytorch_lightning import LightningDataModule +from lightning import LightningDataModule from torch.utils.data import DataLoader, Dataset, random_split diff --git a/src/pl_bolts/datamodules/vocdetection_datamodule.py b/src/pl_bolts/datamodules/vocdetection_datamodule.py index de8a84b0f..98af7df88 100644 --- a/src/pl_bolts/datamodules/vocdetection_datamodule.py +++ b/src/pl_bolts/datamodules/vocdetection_datamodule.py @@ -2,13 +2,12 @@ from typing import Any, Callable, Dict, List, Optional, Tuple import torch -from pytorch_lightning import LightningDataModule -from torch import Tensor -from torch.utils.data import DataLoader, Dataset - +from lightning import LightningDataModule from pl_bolts.utils import _TORCHVISION_AVAILABLE from pl_bolts.utils.stability import under_review from pl_bolts.utils.warnings import warn_missing_pkg +from torch import Tensor +from torch.utils.data import DataLoader, Dataset if _TORCHVISION_AVAILABLE: from torchvision import transforms as transform_lib diff --git a/src/pl_bolts/datasets/array_dataset.py b/src/pl_bolts/datasets/array_dataset.py index ab17d3b7e..476cd4efa 100644 --- a/src/pl_bolts/datasets/array_dataset.py +++ b/src/pl_bolts/datasets/array_dataset.py @@ -1,9 +1,8 @@ from typing import Tuple, Union -from pytorch_lightning.utilities import exceptions -from torch.utils.data import Dataset - +from lightning.fabric.utilities import exceptions from pl_bolts.datasets.base_dataset import DataModel, TArrays +from torch.utils.data import Dataset class ArrayDataset(Dataset): diff --git a/src/pl_bolts/losses/self_supervised_learning.py b/src/pl_bolts/losses/self_supervised_learning.py index 2588fc607..8fbe6f5f3 100644 --- a/src/pl_bolts/losses/self_supervised_learning.py +++ b/src/pl_bolts/losses/self_supervised_learning.py @@ -1,9 +1,8 @@ import numpy as np import torch -from torch import nn - from pl_bolts.models.vision.pixel_cnn import PixelCNN from pl_bolts.utils.stability import under_review +from torch import nn @under_review() @@ -307,7 +306,7 @@ def forward(self, anchor_maps, positive_maps): Example: >>> import torch - >>> from pytorch_lightning import seed_everything + >>> from lightning import seed_everything >>> seed_everything(0) 0 >>> a1 = torch.rand(3, 5, 2, 2) diff --git a/src/pl_bolts/models/autoencoders/basic_ae/basic_ae_module.py b/src/pl_bolts/models/autoencoders/basic_ae/basic_ae_module.py index 300d45949..112f40a6f 100644 --- a/src/pl_bolts/models/autoencoders/basic_ae/basic_ae_module.py +++ b/src/pl_bolts/models/autoencoders/basic_ae/basic_ae_module.py @@ -2,10 +2,7 @@ from argparse import ArgumentParser import torch -from pytorch_lightning import LightningModule, Trainer -from torch import nn -from torch.nn import functional as F # noqa: N812 - +from lightning import LightningModule, Trainer from pl_bolts import _HTTPS_AWS_HUB from pl_bolts.models.autoencoders.components import ( resnet18_decoder, @@ -14,6 +11,8 @@ resnet50_encoder, ) from pl_bolts.utils.stability import under_review +from torch import nn +from torch.nn import functional as F # noqa: N812 @under_review() @@ -154,7 +153,11 @@ def add_model_specific_args(parent_parser): @under_review() def cli_main(args=None): - from pl_bolts.datamodules import CIFAR10DataModule, ImagenetDataModule, STL10DataModule + from pl_bolts.datamodules import ( + CIFAR10DataModule, + ImagenetDataModule, + STL10DataModule, + ) parser = ArgumentParser() parser.add_argument("--dataset", default="cifar10", type=str, choices=["cifar10", "stl10", "imagenet"]) diff --git a/src/pl_bolts/models/autoencoders/basic_vae/basic_vae_module.py b/src/pl_bolts/models/autoencoders/basic_vae/basic_vae_module.py index 0cf6729df..5f6295492 100644 --- a/src/pl_bolts/models/autoencoders/basic_vae/basic_vae_module.py +++ b/src/pl_bolts/models/autoencoders/basic_vae/basic_vae_module.py @@ -2,10 +2,7 @@ from argparse import ArgumentParser import torch -from pytorch_lightning import LightningModule, Trainer, seed_everything -from torch import nn -from torch.nn import functional as F # noqa: N812 - +from lightning import LightningModule, Trainer, seed_everything from pl_bolts import _HTTPS_AWS_HUB from pl_bolts.models.autoencoders.components import ( resnet18_decoder, @@ -14,6 +11,8 @@ resnet50_encoder, ) from pl_bolts.utils.stability import under_review +from torch import nn +from torch.nn import functional as F # noqa: N812 @under_review() @@ -187,7 +186,11 @@ def add_model_specific_args(parent_parser): @under_review() def cli_main(args=None): - from pl_bolts.datamodules import CIFAR10DataModule, ImagenetDataModule, STL10DataModule + from pl_bolts.datamodules import ( + CIFAR10DataModule, + ImagenetDataModule, + STL10DataModule, + ) seed_everything() diff --git a/src/pl_bolts/models/detection/faster_rcnn/faster_rcnn_module.py b/src/pl_bolts/models/detection/faster_rcnn/faster_rcnn_module.py index 440df81d7..1cc2d22e1 100644 --- a/src/pl_bolts/models/detection/faster_rcnn/faster_rcnn_module.py +++ b/src/pl_bolts/models/detection/faster_rcnn/faster_rcnn_module.py @@ -2,16 +2,20 @@ from typing import Any, Optional, Union import torch -from pytorch_lightning import LightningModule, Trainer, seed_everything - +from lightning import LightningModule, Trainer, seed_everything from pl_bolts.models.detection.faster_rcnn import create_fasterrcnn_backbone from pl_bolts.utils import _TORCHVISION_AVAILABLE from pl_bolts.utils.stability import under_review from pl_bolts.utils.warnings import warn_missing_pkg if _TORCHVISION_AVAILABLE: - from torchvision.models.detection.faster_rcnn import FasterRCNN as torchvision_FasterRCNN - from torchvision.models.detection.faster_rcnn import FastRCNNPredictor, fasterrcnn_resnet50_fpn + from torchvision.models.detection.faster_rcnn import ( + FasterRCNN as torchvision_FasterRCNN, + ) + from torchvision.models.detection.faster_rcnn import ( + FastRCNNPredictor, + fasterrcnn_resnet50_fpn, + ) from torchvision.ops import box_iou else: # pragma: no cover warn_missing_pkg("torchvision") diff --git a/src/pl_bolts/models/detection/retinanet/retinanet_module.py b/src/pl_bolts/models/detection/retinanet/retinanet_module.py index 78c140340..589b179bd 100644 --- a/src/pl_bolts/models/detection/retinanet/retinanet_module.py +++ b/src/pl_bolts/models/detection/retinanet/retinanet_module.py @@ -1,16 +1,20 @@ from typing import Any, Optional import torch -from pytorch_lightning import LightningModule - +from lightning import LightningModule from pl_bolts.models.detection.retinanet import create_retinanet_backbone from pl_bolts.utils import _TORCHVISION_AVAILABLE, _TORCHVISION_LESS_THAN_0_13 from pl_bolts.utils.stability import under_review from pl_bolts.utils.warnings import warn_missing_pkg if _TORCHVISION_AVAILABLE: - from torchvision.models.detection.retinanet import RetinaNet as torchvision_RetinaNet - from torchvision.models.detection.retinanet import RetinaNetHead, retinanet_resnet50_fpn + from torchvision.models.detection.retinanet import ( + RetinaNet as torchvision_RetinaNet, + ) + from torchvision.models.detection.retinanet import ( + RetinaNetHead, + retinanet_resnet50_fpn, + ) from torchvision.ops import box_iou else: # pragma: no cover warn_missing_pkg("torchvision") @@ -136,8 +140,7 @@ def configure_optimizers(self): @under_review() def cli_main(): - from pytorch_lightning.cli import LightningCLI - + from lightning.cli import LightningCLI from pl_bolts.datamodules import VOCDetectionDataModule LightningCLI(RetinaNet, VOCDetectionDataModule, seed_everything_default=42) diff --git a/src/pl_bolts/models/detection/yolo/darknet_network.py b/src/pl_bolts/models/detection/yolo/darknet_network.py index 7a38a0f5d..cdd357ad8 100644 --- a/src/pl_bolts/models/detection/yolo/darknet_network.py +++ b/src/pl_bolts/models/detection/yolo/darknet_network.py @@ -6,16 +6,23 @@ import numpy as np import torch import torch.nn as nn -from pytorch_lightning.utilities.exceptions import MisconfigurationException +from lightning.fabric.utilities.exceptions import MisconfigurationException try: - from pytorch_lightning.utilities.rank_zero import rank_zero_info + from lightning.fabric.utilities.rank_zero import rank_zero_info except ModuleNotFoundError: - from pytorch_lightning.utilities.distributed import rank_zero_info + from lightning.fabric.utilities.distributed import rank_zero_info from torch import Tensor -from .layers import Conv, DetectionLayer, MaxPool, RouteLayer, ShortcutLayer, create_detection_layer +from .layers import ( + Conv, + DetectionLayer, + MaxPool, + RouteLayer, + ShortcutLayer, + create_detection_layer, +) from .torch_networks import NETWORK_OUTPUT from .types import TARGETS from .utils import get_image_size diff --git a/src/pl_bolts/models/detection/yolo/yolo_module.py b/src/pl_bolts/models/detection/yolo/yolo_module.py index 429ed087b..6b0a79035 100644 --- a/src/pl_bolts/models/detection/yolo/yolo_module.py +++ b/src/pl_bolts/models/detection/yolo/yolo_module.py @@ -3,8 +3,8 @@ import torch import torch.nn as nn -from pytorch_lightning import LightningModule -from pytorch_lightning.utilities.types import STEP_OUTPUT +from lightning import LightningModule +from lightning.pytorch.utilities.types import STEP_OUTPUT from torch import Tensor, optim # It seems to be impossible to avoid mypy errors if using import instead of getattr(). @@ -621,6 +621,6 @@ def _resize(self, image: Tensor, target: TARGET) -> Tuple[Tensor, TARGET]: if __name__ == "__main__": - from pytorch_lightning.cli import LightningCLI + from lightning.cli import LightningCLI LightningCLI(CLIYOLO, ResizedVOCDetectionDataModule, seed_everything_default=42) diff --git a/src/pl_bolts/models/gans/basic/basic_gan_module.py b/src/pl_bolts/models/gans/basic/basic_gan_module.py index 965577b6b..45e7ecc15 100644 --- a/src/pl_bolts/models/gans/basic/basic_gan_module.py +++ b/src/pl_bolts/models/gans/basic/basic_gan_module.py @@ -1,11 +1,10 @@ from argparse import ArgumentParser import torch -from pytorch_lightning import LightningModule, Trainer, seed_everything -from pytorch_lightning.callbacks.progress import TQDMProgressBar -from torch.nn import functional as F # noqa: N812 - +from lightning import LightningModule, Trainer, seed_everything +from lightning.pytorch.callbacks.progress import TQDMProgressBar from pl_bolts.models.gans.basic.components import Discriminator, Generator +from torch.nn import functional as F # noqa: N812 class GAN(LightningModule): @@ -152,8 +151,16 @@ def add_model_specific_args(parent_parser): def cli_main(args=None): - from pl_bolts.callbacks import LatentDimInterpolator, TensorboardGenerativeModelImageSampler - from pl_bolts.datamodules import CIFAR10DataModule, ImagenetDataModule, MNISTDataModule, STL10DataModule + from pl_bolts.callbacks import ( + LatentDimInterpolator, + TensorboardGenerativeModelImageSampler, + ) + from pl_bolts.datamodules import ( + CIFAR10DataModule, + ImagenetDataModule, + MNISTDataModule, + STL10DataModule, + ) seed_everything(1234) diff --git a/src/pl_bolts/models/gans/dcgan/dcgan_module.py b/src/pl_bolts/models/gans/dcgan/dcgan_module.py index 01202a25b..e054d3de4 100644 --- a/src/pl_bolts/models/gans/dcgan/dcgan_module.py +++ b/src/pl_bolts/models/gans/dcgan/dcgan_module.py @@ -3,14 +3,16 @@ import torch import torch.nn as nn -from pytorch_lightning import LightningModule, Trainer, seed_everything -from torch import Tensor -from torch.utils.data import DataLoader - -from pl_bolts.callbacks import LatentDimInterpolator, TensorboardGenerativeModelImageSampler +from lightning import LightningModule, Trainer, seed_everything +from pl_bolts.callbacks import ( + LatentDimInterpolator, + TensorboardGenerativeModelImageSampler, +) from pl_bolts.models.gans.dcgan.components import DCGANDiscriminator, DCGANGenerator from pl_bolts.utils import _TORCHVISION_AVAILABLE from pl_bolts.utils.warnings import warn_missing_pkg +from torch import Tensor +from torch.utils.data import DataLoader if _TORCHVISION_AVAILABLE: from torchvision import transforms as transform_lib diff --git a/src/pl_bolts/models/gans/pix2pix/pix2pix_module.py b/src/pl_bolts/models/gans/pix2pix/pix2pix_module.py index 7414e4d93..1886a41c9 100644 --- a/src/pl_bolts/models/gans/pix2pix/pix2pix_module.py +++ b/src/pl_bolts/models/gans/pix2pix/pix2pix_module.py @@ -1,9 +1,8 @@ import torch -from pytorch_lightning import LightningModule -from torch import nn - +from lightning import LightningModule from pl_bolts.models.gans.pix2pix.components import Generator, PatchGAN from pl_bolts.utils.stability import under_review +from torch import nn @under_review() diff --git a/src/pl_bolts/models/mnist_module.py b/src/pl_bolts/models/mnist_module.py index 83af348c1..a2485c6c0 100644 --- a/src/pl_bolts/models/mnist_module.py +++ b/src/pl_bolts/models/mnist_module.py @@ -2,12 +2,11 @@ from typing import Any import torch -from pytorch_lightning import LightningModule, Trainer +from lightning import LightningModule, Trainer +from pl_bolts.utils import _TORCHVISION_AVAILABLE from torch import Tensor from torch.nn import functional as F # noqa: N812 -from pl_bolts.utils import _TORCHVISION_AVAILABLE - class LitMNIST(LightningModule): """PyTorch Lightning implementation of a two-layer MNIST classification module. diff --git a/src/pl_bolts/models/regression/linear_regression.py b/src/pl_bolts/models/regression/linear_regression.py index 6993d8237..4ca9433e5 100644 --- a/src/pl_bolts/models/regression/linear_regression.py +++ b/src/pl_bolts/models/regression/linear_regression.py @@ -2,14 +2,13 @@ from typing import Any, Dict, List, Tuple, Type import torch -from pytorch_lightning import LightningModule, Trainer, seed_everything +from lightning import LightningModule, Trainer, seed_everything +from pl_bolts.utils.stability import under_review from torch import Tensor, nn from torch.nn import functional as F # noqa: N812 from torch.optim import Adam from torch.optim.optimizer import Optimizer -from pl_bolts.utils.stability import under_review - @under_review() class LinearRegression(LightningModule): diff --git a/src/pl_bolts/models/regression/logistic_regression.py b/src/pl_bolts/models/regression/logistic_regression.py index 26ec01a19..022a01458 100644 --- a/src/pl_bolts/models/regression/logistic_regression.py +++ b/src/pl_bolts/models/regression/logistic_regression.py @@ -4,14 +4,13 @@ from typing import Any, Dict, List, Tuple, Type import torch -from pytorch_lightning import LightningModule, Trainer, seed_everything +from lightning import LightningModule, Trainer, seed_everything +from pl_bolts.utils.stability import under_review from torch import Tensor, nn from torch.optim import Adam from torch.optim.optimizer import Optimizer from torchmetrics import functional -from pl_bolts.utils.stability import under_review - class LogisticRegression(LightningModule): """Logistic Regression Model.""" diff --git a/src/pl_bolts/models/rl/advantage_actor_critic_model.py b/src/pl_bolts/models/rl/advantage_actor_critic_model.py index e4863e32f..7672b2bb8 100644 --- a/src/pl_bolts/models/rl/advantage_actor_critic_model.py +++ b/src/pl_bolts/models/rl/advantage_actor_critic_model.py @@ -1,22 +1,22 @@ """Advantage Actor Critic (A2C)""" + from argparse import ArgumentParser from collections import OrderedDict from typing import Any, Iterator, List, Tuple import numpy as np import torch -from pytorch_lightning import LightningModule, Trainer, seed_everything -from pytorch_lightning.callbacks import ModelCheckpoint -from torch import Tensor, optim -from torch.optim.optimizer import Optimizer -from torch.utils.data import DataLoader - +from lightning import LightningModule, Trainer, seed_everything +from lightning.pytorch.callbacks import ModelCheckpoint from pl_bolts.datamodules import ExperienceSourceDataset from pl_bolts.models.rl.common.agents import ActorCriticAgent from pl_bolts.models.rl.common.networks import ActorCriticMLP from pl_bolts.utils import _GYM_AVAILABLE from pl_bolts.utils.stability import under_review from pl_bolts.utils.warnings import warn_missing_pkg +from torch import Tensor, optim +from torch.optim.optimizer import Optimizer +from torch.utils.data import DataLoader if _GYM_AVAILABLE: import gym diff --git a/src/pl_bolts/models/rl/double_dqn_model.py b/src/pl_bolts/models/rl/double_dqn_model.py index 2d76279c8..954c64eef 100644 --- a/src/pl_bolts/models/rl/double_dqn_model.py +++ b/src/pl_bolts/models/rl/double_dqn_model.py @@ -1,14 +1,14 @@ """Double DQN.""" + import argparse from collections import OrderedDict from typing import Tuple -from pytorch_lightning import Trainer -from torch import Tensor - +from lightning import Trainer from pl_bolts.losses.rl import double_dqn_loss from pl_bolts.models.rl.dqn_model import DQN from pl_bolts.utils.stability import under_review +from torch import Tensor @under_review() diff --git a/src/pl_bolts/models/rl/dqn_model.py b/src/pl_bolts/models/rl/dqn_model.py index 567aa8d18..add95a67d 100644 --- a/src/pl_bolts/models/rl/dqn_model.py +++ b/src/pl_bolts/models/rl/dqn_model.py @@ -1,17 +1,14 @@ """Deep Q Network.""" + import argparse from collections import OrderedDict from typing import Dict, List, Optional, Tuple import numpy as np import torch -from pytorch_lightning import LightningModule, Trainer, seed_everything -from pytorch_lightning.callbacks import ModelCheckpoint -from pytorch_lightning.strategies import DataParallelStrategy -from torch import Tensor, optim -from torch.optim.optimizer import Optimizer -from torch.utils.data import DataLoader - +from lightning import LightningModule, Trainer, seed_everything +from lightning.pytorch.callbacks import ModelCheckpoint +from lightning.pytorch.strategies import DataParallelStrategy from pl_bolts.datamodules.experience_source import Experience, ExperienceSourceDataset from pl_bolts.losses.rl import dqn_loss from pl_bolts.models.rl.common.agents import ValueAgent @@ -21,6 +18,9 @@ from pl_bolts.utils import _GYM_AVAILABLE from pl_bolts.utils.stability import under_review from pl_bolts.utils.warnings import warn_missing_pkg +from torch import Tensor, optim +from torch.optim.optimizer import Optimizer +from torch.utils.data import DataLoader if _GYM_AVAILABLE: from gym import Env diff --git a/src/pl_bolts/models/rl/dueling_dqn_model.py b/src/pl_bolts/models/rl/dueling_dqn_model.py index 1e072d5ff..8a6d19e5b 100644 --- a/src/pl_bolts/models/rl/dueling_dqn_model.py +++ b/src/pl_bolts/models/rl/dueling_dqn_model.py @@ -1,8 +1,8 @@ """Dueling DQN.""" -import argparse -from pytorch_lightning import Trainer +import argparse +from lightning import Trainer from pl_bolts.models.rl.common.networks import DuelingCNN from pl_bolts.models.rl.dqn_model import DQN from pl_bolts.utils.stability import under_review diff --git a/src/pl_bolts/models/rl/noisy_dqn_model.py b/src/pl_bolts/models/rl/noisy_dqn_model.py index 76b4531c5..3b871ac05 100644 --- a/src/pl_bolts/models/rl/noisy_dqn_model.py +++ b/src/pl_bolts/models/rl/noisy_dqn_model.py @@ -1,15 +1,15 @@ """Noisy DQN.""" + import argparse from typing import Tuple import numpy as np -from pytorch_lightning import Trainer -from torch import Tensor - +from lightning import Trainer from pl_bolts.datamodules.experience_source import Experience from pl_bolts.models.rl.common.networks import NoisyCNN from pl_bolts.models.rl.dqn_model import DQN from pl_bolts.utils.stability import under_review +from torch import Tensor @under_review() diff --git a/src/pl_bolts/models/rl/per_dqn_model.py b/src/pl_bolts/models/rl/per_dqn_model.py index a864afb51..f4497d8b4 100644 --- a/src/pl_bolts/models/rl/per_dqn_model.py +++ b/src/pl_bolts/models/rl/per_dqn_model.py @@ -1,18 +1,18 @@ """Prioritized Experience Replay DQN.""" + import argparse from collections import OrderedDict from typing import Tuple import numpy as np -from pytorch_lightning import Trainer -from torch import Tensor -from torch.utils.data import DataLoader - +from lightning import Trainer from pl_bolts.datamodules import ExperienceSourceDataset from pl_bolts.losses.rl import per_dqn_loss from pl_bolts.models.rl.common.memory import Experience, PERBuffer from pl_bolts.models.rl.dqn_model import DQN from pl_bolts.utils.stability import under_review +from torch import Tensor +from torch.utils.data import DataLoader @under_review() diff --git a/src/pl_bolts/models/rl/ppo_model.py b/src/pl_bolts/models/rl/ppo_model.py index 21bc0873c..b16d47435 100644 --- a/src/pl_bolts/models/rl/ppo_model.py +++ b/src/pl_bolts/models/rl/ppo_model.py @@ -2,16 +2,15 @@ from typing import Any, List, Tuple import torch -from pytorch_lightning import LightningModule, Trainer, seed_everything -from torch import Tensor -from torch.optim.optimizer import Optimizer -from torch.utils.data import DataLoader - +from lightning import LightningModule, Trainer, seed_everything from pl_bolts.datamodules import ExperienceSourceDataset from pl_bolts.models.rl.common.networks import MLP, ActorCategorical, ActorContinous from pl_bolts.utils import _GYM_AVAILABLE from pl_bolts.utils.stability import under_review from pl_bolts.utils.warnings import warn_missing_pkg +from torch import Tensor +from torch.optim.optimizer import Optimizer +from torch.utils.data import DataLoader if _GYM_AVAILABLE: import gym diff --git a/src/pl_bolts/models/rl/reinforce_model.py b/src/pl_bolts/models/rl/reinforce_model.py index 876686fab..3c814791e 100644 --- a/src/pl_bolts/models/rl/reinforce_model.py +++ b/src/pl_bolts/models/rl/reinforce_model.py @@ -3,13 +3,8 @@ from typing import List, Tuple import numpy as np -from pytorch_lightning import LightningModule, Trainer, seed_everything -from pytorch_lightning.callbacks import ModelCheckpoint -from torch import Tensor, optim -from torch.nn.functional import log_softmax -from torch.optim.optimizer import Optimizer -from torch.utils.data import DataLoader - +from lightning import LightningModule, Trainer, seed_everything +from lightning.pytorch.callbacks import ModelCheckpoint from pl_bolts.datamodules import ExperienceSourceDataset from pl_bolts.datamodules.experience_source import Experience from pl_bolts.models.rl.common.agents import PolicyAgent @@ -17,6 +12,10 @@ from pl_bolts.utils import _GYM_AVAILABLE from pl_bolts.utils.stability import under_review from pl_bolts.utils.warnings import warn_missing_pkg +from torch import Tensor, optim +from torch.nn.functional import log_softmax +from torch.optim.optimizer import Optimizer +from torch.utils.data import DataLoader if _GYM_AVAILABLE: import gym diff --git a/src/pl_bolts/models/rl/sac_model.py b/src/pl_bolts/models/rl/sac_model.py index 8c0bb2b71..eed7ad6d7 100644 --- a/src/pl_bolts/models/rl/sac_model.py +++ b/src/pl_bolts/models/rl/sac_model.py @@ -1,16 +1,12 @@ """Soft Actor Critic.""" + import argparse from typing import Dict, List, Tuple import numpy as np import torch -from pytorch_lightning import LightningModule, Trainer, seed_everything -from pytorch_lightning.callbacks import ModelCheckpoint -from torch import Tensor, optim -from torch.nn import functional as F # noqa: N812 -from torch.optim.optimizer import Optimizer -from torch.utils.data import DataLoader - +from lightning import LightningModule, Trainer, seed_everything +from lightning.pytorch.callbacks import ModelCheckpoint from pl_bolts.datamodules.experience_source import Experience, ExperienceSourceDataset from pl_bolts.models.rl.common.agents import SoftActorCriticAgent from pl_bolts.models.rl.common.memory import MultiStepBuffer @@ -18,6 +14,10 @@ from pl_bolts.utils import _GYM_AVAILABLE from pl_bolts.utils.stability import under_review from pl_bolts.utils.warnings import warn_missing_pkg +from torch import Tensor, optim +from torch.nn import functional as F # noqa: N812 +from torch.optim.optimizer import Optimizer +from torch.utils.data import DataLoader if _GYM_AVAILABLE: import gym diff --git a/src/pl_bolts/models/rl/vanilla_policy_gradient_model.py b/src/pl_bolts/models/rl/vanilla_policy_gradient_model.py index bebfbc763..73d25e698 100644 --- a/src/pl_bolts/models/rl/vanilla_policy_gradient_model.py +++ b/src/pl_bolts/models/rl/vanilla_policy_gradient_model.py @@ -4,19 +4,18 @@ import numpy as np import torch -from pytorch_lightning import LightningModule, Trainer, seed_everything -from pytorch_lightning.callbacks import ModelCheckpoint -from torch import Tensor, optim -from torch.nn.functional import log_softmax, softmax -from torch.optim.optimizer import Optimizer -from torch.utils.data import DataLoader - +from lightning import LightningModule, Trainer, seed_everything +from lightning.pytorch.callbacks import ModelCheckpoint from pl_bolts.datamodules import ExperienceSourceDataset from pl_bolts.models.rl.common.agents import PolicyAgent from pl_bolts.models.rl.common.networks import MLP from pl_bolts.utils import _GYM_AVAILABLE from pl_bolts.utils.stability import under_review from pl_bolts.utils.warnings import warn_missing_pkg +from torch import Tensor, optim +from torch.nn.functional import log_softmax, softmax +from torch.optim.optimizer import Optimizer +from torch.utils.data import DataLoader if _GYM_AVAILABLE: import gym diff --git a/src/pl_bolts/models/self_supervised/amdim/amdim_module.py b/src/pl_bolts/models/self_supervised/amdim/amdim_module.py index 295b2fe09..690a23be0 100644 --- a/src/pl_bolts/models/self_supervised/amdim/amdim_module.py +++ b/src/pl_bolts/models/self_supervised/amdim/amdim_module.py @@ -3,15 +3,14 @@ from typing import Union import torch -from pytorch_lightning import LightningDataModule, LightningModule, Trainer -from torch import optim -from torch.utils.data import DataLoader - +from lightning import LightningDataModule, LightningModule, Trainer from pl_bolts.losses.self_supervised_learning import FeatureMapContrastiveTask from pl_bolts.models.self_supervised.amdim.datasets import AMDIMPretraining from pl_bolts.models.self_supervised.amdim.networks import AMDIMEncoder from pl_bolts.utils.self_supervised import torchvision_ssl_encoder from pl_bolts.utils.stability import under_review +from torch import optim +from torch.utils.data import DataLoader @under_review() diff --git a/src/pl_bolts/models/self_supervised/byol/byol_module.py b/src/pl_bolts/models/self_supervised/byol/byol_module.py index 93bcd2742..5210c2d66 100644 --- a/src/pl_bolts/models/self_supervised/byol/byol_module.py +++ b/src/pl_bolts/models/self_supervised/byol/byol_module.py @@ -3,14 +3,13 @@ from typing import Any, Union import torch -from pytorch_lightning import LightningModule, Trainer, seed_everything -from torch import Tensor -from torch.nn import functional as F # noqa: N812 -from torch.optim import Adam - +from lightning import LightningModule, Trainer, seed_everything from pl_bolts.callbacks.byol_updates import BYOLMAWeightUpdate from pl_bolts.models.self_supervised.byol.models import MLP, SiameseArm from pl_bolts.optimizers.lr_scheduler import LinearWarmupCosineAnnealingLR +from torch import Tensor +from torch.nn import functional as F # noqa: N812 +from torch.optim import Adam class BYOL(LightningModule): @@ -168,8 +167,15 @@ def add_model_specific_args(parent_parser: ArgumentParser) -> ArgumentParser: def cli_main(): from pl_bolts.callbacks.ssl_online import SSLOnlineEvaluator - from pl_bolts.datamodules import CIFAR10DataModule, ImagenetDataModule, STL10DataModule - from pl_bolts.models.self_supervised.simclr import SimCLREvalDataTransform, SimCLRTrainDataTransform + from pl_bolts.datamodules import ( + CIFAR10DataModule, + ImagenetDataModule, + STL10DataModule, + ) + from pl_bolts.models.self_supervised.simclr import ( + SimCLREvalDataTransform, + SimCLRTrainDataTransform, + ) seed_everything(1234) diff --git a/src/pl_bolts/models/self_supervised/cpc/cpc_finetuner.py b/src/pl_bolts/models/self_supervised/cpc/cpc_finetuner.py index 7d66cc772..424918dfd 100644 --- a/src/pl_bolts/models/self_supervised/cpc/cpc_finetuner.py +++ b/src/pl_bolts/models/self_supervised/cpc/cpc_finetuner.py @@ -1,8 +1,7 @@ import os from argparse import ArgumentParser -from pytorch_lightning import Trainer, seed_everything - +from lightning import Trainer, seed_everything from pl_bolts.models.self_supervised import CPC_v2, SSLFineTuner from pl_bolts.transforms.self_supervised.cpc_transforms import ( CPCEvalTransformsCIFAR10, diff --git a/src/pl_bolts/models/self_supervised/cpc/cpc_module.py b/src/pl_bolts/models/self_supervised/cpc/cpc_module.py index 6f60b7a26..8cc613448 100644 --- a/src/pl_bolts/models/self_supervised/cpc/cpc_module.py +++ b/src/pl_bolts/models/self_supervised/cpc/cpc_module.py @@ -1,13 +1,12 @@ """CPC V2.""" + import math from argparse import ArgumentParser from typing import Optional import torch -from pytorch_lightning import LightningModule, Trainer, seed_everything -from pytorch_lightning.utilities import rank_zero_warn -from torch import optim - +from lightning import LightningModule, Trainer, seed_everything +from lightning.fabric.utilities import rank_zero_warn from pl_bolts.datamodules.stl10_datamodule import STL10DataModule from pl_bolts.losses.self_supervised_learning import CPCTask from pl_bolts.models.self_supervised.cpc.networks import cpc_resnet101 @@ -22,6 +21,7 @@ from pl_bolts.utils.pretrained_weights import load_pretrained from pl_bolts.utils.self_supervised import torchvision_ssl_encoder from pl_bolts.utils.stability import under_review +from torch import optim __all__ = ["CPC_v2"] diff --git a/src/pl_bolts/models/self_supervised/moco/callbacks.py b/src/pl_bolts/models/self_supervised/moco/callbacks.py index 24bd5a1bf..6b4bc2dfc 100644 --- a/src/pl_bolts/models/self_supervised/moco/callbacks.py +++ b/src/pl_bolts/models/self_supervised/moco/callbacks.py @@ -1,7 +1,6 @@ import math -from pytorch_lightning import Callback - +from lightning import Callback from pl_bolts.utils.stability import under_review diff --git a/src/pl_bolts/models/self_supervised/moco/moco_module.py b/src/pl_bolts/models/self_supervised/moco/moco_module.py index 86d01147c..6dc44dc18 100644 --- a/src/pl_bolts/models/self_supervised/moco/moco_module.py +++ b/src/pl_bolts/models/self_supervised/moco/moco_module.py @@ -8,13 +8,14 @@ You may obtain a copy of the License from the LICENSE file present in this folder. """ + from copy import copy, deepcopy from typing import Any, Dict, List, Optional, Tuple, Type, Union import torch -from pytorch_lightning import LightningModule -from pytorch_lightning.strategies import DDPStrategy -from pytorch_lightning.utilities.types import STEP_OUTPUT +from lightning import LightningModule +from lightning.pytorch.strategies import DDPStrategy +from lightning.pytorch.utilities.types import STEP_OUTPUT from torch import Tensor, nn, optim from torch.nn import functional as F # noqa: N812 from torch.utils.data import DataLoader, Dataset @@ -28,7 +29,12 @@ from pl_bolts.datamodules import CIFAR10DataModule from pl_bolts.metrics import precision_at_k -from pl_bolts.models.self_supervised.moco.utils import concatenate_all, shuffle_batch, sort_batch, validate_batch +from pl_bolts.models.self_supervised.moco.utils import ( + concatenate_all, + shuffle_batch, + sort_batch, + validate_batch, +) from pl_bolts.transforms.self_supervised.moco_transforms import ( MoCo2EvalCIFAR10Transforms, MoCo2TrainCIFAR10Transforms, @@ -323,7 +329,7 @@ def _data_loader(self, dataset: Dataset, shuffle: bool = False) -> DataLoader: def cli_main() -> None: - from pytorch_lightning.cli import LightningCLI + from lightning.cli import LightningCLI LightningCLI(MoCo, CIFAR10ContrastiveDataModule, seed_everything_default=42) diff --git a/src/pl_bolts/models/self_supervised/simclr/simclr_finetuner.py b/src/pl_bolts/models/self_supervised/simclr/simclr_finetuner.py index 2f16509ac..c1e65d6c3 100644 --- a/src/pl_bolts/models/self_supervised/simclr/simclr_finetuner.py +++ b/src/pl_bolts/models/self_supervised/simclr/simclr_finetuner.py @@ -1,8 +1,7 @@ import os from argparse import ArgumentParser -from pytorch_lightning import Trainer, seed_everything - +from lightning import Trainer, seed_everything from pl_bolts.models.self_supervised.simclr.simclr_module import SimCLR from pl_bolts.models.self_supervised.ssl_finetuner import SSLFineTuner from pl_bolts.transforms.dataset_normalizations import ( @@ -10,13 +9,19 @@ imagenet_normalization, stl10_normalization, ) -from pl_bolts.transforms.self_supervised.simclr_transforms import SimCLRFinetuneTransform +from pl_bolts.transforms.self_supervised.simclr_transforms import ( + SimCLRFinetuneTransform, +) from pl_bolts.utils.stability import under_review @under_review() def cli_main(): # pragma: no cover - from pl_bolts.datamodules import CIFAR10DataModule, ImagenetDataModule, STL10DataModule + from pl_bolts.datamodules import ( + CIFAR10DataModule, + ImagenetDataModule, + STL10DataModule, + ) seed_everything(1234) diff --git a/src/pl_bolts/models/self_supervised/simclr/simclr_module.py b/src/pl_bolts/models/self_supervised/simclr/simclr_module.py index d17ddf7a1..f686c7db1 100644 --- a/src/pl_bolts/models/self_supervised/simclr/simclr_module.py +++ b/src/pl_bolts/models/self_supervised/simclr/simclr_module.py @@ -2,11 +2,8 @@ from argparse import ArgumentParser import torch -from pytorch_lightning import LightningModule, Trainer -from pytorch_lightning.callbacks import LearningRateMonitor, ModelCheckpoint -from torch import Tensor, nn -from torch.nn import functional as F # noqa: N812 - +from lightning import LightningModule, Trainer +from lightning.pytorch.callbacks import LearningRateMonitor, ModelCheckpoint from pl_bolts.models.self_supervised.resnets import resnet18, resnet50 from pl_bolts.optimizers.lars import LARS from pl_bolts.optimizers.lr_scheduler import linear_warmup_decay @@ -16,6 +13,8 @@ stl10_normalization, ) from pl_bolts.utils.stability import under_review +from torch import Tensor, nn +from torch.nn import functional as F # noqa: N812 @under_review() @@ -301,8 +300,15 @@ def add_model_specific_args(parent_parser): @under_review() def cli_main(): from pl_bolts.callbacks.ssl_online import SSLOnlineEvaluator - from pl_bolts.datamodules import CIFAR10DataModule, ImagenetDataModule, STL10DataModule - from pl_bolts.transforms.self_supervised.simclr_transforms import SimCLREvalDataTransform, SimCLRTrainDataTransform + from pl_bolts.datamodules import ( + CIFAR10DataModule, + ImagenetDataModule, + STL10DataModule, + ) + from pl_bolts.transforms.self_supervised.simclr_transforms import ( + SimCLREvalDataTransform, + SimCLRTrainDataTransform, + ) parser = ArgumentParser() diff --git a/src/pl_bolts/models/self_supervised/simsiam/simsiam_module.py b/src/pl_bolts/models/self_supervised/simsiam/simsiam_module.py index 9bd79906f..aa7fd83d3 100644 --- a/src/pl_bolts/models/self_supervised/simsiam/simsiam_module.py +++ b/src/pl_bolts/models/self_supervised/simsiam/simsiam_module.py @@ -5,11 +5,10 @@ import torch import torch.nn as nn import torch.nn.functional as F # noqa: N812 -from pytorch_lightning import LightningModule, Trainer, seed_everything -from torch import Tensor - +from lightning import LightningModule, Trainer, seed_everything from pl_bolts.models.self_supervised.byol.models import MLP, SiameseArm from pl_bolts.optimizers.lr_scheduler import LinearWarmupCosineAnnealingLR +from torch import Tensor class SimSiam(LightningModule): @@ -190,8 +189,15 @@ def add_model_specific_args(parent_parser: ArgumentParser) -> ArgumentParser: def cli_main(): from pl_bolts.callbacks.ssl_online import SSLOnlineEvaluator - from pl_bolts.datamodules import CIFAR10DataModule, ImagenetDataModule, STL10DataModule - from pl_bolts.models.self_supervised.simclr import SimCLREvalDataTransform, SimCLRTrainDataTransform + from pl_bolts.datamodules import ( + CIFAR10DataModule, + ImagenetDataModule, + STL10DataModule, + ) + from pl_bolts.models.self_supervised.simclr import ( + SimCLREvalDataTransform, + SimCLRTrainDataTransform, + ) seed_everything(1234) diff --git a/src/pl_bolts/models/self_supervised/ssl_finetuner.py b/src/pl_bolts/models/self_supervised/ssl_finetuner.py index b5f74e7ab..32266e91d 100644 --- a/src/pl_bolts/models/self_supervised/ssl_finetuner.py +++ b/src/pl_bolts/models/self_supervised/ssl_finetuner.py @@ -1,12 +1,11 @@ from typing import Optional, Tuple import torch -from pytorch_lightning import LightningModule +from lightning import LightningModule +from pl_bolts.models.self_supervised import SSLEvaluator from torch.nn import functional as F # noqa: N812 from torchmetrics import Accuracy -from pl_bolts.models.self_supervised import SSLEvaluator - class SSLFineTuner(LightningModule): """Finetunes a self-supervised learning backbone using the standard evaluation protocol of a singler layer MLP with diff --git a/src/pl_bolts/models/self_supervised/swav/swav_finetuner.py b/src/pl_bolts/models/self_supervised/swav/swav_finetuner.py index b8c6b4ebb..e3e3b9c44 100644 --- a/src/pl_bolts/models/self_supervised/swav/swav_finetuner.py +++ b/src/pl_bolts/models/self_supervised/swav/swav_finetuner.py @@ -1,11 +1,13 @@ import os from argparse import ArgumentParser -from pytorch_lightning import Trainer, seed_everything - +from lightning import Trainer, seed_everything from pl_bolts.models.self_supervised.ssl_finetuner import SSLFineTuner from pl_bolts.models.self_supervised.swav.swav_module import SwAV -from pl_bolts.transforms.dataset_normalizations import imagenet_normalization, stl10_normalization +from pl_bolts.transforms.dataset_normalizations import ( + imagenet_normalization, + stl10_normalization, +) from pl_bolts.transforms.self_supervised.swav_transforms import SwAVFinetuneTransform diff --git a/src/pl_bolts/models/self_supervised/swav/swav_module.py b/src/pl_bolts/models/self_supervised/swav/swav_module.py index e212358c4..e8d7dbbba 100644 --- a/src/pl_bolts/models/self_supervised/swav/swav_module.py +++ b/src/pl_bolts/models/self_supervised/swav/swav_module.py @@ -1,12 +1,11 @@ """Adapted from official swav implementation: https://github.com/facebookresearch/swav.""" + import os from argparse import ArgumentParser import torch -from pytorch_lightning import LightningModule, Trainer -from pytorch_lightning.callbacks import LearningRateMonitor, ModelCheckpoint -from torch import nn - +from lightning import LightningModule, Trainer +from lightning.pytorch.callbacks import LearningRateMonitor, ModelCheckpoint from pl_bolts.models.self_supervised.swav.loss import SWAVLoss from pl_bolts.models.self_supervised.swav.swav_resnet import resnet18, resnet50 from pl_bolts.optimizers.lars import LARS @@ -16,6 +15,7 @@ imagenet_normalization, stl10_normalization, ) +from torch import nn class SwAV(LightningModule): @@ -381,8 +381,15 @@ def add_model_specific_args(parent_parser): def cli_main(): from pl_bolts.callbacks.ssl_online import SSLOnlineEvaluator - from pl_bolts.datamodules import CIFAR10DataModule, ImagenetDataModule, STL10DataModule - from pl_bolts.transforms.self_supervised.swav_transforms import SwAVEvalDataTransform, SwAVTrainDataTransform + from pl_bolts.datamodules import ( + CIFAR10DataModule, + ImagenetDataModule, + STL10DataModule, + ) + from pl_bolts.transforms.self_supervised.swav_transforms import ( + SwAVEvalDataTransform, + SwAVTrainDataTransform, + ) parser = ArgumentParser() diff --git a/src/pl_bolts/models/vision/image_gpt/gpt2.py b/src/pl_bolts/models/vision/image_gpt/gpt2.py index 4d6b4ec2e..5bf946182 100644 --- a/src/pl_bolts/models/vision/image_gpt/gpt2.py +++ b/src/pl_bolts/models/vision/image_gpt/gpt2.py @@ -1,8 +1,7 @@ import torch -from pytorch_lightning import LightningModule -from torch import nn - +from lightning import LightningModule from pl_bolts.utils.stability import under_review +from torch import nn @under_review() diff --git a/src/pl_bolts/models/vision/image_gpt/igpt_module.py b/src/pl_bolts/models/vision/image_gpt/igpt_module.py index 89c48c536..8cbe9fd1b 100644 --- a/src/pl_bolts/models/vision/image_gpt/igpt_module.py +++ b/src/pl_bolts/models/vision/image_gpt/igpt_module.py @@ -2,11 +2,10 @@ from argparse import ArgumentParser import torch -from pytorch_lightning import LightningModule, Trainer -from torch import nn - +from lightning import LightningModule, Trainer from pl_bolts.models.vision.image_gpt.gpt2 import GPT2 from pl_bolts.utils.stability import under_review +from torch import nn @under_review() diff --git a/src/pl_bolts/models/vision/segmentation.py b/src/pl_bolts/models/vision/segmentation.py index 550b6af8b..0ff2354db 100644 --- a/src/pl_bolts/models/vision/segmentation.py +++ b/src/pl_bolts/models/vision/segmentation.py @@ -2,12 +2,11 @@ from typing import Any, Dict, Optional import torch -from pytorch_lightning import LightningModule, Trainer, seed_everything +from lightning import LightningModule, Trainer, seed_everything +from pl_bolts.models.vision.unet import UNet from torch import Tensor from torch.nn import functional as F # noqa: N812 -from pl_bolts.models.vision.unet import UNet - class SemSegment(LightningModule): """Basic model for semantic segmentation. Uses UNet architecture by default. diff --git a/src/pl_bolts/utils/__init__.py b/src/pl_bolts/utils/__init__.py index c2f7e4c98..795555107 100644 --- a/src/pl_bolts/utils/__init__.py +++ b/src/pl_bolts/utils/__init__.py @@ -3,8 +3,9 @@ import torch from lightning_utilities.core.imports import compare_version, module_available - -from pl_bolts.callbacks.verification.batch_gradient import BatchGradientVerification # type: ignore +from pl_bolts.callbacks.verification.batch_gradient import ( + BatchGradientVerification, # type: ignore +) _NATIVE_AMP_AVAILABLE: bool = module_available("torch.cuda.amp") and hasattr(torch.cuda.amp, "autocast") _IS_WINDOWS = platform.system() == "Windows" diff --git a/src/pl_bolts/utils/_dependency.py b/src/pl_bolts/utils/_dependency.py index 81ce2d912..32571c478 100644 --- a/src/pl_bolts/utils/_dependency.py +++ b/src/pl_bolts/utils/_dependency.py @@ -2,7 +2,10 @@ import os from typing import Any, Callable -from lightning_utilities.core.imports import ModuleAvailableCache, RequirementCache +from lightning.fabric.utilities.core.imports import ( + ModuleAvailableCache, + RequirementCache, +) # ToDo: replace with utils wrapper after 0.10 is released diff --git a/src/pl_bolts/utils/arguments.py b/src/pl_bolts/utils/arguments.py index 0325d6575..1d2f88e1c 100644 --- a/src/pl_bolts/utils/arguments.py +++ b/src/pl_bolts/utils/arguments.py @@ -3,8 +3,7 @@ from dataclasses import dataclass from typing import Any, Dict, List, Optional -from pytorch_lightning import LightningDataModule, LightningModule - +from lightning import LightningDataModule, LightningModule from pl_bolts.utils.stability import under_review diff --git a/src/pl_bolts/utils/pretrained_weights.py b/src/pl_bolts/utils/pretrained_weights.py index 59561557d..6954c4611 100644 --- a/src/pl_bolts/utils/pretrained_weights.py +++ b/src/pl_bolts/utils/pretrained_weights.py @@ -1,7 +1,6 @@ from typing import Optional -from pytorch_lightning import LightningModule - +from lightning import LightningModule from pl_bolts.utils.stability import under_review vae_imagenet2012 = ( diff --git a/src/pl_bolts/utils/stability.py b/src/pl_bolts/utils/stability.py index 3f0e4e6b0..867d026df 100644 --- a/src/pl_bolts/utils/stability.py +++ b/src/pl_bolts/utils/stability.py @@ -16,7 +16,7 @@ from typing import Callable, Optional, Type, Union from warnings import filterwarnings -from pytorch_lightning.utilities import rank_zero_warn +from lightning.fabric.utilities import rank_zero_warn class UnderReviewWarning(Warning): diff --git a/tests/__init__.py b/tests/__init__.py index 726ea0f46..4b61c66a2 100644 --- a/tests/__init__.py +++ b/tests/__init__.py @@ -1,7 +1,7 @@ import os import torch -from pytorch_lightning import seed_everything +from lightning.pytorch import seed_everything TEST_ROOT = os.path.realpath(os.path.dirname(__file__)) PROJECT_ROOT = os.path.dirname(TEST_ROOT) diff --git a/tests/callbacks/test_data_monitor.py b/tests/callbacks/test_data_monitor.py index 3d3ddc0c4..d5d0ae040 100644 --- a/tests/callbacks/test_data_monitor.py +++ b/tests/callbacks/test_data_monitor.py @@ -3,10 +3,10 @@ import pytest import torch +from lightning.pytorch import Trainer from pl_bolts.callbacks import ModuleDataMonitor, TrainingDataMonitor from pl_bolts.datamodules import MNISTDataModule from pl_bolts.models import LitMNIST -from pytorch_lightning import Trainer from torch import nn diff --git a/tests/callbacks/test_ort.py b/tests/callbacks/test_ort.py index 9186c3ad3..23d3976f8 100644 --- a/tests/callbacks/test_ort.py +++ b/tests/callbacks/test_ort.py @@ -13,12 +13,11 @@ # limitations under the License. import pytest +from lightning.pytorch import Callback, Trainer +from lightning.pytorch.core.module import LightningModule +from lightning.pytorch.utilities.exceptions import MisconfigurationException from pl_bolts.callbacks import ORTCallback from pl_bolts.utils import _TORCH_ORT_AVAILABLE -from pytorch_lightning import Callback, Trainer -from pytorch_lightning.core.lightning import LightningModule -from pytorch_lightning.utilities.exceptions import MisconfigurationException - from tests.helpers.boring_model import BoringModel if _TORCH_ORT_AVAILABLE: diff --git a/tests/callbacks/test_sparseml.py b/tests/callbacks/test_sparseml.py index fdef95808..fd06e7cad 100644 --- a/tests/callbacks/test_sparseml.py +++ b/tests/callbacks/test_sparseml.py @@ -16,11 +16,10 @@ import pytest import torch +from lightning.pytorch import Callback, LightningModule, Trainer +from lightning.pytorch.utilities.exceptions import MisconfigurationException from pl_bolts.callbacks import SparseMLCallback from pl_bolts.utils import _SPARSEML_TORCH_SATISFIED -from pytorch_lightning import Callback, LightningModule, Trainer -from pytorch_lightning.utilities.exceptions import MisconfigurationException - from tests.helpers.boring_model import BoringModel if _SPARSEML_TORCH_SATISFIED: diff --git a/tests/callbacks/test_variational_callbacks.py b/tests/callbacks/test_variational_callbacks.py index 6994beb10..d51044a81 100644 --- a/tests/callbacks/test_variational_callbacks.py +++ b/tests/callbacks/test_variational_callbacks.py @@ -2,9 +2,9 @@ from pl_bolts.models.gans import GAN try: - from pytorch_lightning.loggers.logger import DummyLogger # PL v1.9+ + from lightning.pytorch.loggers.logger import DummyLogger # PL v1.9+ except ModuleNotFoundError: - from pytorch_lightning.loggers.base import DummyLogger # PL v1.8 + from lightning.pytorch.loggers.base import DummyLogger # PL v1.8 def test_latent_dim_interpolator(): diff --git a/tests/callbacks/verification/test_base.py b/tests/callbacks/verification/test_base.py index 624d2a920..675d18249 100644 --- a/tests/callbacks/verification/test_base.py +++ b/tests/callbacks/verification/test_base.py @@ -3,11 +3,10 @@ import pytest import torch import torch.nn as nn +from lightning.pytorch import LightningModule +from lightning.pytorch.utilities import move_data_to_device from pl_bolts.callbacks.verification.base import VerificationBase from pl_bolts.utils import _PL_GREATER_EQUAL_1_4 -from pytorch_lightning import LightningModule -from pytorch_lightning.utilities import move_data_to_device - from tests import _MARK_REQUIRE_GPU diff --git a/tests/callbacks/verification/test_batch_gradient.py b/tests/callbacks/verification/test_batch_gradient.py index bc25bf3b4..7fa9bd545 100644 --- a/tests/callbacks/verification/test_batch_gradient.py +++ b/tests/callbacks/verification/test_batch_gradient.py @@ -2,14 +2,17 @@ import pytest import torch +from lightning.pytorch import LightningModule, Trainer +from lightning.pytorch.utilities.exceptions import MisconfigurationException from pl_bolts.callbacks import BatchGradientVerificationCallback -from pl_bolts.callbacks.verification.batch_gradient import default_input_mapping, default_output_mapping, selective_eval +from pl_bolts.callbacks.verification.batch_gradient import ( + default_input_mapping, + default_output_mapping, + selective_eval, +) from pl_bolts.utils import BatchGradientVerification -from pytorch_lightning import LightningModule, Trainer -from pytorch_lightning.utilities.exceptions import MisconfigurationException -from torch import Tensor, nn - from tests import _MARK_REQUIRE_GPU +from torch import Tensor, nn class TemplateModel(nn.Module): diff --git a/tests/conftest.py b/tests/conftest.py index b1340ac33..d5e920308 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -5,9 +5,15 @@ import pytest import torch -from pl_bolts.utils import _IS_WINDOWS, _TORCHVISION_AVAILABLE, _TORCHVISION_LESS_THAN_0_13 +from lightning.pytorch.trainer.connectors.signal_connector import ( + _SignalConnector as SignalConnector, +) +from pl_bolts.utils import ( + _IS_WINDOWS, + _TORCHVISION_AVAILABLE, + _TORCHVISION_LESS_THAN_0_13, +) from pl_bolts.utils.stability import UnderReviewWarning -from pytorch_lightning.trainer.connectors.signal_connector import SignalConnector # GitHub Actions use this path to cache datasets. # Use `datadir` fixture where possible and use `DATASETS_PATH` in diff --git a/tests/datamodules/test_sklearn_dataloaders.py b/tests/datamodules/test_sklearn_dataloaders.py index f8bad22eb..7c37f280f 100644 --- a/tests/datamodules/test_sklearn_dataloaders.py +++ b/tests/datamodules/test_sklearn_dataloaders.py @@ -2,8 +2,8 @@ import numpy as np import pytest +from lightning.pytorch import seed_everything from pl_bolts.datamodules.sklearn_datamodule import SklearnDataModule -from pytorch_lightning import seed_everything try: from sklearn.utils import shuffle as sk_shuffle diff --git a/tests/datasets/test_array_dataset.py b/tests/datasets/test_array_dataset.py index 0bf87bd9c..01ade8071 100644 --- a/tests/datasets/test_array_dataset.py +++ b/tests/datasets/test_array_dataset.py @@ -1,9 +1,9 @@ import numpy as np import pytest import torch +from lightning.pytorch.utilities import exceptions from pl_bolts.datasets import ArrayDataset, DataModel from pl_bolts.datasets.utils import to_tensor -from pytorch_lightning.utilities import exceptions class TestArrayDataset: diff --git a/tests/helpers/boring_model.py b/tests/helpers/boring_model.py index cb98a0190..4c5316afa 100644 --- a/tests/helpers/boring_model.py +++ b/tests/helpers/boring_model.py @@ -14,7 +14,7 @@ from typing import Optional import torch -from pytorch_lightning import LightningDataModule, LightningModule +from lightning.pytorch import LightningDataModule, LightningModule from torch.utils.data import DataLoader, Dataset, IterableDataset, Subset diff --git a/tests/models/gans/integration/test_gans.py b/tests/models/gans/integration/test_gans.py index b04172a0e..8a7a0b70d 100644 --- a/tests/models/gans/integration/test_gans.py +++ b/tests/models/gans/integration/test_gans.py @@ -1,12 +1,12 @@ import warnings import pytest +from lightning.pytorch import Trainer, seed_everything +from lightning.pytorch.utilities.warnings import PossibleUserWarning from pl_bolts.datamodules import CIFAR10DataModule, MNISTDataModule from pl_bolts.datasets.sr_mnist_dataset import SRMNIST from pl_bolts.models.gans import DCGAN, GAN, SRGAN, SRResNet from pl_bolts.utils import _IS_WINDOWS -from pytorch_lightning import Trainer, seed_everything -from pytorch_lightning.utilities.warnings import PossibleUserWarning from torch.utils.data.dataloader import DataLoader from torchvision import transforms as transform_lib diff --git a/tests/models/gans/unit/test_basic_components.py b/tests/models/gans/unit/test_basic_components.py index e761678e0..c9693a312 100644 --- a/tests/models/gans/unit/test_basic_components.py +++ b/tests/models/gans/unit/test_basic_components.py @@ -1,7 +1,7 @@ import pytest import torch +from lightning.pytorch import seed_everything from pl_bolts.models.gans.basic.components import Discriminator, Generator -from pytorch_lightning import seed_everything @pytest.mark.parametrize( diff --git a/tests/models/rl/integration/test_actor_critic_models.py b/tests/models/rl/integration/test_actor_critic_models.py index 041da1672..ed555c0ef 100644 --- a/tests/models/rl/integration/test_actor_critic_models.py +++ b/tests/models/rl/integration/test_actor_critic_models.py @@ -2,10 +2,10 @@ import pytest import torch.cuda +from lightning.pytorch import Trainer from pl_bolts.models.rl.advantage_actor_critic_model import AdvantageActorCritic from pl_bolts.models.rl.sac_model import SAC from pl_bolts.utils import _GYM_GREATER_EQUAL_0_20 -from pytorch_lightning import Trainer def test_a2c_cli(): diff --git a/tests/models/rl/integration/test_policy_models.py b/tests/models/rl/integration/test_policy_models.py index 62a8fff3d..22f8a7342 100644 --- a/tests/models/rl/integration/test_policy_models.py +++ b/tests/models/rl/integration/test_policy_models.py @@ -2,9 +2,9 @@ from unittest import TestCase import torch +from lightning.pytorch import Trainer from pl_bolts.models.rl.reinforce_model import Reinforce from pl_bolts.models.rl.vanilla_policy_gradient_model import VanillaPolicyGradient -from pytorch_lightning import Trainer class TestPolicyModels(TestCase): diff --git a/tests/models/rl/integration/test_value_models.py b/tests/models/rl/integration/test_value_models.py index ac1f5adb1..e127a9880 100644 --- a/tests/models/rl/integration/test_value_models.py +++ b/tests/models/rl/integration/test_value_models.py @@ -3,13 +3,13 @@ import pytest import torch +from lightning.pytorch import Trainer from pl_bolts.models.rl.double_dqn_model import DoubleDQN from pl_bolts.models.rl.dqn_model import DQN from pl_bolts.models.rl.dueling_dqn_model import DuelingDQN from pl_bolts.models.rl.noisy_dqn_model import NoisyDQN from pl_bolts.models.rl.per_dqn_model import PERDQN from pl_bolts.utils import _IS_WINDOWS -from pytorch_lightning import Trainer class TestValueModels(TestCase): diff --git a/tests/models/rl/unit/test_ppo.py b/tests/models/rl/unit/test_ppo.py index eac100dc0..e1e2f40e5 100644 --- a/tests/models/rl/unit/test_ppo.py +++ b/tests/models/rl/unit/test_ppo.py @@ -1,7 +1,7 @@ import numpy as np import torch +from lightning.pytorch import Trainer from pl_bolts.models.rl.ppo_model import PPO -from pytorch_lightning import Trainer from torch import Tensor diff --git a/tests/models/self_supervised/test_models.py b/tests/models/self_supervised/test_models.py index d5c010848..5b632860d 100644 --- a/tests/models/self_supervised/test_models.py +++ b/tests/models/self_supervised/test_models.py @@ -2,18 +2,37 @@ import pytest import torch +from lightning.pytorch import Trainer +from lightning.pytorch.utilities.warnings import PossibleUserWarning from pl_bolts.datamodules import CIFAR10DataModule -from pl_bolts.models.self_supervised import AMDIM, BYOL, CPC_v2, MoCo, SimCLR, SimSiam, SwAV -from pl_bolts.models.self_supervised.cpc import CPCEvalTransformsCIFAR10, CPCTrainTransformsCIFAR10 +from pl_bolts.models.self_supervised import ( + AMDIM, + BYOL, + CPC_v2, + MoCo, + SimCLR, + SimSiam, + SwAV, +) +from pl_bolts.models.self_supervised.cpc import ( + CPCEvalTransformsCIFAR10, + CPCTrainTransformsCIFAR10, +) from pl_bolts.models.self_supervised.moco.callbacks import MoCoLRScheduler from pl_bolts.transforms.dataset_normalizations import cifar10_normalization -from pl_bolts.transforms.self_supervised.moco_transforms import MoCo2EvalCIFAR10Transforms, MoCo2TrainCIFAR10Transforms -from pl_bolts.transforms.self_supervised.simclr_transforms import SimCLREvalDataTransform, SimCLRTrainDataTransform -from pl_bolts.transforms.self_supervised.swav_transforms import SwAVEvalDataTransform, SwAVTrainDataTransform +from pl_bolts.transforms.self_supervised.moco_transforms import ( + MoCo2EvalCIFAR10Transforms, + MoCo2TrainCIFAR10Transforms, +) +from pl_bolts.transforms.self_supervised.simclr_transforms import ( + SimCLREvalDataTransform, + SimCLRTrainDataTransform, +) +from pl_bolts.transforms.self_supervised.swav_transforms import ( + SwAVEvalDataTransform, + SwAVTrainDataTransform, +) from pl_bolts.utils import _IS_WINDOWS -from pytorch_lightning import Trainer -from pytorch_lightning.utilities.warnings import PossibleUserWarning - from tests import _MARK_REQUIRE_GPU diff --git a/tests/models/test_autoencoders.py b/tests/models/test_autoencoders.py index bf3162b25..53f8c0112 100644 --- a/tests/models/test_autoencoders.py +++ b/tests/models/test_autoencoders.py @@ -1,8 +1,14 @@ import pytest import torch +from lightning.pytorch import Trainer, seed_everything from pl_bolts.datamodules import CIFAR10DataModule -from pl_bolts.models.autoencoders import AE, VAE, resnet18_decoder, resnet18_encoder, resnet50_encoder -from pytorch_lightning import Trainer, seed_everything +from pl_bolts.models.autoencoders import ( + AE, + VAE, + resnet18_decoder, + resnet18_encoder, + resnet50_encoder, +) @pytest.mark.parametrize("dm_cls", [pytest.param(CIFAR10DataModule, id="cifar10")]) diff --git a/tests/models/test_classic_ml.py b/tests/models/test_classic_ml.py index f4eac0955..caac82d1d 100644 --- a/tests/models/test_classic_ml.py +++ b/tests/models/test_classic_ml.py @@ -1,8 +1,8 @@ import numpy as np import pytest +from lightning.pytorch import Trainer, seed_everything from pl_bolts.datamodules.sklearn_datamodule import SklearnDataset from pl_bolts.models.regression import LinearRegression -from pytorch_lightning import Trainer, seed_everything from torch.utils.data import DataLoader diff --git a/tests/models/test_detection.py b/tests/models/test_detection.py index fb207b343..762ac2bf0 100644 --- a/tests/models/test_detection.py +++ b/tests/models/test_detection.py @@ -3,6 +3,8 @@ import pytest import torch +from lightning.pytorch import Trainer +from lightning.pytorch.utilities.warnings import PossibleUserWarning from pl_bolts.datasets import DummyDetectionDataset from pl_bolts.models.detection import ( YOLO, @@ -18,11 +20,8 @@ ) from pl_bolts.models.detection.faster_rcnn import create_fasterrcnn_backbone from pl_bolts.utils import _IS_WINDOWS -from pytorch_lightning import Trainer -from pytorch_lightning.utilities.warnings import PossibleUserWarning -from torch.utils.data import DataLoader - from tests import TEST_ROOT +from torch.utils.data import DataLoader def _collate_fn(batch): diff --git a/tests/models/test_mnist_templates.py b/tests/models/test_mnist_templates.py index 58323cf35..42062c203 100644 --- a/tests/models/test_mnist_templates.py +++ b/tests/models/test_mnist_templates.py @@ -1,9 +1,9 @@ import warnings +from lightning.pytorch import Trainer, seed_everything +from lightning.pytorch.utilities.warnings import PossibleUserWarning from pl_bolts.datamodules import MNISTDataModule from pl_bolts.models import LitMNIST -from pytorch_lightning import Trainer, seed_everything -from pytorch_lightning.utilities.warnings import PossibleUserWarning def test_mnist(tmpdir, datadir, catch_warnings): diff --git a/tests/models/test_vision.py b/tests/models/test_vision.py index 3f0b9a7d0..715ef3d13 100644 --- a/tests/models/test_vision.py +++ b/tests/models/test_vision.py @@ -2,16 +2,17 @@ import pytest import torch +from lightning.pytorch import LightningDataModule, Trainer +from lightning.pytorch import __version__ as pl_version +from lightning.pytorch import seed_everything +from lightning.pytorch.callbacks.progress import TQDMProgressBar +from lightning.pytorch.utilities.warnings import PossibleUserWarning from packaging import version from pl_bolts.datamodules import FashionMNISTDataModule, MNISTDataModule from pl_bolts.datasets import DummyDataset from pl_bolts.models.vision import GPT2, ImageGPT, SemSegment, UNet from pl_bolts.models.vision.unet import DoubleConv, Down, Up from pl_bolts.utils import _IS_WINDOWS -from pytorch_lightning import LightningDataModule, Trainer, seed_everything -from pytorch_lightning import __version__ as pl_version -from pytorch_lightning.callbacks.progress import TQDMProgressBar -from pytorch_lightning.utilities.warnings import PossibleUserWarning from torch.utils.data import DataLoader diff --git a/tests/models/yolo/unit/test_darknet_network.py b/tests/models/yolo/unit/test_darknet_network.py index 18020e27b..1c51d33c4 100644 --- a/tests/models/yolo/unit/test_darknet_network.py +++ b/tests/models/yolo/unit/test_darknet_network.py @@ -2,13 +2,13 @@ import pytest import torch.nn as nn +from lightning.pytorch.utilities.warnings import PossibleUserWarning from pl_bolts.models.detection.yolo.darknet_network import ( _create_convolutional, _create_maxpool, _create_shortcut, _create_upsample, ) -from pytorch_lightning.utilities.warnings import PossibleUserWarning @pytest.mark.parametrize( diff --git a/tests/models/yolo/unit/test_utils.py b/tests/models/yolo/unit/test_utils.py index b883d2d60..6ff07ffea 100644 --- a/tests/models/yolo/unit/test_utils.py +++ b/tests/models/yolo/unit/test_utils.py @@ -2,6 +2,7 @@ import pytest import torch +from lightning.pytorch.utilities.warnings import PossibleUserWarning from pl_bolts.models.detection.yolo.utils import ( aligned_iou, box_size_ratio, @@ -11,7 +12,6 @@ iou_below, is_inside_box, ) -from pytorch_lightning.utilities.warnings import PossibleUserWarning @pytest.mark.parametrize(("width", "height"), [(10, 5)]) diff --git a/tests/optimizers/test_lr_scheduler.py b/tests/optimizers/test_lr_scheduler.py index 477aae5b7..f2e06dae3 100644 --- a/tests/optimizers/test_lr_scheduler.py +++ b/tests/optimizers/test_lr_scheduler.py @@ -2,8 +2,8 @@ import numpy as np import torch +from lightning.pytorch import seed_everything from pl_bolts.optimizers.lr_scheduler import LinearWarmupCosineAnnealingLR -from pytorch_lightning import seed_everything from torch.nn import functional as F # noqa: N812 from torch.optim import SGD from torch.optim.lr_scheduler import _LRScheduler diff --git a/tests/transforms/test_normalizations.py b/tests/transforms/test_normalizations.py index b0c22cd07..6f2c22f7f 100644 --- a/tests/transforms/test_normalizations.py +++ b/tests/transforms/test_normalizations.py @@ -1,12 +1,12 @@ import pytest import torch +from lightning.pytorch import seed_everything from pl_bolts.transforms.dataset_normalizations import ( cifar10_normalization, emnist_normalization, imagenet_normalization, stl10_normalization, ) -from pytorch_lightning import seed_everything @pytest.mark.parametrize( diff --git a/tests/transforms/test_transforms.py b/tests/transforms/test_transforms.py index a2082a19a..df004d594 100644 --- a/tests/transforms/test_transforms.py +++ b/tests/transforms/test_transforms.py @@ -2,7 +2,7 @@ import pytest import torch -from pytorch_lightning import seed_everything +from lightning.pytorch import seed_everything try: from torchvision import transforms @@ -35,7 +35,10 @@ MoCo2TrainImagenetTransforms, MoCo2TrainSTL10Transforms, ) -from pl_bolts.transforms.self_supervised.simclr_transforms import SimCLREvalDataTransform, SimCLRTrainDataTransform +from pl_bolts.transforms.self_supervised.simclr_transforms import ( + SimCLREvalDataTransform, + SimCLRTrainDataTransform, +) @pytest.mark.parametrize( diff --git a/tests/utils/test_arguments.py b/tests/utils/test_arguments.py index e4e293a22..e1629e102 100644 --- a/tests/utils/test_arguments.py +++ b/tests/utils/test_arguments.py @@ -1,8 +1,8 @@ from dataclasses import FrozenInstanceError import pytest +from lightning.pytorch import LightningDataModule, LightningModule from pl_bolts.utils.arguments import LightningArgumentParser, LitArg, gather_lit_args -from pytorch_lightning import LightningDataModule, LightningModule class DummyParentModel(LightningModule): From fbc8f2fa652faac8ec24babf0b942e1c23117d6f Mon Sep 17 00:00:00 2001 From: James Bishop Date: Sun, 5 May 2024 14:23:21 +0100 Subject: [PATCH 3/6] uplift torchmetrics, torchvision deps --- requirements/base.txt | 4 ++-- requirements/models.txt | 2 +- 2 files changed, 3 insertions(+), 3 deletions(-) diff --git a/requirements/base.txt b/requirements/base.txt index 74fcae877..ac8da8558 100644 --- a/requirements/base.txt +++ b/requirements/base.txt @@ -1,6 +1,6 @@ numpy <1.26.0 lightning >=2.0.0 # strict -torchmetrics >=0.10.0, <0.12.0 +torchmetrics >=0.7.0, <1.3.0 lightning-utilities >0.3.1 # this is needed for PL 1.7 -torchvision >=0.10.0 # todo: move to topic related extras +torchvision >=0.15.0, <0.19.0 # todo: move to topic related extras tensorboard >=2.9.1, <2.14.0 # for `TensorBoardLogger` diff --git a/requirements/models.txt b/requirements/models.txt index ec3dd912c..9d97235b6 100644 --- a/requirements/models.txt +++ b/requirements/models.txt @@ -1,4 +1,4 @@ -torchvision >=0.10.0 +torchvision >=0.15.0, <0.19.0 scikit-learn >=1.0.2 Pillow >9.0.0 gym[atari] >=0.17.2, <0.22.0 # strict From 9820f586ff7c282df12afab2fb9485a18625115c Mon Sep 17 00:00:00 2001 From: James Bishop Date: Sun, 5 May 2024 14:31:49 +0100 Subject: [PATCH 4/6] refactor missed imports --- src/pl_bolts/callbacks/knn_online.py | 2 +- src/pl_bolts/utils/_dependency.py | 5 +---- 2 files changed, 2 insertions(+), 5 deletions(-) diff --git a/src/pl_bolts/callbacks/knn_online.py b/src/pl_bolts/callbacks/knn_online.py index 0d0cb8896..9c392a39e 100644 --- a/src/pl_bolts/callbacks/knn_online.py +++ b/src/pl_bolts/callbacks/knn_online.py @@ -2,7 +2,7 @@ import torch from lightning import Callback, LightningModule, Trainer -from lightning.accelerators import Accelerator +from lightning.pytorch.accelerators import Accelerator from pl_bolts.utils.stability import under_review from torch import Tensor from torch.nn import functional as F # noqa: N812 diff --git a/src/pl_bolts/utils/_dependency.py b/src/pl_bolts/utils/_dependency.py index 32571c478..81ce2d912 100644 --- a/src/pl_bolts/utils/_dependency.py +++ b/src/pl_bolts/utils/_dependency.py @@ -2,10 +2,7 @@ import os from typing import Any, Callable -from lightning.fabric.utilities.core.imports import ( - ModuleAvailableCache, - RequirementCache, -) +from lightning_utilities.core.imports import ModuleAvailableCache, RequirementCache # ToDo: replace with utils wrapper after 0.10 is released From 1fa9ef599325b504bb38ed26ca24bdf3a926fad9 Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Fri, 7 Jun 2024 08:24:40 +0000 Subject: [PATCH 5/6] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- src/pl_bolts/callbacks/data_monitor.py | 9 +++++---- src/pl_bolts/callbacks/knn_online.py | 3 ++- src/pl_bolts/callbacks/printing.py | 1 + src/pl_bolts/callbacks/sparseml.py | 1 + src/pl_bolts/callbacks/ssl_online.py | 5 +++-- src/pl_bolts/callbacks/torch_ort.py | 1 + src/pl_bolts/callbacks/variational.py | 3 ++- src/pl_bolts/callbacks/verification/base.py | 1 + src/pl_bolts/callbacks/verification/batch_gradient.py | 3 ++- src/pl_bolts/callbacks/vision/confused_logit.py | 3 ++- src/pl_bolts/callbacks/vision/image_generation.py | 1 + src/pl_bolts/callbacks/vision/sr_image_logger.py | 1 + src/pl_bolts/datamodules/cityscapes_datamodule.py | 3 ++- src/pl_bolts/datamodules/imagenet_datamodule.py | 3 ++- src/pl_bolts/datamodules/kitti_datamodule.py | 5 +++-- src/pl_bolts/datamodules/sklearn_datamodule.py | 3 ++- src/pl_bolts/datamodules/sr_datamodule.py | 3 ++- src/pl_bolts/datamodules/ssl_imagenet_datamodule.py | 3 ++- src/pl_bolts/datamodules/stl10_datamodule.py | 3 ++- src/pl_bolts/datamodules/vocdetection_datamodule.py | 5 +++-- src/pl_bolts/datasets/array_dataset.py | 3 ++- src/pl_bolts/losses/self_supervised_learning.py | 3 ++- .../models/autoencoders/basic_ae/basic_ae_module.py | 5 +++-- .../models/autoencoders/basic_vae/basic_vae_module.py | 5 +++-- .../models/detection/faster_rcnn/faster_rcnn_module.py | 1 + .../models/detection/retinanet/retinanet_module.py | 2 ++ src/pl_bolts/models/gans/basic/basic_gan_module.py | 3 ++- src/pl_bolts/models/gans/dcgan/dcgan_module.py | 5 +++-- src/pl_bolts/models/gans/pix2pix/pix2pix_module.py | 3 ++- src/pl_bolts/models/mnist_module.py | 3 ++- src/pl_bolts/models/regression/linear_regression.py | 3 ++- src/pl_bolts/models/regression/logistic_regression.py | 3 ++- src/pl_bolts/models/rl/advantage_actor_critic_model.py | 7 ++++--- src/pl_bolts/models/rl/double_dqn_model.py | 3 ++- src/pl_bolts/models/rl/dqn_model.py | 7 ++++--- src/pl_bolts/models/rl/dueling_dqn_model.py | 1 + src/pl_bolts/models/rl/noisy_dqn_model.py | 3 ++- src/pl_bolts/models/rl/per_dqn_model.py | 5 +++-- src/pl_bolts/models/rl/ppo_model.py | 7 ++++--- src/pl_bolts/models/rl/reinforce_model.py | 9 +++++---- src/pl_bolts/models/rl/sac_model.py | 9 +++++---- src/pl_bolts/models/rl/vanilla_policy_gradient_model.py | 9 +++++---- .../models/self_supervised/amdim/amdim_module.py | 5 +++-- src/pl_bolts/models/self_supervised/byol/byol_module.py | 7 ++++--- src/pl_bolts/models/self_supervised/cpc/cpc_finetuner.py | 1 + src/pl_bolts/models/self_supervised/cpc/cpc_module.py | 3 ++- src/pl_bolts/models/self_supervised/moco/callbacks.py | 1 + .../models/self_supervised/simclr/simclr_finetuner.py | 1 + .../models/self_supervised/simclr/simclr_module.py | 5 +++-- .../models/self_supervised/simsiam/simsiam_module.py | 3 ++- src/pl_bolts/models/self_supervised/ssl_finetuner.py | 3 ++- .../models/self_supervised/swav/swav_finetuner.py | 1 + src/pl_bolts/models/self_supervised/swav/swav_module.py | 3 ++- src/pl_bolts/models/vision/image_gpt/gpt2.py | 3 ++- src/pl_bolts/models/vision/image_gpt/igpt_module.py | 3 ++- src/pl_bolts/models/vision/segmentation.py | 3 ++- src/pl_bolts/utils/__init__.py | 1 + src/pl_bolts/utils/arguments.py | 1 + src/pl_bolts/utils/pretrained_weights.py | 1 + tests/callbacks/test_ort.py | 1 + tests/callbacks/test_sparseml.py | 1 + tests/callbacks/verification/test_base.py | 1 + tests/callbacks/verification/test_batch_gradient.py | 3 ++- tests/models/self_supervised/test_models.py | 1 + tests/models/test_detection.py | 3 ++- tests/models/test_vision.py | 3 +-- 66 files changed, 141 insertions(+), 76 deletions(-) diff --git a/src/pl_bolts/callbacks/data_monitor.py b/src/pl_bolts/callbacks/data_monitor.py index 7a6466821..93a320968 100644 --- a/src/pl_bolts/callbacks/data_monitor.py +++ b/src/pl_bolts/callbacks/data_monitor.py @@ -6,13 +6,14 @@ from lightning.fabric.utilities import rank_zero_warn from lightning.fabric.utilities.apply_func import apply_to_collection from lightning.pytorch.loggers import TensorBoardLogger, WandbLogger -from pl_bolts.utils import _WANDB_AVAILABLE -from pl_bolts.utils.stability import under_review -from pl_bolts.utils.warnings import warn_missing_pkg from torch import Tensor, nn from torch.nn import Module from torch.utils.hooks import RemovableHandle +from pl_bolts.utils import _WANDB_AVAILABLE +from pl_bolts.utils.stability import under_review +from pl_bolts.utils.warnings import warn_missing_pkg + # Backward compatibility for Lightning Logger try: from lightning.pytorch.loggers import Logger @@ -111,7 +112,7 @@ def _is_logger_available(self, logger: Logger) -> bool: if not isinstance(logger, self.supported_loggers): rank_zero_warn( f"{self.__class__.__name__} does not support logging with {logger.__class__.__name__}." - f" Supported loggers are: {', '.join((str(x.__name__) for x in self.supported_loggers))}" + f" Supported loggers are: {', '.join(str(x.__name__) for x in self.supported_loggers)}" ) available = False return available diff --git a/src/pl_bolts/callbacks/knn_online.py b/src/pl_bolts/callbacks/knn_online.py index 9c392a39e..630c38e8a 100644 --- a/src/pl_bolts/callbacks/knn_online.py +++ b/src/pl_bolts/callbacks/knn_online.py @@ -3,10 +3,11 @@ import torch from lightning import Callback, LightningModule, Trainer from lightning.pytorch.accelerators import Accelerator -from pl_bolts.utils.stability import under_review from torch import Tensor from torch.nn import functional as F # noqa: N812 +from pl_bolts.utils.stability import under_review + @under_review() class KNNOnlineEvaluator(Callback): diff --git a/src/pl_bolts/callbacks/printing.py b/src/pl_bolts/callbacks/printing.py index 555250585..a3fab6db8 100644 --- a/src/pl_bolts/callbacks/printing.py +++ b/src/pl_bolts/callbacks/printing.py @@ -5,6 +5,7 @@ from lightning import LightningModule, Trainer from lightning.fabric.utilities import rank_zero_info from lightning.pytorch.callbacks import Callback + from pl_bolts.utils.stability import under_review diff --git a/src/pl_bolts/callbacks/sparseml.py b/src/pl_bolts/callbacks/sparseml.py index 430b867c4..64d594036 100644 --- a/src/pl_bolts/callbacks/sparseml.py +++ b/src/pl_bolts/callbacks/sparseml.py @@ -16,6 +16,7 @@ import torch from lightning import Callback, LightningModule, Trainer from lightning.fabric.utilities.exceptions import MisconfigurationException + from pl_bolts.utils import ( _SPARSEML_AVAILABLE, _SPARSEML_TORCH_SATISFIED, diff --git a/src/pl_bolts/callbacks/ssl_online.py b/src/pl_bolts/callbacks/ssl_online.py index 22bbd718f..791889bc8 100644 --- a/src/pl_bolts/callbacks/ssl_online.py +++ b/src/pl_bolts/callbacks/ssl_online.py @@ -4,13 +4,14 @@ import torch from lightning import Callback, LightningModule, Trainer from lightning.fabric.utilities import rank_zero_warn -from pl_bolts.models.self_supervised.evaluator import SSLEvaluator -from pl_bolts.utils.stability import under_review from torch import Tensor, nn from torch.nn import functional as F # noqa: N812 from torch.optim import Optimizer from torchmetrics.functional import accuracy +from pl_bolts.models.self_supervised.evaluator import SSLEvaluator +from pl_bolts.utils.stability import under_review + @under_review() class SSLOnlineEvaluator(Callback): # pragma: no cover diff --git a/src/pl_bolts/callbacks/torch_ort.py b/src/pl_bolts/callbacks/torch_ort.py index c9feef774..455365cd0 100644 --- a/src/pl_bolts/callbacks/torch_ort.py +++ b/src/pl_bolts/callbacks/torch_ort.py @@ -13,6 +13,7 @@ # limitations under the License. from lightning import Callback, LightningModule, Trainer from lightning.fabric.utilities.exceptions import MisconfigurationException + from pl_bolts.utils import _TORCH_ORT_AVAILABLE if _TORCH_ORT_AVAILABLE: diff --git a/src/pl_bolts/callbacks/variational.py b/src/pl_bolts/callbacks/variational.py index eac65a9d9..e76e81114 100644 --- a/src/pl_bolts/callbacks/variational.py +++ b/src/pl_bolts/callbacks/variational.py @@ -4,10 +4,11 @@ import torch from lightning import LightningModule, Trainer from lightning.pytorch.callbacks import Callback +from torch import Tensor + from pl_bolts.utils import _TORCHVISION_AVAILABLE from pl_bolts.utils.stability import under_review from pl_bolts.utils.warnings import warn_missing_pkg -from torch import Tensor if _TORCHVISION_AVAILABLE: import torchvision diff --git a/src/pl_bolts/callbacks/verification/base.py b/src/pl_bolts/callbacks/verification/base.py index 1134831f6..019b1ab4c 100644 --- a/src/pl_bolts/callbacks/verification/base.py +++ b/src/pl_bolts/callbacks/verification/base.py @@ -7,6 +7,7 @@ from lightning import Callback, LightningModule from lightning.fabric.utilities import move_data_to_device, rank_zero_warn from lightning.pytorch.utilities.signature_utils import is_param_in_hook_signature + from pl_bolts.utils.stability import under_review diff --git a/src/pl_bolts/callbacks/verification/batch_gradient.py b/src/pl_bolts/callbacks/verification/batch_gradient.py index 91299337e..eb09b4bc1 100644 --- a/src/pl_bolts/callbacks/verification/batch_gradient.py +++ b/src/pl_bolts/callbacks/verification/batch_gradient.py @@ -7,12 +7,13 @@ from lightning import LightningModule, Trainer from lightning.fabric.utilities.apply_func import apply_to_collection from lightning.fabric.utilities.exceptions import MisconfigurationException +from torch import Tensor + from pl_bolts.callbacks.verification.base import ( VerificationBase, VerificationCallbackBase, ) from pl_bolts.utils.stability import under_review -from torch import Tensor @under_review() diff --git a/src/pl_bolts/callbacks/vision/confused_logit.py b/src/pl_bolts/callbacks/vision/confused_logit.py index 1c23c423f..a7efd8abd 100644 --- a/src/pl_bolts/callbacks/vision/confused_logit.py +++ b/src/pl_bolts/callbacks/vision/confused_logit.py @@ -2,10 +2,11 @@ import torch from lightning import Callback, LightningModule, Trainer +from torch import Tensor, nn + from pl_bolts.utils import _MATPLOTLIB_AVAILABLE from pl_bolts.utils.stability import under_review from pl_bolts.utils.warnings import warn_missing_pkg -from torch import Tensor, nn if _MATPLOTLIB_AVAILABLE: from matplotlib import pyplot as plt diff --git a/src/pl_bolts/callbacks/vision/image_generation.py b/src/pl_bolts/callbacks/vision/image_generation.py index 3a7cc3df5..b629cc205 100644 --- a/src/pl_bolts/callbacks/vision/image_generation.py +++ b/src/pl_bolts/callbacks/vision/image_generation.py @@ -2,6 +2,7 @@ import torch from lightning import Callback, LightningModule, Trainer + from pl_bolts.utils import _TORCHVISION_AVAILABLE from pl_bolts.utils.stability import under_review from pl_bolts.utils.warnings import warn_missing_pkg diff --git a/src/pl_bolts/callbacks/vision/sr_image_logger.py b/src/pl_bolts/callbacks/vision/sr_image_logger.py index d4b178a8e..4412bf948 100644 --- a/src/pl_bolts/callbacks/vision/sr_image_logger.py +++ b/src/pl_bolts/callbacks/vision/sr_image_logger.py @@ -4,6 +4,7 @@ import torch import torch.nn.functional as F # noqa: N812 from lightning import Callback + from pl_bolts.utils import _TORCHVISION_AVAILABLE from pl_bolts.utils.stability import under_review from pl_bolts.utils.warnings import warn_missing_pkg diff --git a/src/pl_bolts/datamodules/cityscapes_datamodule.py b/src/pl_bolts/datamodules/cityscapes_datamodule.py index a1933eb7d..2842c3b1a 100644 --- a/src/pl_bolts/datamodules/cityscapes_datamodule.py +++ b/src/pl_bolts/datamodules/cityscapes_datamodule.py @@ -1,10 +1,11 @@ from typing import Any, Callable, Optional from lightning import LightningDataModule +from torch.utils.data import DataLoader + from pl_bolts.utils import _TORCHVISION_AVAILABLE from pl_bolts.utils.stability import under_review from pl_bolts.utils.warnings import warn_missing_pkg -from torch.utils.data import DataLoader if _TORCHVISION_AVAILABLE: from torchvision import transforms as transform_lib diff --git a/src/pl_bolts/datamodules/imagenet_datamodule.py b/src/pl_bolts/datamodules/imagenet_datamodule.py index 601186619..299df78d0 100644 --- a/src/pl_bolts/datamodules/imagenet_datamodule.py +++ b/src/pl_bolts/datamodules/imagenet_datamodule.py @@ -3,12 +3,13 @@ from typing import Any, Callable, Optional from lightning import LightningDataModule +from torch.utils.data import DataLoader + from pl_bolts.datasets import UnlabeledImagenet from pl_bolts.transforms.dataset_normalizations import imagenet_normalization from pl_bolts.utils import _TORCHVISION_AVAILABLE from pl_bolts.utils.stability import under_review from pl_bolts.utils.warnings import warn_missing_pkg -from torch.utils.data import DataLoader if _TORCHVISION_AVAILABLE: from torchvision import transforms as transform_lib diff --git a/src/pl_bolts/datamodules/kitti_datamodule.py b/src/pl_bolts/datamodules/kitti_datamodule.py index e7416bebb..72994e76f 100644 --- a/src/pl_bolts/datamodules/kitti_datamodule.py +++ b/src/pl_bolts/datamodules/kitti_datamodule.py @@ -3,12 +3,13 @@ import torch from lightning import LightningDataModule +from torch.utils.data import DataLoader +from torch.utils.data.dataset import random_split + from pl_bolts.datasets import KittiDataset from pl_bolts.utils import _TORCHVISION_AVAILABLE from pl_bolts.utils.stability import under_review from pl_bolts.utils.warnings import warn_missing_pkg -from torch.utils.data import DataLoader -from torch.utils.data.dataset import random_split if _TORCHVISION_AVAILABLE: from torchvision import transforms diff --git a/src/pl_bolts/datamodules/sklearn_datamodule.py b/src/pl_bolts/datamodules/sklearn_datamodule.py index 3c179d5b3..5f5c6f90e 100644 --- a/src/pl_bolts/datamodules/sklearn_datamodule.py +++ b/src/pl_bolts/datamodules/sklearn_datamodule.py @@ -3,10 +3,11 @@ import numpy as np from lightning import LightningDataModule +from torch.utils.data import DataLoader, Dataset + from pl_bolts.utils import _SKLEARN_AVAILABLE from pl_bolts.utils.stability import under_review from pl_bolts.utils.warnings import warn_missing_pkg -from torch.utils.data import DataLoader, Dataset if _SKLEARN_AVAILABLE: from sklearn.utils import shuffle as sk_shuffle diff --git a/src/pl_bolts/datamodules/sr_datamodule.py b/src/pl_bolts/datamodules/sr_datamodule.py index 657e5d320..ad7637058 100644 --- a/src/pl_bolts/datamodules/sr_datamodule.py +++ b/src/pl_bolts/datamodules/sr_datamodule.py @@ -1,9 +1,10 @@ from typing import Any from lightning import LightningDataModule -from pl_bolts.utils.stability import under_review from torch.utils.data import DataLoader, Dataset +from pl_bolts.utils.stability import under_review + @under_review() class TVTDataModule(LightningDataModule): diff --git a/src/pl_bolts/datamodules/ssl_imagenet_datamodule.py b/src/pl_bolts/datamodules/ssl_imagenet_datamodule.py index 5a489d850..01eab27c5 100644 --- a/src/pl_bolts/datamodules/ssl_imagenet_datamodule.py +++ b/src/pl_bolts/datamodules/ssl_imagenet_datamodule.py @@ -2,12 +2,13 @@ from typing import Any, Callable, Optional from lightning import LightningDataModule +from torch.utils.data import DataLoader + from pl_bolts.datasets import UnlabeledImagenet from pl_bolts.transforms.dataset_normalizations import imagenet_normalization from pl_bolts.utils import _TORCHVISION_AVAILABLE from pl_bolts.utils.stability import under_review from pl_bolts.utils.warnings import warn_missing_pkg -from torch.utils.data import DataLoader if _TORCHVISION_AVAILABLE: from torchvision import transforms as transform_lib diff --git a/src/pl_bolts/datamodules/stl10_datamodule.py b/src/pl_bolts/datamodules/stl10_datamodule.py index fa77682be..851b045eb 100644 --- a/src/pl_bolts/datamodules/stl10_datamodule.py +++ b/src/pl_bolts/datamodules/stl10_datamodule.py @@ -4,12 +4,13 @@ import torch from lightning import LightningDataModule +from torch.utils.data import DataLoader, random_split + from pl_bolts.datasets import ConcatDataset from pl_bolts.transforms.dataset_normalizations import stl10_normalization from pl_bolts.utils import _TORCHVISION_AVAILABLE from pl_bolts.utils.stability import under_review from pl_bolts.utils.warnings import warn_missing_pkg -from torch.utils.data import DataLoader, random_split if _TORCHVISION_AVAILABLE: from torchvision import transforms as transform_lib diff --git a/src/pl_bolts/datamodules/vocdetection_datamodule.py b/src/pl_bolts/datamodules/vocdetection_datamodule.py index 98af7df88..32137485d 100644 --- a/src/pl_bolts/datamodules/vocdetection_datamodule.py +++ b/src/pl_bolts/datamodules/vocdetection_datamodule.py @@ -3,11 +3,12 @@ import torch from lightning import LightningDataModule +from torch import Tensor +from torch.utils.data import DataLoader, Dataset + from pl_bolts.utils import _TORCHVISION_AVAILABLE from pl_bolts.utils.stability import under_review from pl_bolts.utils.warnings import warn_missing_pkg -from torch import Tensor -from torch.utils.data import DataLoader, Dataset if _TORCHVISION_AVAILABLE: from torchvision import transforms as transform_lib diff --git a/src/pl_bolts/datasets/array_dataset.py b/src/pl_bolts/datasets/array_dataset.py index 476cd4efa..1df4fe65c 100644 --- a/src/pl_bolts/datasets/array_dataset.py +++ b/src/pl_bolts/datasets/array_dataset.py @@ -1,9 +1,10 @@ from typing import Tuple, Union from lightning.fabric.utilities import exceptions -from pl_bolts.datasets.base_dataset import DataModel, TArrays from torch.utils.data import Dataset +from pl_bolts.datasets.base_dataset import DataModel, TArrays + class ArrayDataset(Dataset): """Dataset wrapping tensors, lists, numpy arrays. diff --git a/src/pl_bolts/losses/self_supervised_learning.py b/src/pl_bolts/losses/self_supervised_learning.py index 8fbe6f5f3..81a3f46f1 100644 --- a/src/pl_bolts/losses/self_supervised_learning.py +++ b/src/pl_bolts/losses/self_supervised_learning.py @@ -1,8 +1,9 @@ import numpy as np import torch +from torch import nn + from pl_bolts.models.vision.pixel_cnn import PixelCNN from pl_bolts.utils.stability import under_review -from torch import nn @under_review() diff --git a/src/pl_bolts/models/autoencoders/basic_ae/basic_ae_module.py b/src/pl_bolts/models/autoencoders/basic_ae/basic_ae_module.py index 112f40a6f..d4476f262 100644 --- a/src/pl_bolts/models/autoencoders/basic_ae/basic_ae_module.py +++ b/src/pl_bolts/models/autoencoders/basic_ae/basic_ae_module.py @@ -3,6 +3,9 @@ import torch from lightning import LightningModule, Trainer +from torch import nn +from torch.nn import functional as F # noqa: N812 + from pl_bolts import _HTTPS_AWS_HUB from pl_bolts.models.autoencoders.components import ( resnet18_decoder, @@ -11,8 +14,6 @@ resnet50_encoder, ) from pl_bolts.utils.stability import under_review -from torch import nn -from torch.nn import functional as F # noqa: N812 @under_review() diff --git a/src/pl_bolts/models/autoencoders/basic_vae/basic_vae_module.py b/src/pl_bolts/models/autoencoders/basic_vae/basic_vae_module.py index 5f6295492..15dcbdb9e 100644 --- a/src/pl_bolts/models/autoencoders/basic_vae/basic_vae_module.py +++ b/src/pl_bolts/models/autoencoders/basic_vae/basic_vae_module.py @@ -3,6 +3,9 @@ import torch from lightning import LightningModule, Trainer, seed_everything +from torch import nn +from torch.nn import functional as F # noqa: N812 + from pl_bolts import _HTTPS_AWS_HUB from pl_bolts.models.autoencoders.components import ( resnet18_decoder, @@ -11,8 +14,6 @@ resnet50_encoder, ) from pl_bolts.utils.stability import under_review -from torch import nn -from torch.nn import functional as F # noqa: N812 @under_review() diff --git a/src/pl_bolts/models/detection/faster_rcnn/faster_rcnn_module.py b/src/pl_bolts/models/detection/faster_rcnn/faster_rcnn_module.py index 1cc2d22e1..30ad007d1 100644 --- a/src/pl_bolts/models/detection/faster_rcnn/faster_rcnn_module.py +++ b/src/pl_bolts/models/detection/faster_rcnn/faster_rcnn_module.py @@ -3,6 +3,7 @@ import torch from lightning import LightningModule, Trainer, seed_everything + from pl_bolts.models.detection.faster_rcnn import create_fasterrcnn_backbone from pl_bolts.utils import _TORCHVISION_AVAILABLE from pl_bolts.utils.stability import under_review diff --git a/src/pl_bolts/models/detection/retinanet/retinanet_module.py b/src/pl_bolts/models/detection/retinanet/retinanet_module.py index 589b179bd..4fd11c003 100644 --- a/src/pl_bolts/models/detection/retinanet/retinanet_module.py +++ b/src/pl_bolts/models/detection/retinanet/retinanet_module.py @@ -2,6 +2,7 @@ import torch from lightning import LightningModule + from pl_bolts.models.detection.retinanet import create_retinanet_backbone from pl_bolts.utils import _TORCHVISION_AVAILABLE, _TORCHVISION_LESS_THAN_0_13 from pl_bolts.utils.stability import under_review @@ -141,6 +142,7 @@ def configure_optimizers(self): @under_review() def cli_main(): from lightning.cli import LightningCLI + from pl_bolts.datamodules import VOCDetectionDataModule LightningCLI(RetinaNet, VOCDetectionDataModule, seed_everything_default=42) diff --git a/src/pl_bolts/models/gans/basic/basic_gan_module.py b/src/pl_bolts/models/gans/basic/basic_gan_module.py index 45e7ecc15..a48885c5e 100644 --- a/src/pl_bolts/models/gans/basic/basic_gan_module.py +++ b/src/pl_bolts/models/gans/basic/basic_gan_module.py @@ -3,9 +3,10 @@ import torch from lightning import LightningModule, Trainer, seed_everything from lightning.pytorch.callbacks.progress import TQDMProgressBar -from pl_bolts.models.gans.basic.components import Discriminator, Generator from torch.nn import functional as F # noqa: N812 +from pl_bolts.models.gans.basic.components import Discriminator, Generator + class GAN(LightningModule): """Vanilla GAN implementation. diff --git a/src/pl_bolts/models/gans/dcgan/dcgan_module.py b/src/pl_bolts/models/gans/dcgan/dcgan_module.py index e054d3de4..ed026e42b 100644 --- a/src/pl_bolts/models/gans/dcgan/dcgan_module.py +++ b/src/pl_bolts/models/gans/dcgan/dcgan_module.py @@ -4,6 +4,9 @@ import torch import torch.nn as nn from lightning import LightningModule, Trainer, seed_everything +from torch import Tensor +from torch.utils.data import DataLoader + from pl_bolts.callbacks import ( LatentDimInterpolator, TensorboardGenerativeModelImageSampler, @@ -11,8 +14,6 @@ from pl_bolts.models.gans.dcgan.components import DCGANDiscriminator, DCGANGenerator from pl_bolts.utils import _TORCHVISION_AVAILABLE from pl_bolts.utils.warnings import warn_missing_pkg -from torch import Tensor -from torch.utils.data import DataLoader if _TORCHVISION_AVAILABLE: from torchvision import transforms as transform_lib diff --git a/src/pl_bolts/models/gans/pix2pix/pix2pix_module.py b/src/pl_bolts/models/gans/pix2pix/pix2pix_module.py index 1886a41c9..7831b0e60 100644 --- a/src/pl_bolts/models/gans/pix2pix/pix2pix_module.py +++ b/src/pl_bolts/models/gans/pix2pix/pix2pix_module.py @@ -1,8 +1,9 @@ import torch from lightning import LightningModule +from torch import nn + from pl_bolts.models.gans.pix2pix.components import Generator, PatchGAN from pl_bolts.utils.stability import under_review -from torch import nn @under_review() diff --git a/src/pl_bolts/models/mnist_module.py b/src/pl_bolts/models/mnist_module.py index a2485c6c0..eb26a2299 100644 --- a/src/pl_bolts/models/mnist_module.py +++ b/src/pl_bolts/models/mnist_module.py @@ -3,10 +3,11 @@ import torch from lightning import LightningModule, Trainer -from pl_bolts.utils import _TORCHVISION_AVAILABLE from torch import Tensor from torch.nn import functional as F # noqa: N812 +from pl_bolts.utils import _TORCHVISION_AVAILABLE + class LitMNIST(LightningModule): """PyTorch Lightning implementation of a two-layer MNIST classification module. diff --git a/src/pl_bolts/models/regression/linear_regression.py b/src/pl_bolts/models/regression/linear_regression.py index 4ca9433e5..533242f09 100644 --- a/src/pl_bolts/models/regression/linear_regression.py +++ b/src/pl_bolts/models/regression/linear_regression.py @@ -3,12 +3,13 @@ import torch from lightning import LightningModule, Trainer, seed_everything -from pl_bolts.utils.stability import under_review from torch import Tensor, nn from torch.nn import functional as F # noqa: N812 from torch.optim import Adam from torch.optim.optimizer import Optimizer +from pl_bolts.utils.stability import under_review + @under_review() class LinearRegression(LightningModule): diff --git a/src/pl_bolts/models/regression/logistic_regression.py b/src/pl_bolts/models/regression/logistic_regression.py index 022a01458..0a001240b 100644 --- a/src/pl_bolts/models/regression/logistic_regression.py +++ b/src/pl_bolts/models/regression/logistic_regression.py @@ -5,12 +5,13 @@ import torch from lightning import LightningModule, Trainer, seed_everything -from pl_bolts.utils.stability import under_review from torch import Tensor, nn from torch.optim import Adam from torch.optim.optimizer import Optimizer from torchmetrics import functional +from pl_bolts.utils.stability import under_review + class LogisticRegression(LightningModule): """Logistic Regression Model.""" diff --git a/src/pl_bolts/models/rl/advantage_actor_critic_model.py b/src/pl_bolts/models/rl/advantage_actor_critic_model.py index 7672b2bb8..d0f371dd2 100644 --- a/src/pl_bolts/models/rl/advantage_actor_critic_model.py +++ b/src/pl_bolts/models/rl/advantage_actor_critic_model.py @@ -8,15 +8,16 @@ import torch from lightning import LightningModule, Trainer, seed_everything from lightning.pytorch.callbacks import ModelCheckpoint +from torch import Tensor, optim +from torch.optim.optimizer import Optimizer +from torch.utils.data import DataLoader + from pl_bolts.datamodules import ExperienceSourceDataset from pl_bolts.models.rl.common.agents import ActorCriticAgent from pl_bolts.models.rl.common.networks import ActorCriticMLP from pl_bolts.utils import _GYM_AVAILABLE from pl_bolts.utils.stability import under_review from pl_bolts.utils.warnings import warn_missing_pkg -from torch import Tensor, optim -from torch.optim.optimizer import Optimizer -from torch.utils.data import DataLoader if _GYM_AVAILABLE: import gym diff --git a/src/pl_bolts/models/rl/double_dqn_model.py b/src/pl_bolts/models/rl/double_dqn_model.py index 954c64eef..9f31dbccb 100644 --- a/src/pl_bolts/models/rl/double_dqn_model.py +++ b/src/pl_bolts/models/rl/double_dqn_model.py @@ -5,10 +5,11 @@ from typing import Tuple from lightning import Trainer +from torch import Tensor + from pl_bolts.losses.rl import double_dqn_loss from pl_bolts.models.rl.dqn_model import DQN from pl_bolts.utils.stability import under_review -from torch import Tensor @under_review() diff --git a/src/pl_bolts/models/rl/dqn_model.py b/src/pl_bolts/models/rl/dqn_model.py index add95a67d..f54d146f9 100644 --- a/src/pl_bolts/models/rl/dqn_model.py +++ b/src/pl_bolts/models/rl/dqn_model.py @@ -9,6 +9,10 @@ from lightning import LightningModule, Trainer, seed_everything from lightning.pytorch.callbacks import ModelCheckpoint from lightning.pytorch.strategies import DataParallelStrategy +from torch import Tensor, optim +from torch.optim.optimizer import Optimizer +from torch.utils.data import DataLoader + from pl_bolts.datamodules.experience_source import Experience, ExperienceSourceDataset from pl_bolts.losses.rl import dqn_loss from pl_bolts.models.rl.common.agents import ValueAgent @@ -18,9 +22,6 @@ from pl_bolts.utils import _GYM_AVAILABLE from pl_bolts.utils.stability import under_review from pl_bolts.utils.warnings import warn_missing_pkg -from torch import Tensor, optim -from torch.optim.optimizer import Optimizer -from torch.utils.data import DataLoader if _GYM_AVAILABLE: from gym import Env diff --git a/src/pl_bolts/models/rl/dueling_dqn_model.py b/src/pl_bolts/models/rl/dueling_dqn_model.py index 8a6d19e5b..ffa091c66 100644 --- a/src/pl_bolts/models/rl/dueling_dqn_model.py +++ b/src/pl_bolts/models/rl/dueling_dqn_model.py @@ -3,6 +3,7 @@ import argparse from lightning import Trainer + from pl_bolts.models.rl.common.networks import DuelingCNN from pl_bolts.models.rl.dqn_model import DQN from pl_bolts.utils.stability import under_review diff --git a/src/pl_bolts/models/rl/noisy_dqn_model.py b/src/pl_bolts/models/rl/noisy_dqn_model.py index 3b871ac05..ff16c499d 100644 --- a/src/pl_bolts/models/rl/noisy_dqn_model.py +++ b/src/pl_bolts/models/rl/noisy_dqn_model.py @@ -5,11 +5,12 @@ import numpy as np from lightning import Trainer +from torch import Tensor + from pl_bolts.datamodules.experience_source import Experience from pl_bolts.models.rl.common.networks import NoisyCNN from pl_bolts.models.rl.dqn_model import DQN from pl_bolts.utils.stability import under_review -from torch import Tensor @under_review() diff --git a/src/pl_bolts/models/rl/per_dqn_model.py b/src/pl_bolts/models/rl/per_dqn_model.py index f4497d8b4..4c913b0a1 100644 --- a/src/pl_bolts/models/rl/per_dqn_model.py +++ b/src/pl_bolts/models/rl/per_dqn_model.py @@ -6,13 +6,14 @@ import numpy as np from lightning import Trainer +from torch import Tensor +from torch.utils.data import DataLoader + from pl_bolts.datamodules import ExperienceSourceDataset from pl_bolts.losses.rl import per_dqn_loss from pl_bolts.models.rl.common.memory import Experience, PERBuffer from pl_bolts.models.rl.dqn_model import DQN from pl_bolts.utils.stability import under_review -from torch import Tensor -from torch.utils.data import DataLoader @under_review() diff --git a/src/pl_bolts/models/rl/ppo_model.py b/src/pl_bolts/models/rl/ppo_model.py index b16d47435..33bc2fe79 100644 --- a/src/pl_bolts/models/rl/ppo_model.py +++ b/src/pl_bolts/models/rl/ppo_model.py @@ -3,14 +3,15 @@ import torch from lightning import LightningModule, Trainer, seed_everything +from torch import Tensor +from torch.optim.optimizer import Optimizer +from torch.utils.data import DataLoader + from pl_bolts.datamodules import ExperienceSourceDataset from pl_bolts.models.rl.common.networks import MLP, ActorCategorical, ActorContinous from pl_bolts.utils import _GYM_AVAILABLE from pl_bolts.utils.stability import under_review from pl_bolts.utils.warnings import warn_missing_pkg -from torch import Tensor -from torch.optim.optimizer import Optimizer -from torch.utils.data import DataLoader if _GYM_AVAILABLE: import gym diff --git a/src/pl_bolts/models/rl/reinforce_model.py b/src/pl_bolts/models/rl/reinforce_model.py index 3c814791e..6072c4ec3 100644 --- a/src/pl_bolts/models/rl/reinforce_model.py +++ b/src/pl_bolts/models/rl/reinforce_model.py @@ -5,6 +5,11 @@ import numpy as np from lightning import LightningModule, Trainer, seed_everything from lightning.pytorch.callbacks import ModelCheckpoint +from torch import Tensor, optim +from torch.nn.functional import log_softmax +from torch.optim.optimizer import Optimizer +from torch.utils.data import DataLoader + from pl_bolts.datamodules import ExperienceSourceDataset from pl_bolts.datamodules.experience_source import Experience from pl_bolts.models.rl.common.agents import PolicyAgent @@ -12,10 +17,6 @@ from pl_bolts.utils import _GYM_AVAILABLE from pl_bolts.utils.stability import under_review from pl_bolts.utils.warnings import warn_missing_pkg -from torch import Tensor, optim -from torch.nn.functional import log_softmax -from torch.optim.optimizer import Optimizer -from torch.utils.data import DataLoader if _GYM_AVAILABLE: import gym diff --git a/src/pl_bolts/models/rl/sac_model.py b/src/pl_bolts/models/rl/sac_model.py index eed7ad6d7..1833783fc 100644 --- a/src/pl_bolts/models/rl/sac_model.py +++ b/src/pl_bolts/models/rl/sac_model.py @@ -7,6 +7,11 @@ import torch from lightning import LightningModule, Trainer, seed_everything from lightning.pytorch.callbacks import ModelCheckpoint +from torch import Tensor, optim +from torch.nn import functional as F # noqa: N812 +from torch.optim.optimizer import Optimizer +from torch.utils.data import DataLoader + from pl_bolts.datamodules.experience_source import Experience, ExperienceSourceDataset from pl_bolts.models.rl.common.agents import SoftActorCriticAgent from pl_bolts.models.rl.common.memory import MultiStepBuffer @@ -14,10 +19,6 @@ from pl_bolts.utils import _GYM_AVAILABLE from pl_bolts.utils.stability import under_review from pl_bolts.utils.warnings import warn_missing_pkg -from torch import Tensor, optim -from torch.nn import functional as F # noqa: N812 -from torch.optim.optimizer import Optimizer -from torch.utils.data import DataLoader if _GYM_AVAILABLE: import gym diff --git a/src/pl_bolts/models/rl/vanilla_policy_gradient_model.py b/src/pl_bolts/models/rl/vanilla_policy_gradient_model.py index 73d25e698..9a8087667 100644 --- a/src/pl_bolts/models/rl/vanilla_policy_gradient_model.py +++ b/src/pl_bolts/models/rl/vanilla_policy_gradient_model.py @@ -6,16 +6,17 @@ import torch from lightning import LightningModule, Trainer, seed_everything from lightning.pytorch.callbacks import ModelCheckpoint +from torch import Tensor, optim +from torch.nn.functional import log_softmax, softmax +from torch.optim.optimizer import Optimizer +from torch.utils.data import DataLoader + from pl_bolts.datamodules import ExperienceSourceDataset from pl_bolts.models.rl.common.agents import PolicyAgent from pl_bolts.models.rl.common.networks import MLP from pl_bolts.utils import _GYM_AVAILABLE from pl_bolts.utils.stability import under_review from pl_bolts.utils.warnings import warn_missing_pkg -from torch import Tensor, optim -from torch.nn.functional import log_softmax, softmax -from torch.optim.optimizer import Optimizer -from torch.utils.data import DataLoader if _GYM_AVAILABLE: import gym diff --git a/src/pl_bolts/models/self_supervised/amdim/amdim_module.py b/src/pl_bolts/models/self_supervised/amdim/amdim_module.py index 690a23be0..da17a45ad 100644 --- a/src/pl_bolts/models/self_supervised/amdim/amdim_module.py +++ b/src/pl_bolts/models/self_supervised/amdim/amdim_module.py @@ -4,13 +4,14 @@ import torch from lightning import LightningDataModule, LightningModule, Trainer +from torch import optim +from torch.utils.data import DataLoader + from pl_bolts.losses.self_supervised_learning import FeatureMapContrastiveTask from pl_bolts.models.self_supervised.amdim.datasets import AMDIMPretraining from pl_bolts.models.self_supervised.amdim.networks import AMDIMEncoder from pl_bolts.utils.self_supervised import torchvision_ssl_encoder from pl_bolts.utils.stability import under_review -from torch import optim -from torch.utils.data import DataLoader @under_review() diff --git a/src/pl_bolts/models/self_supervised/byol/byol_module.py b/src/pl_bolts/models/self_supervised/byol/byol_module.py index 5210c2d66..ac85edfee 100644 --- a/src/pl_bolts/models/self_supervised/byol/byol_module.py +++ b/src/pl_bolts/models/self_supervised/byol/byol_module.py @@ -4,13 +4,14 @@ import torch from lightning import LightningModule, Trainer, seed_everything -from pl_bolts.callbacks.byol_updates import BYOLMAWeightUpdate -from pl_bolts.models.self_supervised.byol.models import MLP, SiameseArm -from pl_bolts.optimizers.lr_scheduler import LinearWarmupCosineAnnealingLR from torch import Tensor from torch.nn import functional as F # noqa: N812 from torch.optim import Adam +from pl_bolts.callbacks.byol_updates import BYOLMAWeightUpdate +from pl_bolts.models.self_supervised.byol.models import MLP, SiameseArm +from pl_bolts.optimizers.lr_scheduler import LinearWarmupCosineAnnealingLR + class BYOL(LightningModule): """PyTorch Lightning implementation of Bootstrap Your Own Latent (BYOL_)_ diff --git a/src/pl_bolts/models/self_supervised/cpc/cpc_finetuner.py b/src/pl_bolts/models/self_supervised/cpc/cpc_finetuner.py index 424918dfd..72781d236 100644 --- a/src/pl_bolts/models/self_supervised/cpc/cpc_finetuner.py +++ b/src/pl_bolts/models/self_supervised/cpc/cpc_finetuner.py @@ -2,6 +2,7 @@ from argparse import ArgumentParser from lightning import Trainer, seed_everything + from pl_bolts.models.self_supervised import CPC_v2, SSLFineTuner from pl_bolts.transforms.self_supervised.cpc_transforms import ( CPCEvalTransformsCIFAR10, diff --git a/src/pl_bolts/models/self_supervised/cpc/cpc_module.py b/src/pl_bolts/models/self_supervised/cpc/cpc_module.py index 8cc613448..db85bb017 100644 --- a/src/pl_bolts/models/self_supervised/cpc/cpc_module.py +++ b/src/pl_bolts/models/self_supervised/cpc/cpc_module.py @@ -7,6 +7,8 @@ import torch from lightning import LightningModule, Trainer, seed_everything from lightning.fabric.utilities import rank_zero_warn +from torch import optim + from pl_bolts.datamodules.stl10_datamodule import STL10DataModule from pl_bolts.losses.self_supervised_learning import CPCTask from pl_bolts.models.self_supervised.cpc.networks import cpc_resnet101 @@ -21,7 +23,6 @@ from pl_bolts.utils.pretrained_weights import load_pretrained from pl_bolts.utils.self_supervised import torchvision_ssl_encoder from pl_bolts.utils.stability import under_review -from torch import optim __all__ = ["CPC_v2"] diff --git a/src/pl_bolts/models/self_supervised/moco/callbacks.py b/src/pl_bolts/models/self_supervised/moco/callbacks.py index 6b4bc2dfc..c840baec3 100644 --- a/src/pl_bolts/models/self_supervised/moco/callbacks.py +++ b/src/pl_bolts/models/self_supervised/moco/callbacks.py @@ -1,6 +1,7 @@ import math from lightning import Callback + from pl_bolts.utils.stability import under_review diff --git a/src/pl_bolts/models/self_supervised/simclr/simclr_finetuner.py b/src/pl_bolts/models/self_supervised/simclr/simclr_finetuner.py index c1e65d6c3..4ecebc49c 100644 --- a/src/pl_bolts/models/self_supervised/simclr/simclr_finetuner.py +++ b/src/pl_bolts/models/self_supervised/simclr/simclr_finetuner.py @@ -2,6 +2,7 @@ from argparse import ArgumentParser from lightning import Trainer, seed_everything + from pl_bolts.models.self_supervised.simclr.simclr_module import SimCLR from pl_bolts.models.self_supervised.ssl_finetuner import SSLFineTuner from pl_bolts.transforms.dataset_normalizations import ( diff --git a/src/pl_bolts/models/self_supervised/simclr/simclr_module.py b/src/pl_bolts/models/self_supervised/simclr/simclr_module.py index f686c7db1..751615655 100644 --- a/src/pl_bolts/models/self_supervised/simclr/simclr_module.py +++ b/src/pl_bolts/models/self_supervised/simclr/simclr_module.py @@ -4,6 +4,9 @@ import torch from lightning import LightningModule, Trainer from lightning.pytorch.callbacks import LearningRateMonitor, ModelCheckpoint +from torch import Tensor, nn +from torch.nn import functional as F # noqa: N812 + from pl_bolts.models.self_supervised.resnets import resnet18, resnet50 from pl_bolts.optimizers.lars import LARS from pl_bolts.optimizers.lr_scheduler import linear_warmup_decay @@ -13,8 +16,6 @@ stl10_normalization, ) from pl_bolts.utils.stability import under_review -from torch import Tensor, nn -from torch.nn import functional as F # noqa: N812 @under_review() diff --git a/src/pl_bolts/models/self_supervised/simsiam/simsiam_module.py b/src/pl_bolts/models/self_supervised/simsiam/simsiam_module.py index aa7fd83d3..9390949b0 100644 --- a/src/pl_bolts/models/self_supervised/simsiam/simsiam_module.py +++ b/src/pl_bolts/models/self_supervised/simsiam/simsiam_module.py @@ -6,9 +6,10 @@ import torch.nn as nn import torch.nn.functional as F # noqa: N812 from lightning import LightningModule, Trainer, seed_everything +from torch import Tensor + from pl_bolts.models.self_supervised.byol.models import MLP, SiameseArm from pl_bolts.optimizers.lr_scheduler import LinearWarmupCosineAnnealingLR -from torch import Tensor class SimSiam(LightningModule): diff --git a/src/pl_bolts/models/self_supervised/ssl_finetuner.py b/src/pl_bolts/models/self_supervised/ssl_finetuner.py index 32266e91d..48eb8dd34 100644 --- a/src/pl_bolts/models/self_supervised/ssl_finetuner.py +++ b/src/pl_bolts/models/self_supervised/ssl_finetuner.py @@ -2,10 +2,11 @@ import torch from lightning import LightningModule -from pl_bolts.models.self_supervised import SSLEvaluator from torch.nn import functional as F # noqa: N812 from torchmetrics import Accuracy +from pl_bolts.models.self_supervised import SSLEvaluator + class SSLFineTuner(LightningModule): """Finetunes a self-supervised learning backbone using the standard evaluation protocol of a singler layer MLP with diff --git a/src/pl_bolts/models/self_supervised/swav/swav_finetuner.py b/src/pl_bolts/models/self_supervised/swav/swav_finetuner.py index e3e3b9c44..23d17e42e 100644 --- a/src/pl_bolts/models/self_supervised/swav/swav_finetuner.py +++ b/src/pl_bolts/models/self_supervised/swav/swav_finetuner.py @@ -2,6 +2,7 @@ from argparse import ArgumentParser from lightning import Trainer, seed_everything + from pl_bolts.models.self_supervised.ssl_finetuner import SSLFineTuner from pl_bolts.models.self_supervised.swav.swav_module import SwAV from pl_bolts.transforms.dataset_normalizations import ( diff --git a/src/pl_bolts/models/self_supervised/swav/swav_module.py b/src/pl_bolts/models/self_supervised/swav/swav_module.py index e8d7dbbba..3da02a35d 100644 --- a/src/pl_bolts/models/self_supervised/swav/swav_module.py +++ b/src/pl_bolts/models/self_supervised/swav/swav_module.py @@ -6,6 +6,8 @@ import torch from lightning import LightningModule, Trainer from lightning.pytorch.callbacks import LearningRateMonitor, ModelCheckpoint +from torch import nn + from pl_bolts.models.self_supervised.swav.loss import SWAVLoss from pl_bolts.models.self_supervised.swav.swav_resnet import resnet18, resnet50 from pl_bolts.optimizers.lars import LARS @@ -15,7 +17,6 @@ imagenet_normalization, stl10_normalization, ) -from torch import nn class SwAV(LightningModule): diff --git a/src/pl_bolts/models/vision/image_gpt/gpt2.py b/src/pl_bolts/models/vision/image_gpt/gpt2.py index 5bf946182..5f816a8d2 100644 --- a/src/pl_bolts/models/vision/image_gpt/gpt2.py +++ b/src/pl_bolts/models/vision/image_gpt/gpt2.py @@ -1,8 +1,9 @@ import torch from lightning import LightningModule -from pl_bolts.utils.stability import under_review from torch import nn +from pl_bolts.utils.stability import under_review + @under_review() class Block(nn.Module): diff --git a/src/pl_bolts/models/vision/image_gpt/igpt_module.py b/src/pl_bolts/models/vision/image_gpt/igpt_module.py index 8cbe9fd1b..6a738d730 100644 --- a/src/pl_bolts/models/vision/image_gpt/igpt_module.py +++ b/src/pl_bolts/models/vision/image_gpt/igpt_module.py @@ -3,9 +3,10 @@ import torch from lightning import LightningModule, Trainer +from torch import nn + from pl_bolts.models.vision.image_gpt.gpt2 import GPT2 from pl_bolts.utils.stability import under_review -from torch import nn @under_review() diff --git a/src/pl_bolts/models/vision/segmentation.py b/src/pl_bolts/models/vision/segmentation.py index 0ff2354db..b0a9ed05d 100644 --- a/src/pl_bolts/models/vision/segmentation.py +++ b/src/pl_bolts/models/vision/segmentation.py @@ -3,10 +3,11 @@ import torch from lightning import LightningModule, Trainer, seed_everything -from pl_bolts.models.vision.unet import UNet from torch import Tensor from torch.nn import functional as F # noqa: N812 +from pl_bolts.models.vision.unet import UNet + class SemSegment(LightningModule): """Basic model for semantic segmentation. Uses UNet architecture by default. diff --git a/src/pl_bolts/utils/__init__.py b/src/pl_bolts/utils/__init__.py index 795555107..b19909f65 100644 --- a/src/pl_bolts/utils/__init__.py +++ b/src/pl_bolts/utils/__init__.py @@ -3,6 +3,7 @@ import torch from lightning_utilities.core.imports import compare_version, module_available + from pl_bolts.callbacks.verification.batch_gradient import ( BatchGradientVerification, # type: ignore ) diff --git a/src/pl_bolts/utils/arguments.py b/src/pl_bolts/utils/arguments.py index 1d2f88e1c..ec1d79ac4 100644 --- a/src/pl_bolts/utils/arguments.py +++ b/src/pl_bolts/utils/arguments.py @@ -4,6 +4,7 @@ from typing import Any, Dict, List, Optional from lightning import LightningDataModule, LightningModule + from pl_bolts.utils.stability import under_review diff --git a/src/pl_bolts/utils/pretrained_weights.py b/src/pl_bolts/utils/pretrained_weights.py index 6954c4611..f4b51347f 100644 --- a/src/pl_bolts/utils/pretrained_weights.py +++ b/src/pl_bolts/utils/pretrained_weights.py @@ -1,6 +1,7 @@ from typing import Optional from lightning import LightningModule + from pl_bolts.utils.stability import under_review vae_imagenet2012 = ( diff --git a/tests/callbacks/test_ort.py b/tests/callbacks/test_ort.py index 23d3976f8..ed3197879 100644 --- a/tests/callbacks/test_ort.py +++ b/tests/callbacks/test_ort.py @@ -18,6 +18,7 @@ from lightning.pytorch.utilities.exceptions import MisconfigurationException from pl_bolts.callbacks import ORTCallback from pl_bolts.utils import _TORCH_ORT_AVAILABLE + from tests.helpers.boring_model import BoringModel if _TORCH_ORT_AVAILABLE: diff --git a/tests/callbacks/test_sparseml.py b/tests/callbacks/test_sparseml.py index fd06e7cad..e5c038b72 100644 --- a/tests/callbacks/test_sparseml.py +++ b/tests/callbacks/test_sparseml.py @@ -20,6 +20,7 @@ from lightning.pytorch.utilities.exceptions import MisconfigurationException from pl_bolts.callbacks import SparseMLCallback from pl_bolts.utils import _SPARSEML_TORCH_SATISFIED + from tests.helpers.boring_model import BoringModel if _SPARSEML_TORCH_SATISFIED: diff --git a/tests/callbacks/verification/test_base.py b/tests/callbacks/verification/test_base.py index 675d18249..343b99992 100644 --- a/tests/callbacks/verification/test_base.py +++ b/tests/callbacks/verification/test_base.py @@ -7,6 +7,7 @@ from lightning.pytorch.utilities import move_data_to_device from pl_bolts.callbacks.verification.base import VerificationBase from pl_bolts.utils import _PL_GREATER_EQUAL_1_4 + from tests import _MARK_REQUIRE_GPU diff --git a/tests/callbacks/verification/test_batch_gradient.py b/tests/callbacks/verification/test_batch_gradient.py index 7fa9bd545..3a883105b 100644 --- a/tests/callbacks/verification/test_batch_gradient.py +++ b/tests/callbacks/verification/test_batch_gradient.py @@ -11,9 +11,10 @@ selective_eval, ) from pl_bolts.utils import BatchGradientVerification -from tests import _MARK_REQUIRE_GPU from torch import Tensor, nn +from tests import _MARK_REQUIRE_GPU + class TemplateModel(nn.Module): def __init__(self, mix_data=False) -> None: diff --git a/tests/models/self_supervised/test_models.py b/tests/models/self_supervised/test_models.py index 5b632860d..7e7d511e0 100644 --- a/tests/models/self_supervised/test_models.py +++ b/tests/models/self_supervised/test_models.py @@ -33,6 +33,7 @@ SwAVTrainDataTransform, ) from pl_bolts.utils import _IS_WINDOWS + from tests import _MARK_REQUIRE_GPU diff --git a/tests/models/test_detection.py b/tests/models/test_detection.py index 762ac2bf0..9d29de5d9 100644 --- a/tests/models/test_detection.py +++ b/tests/models/test_detection.py @@ -20,9 +20,10 @@ ) from pl_bolts.models.detection.faster_rcnn import create_fasterrcnn_backbone from pl_bolts.utils import _IS_WINDOWS -from tests import TEST_ROOT from torch.utils.data import DataLoader +from tests import TEST_ROOT + def _collate_fn(batch): return tuple(zip(*batch)) diff --git a/tests/models/test_vision.py b/tests/models/test_vision.py index 715ef3d13..0e4835d16 100644 --- a/tests/models/test_vision.py +++ b/tests/models/test_vision.py @@ -2,9 +2,8 @@ import pytest import torch -from lightning.pytorch import LightningDataModule, Trainer +from lightning.pytorch import LightningDataModule, Trainer, seed_everything from lightning.pytorch import __version__ as pl_version -from lightning.pytorch import seed_everything from lightning.pytorch.callbacks.progress import TQDMProgressBar from lightning.pytorch.utilities.warnings import PossibleUserWarning from packaging import version From 2281be0c62b5b85a24519d08ad3b1818eb32d3de Mon Sep 17 00:00:00 2001 From: James Bishop Date: Fri, 7 Jun 2024 09:28:10 +0100 Subject: [PATCH 6/6] Update CHANGELOG.md --- CHANGELOG.md | 2 ++ 1 file changed, 2 insertions(+) diff --git a/CHANGELOG.md b/CHANGELOG.md index 3b68a5ea0..997371c43 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -15,6 +15,8 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 - Revision of the MoCo SSL model ([#928](https://github.com/PyTorchLightning/pytorch-lightning-bolts/pull/928)) +- Updated lightning dependency to support lightning 2.x ([#1094](https://github.com/Lightning-AI/pytorch-lightning/pull/2671)) + ### Deprecated