Skip to content

Commit

Permalink
Add Balanced Accuracy (Stability-AI#95)
Browse files Browse the repository at this point in the history
* First implementation of balanced accuracy

* Add comment

* Make JNLI a balanced acc task

* Add mcc and balanced f1 scores

---------

Co-authored-by: Paul O'Leary McCann <polm@dampfkraft.com>
  • Loading branch information
polm-stability and polm committed Oct 11, 2023
1 parent 100827e commit 936504d
Show file tree
Hide file tree
Showing 4 changed files with 76 additions and 5 deletions.
44 changes: 44 additions & 0 deletions lm_eval/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
import torch.nn.functional as F

from lm_eval.metrics import mean, weighted_perplexity, weighted_mean, bits_per_byte
from lm_eval.metrics import balanced_mean, matthews_corrcoef, macro_f1
from lm_eval import utils
from abc import abstractmethod

Expand Down Expand Up @@ -709,6 +710,49 @@ def aggregation(self):
"acc_norm": mean,
}

class BalancedMultipleChoiceTask(MultipleChoiceTask):
"""A task where the choices are the same every time, and accuracy should be
calculated separately for each class.
Originally created for marc-ja, which is severely imbalanced, though also
useful with less weird datasets. Not suitable for datasets where the choices
change for every question.
"""
def process_results(self, doc, results):
gold = doc["gold"]

pred = np.argmax(results)
acc = 1.0 if np.argmax(results) == gold else 0.0
completion_len = np.array([float(len(i)) for i in doc["choices"]])
acc_norm = 1.0 if np.argmax(results / completion_len) == gold else 0.0

return {
"acc": acc,
"acc_norm": acc_norm,
"balanced_acc": (acc, gold),
"mcc": (gold, pred),
"macro_f1": (gold, pred),
}

def higher_is_better(self):
return {
"acc": True,
"acc_norm": True,
"balanced_acc": True,
"mcc": True,
"macro_f1": True,
}

def aggregation(self):
return {
"acc": mean,
"acc_norm": mean,
"balanced_acc": balanced_mean,
"mcc": matthews_corrcoef,
"macro_f1": macro_f1,
}



class PerplexityTask(Task, abc.ABC):
def should_decontaminate(self):
Expand Down
27 changes: 27 additions & 0 deletions lm_eval/metrics.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
import sacrebleu
import sklearn.metrics
import random
from collections import defaultdict


def mean(arr):
Expand All @@ -29,6 +30,22 @@ def median(arr):
return arr[len(arr) // 2]


def balanced_mean(arr):
# each entry is of the form (acc score, class label)
# first group the results
by_class = defaultdict(list)
for acc, label in arr:
by_class[label].append(acc)

# calculate class averages
avgs = []
for key, vals in by_class.items():
avgs.append(sum(vals) / len(vals))

# average the class values
return sum(avgs) / len(avgs)


def matthews_corrcoef(items):
unzipped_list = list(zip(*items))
golds = unzipped_list[0]
Expand All @@ -45,6 +62,16 @@ def f1_score(items):
return np.max(fscore)


def macro_f1(items):
# this is different from f1-score which uses default binary avg
unzipped_list = list(zip(*items))
golds = unzipped_list[0]
preds = unzipped_list[1]
fscore = sklearn.metrics.f1_score(golds, preds, average="macro")

return fscore


def acc_all(items):
# Only count as correct if all answers are labeled correctly for each question
question_scoring_dict = {}
Expand Down
4 changes: 2 additions & 2 deletions lm_eval/tasks/ja/jnli.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
Homepage: https://github.com/yahoojapan/JGLUE
"""
from lm_eval.base import MultipleChoiceTask, rf
from lm_eval.base import BalancedMultipleChoiceTask, rf

_CITATION = """
@inproceedings{kurihara-etal-2022-jglue,
Expand All @@ -28,7 +28,7 @@



class JNLIWithFintanPrompt(MultipleChoiceTask):
class JNLIWithFintanPrompt(BalancedMultipleChoiceTask):
"""
prompt template is taken from [ChatGPT vs BERT: どちらが日本語をより理解できるのか?](https://fintan.jp/page/9126/)
"""
Expand Down
6 changes: 3 additions & 3 deletions lm_eval/tasks/ja/marc_ja.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
Homepage: https://github.com/yahoojapan/JGLUE
"""
from lm_eval.base import MultipleChoiceTask, rf
from lm_eval.base import BalancedMultipleChoiceTask, rf

_CITATION = """
@inproceedings{kurihara-etal-2022-jglue,
Expand All @@ -28,7 +28,7 @@



class MARCJaWithFintanPrompt(MultipleChoiceTask):
class MARCJaWithFintanPrompt(BalancedMultipleChoiceTask):
"""
prompt template is taken from [ChatGPT vs BERT: どちらが日本語をより理解できるのか?](https://fintan.jp/page/9126/)
"""
Expand Down Expand Up @@ -162,4 +162,4 @@ def construct_tasks():
tasks = {}
for version_class in VERSIONS:
tasks[f"marc_ja-{version_class.VERSION}-{version_class.PROMPT_VERSION}"] = version_class
return tasks
return tasks

0 comments on commit 936504d

Please sign in to comment.