diff --git a/pl_examples/basic_examples/autoencoder.py b/pl_examples/basic_examples/autoencoder.py index b3188a21b7f04..a2010a89f4461 100644 --- a/pl_examples/basic_examples/autoencoder.py +++ b/pl_examples/basic_examples/autoencoder.py @@ -22,9 +22,10 @@ import pytorch_lightning as pl from pl_examples import _DATASETS_PATH, _TORCHVISION_AVAILABLE, _TORCHVISION_MNIST_AVAILABLE, cli_lightning_logo -if _TORCHVISION_AVAILABLE and _TORCHVISION_MNIST_AVAILABLE: +if _TORCHVISION_AVAILABLE: from torchvision import transforms - from torchvision.datasets.mnist import MNIST +if _TORCHVISION_MNIST_AVAILABLE: + from torchvision.datasets import MNIST else: from tests.helpers.datasets import MNIST diff --git a/pl_examples/basic_examples/backbone_image_classifier.py b/pl_examples/basic_examples/backbone_image_classifier.py index 01a5dca0de3c7..3546bee9ad129 100644 --- a/pl_examples/basic_examples/backbone_image_classifier.py +++ b/pl_examples/basic_examples/backbone_image_classifier.py @@ -21,9 +21,10 @@ import pytorch_lightning as pl from pl_examples import _DATASETS_PATH, _TORCHVISION_AVAILABLE, _TORCHVISION_MNIST_AVAILABLE, cli_lightning_logo -if _TORCHVISION_AVAILABLE and _TORCHVISION_MNIST_AVAILABLE: +if _TORCHVISION_AVAILABLE: from torchvision import transforms - from torchvision.datasets.mnist import MNIST +if _TORCHVISION_MNIST_AVAILABLE: + from torchvision.datasets import MNIST else: from tests.helpers.datasets import MNIST diff --git a/pl_examples/basic_examples/dali_image_classifier.py b/pl_examples/basic_examples/dali_image_classifier.py index b4bf1407a9b26..da5b1e4fd9e9c 100644 --- a/pl_examples/basic_examples/dali_image_classifier.py +++ b/pl_examples/basic_examples/dali_image_classifier.py @@ -31,9 +31,10 @@ cli_lightning_logo, ) -if _TORCHVISION_AVAILABLE and _TORCHVISION_MNIST_AVAILABLE: +if _TORCHVISION_AVAILABLE: from torchvision import transforms - from torchvision.datasets.mnist import MNIST +if _TORCHVISION_MNIST_AVAILABLE: + from torchvision.datasets import MNIST else: from tests.helpers.datasets import MNIST diff --git a/pl_examples/basic_examples/mnist_datamodule.py b/pl_examples/basic_examples/mnist_datamodule.py index a50f67cdab301..a6d59c64d9aa0 100644 --- a/pl_examples/basic_examples/mnist_datamodule.py +++ b/pl_examples/basic_examples/mnist_datamodule.py @@ -20,8 +20,9 @@ from pl_examples import _DATASETS_PATH, _TORCHVISION_AVAILABLE, _TORCHVISION_MNIST_AVAILABLE from pytorch_lightning import LightningDataModule -if _TORCHVISION_AVAILABLE and _TORCHVISION_MNIST_AVAILABLE: +if _TORCHVISION_AVAILABLE: from torchvision import transforms as transform_lib +if _TORCHVISION_MNIST_AVAILABLE: from torchvision.datasets import MNIST else: from tests.helpers.datasets import MNIST diff --git a/pl_examples/domain_templates/generative_adversarial_net.py b/pl_examples/domain_templates/generative_adversarial_net.py index 285fba8b93f1b..e65ede17dac7a 100644 --- a/pl_examples/domain_templates/generative_adversarial_net.py +++ b/pl_examples/domain_templates/generative_adversarial_net.py @@ -32,9 +32,10 @@ from pytorch_lightning.core import LightningDataModule, LightningModule from pytorch_lightning.trainer import Trainer -if _TORCHVISION_AVAILABLE and _TORCHVISION_MNIST_AVAILABLE: +if _TORCHVISION_AVAILABLE: import torchvision - import torchvision.transforms as transforms + from torchvision import transforms +if _TORCHVISION_MNIST_AVAILABLE: from torchvision.datasets import MNIST else: from tests.helpers.datasets import MNIST diff --git a/tests/helpers/datasets.py b/tests/helpers/datasets.py index 5af3fbfbc4a11..e7bdad0f1538c 100644 --- a/tests/helpers/datasets.py +++ b/tests/helpers/datasets.py @@ -69,6 +69,7 @@ def __init__( train: bool = True, normalize: tuple = (0.1307, 0.3081), download: bool = True, + **kwargs, ): super().__init__() self.root = root