Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

lazier lazy_wheel #11481

Open
wants to merge 9 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions news/11481.feature.rst
Original file line number Diff line number Diff line change
@@ -0,0 +1,2 @@
Implement lazier lazy_wheel that avoids the HEAD request, fetching the metadata
for many wheels in 1 request.
181 changes: 146 additions & 35 deletions src/pip/_internal/network/lazy_wheel.py
Original file line number Diff line number Diff line change
@@ -1,40 +1,51 @@
"""Lazy ZIP over HTTP"""

from __future__ import annotations

__all__ = ["HTTPRangeRequestUnsupported", "dist_from_wheel_url"]

import logging
from bisect import bisect_left, bisect_right
from contextlib import contextmanager
from tempfile import NamedTemporaryFile
from typing import Any, Dict, Generator, List, Optional, Tuple
from typing import Any, Iterator
from zipfile import BadZipfile, ZipFile

from pip._vendor.packaging.utils import canonicalize_name
from pip._vendor.requests.models import CONTENT_CHUNK_SIZE, Response
from pip._vendor.requests.models import CONTENT_CHUNK_SIZE, HTTPError, Response

from pip._internal.exceptions import InvalidWheel
from pip._internal.metadata import BaseDistribution, MemoryWheel, get_wheel_distribution
from pip._internal.network.session import PipSession
from pip._internal.network.utils import HEADERS, raise_for_status, response_chunks
from pip._internal.network.session import PipSession as Session
from pip._internal.network.utils import HEADERS

log = logging.getLogger(__name__)


class HTTPRangeRequestUnsupported(Exception):
pass


def dist_from_wheel_url(name: str, url: str, session: PipSession) -> BaseDistribution:
def dist_from_wheel_url(name: str, url: str, session: Session) -> BaseDistribution:
"""Return a distribution object from the given wheel URL.

This uses HTTP range requests to only fetch the portion of the wheel
containing metadata, just enough for the object to be constructed.
If such requests are not supported, HTTPRangeRequestUnsupported
is raised.
"""
with LazyZipOverHTTP(url, session) as zf:
# For read-only ZIP files, ZipFile only needs methods read,
# seek, seekable and tell, not the whole IO protocol.
wheel = MemoryWheel(zf.name, zf) # type: ignore
# After context manager exit, wheel.name
# is an invalid file by intention.
return get_wheel_distribution(wheel, canonicalize_name(name))
try:
with LazyZipOverHTTP(url, session) as zf:
zf.prefetch_dist_info()

# For read-only ZIP files, ZipFile only needs methods read,
# seek, seekable and tell, not the whole IO protocol.
wheel = MemoryWheel(zf.name, zf) # type: ignore
# After context manager exit, wheel.name
# is an invalid file by intention.
return get_wheel_distribution(wheel, canonicalize_name(name))
except BadZipfile:
raise InvalidWheel(url, name)


class LazyZipOverHTTP:
Expand All @@ -47,20 +58,60 @@ class LazyZipOverHTTP:
"""

def __init__(
self, url: str, session: PipSession, chunk_size: int = CONTENT_CHUNK_SIZE
self, url: str, session: Session, chunk_size: int = CONTENT_CHUNK_SIZE
) -> None:
head = session.head(url, headers=HEADERS)
raise_for_status(head)
assert head.status_code == 200

# if CONTENT_CHUNK_SIZE is bigger than the file:
# In [8]: response.headers["Content-Range"]
# Out[8]: 'bytes 0-3133374/3133375'

self._request_count = 0

self._session, self._url, self._chunk_size = session, url, chunk_size
self._length = int(head.headers["Content-Length"])

# initial range request for the end of the file
try:
tail = self._stream_response(start="", end=CONTENT_CHUNK_SIZE)
except HTTPError as e:
if e.response.status_code != 416:
raise

# The 416 response message contains a Content-Range indicating an
# unsatisfied range (that is a '*') followed by a '/' and the current
# length of the resource. E.g. Content-Range: bytes */12777
content_length = int(e.response.headers["content-range"].rsplit("/", 1)[-1])
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Do we want some error handling here (and other places where we read the header) in case the server sends something invalid?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Errors here probably devolve to "range request not supported" and should trigger a non-lazy retry.

tail = self._stream_response(start=0, end=content_length)

# e.g. {'accept-ranges': 'bytes', 'content-length': '10240',
# 'content-range': 'bytes 12824-23063/23064', 'last-modified': 'Sat, 16
# Apr 2022 13:03:02 GMT', 'date': 'Thu, 21 Apr 2022 11:34:04 GMT'}

if tail.status_code != 206:
if (
tail.status_code == 200
and int(tail.headers["content-length"]) <= CONTENT_CHUNK_SIZE
):
# small file
content_length = len(tail.content)
tail.headers["content-range"] = f"0-{content_length-1}/{content_length}"
else:
raise HTTPRangeRequestUnsupported("range request is not supported")

# lowercase content-range to support s3
self._length = int(tail.headers["content-range"].partition("/")[-1])
Comment on lines +100 to +101
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I thought headers is a special case-insensitive mapping anyway, does lower-casing actually matter?

Copy link
Member Author

@dholth dholth Nov 16, 2022

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We used a requests -> boto3 adapter that is case sensitive, so for conda we can point this at s3 if wanted.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

So this is an S3 issue? I’d very much prefer this weird context to not get into pip. It’s OK we use lower-cased keys here (because it does not really matter for us), but the comment confuses things.

self._file = NamedTemporaryFile()
self.truncate(self._length)
self._left: List[int] = []
self._right: List[int] = []
if "bytes" not in head.headers.get("Accept-Ranges", "none"):
raise HTTPRangeRequestUnsupported("range request is not supported")
self._check_zip()

# length is also in Content-Length and Content-Range header
with self._stay():
content_length = int(tail.headers["content-length"])
if hasattr(tail, "content"):
assert content_length == len(tail.content)
self.seek(self._length - content_length)
for chunk in tail.iter_content(self._chunk_size):
self._file.write(chunk)
self._left: list[int] = [self._length - content_length]
self._right: list[int] = [self._length - 1]

@property
def mode(self) -> str:
Expand Down Expand Up @@ -92,7 +143,8 @@ def read(self, size: int = -1) -> bytes:
all bytes until EOF are returned. Fewer than
size bytes may be returned if EOF is reached.
"""
download_size = max(size, self._chunk_size)
# BUG does not download correctly if size is unspecified
download_size = size
start, length = self.tell(), self._length
stop = length if size < 0 else min(start + download_size, length)
start = max(0, stop - download_size)
Expand All @@ -117,7 +169,7 @@ def tell(self) -> int:
"""Return the current position."""
return self._file.tell()

def truncate(self, size: Optional[int] = None) -> int:
def truncate(self, size: int | None = None) -> int:
"""Resize the stream to the given size in bytes.

If size is unspecified resize to the current position.
Expand All @@ -131,15 +183,16 @@ def writable(self) -> bool:
"""Return False."""
return False

def __enter__(self) -> "LazyZipOverHTTP":
def __enter__(self) -> LazyZipOverHTTP:
self._file.__enter__()
return self

def __exit__(self, *exc: Any) -> None:
self._file.__exit__(*exc)
def __exit__(self, *exc: Any) -> bool | None:
print(self._request_count, "requests to fetch metadata from", self._url[107:])
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Stray debugging code?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes. A test makes sure _request_count is <= the maximum of 3. It might be possible to construct a pathological wheel where the .dist-info files were all over, that would rather have > 3 requests...

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Which test are you referring to here?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Could we use log.debug() like we do elsewhere in this change?

return self._file.__exit__(*exc) # type: ignore

@contextmanager
def _stay(self) -> Generator[None, None, None]:
def _stay(self) -> Iterator[None]:
"""Return a context manager keeping the position.

At the end of the block, seek back to original position.
Expand All @@ -166,19 +219,27 @@ def _check_zip(self) -> None:
break

def _stream_response(
self, start: int, end: int, base_headers: Dict[str, str] = HEADERS
self, start: int | str, end: int, base_headers: dict[str, str] = HEADERS
) -> Response:
"""Return HTTP response to a range request from start to end."""
"""Return HTTP response to a range request from start to end.

:param start: if "", request ``end` bytes from end of file."""
headers = base_headers.copy()
headers["Range"] = f"bytes={start}-{end}"
log.debug("%s", headers["Range"])
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Should make the log message more useful.

# TODO: Get range requests to be correctly cached
headers["Cache-Control"] = "no-cache"
return self._session.get(self._url, headers=headers, stream=True)
# TODO: If-Match (etag) to detect file changed during fetch would be a
# good addition to HEADERS
self._request_count += 1
response = self._session.get(self._url, headers=headers, stream=True)
response.raise_for_status()
return response

def _merge(
self, start: int, end: int, left: int, right: int
) -> Generator[Tuple[int, int], None, None]:
"""Return a generator of intervals to be fetched.
) -> Iterator[tuple[int, int]]:
"""Return an iterator of intervals to be fetched.

Args:
start (int): Start of needed interval
Expand All @@ -204,7 +265,57 @@ def _download(self, start: int, end: int) -> None:
right = bisect_right(self._left, end)
for start, end in self._merge(start, end, left, right):
response = self._stream_response(start, end)
response.raise_for_status()
self.seek(start)
for chunk in response_chunks(response, self._chunk_size):
for chunk in response.iter_content(self._chunk_size):
self._file.write(chunk)

def prefetch(self, target_file: str) -> None:
"""
Prefetch a specific file from the remote ZIP in one request.
"""
with self._stay(): # not strictly necessary
# try to read entire conda info in one request
zf = ZipFile(self) # type: ignore
infolist = zf.infolist()
for i, info in enumerate(infolist):
if info.filename == target_file:
# could be incorrect if zipfile was concatenated to another
# file (not likely for .conda)
start = info.header_offset
try:
end = infolist[i + 1].header_offset
# or info.header_offset
# + len(info.filename)
# + len(info.extra)
# + info.compress_size
# (unless Zip64)
except IndexError:
end = zf.start_dir
self.seek(start)
self.read(end - start)
log.debug(
"prefetch %s-%s",
info.header_offset,
end,
)
break
else:
log.debug("no zip prefetch")

def prefetch_dist_info(self) -> None:
"""
Read contents of entire dist-info section of wheel.

pip wants to read WHEEL and METADATA.
"""
with self._stay():
zf = ZipFile(self) # type: ignore
infolist = zf.infolist()
for info in infolist:
# should be (wheel filename without extension etc) + (.dist-info/)
if ".dist-info/" in info.filename:
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think there’s a better detection logic somewhere (probably in wheel installation code) for this.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

In pypa/wheel wheelfile.py it derives the .dist-info dirname from the parsed wheel filename. https://github.com/pypa/wheel/blob/main/src/wheel/wheelfile.py#L49, there probably is similar code in pip.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If there is similar code existing in the code base, we should extract the common logic.

start = info.header_offset
end = zf.start_dir
self.seek(start)
self.read(end - start)
break