Skip to content

Commit

Permalink
adding required wrapper (#1056)
Browse files Browse the repository at this point in the history
Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
  • Loading branch information
Borda and pre-commit-ci[bot] committed Jul 12, 2023
1 parent c6f6d3b commit b8966d1
Show file tree
Hide file tree
Showing 3 changed files with 57 additions and 1 deletion.
2 changes: 1 addition & 1 deletion CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0

### Added

-
- Added `requires` wrapper ([#1056](https://github.com/PyTorchLightning/pytorch-lightning-bolts/pull/1056))


### Changed
Expand Down
28 changes: 28 additions & 0 deletions src/pl_bolts/utils/_dependency.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,28 @@
import functools
import os
from typing import Any, Callable

from lightning_utilities.core.imports import ModuleAvailableCache, RequirementCache


# ToDo: replace with utils wrapper after 0.10 is released
def requires(*module_path_version: str) -> Callable:
"""Wrapper for enforcing certain requirements for a particular class or function."""

def decorator(func: Callable) -> Callable:
reqs = [
ModuleAvailableCache(mod_ver) if "." in mod_ver else RequirementCache(mod_ver)
for mod_ver in module_path_version
]
available = all(map(bool, reqs))
if not available:

@functools.wraps(func)
def wrapper(*args: Any, **kwargs: Any) -> Any:
msg = os.linesep.join([repr(r) for r in reqs if not bool(r)])
raise ModuleNotFoundError(f"Required dependencies not available: \n{msg}")

return wrapper
return func

return decorator
28 changes: 28 additions & 0 deletions tests/utils/test_dependency.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,28 @@
import pytest
from pl_bolts.utils._dependency import requires


@requires("torch")
def using_torch():
return True


@requires("torch.anything.wrong")
def using_torch_wrong_path():
return True


@requires("torch>99.0")
def using_torch_bad_version():
return True


def test_requires_pass():
assert using_torch() is True


def test_requires_fail():
with pytest.raises(ModuleNotFoundError, match="Required dependencies not available"):
assert using_torch_wrong_path()
with pytest.raises(ModuleNotFoundError, match="Required dependencies not available"):
assert using_torch_bad_version()

0 comments on commit b8966d1

Please sign in to comment.