Skip to content

Commit

Permalink
Fix accuracy in table, report multiple metrics (#1823)
Browse files Browse the repository at this point in the history
  • Loading branch information
pgmpablo157321 authored Aug 15, 2024
1 parent 725b3c0 commit 34884c5
Showing 1 changed file with 7 additions and 4 deletions.
11 changes: 7 additions & 4 deletions tools/submission/submission_checker.py
Original file line number Diff line number Diff line change
Expand Up @@ -962,12 +962,13 @@ def check_accuracy_dir(config, model, path, verbose):
is_valid = False
all_accuracy_valid = True
acc = None
result_acc = None
result_acc = {}
hash_val = None
target = config.get_accuracy_target(model)
acc_upper_limit = config.get_accuracy_upper_limit(model)
patterns = []
acc_targets = []
acc_types = []
if acc_upper_limit is not None:
acc_limits = []
up_patterns = []
Expand All @@ -981,10 +982,11 @@ def check_accuracy_dir(config, model, path, verbose):
acc_type, acc_target = target[i:i+2]
patterns.append(ACC_PATTERN[acc_type])
acc_targets.append(acc_target)
acc_types.append(acc_type)
acc_seen = [False for _ in acc_targets]
with open(os.path.join(path, "accuracy.txt"), "r", encoding="utf-8") as f:
for line in f:
for i, (pattern, acc_target) in enumerate(zip(patterns, acc_targets)):
for i, (pattern, acc_target, acc_type) in enumerate(zip(patterns, acc_targets, acc_types)):
m = re.match(pattern, line)
if m:
acc = m.group(1)
Expand All @@ -997,8 +999,8 @@ def check_accuracy_dir(config, model, path, verbose):
elif acc is not None:
all_accuracy_valid = False
log.warning("%s accuracy not met: expected=%f, found=%s", path, acc_target, acc)
if i == 0 and acc:
result_acc = acc
if acc:
result_acc[acc_type] = acc
acc = None
if acc_upper_limit is not None:
for i, (pattern, acc_limit) in enumerate(zip(up_patterns, acc_limits)):
Expand Down Expand Up @@ -2012,6 +2014,7 @@ def log_result(
acc_path,
debug or is_closed_or_network,
)
acc = json.dumps(acc).replace(",", " ").replace('"', "").replace("{", "").replace("}", "")
if mlperf_model in REQUIRED_ACC_BENCHMARK:
if config.version in REQUIRED_ACC_BENCHMARK[mlperf_model]:
extra_files_pass, missing_files = check_extra_files(acc_path, REQUIRED_ACC_BENCHMARK[mlperf_model][config.version])
Expand Down

0 comments on commit 34884c5

Please sign in to comment.