Skip to content

Commit

Permalink
Reordering tests experiment (pytorch#106347)
Browse files Browse the repository at this point in the history
Companion with pytorch/test-infra#4424

Uses the file rating generated by the test infra PR to re order tests.  For each test file, sum the file ratings from the changed files in the PR, and put the tests in order of sum.

A lot of tests are probably going to end up as "prioritized" since it takes anything with a rating > 0 right now.

Sharding is done twice, once on the prioritized tests, and once on the general/non prioritized tests.  Prioritized tests have an order, so they should be sharded according to that order, while general tests don't have an order and are sharded by test time, which should result in more balanced shards.

I'll change the metric name before I merge, i want to quarantine my testing stuff from actual results

Pull Request resolved: pytorch#106347
Approved by: https://github.com/ZainRizvi
  • Loading branch information
clee2000 authored and summerdo committed Aug 17, 2023
1 parent dde1a63 commit cf4178f
Show file tree
Hide file tree
Showing 14 changed files with 210 additions and 160 deletions.
1 change: 1 addition & 0 deletions .ci/pytorch/win-test-helpers/build_pytorch.bat
Original file line number Diff line number Diff line change
Expand Up @@ -128,6 +128,7 @@ python -c "import os, glob; os.system('python -mpip install --no-index --no-deps
:: export test times so that potential sharded tests that'll branch off this build will use consistent data
python tools/stats/export_test_times.py
copy /Y ".pytorch-test-times.json" "%PYTORCH_FINAL_PACKAGE_DIR%"
copy /Y ".pytorch-test-file-ratings.json" "%PYTORCH_FINAL_PACKAGE_DIR%"

:: Also save build/.ninja_log as an artifact
copy /Y "build\.ninja_log" "%PYTORCH_FINAL_PACKAGE_DIR%\"
Expand Down
1 change: 1 addition & 0 deletions .ci/pytorch/win-test-helpers/test_python_jit_legacy.bat
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@ call %SCRIPT_HELPERS_DIR%\setup_pytorch_env.bat

echo Copying over test times file
copy /Y "%PYTORCH_FINAL_PACKAGE_DIR_WIN%\.pytorch-test-times.json" "%PROJECT_DIR_WIN%"
copy /Y "%PYTORCH_FINAL_PACKAGE_DIR_WIN%\.pytorch-test-file-ratings.json" "%PROJECT_DIR_WIN%"

pushd test

Expand Down
1 change: 1 addition & 0 deletions .ci/pytorch/win-test-helpers/test_python_shard.bat
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@ if "%SHARD_NUMBER%" == "1" (

echo Copying over test times file
copy /Y "%PYTORCH_FINAL_PACKAGE_DIR_WIN%\.pytorch-test-times.json" "%PROJECT_DIR_WIN%"
copy /Y "%PYTORCH_FINAL_PACKAGE_DIR_WIN%\.pytorch-test-file-ratings.json" "%PROJECT_DIR_WIN%"

echo Run nn tests
python run_test.py --exclude-jit-executor --exclude-distributed-tests --shard "%SHARD_NUMBER%" "%NUM_TEST_SHARDS%" --verbose
Expand Down
2 changes: 1 addition & 1 deletion .circleci/config.yml

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

2 changes: 1 addition & 1 deletion .circleci/verbatim-sources/job-specs/job-specs-custom.yml
Original file line number Diff line number Diff line change
Expand Up @@ -177,7 +177,7 @@
- run:
name: Archive artifacts into zip
command: |
zip -1 -r artifacts.zip dist/ build/.ninja_log build/compile_commands.json .pytorch-test-times.json
zip -1 -r artifacts.zip dist/ build/.ninja_log build/compile_commands.json .pytorch-test-times.json .pytorch-test-file-ratings.json
cp artifacts.zip /Users/distiller/workspace
- persist_to_workspace:
Expand Down
2 changes: 1 addition & 1 deletion .github/workflows/_linux-build.yml
Original file line number Diff line number Diff line change
Expand Up @@ -170,7 +170,7 @@ jobs:
- name: Archive artifacts into zip
if: inputs.build-generates-artifacts && steps.build.outcome != 'skipped'
run: |
zip -1 -r artifacts.zip dist/ build/custom_test_artifacts build/lib build/bin .pytorch-test-times.json
zip -1 -r artifacts.zip dist/ build/custom_test_artifacts build/lib build/bin .pytorch-test-times.json .pytorch-test-file-ratings.json
- name: Store PyTorch Build Artifacts on S3
uses: seemethere/upload-artifact-s3@v5
Expand Down
2 changes: 1 addition & 1 deletion .github/workflows/_mac-build.yml
Original file line number Diff line number Diff line change
Expand Up @@ -182,7 +182,7 @@ jobs:
- name: Archive artifacts into zip
if: inputs.build-generates-artifacts && steps.build.outcome != 'skipped'
run: |
zip -1 -r artifacts.zip dist/ build/.ninja_log build/compile_commands.json .pytorch-test-times.json
zip -1 -r artifacts.zip dist/ build/.ninja_log build/compile_commands.json .pytorch-test-times.json .pytorch-test-file-ratings.json
- name: Store PyTorch Build Artifacts on GHA
uses: actions/upload-artifact@v3
Expand Down
1 change: 1 addition & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@ coverage.xml
**/.pytorch-disabled-tests.json
**/.pytorch-slow-tests.json
**/.pytorch-test-times.json
**/.pytorch-test-file-ratings.json
*/*.pyc
*/*.so*
*/**/__pycache__
Expand Down
185 changes: 105 additions & 80 deletions test/run_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,9 +11,10 @@
import subprocess
import sys
import tempfile
import time
from datetime import datetime
from distutils.version import LooseVersion
from typing import Any, cast, Dict, List, Optional, Union
from typing import Any, cast, Dict, List, NamedTuple, Optional, Union

import pkg_resources

Expand All @@ -40,11 +41,11 @@
# using tools/ to optimize test run.
sys.path.insert(0, str(REPO_ROOT))
from tools.stats.export_test_times import TEST_TIMES_FILE
from tools.stats.upload_stats_lib import emit_metric
from tools.testing.test_selections import (
calculate_shards,
get_reordered_tests,
get_test_case_configs,
log_time_savings,
NUM_PROCS,
ShardedTest,
THRESHOLD,
Expand Down Expand Up @@ -1278,7 +1279,9 @@ def exclude_tests(
return selected_tests


def must_serial(file: str) -> bool:
def must_serial(file: Union[str, ShardedTest]) -> bool:
if isinstance(file, ShardedTest):
file = file.name
return (
os.getenv("PYTORCH_TEST_RUN_EVERYTHING_IN_SERIAL", "0") == "1"
or DISTRIBUTED_TEST_PREFIX in os.getenv("TEST_CONFIG", "")
Expand Down Expand Up @@ -1408,20 +1411,10 @@ def get_selected_tests(options) -> List[ShardedTest]:
)

selected_tests = [parse_test_module(x) for x in selected_tests]
return selected_tests

# sharding
which_shard, num_shards = 1, 1
if options.shard:
assert len(options.shard) == 2, "Unexpected shard format"
assert min(options.shard) > 0, "Shards must be positive numbers"
which_shard, num_shards = options.shard
assert (
which_shard <= num_shards
), "Selected shard must be less than or equal to total number of shards"
assert num_shards <= len(
selected_tests
), f"Number of shards must be less than {len(selected_tests)}"

def download_test_times(file: str = TEST_TIMES_FILE) -> Dict[str, float]:
# Download previous test times to make sharding decisions
path = os.path.join(str(REPO_ROOT), TEST_TIMES_FILE)
if os.path.exists(path):
Expand All @@ -1434,24 +1427,50 @@ def get_selected_tests(options) -> List[ShardedTest]:
print(
"::warning:: Gathered no stats from artifacts. Proceeding with default sharding plan."
)
return {}
else:
print("Found test time stats from artifacts")
return test_file_times[test_config]


def do_sharding(
options,
selected_tests: List[str],
test_file_times: Dict[str, float],
sort_by_time: bool = True,
) -> List[ShardedTest]:
which_shard, num_shards = 1, 1
if options.shard:
assert len(options.shard) == 2, "Unexpected shard format"
assert min(options.shard) > 0, "Shards must be positive numbers"
which_shard, num_shards = options.shard
assert (
which_shard <= num_shards
), "Selected shard must be less than or equal to total number of shards"

if HAVE_TEST_SELECTION_TOOLS:
# Do sharding
test_file_times_config = test_file_times.get(test_config, {})
shards = calculate_shards(
num_shards, selected_tests, test_file_times_config, must_serial=must_serial
num_shards,
selected_tests,
test_file_times,
must_serial=must_serial,
sort_by_time=sort_by_time,
)
_, tests_from_shard = shards[which_shard - 1]
selected_tests = tests_from_shard

return selected_tests


class TestFailure(NamedTuple):
test: str
message: str


def run_test_module(
test: Union[ShardedTest, str], test_directory: str, options
) -> Optional[str]:
) -> Optional[TestFailure]:
maybe_set_hip_visible_devies()

# Printing the date here can help diagnose which tests are slow
Expand All @@ -1472,39 +1491,24 @@ def run_test_module(
# return code -N, where N is the signal number.
signal_name = SIGNALS_TO_NAMES_DICT[-return_code]
message += f" Received signal: {signal_name}"
return message
return TestFailure(test, message)


def run_tests(
selected_tests: List[ShardedTest], test_directory: str, options, group_name: str
selected_tests: List[ShardedTest],
test_directory: str,
options,
failures: List[TestFailure],
) -> None:
failure_messages = []

if len(selected_tests) == 0:
print_to_stderr(f"No tests in group `{group_name}`")
return failure_messages
return

# parallel = in parallel with other files
# serial = this file on it's own. The file might still be run in parallel with itself (ex test_ops)
selected_tests_parallel = [
x
for x in selected_tests
if not must_serial(x.name if isinstance(x, ShardedTest) else x)
]
selected_tests_parallel = [x for x in selected_tests if not must_serial(x)]
selected_tests_serial = [
x for x in selected_tests if x not in selected_tests_parallel
]
print(f"TEST GROUP: {group_name}")
print_to_stderr(
"parallel (file granularity) tests :\n {}".format(
"\n".join(str(x) for x in selected_tests_parallel)
)
)
print_to_stderr(
"serial (file granularity) tests:\n {}".format(
"\n ".join(str(x) for x in selected_tests_serial)
)
)

# See Note [ROCm parallel CI testing]
pool = get_context("spawn").Pool(
Expand All @@ -1523,15 +1527,15 @@ def run_tests(
# Take the conftest file from the test directory
shutil.copy(os.path.join(test_directory, "conftest.py"), cpp_conftest_file)

def handle_error_messages(err_message):
if err_message is None:
def handle_error_messages(failure: Optional[TestFailure]):
if failure is None:
return False
failure_messages.append(err_message)
print_to_stderr(err_message)
failures.append(failure)
print_to_stderr(failure.message)
return True

def parallel_test_completion_callback(err_message):
test_failed = handle_error_messages(err_message)
def parallel_test_completion_callback(failure):
test_failed = handle_error_messages(failure)
if (
test_failed
and not options.continue_through_error
Expand All @@ -1557,10 +1561,10 @@ def parallel_test_completion_callback(err_message):
if (
not options.continue_through_error
and not RERUN_DISABLED_TESTS
and len(failure_messages) != 0
and len(failures) != 0
):
raise RuntimeError(
"\n".join(failure_messages)
"\n".join(x.message for x in failures)
+ "\n\nTip: You can keep running tests even on failure by "
"passing --keep-going to run_test.py.\n"
"If running on CI, add the 'keep-going' label to "
Expand All @@ -1571,20 +1575,20 @@ def parallel_test_completion_callback(err_message):
options_clone = copy.deepcopy(options)
if can_run_in_pytest(test):
options_clone.pytest = True
err_message = run_test_module(test, test_directory, options_clone)
test_failed = handle_error_messages(err_message)
failure = run_test_module(test, test_directory, options_clone)
test_failed = handle_error_messages(failure)
if (
test_failed
and not options.continue_through_error
and not RERUN_DISABLED_TESTS
):
raise RuntimeError(err_message)
raise RuntimeError(failure.message)

finally:
pool.terminate()
pool.join()

return failure_messages
return


def check_pip_packages() -> None:
Expand All @@ -1611,30 +1615,47 @@ def main():
test_directory = str(REPO_ROOT / "test")
selected_tests = get_selected_tests(options)

if options.verbose:
print_to_stderr(
"Selected tests:\n {}".format("\n ".join(str(x) for x in selected_tests))
)

if options.dry_run:
return

if options.coverage and not PYTORCH_COLLECT_COVERAGE:
shell(["coverage", "erase"])

prioritized_tests = []
remaining_tests = selected_tests
general_tests = selected_tests
if IS_CI and HAVE_TEST_SELECTION_TOOLS:
(prioritized_tests, remaining_tests) = get_reordered_tests(selected_tests)
log_time_savings(
selected_tests,
prioritized_tests,
is_serial_test_fn=must_serial,
num_procs=NUM_PROCS,
)

# downloading test cases configuration to local environment
get_test_case_configs(dirpath=test_directory)
(prioritized_tests, general_tests) = get_reordered_tests(general_tests)

metrics_dict = {
"prioritized_tests": prioritized_tests,
"general_tests": general_tests,
"cpp": options.cpp,
}

test_times_dict = download_test_times(TEST_TIMES_FILE)
prioritized_tests = do_sharding(
options, prioritized_tests, test_times_dict, sort_by_time=False
)
general_tests = do_sharding(options, general_tests, test_times_dict)

if options.verbose:

def print_tests(category, tests):
tests_str = "\n ".join(str(x) for x in tests)
print_to_stderr(f"{category} tests:\n {tests_str}")

print_tests(
"Prioritized parallel", [x for x in prioritized_tests if not must_serial(x)]
)
print_tests(
"Prioritized serial", [x for x in prioritized_tests if must_serial(x)]
)
print_tests(
"General parallel", [x for x in general_tests if not must_serial(x)]
)
print_tests("General serial", [x for x in general_tests if must_serial(x)])

if options.dry_run:
return

if options.dynamo:
os.environ["PYTORCH_TEST_WITH_DYNAMO"] = "1"
Expand All @@ -1646,17 +1667,17 @@ def main():

os.makedirs(REPO_ROOT / "test" / "test-reports", exist_ok=True)

failure_messages = []

prioritized_failures: List[TestFailure] = []
general_failures: List[TestFailure] = []
start_time = time.time()
# First run the prioritized tests, then the remaining tests.
try:
failure_messages = run_tests(
prioritized_tests, test_directory, options, "Prioritized tests"
)

failure_messages += run_tests(
remaining_tests, test_directory, options, "General tests"
)
run_tests(prioritized_tests, test_directory, options, prioritized_failures)
metrics_dict["prioritized_failures"] = [x.test for x in prioritized_failures]
metrics_dict["general_start_time"] = time.time() - start_time
run_tests(general_tests, test_directory, options, general_failures)
metrics_dict["general_end_time"] = time.time() - start_time
metrics_dict["all_failures"] = [x.test for x in general_failures]

finally:
if options.coverage:
Expand All @@ -1671,8 +1692,12 @@ def main():
if not PYTORCH_COLLECT_COVERAGE:
cov.html_report()

if len(failure_messages) != 0:
for err in failure_messages:
if IS_CI and HAVE_TEST_SELECTION_TOOLS:
emit_metric("td_experiment_1", metrics_dict)

all_failures = prioritized_failures + general_failures
if len(all_failures) != 0:
for _, err in all_failures:
print_to_stderr(err)

# A disabled test is expected to fail, so there is no need to report a failure here
Expand Down
Loading

0 comments on commit cf4178f

Please sign in to comment.