diff --git a/tools/submission/submission_checker.py b/tools/submission/submission_checker.py index 95ee7d03d..a02b925b9 100755 --- a/tools/submission/submission_checker.py +++ b/tools/submission/submission_checker.py @@ -962,12 +962,13 @@ def check_accuracy_dir(config, model, path, verbose): is_valid = False all_accuracy_valid = True acc = None - result_acc = None + result_acc = {} hash_val = None target = config.get_accuracy_target(model) acc_upper_limit = config.get_accuracy_upper_limit(model) patterns = [] acc_targets = [] + acc_types = [] if acc_upper_limit is not None: acc_limits = [] up_patterns = [] @@ -981,10 +982,11 @@ def check_accuracy_dir(config, model, path, verbose): acc_type, acc_target = target[i:i+2] patterns.append(ACC_PATTERN[acc_type]) acc_targets.append(acc_target) + acc_types.append(acc_type) acc_seen = [False for _ in acc_targets] with open(os.path.join(path, "accuracy.txt"), "r", encoding="utf-8") as f: for line in f: - for i, (pattern, acc_target) in enumerate(zip(patterns, acc_targets)): + for i, (pattern, acc_target, acc_type) in enumerate(zip(patterns, acc_targets, acc_types)): m = re.match(pattern, line) if m: acc = m.group(1) @@ -997,8 +999,8 @@ def check_accuracy_dir(config, model, path, verbose): elif acc is not None: all_accuracy_valid = False log.warning("%s accuracy not met: expected=%f, found=%s", path, acc_target, acc) - if i == 0 and acc: - result_acc = acc + if acc: + result_acc[acc_type] = acc acc = None if acc_upper_limit is not None: for i, (pattern, acc_limit) in enumerate(zip(up_patterns, acc_limits)): @@ -2012,6 +2014,7 @@ def log_result( acc_path, debug or is_closed_or_network, ) + acc = json.dumps(acc).replace(",", " ").replace('"', "").replace("{", "").replace("}", "") if mlperf_model in REQUIRED_ACC_BENCHMARK: if config.version in REQUIRED_ACC_BENCHMARK[mlperf_model]: extra_files_pass, missing_files = check_extra_files(acc_path, REQUIRED_ACC_BENCHMARK[mlperf_model][config.version])