Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Provide file hashes in the URLs to avoid unnecessary file downloads (bandwidth saver) #1433

Merged
merged 15 commits into from
Sep 23, 2023
Merged
46 changes: 27 additions & 19 deletions s3_management/manage.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,6 @@


S3 = boto3.resource('s3')
CLIENT = boto3.client('s3')
BUCKET = S3.Bucket('pytorch')

ACCEPTED_FILE_EXTENSIONS = ("whl", "zip", "tar.gz")
Expand Down Expand Up @@ -121,8 +120,8 @@ def between_bad_dates(package_build_time: datetime):


class S3Index:
def __init__(self: S3IndexType, objects: List[str], prefix: str) -> None:
self.objects = objects
def __init__(self: S3IndexType, objects: Dict[str, str | None], prefix: str) -> None:
self.objects = objects # s3 key to checksum mapping
self.prefix = prefix.rstrip("/")
self.html_name = PREFIXES_WITH_HTML[self.prefix]
# should dynamically grab subdirectories like whl/test/cu101
Expand All @@ -131,7 +130,7 @@ def __init__(self: S3IndexType, objects: List[str], prefix: str) -> None:
path.dirname(obj) for obj in objects if path.dirname != prefix
}

def nightly_packages_to_show(self: S3IndexType) -> Set[str]:
def nightly_packages_to_show(self: S3IndexType) -> Dict[str, str | None]:
"""Finding packages to show based on a threshold we specify

Basically takes our S3 packages, normalizes the version for easier
Expand All @@ -146,7 +145,7 @@ def nightly_packages_to_show(self: S3IndexType) -> Set[str]:
# also includes versions without GPU specifier (i.e. cu102) for easier
# sorting, sorts in reverse to put the most recent versions first
all_sorted_packages = sorted(
{self.normalize_package_version(obj) for obj in self.objects},
{self.normalize_package_version(s3_key) for s3_key in self.objects},
key=lambda name_ver: parse(name_ver.split('-', 1)[-1]),
reverse=True,
)
Expand All @@ -166,10 +165,11 @@ def nightly_packages_to_show(self: S3IndexType) -> Set[str]:
to_hide.add(obj)
else:
packages[package_name] += 1
return set(self.objects).difference({
obj for obj in self.objects
if self.normalize_package_version(obj) in to_hide
})
return {
s3_key: checksum
for s3_key, checksum in self.objects.items()
if self.normalize_package_version(s3_key) not in to_hide
}

def is_obj_at_root(self, obj:str) -> bool:
return path.dirname(obj) == self.prefix
Expand All @@ -184,21 +184,21 @@ def gen_file_list(
self,
subdir: Optional[str]=None,
package_name: Optional[str] = None
) -> Iterator[str]:
) -> Iterator[str, str | None]:
objects = (
self.nightly_packages_to_show() if self.prefix == 'whl/nightly'
else self.objects
)
subdir = self._resolve_subdir(subdir) + '/'
for obj in objects:
for obj, checksum in objects.items():
if package_name is not None:
if self.obj_to_package_name(obj) != package_name:
continue
if self.is_obj_at_root(obj) or obj.startswith(subdir):
yield obj
yield obj, checksum

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The return type of this function needs updating

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I believe this is fixed now.


def get_package_names(self, subdir: Optional[str] = None) -> List[str]:
return sorted(set(self.obj_to_package_name(obj) for obj in self.gen_file_list(subdir)))
return sorted(set(self.obj_to_package_name(obj) for obj, _ in self.gen_file_list(subdir)))

def normalize_package_version(self: S3IndexType, obj: str) -> str:
# removes the GPU specifier from the package name as well as
Expand Down Expand Up @@ -226,7 +226,7 @@ def to_legacy_html(
out: List[str] = []
subdir = self._resolve_subdir(subdir)
is_root = subdir == self.prefix
for obj in self.gen_file_list(subdir):
for obj, _ in self.gen_file_list(subdir):
matteius marked this conversation as resolved.
Show resolved Hide resolved
matteius marked this conversation as resolved.
Show resolved Hide resolved
# Strip our prefix
sanitized_obj = obj.replace(subdir, "", 1)
if sanitized_obj.startswith('/'):
Expand Down Expand Up @@ -254,8 +254,11 @@ def to_simple_package_html(
out.append('<html>')
out.append(' <body>')
out.append(' <h1>Links for {}</h1>'.format(package_name.lower().replace("_","-")))
for obj in sorted(self.gen_file_list(subdir, package_name)):
out.append(f' <a href="/{obj}">{path.basename(obj).replace("%2B","+")}</a><br/>')
for obj, checksum in sorted(self.gen_file_list(subdir, package_name)):
maybe_fragment = ""
if checksum:
maybe_fragment = f"#sha256={checksum}"
out.append(f' <a href="/{obj}{maybe_fragment}">{path.basename(obj).replace("%2B","+")}</a><br/>')
# Adding html footer
out.append(' </body>')
out.append('</html>')
Expand Down Expand Up @@ -316,7 +319,6 @@ def upload_pep503_htmls(self) -> None:
Body=self.to_simple_package_html(subdir=subdir, package_name=pkg_name)
)


def save_legacy_html(self) -> None:
for subdir in self.subdirs:
print(f"INFO Saving {subdir}/{self.html_name}")
Expand All @@ -337,7 +339,7 @@ def save_pep503_htmls(self) -> None:

@classmethod
def from_S3(cls: Type[S3IndexType], prefix: str) -> S3IndexType:
objects = []
objects = {}
prefix = prefix.rstrip("/")
for obj in BUCKET.objects.filter(Prefix=prefix):
is_acceptable = any([path.dirname(obj.key) == prefix] + [
Expand All @@ -348,10 +350,14 @@ def from_S3(cls: Type[S3IndexType], prefix: str) -> S3IndexType:
for pattern in ACCEPTED_SUBDIR_PATTERNS
]) and obj.key.endswith(ACCEPTED_FILE_EXTENSIONS)
if is_acceptable:
# Add PEP 503-compatible hashes to URLs to allow clients to avoid spurious downloads, if possible.
response = obj.meta.client.head_object(Bucket=BUCKET.name, Key=obj.key, ChecksumMode="ENABLED")
sha256 = response.get("ChecksumSHA256")
matteius marked this conversation as resolved.
Show resolved Hide resolved
matteius marked this conversation as resolved.
Show resolved Hide resolved
sanitized_key = obj.key.replace("+", "%2B")
objects.append(sanitized_key)
objects[sanitized_key] = sha256
return cls(objects, prefix)


def create_parser() -> argparse.ArgumentParser:
parser = argparse.ArgumentParser("Manage S3 HTML indices for PyTorch")
parser.add_argument(
Expand All @@ -363,6 +369,7 @@ def create_parser() -> argparse.ArgumentParser:
parser.add_argument("--generate-pep503", action="store_true")
return parser


def main():
parser = create_parser()
args = parser.parse_args()
Expand All @@ -387,5 +394,6 @@ def main():
if args.generate_pep503:
idx.upload_pep503_htmls()


if __name__ == "__main__":
main()