Skip to content
This repository has been archived by the owner on Oct 9, 2023. It is now read-only.

Commit

Permalink
Skip core tests when not core testing (#1330)
Browse files Browse the repository at this point in the history
  • Loading branch information
ethanwharris authored May 6, 2022
1 parent 3faf9f6 commit 07d63e3
Show file tree
Hide file tree
Showing 60 changed files with 343 additions and 61 deletions.
6 changes: 6 additions & 0 deletions .github/workflows/ci-testing.yml
Original file line number Diff line number Diff line change
Expand Up @@ -52,6 +52,12 @@ jobs:
with:
python-version: ${{ matrix.python-version }}

- name: Set Swap Space
if: runner.os == 'Linux'
uses: pierotofy/set-swap-space@master
with:
swap-size-gb: 10

# Github Actions: Run step on specific OS: https://stackoverflow.com/a/57948488/4521646
- name: Setup macOS
if: runner.os == 'macOS'
Expand Down
5 changes: 5 additions & 0 deletions flash/core/data/data_module.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,9 +37,14 @@
from flash.core.data.splits import SplitDataset
from flash.core.data.utils import _STAGES_PREFIX
from flash.core.registry import FlashRegistry
from flash.core.utilities.imports import _CORE_TESTING
from flash.core.utilities.stages import RunningStage
from flash.core.utilities.types import INPUT_TRANSFORM_TYPE

# Skip doctests if requirements aren't available
if not _CORE_TESTING:
__doctest_skip__ = ["DataModule"]


class DatasetInput(Input):
"""The ``DatasetInput`` implements default behaviours for data sources which expect the input to
Expand Down
5 changes: 5 additions & 0 deletions flash/core/data/utilities/classification.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,11 @@
import torch

from flash.core.data.utilities.sort import sorted_alphanumeric
from flash.core.utilities.imports import _CORE_TESTING

# Skip doctests if requirements aren't available
if not _CORE_TESTING:
__doctest_skip__ = ["*"]


def _is_list_like(x: Any) -> bool:
Expand Down
6 changes: 5 additions & 1 deletion flash/core/data/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,9 +23,13 @@
from pytorch_lightning.utilities.apply_func import apply_to_collection
from tqdm.auto import tqdm as tq

from flash.core.utilities.imports import _PIL_AVAILABLE, _TORCHVISION_AVAILABLE
from flash.core.utilities.imports import _CORE_TESTING, _PIL_AVAILABLE, _TORCHVISION_AVAILABLE
from flash.core.utilities.stages import RunningStage

# Skip doctests if requirements aren't available
if not _CORE_TESTING:
__doctest_skip__ = ["download_data"]

if _PIL_AVAILABLE:
from PIL.Image import Image
else:
Expand Down
6 changes: 5 additions & 1 deletion flash/core/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,7 +49,7 @@
from flash.core.registry import FlashRegistry
from flash.core.serve.composition import Composition
from flash.core.utilities.apply_func import get_callable_dict
from flash.core.utilities.imports import _PL_GREATER_EQUAL_1_5_0, requires
from flash.core.utilities.imports import _CORE_TESTING, _PL_GREATER_EQUAL_1_5_0, requires
from flash.core.utilities.providers import _HUGGINGFACE
from flash.core.utilities.stages import RunningStage
from flash.core.utilities.types import (
Expand All @@ -62,6 +62,10 @@
OUTPUT_TRANSFORM_TYPE,
)

# Skip doctests if requirements aren't available
if not _CORE_TESTING:
__doctest_skip__ = ["Task", "Task.*"]


class ModuleWrapperBase:
"""The ``ModuleWrapperBase`` is a base for classes which wrap a ``LightningModule`` or an instance of
Expand Down
6 changes: 6 additions & 0 deletions flash/core/optimizers/lamb.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,12 @@
from torch import nn
from torch.optim.optimizer import Optimizer

from flash.core.utilities.imports import _CORE_TESTING

# Skip doctests if requirements aren't available
if not _CORE_TESTING:
__doctest_skip__ = ["LAMB"]


class LAMB(Optimizer):
r"""Extends ADAM in pytorch to incorporate LAMB algorithm from the paper:
Expand Down
6 changes: 6 additions & 0 deletions flash/core/optimizers/lars.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,12 @@
from torch import nn
from torch.optim.optimizer import Optimizer, required

from flash.core.utilities.imports import _CORE_TESTING

# Skip doctests if requirements aren't available
if not _CORE_TESTING:
__doctest_skip__ = ["LARS"]


class LARS(Optimizer):
r"""Extends SGD in PyTorch with LARS scaling from the paper
Expand Down
6 changes: 6 additions & 0 deletions flash/core/optimizers/lr_scheduler.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,12 @@
from torch.optim import Adam, Optimizer
from torch.optim.lr_scheduler import _LRScheduler

from flash.core.utilities.imports import _CORE_TESTING

# Skip doctests if requirements aren't available
if not _CORE_TESTING:
__doctest_skip__ = ["LinearWarmupCosineAnnealingLR"]


class LinearWarmupCosineAnnealingLR(_LRScheduler):
"""Sets the learning rate of each parameter group to follow a linear warmup schedule between warmup_start_lr
Expand Down
5 changes: 5 additions & 0 deletions flash/core/serve/dag/optimization.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,11 @@
from flash.core.serve.dag.task import flatten, get, get_dependencies, ishashable, istask, reverse_dict, subs, toposort
from flash.core.serve.dag.utils import key_split
from flash.core.serve.dag.utils_test import add, inc, mul
from flash.core.utilities.imports import _SERVE_TESTING

# Skip doctests if requirements aren't available
if not _SERVE_TESTING:
__doctest_skip__ = ["*"]


def cull(dsk, keys):
Expand Down
5 changes: 5 additions & 0 deletions flash/core/serve/dag/order.py
Original file line number Diff line number Diff line change
Expand Up @@ -81,6 +81,11 @@

from flash.core.serve.dag.task import get_dependencies, get_deps, getcycle, reverse_dict
from flash.core.serve.dag.utils_test import add, inc
from flash.core.utilities.imports import _SERVE_TESTING

# Skip doctests if requirements aren't available
if not _SERVE_TESTING:
__doctest_skip__ = ["*"]


def order(dsk, dependencies=None):
Expand Down
5 changes: 5 additions & 0 deletions flash/core/serve/dag/rewrite.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,11 @@
from collections import deque

from flash.core.serve.dag.task import istask, subs
from flash.core.utilities.imports import _SERVE_TESTING

# Skip doctests if requirements aren't available
if not _SERVE_TESTING:
__doctest_skip__ = ["*"]


def head(task):
Expand Down
5 changes: 5 additions & 0 deletions flash/core/serve/dag/task.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,11 @@
from typing import List, Sequence

from flash.core.serve.dag.utils_test import add, inc
from flash.core.utilities.imports import _SERVE_TESTING

# Skip doctests if requirements aren't available
if not _SERVE_TESTING:
__doctest_skip__ = ["*"]

no_default = "__no_default__"

Expand Down
6 changes: 6 additions & 0 deletions flash/core/serve/dag/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,12 @@
import re
from operator import methodcaller

from flash.core.utilities.imports import _SERVE_TESTING

# Skip doctests if requirements aren't available
if not _SERVE_TESTING:
__doctest_skip__ = ["*"]


def funcname(func):
"""Get the name of a function."""
Expand Down
6 changes: 5 additions & 1 deletion flash/core/serve/decorators.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,11 @@
from flash.core.serve.core import Connection, make_param_dict, make_parameter_container, ParameterContainer, Servable
from flash.core.serve.types.base import BaseType
from flash.core.serve.utils import fn_outputs_to_keyed_map
from flash.core.utilities.imports import _CYTOOLZ_AVAILABLE
from flash.core.utilities.imports import _CYTOOLZ_AVAILABLE, _SERVE_TESTING

# Skip doctests if requirements aren't available
if not _SERVE_TESTING:
__doctest_skip__ = ["*"]

if _CYTOOLZ_AVAILABLE:
from cytoolz import compose
Expand Down
6 changes: 5 additions & 1 deletion flash/core/serve/interfaces/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,11 @@
from flash.core.serve.component import ModelComponent
from flash.core.serve.core import Endpoint
from flash.core.serve.types import Repeated
from flash.core.utilities.imports import _PYDANTIC_AVAILABLE
from flash.core.utilities.imports import _PYDANTIC_AVAILABLE, _SERVE_TESTING

# Skip doctests if requirements aren't available
if not _SERVE_TESTING:
__doctest_skip__ = ["EndpointProtocol.*"]

if _PYDANTIC_AVAILABLE:
from pydantic import BaseModel, create_model
Expand Down
19 changes: 0 additions & 19 deletions flash/core/serve/utils.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,3 @@
from importlib.util import find_spec
from pathlib import Path
from typing import Any, Dict, Optional

Expand Down Expand Up @@ -46,21 +45,3 @@ def download_file(url: str, *, download_path: Optional[Path] = None) -> str:
f.write(chunk)

return fpath


def _module_available(module_path: str) -> bool:
"""Check if a path is available in your environment.
>>> _module_available('os')
True
>>> _module_available('bla.bla')
False
"""
try:
return find_spec(module_path) is not None
except AttributeError:
# Python 3.6
return False
except ModuleNotFoundError:
# Python 3.7+
return False
2 changes: 2 additions & 0 deletions flash/core/utilities/imports.py
Original file line number Diff line number Diff line change
Expand Up @@ -275,6 +275,7 @@ def _import_module(self):


# Global variables used for testing purposes (e.g. to only run doctests in the correct CI job)
_CORE_TESTING = True
_IMAGE_TESTING = _IMAGE_AVAILABLE
_IMAGE_EXTRAS_TESTING = False # Not for normal use
_VIDEO_TESTING = _VIDEO_AVAILABLE
Expand All @@ -288,6 +289,7 @@ def _import_module(self):

if "FLASH_TEST_TOPIC" in os.environ:
topic = os.environ["FLASH_TEST_TOPIC"]
_CORE_TESTING = topic == "core"
_IMAGE_TESTING = topic == "image"
_IMAGE_EXTRAS_TESTING = topic == "image,image_extras" or topic == "icevision" or topic == "vissl"
_VIDEO_TESTING = topic == "video"
Expand Down
6 changes: 6 additions & 0 deletions flash/core/utilities/stability.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,12 @@

from pytorch_lightning.utilities import rank_zero_warn

from flash.core.utilities.imports import _CORE_TESTING

# Skip doctests if requirements aren't available
if not _CORE_TESTING:
__doctest_skip__ = ["beta"]


@functools.lru_cache() # Trick to only warn once for each message
def _raise_beta_warning(message: str, stacklevel: int = 6):
Expand Down
9 changes: 1 addition & 8 deletions flash_examples/question_answering.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,22 +13,15 @@
# limitations under the License.
from flash import Trainer
from flash.core.data.utils import download_data
from flash.core.utilities.imports import example_requires
from flash.text import QuestionAnsweringData, QuestionAnsweringTask

example_requires("text")

import nltk # noqa: E402

nltk.download("punkt")

# 1. Create the DataModule
download_data("https://pl-flash-data.s3.amazonaws.com/squad_tiny.zip", "./data/")

datamodule = QuestionAnsweringData.from_squad_v2(
train_file="./data/squad_tiny/train.json",
val_file="./data/squad_tiny/val.json",
batch_size=4,
batch_size=1,
)

# 2. Build the task
Expand Down
4 changes: 4 additions & 0 deletions tests/core/data/io/test_input.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,9 +15,11 @@
from pytorch_lightning.utilities.exceptions import MisconfigurationException

from flash.core.data.io.input import Input, IterableInput, ServeInput
from flash.core.utilities.imports import _CORE_TESTING
from flash.core.utilities.stages import RunningStage


@pytest.mark.skipif(not _CORE_TESTING, reason="Not testing core.")
def test_input_validation():
with pytest.raises(RuntimeError, match="Use `IterableInput` instead."):

Expand All @@ -38,6 +40,7 @@ def __init__(self, *args, **kwargs):
ValidInput(RunningStage.TRAINING)


@pytest.mark.skipif(not _CORE_TESTING, reason="Not testing core.")
def test_iterable_input_validation():
with pytest.raises(RuntimeError, match="Use `Input` instead."):

Expand All @@ -58,6 +61,7 @@ def __init__(self, *args, **kwargs):
ValidIterableInput(RunningStage.TRAINING)


@pytest.mark.skipif(not _CORE_TESTING, reason="Not testing core.")
def test_serve_input():

server_input = ServeInput()
Expand Down
4 changes: 4 additions & 0 deletions tests/core/data/io/test_output.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,9 +13,13 @@
# limitations under the License.
from unittest.mock import Mock

import pytest

from flash.core.data.io.output import Output
from flash.core.utilities.imports import _CORE_TESTING


@pytest.mark.skipif(not _CORE_TESTING, reason="Not testing core.")
def test_output():
"""Tests basic ``Output`` methods."""
my_output = Output()
Expand Down
3 changes: 3 additions & 0 deletions tests/core/data/io/test_output_transform.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,11 +11,14 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import pytest
import torch

from flash.core.data.io.output_transform import OutputTransform
from flash.core.utilities.imports import _CORE_TESTING


@pytest.mark.skipif(not _CORE_TESTING, reason="Not testing core.")
def test_output_transform():
class CustomOutputTransform(OutputTransform):
@staticmethod
Expand Down
3 changes: 3 additions & 0 deletions tests/core/data/test_batch.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
import torch

from flash.core.data.batch import default_uncollate
from flash.core.utilities.imports import _CORE_TESTING

Case = namedtuple("Case", ["collated_batch", "uncollated_batch"])

Expand Down Expand Up @@ -46,6 +47,7 @@
]


@pytest.mark.skipif(not _CORE_TESTING, reason="Not testing core.")
@pytest.mark.parametrize("case", cases)
def test_default_uncollate(case):
assert default_uncollate(case.collated_batch) == case.uncollated_batch
Expand All @@ -60,6 +62,7 @@ def test_default_uncollate(case):
]


@pytest.mark.skipif(not _CORE_TESTING, reason="Not testing core.")
@pytest.mark.parametrize("error_case", error_cases)
def test_default_uncollate_raises(error_case):
with pytest.raises(ValueError, match=error_case.match):
Expand Down
3 changes: 3 additions & 0 deletions tests/core/data/test_callback.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,16 +13,19 @@
# limitations under the License.
from unittest import mock

import pytest
import torch

from flash import DataKeys
from flash.core.data.data_module import DataModule, DatasetInput
from flash.core.data.io.input_transform import InputTransform
from flash.core.model import Task
from flash.core.trainer import Trainer
from flash.core.utilities.imports import _CORE_TESTING
from flash.core.utilities.stages import RunningStage


@pytest.mark.skipif(not _CORE_TESTING, reason="Not testing core.")
@mock.patch("pickle.dumps") # need to mock pickle or we get pickle error
@mock.patch("torch.save") # need to mock torch.save, or we get pickle error
def test_flash_callback(_, __, tmpdir):
Expand Down
Loading

0 comments on commit 07d63e3

Please sign in to comment.