From 629e4a08a905c2b6d8af1b6501eda933997132c0 Mon Sep 17 00:00:00 2001 From: Paul O'Leary McCann Date: Wed, 4 Oct 2023 18:16:26 +0900 Subject: [PATCH 1/4] First implementation of balanced accuracy --- lm_eval/base.py | 30 ++++++++++++++++++++++++++++++ lm_eval/metrics.py | 17 +++++++++++++++++ lm_eval/tasks/ja/marc_ja.py | 6 +++--- 3 files changed, 50 insertions(+), 3 deletions(-) diff --git a/lm_eval/base.py b/lm_eval/base.py index 0e40f08d23..68b7ac4536 100644 --- a/lm_eval/base.py +++ b/lm_eval/base.py @@ -13,6 +13,7 @@ import torch.nn.functional as F from lm_eval.metrics import mean, weighted_perplexity, weighted_mean, bits_per_byte +from lm_eval.metrics import balanced_mean from lm_eval import utils from abc import abstractmethod @@ -709,6 +710,35 @@ def aggregation(self): "acc_norm": mean, } +class BalancedMultipleChoiceTask(MultipleChoiceTask): + def process_results(self, doc, results): + gold = doc["gold"] + + acc = 1.0 if np.argmax(results) == gold else 0.0 + completion_len = np.array([float(len(i)) for i in doc["choices"]]) + acc_norm = 1.0 if np.argmax(results / completion_len) == gold else 0.0 + + return { + "acc": acc, + "acc_norm": acc_norm, + "balanced_acc": (acc, gold) + } + + def higher_is_better(self): + return { + "acc": True, + "acc_norm": True, + "balanced_acc": True, + } + + def aggregation(self): + return { + "acc": mean, + "acc_norm": mean, + "balanced_acc": balanced_mean, + } + + class PerplexityTask(Task, abc.ABC): def should_decontaminate(self): diff --git a/lm_eval/metrics.py b/lm_eval/metrics.py index 8f30a42695..ed6308343e 100644 --- a/lm_eval/metrics.py +++ b/lm_eval/metrics.py @@ -5,6 +5,7 @@ import sacrebleu import sklearn.metrics import random +from collections import defaultdict def mean(arr): @@ -29,6 +30,22 @@ def median(arr): return arr[len(arr) // 2] +def balanced_mean(arr): + # each entry is of the form (acc score, class label) + # first group the results + by_class = defaultdict(list) + for acc, label in arr: + by_class[label].append(acc) + + # calculate class averages + avgs = [] + for key, vals in by_class.items(): + avgs.append(sum(vals) / len(vals)) + + # average the class values + return sum(avgs) / len(avgs) + + def matthews_corrcoef(items): unzipped_list = list(zip(*items)) golds = unzipped_list[0] diff --git a/lm_eval/tasks/ja/marc_ja.py b/lm_eval/tasks/ja/marc_ja.py index 1b15f3cfd2..cb9405b75a 100644 --- a/lm_eval/tasks/ja/marc_ja.py +++ b/lm_eval/tasks/ja/marc_ja.py @@ -7,7 +7,7 @@ Homepage: https://github.com/yahoojapan/JGLUE """ -from lm_eval.base import MultipleChoiceTask, rf +from lm_eval.base import BalancedMultipleChoiceTask, rf _CITATION = """ @inproceedings{kurihara-etal-2022-jglue, @@ -28,7 +28,7 @@ -class MARCJaWithFintanPrompt(MultipleChoiceTask): +class MARCJaWithFintanPrompt(BalancedMultipleChoiceTask): """ prompt template is taken from [ChatGPT vs BERT: どちらが日本語をより理解できるのか?](https://fintan.jp/page/9126/) """ @@ -162,4 +162,4 @@ def construct_tasks(): tasks = {} for version_class in VERSIONS: tasks[f"marc_ja-{version_class.VERSION}-{version_class.PROMPT_VERSION}"] = version_class - return tasks \ No newline at end of file + return tasks From 7819403f770bbc58b457af5f1a2ce6a87a76a81b Mon Sep 17 00:00:00 2001 From: Paul O'Leary McCann Date: Thu, 5 Oct 2023 14:55:49 +0900 Subject: [PATCH 2/4] Add comment --- lm_eval/base.py | 7 +++++++ 1 file changed, 7 insertions(+) diff --git a/lm_eval/base.py b/lm_eval/base.py index 68b7ac4536..330135c4b5 100644 --- a/lm_eval/base.py +++ b/lm_eval/base.py @@ -711,6 +711,13 @@ def aggregation(self): } class BalancedMultipleChoiceTask(MultipleChoiceTask): + """A task where the choices are the same every time, and accuracy should be + calculated separately for each class. + + Originally created for marc-ja, which is severely imbalanced, though also + useful with less weird datasets. Not suitable for datasets where the choices + change for every question. + """ def process_results(self, doc, results): gold = doc["gold"] From 9fa383be5a6c606d67fe75bf52ce3a3066c1315e Mon Sep 17 00:00:00 2001 From: Paul O'Leary McCann Date: Thu, 5 Oct 2023 14:56:02 +0900 Subject: [PATCH 3/4] Make JNLI a balanced acc task --- lm_eval/tasks/ja/jnli.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/lm_eval/tasks/ja/jnli.py b/lm_eval/tasks/ja/jnli.py index 1dd5da3a67..61fdeee18f 100644 --- a/lm_eval/tasks/ja/jnli.py +++ b/lm_eval/tasks/ja/jnli.py @@ -7,7 +7,7 @@ Homepage: https://github.com/yahoojapan/JGLUE """ -from lm_eval.base import MultipleChoiceTask, rf +from lm_eval.base import BalancedMultipleChoiceTask, rf _CITATION = """ @inproceedings{kurihara-etal-2022-jglue, @@ -28,7 +28,7 @@ -class JNLIWithFintanPrompt(MultipleChoiceTask): +class JNLIWithFintanPrompt(BalancedMultipleChoiceTask): """ prompt template is taken from [ChatGPT vs BERT: どちらが日本語をより理解できるのか?](https://fintan.jp/page/9126/) """ From c6cb3be3e823e2de9cfeb9d6b61cdc01ebc25d64 Mon Sep 17 00:00:00 2001 From: Paul O'Leary McCann Date: Thu, 5 Oct 2023 15:41:26 +0900 Subject: [PATCH 4/4] Add mcc and balanced f1 scores --- lm_eval/base.py | 11 +++++++++-- lm_eval/metrics.py | 10 ++++++++++ 2 files changed, 19 insertions(+), 2 deletions(-) diff --git a/lm_eval/base.py b/lm_eval/base.py index 330135c4b5..d5d7a516f0 100644 --- a/lm_eval/base.py +++ b/lm_eval/base.py @@ -13,7 +13,7 @@ import torch.nn.functional as F from lm_eval.metrics import mean, weighted_perplexity, weighted_mean, bits_per_byte -from lm_eval.metrics import balanced_mean +from lm_eval.metrics import balanced_mean, matthews_corrcoef, macro_f1 from lm_eval import utils from abc import abstractmethod @@ -721,6 +721,7 @@ class BalancedMultipleChoiceTask(MultipleChoiceTask): def process_results(self, doc, results): gold = doc["gold"] + pred = np.argmax(results) acc = 1.0 if np.argmax(results) == gold else 0.0 completion_len = np.array([float(len(i)) for i in doc["choices"]]) acc_norm = 1.0 if np.argmax(results / completion_len) == gold else 0.0 @@ -728,7 +729,9 @@ def process_results(self, doc, results): return { "acc": acc, "acc_norm": acc_norm, - "balanced_acc": (acc, gold) + "balanced_acc": (acc, gold), + "mcc": (gold, pred), + "macro_f1": (gold, pred), } def higher_is_better(self): @@ -736,6 +739,8 @@ def higher_is_better(self): "acc": True, "acc_norm": True, "balanced_acc": True, + "mcc": True, + "macro_f1": True, } def aggregation(self): @@ -743,6 +748,8 @@ def aggregation(self): "acc": mean, "acc_norm": mean, "balanced_acc": balanced_mean, + "mcc": matthews_corrcoef, + "macro_f1": macro_f1, } diff --git a/lm_eval/metrics.py b/lm_eval/metrics.py index ed6308343e..9a00655fee 100644 --- a/lm_eval/metrics.py +++ b/lm_eval/metrics.py @@ -62,6 +62,16 @@ def f1_score(items): return np.max(fscore) +def macro_f1(items): + # this is different from f1-score which uses default binary avg + unzipped_list = list(zip(*items)) + golds = unzipped_list[0] + preds = unzipped_list[1] + fscore = sklearn.metrics.f1_score(golds, preds, average="macro") + + return fscore + + def acc_all(items): # Only count as correct if all answers are labeled correctly for each question question_scoring_dict = {}