Skip to content

Commit

Permalink
add a script that fetches and validate binary sizes of the wheels using
Browse files Browse the repository at this point in the history
https://download.pytorch.org/whl/ index using specified rules
  • Loading branch information
izaitsevfb committed Feb 9, 2023
1 parent c212b85 commit 8f2e8cf
Show file tree
Hide file tree
Showing 5 changed files with 190 additions and 0 deletions.
21 changes: 21 additions & 0 deletions .github/workflows/test-binary-size-validation.yml
Original file line number Diff line number Diff line change
@@ -0,0 +1,21 @@
name: Test binary size validation script
on:
pull_request:
paths:
- .github/workflows/binary-size-validation.yml
- tools/binary_size_validation/test_binary_size_validation.py
- tools/binary_size_validation/binary_size_validation.py
workflow_dispatch:

jobs:
test-binary-size-validation:
runs-on: ubuntu-latest
steps:
- name: Checkout
uses: actions/checkout@v3
- name: Install requirements
run: |
pip3 install -r tools/binary_size_validation/requirements.txt
- name: Run pytest
run: |
pytest tools/binary_size_validation/test_binary_size_validation.py
27 changes: 27 additions & 0 deletions tools/binary_size_validation/README.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,27 @@
# PyTorch Wheel Binary Size Validation

A script to fetch and validate the binary size of PyTorch wheels
in the given channel (test, nightly) against the given threshold.


### Installation

```bash
pip install -r requirements.txt
```

### Usage

```bash
# print help
python binary_size_validation.py --help

# print sizes of the all items in the index
python binary_size_validation.py --url https://download.pytorch.org/whl/nightly/torch/

# fail if any of the torch2.0 wheels are larger than 900MB
python binary_size_validation.py --url https://download.pytorch.org/whl/nightly/torch/ --include "torch-2\.0" --threshold 900

# fail if any of the latest nightly pypi wheels are larger than 750MB
python binary_size_validation.py --include "pypi" --only-latest-version --threshold 750
```
91 changes: 91 additions & 0 deletions tools/binary_size_validation/binary_size_validation.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,91 @@
# Script that parses wheel index (e.g. https://download.pytorch.org/whl/test/torch/),
# fetches and validates binary size for the files that match the given regex.

import requests
import re
from collections import namedtuple
import click
from bs4 import BeautifulSoup
from urllib.parse import urljoin

Wheel = namedtuple("Wheel", ["name", "url"])


def parse_index(html: str,
base_url: str,
include_regex: str = "",
exclude_regex: str = "",
latest_version_only=False) -> list[Wheel]:
"""
parse the html page and return a list of wheels
:param html: html page
:param base_url: base url of the page
:param include_regex: regex to filter the wheel names. If empty, all wheels are included
:param exclude_regex: regex to exclude the matching wheel names. If empty, no wheels are excluded
:param latest_version_only: if True, return the wheels of the latest version only
:return: list of wheels
"""
soup = BeautifulSoup(html, "html.parser")

wheels = []
for a in soup.find_all("a"):
wheel_name = a.text
wheel_url = urljoin(base_url, a.get("href"))
if (not include_regex or re.search(include_regex, wheel_name)) \
and (not exclude_regex or not re.search(exclude_regex, wheel_name)):
wheels.append(Wheel(name=wheel_name, url=wheel_url))

# filter out the wheels that are not the latest version
if len(wheels) > 0 and latest_version_only:
# get the prefixes (up to the second '+'/'-' sign) of the wheels
prefixes = set()
for wheel in wheels:
prefix = re.search(r"^([^-+]+[-+][^-+]+)[-+]", wheel.name).group(1)
if not prefix:
raise RuntimeError(f"Failed to get version prefix of {wheel.name}"
"Please check the regex_filter or don't use --latest-version-only")
prefixes.add(prefix)
latest_version = max(prefixes)
print(f"Latest version prefix: {latest_version}")

# filter out the wheels that are not the latest version
wheels = [wheel for wheel in wheels if wheel.name.startswith(latest_version)]

return wheels


def get_binary_size(file_url: str) -> int:
"""
get the binary size of the given file
:param file_url: url of the file
:return: binary size in bytes
"""
return int(requests.head(file_url).headers['Content-Length'])


@click.command(
help="Validate the binary sizes of the given wheel index."
)
@click.option("--url", help="url of the wheel index",
default="https://download.pytorch.org/whl/nightly/torch/")
@click.option("--include", help="regex to filter the wheel names. Only the matching wheel names will be checked.",
default="")
@click.option("--exclude", help="regex to exclude wheel names. Matching wheel names will NOT be checked.",
default="")
@click.option("--threshold", help="threshold in MB, optional", default=0)
@click.option("--only-latest-version", help="only validate the latest version",
is_flag=True, show_default=True, default=False)
def main(url, include, exclude, threshold, only_latest_version):
page = requests.get(url)
wheels = parse_index(page.text, url, include, exclude, only_latest_version)
for wheel in wheels:
print(f"Validating {wheel.url}...")
size = get_binary_size(wheel.url)
print(f"{wheel.name}: {int(size) / 1024 / 1024:.2f} MB")
if threshold and int(size) > threshold:
raise RuntimeError(
f"Binary size of {wheel.name} {int(size) / 1024 / 1024:.2f} MB exceeds the threshold {threshold} MB")


if __name__ == "__main__":
main()
4 changes: 4 additions & 0 deletions tools/binary_size_validation/requirements.txt
Original file line number Diff line number Diff line change
@@ -0,0 +1,4 @@
beautifulsoup4==4.11.2
click==8.0.4
pytest==7.1.1
requests==2.27.1
47 changes: 47 additions & 0 deletions tools/binary_size_validation/test_binary_size_validation.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,47 @@
from binary_size_validation import parse_index

# ignore long lines in this file
# flake8: noqa: E501
test_html = """
<!DOCTYPE html>
<html>
<body>
<h1>Links for torch</h1>
<a href="/whl/nightly/cpu/torch-1.13.0.dev20220728%2Bcpu-cp310-cp310-linux_x86_64.whl">torch-1.13.0.dev20220728+cpu-cp310-cp310-linux_x86_64.whl</a><br/>
<a href="/whl/nightly/cpu/torch-1.13.0.dev20220728%2Bcpu-cp310-cp310-win_amd64.whl">torch-1.13.0.dev20220728+cpu-cp310-cp310-win_amd64.whl</a><br/>
<a href="/whl/nightly/cpu/torch-1.13.0.dev20220728%2Bcpu-cp37-cp37m-linux_x86_64.whl">torch-1.13.0.dev20220728+cpu-cp37-cp37m-linux_x86_64.whl</a><br/>
<a href="/whl/nightly/cpu/torch-1.13.0.dev20220728%2Bcpu-cp37-cp37m-win_amd64.whl">torch-1.13.0.dev20220728+cpu-cp37-cp37m-win_amd64.whl</a><br/>
<a href="/whl/nightly/rocm5.3/torch-2.0.0.dev20230206%2Brocm5.3-cp39-cp39-linux_x86_64.whl">torch-2.0.0.dev20230206+rocm5.3-cp39-cp39-linux_x86_64.whl</a><br/>
<a href="/whl/nightly/rocm5.3/torch-2.0.0.dev20230207%2Brocm5.3-cp310-cp310-linux_x86_64.whl">torch-2.0.0.dev20230207+rocm5.3-cp310-cp310-linux_x86_64.whl</a><br/>
<a href="/whl/nightly/rocm5.3/torch-2.0.0.dev20230207%2Brocm5.3-cp38-cp38-linux_x86_64.whl">torch-2.0.0.dev20230207+rocm5.3-cp38-cp38-linux_x86_64.whl</a><br/>
<a href="/whl/nightly/rocm5.3/torch-2.0.0.dev20230207%2Brocm5.3-cp39-cp39-linux_x86_64.whl">torch-2.0.0.dev20230207+rocm5.3-cp39-cp39-linux_x86_64.whl</a><br/>
</body>
</html>
<!--TIMESTAMP 1675892605-->
"""

base_url = "https://download.pytorch.org/whl/nightly/torch/"


def test_get_whl_links():
wheels = parse_index(test_html, base_url)
assert len(wheels) == 8
assert wheels[0].url == \
"https://download.pytorch.org/whl/nightly/cpu/torch-1.13.0.dev20220728%2Bcpu-cp310-cp310-linux_x86_64.whl"


def test_include_exclude():
wheels = parse_index(test_html, base_url, "amd6\\d")
assert len(wheels) == 2
assert wheels[0].name == "torch-1.13.0.dev20220728+cpu-cp310-cp310-win_amd64.whl"
assert wheels[1].name == "torch-1.13.0.dev20220728+cpu-cp37-cp37m-win_amd64.whl"

wheels = parse_index(test_html, base_url, "amd6\\d", "cp37")
assert len(wheels) == 1
assert wheels[0].name == "torch-1.13.0.dev20220728+cpu-cp310-cp310-win_amd64.whl"


def test_latest_version_only():
wheels = parse_index(test_html, base_url, latest_version_only=True)
assert len(wheels) == 3
assert all(w.name.startswith("torch-2.0.0.dev20230207") for w in wheels)

0 comments on commit 8f2e8cf

Please sign in to comment.