Skip to content

Commit

Permalink
fix incorrect implicit return types
Browse files Browse the repository at this point in the history
  • Loading branch information
Avasam committed May 23, 2024
1 parent 52d7324 commit f7f3833
Showing 1 changed file with 63 additions and 41 deletions.
104 changes: 63 additions & 41 deletions pkg_resources/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,8 @@
import types
from typing import (
Any,
Generator,
Iterator,
Mapping,
MutableSequence,
NamedTuple,
Expand Down Expand Up @@ -99,7 +101,16 @@
from pkg_resources.extern.platformdirs import user_cache_dir as _user_cache_dir

if TYPE_CHECKING:
from _typeshed import StrPath
from typing_extensions import Self
from _typeshed import StrPath, StrOrBytesPath
from itertools import chain

# yield_lines is exported, we should type it, but extern is too dynamic
def yield_lines(iterable: "_NestedStr") -> chain[str]: ...

# Trick type-checkers into seeing the original instances for type comparisons
from packaging import version as _packaging_version
from packaging import requirements as _packaging_requirements

warnings.warn(
"pkg_resources is deprecated as an API. "
Expand All @@ -109,7 +120,7 @@
)


T = TypeVar("T")
_T = TypeVar("_T")
# Type aliases
_NestedStr = Union[str, Iterable[Union[str, Iterable["_NestedStr"]]]]
_InstallerType = Callable[["Requirement"], Optional["Distribution"]]
Expand Down Expand Up @@ -142,20 +153,20 @@ class PEP440Warning(RuntimeWarning):
_state_vars: Dict[str, str] = {}


def _declare_state(vartype: str, varname: str, initial_value: T) -> T:
def _declare_state(vartype: str, varname: str, initial_value: _T) -> _T:
_state_vars[varname] = vartype
return initial_value


def __getstate__():
def __getstate__() -> Dict[str, Any]:
state = {}
g = globals()
for k, v in _state_vars.items():
state[k] = g['_sget_' + v](g[k])
return state


def __setstate__(state):
def __setstate__(state: Dict[str, Any]) -> Dict[str, Any]:
g = globals()
for k, v in state.items():
g['_sset_' + _state_vars[k]](k, g[k], v)
Expand Down Expand Up @@ -310,11 +321,11 @@ class VersionConflict(ResolutionError):
_template = "{self.dist} is installed but {self.req} is required"

@property
def dist(self):
def dist(self) -> "Distribution":
return self.args[0]

@property
def req(self):
def req(self) -> "Requirement":
return self.args[1]

def report(self):
Expand All @@ -340,7 +351,7 @@ class ContextualVersionConflict(VersionConflict):
_template = VersionConflict._template + ' by {self.required_by}'

@property
def required_by(self):
def required_by(self) -> Set[str]:
return self.args[2]


Expand All @@ -353,11 +364,11 @@ class DistributionNotFound(ResolutionError):
)

@property
def req(self):
def req(self) -> "Requirement":
return self.args[0]

@property
def requirers(self):
def requirers(self) -> Optional[Set[str]]:
return self.args[1]

@property
Expand Down Expand Up @@ -663,11 +674,11 @@ def add_entry(self, entry: str):
for dist in find_distributions(entry, True):
self.add(dist, entry, False)

def __contains__(self, dist: "Distribution"):
def __contains__(self, dist: "Distribution") -> bool:
"""True if `dist` is the active distribution for its project"""
return self.by_key.get(dist.key) == dist

def find(self, req: "Requirement"):
def find(self, req: "Requirement") -> Optional["Distribution"]:
"""Find a distribution matching requirement `req`
If there is an active distribution for the requested project, this
Expand All @@ -691,7 +702,9 @@ def find(self, req: "Requirement"):
raise VersionConflict(dist, req)
return dist

def iter_entry_points(self, group: str, name: Optional[str] = None):
def iter_entry_points(
self, group: str, name: Optional[str] = None
) -> Generator["EntryPoint", None, None]:
"""Yield entry point objects from `group` matching `name`
If `name` is None, yields all entry points in `group` from all
Expand All @@ -713,7 +726,7 @@ def run_script(self, requires: str, script_name: str):
ns['__name__'] = name
self.require(requires)[0].run_script(script_name, ns)

def __iter__(self):
def __iter__(self) -> Iterator["Distribution"]:
"""Yield distributions for non-duplicate projects in the working set
The yield order is the order in which the items' path entries were
Expand Down Expand Up @@ -1099,7 +1112,7 @@ def scan(self, search_path: Optional[Sequence[str]] = None):
for dist in find_distributions(item):
self.add(dist)

def __getitem__(self, project_name: str):
def __getitem__(self, project_name: str) -> List["Distribution"]:
"""Return a newest-to-oldest list of distributions for `project_name`
Uses case-insensitive `project_name` comparison, assuming all the
Expand Down Expand Up @@ -1166,7 +1179,7 @@ def obtain(
to the `installer` argument."""
return installer(requirement) if installer else None

def __iter__(self):
def __iter__(self) -> Iterator[str]:
"""Yield the unique project names of the available distributions"""
for key in self._distmap.keys():
if self[key]:
Expand Down Expand Up @@ -1399,7 +1412,7 @@ def cleanup_resources(self, force: bool = False) -> List[str]:
return []


def get_default_cache():
def get_default_cache() -> str:
"""
Return the ``PYTHON_EGG_CACHE`` environment variable
or a platform-relevant user cache dir for an app
Expand Down Expand Up @@ -1491,7 +1504,7 @@ def invalid_marker(text: str):
return False


def evaluate_marker(text: str, extra: Optional[str] = None):
def evaluate_marker(text: str, extra: Optional[str] = None) -> bool:
"""
Evaluate a PEP 508 environment marker.
Return a boolean indicating the marker result in this environment.
Expand Down Expand Up @@ -1829,7 +1842,7 @@ class manifest_mod(NamedTuple):
manifest: Dict[str, zipfile.ZipInfo]
mtime: float

def load(self, path: str): # type: ignore[override] # ZipManifests.load is a classmethod
def load(self, path: str) -> Dict[str, zipfile.ZipInfo]: # type: ignore[override] # ZipManifests.load is a classmethod
"""
Load a manifest at path or return a suitable manifest already loaded.
"""
Expand Down Expand Up @@ -2112,7 +2125,9 @@ def register_finder(importer_type: type, distribution_finder: _AdapterType):
_distribution_finders[importer_type] = distribution_finder


def find_distributions(path_item: str, only: bool = False):
def find_distributions(
path_item: str, only: bool = False
) -> Generator["Distribution", None, None]:
"""Yield distributions accessible via `path_item`"""
importer = get_importer(path_item)
finder = _find_adapter(_distribution_finders, importer)
Expand All @@ -2121,7 +2136,7 @@ def find_distributions(path_item: str, only: bool = False):

def find_eggs_in_zip(
importer: zipimport.zipimporter, path_item: str, only: bool = False
):
) -> Generator["Distribution", None, None]:
"""
Find eggs in zip files; possibly multiple nested eggs.
"""
Expand Down Expand Up @@ -2159,7 +2174,9 @@ def find_nothing(
register_finder(object, find_nothing)


def find_on_path(importer: Optional[object], path_item, only=False):
def find_on_path(
importer: Optional[object], path_item, only=False
) -> Generator["Distribution", None, None]:
"""Yield distributions accessible on a sys.path directory"""
path_item = _normalize_cached(path_item)

Expand Down Expand Up @@ -2214,7 +2231,7 @@ def __call__(self, fullpath):
return iter(())


def safe_listdir(path):
def safe_listdir(path: "StrOrBytesPath"):
"""
Attempt to list contents of path, but suppress some exceptions.
"""
Expand All @@ -2230,13 +2247,13 @@ def safe_listdir(path):
return ()


def distributions_from_metadata(path):
def distributions_from_metadata(path) -> Generator["Distribution", None, None]:
root = os.path.dirname(path)
if os.path.isdir(path):
if len(os.listdir(path)) == 0:
# empty metadata dir; skip
return
metadata = PathMetadata(root, path)
metadata: _MetadataType = PathMetadata(root, path)
else:
metadata = FileMetadata(path)
entry = os.path.basename(path)
Expand Down Expand Up @@ -2430,8 +2447,8 @@ def fixup_namespace_packages(path_item: str, parent: Optional[str] = None):

def file_ns_handler(
importer: Optional[importlib.abc.PathEntryFinder],
path_item,
packageName,
path_item: "StrPath",
packageName: str,
module: types.ModuleType,
):
"""Compute an ns-package subpath for a filesystem or zipfile importer"""
Expand Down Expand Up @@ -2661,7 +2678,7 @@ def parse_group(
"""Parse an entry point group"""
if not MODULE(group):
raise ValueError("Invalid group name", group)
this = {}
this: Dict[str, Self] = {}
for line in yield_lines(lines):
ep = cls.parse(line, dist)
if ep.name in this:
Expand All @@ -2676,11 +2693,12 @@ def parse_map(
dist: Optional["Distribution"] = None,
):
"""Parse a map of entry point groups"""
_data: Iterable[Tuple[Optional[str], Union[str, Iterable[str]]]]
if isinstance(data, dict):
_data = data.items()
else:
_data = split_sections(data)
maps: Dict[str, Dict[str, "EntryPoint"]] = {}
maps: Dict[str, Dict[str, "Self"]] = {}
for group, lines in _data:
if group is None:
if not lines:
Expand Down Expand Up @@ -2739,7 +2757,7 @@ def from_location(
basename: str,
metadata: _MetadataType = None,
**kw: int, # We could set `precedence` explicitly, but keeping this as `**kw` for full backwards and subclassing compatibility
):
) -> "Distribution":
project_name, version, py_version, platform = [None] * 4
basename, ext = os.path.splitext(basename)
if ext.lower() in _distributionImpl:
Expand Down Expand Up @@ -2878,14 +2896,14 @@ def _dep_map(self):
return self.__dep_map

@staticmethod
def _filter_extras(dm):
def _filter_extras(dm: Dict[Union[str, None], List["Requirement"]]):
"""
Given a mapping of extras to dependencies, strip off
environment markers and filter out any dependencies
not matching the markers.
"""
for extra in list(filter(None, dm)):
new_extra = extra
new_extra: Optional[str] = extra
reqs = dm.pop(extra)
new_extra, _, marker = extra.partition(':')
fails_marker = marker and (
Expand All @@ -2908,7 +2926,7 @@ def _build_dep_map(self):
def requires(self, extras: Iterable[str] = ()):
"""List of Requirements needed for this distro if `extras` are used"""
dm = self._dep_map
deps = []
deps: List[Requirement] = []
deps.extend(dm.get(None, ()))
for ext in extras:
try:
Expand Down Expand Up @@ -3205,11 +3223,11 @@ def _dep_map(self):
self.__dep_map = self._compute_dependencies()
return self.__dep_map

def _compute_dependencies(self):
def _compute_dependencies(self) -> Dict[Union[str, None], List["Requirement"]]:
"""Recompute this distribution's dependencies."""
dm = self.__dep_map = {None: []}
self.__dep_map: Dict[Union[str, None], List["Requirement"]] = {None: []}

reqs = []
reqs: List[Requirement] = []
# Including any condition expressions
for req in self._parsed_pkg_info.get_all('Requires-Dist') or []:
reqs.extend(parse_requirements(req))
Expand All @@ -3220,13 +3238,15 @@ def reqs_for_extra(extra):
yield req

common = types.MappingProxyType(dict.fromkeys(reqs_for_extra(None)))
dm[None].extend(common)
self.__dep_map[None].extend(common)

for extra in self._parsed_pkg_info.get_all('Provides-Extra') or []:
s_extra = safe_extra(extra.strip())
dm[s_extra] = [r for r in reqs_for_extra(extra) if r not in common]
self.__dep_map[s_extra] = [
r for r in reqs_for_extra(extra) if r not in common
]

return dm
return self.__dep_map


_distributionImpl = {
Expand Down Expand Up @@ -3287,7 +3307,7 @@ def __eq__(self, other: object):
def __ne__(self, other):
return not self == other

def __contains__(self, item: Union[Distribution, str, Tuple[str, ...]]):
def __contains__(self, item: Union[Distribution, str, Tuple[str, ...]]) -> bool:
if isinstance(item, Distribution):
if item.key != self.key:
return False
Expand Down Expand Up @@ -3351,7 +3371,9 @@ def _bypass_ensure_directory(path):
pass


def split_sections(s: _NestedStr):
def split_sections(
s: _NestedStr,
) -> Generator[Tuple[Optional[str], List[str]], None, None]:
"""Split a string or iterable thereof into (section, content) pairs
Each ``section`` is a stripped version of the section header ("[section]")
Expand Down

0 comments on commit f7f3833

Please sign in to comment.