diff --git a/.ci/pytorch/win-test-helpers/build_pytorch.bat b/.ci/pytorch/win-test-helpers/build_pytorch.bat index 5e74062613f63..57f59febe7293 100644 --- a/.ci/pytorch/win-test-helpers/build_pytorch.bat +++ b/.ci/pytorch/win-test-helpers/build_pytorch.bat @@ -128,6 +128,7 @@ python -c "import os, glob; os.system('python -mpip install --no-index --no-deps :: export test times so that potential sharded tests that'll branch off this build will use consistent data python tools/stats/export_test_times.py copy /Y ".pytorch-test-times.json" "%PYTORCH_FINAL_PACKAGE_DIR%" + copy /Y ".pytorch-test-file-ratings.json" "%PYTORCH_FINAL_PACKAGE_DIR%" :: Also save build/.ninja_log as an artifact copy /Y "build\.ninja_log" "%PYTORCH_FINAL_PACKAGE_DIR%\" diff --git a/.ci/pytorch/win-test-helpers/test_python_jit_legacy.bat b/.ci/pytorch/win-test-helpers/test_python_jit_legacy.bat index c18151d65c023..7277de33dcf85 100644 --- a/.ci/pytorch/win-test-helpers/test_python_jit_legacy.bat +++ b/.ci/pytorch/win-test-helpers/test_python_jit_legacy.bat @@ -2,6 +2,7 @@ call %SCRIPT_HELPERS_DIR%\setup_pytorch_env.bat echo Copying over test times file copy /Y "%PYTORCH_FINAL_PACKAGE_DIR_WIN%\.pytorch-test-times.json" "%PROJECT_DIR_WIN%" +copy /Y "%PYTORCH_FINAL_PACKAGE_DIR_WIN%\.pytorch-test-file-ratings.json" "%PROJECT_DIR_WIN%" pushd test diff --git a/.ci/pytorch/win-test-helpers/test_python_shard.bat b/.ci/pytorch/win-test-helpers/test_python_shard.bat index 5313bc0078d5f..ec7e78bac9eff 100644 --- a/.ci/pytorch/win-test-helpers/test_python_shard.bat +++ b/.ci/pytorch/win-test-helpers/test_python_shard.bat @@ -23,6 +23,7 @@ if "%SHARD_NUMBER%" == "1" ( echo Copying over test times file copy /Y "%PYTORCH_FINAL_PACKAGE_DIR_WIN%\.pytorch-test-times.json" "%PROJECT_DIR_WIN%" +copy /Y "%PYTORCH_FINAL_PACKAGE_DIR_WIN%\.pytorch-test-file-ratings.json" "%PROJECT_DIR_WIN%" echo Run nn tests python run_test.py --exclude-jit-executor --exclude-distributed-tests --shard "%SHARD_NUMBER%" "%NUM_TEST_SHARDS%" --verbose diff --git a/.circleci/config.yml b/.circleci/config.yml index 5cb89ac2c1403..36149c4f745b7 100644 --- a/.circleci/config.yml +++ b/.circleci/config.yml @@ -652,7 +652,7 @@ jobs: - run: name: Archive artifacts into zip command: | - zip -1 -r artifacts.zip dist/ build/.ninja_log build/compile_commands.json .pytorch-test-times.json + zip -1 -r artifacts.zip dist/ build/.ninja_log build/compile_commands.json .pytorch-test-times.json .pytorch-test-file-ratings.json cp artifacts.zip /Users/distiller/workspace - persist_to_workspace: diff --git a/.circleci/verbatim-sources/job-specs/job-specs-custom.yml b/.circleci/verbatim-sources/job-specs/job-specs-custom.yml index f03e173ccece9..0a2aee3116a3f 100644 --- a/.circleci/verbatim-sources/job-specs/job-specs-custom.yml +++ b/.circleci/verbatim-sources/job-specs/job-specs-custom.yml @@ -177,7 +177,7 @@ - run: name: Archive artifacts into zip command: | - zip -1 -r artifacts.zip dist/ build/.ninja_log build/compile_commands.json .pytorch-test-times.json + zip -1 -r artifacts.zip dist/ build/.ninja_log build/compile_commands.json .pytorch-test-times.json .pytorch-test-file-ratings.json cp artifacts.zip /Users/distiller/workspace - persist_to_workspace: diff --git a/.github/workflows/_linux-build.yml b/.github/workflows/_linux-build.yml index 269260c34aae1..7031e4e6f9aaf 100644 --- a/.github/workflows/_linux-build.yml +++ b/.github/workflows/_linux-build.yml @@ -170,7 +170,7 @@ jobs: - name: Archive artifacts into zip if: inputs.build-generates-artifacts && steps.build.outcome != 'skipped' run: | - zip -1 -r artifacts.zip dist/ build/custom_test_artifacts build/lib build/bin .pytorch-test-times.json + zip -1 -r artifacts.zip dist/ build/custom_test_artifacts build/lib build/bin .pytorch-test-times.json .pytorch-test-file-ratings.json - name: Store PyTorch Build Artifacts on S3 uses: seemethere/upload-artifact-s3@v5 diff --git a/.github/workflows/_mac-build.yml b/.github/workflows/_mac-build.yml index 9ba093e6c7ede..2585709b516ee 100644 --- a/.github/workflows/_mac-build.yml +++ b/.github/workflows/_mac-build.yml @@ -182,7 +182,7 @@ jobs: - name: Archive artifacts into zip if: inputs.build-generates-artifacts && steps.build.outcome != 'skipped' run: | - zip -1 -r artifacts.zip dist/ build/.ninja_log build/compile_commands.json .pytorch-test-times.json + zip -1 -r artifacts.zip dist/ build/.ninja_log build/compile_commands.json .pytorch-test-times.json .pytorch-test-file-ratings.json - name: Store PyTorch Build Artifacts on GHA uses: actions/upload-artifact@v3 diff --git a/.gitignore b/.gitignore index 9ffab8fffac30..424cc4b769352 100644 --- a/.gitignore +++ b/.gitignore @@ -19,6 +19,7 @@ coverage.xml **/.pytorch-disabled-tests.json **/.pytorch-slow-tests.json **/.pytorch-test-times.json +**/.pytorch-test-file-ratings.json */*.pyc */*.so* */**/__pycache__ diff --git a/test/run_test.py b/test/run_test.py index cd62621ca7439..2c27197343541 100755 --- a/test/run_test.py +++ b/test/run_test.py @@ -11,9 +11,10 @@ import subprocess import sys import tempfile +import time from datetime import datetime from distutils.version import LooseVersion -from typing import Any, cast, Dict, List, Optional, Union +from typing import Any, cast, Dict, List, NamedTuple, Optional, Union import pkg_resources @@ -40,11 +41,11 @@ # using tools/ to optimize test run. sys.path.insert(0, str(REPO_ROOT)) from tools.stats.export_test_times import TEST_TIMES_FILE + from tools.stats.upload_stats_lib import emit_metric from tools.testing.test_selections import ( calculate_shards, get_reordered_tests, get_test_case_configs, - log_time_savings, NUM_PROCS, ShardedTest, THRESHOLD, @@ -1278,7 +1279,9 @@ def exclude_tests( return selected_tests -def must_serial(file: str) -> bool: +def must_serial(file: Union[str, ShardedTest]) -> bool: + if isinstance(file, ShardedTest): + file = file.name return ( os.getenv("PYTORCH_TEST_RUN_EVERYTHING_IN_SERIAL", "0") == "1" or DISTRIBUTED_TEST_PREFIX in os.getenv("TEST_CONFIG", "") @@ -1408,20 +1411,10 @@ def get_selected_tests(options) -> List[ShardedTest]: ) selected_tests = [parse_test_module(x) for x in selected_tests] + return selected_tests - # sharding - which_shard, num_shards = 1, 1 - if options.shard: - assert len(options.shard) == 2, "Unexpected shard format" - assert min(options.shard) > 0, "Shards must be positive numbers" - which_shard, num_shards = options.shard - assert ( - which_shard <= num_shards - ), "Selected shard must be less than or equal to total number of shards" - assert num_shards <= len( - selected_tests - ), f"Number of shards must be less than {len(selected_tests)}" +def download_test_times(file: str = TEST_TIMES_FILE) -> Dict[str, float]: # Download previous test times to make sharding decisions path = os.path.join(str(REPO_ROOT), TEST_TIMES_FILE) if os.path.exists(path): @@ -1434,14 +1427,35 @@ def get_selected_tests(options) -> List[ShardedTest]: print( "::warning:: Gathered no stats from artifacts. Proceeding with default sharding plan." ) + return {} else: print("Found test time stats from artifacts") + return test_file_times[test_config] + + +def do_sharding( + options, + selected_tests: List[str], + test_file_times: Dict[str, float], + sort_by_time: bool = True, +) -> List[ShardedTest]: + which_shard, num_shards = 1, 1 + if options.shard: + assert len(options.shard) == 2, "Unexpected shard format" + assert min(options.shard) > 0, "Shards must be positive numbers" + which_shard, num_shards = options.shard + assert ( + which_shard <= num_shards + ), "Selected shard must be less than or equal to total number of shards" if HAVE_TEST_SELECTION_TOOLS: # Do sharding - test_file_times_config = test_file_times.get(test_config, {}) shards = calculate_shards( - num_shards, selected_tests, test_file_times_config, must_serial=must_serial + num_shards, + selected_tests, + test_file_times, + must_serial=must_serial, + sort_by_time=sort_by_time, ) _, tests_from_shard = shards[which_shard - 1] selected_tests = tests_from_shard @@ -1449,9 +1463,14 @@ def get_selected_tests(options) -> List[ShardedTest]: return selected_tests +class TestFailure(NamedTuple): + test: str + message: str + + def run_test_module( test: Union[ShardedTest, str], test_directory: str, options -) -> Optional[str]: +) -> Optional[TestFailure]: maybe_set_hip_visible_devies() # Printing the date here can help diagnose which tests are slow @@ -1472,39 +1491,24 @@ def run_test_module( # return code -N, where N is the signal number. signal_name = SIGNALS_TO_NAMES_DICT[-return_code] message += f" Received signal: {signal_name}" - return message + return TestFailure(test, message) def run_tests( - selected_tests: List[ShardedTest], test_directory: str, options, group_name: str + selected_tests: List[ShardedTest], + test_directory: str, + options, + failures: List[TestFailure], ) -> None: - failure_messages = [] - if len(selected_tests) == 0: - print_to_stderr(f"No tests in group `{group_name}`") - return failure_messages + return # parallel = in parallel with other files # serial = this file on it's own. The file might still be run in parallel with itself (ex test_ops) - selected_tests_parallel = [ - x - for x in selected_tests - if not must_serial(x.name if isinstance(x, ShardedTest) else x) - ] + selected_tests_parallel = [x for x in selected_tests if not must_serial(x)] selected_tests_serial = [ x for x in selected_tests if x not in selected_tests_parallel ] - print(f"TEST GROUP: {group_name}") - print_to_stderr( - "parallel (file granularity) tests :\n {}".format( - "\n".join(str(x) for x in selected_tests_parallel) - ) - ) - print_to_stderr( - "serial (file granularity) tests:\n {}".format( - "\n ".join(str(x) for x in selected_tests_serial) - ) - ) # See Note [ROCm parallel CI testing] pool = get_context("spawn").Pool( @@ -1523,15 +1527,15 @@ def run_tests( # Take the conftest file from the test directory shutil.copy(os.path.join(test_directory, "conftest.py"), cpp_conftest_file) - def handle_error_messages(err_message): - if err_message is None: + def handle_error_messages(failure: Optional[TestFailure]): + if failure is None: return False - failure_messages.append(err_message) - print_to_stderr(err_message) + failures.append(failure) + print_to_stderr(failure.message) return True - def parallel_test_completion_callback(err_message): - test_failed = handle_error_messages(err_message) + def parallel_test_completion_callback(failure): + test_failed = handle_error_messages(failure) if ( test_failed and not options.continue_through_error @@ -1557,10 +1561,10 @@ def parallel_test_completion_callback(err_message): if ( not options.continue_through_error and not RERUN_DISABLED_TESTS - and len(failure_messages) != 0 + and len(failures) != 0 ): raise RuntimeError( - "\n".join(failure_messages) + "\n".join(x.message for x in failures) + "\n\nTip: You can keep running tests even on failure by " "passing --keep-going to run_test.py.\n" "If running on CI, add the 'keep-going' label to " @@ -1571,20 +1575,20 @@ def parallel_test_completion_callback(err_message): options_clone = copy.deepcopy(options) if can_run_in_pytest(test): options_clone.pytest = True - err_message = run_test_module(test, test_directory, options_clone) - test_failed = handle_error_messages(err_message) + failure = run_test_module(test, test_directory, options_clone) + test_failed = handle_error_messages(failure) if ( test_failed and not options.continue_through_error and not RERUN_DISABLED_TESTS ): - raise RuntimeError(err_message) + raise RuntimeError(failure.message) finally: pool.terminate() pool.join() - return failure_messages + return def check_pip_packages() -> None: @@ -1611,30 +1615,47 @@ def main(): test_directory = str(REPO_ROOT / "test") selected_tests = get_selected_tests(options) - if options.verbose: - print_to_stderr( - "Selected tests:\n {}".format("\n ".join(str(x) for x in selected_tests)) - ) - - if options.dry_run: - return - if options.coverage and not PYTORCH_COLLECT_COVERAGE: shell(["coverage", "erase"]) prioritized_tests = [] - remaining_tests = selected_tests + general_tests = selected_tests if IS_CI and HAVE_TEST_SELECTION_TOOLS: - (prioritized_tests, remaining_tests) = get_reordered_tests(selected_tests) - log_time_savings( - selected_tests, - prioritized_tests, - is_serial_test_fn=must_serial, - num_procs=NUM_PROCS, - ) - # downloading test cases configuration to local environment get_test_case_configs(dirpath=test_directory) + (prioritized_tests, general_tests) = get_reordered_tests(general_tests) + + metrics_dict = { + "prioritized_tests": prioritized_tests, + "general_tests": general_tests, + "cpp": options.cpp, + } + + test_times_dict = download_test_times(TEST_TIMES_FILE) + prioritized_tests = do_sharding( + options, prioritized_tests, test_times_dict, sort_by_time=False + ) + general_tests = do_sharding(options, general_tests, test_times_dict) + + if options.verbose: + + def print_tests(category, tests): + tests_str = "\n ".join(str(x) for x in tests) + print_to_stderr(f"{category} tests:\n {tests_str}") + + print_tests( + "Prioritized parallel", [x for x in prioritized_tests if not must_serial(x)] + ) + print_tests( + "Prioritized serial", [x for x in prioritized_tests if must_serial(x)] + ) + print_tests( + "General parallel", [x for x in general_tests if not must_serial(x)] + ) + print_tests("General serial", [x for x in general_tests if must_serial(x)]) + + if options.dry_run: + return if options.dynamo: os.environ["PYTORCH_TEST_WITH_DYNAMO"] = "1" @@ -1646,17 +1667,17 @@ def main(): os.makedirs(REPO_ROOT / "test" / "test-reports", exist_ok=True) - failure_messages = [] - + prioritized_failures: List[TestFailure] = [] + general_failures: List[TestFailure] = [] + start_time = time.time() # First run the prioritized tests, then the remaining tests. try: - failure_messages = run_tests( - prioritized_tests, test_directory, options, "Prioritized tests" - ) - - failure_messages += run_tests( - remaining_tests, test_directory, options, "General tests" - ) + run_tests(prioritized_tests, test_directory, options, prioritized_failures) + metrics_dict["prioritized_failures"] = [x.test for x in prioritized_failures] + metrics_dict["general_start_time"] = time.time() - start_time + run_tests(general_tests, test_directory, options, general_failures) + metrics_dict["general_end_time"] = time.time() - start_time + metrics_dict["all_failures"] = [x.test for x in general_failures] finally: if options.coverage: @@ -1671,8 +1692,12 @@ def main(): if not PYTORCH_COLLECT_COVERAGE: cov.html_report() - if len(failure_messages) != 0: - for err in failure_messages: + if IS_CI and HAVE_TEST_SELECTION_TOOLS: + emit_metric("td_experiment_1", metrics_dict) + + all_failures = prioritized_failures + general_failures + if len(all_failures) != 0: + for _, err in all_failures: print_to_stderr(err) # A disabled test is expected to fail, so there is no need to report a failure here diff --git a/tools/stats/export_test_times.py b/tools/stats/export_test_times.py index 4554f546ee050..6e60158a7ea9b 100644 --- a/tools/stats/export_test_times.py +++ b/tools/stats/export_test_times.py @@ -3,14 +3,16 @@ REPO_ROOT = pathlib.Path(__file__).resolve().parent.parent.parent sys.path.append(str(REPO_ROOT)) -from tools.stats.import_test_stats import get_test_times +from tools.stats.import_test_stats import get_test_file_ratings, get_test_times TEST_TIMES_FILE = ".pytorch-test-times.json" +TEST_FILE_RATINGS_FILE = ".pytorch-test-file-ratings.json" def main() -> None: print(f"Exporting test times from test-infra to {TEST_TIMES_FILE}") get_test_times(str(REPO_ROOT), filename=TEST_TIMES_FILE) + get_test_file_ratings(str(REPO_ROOT), filename=TEST_FILE_RATINGS_FILE) if __name__ == "__main__": diff --git a/tools/stats/import_test_stats.py b/tools/stats/import_test_stats.py index 28d8ee0961bd9..a0c0190580748 100644 --- a/tools/stats/import_test_stats.py +++ b/tools/stats/import_test_stats.py @@ -20,6 +20,7 @@ def get_disabled_issues() -> List[str]: SLOW_TESTS_FILE = ".pytorch-slow-tests.json" DISABLED_TESTS_FILE = ".pytorch-disabled-tests.json" + FILE_CACHE_LIFESPAN_SECONDS = datetime.timedelta(hours=3).seconds @@ -116,3 +117,12 @@ def process_disabled_test(the_response: Dict[str, Any]) -> Dict[str, Any]: except Exception: print("Couldn't download test skip set, leaving all tests enabled...") return {} + + +def get_test_file_ratings(dirpath: str, filename: str) -> Optional[Dict[str, Any]]: + url = "https://raw.githubusercontent.com/pytorch/test-infra/generated-stats/stats/file_test_rating.json" + try: + return fetch_and_cache(dirpath, filename, url, lambda x: x) + except Exception: + print("Couldn't download test file ratings file, not reordering...") + return {} diff --git a/tools/stats/upload_stats_lib.py b/tools/stats/upload_stats_lib.py index dd48e78ab7a40..ab8d8734c8da0 100644 --- a/tools/stats/upload_stats_lib.py +++ b/tools/stats/upload_stats_lib.py @@ -263,7 +263,7 @@ def value(self) -> Any: value = os.environ.get(self.env_var) if value is None and self.required: raise ValueError( - f"Missing {self.name}. Please set the {self.env_var}" + f"Missing {self.name}. Please set the {self.env_var} " "environment variable to pass in this value." ) if self.type_conversion_fn: diff --git a/tools/test/test_test_selections.py b/tools/test/test_test_selections.py index 04f5f8899a1c2..c0f646c40f49e 100644 --- a/tools/test/test_test_selections.py +++ b/tools/test/test_test_selections.py @@ -394,28 +394,24 @@ def test_dedupes_failing_test_files(self, mock_exists: Any, mock_open: Any) -> N "tools.testing.test_selections._get_modified_tests", return_value={"test2", "test4"}, ) + @mock.patch( + "tools.testing.test_selections._get_file_rating_tests", return_value=["test1"] + ) def test_get_reordered_tests( - self, mock_get_prev_failing_tests: Any, mock_get_modified_tests: Any + self, + mock_get_prev_failing_tests: Any, + mock_get_modified_tests: Any, + mock_get_file_rating_tests: Any, ) -> None: - tests = [ - ShardedTest(name="test1", shard=1, num_shards=2, time=600.0), - ShardedTest(name="test2", shard=1, num_shards=2, time=500.0), - ShardedTest(name="test3", shard=1, num_shards=2, time=400.0), - ShardedTest(name="test4", shard=1, num_shards=2, time=300.0), - ShardedTest(name="test5", shard=1, num_shards=2, time=200.0), - ] + tests = ["test1", "test2", "test3", "test4", "test5"] - expected_prioritized_tests = {"test4", "test2"} - expected_remaining_tests = {"test1", "test3", "test5"} + expected_prioritized_tests = ["test4", "test2", "test1"] + expected_remaining_tests = {"test3", "test5"} prioritized_tests, remaining_tests = get_reordered_tests(tests) - # Just want to check the names of the tests - prioritized_tests_name = {test.name for test in prioritized_tests} - remaining_tests_name = {test.name for test in remaining_tests} - - self.assertSetEqual(expected_prioritized_tests, prioritized_tests_name) - self.assertSetEqual(expected_remaining_tests, remaining_tests_name) + self.assertListEqual(expected_prioritized_tests, prioritized_tests) + self.assertSetEqual(expected_remaining_tests, set(remaining_tests)) def test_compute_prioritization_time_savings_with_multiple_threads(self) -> None: tests = [ diff --git a/tools/testing/test_selections.py b/tools/testing/test_selections.py index 76f841b8902bd..57ba611c7df72 100644 --- a/tools/testing/test_selections.py +++ b/tools/testing/test_selections.py @@ -3,16 +3,20 @@ import math import os import subprocess +from collections import defaultdict from pathlib import Path -from typing import Callable, Dict, List, NamedTuple, Optional, Set, Tuple +from typing import Callable, cast, Dict, List, NamedTuple, Optional, Set, Tuple from warnings import warn from tools.shared.logging_utils import duration_to_str, pluralize +from tools.stats.export_test_times import TEST_FILE_RATINGS_FILE from tools.stats.import_test_stats import get_disabled_tests, get_slow_tests from tools.stats.upload_stats_lib import emit_metric +REPO_ROOT = Path(__file__).resolve().parent.parent.parent + IS_MEM_LEAK_CHECK = os.getenv("PYTORCH_TEST_CUDA_MEM_LEAK_CHECK", "0") == "1" # NUM_PROCS_FOR_SHARDING_CALC must remain consistent across all shards of a job @@ -81,8 +85,8 @@ def get_with_pytest_shard( ) -> List[ShardedTest]: sharded_tests: List[ShardedTest] = [] for test in tests: - duration = test_file_times[test] - if duration > THRESHOLD: + duration = test_file_times.get(test, None) + if duration and duration > THRESHOLD: num_shards = math.ceil(duration / THRESHOLD) for i in range(num_shards): sharded_tests.append( @@ -98,20 +102,24 @@ def calculate_shards( tests: List[str], test_file_times: Dict[str, float], must_serial: Optional[Callable[[str], bool]] = None, + sort_by_time: bool = True, ) -> List[Tuple[float, List[ShardedTest]]]: must_serial = must_serial or (lambda x: True) - known_tests = [x for x in tests if x in test_file_times] - unknown_tests: List[str] = [x for x in tests if x not in known_tests] + known_tests = tests + unknown_tests = [] - sorted_tests = sorted( - get_with_pytest_shard(known_tests, test_file_times), - key=lambda j: j.get_time(), - reverse=True, - ) + if sort_by_time: + known_tests = [x for x in tests if x in test_file_times] + unknown_tests = [x for x in tests if x not in known_tests] + + known_tests = get_with_pytest_shard(known_tests, test_file_times) + + if sort_by_time: + known_tests = sorted(known_tests, key=lambda j: j.get_time(), reverse=True) sharded_jobs: List[ShardJob] = [ShardJob() for _ in range(num_shards)] - for test in sorted_tests: + for test in known_tests: if must_serial(test.name): min_sharded_job = min(sharded_jobs, key=lambda j: j.get_total_time()) min_sharded_job.serial.append(test) @@ -127,7 +135,7 @@ def calculate_shards( return [job.convert_to_tuple() for job in sharded_jobs] -def _query_changed_test_files() -> List[str]: +def _query_changed_files() -> List[str]: default_branch = f"origin/{os.environ.get('GIT_DEFAULT_BRANCH', 'main')}" merge_base = ( subprocess.check_output(["git", "merge-base", default_branch, "HEAD"]) @@ -186,7 +194,7 @@ def _parse_prev_failing_test_files(last_failed_tests: Dict[str, bool]) -> Set[st def _get_modified_tests() -> Set[str]: try: - changed_files = _query_changed_test_files() + changed_files = _query_changed_files() except Exception as e: warn(f"Can't query changed test files due to {e}") # If unable to get changed files from git, quit without doing any sorting @@ -271,76 +279,81 @@ def log_time_savings( return max_time_savings_sec +def _get_file_rating_tests() -> List[str]: + path = REPO_ROOT / TEST_FILE_RATINGS_FILE + if not os.path.exists(path): + print(f"could not find path {path}") + return [] + with open(path) as f: + test_file_ratings = cast(Dict[str, Dict[str, float]], json.load(f)) + try: + changed_files = _query_changed_files() + except Exception as e: + warn(f"Can't query changed test files due to {e}") + return [] + ratings: Dict[str, float] = defaultdict(float) + for file in changed_files: + for test_file, score in test_file_ratings.get(file, {}).items(): + ratings[test_file] += score + prioritize = sorted(ratings, key=lambda x: ratings[x]) + return prioritize + + def get_reordered_tests( - tests: List[ShardedTest], -) -> Tuple[List[ShardedTest], List[ShardedTest]]: + tests: List[str], +) -> Tuple[List[str], List[str]]: """ Get the reordered test filename list based on github PR history or git changed file. We prioritize running test files that were changed. """ + prioritized_tests: List[str] = [] - def print_tests(tests: Set[str], test_group_description: str) -> None: - if not tests: + def add_tests(tests_to_add: List[str], test_group_description: str) -> None: + if not tests_to_add: return print(f"{test_group_description}:") - for test in tests: - print(f" {test}") - - prioritized_tests: Set[str] = set() - - pri_test = _get_previously_failing_tests() - print_tests( - pri_test, "If run, these tests will prioritized because they previously failed" + for test in tests_to_add: + if test in tests: + print(f" {test}") + if test not in prioritized_tests: + prioritized_tests.append(test) + + add_tests( + sorted(_get_previously_failing_tests()), + "If run, these tests will prioritized because they previously failed", ) - prioritized_tests |= pri_test - pri_test |= _get_modified_tests() - print_tests( - pri_test, "If run, these tests will be prioritized because they were modified" + add_tests( + sorted(_get_modified_tests()), + "If run, these tests will be prioritized because they were modified", ) - prioritized_tests |= pri_test - bring_to_front = [] - the_rest = [] + add_tests( + _get_file_rating_tests(), + "If run, these tests will be preioritized for an experiment in TD", + ) - for test in tests: - if test.name in prioritized_tests: - bring_to_front.append(test) - else: - the_rest.append(test) + prioritized_tests = [x for x in prioritized_tests if x in tests] + the_rest = [x for x in tests if x not in prioritized_tests] - if len(tests) != len(bring_to_front) + len(the_rest): + if prioritized_tests: + test_cnt_str = pluralize(len(tests), "test") print( - f"Something went wrong in CI reordering, expecting total of {len(tests)}:\n" - f"but found prioritized: {len(bring_to_front)}\nthe rest: {len(the_rest)}\n" + f"Reordering tests: Prioritizing {len(prioritized_tests)} of {test_cnt_str}" ) - return ([], tests) - - prioritized_test_names = [] - remaining_test_names = [] - if bring_to_front: - test_cnt_str = pluralize(len(tests), "test") - print(f"Reordering tests: Prioritizing {len(bring_to_front)} of {test_cnt_str}") - - prioritized_test_names = [t.name for t in bring_to_front] - print(f"Prioritized: {prioritized_test_names}") - remaining_test_names = [t.name for t in the_rest] - print(f"The Rest: {remaining_test_names}") - else: - print("Didn't find any tests to prioritize") emit_metric( "test_reordering_prioritized_tests", { - "prioritized_test_cnt": len(bring_to_front), + "prioritized_test_cnt": len(prioritized_tests), "total_test_cnt": len(tests), - "prioritized_tests": prioritized_test_names, - "remaining_tests": remaining_test_names, + "prioritized_tests": prioritized_tests, + "remaining_tests": the_rest, }, ) - return (bring_to_front, the_rest) + return (prioritized_tests, the_rest) def get_test_case_configs(dirpath: str) -> None: