Skip to content

Commit

Permalink
fix: allow import from Registryconfig with optional dependencies (#180)
Browse files Browse the repository at this point in the history
Signed-off-by: Avik Basu <ab93@users.noreply.github.com>
  • Loading branch information
ab93 committed May 10, 2023
1 parent e6a6811 commit 58b3cd4
Show file tree
Hide file tree
Showing 10 changed files with 272 additions and 453 deletions.
4 changes: 3 additions & 1 deletion .github/workflows/ci.yml
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,9 @@ jobs:
- name: Install dependencies
run: |
poetry env use ${{ matrix.python-version }}
poetry install --all-extras --with dev,torch
poetry install --all-extras --with dev
poetry run pip install "torch<3.0" -i https://download.pytorch.org/whl/cpu
poetry run pip install "pytorch-lightning<3.0"
- name: Test with pytest
run: make test
4 changes: 3 additions & 1 deletion .github/workflows/coverage.yml
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,9 @@ jobs:
- name: Install dependencies
run: |
poetry env use ${{ matrix.python-version }}
poetry install --all-extras --with dev,torch
poetry install --all-extras --with dev
poetry run pip install "torch<3.0" -i https://download.pytorch.org/whl/cpu
poetry run pip install "pytorch-lightning<3.0"
- name: Run Coverage
run: |
Expand Down
2 changes: 2 additions & 0 deletions numalogic/config/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
PreprocessFactory,
PostprocessFactory,
ThresholdFactory,
RegistryFactory,
)


Expand All @@ -28,4 +29,5 @@
"PreprocessFactory",
"PostprocessFactory",
"ThresholdFactory",
"RegistryFactory",
]
34 changes: 29 additions & 5 deletions numalogic/config/factory.py
Original file line number Diff line number Diff line change
Expand Up @@ -91,9 +91,33 @@ class ModelFactory(_ObjectFactory):


class RegistryFactory(_ObjectFactory):
import numalogic.registry as reg
_CLS_SET = {"RedisRegistry", "MLflowRegistry"}

_CLS_MAP = {
"RedisRegistry": getattr(reg, "RedisRegistry"),
"MLflowRegistry": getattr(reg, "MLflowRegistry"),
}
def get_instance(self, object_info: Union[ModelInfo, RegistryInfo]):
import numalogic.registry as reg

try:
_cls = getattr(reg, object_info.name)
except AttributeError as err:
if object_info.name in self._CLS_SET:
raise ImportError(
"Please install the required dependencies for the registry you want to use."
) from err
raise UnknownConfigArgsError(
f"Invalid model info instance provided: {object_info}"
) from err
return _cls(**object_info.conf)

def get_cls(self, object_info: Union[ModelInfo, RegistryInfo]):
import numalogic.registry as reg

try:
return getattr(reg, object_info.name)
except AttributeError as err:
if object_info.name in self._CLS_SET:
raise ImportError(
"Please install the required dependencies for the registry you want to use."
) from err
raise UnknownConfigArgsError(
f"Invalid model info instance provided: {object_info}"
) from err
25 changes: 14 additions & 11 deletions numalogic/registry/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,17 +13,20 @@
from numalogic.registry.artifact import ArtifactManager, ArtifactData, ArtifactCache
from numalogic.registry.localcache import LocalLRUCache


__all__ = ["ArtifactManager", "ArtifactData", "ArtifactCache", "LocalLRUCache"]


try:
from numalogic.registry.mlflow_registry import MLflowRegistry # noqa: F401
except ImportError:
pass
else:
__all__.append("MLflowRegistry")

try:
from numalogic.registry.mlflow_registry import MLflowRegistry
from numalogic.registry.redis_registry import RedisRegistry
from numalogic.registry.redis_registry import RedisRegistry # noqa: F401
except ImportError:
__all__ = ["ArtifactManager", "ArtifactData", "ArtifactCache", "LocalLRUCache"]
pass
else:
__all__ = [
"ArtifactManager",
"ArtifactData",
"MLflowRegistry",
"ArtifactCache",
"LocalLRUCache",
"RedisRegistry",
]
__all__.append("RedisRegistry")
586 changes: 154 additions & 432 deletions poetry.lock

Large diffs are not rendered by default.

4 changes: 2 additions & 2 deletions pyproject.toml
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
[tool.poetry]
name = "numalogic"
version = "0.4.dev4"
version = "0.4.dev5"
description = "Collection of operational Machine Learning models and tools."
authors = ["Numalogic Developers"]
packages = [{ include = "numalogic" }]
Expand Down Expand Up @@ -92,7 +92,7 @@ exclude = '''
line-length = 100
src = ["numalogic", "tests"]
select = ["E", "F", "W", "C901", "NPY", "RUF", "TRY", "G", "PLE", "PLW", "UP", "ICN", "RET", "Q"]
ignore = ["TRY003", "TRY301"]
ignore = ["TRY003", "TRY301", "RUF100"]
target-version = "py39"
show-fixes = true
show-source = true
Expand Down
Empty file added tests/config/__init__.py
Empty file.
2 changes: 1 addition & 1 deletion tests/test_config.py → tests/config/test_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,8 +26,8 @@
ThresholdFactory,
NumalogicConf,
ModelInfo,
RegistryFactory,
)
from numalogic.config.factory import RegistryFactory
from numalogic.models.autoencoder import AutoencoderTrainer
from numalogic.models.autoencoder.variants import SparseVanillaAE, SparseConv1dAE, LSTMAE
from numalogic.models.threshold import StdDevThreshold
Expand Down
64 changes: 64 additions & 0 deletions tests/config/test_optdeps.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,64 @@
# Copyright 2022 The Numaproj Authors.
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
# http://www.apache.org/licenses/LICENSE-2.0
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import os
import unittest
from unittest.mock import patch

import fakeredis
from omegaconf import OmegaConf

from numalogic._constants import TESTS_DIR
from numalogic.config import RegistryInfo
from numalogic.tools.exceptions import UnknownConfigArgsError


class TestOptionalDependencies(unittest.TestCase):
def setUp(self) -> None:
self._given_conf = OmegaConf.load(os.path.join(TESTS_DIR, "resources", "config.yaml"))
from numalogic.config import NumalogicConf

self.schema: NumalogicConf = OmegaConf.structured(NumalogicConf)
self.conf = OmegaConf.merge(self.schema, self._given_conf)

@patch("numalogic.config.factory.getattr", side_effect=AttributeError)
def test_not_installed_dep_01(self, _):
from numalogic.config.factory import RegistryFactory

model_factory = RegistryFactory()
server = fakeredis.FakeServer()
redis_cli = fakeredis.FakeStrictRedis(server=server, decode_responses=False)
with self.assertRaises(ImportError):
model_factory.get_cls(self.conf.registry)(redis_cli, **self.conf.registry.conf)

@patch("numalogic.config.factory.getattr", side_effect=AttributeError)
def test_not_installed_dep_02(self, _):
from numalogic.config.factory import RegistryFactory

model_factory = RegistryFactory()
server = fakeredis.FakeServer()
redis_cli = fakeredis.FakeStrictRedis(server=server, decode_responses=False)
with self.assertRaises(ImportError):
model_factory.get_instance(self.conf.registry)(redis_cli, **self.conf.registry.conf)

def test_unknown_registry(self):
from numalogic.config.factory import RegistryFactory

model_factory = RegistryFactory()
reg_conf = RegistryInfo(name="UnknownRegistry")
with self.assertRaises(UnknownConfigArgsError):
model_factory.get_cls(reg_conf)
with self.assertRaises(UnknownConfigArgsError):
model_factory.get_instance(reg_conf)


if __name__ == "__main__":
unittest.main()

0 comments on commit 58b3cd4

Please sign in to comment.