From cf4178f968407a0641676c7d3cf5c5894cf550f8 Mon Sep 17 00:00:00 2001 From: Catherine Lee Date: Wed, 16 Aug 2023 18:23:09 +0000 Subject: [PATCH] Reordering tests experiment (#106347) Companion with https://github.com/pytorch/test-infra/pull/4424 Uses the file rating generated by the test infra PR to re order tests. For each test file, sum the file ratings from the changed files in the PR, and put the tests in order of sum. A lot of tests are probably going to end up as "prioritized" since it takes anything with a rating > 0 right now. Sharding is done twice, once on the prioritized tests, and once on the general/non prioritized tests. Prioritized tests have an order, so they should be sharded according to that order, while general tests don't have an order and are sharded by test time, which should result in more balanced shards. I'll change the metric name before I merge, i want to quarantine my testing stuff from actual results Pull Request resolved: https://github.com/pytorch/pytorch/pull/106347 Approved by: https://github.com/ZainRizvi --- .../win-test-helpers/build_pytorch.bat | 1 + .../test_python_jit_legacy.bat | 1 + .../win-test-helpers/test_python_shard.bat | 1 + .circleci/config.yml | 2 +- .../job-specs/job-specs-custom.yml | 2 +- .github/workflows/_linux-build.yml | 2 +- .github/workflows/_mac-build.yml | 2 +- .gitignore | 1 + test/run_test.py | 185 ++++++++++-------- tools/stats/export_test_times.py | 4 +- tools/stats/import_test_stats.py | 10 + tools/stats/upload_stats_lib.py | 2 +- tools/test/test_test_selections.py | 28 ++- tools/testing/test_selections.py | 129 ++++++------ 14 files changed, 210 insertions(+), 160 deletions(-) diff --git a/.ci/pytorch/win-test-helpers/build_pytorch.bat b/.ci/pytorch/win-test-helpers/build_pytorch.bat index 5e74062613f63..57f59febe7293 100644 --- a/.ci/pytorch/win-test-helpers/build_pytorch.bat +++ b/.ci/pytorch/win-test-helpers/build_pytorch.bat @@ -128,6 +128,7 @@ python -c "import os, glob; os.system('python -mpip install --no-index --no-deps :: export test times so that potential sharded tests that'll branch off this build will use consistent data python tools/stats/export_test_times.py copy /Y ".pytorch-test-times.json" "%PYTORCH_FINAL_PACKAGE_DIR%" + copy /Y ".pytorch-test-file-ratings.json" "%PYTORCH_FINAL_PACKAGE_DIR%" :: Also save build/.ninja_log as an artifact copy /Y "build\.ninja_log" "%PYTORCH_FINAL_PACKAGE_DIR%\" diff --git a/.ci/pytorch/win-test-helpers/test_python_jit_legacy.bat b/.ci/pytorch/win-test-helpers/test_python_jit_legacy.bat index c18151d65c023..7277de33dcf85 100644 --- a/.ci/pytorch/win-test-helpers/test_python_jit_legacy.bat +++ b/.ci/pytorch/win-test-helpers/test_python_jit_legacy.bat @@ -2,6 +2,7 @@ call %SCRIPT_HELPERS_DIR%\setup_pytorch_env.bat echo Copying over test times file copy /Y "%PYTORCH_FINAL_PACKAGE_DIR_WIN%\.pytorch-test-times.json" "%PROJECT_DIR_WIN%" +copy /Y "%PYTORCH_FINAL_PACKAGE_DIR_WIN%\.pytorch-test-file-ratings.json" "%PROJECT_DIR_WIN%" pushd test diff --git a/.ci/pytorch/win-test-helpers/test_python_shard.bat b/.ci/pytorch/win-test-helpers/test_python_shard.bat index 5313bc0078d5f..ec7e78bac9eff 100644 --- a/.ci/pytorch/win-test-helpers/test_python_shard.bat +++ b/.ci/pytorch/win-test-helpers/test_python_shard.bat @@ -23,6 +23,7 @@ if "%SHARD_NUMBER%" == "1" ( echo Copying over test times file copy /Y "%PYTORCH_FINAL_PACKAGE_DIR_WIN%\.pytorch-test-times.json" "%PROJECT_DIR_WIN%" +copy /Y "%PYTORCH_FINAL_PACKAGE_DIR_WIN%\.pytorch-test-file-ratings.json" "%PROJECT_DIR_WIN%" echo Run nn tests python run_test.py --exclude-jit-executor --exclude-distributed-tests --shard "%SHARD_NUMBER%" "%NUM_TEST_SHARDS%" --verbose diff --git a/.circleci/config.yml b/.circleci/config.yml index 5cb89ac2c1403..36149c4f745b7 100644 --- a/.circleci/config.yml +++ b/.circleci/config.yml @@ -652,7 +652,7 @@ jobs: - run: name: Archive artifacts into zip command: | - zip -1 -r artifacts.zip dist/ build/.ninja_log build/compile_commands.json .pytorch-test-times.json + zip -1 -r artifacts.zip dist/ build/.ninja_log build/compile_commands.json .pytorch-test-times.json .pytorch-test-file-ratings.json cp artifacts.zip /Users/distiller/workspace - persist_to_workspace: diff --git a/.circleci/verbatim-sources/job-specs/job-specs-custom.yml b/.circleci/verbatim-sources/job-specs/job-specs-custom.yml index f03e173ccece9..0a2aee3116a3f 100644 --- a/.circleci/verbatim-sources/job-specs/job-specs-custom.yml +++ b/.circleci/verbatim-sources/job-specs/job-specs-custom.yml @@ -177,7 +177,7 @@ - run: name: Archive artifacts into zip command: | - zip -1 -r artifacts.zip dist/ build/.ninja_log build/compile_commands.json .pytorch-test-times.json + zip -1 -r artifacts.zip dist/ build/.ninja_log build/compile_commands.json .pytorch-test-times.json .pytorch-test-file-ratings.json cp artifacts.zip /Users/distiller/workspace - persist_to_workspace: diff --git a/.github/workflows/_linux-build.yml b/.github/workflows/_linux-build.yml index 269260c34aae1..7031e4e6f9aaf 100644 --- a/.github/workflows/_linux-build.yml +++ b/.github/workflows/_linux-build.yml @@ -170,7 +170,7 @@ jobs: - name: Archive artifacts into zip if: inputs.build-generates-artifacts && steps.build.outcome != 'skipped' run: | - zip -1 -r artifacts.zip dist/ build/custom_test_artifacts build/lib build/bin .pytorch-test-times.json + zip -1 -r artifacts.zip dist/ build/custom_test_artifacts build/lib build/bin .pytorch-test-times.json .pytorch-test-file-ratings.json - name: Store PyTorch Build Artifacts on S3 uses: seemethere/upload-artifact-s3@v5 diff --git a/.github/workflows/_mac-build.yml b/.github/workflows/_mac-build.yml index 9ba093e6c7ede..2585709b516ee 100644 --- a/.github/workflows/_mac-build.yml +++ b/.github/workflows/_mac-build.yml @@ -182,7 +182,7 @@ jobs: - name: Archive artifacts into zip if: inputs.build-generates-artifacts && steps.build.outcome != 'skipped' run: | - zip -1 -r artifacts.zip dist/ build/.ninja_log build/compile_commands.json .pytorch-test-times.json + zip -1 -r artifacts.zip dist/ build/.ninja_log build/compile_commands.json .pytorch-test-times.json .pytorch-test-file-ratings.json - name: Store PyTorch Build Artifacts on GHA uses: actions/upload-artifact@v3 diff --git a/.gitignore b/.gitignore index 9ffab8fffac30..424cc4b769352 100644 --- a/.gitignore +++ b/.gitignore @@ -19,6 +19,7 @@ coverage.xml **/.pytorch-disabled-tests.json **/.pytorch-slow-tests.json **/.pytorch-test-times.json +**/.pytorch-test-file-ratings.json */*.pyc */*.so* */**/__pycache__ diff --git a/test/run_test.py b/test/run_test.py index cd62621ca7439..2c27197343541 100755 --- a/test/run_test.py +++ b/test/run_test.py @@ -11,9 +11,10 @@ import subprocess import sys import tempfile +import time from datetime import datetime from distutils.version import LooseVersion -from typing import Any, cast, Dict, List, Optional, Union +from typing import Any, cast, Dict, List, NamedTuple, Optional, Union import pkg_resources @@ -40,11 +41,11 @@ # using tools/ to optimize test run. sys.path.insert(0, str(REPO_ROOT)) from tools.stats.export_test_times import TEST_TIMES_FILE + from tools.stats.upload_stats_lib import emit_metric from tools.testing.test_selections import ( calculate_shards, get_reordered_tests, get_test_case_configs, - log_time_savings, NUM_PROCS, ShardedTest, THRESHOLD, @@ -1278,7 +1279,9 @@ def exclude_tests( return selected_tests -def must_serial(file: str) -> bool: +def must_serial(file: Union[str, ShardedTest]) -> bool: + if isinstance(file, ShardedTest): + file = file.name return ( os.getenv("PYTORCH_TEST_RUN_EVERYTHING_IN_SERIAL", "0") == "1" or DISTRIBUTED_TEST_PREFIX in os.getenv("TEST_CONFIG", "") @@ -1408,20 +1411,10 @@ def get_selected_tests(options) -> List[ShardedTest]: ) selected_tests = [parse_test_module(x) for x in selected_tests] + return selected_tests - # sharding - which_shard, num_shards = 1, 1 - if options.shard: - assert len(options.shard) == 2, "Unexpected shard format" - assert min(options.shard) > 0, "Shards must be positive numbers" - which_shard, num_shards = options.shard - assert ( - which_shard <= num_shards - ), "Selected shard must be less than or equal to total number of shards" - assert num_shards <= len( - selected_tests - ), f"Number of shards must be less than {len(selected_tests)}" +def download_test_times(file: str = TEST_TIMES_FILE) -> Dict[str, float]: # Download previous test times to make sharding decisions path = os.path.join(str(REPO_ROOT), TEST_TIMES_FILE) if os.path.exists(path): @@ -1434,14 +1427,35 @@ def get_selected_tests(options) -> List[ShardedTest]: print( "::warning:: Gathered no stats from artifacts. Proceeding with default sharding plan." ) + return {} else: print("Found test time stats from artifacts") + return test_file_times[test_config] + + +def do_sharding( + options, + selected_tests: List[str], + test_file_times: Dict[str, float], + sort_by_time: bool = True, +) -> List[ShardedTest]: + which_shard, num_shards = 1, 1 + if options.shard: + assert len(options.shard) == 2, "Unexpected shard format" + assert min(options.shard) > 0, "Shards must be positive numbers" + which_shard, num_shards = options.shard + assert ( + which_shard <= num_shards + ), "Selected shard must be less than or equal to total number of shards" if HAVE_TEST_SELECTION_TOOLS: # Do sharding - test_file_times_config = test_file_times.get(test_config, {}) shards = calculate_shards( - num_shards, selected_tests, test_file_times_config, must_serial=must_serial + num_shards, + selected_tests, + test_file_times, + must_serial=must_serial, + sort_by_time=sort_by_time, ) _, tests_from_shard = shards[which_shard - 1] selected_tests = tests_from_shard @@ -1449,9 +1463,14 @@ def get_selected_tests(options) -> List[ShardedTest]: return selected_tests +class TestFailure(NamedTuple): + test: str + message: str + + def run_test_module( test: Union[ShardedTest, str], test_directory: str, options -) -> Optional[str]: +) -> Optional[TestFailure]: maybe_set_hip_visible_devies() # Printing the date here can help diagnose which tests are slow @@ -1472,39 +1491,24 @@ def run_test_module( # return code -N, where N is the signal number. signal_name = SIGNALS_TO_NAMES_DICT[-return_code] message += f" Received signal: {signal_name}" - return message + return TestFailure(test, message) def run_tests( - selected_tests: List[ShardedTest], test_directory: str, options, group_name: str + selected_tests: List[ShardedTest], + test_directory: str, + options, + failures: List[TestFailure], ) -> None: - failure_messages = [] - if len(selected_tests) == 0: - print_to_stderr(f"No tests in group `{group_name}`") - return failure_messages + return # parallel = in parallel with other files # serial = this file on it's own. The file might still be run in parallel with itself (ex test_ops) - selected_tests_parallel = [ - x - for x in selected_tests - if not must_serial(x.name if isinstance(x, ShardedTest) else x) - ] + selected_tests_parallel = [x for x in selected_tests if not must_serial(x)] selected_tests_serial = [ x for x in selected_tests if x not in selected_tests_parallel ] - print(f"TEST GROUP: {group_name}") - print_to_stderr( - "parallel (file granularity) tests :\n {}".format( - "\n".join(str(x) for x in selected_tests_parallel) - ) - ) - print_to_stderr( - "serial (file granularity) tests:\n {}".format( - "\n ".join(str(x) for x in selected_tests_serial) - ) - ) # See Note [ROCm parallel CI testing] pool = get_context("spawn").Pool( @@ -1523,15 +1527,15 @@ def run_tests( # Take the conftest file from the test directory shutil.copy(os.path.join(test_directory, "conftest.py"), cpp_conftest_file) - def handle_error_messages(err_message): - if err_message is None: + def handle_error_messages(failure: Optional[TestFailure]): + if failure is None: return False - failure_messages.append(err_message) - print_to_stderr(err_message) + failures.append(failure) + print_to_stderr(failure.message) return True - def parallel_test_completion_callback(err_message): - test_failed = handle_error_messages(err_message) + def parallel_test_completion_callback(failure): + test_failed = handle_error_messages(failure) if ( test_failed and not options.continue_through_error @@ -1557,10 +1561,10 @@ def parallel_test_completion_callback(err_message): if ( not options.continue_through_error and not RERUN_DISABLED_TESTS - and len(failure_messages) != 0 + and len(failures) != 0 ): raise RuntimeError( - "\n".join(failure_messages) + "\n".join(x.message for x in failures) + "\n\nTip: You can keep running tests even on failure by " "passing --keep-going to run_test.py.\n" "If running on CI, add the 'keep-going' label to " @@ -1571,20 +1575,20 @@ def parallel_test_completion_callback(err_message): options_clone = copy.deepcopy(options) if can_run_in_pytest(test): options_clone.pytest = True - err_message = run_test_module(test, test_directory, options_clone) - test_failed = handle_error_messages(err_message) + failure = run_test_module(test, test_directory, options_clone) + test_failed = handle_error_messages(failure) if ( test_failed and not options.continue_through_error and not RERUN_DISABLED_TESTS ): - raise RuntimeError(err_message) + raise RuntimeError(failure.message) finally: pool.terminate() pool.join() - return failure_messages + return def check_pip_packages() -> None: @@ -1611,30 +1615,47 @@ def main(): test_directory = str(REPO_ROOT / "test") selected_tests = get_selected_tests(options) - if options.verbose: - print_to_stderr( - "Selected tests:\n {}".format("\n ".join(str(x) for x in selected_tests)) - ) - - if options.dry_run: - return - if options.coverage and not PYTORCH_COLLECT_COVERAGE: shell(["coverage", "erase"]) prioritized_tests = [] - remaining_tests = selected_tests + general_tests = selected_tests if IS_CI and HAVE_TEST_SELECTION_TOOLS: - (prioritized_tests, remaining_tests) = get_reordered_tests(selected_tests) - log_time_savings( - selected_tests, - prioritized_tests, - is_serial_test_fn=must_serial, - num_procs=NUM_PROCS, - ) - # downloading test cases configuration to local environment get_test_case_configs(dirpath=test_directory) + (prioritized_tests, general_tests) = get_reordered_tests(general_tests) + + metrics_dict = { + "prioritized_tests": prioritized_tests, + "general_tests": general_tests, + "cpp": options.cpp, + } + + test_times_dict = download_test_times(TEST_TIMES_FILE) + prioritized_tests = do_sharding( + options, prioritized_tests, test_times_dict, sort_by_time=False + ) + general_tests = do_sharding(options, general_tests, test_times_dict) + + if options.verbose: + + def print_tests(category, tests): + tests_str = "\n ".join(str(x) for x in tests) + print_to_stderr(f"{category} tests:\n {tests_str}") + + print_tests( + "Prioritized parallel", [x for x in prioritized_tests if not must_serial(x)] + ) + print_tests( + "Prioritized serial", [x for x in prioritized_tests if must_serial(x)] + ) + print_tests( + "General parallel", [x for x in general_tests if not must_serial(x)] + ) + print_tests("General serial", [x for x in general_tests if must_serial(x)]) + + if options.dry_run: + return if options.dynamo: os.environ["PYTORCH_TEST_WITH_DYNAMO"] = "1" @@ -1646,17 +1667,17 @@ def main(): os.makedirs(REPO_ROOT / "test" / "test-reports", exist_ok=True) - failure_messages = [] - + prioritized_failures: List[TestFailure] = [] + general_failures: List[TestFailure] = [] + start_time = time.time() # First run the prioritized tests, then the remaining tests. try: - failure_messages = run_tests( - prioritized_tests, test_directory, options, "Prioritized tests" - ) - - failure_messages += run_tests( - remaining_tests, test_directory, options, "General tests" - ) + run_tests(prioritized_tests, test_directory, options, prioritized_failures) + metrics_dict["prioritized_failures"] = [x.test for x in prioritized_failures] + metrics_dict["general_start_time"] = time.time() - start_time + run_tests(general_tests, test_directory, options, general_failures) + metrics_dict["general_end_time"] = time.time() - start_time + metrics_dict["all_failures"] = [x.test for x in general_failures] finally: if options.coverage: @@ -1671,8 +1692,12 @@ def main(): if not PYTORCH_COLLECT_COVERAGE: cov.html_report() - if len(failure_messages) != 0: - for err in failure_messages: + if IS_CI and HAVE_TEST_SELECTION_TOOLS: + emit_metric("td_experiment_1", metrics_dict) + + all_failures = prioritized_failures + general_failures + if len(all_failures) != 0: + for _, err in all_failures: print_to_stderr(err) # A disabled test is expected to fail, so there is no need to report a failure here diff --git a/tools/stats/export_test_times.py b/tools/stats/export_test_times.py index 4554f546ee050..6e60158a7ea9b 100644 --- a/tools/stats/export_test_times.py +++ b/tools/stats/export_test_times.py @@ -3,14 +3,16 @@ REPO_ROOT = pathlib.Path(__file__).resolve().parent.parent.parent sys.path.append(str(REPO_ROOT)) -from tools.stats.import_test_stats import get_test_times +from tools.stats.import_test_stats import get_test_file_ratings, get_test_times TEST_TIMES_FILE = ".pytorch-test-times.json" +TEST_FILE_RATINGS_FILE = ".pytorch-test-file-ratings.json" def main() -> None: print(f"Exporting test times from test-infra to {TEST_TIMES_FILE}") get_test_times(str(REPO_ROOT), filename=TEST_TIMES_FILE) + get_test_file_ratings(str(REPO_ROOT), filename=TEST_FILE_RATINGS_FILE) if __name__ == "__main__": diff --git a/tools/stats/import_test_stats.py b/tools/stats/import_test_stats.py index 28d8ee0961bd9..a0c0190580748 100644 --- a/tools/stats/import_test_stats.py +++ b/tools/stats/import_test_stats.py @@ -20,6 +20,7 @@ def get_disabled_issues() -> List[str]: SLOW_TESTS_FILE = ".pytorch-slow-tests.json" DISABLED_TESTS_FILE = ".pytorch-disabled-tests.json" + FILE_CACHE_LIFESPAN_SECONDS = datetime.timedelta(hours=3).seconds @@ -116,3 +117,12 @@ def process_disabled_test(the_response: Dict[str, Any]) -> Dict[str, Any]: except Exception: print("Couldn't download test skip set, leaving all tests enabled...") return {} + + +def get_test_file_ratings(dirpath: str, filename: str) -> Optional[Dict[str, Any]]: + url = "https://raw.githubusercontent.com/pytorch/test-infra/generated-stats/stats/file_test_rating.json" + try: + return fetch_and_cache(dirpath, filename, url, lambda x: x) + except Exception: + print("Couldn't download test file ratings file, not reordering...") + return {} diff --git a/tools/stats/upload_stats_lib.py b/tools/stats/upload_stats_lib.py index dd48e78ab7a40..ab8d8734c8da0 100644 --- a/tools/stats/upload_stats_lib.py +++ b/tools/stats/upload_stats_lib.py @@ -263,7 +263,7 @@ def value(self) -> Any: value = os.environ.get(self.env_var) if value is None and self.required: raise ValueError( - f"Missing {self.name}. Please set the {self.env_var}" + f"Missing {self.name}. Please set the {self.env_var} " "environment variable to pass in this value." ) if self.type_conversion_fn: diff --git a/tools/test/test_test_selections.py b/tools/test/test_test_selections.py index 04f5f8899a1c2..c0f646c40f49e 100644 --- a/tools/test/test_test_selections.py +++ b/tools/test/test_test_selections.py @@ -394,28 +394,24 @@ def test_dedupes_failing_test_files(self, mock_exists: Any, mock_open: Any) -> N "tools.testing.test_selections._get_modified_tests", return_value={"test2", "test4"}, ) + @mock.patch( + "tools.testing.test_selections._get_file_rating_tests", return_value=["test1"] + ) def test_get_reordered_tests( - self, mock_get_prev_failing_tests: Any, mock_get_modified_tests: Any + self, + mock_get_prev_failing_tests: Any, + mock_get_modified_tests: Any, + mock_get_file_rating_tests: Any, ) -> None: - tests = [ - ShardedTest(name="test1", shard=1, num_shards=2, time=600.0), - ShardedTest(name="test2", shard=1, num_shards=2, time=500.0), - ShardedTest(name="test3", shard=1, num_shards=2, time=400.0), - ShardedTest(name="test4", shard=1, num_shards=2, time=300.0), - ShardedTest(name="test5", shard=1, num_shards=2, time=200.0), - ] + tests = ["test1", "test2", "test3", "test4", "test5"] - expected_prioritized_tests = {"test4", "test2"} - expected_remaining_tests = {"test1", "test3", "test5"} + expected_prioritized_tests = ["test4", "test2", "test1"] + expected_remaining_tests = {"test3", "test5"} prioritized_tests, remaining_tests = get_reordered_tests(tests) - # Just want to check the names of the tests - prioritized_tests_name = {test.name for test in prioritized_tests} - remaining_tests_name = {test.name for test in remaining_tests} - - self.assertSetEqual(expected_prioritized_tests, prioritized_tests_name) - self.assertSetEqual(expected_remaining_tests, remaining_tests_name) + self.assertListEqual(expected_prioritized_tests, prioritized_tests) + self.assertSetEqual(expected_remaining_tests, set(remaining_tests)) def test_compute_prioritization_time_savings_with_multiple_threads(self) -> None: tests = [ diff --git a/tools/testing/test_selections.py b/tools/testing/test_selections.py index 76f841b8902bd..57ba611c7df72 100644 --- a/tools/testing/test_selections.py +++ b/tools/testing/test_selections.py @@ -3,16 +3,20 @@ import math import os import subprocess +from collections import defaultdict from pathlib import Path -from typing import Callable, Dict, List, NamedTuple, Optional, Set, Tuple +from typing import Callable, cast, Dict, List, NamedTuple, Optional, Set, Tuple from warnings import warn from tools.shared.logging_utils import duration_to_str, pluralize +from tools.stats.export_test_times import TEST_FILE_RATINGS_FILE from tools.stats.import_test_stats import get_disabled_tests, get_slow_tests from tools.stats.upload_stats_lib import emit_metric +REPO_ROOT = Path(__file__).resolve().parent.parent.parent + IS_MEM_LEAK_CHECK = os.getenv("PYTORCH_TEST_CUDA_MEM_LEAK_CHECK", "0") == "1" # NUM_PROCS_FOR_SHARDING_CALC must remain consistent across all shards of a job @@ -81,8 +85,8 @@ def get_with_pytest_shard( ) -> List[ShardedTest]: sharded_tests: List[ShardedTest] = [] for test in tests: - duration = test_file_times[test] - if duration > THRESHOLD: + duration = test_file_times.get(test, None) + if duration and duration > THRESHOLD: num_shards = math.ceil(duration / THRESHOLD) for i in range(num_shards): sharded_tests.append( @@ -98,20 +102,24 @@ def calculate_shards( tests: List[str], test_file_times: Dict[str, float], must_serial: Optional[Callable[[str], bool]] = None, + sort_by_time: bool = True, ) -> List[Tuple[float, List[ShardedTest]]]: must_serial = must_serial or (lambda x: True) - known_tests = [x for x in tests if x in test_file_times] - unknown_tests: List[str] = [x for x in tests if x not in known_tests] + known_tests = tests + unknown_tests = [] - sorted_tests = sorted( - get_with_pytest_shard(known_tests, test_file_times), - key=lambda j: j.get_time(), - reverse=True, - ) + if sort_by_time: + known_tests = [x for x in tests if x in test_file_times] + unknown_tests = [x for x in tests if x not in known_tests] + + known_tests = get_with_pytest_shard(known_tests, test_file_times) + + if sort_by_time: + known_tests = sorted(known_tests, key=lambda j: j.get_time(), reverse=True) sharded_jobs: List[ShardJob] = [ShardJob() for _ in range(num_shards)] - for test in sorted_tests: + for test in known_tests: if must_serial(test.name): min_sharded_job = min(sharded_jobs, key=lambda j: j.get_total_time()) min_sharded_job.serial.append(test) @@ -127,7 +135,7 @@ def calculate_shards( return [job.convert_to_tuple() for job in sharded_jobs] -def _query_changed_test_files() -> List[str]: +def _query_changed_files() -> List[str]: default_branch = f"origin/{os.environ.get('GIT_DEFAULT_BRANCH', 'main')}" merge_base = ( subprocess.check_output(["git", "merge-base", default_branch, "HEAD"]) @@ -186,7 +194,7 @@ def _parse_prev_failing_test_files(last_failed_tests: Dict[str, bool]) -> Set[st def _get_modified_tests() -> Set[str]: try: - changed_files = _query_changed_test_files() + changed_files = _query_changed_files() except Exception as e: warn(f"Can't query changed test files due to {e}") # If unable to get changed files from git, quit without doing any sorting @@ -271,76 +279,81 @@ def log_time_savings( return max_time_savings_sec +def _get_file_rating_tests() -> List[str]: + path = REPO_ROOT / TEST_FILE_RATINGS_FILE + if not os.path.exists(path): + print(f"could not find path {path}") + return [] + with open(path) as f: + test_file_ratings = cast(Dict[str, Dict[str, float]], json.load(f)) + try: + changed_files = _query_changed_files() + except Exception as e: + warn(f"Can't query changed test files due to {e}") + return [] + ratings: Dict[str, float] = defaultdict(float) + for file in changed_files: + for test_file, score in test_file_ratings.get(file, {}).items(): + ratings[test_file] += score + prioritize = sorted(ratings, key=lambda x: ratings[x]) + return prioritize + + def get_reordered_tests( - tests: List[ShardedTest], -) -> Tuple[List[ShardedTest], List[ShardedTest]]: + tests: List[str], +) -> Tuple[List[str], List[str]]: """ Get the reordered test filename list based on github PR history or git changed file. We prioritize running test files that were changed. """ + prioritized_tests: List[str] = [] - def print_tests(tests: Set[str], test_group_description: str) -> None: - if not tests: + def add_tests(tests_to_add: List[str], test_group_description: str) -> None: + if not tests_to_add: return print(f"{test_group_description}:") - for test in tests: - print(f" {test}") - - prioritized_tests: Set[str] = set() - - pri_test = _get_previously_failing_tests() - print_tests( - pri_test, "If run, these tests will prioritized because they previously failed" + for test in tests_to_add: + if test in tests: + print(f" {test}") + if test not in prioritized_tests: + prioritized_tests.append(test) + + add_tests( + sorted(_get_previously_failing_tests()), + "If run, these tests will prioritized because they previously failed", ) - prioritized_tests |= pri_test - pri_test |= _get_modified_tests() - print_tests( - pri_test, "If run, these tests will be prioritized because they were modified" + add_tests( + sorted(_get_modified_tests()), + "If run, these tests will be prioritized because they were modified", ) - prioritized_tests |= pri_test - bring_to_front = [] - the_rest = [] + add_tests( + _get_file_rating_tests(), + "If run, these tests will be preioritized for an experiment in TD", + ) - for test in tests: - if test.name in prioritized_tests: - bring_to_front.append(test) - else: - the_rest.append(test) + prioritized_tests = [x for x in prioritized_tests if x in tests] + the_rest = [x for x in tests if x not in prioritized_tests] - if len(tests) != len(bring_to_front) + len(the_rest): + if prioritized_tests: + test_cnt_str = pluralize(len(tests), "test") print( - f"Something went wrong in CI reordering, expecting total of {len(tests)}:\n" - f"but found prioritized: {len(bring_to_front)}\nthe rest: {len(the_rest)}\n" + f"Reordering tests: Prioritizing {len(prioritized_tests)} of {test_cnt_str}" ) - return ([], tests) - - prioritized_test_names = [] - remaining_test_names = [] - if bring_to_front: - test_cnt_str = pluralize(len(tests), "test") - print(f"Reordering tests: Prioritizing {len(bring_to_front)} of {test_cnt_str}") - - prioritized_test_names = [t.name for t in bring_to_front] - print(f"Prioritized: {prioritized_test_names}") - remaining_test_names = [t.name for t in the_rest] - print(f"The Rest: {remaining_test_names}") - else: - print("Didn't find any tests to prioritize") emit_metric( "test_reordering_prioritized_tests", { - "prioritized_test_cnt": len(bring_to_front), + "prioritized_test_cnt": len(prioritized_tests), "total_test_cnt": len(tests), - "prioritized_tests": prioritized_test_names, - "remaining_tests": remaining_test_names, + "prioritized_tests": prioritized_tests, + "remaining_tests": the_rest, }, ) - return (bring_to_front, the_rest) + return (prioritized_tests, the_rest) def get_test_case_configs(dirpath: str) -> None: