Skip to content

Commit

Permalink
Replace use of pkg_resources with importlib.metadata
Browse files Browse the repository at this point in the history
Use iter_entry_point() workaround until Python 3.9 support is dropped
  • Loading branch information
tysmith committed Jun 27, 2024
1 parent 90584bc commit e2e2320
Show file tree
Hide file tree
Showing 9 changed files with 69 additions and 46 deletions.
3 changes: 1 addition & 2 deletions grizzly/args.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,8 +18,7 @@
from FTB.ProgramConfiguration import ProgramConfiguration

from .common.fuzzmanager import FM_CONFIG
from .common.plugins import scan as scan_plugins
from .common.plugins import scan_target_assets
from .common.plugins import scan_plugins, scan_target_assets
from .common.utils import DEFAULT_TIME_LIMIT, TIMEOUT_DELAY, package_version


Expand Down
20 changes: 10 additions & 10 deletions grizzly/common/plugins.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,9 +4,9 @@
from logging import getLogger
from typing import Any, Dict, List, Tuple

from pkg_resources import iter_entry_points
from .utils import iter_entry_point

__all__ = ("load", "scan", "PluginLoadError")
__all__ = ("load_plugin", "scan_plugins", "PluginLoadError")


LOG = getLogger(__name__)
Expand All @@ -16,7 +16,7 @@ class PluginLoadError(Exception):
"""Raised if loading a plug-in fails"""


def load(name: str, group: str, base_type: type) -> Any:
def load_plugin(name: str, group: str, base_type: type) -> Any:
"""Load a plug-in.
Args:
Expand All @@ -28,10 +28,10 @@ def load(name: str, group: str, base_type: type) -> Any:
Loaded plug-in object.
"""
assert isinstance(base_type, type)
for entry in iter_entry_points(group):
for entry in iter_entry_point(group):
if entry.name == name:
LOG.debug("loading %r (%s)", name, base_type.__name__)
plugin = entry.load()
LOG.debug("loading %r (%s)", name, base_type.__name__)
break
else:
raise PluginLoadError(f"{name!r} not found in {group!r}")
Expand All @@ -40,7 +40,7 @@ def load(name: str, group: str, base_type: type) -> Any:
return plugin


def scan(group: str) -> List[str]:
def scan_plugins(group: str) -> List[str]:
"""Scan for installed plug-ins.
Args:
Expand All @@ -49,9 +49,9 @@ def scan(group: str) -> List[str]:
Returns:
Names of installed entry points.
"""
found = []
found: List[str] = []
LOG.debug("scanning %r", group)
for entry in iter_entry_points(group):
for entry in iter_entry_point(group):
if entry.name in found:
# not sure if this can even happen
raise PluginLoadError(f"Duplicate entry {entry.name!r} in {group!r}")
Expand All @@ -68,7 +68,7 @@ def scan_target_assets() -> Dict[str, Tuple[str, ...]]:
Returns:
Name of target and list of supported assets.
"""
assets = {}
for entry in iter_entry_points("grizzly_targets"):
assets: Dict[str, Tuple[str, ...]] = {}
for entry in iter_entry_point("grizzly_targets"):
assets[entry.name] = entry.load().SUPPORTED_ASSETS
return assets
46 changes: 25 additions & 21 deletions grizzly/common/test_plugins.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,12 @@
# This Source Code Form is subject to the terms of the Mozilla Public
# License, v. 2.0. If a copy of the MPL was not distributed with this
# file, You can obtain one at http://mozilla.org/MPL/2.0/.
from pkg_resources import EntryPoint
from importlib.metadata import EntryPoint

from pytest import raises

from ..target import Target
from .plugins import PluginLoadError, load, scan, scan_target_assets
from .plugins import PluginLoadError, load_plugin, scan_plugins, scan_target_assets


class FakeType1:
Expand All @@ -19,22 +20,23 @@ class FakeType2:
def test_load_01(mocker):
"""test load() - nothing to load"""
mocker.patch(
"grizzly.common.plugins.iter_entry_points", autospec=True, return_value=[]
"grizzly.common.plugins.iter_entry_point", autospec=True, return_value=()
)
with raises(PluginLoadError, match="'test-name' not found in 'test-group'"):
load("test-name", "test-group", FakeType1)
load_plugin("test-name", "test-group", FakeType1)


def test_load_02(mocker):
"""test load() - successful load"""
# Note: Mock.name cannot be set via the constructor so spec_set cannot be used
entry = mocker.Mock(spec=EntryPoint)
entry.name = "test-name"
entry.load.return_value = FakeType1
mocker.patch(
"grizzly.common.plugins.iter_entry_points", autospec=True, return_value=[entry]
"grizzly.common.plugins.iter_entry_point",
autospec=True,
return_value=(entry,),
)
assert load("test-name", "test-group", FakeType1)
assert load_plugin("test-name", "test-group", FakeType1)


def test_load_03(mocker):
Expand All @@ -43,60 +45,62 @@ def test_load_03(mocker):
entry.name = "test-name"
entry.load.return_value = FakeType1
mocker.patch(
"grizzly.common.plugins.iter_entry_points", autospec=True, return_value=[entry]
"grizzly.common.plugins.iter_entry_point",
autospec=True,
return_value=(entry,),
)
with raises(PluginLoadError, match="'test-name' doesn't inherit from FakeType2"):
load("test-name", "test-group", FakeType2)
load_plugin("test-name", "test-group", FakeType2)


def test_scan_01(mocker):
"""test scan() - no entries found"""
mocker.patch(
"grizzly.common.plugins.iter_entry_points", autospec=True, return_value=[]
"grizzly.common.plugins.iter_entry_point", autospec=True, return_value=()
)
assert not scan("test_group")
assert not scan_plugins("test_group")


def test_scan_02(mocker):
"""test scan() - duplicate entry"""
entry = mocker.Mock(spec=EntryPoint)
entry.name = "test_entry"
mocker.patch(
"grizzly.common.plugins.iter_entry_points",
"grizzly.common.plugins.iter_entry_point",
autospec=True,
return_value=[entry, entry],
return_value=(entry, entry),
)
with raises(PluginLoadError, match="Duplicate entry 'test_entry' in 'test_group'"):
scan("test_group")
scan_plugins("test_group")


def test_scan_03(mocker):
"""test scan() - success"""
entry = mocker.Mock(spec=EntryPoint)
entry.name = "test-name"
mocker.patch(
"grizzly.common.plugins.iter_entry_points",
"grizzly.common.plugins.iter_entry_point",
autospec=True,
return_value=[entry],
return_value=(entry,),
)
assert "test-name" in scan("test_group")
assert "test-name" in scan_plugins("test_group")


def test_scan_target_assets_01(mocker):
"""test scan_target_assets() - success"""
targ1 = mocker.Mock(spec=EntryPoint)
targ1.name = "t1"
targ1.load.return_value = mocker.Mock(spec_set=Target, SUPPORTED_ASSETS=None)
targ1.load.return_value = mocker.Mock(spec_set=Target, SUPPORTED_ASSETS=())
targ2 = mocker.Mock(spec=EntryPoint)
targ2.name = "t2"
targ2.load.return_value = mocker.Mock(spec_set=Target, SUPPORTED_ASSETS=("a", "B"))
mocker.patch(
"grizzly.common.plugins.iter_entry_points",
"grizzly.common.plugins.iter_entry_point",
autospec=True,
return_value=[targ1, targ2],
return_value=(targ1, targ2),
)
assets = scan_target_assets()
assert "t1" in assets
assert assets["t1"] is None
assert not assets["t1"]
assert "t2" in assets
assert "B" in assets["t2"]
25 changes: 24 additions & 1 deletion grizzly/common/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,10 +2,11 @@
# License, v. 2.0. If a copy of the MPL was not distributed with this
# file, You can obtain one at http://mozilla.org/MPL/2.0/.
from enum import IntEnum, unique
from importlib.metadata import PackageNotFoundError, version
from importlib.metadata import EntryPoint, PackageNotFoundError, entry_points, version
from logging import DEBUG, basicConfig, getLogger
from os import getenv
from pathlib import Path
from sys import version_info
from tempfile import gettempdir
from typing import Any, Iterable, Optional, Tuple, Union

Expand All @@ -17,6 +18,7 @@
"Exit",
"grz_tmp",
"HARNESS_FILE",
"iter_entry_point",
"package_version",
"time_limits",
"TIMEOUT_DELAY",
Expand Down Expand Up @@ -105,6 +107,27 @@ def display_time_limits(time_limit: int, timeout: int, no_harness: bool) -> None
LOG.warning("TIMEOUT DISABLED, not recommended for automation")


def iter_entry_point(group: str) -> Tuple[EntryPoint, ...]:
"""Compatibility wrapper code for importlib.metadata.entry_points()
Args:
group: See entry_points().
Returns:
EntryPoints
"""
# TODO: remove this function when support for Python 3.9 is dropped
assert group
if version_info.minor >= 10:
eps = entry_points().select(group=group)
else:
try:
eps = entry_points()[group]
except KeyError:
eps = () # type: ignore
return tuple(x for x in eps)


def package_version(name: str, default: str = "unknown") -> str:
"""Get version of an installed package.
Expand Down
2 changes: 1 addition & 1 deletion grizzly/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@
from sapphire import CertificateBundle, Sapphire

from .adapter import Adapter
from .common.plugins import load as load_plugin
from .common.plugins import load_plugin
from .common.reporter import (
FailedLaunchReporter,
FilesystemReporter,
Expand Down
2 changes: 1 addition & 1 deletion grizzly/reduce/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@
from sapphire import CertificateBundle, Sapphire

from ..common.fuzzmanager import CrashEntry
from ..common.plugins import load as load_plugin
from ..common.plugins import load_plugin
from ..common.reporter import (
FailedLaunchReporter,
FilesystemReporter,
Expand Down
6 changes: 2 additions & 4 deletions grizzly/reduce/strategies/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,10 +30,8 @@
cast,
)

from pkg_resources import iter_entry_points

from ...common.storage import TestCase
from ...common.utils import grz_tmp
from ...common.utils import grz_tmp, iter_entry_point

LOG = getLogger(__name__)

Expand Down Expand Up @@ -204,7 +202,7 @@ def _load_strategies() -> Dict[str, Type[Strategy]]:
A mapping of strategy names to strategy class.
"""
strategies: Dict[str, Type[Strategy]] = {}
for entry_point in iter_entry_points("grizzly_reduce_strategies"):
for entry_point in iter_entry_point("grizzly_reduce_strategies"):
try:
strategy_cls = cast(Type[Strategy], entry_point.load())
assert (
Expand Down
9 changes: 4 additions & 5 deletions grizzly/reduce/test_strategies.py
Original file line number Diff line number Diff line change
Expand Up @@ -82,11 +82,10 @@ class _BadStrategy:
def load(cls):
raise RuntimeError("oops")

def entries(_):
yield _BadStrategy
yield _GoodStrategy

mocker.patch("grizzly.reduce.strategies.iter_entry_points", side_effect=entries)
mocker.patch(
"grizzly.reduce.strategies.iter_entry_point",
return_value=(_BadStrategy, _GoodStrategy),
)
mocker.patch("grizzly.reduce.strategies.DEFAULT_STRATEGIES", new=("good",))
result = _load_strategies()
assert result == {"good": _GoodStrategy}
Expand Down
2 changes: 1 addition & 1 deletion grizzly/replay/replay.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@

from sapphire import CertificateBundle, Sapphire, ServerMap

from ..common.plugins import load as load_plugin
from ..common.plugins import load_plugin
from ..common.report import Report
from ..common.reporter import (
FailedLaunchReporter,
Expand Down

0 comments on commit e2e2320

Please sign in to comment.