Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Provide file hashes in the URLs to avoid unnecessary file downloads (bandwidth saver) #1433

Merged
merged 15 commits into from
Sep 23, 2023
Merged
59 changes: 44 additions & 15 deletions s3_management/manage.py
Original file line number Diff line number Diff line change
@@ -1,20 +1,22 @@
#!/usr/bin/env python

import argparse
import base64
import dataclasses
matteius marked this conversation as resolved.
Show resolved Hide resolved
import functools
import time

from os import path, makedirs
from datetime import datetime
from collections import defaultdict
from typing import Iterator, List, Type, Dict, Set, TypeVar, Optional
from typing import Iterable, List, Type, Dict, Set, TypeVar, Optional
from re import sub, match, search
from packaging.version import parse

import boto3


S3 = boto3.resource('s3')
CLIENT = boto3.client('s3')
BUCKET = S3.Bucket('pytorch')

ACCEPTED_FILE_EXTENSIONS = ("whl", "zip", "tar.gz")
Expand Down Expand Up @@ -104,6 +106,23 @@

S3IndexType = TypeVar('S3IndexType', bound='S3Index')


@dataclasses.dataclass(frozen=True)
@functools.total_ordering
class S3Object:
key: str
checksum: str | None

def __str__(self):
return self.key

def __eq__(self, other):
return self.key == other.key

def __lt__(self, other):
return self.key < other.key


def extract_package_build_time(full_package_name: str) -> datetime:
result = search(PACKAGE_DATE_REGEX, full_package_name)
if result is not None:
Expand All @@ -121,7 +140,7 @@ def between_bad_dates(package_build_time: datetime):


class S3Index:
def __init__(self: S3IndexType, objects: List[str], prefix: str) -> None:
def __init__(self: S3IndexType, objects: List[S3Object], prefix: str) -> None:
self.objects = objects
self.prefix = prefix.rstrip("/")
self.html_name = PREFIXES_WITH_HTML[self.prefix]
Expand All @@ -131,7 +150,7 @@ def __init__(self: S3IndexType, objects: List[str], prefix: str) -> None:
path.dirname(obj) for obj in objects if path.dirname != prefix
}

def nightly_packages_to_show(self: S3IndexType) -> Set[str]:
def nightly_packages_to_show(self: S3IndexType) -> Set[S3Object]:
"""Finding packages to show based on a threshold we specify

Basically takes our S3 packages, normalizes the version for easier
Expand Down Expand Up @@ -171,8 +190,8 @@ def nightly_packages_to_show(self: S3IndexType) -> Set[str]:
if self.normalize_package_version(obj) in to_hide
})

def is_obj_at_root(self, obj:str) -> bool:
return path.dirname(obj) == self.prefix
def is_obj_at_root(self, obj: S3Object) -> bool:
return path.dirname(str(obj)) == self.prefix

def _resolve_subdir(self, subdir: Optional[str] = None) -> str:
if not subdir:
Expand All @@ -184,7 +203,7 @@ def gen_file_list(
self,
subdir: Optional[str]=None,
package_name: Optional[str] = None
) -> Iterator[str]:
) -> Iterable[S3Object]:
objects = (
self.nightly_packages_to_show() if self.prefix == 'whl/nightly'
else self.objects
Expand All @@ -194,23 +213,23 @@ def gen_file_list(
if package_name is not None:
if self.obj_to_package_name(obj) != package_name:
continue
if self.is_obj_at_root(obj) or obj.startswith(subdir):
if self.is_obj_at_root(obj) or str(obj).startswith(subdir):
yield obj

def get_package_names(self, subdir: Optional[str] = None) -> List[str]:
return sorted(set(self.obj_to_package_name(obj) for obj in self.gen_file_list(subdir)))

def normalize_package_version(self: S3IndexType, obj: str) -> str:
def normalize_package_version(self: S3IndexType, obj: S3Object) -> str:
# removes the GPU specifier from the package name as well as
# unnecessary things like the file extension, architecture name, etc.
return sub(
r"%2B.*",
"",
"-".join(path.basename(obj).split("-")[:2])
"-".join(path.basename(str(obj)).split("-")[:2])
)

def obj_to_package_name(self, obj: str) -> str:
return path.basename(obj).split('-', 1)[0]
def obj_to_package_name(self, obj: S3Object) -> str:
return path.basename(str(obj)).split('-', 1)[0]

def to_legacy_html(
self,
Expand Down Expand Up @@ -255,7 +274,8 @@ def to_simple_package_html(
out.append(' <body>')
out.append(' <h1>Links for {}</h1>'.format(package_name.lower().replace("_","-")))
for obj in sorted(self.gen_file_list(subdir, package_name)):
out.append(f' <a href="/{obj}">{path.basename(obj).replace("%2B","+")}</a><br/>')
maybe_fragment = f"#sha256={obj.checksum}" if obj.checksum else ""
out.append(f' <a href="/{obj}{maybe_fragment}">{path.basename(obj).replace("%2B","+")}</a><br/>')
# Adding html footer
out.append(' </body>')
out.append('</html>')
Expand Down Expand Up @@ -316,7 +336,6 @@ def upload_pep503_htmls(self) -> None:
Body=self.to_simple_package_html(subdir=subdir, package_name=pkg_name)
)


def save_legacy_html(self) -> None:
for subdir in self.subdirs:
print(f"INFO Saving {subdir}/{self.html_name}")
Expand Down Expand Up @@ -348,10 +367,18 @@ def from_S3(cls: Type[S3IndexType], prefix: str) -> S3IndexType:
for pattern in ACCEPTED_SUBDIR_PATTERNS
]) and obj.key.endswith(ACCEPTED_FILE_EXTENSIONS)
if is_acceptable:
# Add PEP 503-compatible hashes to URLs to allow clients to avoid spurious downloads, if possible.
response = obj.meta.client.head_object(Bucket=BUCKET.name, Key=obj.key, ChecksumMode="ENABLED")
sha256 = (_b64 := response.get("ChecksumSHA256")) and base64.b64decode(_b64).hex()
sanitized_key = obj.key.replace("+", "%2B")
objects.append(sanitized_key)
s3_object = S3Object(
key=sanitized_key,
checksum=sha256,
)
objects.append(s3_object)
return cls(objects, prefix)


def create_parser() -> argparse.ArgumentParser:
parser = argparse.ArgumentParser("Manage S3 HTML indices for PyTorch")
parser.add_argument(
Expand All @@ -363,6 +390,7 @@ def create_parser() -> argparse.ArgumentParser:
parser.add_argument("--generate-pep503", action="store_true")
return parser


def main():
parser = create_parser()
args = parser.parse_args()
Expand All @@ -387,5 +415,6 @@ def main():
if args.generate_pep503:
idx.upload_pep503_htmls()


if __name__ == "__main__":
main()