Skip to content

Commit

Permalink
Do not require omegaconf to run tests (#10832)
Browse files Browse the repository at this point in the history
  • Loading branch information
carmocca committed Nov 30, 2021
1 parent a81accb commit 38ed26e
Show file tree
Hide file tree
Showing 6 changed files with 64 additions and 58 deletions.
17 changes: 10 additions & 7 deletions tests/checkpointing/test_model_checkpoint.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,6 @@
import pytest
import torch
import yaml
from omegaconf import Container, OmegaConf
from torch import optim

import pytorch_lightning as pl
Expand All @@ -39,9 +38,13 @@
from pytorch_lightning.loggers import TensorBoardLogger
from pytorch_lightning.utilities.cloud_io import load as pl_load
from pytorch_lightning.utilities.exceptions import MisconfigurationException
from pytorch_lightning.utilities.imports import _OMEGACONF_AVAILABLE
from tests.helpers import BoringModel
from tests.helpers.runif import RunIf

if _OMEGACONF_AVAILABLE:
from omegaconf import Container, OmegaConf


def test_model_checkpoint_state_key():
early_stopping = ModelCheckpoint(monitor="val_loss")
Expand Down Expand Up @@ -1094,8 +1097,8 @@ def training_step(self, *args):
assert model_checkpoint.current_score == expected


@pytest.mark.parametrize("hparams_type", [dict, Container])
def test_hparams_type(tmpdir, hparams_type):
@pytest.mark.parametrize("use_omegaconf", [False, pytest.param(True, marks=RunIf(omegaconf=True))])
def test_hparams_type(tmpdir, use_omegaconf):
class TestModel(BoringModel):
def __init__(self, hparams):
super().__init__()
Expand All @@ -1113,15 +1116,15 @@ def __init__(self, hparams):
enable_model_summary=False,
)
hp = {"test_hp_0": 1, "test_hp_1": 2}
hp = OmegaConf.create(hp) if hparams_type == Container else Namespace(**hp)
hp = OmegaConf.create(hp) if use_omegaconf else Namespace(**hp)
model = TestModel(hp)
trainer.fit(model)
ckpt = trainer.checkpoint_connector.dump_checkpoint()
if hparams_type == Container:
assert isinstance(ckpt[model.CHECKPOINT_HYPER_PARAMS_KEY], hparams_type)
if use_omegaconf:
assert isinstance(ckpt[model.CHECKPOINT_HYPER_PARAMS_KEY], Container)
else:
# make sure it's not AttributeDict
assert type(ckpt[model.CHECKPOINT_HYPER_PARAMS_KEY]) is hparams_type
assert type(ckpt[model.CHECKPOINT_HYPER_PARAMS_KEY]) is dict


def test_ckpt_version_after_rerun_new_trainer(tmpdir):
Expand Down
11 changes: 7 additions & 4 deletions tests/core/test_datamodules.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,12 +20,11 @@

import pytest
import torch
from omegaconf import OmegaConf

from pytorch_lightning import LightningDataModule, Trainer
from pytorch_lightning.callbacks import ModelCheckpoint
from pytorch_lightning.trainer.states import TrainerFn
from pytorch_lightning.utilities import AttributeDict
from pytorch_lightning.utilities import _OMEGACONF_AVAILABLE, AttributeDict
from pytorch_lightning.utilities.exceptions import MisconfigurationException
from pytorch_lightning.utilities.model_helpers import is_overridden
from tests.helpers import BoringDataModule, BoringModel
Expand All @@ -34,6 +33,9 @@
from tests.helpers.simple_models import ClassificationModel
from tests.helpers.utils import reset_seed

if _OMEGACONF_AVAILABLE:
from omegaconf import OmegaConf


@mock.patch("pytorch_lightning.trainer.trainer.Trainer.node_rank", new_callable=PropertyMock)
@mock.patch("pytorch_lightning.trainer.trainer.Trainer.local_rank", new_callable=PropertyMock)
Expand Down Expand Up @@ -440,8 +442,9 @@ def test_hyperparameters_saving():
data = DataModuleWithHparams_1({"hello": "world"}, "foo", kwarg0="bar")
assert data.hparams == AttributeDict({"hello": "world"})

data = DataModuleWithHparams_1(OmegaConf.create({"hello": "world"}), "foo", kwarg0="bar")
assert data.hparams == OmegaConf.create({"hello": "world"})
if _OMEGACONF_AVAILABLE:
data = DataModuleWithHparams_1(OmegaConf.create({"hello": "world"}), "foo", kwarg0="bar")
assert data.hparams == OmegaConf.create({"hello": "world"})


def test_define_as_dataclass():
Expand Down
9 changes: 8 additions & 1 deletion tests/helpers/runif.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@
_FAIRSCALE_FULLY_SHARDED_AVAILABLE,
_HOROVOD_AVAILABLE,
_IPU_AVAILABLE,
_OMEGACONF_AVAILABLE,
_RICH_AVAILABLE,
_TORCH_QUANTIZE_AVAILABLE,
_TPU_AVAILABLE,
Expand Down Expand Up @@ -70,6 +71,7 @@ def __new__(
deepspeed: bool = False,
rich: bool = False,
skip_49370: bool = False,
omegaconf: bool = False,
**kwargs,
):
"""
Expand All @@ -89,9 +91,10 @@ def __new__(
standalone: Mark the test as standalone, our CI will run it in a separate process.
fairscale: Require that facebookresearch/fairscale is installed.
fairscale_fully_sharded: Require that `fairscale` fully sharded support is available.
deepspeed: Require that Microsoft/DeepSpeed is installed.
deepspeed: Require that microsoft/DeepSpeed is installed.
rich: Require that willmcgugan/rich is installed.
skip_49370: Skip the test as it's impacted by https://github.com/pytorch/pytorch/issues/49370.
omegaconf: Require that omry/omegaconf is installed.
**kwargs: Any :class:`pytest.mark.skipif` keyword arguments.
"""
conditions = []
Expand Down Expand Up @@ -177,6 +180,10 @@ def __new__(
conditions.append(ge_3_9 and old_torch)
reasons.append("Impacted by https://github.com/pytorch/pytorch/issues/49370")

if omegaconf:
conditions.append(not _OMEGACONF_AVAILABLE)
reasons.append("omegaconf")

reasons = [rs for cond, rs in zip(conditions, reasons) if cond]
return pytest.mark.skipif(
*args, condition=any(conditions), reason=f"Requires: [{' + '.join(reasons)}]", **kwargs
Expand Down
10 changes: 6 additions & 4 deletions tests/loggers/test_tensorboard.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,13 +21,16 @@
import pytest
import torch
import yaml
from omegaconf import OmegaConf

from pytorch_lightning import Trainer
from pytorch_lightning.loggers import TensorBoardLogger
from pytorch_lightning.loggers.base import LoggerCollection
from pytorch_lightning.utilities.imports import _compare_version
from pytorch_lightning.utilities.imports import _compare_version, _OMEGACONF_AVAILABLE
from tests.helpers import BoringModel
from tests.helpers.runif import RunIf

if _OMEGACONF_AVAILABLE:
from omegaconf import OmegaConf


@pytest.mark.skipif(
Expand Down Expand Up @@ -205,6 +208,7 @@ def test_tensorboard_log_hparams_and_metrics(tmpdir):
logger.log_hyperparams(hparams, metrics)


@RunIf(omegaconf=True)
def test_tensorboard_log_omegaconf_hparams_and_metrics(tmpdir):
logger = TensorBoardLogger(tmpdir, default_hp_metric=False)
hparams = {
Expand All @@ -214,8 +218,6 @@ def test_tensorboard_log_omegaconf_hparams_and_metrics(tmpdir):
"bool": True,
"dict": {"a": {"b": "c"}},
"list": [1, 2, 3],
# "namespace": Namespace(foo=Namespace(bar="buzz")),
# "layer": torch.nn.BatchNorm1d,
}
hparams = OmegaConf.create(hparams)

Expand Down
58 changes: 23 additions & 35 deletions tests/models/test_hparams.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,21 +24,24 @@
import pytest
import torch
from fsspec.implementations.local import LocalFileSystem
from omegaconf import Container, OmegaConf
from omegaconf.dictconfig import DictConfig
from torch.utils.data import DataLoader

from pytorch_lightning import LightningModule, Trainer
from pytorch_lightning.callbacks import ModelCheckpoint
from pytorch_lightning.core.datamodule import LightningDataModule
from pytorch_lightning.core.saving import load_hparams_from_yaml, save_hparams_to_yaml
from pytorch_lightning.utilities import _HYDRA_EXPERIMENTAL_AVAILABLE, AttributeDict, is_picklable
from pytorch_lightning.utilities import _HYDRA_EXPERIMENTAL_AVAILABLE, _OMEGACONF_AVAILABLE, AttributeDict, is_picklable
from pytorch_lightning.utilities.exceptions import MisconfigurationException
from tests.helpers import BoringModel, RandomDataset
from tests.helpers.runif import RunIf

if _HYDRA_EXPERIMENTAL_AVAILABLE:
from hydra.experimental import compose, initialize

if _OMEGACONF_AVAILABLE:
from omegaconf import Container, OmegaConf
from omegaconf.dictconfig import DictConfig


class SaveHparamsModel(BoringModel):
"""Tests that a model can take an object."""
Expand Down Expand Up @@ -117,6 +120,7 @@ def test_dict_hparams(tmpdir, cls):
_run_standard_hparams_test(tmpdir, model, cls)


@RunIf(omegaconf=True)
@pytest.mark.parametrize("cls", [SaveHparamsModel, SaveHparamsDecoratedModel])
def test_omega_conf_hparams(tmpdir, cls):
# init model
Expand Down Expand Up @@ -275,10 +279,18 @@ def __init__(obj, *more_args, other_arg=300, **more_kwargs):
obj.save_hyperparameters()


class DictConfSubClassBoringModel(SubClassBoringModel):
def __init__(self, *args, dict_conf=OmegaConf.create(dict(my_param="something")), **kwargs):
super().__init__(*args, **kwargs)
self.save_hyperparameters()
if _OMEGACONF_AVAILABLE:

class DictConfSubClassBoringModel(SubClassBoringModel):
def __init__(self, *args, dict_conf=OmegaConf.create(dict(my_param="something")), **kwargs):
super().__init__(*args, **kwargs)
self.save_hyperparameters()


else:

class DictConfSubClassBoringModel:
...


@pytest.mark.parametrize(
Expand All @@ -290,7 +302,7 @@ def __init__(self, *args, dict_conf=OmegaConf.create(dict(my_param="something"))
SubSubClassBoringModel,
AggSubClassBoringModel,
UnconventionalArgsBoringModel,
DictConfSubClassBoringModel,
pytest.param(DictConfSubClassBoringModel, marks=RunIf(omegaconf=True)),
],
)
def test_collect_init_arguments(tmpdir, cls):
Expand Down Expand Up @@ -383,31 +395,6 @@ def test_collect_init_arguments_with_local_vars(cls):
assert model.hparams["arg2"] == 2


# @pytest.mark.parametrize("cls,config", [
# (SaveHparamsModel, Namespace(my_arg=42)),
# (SaveHparamsModel, dict(my_arg=42)),
# (SaveHparamsModel, OmegaConf.create(dict(my_arg=42))),
# (AssignHparamsModel, Namespace(my_arg=42)),
# (AssignHparamsModel, dict(my_arg=42)),
# (AssignHparamsModel, OmegaConf.create(dict(my_arg=42))),
# ])
# def test_single_config_models(tmpdir, cls, config):
# """ Test that the model automatically saves the arguments passed into the constructor """
# model = cls(config)
#
# # no matter how you do it, it should be assigned
# assert model.hparams.my_arg == 42
#
# # verify that the checkpoint saved the correct values
# trainer = Trainer(default_root_dir=tmpdir, max_epochs=2, overfit_batches=0.5)
# trainer.fit(model)
#
# # verify that model loads correctly
# raw_checkpoint_path = _raw_checkpoint_path(trainer)
# model = cls.load_from_checkpoint(raw_checkpoint_path)
# assert model.hparams.my_arg == 42


class AnotherArgModel(BoringModel):
def __init__(self, arg1):
super().__init__()
Expand Down Expand Up @@ -511,8 +498,9 @@ def _compare_params(loaded_params, default_params: dict):
save_hparams_to_yaml(path_yaml, AttributeDict(hparams))
_compare_params(load_hparams_from_yaml(path_yaml, use_omegaconf=False), hparams)

save_hparams_to_yaml(path_yaml, OmegaConf.create(hparams))
_compare_params(load_hparams_from_yaml(path_yaml), hparams)
if _OMEGACONF_AVAILABLE:
save_hparams_to_yaml(path_yaml, OmegaConf.create(hparams))
_compare_params(load_hparams_from_yaml(path_yaml), hparams)


class NoArgsSubClassBoringModel(CustomBoringModel):
Expand Down
17 changes: 10 additions & 7 deletions tests/trainer/test_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,6 @@
import cloudpickle
import pytest
import torch
from omegaconf import OmegaConf
from torch.nn.parallel.distributed import DistributedDataParallel
from torch.optim import SGD
from torch.utils.data import DataLoader, IterableDataset
Expand All @@ -51,6 +50,7 @@
from pytorch_lightning.utilities import _AcceleratorType, _StrategyType
from pytorch_lightning.utilities.cloud_io import load as pl_load
from pytorch_lightning.utilities.exceptions import DeadlockDetectedException, MisconfigurationException
from pytorch_lightning.utilities.imports import _OMEGACONF_AVAILABLE
from pytorch_lightning.utilities.seed import seed_everything
from tests.base import EvalModelTemplate
from tests.helpers import BoringModel, RandomDataset
Expand All @@ -59,6 +59,9 @@
from tests.helpers.runif import RunIf
from tests.helpers.simple_models import ClassificationModel

if _OMEGACONF_AVAILABLE:
from omegaconf import OmegaConf


@pytest.mark.parametrize("url_ckpt", [True, False])
def test_no_val_module(monkeypatch, tmpdir, tmpdir_server, url_ckpt):
Expand Down Expand Up @@ -1271,12 +1274,12 @@ def __init__(self, **kwargs):
TrainerSubclass(abcdefg="unknown_arg")


@pytest.mark.parametrize(
"trainer_params", [OmegaConf.create(dict(max_epochs=1, gpus=1)), OmegaConf.create(dict(max_epochs=1, gpus=[0]))]
)
@RunIf(min_gpus=1)
def test_trainer_omegaconf(trainer_params):
Trainer(**trainer_params)
@RunIf(omegaconf=True)
@pytest.mark.parametrize("trainer_params", [{"max_epochs": 1, "gpus": 1}, {"max_epochs": 1, "gpus": [0]}])
@mock.patch("torch.cuda.device_count", return_value=1)
def test_trainer_omegaconf(_, trainer_params):
config = OmegaConf.create(trainer_params)
Trainer(**config)


def test_trainer_pickle(tmpdir):
Expand Down

0 comments on commit 38ed26e

Please sign in to comment.