Skip to content

Commit

Permalink
Make the call to actual download a bit less nested
Browse files Browse the repository at this point in the history
  • Loading branch information
McSinyx committed Aug 1, 2020
1 parent 68608d9 commit 0af3a4c
Show file tree
Hide file tree
Showing 3 changed files with 20 additions and 68 deletions.
44 changes: 14 additions & 30 deletions src/pip/_internal/network/download.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@
from pip._internal.utils.typing import MYPY_CHECK_RUNNING

if MYPY_CHECK_RUNNING:
from typing import Iterable, Optional
from typing import Iterable, Optional, Tuple

from pip._vendor.requests.models import Response

Expand Down Expand Up @@ -133,27 +133,6 @@ def _get_http_response_filename(resp, link):
return filename


def _http_get_download(session, link):
# type: (PipSession, Link) -> Response
target_url = link.url.split('#', 1)[0]
resp = session.get(target_url, headers=HEADERS, stream=True)
raise_for_status(resp)
return resp


class Download(object):
def __init__(
self,
response, # type: Response
filename, # type: str
chunks, # type: Iterable[bytes]
):
# type: (...) -> None
self.response = response
self.filename = filename
self.chunks = chunks


class Downloader(object):
def __init__(
self,
Expand All @@ -164,19 +143,24 @@ def __init__(
self._session = session
self._progress_bar = progress_bar

def __call__(self, link):
# type: (Link) -> Download
def __call__(self, link, tmpdir):
# type: (Link, str) -> Tuple[str, str]
url = link.url.split('#', 1)[0]
response = self._session.get(url, headers=HEADERS, stream=True)
try:
resp = _http_get_download(self._session, link)
raise_for_status(response)
except NetworkConnectionError as e:
assert e.response is not None
logger.critical(
"HTTP error %s while getting %s", e.response.status_code, link
)
raise

return Download(
resp,
_get_http_response_filename(resp, link),
_prepare_download(resp, link, self._progress_bar),
)
chunks = _prepare_download(response, link, self._progress_bar)
filename = _get_http_response_filename(response, link)
file_path = os.path.join(tmpdir, filename)

with open(file_path, 'wb') as content_file:
for chunk in chunks:
content_file.write(chunk)
return file_path, response.headers.get('content-type', '')
31 changes: 4 additions & 27 deletions src/pip/_internal/operations/prepare.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,9 +39,7 @@
from pip._internal.vcs import vcs

if MYPY_CHECK_RUNNING:
from typing import (
Callable, List, Optional, Tuple,
)
from typing import Callable, List, Optional

from mypy_extensions import TypedDict

Expand Down Expand Up @@ -126,9 +124,9 @@ def get_http_url(
content_type = mimetypes.guess_type(from_path)[0]
else:
# let's download to a tmp dir
from_path, content_type = _download_http_url(
link, downloader, temp_dir.path, hashes
)
from_path, content_type = downloader(link, temp_dir.path)
if hashes:
hashes.check_against_path(from_path)

return File(from_path, content_type)

Expand Down Expand Up @@ -267,27 +265,6 @@ def unpack_url(
return file


def _download_http_url(
link, # type: Link
downloader, # type: Downloader
temp_dir, # type: str
hashes, # type: Optional[Hashes]
):
# type: (...) -> Tuple[str, str]
"""Download link url into temp_dir using provided session"""
download = downloader(link)

file_path = os.path.join(temp_dir, download.filename)
with open(file_path, 'wb') as content_file:
for chunk in download.chunks:
content_file.write(chunk)

if hashes:
hashes.check_against_path(file_path)

return file_path, download.response.headers.get('content-type', '')


def _check_download_dir(link, download_dir, hashes):
# type: (Link, str, Optional[Hashes]) -> Optional[str]
""" Check download_dir for previously downloaded file with correct hash
Expand Down
13 changes: 2 additions & 11 deletions tests/unit/test_operations_prepare.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,11 +10,7 @@
from pip._internal.models.link import Link
from pip._internal.network.download import Downloader
from pip._internal.network.session import PipSession
from pip._internal.operations.prepare import (
_copy_source_tree,
_download_http_url,
unpack_url,
)
from pip._internal.operations.prepare import _copy_source_tree, unpack_url
from pip._internal.utils.hashes import Hashes
from pip._internal.utils.urls import path_to_url
from tests.lib.filesystem import (
Expand Down Expand Up @@ -83,12 +79,7 @@ def test_download_http_url__no_directory_traversal(mock_raise_for_status,

download_dir = tmpdir.joinpath('download')
os.mkdir(download_dir)
file_path, content_type = _download_http_url(
link,
downloader,
download_dir,
hashes=None,
)
file_path, content_type = downloader(link, download_dir)
# The file should be downloaded to download_dir.
actual = os.listdir(download_dir)
assert actual == ['out_dir_file']
Expand Down

0 comments on commit 0af3a4c

Please sign in to comment.