-
Notifications
You must be signed in to change notification settings - Fork 2.7k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Added CUAD metrics #2273
Merged
Merged
Added CUAD metrics #2273
Changes from 1 commit
Commits
Show all changes
2 commits
Select commit
Hold shift + click to select a range
File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,115 @@ | ||
# coding=utf-8 | ||
# Copyright 2020 The HuggingFace Datasets Authors. | ||
# | ||
# Licensed under the Apache License, Version 2.0 (the "License"); | ||
# you may not use this file except in compliance with the License. | ||
# You may obtain a copy of the License at | ||
# | ||
# http://www.apache.org/licenses/LICENSE-2.0 | ||
# | ||
# Unless required by applicable law or agreed to in writing, software | ||
# distributed under the License is distributed on an "AS IS" BASIS, | ||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
# See the License for the specific language governing permissions and | ||
# limitations under the License. | ||
""" CUAD metric. """ | ||
|
||
import datasets | ||
|
||
from .evaluate import evaluate | ||
|
||
|
||
_CITATION = """\ | ||
@article{hendrycks2021cuad, | ||
title={CUAD: An Expert-Annotated NLP Dataset for Legal Contract Review}, | ||
author={Dan Hendrycks and Collin Burns and Anya Chen and Spencer Ball}, | ||
journal={arXiv preprint arXiv:2103.06268}, | ||
year={2021} | ||
} | ||
""" | ||
|
||
_DESCRIPTION = """ | ||
This metric wrap the official scoring script for version 1 of the Contract | ||
Understanding Atticus Dataset (CUAD). | ||
Contract Understanding Atticus Dataset (CUAD) v1 is a corpus of more than 13,000 labels in 510 | ||
commercial legal contracts that have been manually labeled to identify 41 categories of important | ||
clauses that lawyers look for when reviewing contracts in connection with corporate transactions. | ||
""" | ||
|
||
_KWARGS_DESCRIPTION = """ | ||
Computes CUAD scores (EM, F1, AUPR, Precision@80%Recall, and Precision@90%Recall). | ||
Args: | ||
predictions: List of question-answers dictionaries with the following key-values: | ||
- 'id': id of the question-answer pair as given in the references (see below) | ||
- 'prediction_text': list of possible texts for the answer, as a list of strings | ||
references: List of question-answers dictionaries with the following key-values: | ||
- 'id': id of the question-answer pair (see above), | ||
- 'answers': a Dict in the CUAD dataset format | ||
{ | ||
'text': list of possible texts for the answer, as a list of strings | ||
'answer_start': list of start positions for the answer, as a list of ints | ||
} | ||
Note that answer_start values are not taken into account to compute the metric. | ||
Returns: | ||
'exact_match': Exact match (the normalized answer exactly match the gold answer) | ||
'f1': The F-score of predicted tokens versus the gold answer | ||
'aupr': Area Under the Precision-Recall curve | ||
'prec_at_80_recall': Precision at 80% recall | ||
'prec_at_90_recall': Precision at 90% recall | ||
Examples: | ||
>>> predictions = [{'prediction_text': ['The seller:', 'The buyer/End-User: Shenzhen LOHAS Supply Chain Management Co., Ltd.'], 'id': 'LohaCompanyltd_20191209_F-1_EX-10.16_11917878_EX-10.16_Supply Agreement__Parties'}] | ||
>>> references = [{'answers': {'answer_start': [143, 49], 'text': ['The seller:', 'The buyer/End-User: Shenzhen LOHAS Supply Chain Management Co., Ltd.']}, 'id': 'LohaCompanyltd_20191209_F-1_EX-10.16_11917878_EX-10.16_Supply Agreement__Parties'}] | ||
>>> cuad_metric = datasets.load_metric("cuad") | ||
>>> results = cuad_metric.compute(predictions=predictions, references=references) | ||
>>> print(results) | ||
{'exact_match': 100.0, 'f1': 100.0, 'aupr': 0.0, 'prec_at_80_recall': 1.0, 'prec_at_90_recall': 1.0} | ||
""" | ||
|
||
|
||
@datasets.utils.file_utils.add_start_docstrings(_DESCRIPTION, _KWARGS_DESCRIPTION) | ||
class CUAD(datasets.Metric): | ||
def _info(self): | ||
return datasets.MetricInfo( | ||
description=_DESCRIPTION, | ||
citation=_CITATION, | ||
inputs_description=_KWARGS_DESCRIPTION, | ||
features=datasets.Features( | ||
{ | ||
"predictions": { | ||
"id": datasets.Value("string"), | ||
"prediction_text": datasets.features.Sequence(datasets.Value("string")), | ||
}, | ||
"references": { | ||
"id": datasets.Value("string"), | ||
"answers": datasets.features.Sequence( | ||
{ | ||
"text": datasets.Value("string"), | ||
"answer_start": datasets.Value("int32"), | ||
} | ||
), | ||
}, | ||
} | ||
), | ||
codebase_urls=["https://www.atticusprojectai.org/cuad"], | ||
reference_urls=["https://www.atticusprojectai.org/cuad"], | ||
) | ||
|
||
def _compute(self, predictions, references): | ||
pred_dict = {prediction["id"]: prediction["prediction_text"] for prediction in predictions} | ||
dataset = [ | ||
{ | ||
"paragraphs": [ | ||
{ | ||
"qas": [ | ||
{ | ||
"answers": [{"text": answer_text} for answer_text in ref["answers"]["text"]], | ||
"id": ref["id"], | ||
} | ||
for ref in references | ||
] | ||
} | ||
] | ||
} | ||
] | ||
score = evaluate(dataset=dataset, predictions=pred_dict) | ||
return score |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,205 @@ | ||
""" Official evaluation script for CUAD dataset. """ | ||
|
||
import argparse | ||
import json | ||
import re | ||
import string | ||
import sys | ||
|
||
import numpy as np | ||
|
||
|
||
IOU_THRESH = 0.5 | ||
|
||
|
||
def get_jaccard(prediction, ground_truth): | ||
remove_tokens = [".", ",", ";", ":"] | ||
|
||
for token in remove_tokens: | ||
ground_truth = ground_truth.replace(token, "") | ||
prediction = prediction.replace(token, "") | ||
|
||
ground_truth, prediction = ground_truth.lower(), prediction.lower() | ||
ground_truth, prediction = ground_truth.replace("/", " "), prediction.replace("/", " ") | ||
ground_truth, prediction = set(ground_truth.split(" ")), set(prediction.split(" ")) | ||
|
||
intersection = ground_truth.intersection(prediction) | ||
union = ground_truth.union(prediction) | ||
jaccard = len(intersection) / len(union) | ||
return jaccard | ||
|
||
|
||
def normalize_answer(s): | ||
"""Lower text and remove punctuation, articles and extra whitespace.""" | ||
|
||
def remove_articles(text): | ||
return re.sub(r"\b(a|an|the)\b", " ", text) | ||
|
||
def white_space_fix(text): | ||
return " ".join(text.split()) | ||
|
||
def remove_punc(text): | ||
exclude = set(string.punctuation) | ||
return "".join(ch for ch in text if ch not in exclude) | ||
|
||
def lower(text): | ||
return text.lower() | ||
|
||
return white_space_fix(remove_articles(remove_punc(lower(s)))) | ||
|
||
|
||
def compute_precision_recall(predictions, ground_truths, qa_id): | ||
tp, fp, fn = 0, 0, 0 | ||
|
||
substr_ok = "Parties" in qa_id | ||
|
||
# first check if ground truth is empty | ||
if len(ground_truths) == 0: | ||
if len(predictions) > 0: | ||
fp += len(predictions) # false positive for each one | ||
else: | ||
for ground_truth in ground_truths: | ||
assert len(ground_truth) > 0 | ||
# check if there is a match | ||
match_found = False | ||
for pred in predictions: | ||
if substr_ok: | ||
is_match = get_jaccard(pred, ground_truth) >= IOU_THRESH or ground_truth in pred | ||
else: | ||
is_match = get_jaccard(pred, ground_truth) >= IOU_THRESH | ||
if is_match: | ||
match_found = True | ||
|
||
if match_found: | ||
tp += 1 | ||
else: | ||
fn += 1 | ||
|
||
# now also get any fps by looping through preds | ||
for pred in predictions: | ||
# Check if there's a match. if so, don't count (don't want to double count based on the above) | ||
# but if there's no match, then this is a false positive. | ||
# (Note: we get the true positives in the above loop instead of this loop so that we don't double count | ||
# multiple predictions that are matched with the same answer.) | ||
match_found = False | ||
for ground_truth in ground_truths: | ||
assert len(ground_truth) > 0 | ||
if substr_ok: | ||
is_match = get_jaccard(pred, ground_truth) >= IOU_THRESH or ground_truth in pred | ||
else: | ||
is_match = get_jaccard(pred, ground_truth) >= IOU_THRESH | ||
if is_match: | ||
match_found = True | ||
|
||
if not match_found: | ||
fp += 1 | ||
|
||
precision = tp / (tp + fp) if tp + fp > 0 else np.nan | ||
recall = tp / (tp + fn) if tp + fn > 0 else np.nan | ||
|
||
return precision, recall | ||
|
||
|
||
def process_precisions(precisions): | ||
""" | ||
Processes precisions to ensure that precision and recall don't both get worse. | ||
Assumes the list precision is sorted in order of recalls | ||
""" | ||
precision_best = precisions[::-1] | ||
for i in range(1, len(precision_best)): | ||
precision_best[i] = max(precision_best[i - 1], precision_best[i]) | ||
precisions = precision_best[::-1] | ||
return precisions | ||
|
||
|
||
def get_aupr(precisions, recalls): | ||
processed_precisions = process_precisions(precisions) | ||
aupr = np.trapz(processed_precisions, recalls) | ||
if np.isnan(aupr): | ||
return 0 | ||
return aupr | ||
|
||
|
||
def get_prec_at_recall(precisions, recalls, recall_thresh): | ||
"""Assumes recalls are sorted in increasing order""" | ||
processed_precisions = process_precisions(precisions) | ||
prec_at_recall = 0 | ||
for prec, recall in zip(processed_precisions, recalls): | ||
if recall >= recall_thresh: | ||
prec_at_recall = prec | ||
break | ||
return prec_at_recall | ||
|
||
|
||
def exact_match_score(prediction, ground_truth): | ||
return normalize_answer(prediction) == normalize_answer(ground_truth) | ||
|
||
|
||
def metric_max_over_ground_truths(metric_fn, predictions, ground_truths): | ||
score = 0 | ||
for pred in predictions: | ||
for ground_truth in ground_truths: | ||
score = metric_fn(pred, ground_truth) | ||
if score == 1: # break the loop when one prediction matches the ground truth | ||
break | ||
if score == 1: | ||
break | ||
return score | ||
|
||
|
||
def evaluate(dataset, predictions): | ||
f1 = exact_match = total = 0 | ||
precisions = [] | ||
recalls = [] | ||
for article in dataset: | ||
for paragraph in article["paragraphs"]: | ||
for qa in paragraph["qas"]: | ||
total += 1 | ||
if qa["id"] not in predictions: | ||
message = "Unanswered question " + qa["id"] + " will receive score 0." | ||
print(message, file=sys.stderr) | ||
continue | ||
ground_truths = list(map(lambda x: x["text"], qa["answers"])) | ||
prediction = predictions[qa["id"]] | ||
precision, recall = compute_precision_recall(prediction, ground_truths, qa["id"]) | ||
|
||
precisions.append(precision) | ||
recalls.append(recall) | ||
|
||
if precision == 0 and recall == 0: | ||
f1 += 0 | ||
else: | ||
f1 += 2 * (precision * recall) / (precision + recall) | ||
|
||
exact_match += metric_max_over_ground_truths(exact_match_score, prediction, ground_truths) | ||
|
||
precisions = [x for _, x in sorted(zip(recalls, precisions))] | ||
recalls.sort() | ||
|
||
f1 = 100.0 * f1 / total | ||
exact_match = 100.0 * exact_match / total | ||
aupr = get_aupr(precisions, recalls) | ||
|
||
prec_at_90_recall = get_prec_at_recall(precisions, recalls, recall_thresh=0.9) | ||
prec_at_80_recall = get_prec_at_recall(precisions, recalls, recall_thresh=0.8) | ||
|
||
return { | ||
"exact_match": exact_match, | ||
"f1": f1, | ||
"aupr": aupr, | ||
"prec_at_80_recall": prec_at_80_recall, | ||
"prec_at_90_recall": prec_at_90_recall, | ||
} | ||
|
||
|
||
if __name__ == "__main__": | ||
parser = argparse.ArgumentParser(description="Evaluation for CUAD") | ||
parser.add_argument("dataset_file", help="Dataset file") | ||
parser.add_argument("prediction_file", help="Prediction File") | ||
args = parser.parse_args() | ||
with open(args.dataset_file) as dataset_file: | ||
dataset_json = json.load(dataset_file) | ||
dataset = dataset_json["data"] | ||
with open(args.prediction_file) as prediction_file: | ||
predictions = json.load(prediction_file) | ||
print(json.dumps(evaluate(dataset, predictions))) |
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Maybe mention in the docstring that users can use multiple answers depending on a threshold on the confidence probability of each prediction ?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Added changes in docstring for
prediction_text
. Let me know if something is also required