Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Do not require omegaconf to run tests #10832

Merged
merged 2 commits into from
Nov 30, 2021
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
17 changes: 10 additions & 7 deletions tests/checkpointing/test_model_checkpoint.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,6 @@
import pytest
import torch
import yaml
from omegaconf import Container, OmegaConf
from torch import optim

import pytorch_lightning as pl
Expand All @@ -39,9 +38,13 @@
from pytorch_lightning.loggers import TensorBoardLogger
from pytorch_lightning.utilities.cloud_io import load as pl_load
from pytorch_lightning.utilities.exceptions import MisconfigurationException
from pytorch_lightning.utilities.imports import _OMEGACONF_AVAILABLE
from tests.helpers import BoringModel
from tests.helpers.runif import RunIf

if _OMEGACONF_AVAILABLE:
from omegaconf import Container, OmegaConf


def test_model_checkpoint_state_key():
early_stopping = ModelCheckpoint(monitor="val_loss")
Expand Down Expand Up @@ -1101,8 +1104,8 @@ def training_step(self, *args):
assert model_checkpoint.current_score == expected


@pytest.mark.parametrize("hparams_type", [dict, Container])
def test_hparams_type(tmpdir, hparams_type):
@pytest.mark.parametrize("use_omegaconf", [False, pytest.param(True, marks=RunIf(omegaconf=True))])
def test_hparams_type(tmpdir, use_omegaconf):
class TestModel(BoringModel):
def __init__(self, hparams):
super().__init__()
Expand All @@ -1120,15 +1123,15 @@ def __init__(self, hparams):
enable_model_summary=False,
)
hp = {"test_hp_0": 1, "test_hp_1": 2}
hp = OmegaConf.create(hp) if hparams_type == Container else Namespace(**hp)
hp = OmegaConf.create(hp) if use_omegaconf else Namespace(**hp)
model = TestModel(hp)
trainer.fit(model)
ckpt = trainer.checkpoint_connector.dump_checkpoint()
if hparams_type == Container:
assert isinstance(ckpt[model.CHECKPOINT_HYPER_PARAMS_KEY], hparams_type)
if use_omegaconf:
assert isinstance(ckpt[model.CHECKPOINT_HYPER_PARAMS_KEY], Container)
else:
# make sure it's not AttributeDict
assert type(ckpt[model.CHECKPOINT_HYPER_PARAMS_KEY]) is hparams_type
assert type(ckpt[model.CHECKPOINT_HYPER_PARAMS_KEY]) is dict


def test_ckpt_version_after_rerun_new_trainer(tmpdir):
Expand Down
11 changes: 7 additions & 4 deletions tests/core/test_datamodules.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,12 +20,11 @@

import pytest
import torch
from omegaconf import OmegaConf

from pytorch_lightning import LightningDataModule, Trainer
from pytorch_lightning.callbacks import ModelCheckpoint
from pytorch_lightning.trainer.states import TrainerFn
from pytorch_lightning.utilities import AttributeDict
from pytorch_lightning.utilities import _OMEGACONF_AVAILABLE, AttributeDict
from pytorch_lightning.utilities.exceptions import MisconfigurationException
from pytorch_lightning.utilities.model_helpers import is_overridden
from tests.helpers import BoringDataModule, BoringModel
Expand All @@ -34,6 +33,9 @@
from tests.helpers.simple_models import ClassificationModel
from tests.helpers.utils import reset_seed

if _OMEGACONF_AVAILABLE:
from omegaconf import OmegaConf


@mock.patch("pytorch_lightning.trainer.trainer.Trainer.node_rank", new_callable=PropertyMock)
@mock.patch("pytorch_lightning.trainer.trainer.Trainer.local_rank", new_callable=PropertyMock)
Expand Down Expand Up @@ -440,8 +442,9 @@ def test_hyperparameters_saving():
data = DataModuleWithHparams_1({"hello": "world"}, "foo", kwarg0="bar")
assert data.hparams == AttributeDict({"hello": "world"})

data = DataModuleWithHparams_1(OmegaConf.create({"hello": "world"}), "foo", kwarg0="bar")
assert data.hparams == OmegaConf.create({"hello": "world"})
if _OMEGACONF_AVAILABLE:
data = DataModuleWithHparams_1(OmegaConf.create({"hello": "world"}), "foo", kwarg0="bar")
assert data.hparams == OmegaConf.create({"hello": "world"})


def test_define_as_dataclass():
Expand Down
9 changes: 8 additions & 1 deletion tests/helpers/runif.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@
_FAIRSCALE_FULLY_SHARDED_AVAILABLE,
_HOROVOD_AVAILABLE,
_IPU_AVAILABLE,
_OMEGACONF_AVAILABLE,
_RICH_AVAILABLE,
_TORCH_QUANTIZE_AVAILABLE,
_TPU_AVAILABLE,
Expand Down Expand Up @@ -70,6 +71,7 @@ def __new__(
deepspeed: bool = False,
rich: bool = False,
skip_49370: bool = False,
omegaconf: bool = False,
**kwargs,
):
"""
Expand All @@ -89,9 +91,10 @@ def __new__(
standalone: Mark the test as standalone, our CI will run it in a separate process.
fairscale: Require that facebookresearch/fairscale is installed.
fairscale_fully_sharded: Require that `fairscale` fully sharded support is available.
deepspeed: Require that Microsoft/DeepSpeed is installed.
deepspeed: Require that microsoft/DeepSpeed is installed.
rich: Require that willmcgugan/rich is installed.
skip_49370: Skip the test as it's impacted by https://github.com/pytorch/pytorch/issues/49370.
omegaconf: Require that omry/omegaconf is installed.
**kwargs: Any :class:`pytest.mark.skipif` keyword arguments.
"""
conditions = []
Expand Down Expand Up @@ -177,6 +180,10 @@ def __new__(
conditions.append(ge_3_9 and old_torch)
reasons.append("Impacted by https://github.com/pytorch/pytorch/issues/49370")

if omegaconf:
conditions.append(not _OMEGACONF_AVAILABLE)
reasons.append("omegaconf")

reasons = [rs for cond, rs in zip(conditions, reasons) if cond]
return pytest.mark.skipif(
*args, condition=any(conditions), reason=f"Requires: [{' + '.join(reasons)}]", **kwargs
Expand Down
10 changes: 6 additions & 4 deletions tests/loggers/test_tensorboard.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,13 +21,16 @@
import pytest
import torch
import yaml
from omegaconf import OmegaConf

from pytorch_lightning import Trainer
from pytorch_lightning.loggers import TensorBoardLogger
from pytorch_lightning.loggers.base import LoggerCollection
from pytorch_lightning.utilities.imports import _compare_version
from pytorch_lightning.utilities.imports import _compare_version, _OMEGACONF_AVAILABLE
from tests.helpers import BoringModel
from tests.helpers.runif import RunIf

if _OMEGACONF_AVAILABLE:
from omegaconf import OmegaConf


@pytest.mark.skipif(
Expand Down Expand Up @@ -205,6 +208,7 @@ def test_tensorboard_log_hparams_and_metrics(tmpdir):
logger.log_hyperparams(hparams, metrics)


@RunIf(omegaconf=True)
def test_tensorboard_log_omegaconf_hparams_and_metrics(tmpdir):
logger = TensorBoardLogger(tmpdir, default_hp_metric=False)
hparams = {
Expand All @@ -214,8 +218,6 @@ def test_tensorboard_log_omegaconf_hparams_and_metrics(tmpdir):
"bool": True,
"dict": {"a": {"b": "c"}},
"list": [1, 2, 3],
# "namespace": Namespace(foo=Namespace(bar="buzz")),
# "layer": torch.nn.BatchNorm1d,
}
hparams = OmegaConf.create(hparams)

Expand Down
58 changes: 23 additions & 35 deletions tests/models/test_hparams.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,21 +24,24 @@
import pytest
import torch
from fsspec.implementations.local import LocalFileSystem
from omegaconf import Container, OmegaConf
from omegaconf.dictconfig import DictConfig
from torch.utils.data import DataLoader

from pytorch_lightning import LightningModule, Trainer
from pytorch_lightning.callbacks import ModelCheckpoint
from pytorch_lightning.core.datamodule import LightningDataModule
from pytorch_lightning.core.saving import load_hparams_from_yaml, save_hparams_to_yaml
from pytorch_lightning.utilities import _HYDRA_EXPERIMENTAL_AVAILABLE, AttributeDict, is_picklable
from pytorch_lightning.utilities import _HYDRA_EXPERIMENTAL_AVAILABLE, _OMEGACONF_AVAILABLE, AttributeDict, is_picklable
from pytorch_lightning.utilities.exceptions import MisconfigurationException
from tests.helpers import BoringModel, RandomDataset
from tests.helpers.runif import RunIf

if _HYDRA_EXPERIMENTAL_AVAILABLE:
from hydra.experimental import compose, initialize

if _OMEGACONF_AVAILABLE:
from omegaconf import Container, OmegaConf
from omegaconf.dictconfig import DictConfig


class SaveHparamsModel(BoringModel):
"""Tests that a model can take an object."""
Expand Down Expand Up @@ -117,6 +120,7 @@ def test_dict_hparams(tmpdir, cls):
_run_standard_hparams_test(tmpdir, model, cls)


@RunIf(omegaconf=True)
@pytest.mark.parametrize("cls", [SaveHparamsModel, SaveHparamsDecoratedModel])
def test_omega_conf_hparams(tmpdir, cls):
# init model
Expand Down Expand Up @@ -275,10 +279,18 @@ def __init__(obj, *more_args, other_arg=300, **more_kwargs):
obj.save_hyperparameters()


class DictConfSubClassBoringModel(SubClassBoringModel):
def __init__(self, *args, dict_conf=OmegaConf.create(dict(my_param="something")), **kwargs):
super().__init__(*args, **kwargs)
self.save_hyperparameters()
if _OMEGACONF_AVAILABLE:

class DictConfSubClassBoringModel(SubClassBoringModel):
def __init__(self, *args, dict_conf=OmegaConf.create(dict(my_param="something")), **kwargs):
super().__init__(*args, **kwargs)
self.save_hyperparameters()


else:

class DictConfSubClassBoringModel:
...


@pytest.mark.parametrize(
Expand All @@ -290,7 +302,7 @@ def __init__(self, *args, dict_conf=OmegaConf.create(dict(my_param="something"))
SubSubClassBoringModel,
AggSubClassBoringModel,
UnconventionalArgsBoringModel,
DictConfSubClassBoringModel,
pytest.param(DictConfSubClassBoringModel, marks=RunIf(omegaconf=True)),
],
)
def test_collect_init_arguments(tmpdir, cls):
Expand Down Expand Up @@ -383,31 +395,6 @@ def test_collect_init_arguments_with_local_vars(cls):
assert model.hparams["arg2"] == 2


# @pytest.mark.parametrize("cls,config", [
# (SaveHparamsModel, Namespace(my_arg=42)),
# (SaveHparamsModel, dict(my_arg=42)),
# (SaveHparamsModel, OmegaConf.create(dict(my_arg=42))),
# (AssignHparamsModel, Namespace(my_arg=42)),
# (AssignHparamsModel, dict(my_arg=42)),
# (AssignHparamsModel, OmegaConf.create(dict(my_arg=42))),
# ])
# def test_single_config_models(tmpdir, cls, config):
# """ Test that the model automatically saves the arguments passed into the constructor """
# model = cls(config)
#
# # no matter how you do it, it should be assigned
# assert model.hparams.my_arg == 42
#
# # verify that the checkpoint saved the correct values
# trainer = Trainer(default_root_dir=tmpdir, max_epochs=2, overfit_batches=0.5)
# trainer.fit(model)
#
# # verify that model loads correctly
# raw_checkpoint_path = _raw_checkpoint_path(trainer)
# model = cls.load_from_checkpoint(raw_checkpoint_path)
# assert model.hparams.my_arg == 42


class AnotherArgModel(BoringModel):
def __init__(self, arg1):
super().__init__()
Expand Down Expand Up @@ -511,8 +498,9 @@ def _compare_params(loaded_params, default_params: dict):
save_hparams_to_yaml(path_yaml, AttributeDict(hparams))
_compare_params(load_hparams_from_yaml(path_yaml, use_omegaconf=False), hparams)

save_hparams_to_yaml(path_yaml, OmegaConf.create(hparams))
_compare_params(load_hparams_from_yaml(path_yaml), hparams)
if _OMEGACONF_AVAILABLE:
save_hparams_to_yaml(path_yaml, OmegaConf.create(hparams))
_compare_params(load_hparams_from_yaml(path_yaml), hparams)


class NoArgsSubClassBoringModel(CustomBoringModel):
Expand Down
17 changes: 10 additions & 7 deletions tests/trainer/test_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,6 @@
import cloudpickle
import pytest
import torch
from omegaconf import OmegaConf
from torch.nn.parallel.distributed import DistributedDataParallel
from torch.optim import SGD
from torch.utils.data import DataLoader, IterableDataset
Expand All @@ -51,6 +50,7 @@
from pytorch_lightning.utilities import _AcceleratorType, _StrategyType
from pytorch_lightning.utilities.cloud_io import load as pl_load
from pytorch_lightning.utilities.exceptions import DeadlockDetectedException, MisconfigurationException
from pytorch_lightning.utilities.imports import _OMEGACONF_AVAILABLE
from pytorch_lightning.utilities.seed import seed_everything
from tests.base import EvalModelTemplate
from tests.helpers import BoringModel, RandomDataset
Expand All @@ -59,6 +59,9 @@
from tests.helpers.runif import RunIf
from tests.helpers.simple_models import ClassificationModel

if _OMEGACONF_AVAILABLE:
from omegaconf import OmegaConf


@pytest.mark.parametrize("url_ckpt", [True, False])
def test_no_val_module(monkeypatch, tmpdir, tmpdir_server, url_ckpt):
Expand Down Expand Up @@ -1271,12 +1274,12 @@ def __init__(self, **kwargs):
TrainerSubclass(abcdefg="unknown_arg")


@pytest.mark.parametrize(
"trainer_params", [OmegaConf.create(dict(max_epochs=1, gpus=1)), OmegaConf.create(dict(max_epochs=1, gpus=[0]))]
)
@RunIf(min_gpus=1)
def test_trainer_omegaconf(trainer_params):
Trainer(**trainer_params)
@RunIf(omegaconf=True)
@pytest.mark.parametrize("trainer_params", [{"max_epochs": 1, "gpus": 1}, {"max_epochs": 1, "gpus": [0]}])
@mock.patch("torch.cuda.device_count", return_value=1)
def test_trainer_omegaconf(_, trainer_params):
config = OmegaConf.create(trainer_params)
Trainer(**config)


def test_trainer_pickle(tmpdir):
Expand Down