diff --git a/metrics/cuad/cuad.py b/metrics/cuad/cuad.py new file mode 100644 index 00000000000..f8a23367631 --- /dev/null +++ b/metrics/cuad/cuad.py @@ -0,0 +1,116 @@ +# coding=utf-8 +# Copyright 2020 The HuggingFace Datasets Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +""" CUAD metric. """ + +import datasets + +from .evaluate import evaluate + + +_CITATION = """\ +@article{hendrycks2021cuad, + title={CUAD: An Expert-Annotated NLP Dataset for Legal Contract Review}, + author={Dan Hendrycks and Collin Burns and Anya Chen and Spencer Ball}, + journal={arXiv preprint arXiv:2103.06268}, + year={2021} +} +""" + +_DESCRIPTION = """ +This metric wrap the official scoring script for version 1 of the Contract +Understanding Atticus Dataset (CUAD). +Contract Understanding Atticus Dataset (CUAD) v1 is a corpus of more than 13,000 labels in 510 +commercial legal contracts that have been manually labeled to identify 41 categories of important +clauses that lawyers look for when reviewing contracts in connection with corporate transactions. +""" + +_KWARGS_DESCRIPTION = """ +Computes CUAD scores (EM, F1, AUPR, Precision@80%Recall, and Precision@90%Recall). +Args: + predictions: List of question-answers dictionaries with the following key-values: + - 'id': id of the question-answer pair as given in the references (see below) + - 'prediction_text': list of possible texts for the answer, as a list of strings + depending on a threshold on the confidence probability of each prediction. + references: List of question-answers dictionaries with the following key-values: + - 'id': id of the question-answer pair (see above), + - 'answers': a Dict in the CUAD dataset format + { + 'text': list of possible texts for the answer, as a list of strings + 'answer_start': list of start positions for the answer, as a list of ints + } + Note that answer_start values are not taken into account to compute the metric. +Returns: + 'exact_match': Exact match (the normalized answer exactly match the gold answer) + 'f1': The F-score of predicted tokens versus the gold answer + 'aupr': Area Under the Precision-Recall curve + 'prec_at_80_recall': Precision at 80% recall + 'prec_at_90_recall': Precision at 90% recall +Examples: + >>> predictions = [{'prediction_text': ['The seller:', 'The buyer/End-User: Shenzhen LOHAS Supply Chain Management Co., Ltd.'], 'id': 'LohaCompanyltd_20191209_F-1_EX-10.16_11917878_EX-10.16_Supply Agreement__Parties'}] + >>> references = [{'answers': {'answer_start': [143, 49], 'text': ['The seller:', 'The buyer/End-User: Shenzhen LOHAS Supply Chain Management Co., Ltd.']}, 'id': 'LohaCompanyltd_20191209_F-1_EX-10.16_11917878_EX-10.16_Supply Agreement__Parties'}] + >>> cuad_metric = datasets.load_metric("cuad") + >>> results = cuad_metric.compute(predictions=predictions, references=references) + >>> print(results) + {'exact_match': 100.0, 'f1': 100.0, 'aupr': 0.0, 'prec_at_80_recall': 1.0, 'prec_at_90_recall': 1.0} +""" + + +@datasets.utils.file_utils.add_start_docstrings(_DESCRIPTION, _KWARGS_DESCRIPTION) +class CUAD(datasets.Metric): + def _info(self): + return datasets.MetricInfo( + description=_DESCRIPTION, + citation=_CITATION, + inputs_description=_KWARGS_DESCRIPTION, + features=datasets.Features( + { + "predictions": { + "id": datasets.Value("string"), + "prediction_text": datasets.features.Sequence(datasets.Value("string")), + }, + "references": { + "id": datasets.Value("string"), + "answers": datasets.features.Sequence( + { + "text": datasets.Value("string"), + "answer_start": datasets.Value("int32"), + } + ), + }, + } + ), + codebase_urls=["https://www.atticusprojectai.org/cuad"], + reference_urls=["https://www.atticusprojectai.org/cuad"], + ) + + def _compute(self, predictions, references): + pred_dict = {prediction["id"]: prediction["prediction_text"] for prediction in predictions} + dataset = [ + { + "paragraphs": [ + { + "qas": [ + { + "answers": [{"text": answer_text} for answer_text in ref["answers"]["text"]], + "id": ref["id"], + } + for ref in references + ] + } + ] + } + ] + score = evaluate(dataset=dataset, predictions=pred_dict) + return score diff --git a/metrics/cuad/evaluate.py b/metrics/cuad/evaluate.py new file mode 100644 index 00000000000..3fbe88cfbd5 --- /dev/null +++ b/metrics/cuad/evaluate.py @@ -0,0 +1,205 @@ +""" Official evaluation script for CUAD dataset. """ + +import argparse +import json +import re +import string +import sys + +import numpy as np + + +IOU_THRESH = 0.5 + + +def get_jaccard(prediction, ground_truth): + remove_tokens = [".", ",", ";", ":"] + + for token in remove_tokens: + ground_truth = ground_truth.replace(token, "") + prediction = prediction.replace(token, "") + + ground_truth, prediction = ground_truth.lower(), prediction.lower() + ground_truth, prediction = ground_truth.replace("/", " "), prediction.replace("/", " ") + ground_truth, prediction = set(ground_truth.split(" ")), set(prediction.split(" ")) + + intersection = ground_truth.intersection(prediction) + union = ground_truth.union(prediction) + jaccard = len(intersection) / len(union) + return jaccard + + +def normalize_answer(s): + """Lower text and remove punctuation, articles and extra whitespace.""" + + def remove_articles(text): + return re.sub(r"\b(a|an|the)\b", " ", text) + + def white_space_fix(text): + return " ".join(text.split()) + + def remove_punc(text): + exclude = set(string.punctuation) + return "".join(ch for ch in text if ch not in exclude) + + def lower(text): + return text.lower() + + return white_space_fix(remove_articles(remove_punc(lower(s)))) + + +def compute_precision_recall(predictions, ground_truths, qa_id): + tp, fp, fn = 0, 0, 0 + + substr_ok = "Parties" in qa_id + + # first check if ground truth is empty + if len(ground_truths) == 0: + if len(predictions) > 0: + fp += len(predictions) # false positive for each one + else: + for ground_truth in ground_truths: + assert len(ground_truth) > 0 + # check if there is a match + match_found = False + for pred in predictions: + if substr_ok: + is_match = get_jaccard(pred, ground_truth) >= IOU_THRESH or ground_truth in pred + else: + is_match = get_jaccard(pred, ground_truth) >= IOU_THRESH + if is_match: + match_found = True + + if match_found: + tp += 1 + else: + fn += 1 + + # now also get any fps by looping through preds + for pred in predictions: + # Check if there's a match. if so, don't count (don't want to double count based on the above) + # but if there's no match, then this is a false positive. + # (Note: we get the true positives in the above loop instead of this loop so that we don't double count + # multiple predictions that are matched with the same answer.) + match_found = False + for ground_truth in ground_truths: + assert len(ground_truth) > 0 + if substr_ok: + is_match = get_jaccard(pred, ground_truth) >= IOU_THRESH or ground_truth in pred + else: + is_match = get_jaccard(pred, ground_truth) >= IOU_THRESH + if is_match: + match_found = True + + if not match_found: + fp += 1 + + precision = tp / (tp + fp) if tp + fp > 0 else np.nan + recall = tp / (tp + fn) if tp + fn > 0 else np.nan + + return precision, recall + + +def process_precisions(precisions): + """ + Processes precisions to ensure that precision and recall don't both get worse. + Assumes the list precision is sorted in order of recalls + """ + precision_best = precisions[::-1] + for i in range(1, len(precision_best)): + precision_best[i] = max(precision_best[i - 1], precision_best[i]) + precisions = precision_best[::-1] + return precisions + + +def get_aupr(precisions, recalls): + processed_precisions = process_precisions(precisions) + aupr = np.trapz(processed_precisions, recalls) + if np.isnan(aupr): + return 0 + return aupr + + +def get_prec_at_recall(precisions, recalls, recall_thresh): + """Assumes recalls are sorted in increasing order""" + processed_precisions = process_precisions(precisions) + prec_at_recall = 0 + for prec, recall in zip(processed_precisions, recalls): + if recall >= recall_thresh: + prec_at_recall = prec + break + return prec_at_recall + + +def exact_match_score(prediction, ground_truth): + return normalize_answer(prediction) == normalize_answer(ground_truth) + + +def metric_max_over_ground_truths(metric_fn, predictions, ground_truths): + score = 0 + for pred in predictions: + for ground_truth in ground_truths: + score = metric_fn(pred, ground_truth) + if score == 1: # break the loop when one prediction matches the ground truth + break + if score == 1: + break + return score + + +def evaluate(dataset, predictions): + f1 = exact_match = total = 0 + precisions = [] + recalls = [] + for article in dataset: + for paragraph in article["paragraphs"]: + for qa in paragraph["qas"]: + total += 1 + if qa["id"] not in predictions: + message = "Unanswered question " + qa["id"] + " will receive score 0." + print(message, file=sys.stderr) + continue + ground_truths = list(map(lambda x: x["text"], qa["answers"])) + prediction = predictions[qa["id"]] + precision, recall = compute_precision_recall(prediction, ground_truths, qa["id"]) + + precisions.append(precision) + recalls.append(recall) + + if precision == 0 and recall == 0: + f1 += 0 + else: + f1 += 2 * (precision * recall) / (precision + recall) + + exact_match += metric_max_over_ground_truths(exact_match_score, prediction, ground_truths) + + precisions = [x for _, x in sorted(zip(recalls, precisions))] + recalls.sort() + + f1 = 100.0 * f1 / total + exact_match = 100.0 * exact_match / total + aupr = get_aupr(precisions, recalls) + + prec_at_90_recall = get_prec_at_recall(precisions, recalls, recall_thresh=0.9) + prec_at_80_recall = get_prec_at_recall(precisions, recalls, recall_thresh=0.8) + + return { + "exact_match": exact_match, + "f1": f1, + "aupr": aupr, + "prec_at_80_recall": prec_at_80_recall, + "prec_at_90_recall": prec_at_90_recall, + } + + +if __name__ == "__main__": + parser = argparse.ArgumentParser(description="Evaluation for CUAD") + parser.add_argument("dataset_file", help="Dataset file") + parser.add_argument("prediction_file", help="Prediction File") + args = parser.parse_args() + with open(args.dataset_file) as dataset_file: + dataset_json = json.load(dataset_file) + dataset = dataset_json["data"] + with open(args.prediction_file) as prediction_file: + predictions = json.load(prediction_file) + print(json.dumps(evaluate(dataset, predictions)))