Skip to content
This repository has been archived by the owner on Oct 9, 2023. It is now read-only.

Commit

Permalink
Move *_TESTING variables to flash/core/utilities/imports and use fo…
Browse files Browse the repository at this point in the history
…r doctests (#1134)
  • Loading branch information
ethanwharris authored Jan 25, 2022
1 parent e693129 commit fd1a2e4
Show file tree
Hide file tree
Showing 69 changed files with 121 additions and 162 deletions.
4 changes: 2 additions & 2 deletions flash/audio/classification/data.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,12 +33,12 @@
from flash.core.data.io.input_transform import INPUT_TRANSFORM_TYPE
from flash.core.data.utilities.paths import PATH_TYPE
from flash.core.registry import FlashRegistry
from flash.core.utilities.imports import _AUDIO_AVAILABLE
from flash.core.utilities.imports import _AUDIO_TESTING
from flash.core.utilities.stages import RunningStage
from flash.image.classification.data import MatplotlibVisualization

# Skip doctests if requirements aren't available
if not _AUDIO_AVAILABLE:
if not _AUDIO_TESTING:
__doctest_skip__ = ["AudioClassificationData", "AudioClassificationData.*"]


Expand Down
4 changes: 2 additions & 2 deletions flash/audio/speech_recognition/data.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,11 +27,11 @@
from flash.core.data.io.input import Input
from flash.core.data.io.input_transform import INPUT_TRANSFORM_TYPE, InputTransform
from flash.core.registry import FlashRegistry
from flash.core.utilities.imports import _AUDIO_AVAILABLE
from flash.core.utilities.imports import _AUDIO_TESTING
from flash.core.utilities.stages import RunningStage

# Skip doctests if requirements aren't available
if not _AUDIO_AVAILABLE:
if not _AUDIO_TESTING:
__doctest_skip__ = ["SpeechRecognitionData", "SpeechRecognitionData.*"]


Expand Down
23 changes: 23 additions & 0 deletions flash/core/utilities/imports.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
import functools
import importlib
import operator
import os
import types
from importlib.util import find_spec
from typing import List, Tuple, Union
Expand Down Expand Up @@ -272,3 +273,25 @@ def _import_module(self):
# Update this object's dict so that attribute references are efficient
# (__getattr__ is only called on lookups that fail)
self.__dict__.update(module.__dict__)


# Global variables used for testing purposes (e.g. to only run doctests in the correct CI job)
_IMAGE_TESTING = _IMAGE_AVAILABLE
_VIDEO_TESTING = _VIDEO_AVAILABLE
_TABULAR_TESTING = _TABULAR_AVAILABLE
_TEXT_TESTING = _TEXT_AVAILABLE
_SERVE_TESTING = _SERVE_AVAILABLE
_POINTCLOUD_TESTING = _POINTCLOUD_AVAILABLE
_GRAPH_TESTING = _GRAPH_AVAILABLE
_AUDIO_TESTING = _AUDIO_AVAILABLE

if "FLASH_TEST_TOPIC" in os.environ:
topic = os.environ["FLASH_TEST_TOPIC"]
_IMAGE_TESTING = topic == "image"
_VIDEO_TESTING = topic == "video"
_TABULAR_TESTING = topic == "tabular"
_TEXT_TESTING = topic == "text"
_SERVE_TESTING = topic == "serve"
_POINTCLOUD_TESTING = topic == "pointcloud"
_GRAPH_TESTING = topic == "graph"
_AUDIO_TESTING = topic == "audio"
4 changes: 2 additions & 2 deletions flash/graph/classification/data.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,14 +19,14 @@
from flash.core.data.data_pipeline import DataPipelineState
from flash.core.data.io.input import Input
from flash.core.data.utilities.classification import TargetFormatter
from flash.core.utilities.imports import _GRAPH_AVAILABLE
from flash.core.utilities.imports import _GRAPH_TESTING
from flash.core.utilities.stages import RunningStage
from flash.core.utilities.types import INPUT_TRANSFORM_TYPE
from flash.graph.classification.input import GraphClassificationDatasetInput
from flash.graph.classification.input_transform import GraphClassificationInputTransform

# Skip doctests if requirements aren't available
if not _GRAPH_AVAILABLE:
if not _GRAPH_TESTING:
__doctest_skip__ = ["GraphClassificationData", "GraphClassificationData.*"]


Expand Down
4 changes: 2 additions & 2 deletions flash/image/classification/data.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@
from flash.core.data.utilities.paths import PATH_TYPE
from flash.core.integrations.labelstudio.input import _parse_labelstudio_arguments, LabelStudioImageClassificationInput
from flash.core.registry import FlashRegistry
from flash.core.utilities.imports import _FIFTYONE_AVAILABLE, _IMAGE_AVAILABLE, _MATPLOTLIB_AVAILABLE, Image, requires
from flash.core.utilities.imports import _FIFTYONE_AVAILABLE, _IMAGE_TESTING, _MATPLOTLIB_AVAILABLE, Image, requires
from flash.core.utilities.stages import RunningStage
from flash.image.classification.input import (
ImageClassificationCSVInput,
Expand All @@ -46,7 +46,7 @@
SampleCollection = None

# Skip doctests if requirements aren't available
if not _IMAGE_AVAILABLE:
if not _IMAGE_TESTING:
__doctest_skip__ = ["ImageClassificationData", "ImageClassificationData.*"]

if _MATPLOTLIB_AVAILABLE:
Expand Down
4 changes: 2 additions & 2 deletions flash/image/segmentation/data.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@
from flash.core.data.data_pipeline import DataPipelineState
from flash.core.data.io.input import Input
from flash.core.registry import FlashRegistry
from flash.core.utilities.imports import _FIFTYONE_AVAILABLE, _IMAGE_AVAILABLE, lazy_import
from flash.core.utilities.imports import _FIFTYONE_AVAILABLE, _IMAGE_TESTING, lazy_import
from flash.core.utilities.stages import RunningStage
from flash.core.utilities.types import INPUT_TRANSFORM_TYPE
from flash.image.segmentation.input import (
Expand All @@ -42,7 +42,7 @@
SampleCollection = object

# Skip doctests if requirements aren't available
if not _IMAGE_AVAILABLE:
if not _IMAGE_TESTING:
__doctest_skip__ = ["SemanticSegmentationData", "SemanticSegmentationData.*"]


Expand Down
4 changes: 2 additions & 2 deletions flash/image/style_transfer/data.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,15 +19,15 @@
from flash.core.data.data_module import DataModule
from flash.core.data.data_pipeline import DataPipelineState
from flash.core.data.io.input import Input
from flash.core.utilities.imports import _IMAGE_AVAILABLE
from flash.core.utilities.imports import _IMAGE_TESTING
from flash.core.utilities.stages import RunningStage
from flash.core.utilities.types import INPUT_TRANSFORM_TYPE
from flash.image.classification.input import ImageClassificationFilesInput, ImageClassificationFolderInput
from flash.image.data import ImageNumpyInput, ImageTensorInput
from flash.image.style_transfer.input_transform import StyleTransferInputTransform

# Skip doctests if requirements aren't available
if not _IMAGE_AVAILABLE:
if not _IMAGE_TESTING:
__doctest_skip__ = ["StyleTransferData", "StyleTransferData.*"]


Expand Down
4 changes: 2 additions & 2 deletions flash/tabular/classification/data.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@
from flash.core.data.data_pipeline import DataPipelineState
from flash.core.data.io.input import Input
from flash.core.data.io.input_transform import INPUT_TRANSFORM_TYPE, InputTransform
from flash.core.utilities.imports import _PANDAS_AVAILABLE, _TABULAR_AVAILABLE
from flash.core.utilities.imports import _PANDAS_AVAILABLE, _TABULAR_TESTING
from flash.core.utilities.stages import RunningStage
from flash.tabular.classification.input import TabularClassificationCSVInput, TabularClassificationDataFrameInput
from flash.tabular.data import TabularData
Expand All @@ -27,7 +27,7 @@
DataFrame = object

# Skip doctests if requirements aren't available
if not _TABULAR_AVAILABLE:
if not _TABULAR_TESTING:
__doctest_skip__ = ["TabularClassificationData", "TabularClassificationData.*"]


Expand Down
4 changes: 2 additions & 2 deletions flash/tabular/forecasting/data.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@
from flash.core.data.io.input import Input
from flash.core.data.io.input_transform import INPUT_TRANSFORM_TYPE, InputTransform
from flash.core.data.io.output_transform import OutputTransform
from flash.core.utilities.imports import _PANDAS_AVAILABLE, _TABULAR_AVAILABLE
from flash.core.utilities.imports import _PANDAS_AVAILABLE, _TABULAR_TESTING
from flash.core.utilities.stages import RunningStage
from flash.tabular.forecasting.input import TabularForecastingDataFrameInput

Expand All @@ -32,7 +32,7 @@


# Skip doctests if requirements aren't available
if not _TABULAR_AVAILABLE:
if not _TABULAR_TESTING:
__doctest_skip__ = ["TabularForecastingData", "TabularForecastingData.*"]


Expand Down
4 changes: 2 additions & 2 deletions flash/tabular/regression/data.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@
from flash.core.data.data_pipeline import DataPipelineState
from flash.core.data.io.input import Input
from flash.core.data.io.input_transform import INPUT_TRANSFORM_TYPE, InputTransform
from flash.core.utilities.imports import _PANDAS_AVAILABLE, _TABULAR_AVAILABLE
from flash.core.utilities.imports import _PANDAS_AVAILABLE, _TABULAR_TESTING
from flash.core.utilities.stages import RunningStage
from flash.tabular.data import TabularData
from flash.tabular.regression.input import TabularRegressionCSVInput, TabularRegressionDataFrameInput
Expand All @@ -27,7 +27,7 @@
DataFrame = object

# Skip doctests if requirements aren't available
if not _TABULAR_AVAILABLE:
if not _TABULAR_TESTING:
__doctest_skip__ = ["TabularRegressionData", "TabularRegressionData.*"]


Expand Down
4 changes: 2 additions & 2 deletions flash/text/classification/data.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@
from flash.core.data.utilities.paths import PATH_TYPE
from flash.core.integrations.labelstudio.input import _parse_labelstudio_arguments, LabelStudioTextClassificationInput
from flash.core.integrations.transformers.input_transform import TransformersInputTransform
from flash.core.utilities.imports import _TEXT_AVAILABLE
from flash.core.utilities.imports import _TEXT_AVAILABLE, _TEXT_TESTING
from flash.core.utilities.stages import RunningStage
from flash.text.classification.input import (
TextClassificationCSVInput,
Expand All @@ -38,7 +38,7 @@
Dataset = object

# Skip doctests if requirements aren't available
if not _TEXT_AVAILABLE:
if not _TEXT_TESTING:
__doctest_skip__ = ["TextClassificationData", "TextClassificationData.*"]


Expand Down
4 changes: 2 additions & 2 deletions flash/text/seq2seq/summarization/data.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@
from flash.core.data.io.input import Input
from flash.core.data.utilities.paths import PATH_TYPE
from flash.core.integrations.transformers.input_transform import TransformersInputTransform
from flash.core.utilities.imports import _TEXT_AVAILABLE
from flash.core.utilities.imports import _TEXT_AVAILABLE, _TEXT_TESTING
from flash.core.utilities.stages import RunningStage
from flash.core.utilities.types import INPUT_TRANSFORM_TYPE
from flash.text.seq2seq.core.input import Seq2SeqCSVInput, Seq2SeqInputBase, Seq2SeqJSONInput, Seq2SeqListInput
Expand All @@ -30,7 +30,7 @@
Dataset = object

# Skip doctests if requirements aren't available
if not _TEXT_AVAILABLE:
if not _TEXT_TESTING:
__doctest_skip__ = ["SummarizationData", "SummarizationData.*"]


Expand Down
4 changes: 2 additions & 2 deletions flash/text/seq2seq/translation/data.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@
from flash.core.data.io.input import Input
from flash.core.data.utilities.paths import PATH_TYPE
from flash.core.integrations.transformers.input_transform import TransformersInputTransform
from flash.core.utilities.imports import _TEXT_AVAILABLE
from flash.core.utilities.imports import _TEXT_AVAILABLE, _TEXT_TESTING
from flash.core.utilities.stages import RunningStage
from flash.core.utilities.types import INPUT_TRANSFORM_TYPE
from flash.text.seq2seq.core.input import Seq2SeqCSVInput, Seq2SeqInputBase, Seq2SeqJSONInput, Seq2SeqListInput
Expand All @@ -30,7 +30,7 @@
Dataset = object

# Skip doctests if requirements aren't available
if not _TEXT_AVAILABLE:
if not _TEXT_TESTING:
__doctest_skip__ = ["TranslationData", "TranslationData.*"]


Expand Down
4 changes: 2 additions & 2 deletions flash/video/classification/data.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@
from flash.core.data.io.input_transform import INPUT_TRANSFORM_TYPE
from flash.core.data.utilities.paths import PATH_TYPE
from flash.core.integrations.labelstudio.input import _parse_labelstudio_arguments, LabelStudioVideoClassificationInput
from flash.core.utilities.imports import _FIFTYONE_AVAILABLE, _PYTORCHVIDEO_AVAILABLE, _VIDEO_AVAILABLE, requires
from flash.core.utilities.imports import _FIFTYONE_AVAILABLE, _PYTORCHVIDEO_AVAILABLE, _VIDEO_TESTING, requires
from flash.core.utilities.stages import RunningStage
from flash.video.classification.input import (
VideoClassificationCSVInput,
Expand All @@ -48,7 +48,7 @@
ClipSampler = None

# Skip doctests if requirements aren't available
if not _VIDEO_AVAILABLE:
if not _VIDEO_TESTING:
__doctest_skip__ = ["VideoClassificationData", "VideoClassificationData.*"]


Expand Down
3 changes: 1 addition & 2 deletions tests/audio/classification/test_data.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,8 +23,7 @@
from flash.audio import AudioClassificationData
from flash.core.data.io.input import DataKeys
from flash.core.data.transforms import ApplyToKeys
from flash.core.utilities.imports import _MATPLOTLIB_AVAILABLE, _PIL_AVAILABLE, _TORCHVISION_AVAILABLE
from tests.helpers.utils import _AUDIO_TESTING
from flash.core.utilities.imports import _AUDIO_TESTING, _MATPLOTLIB_AVAILABLE, _PIL_AVAILABLE, _TORCHVISION_AVAILABLE

if _TORCHVISION_AVAILABLE:
import torchvision.transforms as T
Expand Down
3 changes: 1 addition & 2 deletions tests/audio/classification/test_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,8 +16,7 @@
import pytest

from flash.__main__ import main
from flash.core.utilities.imports import _IMAGE_AVAILABLE
from tests.helpers.utils import _AUDIO_TESTING
from flash.core.utilities.imports import _AUDIO_TESTING, _IMAGE_AVAILABLE


@pytest.mark.skipif(not _IMAGE_AVAILABLE, reason="image libraries aren't installed.")
Expand Down
3 changes: 1 addition & 2 deletions tests/audio/speech_recognition/test_data.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,8 +20,7 @@
import flash
from flash.audio import SpeechRecognitionData
from flash.core.data.io.input import DataKeys
from flash.core.utilities.imports import _AUDIO_AVAILABLE
from tests.helpers.utils import _AUDIO_TESTING
from flash.core.utilities.imports import _AUDIO_AVAILABLE, _AUDIO_TESTING

path = str(Path(flash.ASSETS_ROOT) / "example.wav")
sample = {"file": path, "text": "example input."}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@
import flash
from flash import Trainer
from flash.audio import SpeechRecognition, SpeechRecognitionData
from tests.helpers.utils import _AUDIO_TESTING
from flash.core.utilities.imports import _AUDIO_TESTING

TEST_BACKBONE = "patrickvonplaten/wav2vec2_tiny_random_robust" # super small model for testing

Expand Down
3 changes: 1 addition & 2 deletions tests/audio/speech_recognition/test_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,9 +24,8 @@
from flash.audio import SpeechRecognition
from flash.audio.speech_recognition.data import InputTransform, SpeechRecognitionData
from flash.core.data.io.input import DataKeys, Input
from flash.core.utilities.imports import _AUDIO_AVAILABLE
from flash.core.utilities.imports import _AUDIO_AVAILABLE, _AUDIO_TESTING, _SERVE_TESTING
from flash.core.utilities.stages import RunningStage
from tests.helpers.utils import _AUDIO_TESTING, _SERVE_TESTING

# ======== Mock functions ========

Expand Down
3 changes: 1 addition & 2 deletions tests/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,8 +7,7 @@
from pytest_mock import MockerFixture

from flash.core.serve.decorators import uuid4 # noqa (used in mocker.patch)
from flash.core.utilities.imports import _TORCHVISION_AVAILABLE
from tests.helpers.utils import _SERVE_TESTING
from flash.core.utilities.imports import _SERVE_TESTING, _TORCHVISION_AVAILABLE

if _TORCHVISION_AVAILABLE:
import torchvision
Expand Down
3 changes: 1 addition & 2 deletions tests/core/data/test_base_viz.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,10 +23,9 @@
from flash.core.data.base_viz import BaseVisualization
from flash.core.data.io.input import DataKeys
from flash.core.data.utils import _CALLBACK_FUNCS
from flash.core.utilities.imports import _PIL_AVAILABLE
from flash.core.utilities.imports import _IMAGE_TESTING, _PIL_AVAILABLE
from flash.core.utilities.stages import RunningStage
from flash.image import ImageClassificationData
from tests.helpers.utils import _IMAGE_TESTING

if _PIL_AVAILABLE:
from PIL import Image
Expand Down
3 changes: 1 addition & 2 deletions tests/core/data/test_data_module.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,9 +26,8 @@
from flash.core.data.io.input import Input
from flash.core.data.io.input_transform import InputTransform
from flash.core.data.states import PerBatchTransformOnDevice, PerSampleTransform
from flash.core.utilities.imports import _TORCHVISION_AVAILABLE
from flash.core.utilities.imports import _IMAGE_TESTING, _TORCHVISION_AVAILABLE
from flash.core.utilities.stages import RunningStage
from tests.helpers.utils import _IMAGE_TESTING

if _TORCHVISION_AVAILABLE:
import torchvision.transforms as T
Expand Down
2 changes: 1 addition & 1 deletion tests/core/integrations/labelstudio/test_labelstudio.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,11 +10,11 @@
)
from flash.core.integrations.labelstudio.visualizer import launch_app
from flash.core.integrations.transformers.states import TransformersBackboneState
from flash.core.utilities.imports import _IMAGE_TESTING, _TEXT_TESTING, _VIDEO_TESTING
from flash.core.utilities.stages import RunningStage
from flash.image.classification.data import ImageClassificationData
from flash.text.classification.data import TextClassificationData
from flash.video.classification.data import LabelStudioVideoClassificationInput, VideoClassificationData
from tests.helpers.utils import _IMAGE_TESTING, _TEXT_TESTING, _VIDEO_TESTING


def test_utility_load():
Expand Down
2 changes: 1 addition & 1 deletion tests/core/serve/test_components.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,8 +2,8 @@
import torch

from flash.core.serve.types import Label
from flash.core.utilities.imports import _SERVE_TESTING
from tests.core.serve.models import ClassificationInferenceComposable, LightningSqueezenet
from tests.helpers.utils import _SERVE_TESTING


@pytest.mark.skipif(not _SERVE_TESTING, reason="serve libraries aren't installed.")
Expand Down
3 changes: 1 addition & 2 deletions tests/core/serve/test_composition.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,8 +4,7 @@
import pytest

from flash.core.serve import Composition, Endpoint
from flash.core.utilities.imports import _FASTAPI_AVAILABLE
from tests.helpers.utils import _SERVE_TESTING
from flash.core.utilities.imports import _FASTAPI_AVAILABLE, _SERVE_TESTING

if _FASTAPI_AVAILABLE:
from fastapi.testclient import TestClient
Expand Down
3 changes: 1 addition & 2 deletions tests/core/serve/test_gridbase_validations.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,8 +2,7 @@

from flash.core.serve import expose, ModelComponent
from flash.core.serve.types import Number
from flash.core.utilities.imports import _CYTOOLZ_AVAILABLE
from tests.helpers.utils import _SERVE_TESTING
from flash.core.utilities.imports import _CYTOOLZ_AVAILABLE, _SERVE_TESTING


@pytest.mark.skipif(not _CYTOOLZ_AVAILABLE, reason="the library cytoolz is not installed.")
Expand Down
3 changes: 1 addition & 2 deletions tests/core/serve/test_integration.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,8 +3,7 @@
import pytest

from flash.core.serve import Composition, Endpoint
from flash.core.utilities.imports import _FASTAPI_AVAILABLE
from tests.helpers.utils import _SERVE_TESTING
from flash.core.utilities.imports import _FASTAPI_AVAILABLE, _SERVE_TESTING

if _FASTAPI_AVAILABLE:
from fastapi.testclient import TestClient
Expand Down
Loading

0 comments on commit fd1a2e4

Please sign in to comment.