diff --git a/mypy.ini b/mypy.ini index c5b13942c43..b8825fa8da8 100644 --- a/mypy.ini +++ b/mypy.ini @@ -16,6 +16,8 @@ exclude = (?x)( | ^setuptools/config/_validate_pyproject/ # Auto-generated | ^setuptools/tests/bdist_wheel_testdata/ # Duplicate module name ) +# Too many false-positives +disable_error_code = overload-overlap # Ignoring attr-defined because setuptools wraps a lot of distutils classes, adding new attributes, # w/o updating all the attributes and return types from the base classes for type-checkers to understand diff --git a/pkg_resources/__init__.py b/pkg_resources/__init__.py index c86d9f095cc..67bdaff7a5c 100644 --- a/pkg_resources/__init__.py +++ b/pkg_resources/__init__.py @@ -32,6 +32,8 @@ import types from typing import ( Any, + Iterator, + Literal, Mapping, MutableSequence, NamedTuple, @@ -49,6 +51,7 @@ Iterable, Optional, TypeVar, + overload, ) import zipfile import zipimport @@ -99,7 +102,7 @@ from pkg_resources.extern.platformdirs import user_cache_dir as _user_cache_dir if TYPE_CHECKING: - from _typeshed import StrPath + from _typeshed import StrPath, BytesPath, StrOrBytesPath warnings.warn( "pkg_resources is deprecated as an API. " @@ -110,12 +113,16 @@ T = TypeVar("T") +_DistributionT = TypeVar("_DistributionT", bound="Distribution") # Type aliases _NestedStr = Union[str, Iterable[Union[str, Iterable["_NestedStr"]]]] -_InstallerType = Callable[["Requirement"], Optional["Distribution"]] +_StrictInstallerType = Callable[["Requirement"], "_DistributionT"] +_InstallerType = Optional[Callable[["Requirement"], Optional["Distribution"]]] _PkgReqType = Union[str, "Requirement"] _EPDistType = Union["Distribution", _PkgReqType] _MetadataType = Optional["IResourceProvider"] +_ResolvedEntryPoint = Any # Can be any attribute in the module +_ResourceStream = Any # Incomplete: A readable file-like object # Any object works, but let's indicate we expect something like a module (optionally has __loader__ or __file__) _ModuleLike = Union[object, types.ModuleType] _AdapterType = Callable[..., Any] # Incomplete @@ -126,6 +133,10 @@ class _LoaderProtocol(Protocol): def load_module(self, fullname: str, /) -> types.ModuleType: ... +class ZipLoaderModule(Protocol): + __loader__: zipimport.zipimporter + + _PEP440_FALLBACK = re.compile(r"^v?(?P(?:[0-9]+!)?[0-9]+(?:\.[0-9]+)*)", re.I) @@ -399,7 +410,13 @@ def register_loader_type( _provider_factories[loader_type] = provider_factory -def get_provider(moduleOrReq: Union[str, "Requirement"]): +@overload +def get_provider(moduleOrReq: str) -> "IResourceProvider": ... +@overload +def get_provider(moduleOrReq: "Requirement") -> "Distribution": ... +def get_provider( + moduleOrReq: Union[str, "Requirement"], +) -> Union["IResourceProvider", "Distribution"]: """Return an IResourceProvider for the named module or requirement""" if isinstance(moduleOrReq, Requirement): return working_set.find(moduleOrReq) or require(str(moduleOrReq))[0] @@ -510,22 +527,35 @@ def compatible_platforms(provided: Optional[str], required: Optional[str]): return False -def get_distribution(dist: _EPDistType): +@overload +def get_distribution(dist: _DistributionT) -> _DistributionT: ... +@overload +def get_distribution(dist: _PkgReqType) -> "Distribution": ... +def get_distribution( + dist: Union[_DistributionT, _PkgReqType], +) -> Union[_DistributionT, "Distribution"]: """Return a current distribution object for a Requirement or string""" if isinstance(dist, str): dist = Requirement.parse(dist) if isinstance(dist, Requirement): - dist = get_provider(dist) + # Bad type narrowing, dist has to be a Requirement here, so get_provider has to return Distribution + dist = get_provider(dist) # type: ignore[assignment] if not isinstance(dist, Distribution): - raise TypeError("Expected string, Requirement, or Distribution", dist) + raise TypeError("Expected str, Requirement, or Distribution", dist) return dist -def load_entry_point(dist: _EPDistType, group: str, name: str): +def load_entry_point(dist: _EPDistType, group: str, name: str) -> _ResolvedEntryPoint: """Return `name` entry point of `group` for `dist` or raise ImportError""" return get_distribution(dist).load_entry_point(group, name) +@overload +def get_entry_map( + dist: _EPDistType, group: None = None +) -> Dict[str, Dict[str, "EntryPoint"]]: ... +@overload +def get_entry_map(dist: _EPDistType, group: str) -> Dict[str, "EntryPoint"]: ... def get_entry_map(dist: _EPDistType, group: Optional[str] = None): """Return the entry point map for `group`, or the full entry map""" return get_distribution(dist).get_entry_map(group) @@ -540,10 +570,10 @@ class IMetadataProvider(Protocol): def has_metadata(self, name: str) -> bool: """Does the package's distribution contain the named metadata?""" - def get_metadata(self, name: str): + def get_metadata(self, name: str) -> str: """The named metadata resource as a string""" - def get_metadata_lines(self, name: str): + def get_metadata_lines(self, name: str) -> Iterator[str]: """Yield named metadata resource as list of non-blank non-comment lines Leading and trailing whitespace is stripped from each line, and lines @@ -552,22 +582,26 @@ def get_metadata_lines(self, name: str): def metadata_isdir(self, name: str) -> bool: """Is the named metadata a directory? (like ``os.path.isdir()``)""" - def metadata_listdir(self, name: str): + def metadata_listdir(self, name: str) -> List[str]: """List of metadata names in the directory (like ``os.listdir()``)""" - def run_script(self, script_name: str, namespace: Dict[str, Any]): + def run_script(self, script_name: str, namespace: Dict[str, Any]) -> None: """Execute the named script in the supplied namespace dictionary""" class IResourceProvider(IMetadataProvider, Protocol): """An object that provides access to package resources""" - def get_resource_filename(self, manager: "ResourceManager", resource_name: str): + def get_resource_filename( + self, manager: "ResourceManager", resource_name: str + ) -> str: """Return a true filesystem path for `resource_name` `manager` must be a ``ResourceManager``""" - def get_resource_stream(self, manager: "ResourceManager", resource_name: str): + def get_resource_stream( + self, manager: "ResourceManager", resource_name: str + ) -> _ResourceStream: """Return a readable file-like object for `resource_name` `manager` must be a ``ResourceManager``""" @@ -579,13 +613,13 @@ def get_resource_string( `manager` must be a ``ResourceManager``""" - def has_resource(self, resource_name: str): + def has_resource(self, resource_name: str) -> bool: """Does the package contain the named resource?""" - def resource_isdir(self, resource_name: str): + def resource_isdir(self, resource_name: str) -> bool: """Is the named resource a directory? (like ``os.path.isdir()``)""" - def resource_listdir(self, resource_name: str): + def resource_listdir(self, resource_name: str) -> List[str]: """List of resource names in the directory (like ``os.listdir()``)""" @@ -768,14 +802,42 @@ def add( keys2.append(dist.key) self._added_new(dist) + @overload + def resolve( + self, + requirements: Iterable["Requirement"], + env: Optional["Environment"], + installer: _StrictInstallerType[_DistributionT], + replace_conflicting: bool = False, + extras: Optional[Tuple[str, ...]] = None, + ) -> List[_DistributionT]: ... + @overload def resolve( self, requirements: Iterable["Requirement"], env: Optional["Environment"] = None, - installer: Optional[_InstallerType] = None, + *, + installer: _StrictInstallerType[_DistributionT], replace_conflicting: bool = False, extras: Optional[Tuple[str, ...]] = None, - ): + ) -> List[_DistributionT]: ... + @overload + def resolve( + self, + requirements: Iterable["Requirement"], + env: Optional["Environment"] = None, + installer: _InstallerType = None, + replace_conflicting: bool = False, + extras: Optional[Tuple[str, ...]] = None, + ) -> List["Distribution"]: ... + def resolve( + self, + requirements: Iterable["Requirement"], + env: Optional["Environment"] = None, + installer: Union[_InstallerType, _StrictInstallerType[_DistributionT]] = None, + replace_conflicting: bool = False, + extras: Optional[Tuple[str, ...]] = None, + ) -> Union[List["Distribution"], List[_DistributionT]]: """List all distributions needed to (recursively) meet `requirements` `requirements` must be a sequence of ``Requirement`` objects. `env`, @@ -873,13 +935,47 @@ def _resolve_dist( raise VersionConflict(dist, req).with_context(dependent_req) return dist + @overload + def find_plugins( + self, + plugin_env: "Environment", + full_env: Optional["Environment"], + installer: _StrictInstallerType[_DistributionT], + fallback: bool = True, + ) -> Tuple[List[_DistributionT], Dict["Distribution", Exception]]: ... + @overload def find_plugins( self, plugin_env: "Environment", full_env: Optional["Environment"] = None, - installer: Optional[_InstallerType] = None, + *, + installer: _StrictInstallerType[_DistributionT], fallback: bool = True, - ): + ) -> Tuple[List[_DistributionT], Dict["Distribution", Exception]]: ... + @overload + def find_plugins( + self, + plugin_env: "Environment", + full_env: Optional["Environment"] = None, + installer: _InstallerType = None, + fallback: bool = True, + ) -> Tuple[List["Distribution"], Dict["Distribution", Exception]]: ... + def find_plugins( + self, + plugin_env: "Environment", + full_env: Optional["Environment"] = None, + installer: Union[_InstallerType, _StrictInstallerType[_DistributionT]] = None, + fallback: bool = True, + ) -> Union[ + Tuple[ + List["Distribution"], + Dict["Distribution", Exception], + ], + Tuple[ + List["_DistributionT"], + Dict["Distribution", Exception], + ], + ]: """Find all activatable distributions in `plugin_env` Example usage:: @@ -918,8 +1014,8 @@ def find_plugins( # scan project names in alphabetic order plugin_projects.sort() - error_info = {} - distributions = {} + error_info: Dict["Distribution", Exception] = {} + distributions: Dict["Distribution", Optional[Exception]] = {} if full_env is None: env = Environment(self.entries) @@ -1118,13 +1214,29 @@ def add(self, dist: "Distribution"): dists.append(dist) dists.sort(key=operator.attrgetter('hashcmp'), reverse=True) + @overload def best_match( self, req: "Requirement", working_set: WorkingSet, - installer: Optional[Callable[["Requirement"], Any]] = None, + installer: _StrictInstallerType[_DistributionT], replace_conflicting: bool = False, - ): + ) -> _DistributionT: ... + @overload + def best_match( + self, + req: "Requirement", + working_set: WorkingSet, + installer: _InstallerType = None, + replace_conflicting: bool = False, + ) -> Optional["Distribution"]: ... + def best_match( + self, + req: "Requirement", + working_set: WorkingSet, + installer: Union[_InstallerType, _StrictInstallerType[_DistributionT]] = None, + replace_conflicting: bool = False, + ) -> Optional[Union["Distribution", _DistributionT]]: """Find distribution best matching `req` and usable on `working_set` This calls the ``find(req)`` method of the `working_set` to see if a @@ -1151,11 +1263,31 @@ def best_match( # try to download/install return self.obtain(req, installer) + @overload def obtain( self, requirement: "Requirement", - installer: Optional[Callable[["Requirement"], Any]] = None, - ): + installer: _StrictInstallerType[_DistributionT], + ) -> _DistributionT: ... + @overload + def obtain( + self, + requirement: "Requirement", + installer: Optional[Callable[["Requirement"], None]] = None, + ) -> None: ... + @overload + def obtain( + self, requirement: "Requirement", installer: _InstallerType = None + ) -> Optional["Distribution"]: ... + def obtain( + self, + requirement: "Requirement", + installer: Union[ + Optional[Callable[["Requirement"], None]], + _InstallerType, + _StrictInstallerType[_DistributionT], + ] = None, + ) -> Optional[Union["Distribution", _DistributionT]]: """Obtain a distribution matching `requirement` (e.g. via download) Obtain a distro that matches requirement (e.g. via download). In the @@ -1512,7 +1644,6 @@ class NullProvider: egg_name: Optional[str] = None egg_info: Optional[str] = None loader: Optional[_LoaderProtocol] = None - module_path: Optional[str] # Some subclasses can have a None module_path def __init__(self, module: _ModuleLike): self.loader = getattr(module, '__loader__', None) @@ -1555,7 +1686,7 @@ def get_metadata(self, name: str): exc.reason += ' in {} file at path: {}'.format(name, path) raise - def get_metadata_lines(self, name: str): + def get_metadata_lines(self, name: str) -> Iterator[str]: return yield_lines(self.get_metadata(name)) def resource_isdir(self, resource_name: str): @@ -1567,7 +1698,7 @@ def metadata_isdir(self, name: str) -> bool: def resource_listdir(self, resource_name: str): return self._listdir(self._fn(self.module_path, resource_name)) - def metadata_listdir(self, name: str): + def metadata_listdir(self, name: str) -> List[str]: if self.egg_info: return self._listdir(self._fn(self.egg_info, name)) return [] @@ -1580,6 +1711,8 @@ def run_script(self, script_name: str, namespace: Dict[str, Any]): **locals() ), ) + if not self.egg_info: + raise TypeError("Provider is missing egg_info", self.egg_info) script_text = self.get_metadata(script).replace('\r\n', '\n') script_text = script_text.replace('\r', '\n') script_filename = self._fn(self.egg_info, script) @@ -1610,12 +1743,12 @@ def _isdir(self, path) -> bool: "Can't perform this operation for unregistered loader type" ) - def _listdir(self, path): + def _listdir(self, path) -> List[str]: raise NotImplementedError( "Can't perform this operation for unregistered loader type" ) - def _fn(self, base, resource_name: str): + def _fn(self, base: str, resource_name: str): self._validate_resource_path(resource_name) if resource_name: return os.path.join(base, *resource_name.split('/')) @@ -1775,7 +1908,8 @@ def _register(cls): class EmptyProvider(NullProvider): """Provider that returns nothing for all requests""" - module_path = None + # A special case, we don't want all Providers inheriting from NullProvider to have a potentially None module_path + module_path: Optional[str] = None # type: ignore[assignment] _isdir = _has = lambda self, path: False @@ -1851,7 +1985,7 @@ class ZipProvider(EggProvider): # ZipProvider's loader should always be a zipimporter or equivalent loader: zipimport.zipimporter - def __init__(self, module: _ModuleLike): + def __init__(self, module: ZipLoaderModule): super().__init__(module) self.zip_pre = self.loader.archive + os.sep @@ -1900,7 +2034,7 @@ def _get_date_and_size(zip_stat): return timestamp, size # FIXME: 'ZipProvider._extract_resource' is too complex (12) - def _extract_resource(self, manager: ResourceManager, zip_path): # noqa: C901 + def _extract_resource(self, manager: ResourceManager, zip_path) -> str: # noqa: C901 if zip_path in self._index(): for name in self._index()[zip_path]: last = self._extract_resource(manager, os.path.join(zip_path, name)) @@ -2037,7 +2171,7 @@ def _get_metadata_path(self, name): def has_metadata(self, name: str) -> bool: return name == 'PKG-INFO' and os.path.isfile(self.path) - def get_metadata(self, name): + def get_metadata(self, name: str): if name != 'PKG-INFO': raise KeyError("No metadata except PKG-INFO is available") @@ -2053,7 +2187,7 @@ def _warn_on_replacement(self, metadata): msg = tmpl.format(**locals()) warnings.warn(msg) - def get_metadata_lines(self, name): + def get_metadata_lines(self, name: str) -> Iterator[str]: return yield_lines(self.get_metadata(name)) @@ -2465,12 +2599,16 @@ def null_ns_handler( register_namespace_handler(object, null_ns_handler) -def normalize_path(filename: "StrPath"): +@overload +def normalize_path(filename: "StrPath") -> str: ... +@overload +def normalize_path(filename: "BytesPath") -> bytes: ... +def normalize_path(filename: "StrOrBytesPath"): """Normalize a file/dir name for comparison purposes""" return os.path.normcase(os.path.realpath(os.path.normpath(_cygwin_patch(filename)))) -def _cygwin_patch(filename: "StrPath"): # pragma: nocover +def _cygwin_patch(filename: "StrOrBytesPath"): # pragma: nocover """ Contrary to POSIX 2008, on Cygwin, getcwd (3) contains symlink components. Using @@ -2563,12 +2701,23 @@ def __str__(self): def __repr__(self): return "EntryPoint.parse(%r)" % str(self) + @overload + def load( + self, + require: Literal[True] = True, + env: Optional[Environment] = None, + installer: _InstallerType = None, + ) -> _ResolvedEntryPoint: ... + @overload + def load( + self, require: Literal[False], *args: Any, **kwargs: Any + ) -> _ResolvedEntryPoint: ... def load( self, require: bool = True, *args: Optional[Union[Environment, _InstallerType]], **kwargs: Optional[Union[Environment, _InstallerType]], - ): + ) -> _ResolvedEntryPoint: """ Require packages for this EntryPoint, then resolve it. """ @@ -2585,7 +2734,7 @@ def load( self.require(*args, **kwargs) # type: ignore return self.resolve() - def resolve(self): + def resolve(self) -> _ResolvedEntryPoint: """ Resolve the entry point from its module and attrs. """ @@ -2598,7 +2747,7 @@ def resolve(self): def require( self, env: Optional[Environment] = None, - installer: Optional[_InstallerType] = None, + installer: _InstallerType = None, ): if not self.dist: error_cls = UnknownExtra if self.extras else AttributeError @@ -3013,13 +3162,17 @@ def as_requirement(self): return Requirement.parse(spec) - def load_entry_point(self, group: str, name: str): + def load_entry_point(self, group: str, name: str) -> _ResolvedEntryPoint: """Return the `name` entry point of `group` or raise ImportError""" ep = self.get_entry_info(group, name) if ep is None: raise ImportError("Entry point %r not found" % ((group, name),)) return ep.load() + @overload + def get_entry_map(self, group: None = None) -> Dict[str, Dict[str, EntryPoint]]: ... + @overload + def get_entry_map(self, group: str) -> Dict[str, EntryPoint]: ... def get_entry_map(self, group: Optional[str] = None): """Return the entry point map for `group`, or the full entry map""" if not hasattr(self, "_ep_map"):