diff --git a/src/pip/_internal/resolution/resolvelib/resolver.py b/src/pip/_internal/resolution/resolvelib/resolver.py index 12f96702024..78317fb7162 100644 --- a/src/pip/_internal/resolution/resolvelib/resolver.py +++ b/src/pip/_internal/resolution/resolvelib/resolver.py @@ -1,7 +1,7 @@ import functools import logging import os -from typing import TYPE_CHECKING, Dict, List, Optional, Set, Tuple, cast +from typing import TYPE_CHECKING, Dict, Iterator, List, Optional, Set, Tuple, cast from pip._vendor.packaging.utils import canonicalize_name from pip._vendor.resolvelib import BaseReporter, ResolutionImpossible @@ -19,6 +19,7 @@ PipDebuggingReporter, PipReporter, ) +from pip._internal.utils.parallel import LACK_SEM_OPEN, map_multithread from .base import Candidate, Requirement from .factory import Factory @@ -66,6 +67,7 @@ def __init__( self.ignore_dependencies = ignore_dependencies self.upgrade_strategy = upgrade_strategy self._result: Optional[Result] = None + self._finder = finder def resolve( self, root_reqs: List[InstallRequirement], check_supported_wheels: bool @@ -87,6 +89,8 @@ def resolve( reporter, ) + self.prime_finder_cache(provider.identify(r) for r in collected.requirements) + try: try_to_avoid_resolution_too_deep = 2000000 result = self._result = resolver.resolve( @@ -161,6 +165,22 @@ def resolve( self.factory.preparer.prepare_linked_requirements_more(reqs) return req_set + def prime_finder_cache(self, project_names: Iterator[str]) -> None: + if LACK_SEM_OPEN: + return + + if not hasattr(self._finder.find_all_candidates, "cache_info"): + return + + def _maybe_find_candidates(project_name: str) -> None: + try: + self._finder.find_all_candidates(project_name) + except AttributeError: + pass + + for _ in map_multithread(_maybe_find_candidates, project_names): + pass + def get_installation_order( self, req_set: RequirementSet ) -> List[InstallRequirement]: diff --git a/tests/unit/test_finder.py b/tests/unit/test_finder.py index 34720d54ee8..1513dee2246 100644 --- a/tests/unit/test_finder.py +++ b/tests/unit/test_finder.py @@ -558,3 +558,23 @@ def test_find_all_candidates_find_links_and_index(data: TestData) -> None: versions = finder.find_all_candidates("simple") # first the find-links versions then the page versions assert [str(v.version) for v in versions] == ["3.0", "2.0", "1.0", "1.0"] + +def test_finder_caching(data: TestData) -> None: + finder = make_test_finder( + find_links=[data.find_links], + index_urls=[data.index_url("simple")], + ) + def get_findall_cacheinfo(): + cacheinfo = finder.find_all_candidates.cache_info() + return {k: getattr(cacheinfo, k) for k in ['currsize', 'hits', 'misses']} + + # empty before any calls + assert get_findall_cacheinfo() == {"currsize": 0, "hits": 0, "misses": 0} + + # first findall is a miss + finder.find_all_candidates("simple") + assert get_findall_cacheinfo() == {"currsize": 1, "hits": 0, "misses": 1} + + # find best following a find all is a hit + finder.find_best_candidate("simple") + assert get_findall_cacheinfo() == {"currsize": 1, "hits": 1, "misses": 1}