Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Added CUAD metrics #2273

Merged
merged 2 commits into from
Apr 29, 2021
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
115 changes: 115 additions & 0 deletions metrics/cuad/cuad.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,115 @@
# coding=utf-8
# Copyright 2020 The HuggingFace Datasets Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
""" CUAD metric. """

import datasets

from .evaluate import evaluate


_CITATION = """\
@article{hendrycks2021cuad,
title={CUAD: An Expert-Annotated NLP Dataset for Legal Contract Review},
author={Dan Hendrycks and Collin Burns and Anya Chen and Spencer Ball},
journal={arXiv preprint arXiv:2103.06268},
year={2021}
}
"""

_DESCRIPTION = """
This metric wrap the official scoring script for version 1 of the Contract
Understanding Atticus Dataset (CUAD).
Contract Understanding Atticus Dataset (CUAD) v1 is a corpus of more than 13,000 labels in 510
commercial legal contracts that have been manually labeled to identify 41 categories of important
clauses that lawyers look for when reviewing contracts in connection with corporate transactions.
"""

_KWARGS_DESCRIPTION = """
Computes CUAD scores (EM, F1, AUPR, Precision@80%Recall, and Precision@90%Recall).
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Maybe mention in the docstring that users can use multiple answers depending on a threshold on the confidence probability of each prediction ?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Added changes in docstring for prediction_text. Let me know if something is also required

Args:
predictions: List of question-answers dictionaries with the following key-values:
- 'id': id of the question-answer pair as given in the references (see below)
- 'prediction_text': list of possible texts for the answer, as a list of strings
references: List of question-answers dictionaries with the following key-values:
- 'id': id of the question-answer pair (see above),
- 'answers': a Dict in the CUAD dataset format
{
'text': list of possible texts for the answer, as a list of strings
'answer_start': list of start positions for the answer, as a list of ints
}
Note that answer_start values are not taken into account to compute the metric.
Returns:
'exact_match': Exact match (the normalized answer exactly match the gold answer)
'f1': The F-score of predicted tokens versus the gold answer
'aupr': Area Under the Precision-Recall curve
'prec_at_80_recall': Precision at 80% recall
'prec_at_90_recall': Precision at 90% recall
Examples:
>>> predictions = [{'prediction_text': ['The seller:', 'The buyer/End-User: Shenzhen LOHAS Supply Chain Management Co., Ltd.'], 'id': 'LohaCompanyltd_20191209_F-1_EX-10.16_11917878_EX-10.16_Supply Agreement__Parties'}]
>>> references = [{'answers': {'answer_start': [143, 49], 'text': ['The seller:', 'The buyer/End-User: Shenzhen LOHAS Supply Chain Management Co., Ltd.']}, 'id': 'LohaCompanyltd_20191209_F-1_EX-10.16_11917878_EX-10.16_Supply Agreement__Parties'}]
>>> cuad_metric = datasets.load_metric("cuad")
>>> results = cuad_metric.compute(predictions=predictions, references=references)
>>> print(results)
{'exact_match': 100.0, 'f1': 100.0, 'aupr': 0.0, 'prec_at_80_recall': 1.0, 'prec_at_90_recall': 1.0}
"""


@datasets.utils.file_utils.add_start_docstrings(_DESCRIPTION, _KWARGS_DESCRIPTION)
class CUAD(datasets.Metric):
def _info(self):
return datasets.MetricInfo(
description=_DESCRIPTION,
citation=_CITATION,
inputs_description=_KWARGS_DESCRIPTION,
features=datasets.Features(
{
"predictions": {
"id": datasets.Value("string"),
"prediction_text": datasets.features.Sequence(datasets.Value("string")),
},
"references": {
"id": datasets.Value("string"),
"answers": datasets.features.Sequence(
{
"text": datasets.Value("string"),
"answer_start": datasets.Value("int32"),
}
),
},
}
),
codebase_urls=["https://www.atticusprojectai.org/cuad"],
reference_urls=["https://www.atticusprojectai.org/cuad"],
)

def _compute(self, predictions, references):
pred_dict = {prediction["id"]: prediction["prediction_text"] for prediction in predictions}
dataset = [
{
"paragraphs": [
{
"qas": [
{
"answers": [{"text": answer_text} for answer_text in ref["answers"]["text"]],
"id": ref["id"],
}
for ref in references
]
}
]
}
]
score = evaluate(dataset=dataset, predictions=pred_dict)
return score
205 changes: 205 additions & 0 deletions metrics/cuad/evaluate.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,205 @@
""" Official evaluation script for CUAD dataset. """

import argparse
import json
import re
import string
import sys

import numpy as np


IOU_THRESH = 0.5


def get_jaccard(prediction, ground_truth):
remove_tokens = [".", ",", ";", ":"]

for token in remove_tokens:
ground_truth = ground_truth.replace(token, "")
prediction = prediction.replace(token, "")

ground_truth, prediction = ground_truth.lower(), prediction.lower()
ground_truth, prediction = ground_truth.replace("/", " "), prediction.replace("/", " ")
ground_truth, prediction = set(ground_truth.split(" ")), set(prediction.split(" "))

intersection = ground_truth.intersection(prediction)
union = ground_truth.union(prediction)
jaccard = len(intersection) / len(union)
return jaccard


def normalize_answer(s):
"""Lower text and remove punctuation, articles and extra whitespace."""

def remove_articles(text):
return re.sub(r"\b(a|an|the)\b", " ", text)

def white_space_fix(text):
return " ".join(text.split())

def remove_punc(text):
exclude = set(string.punctuation)
return "".join(ch for ch in text if ch not in exclude)

def lower(text):
return text.lower()

return white_space_fix(remove_articles(remove_punc(lower(s))))


def compute_precision_recall(predictions, ground_truths, qa_id):
tp, fp, fn = 0, 0, 0

substr_ok = "Parties" in qa_id

# first check if ground truth is empty
if len(ground_truths) == 0:
if len(predictions) > 0:
fp += len(predictions) # false positive for each one
else:
for ground_truth in ground_truths:
assert len(ground_truth) > 0
# check if there is a match
match_found = False
for pred in predictions:
if substr_ok:
is_match = get_jaccard(pred, ground_truth) >= IOU_THRESH or ground_truth in pred
else:
is_match = get_jaccard(pred, ground_truth) >= IOU_THRESH
if is_match:
match_found = True

if match_found:
tp += 1
else:
fn += 1

# now also get any fps by looping through preds
for pred in predictions:
# Check if there's a match. if so, don't count (don't want to double count based on the above)
# but if there's no match, then this is a false positive.
# (Note: we get the true positives in the above loop instead of this loop so that we don't double count
# multiple predictions that are matched with the same answer.)
match_found = False
for ground_truth in ground_truths:
assert len(ground_truth) > 0
if substr_ok:
is_match = get_jaccard(pred, ground_truth) >= IOU_THRESH or ground_truth in pred
else:
is_match = get_jaccard(pred, ground_truth) >= IOU_THRESH
if is_match:
match_found = True

if not match_found:
fp += 1

precision = tp / (tp + fp) if tp + fp > 0 else np.nan
recall = tp / (tp + fn) if tp + fn > 0 else np.nan

return precision, recall


def process_precisions(precisions):
"""
Processes precisions to ensure that precision and recall don't both get worse.
Assumes the list precision is sorted in order of recalls
"""
precision_best = precisions[::-1]
for i in range(1, len(precision_best)):
precision_best[i] = max(precision_best[i - 1], precision_best[i])
precisions = precision_best[::-1]
return precisions


def get_aupr(precisions, recalls):
processed_precisions = process_precisions(precisions)
aupr = np.trapz(processed_precisions, recalls)
if np.isnan(aupr):
return 0
return aupr


def get_prec_at_recall(precisions, recalls, recall_thresh):
"""Assumes recalls are sorted in increasing order"""
processed_precisions = process_precisions(precisions)
prec_at_recall = 0
for prec, recall in zip(processed_precisions, recalls):
if recall >= recall_thresh:
prec_at_recall = prec
break
return prec_at_recall


def exact_match_score(prediction, ground_truth):
return normalize_answer(prediction) == normalize_answer(ground_truth)


def metric_max_over_ground_truths(metric_fn, predictions, ground_truths):
score = 0
for pred in predictions:
for ground_truth in ground_truths:
score = metric_fn(pred, ground_truth)
if score == 1: # break the loop when one prediction matches the ground truth
break
if score == 1:
break
return score


def evaluate(dataset, predictions):
f1 = exact_match = total = 0
precisions = []
recalls = []
for article in dataset:
for paragraph in article["paragraphs"]:
for qa in paragraph["qas"]:
total += 1
if qa["id"] not in predictions:
message = "Unanswered question " + qa["id"] + " will receive score 0."
print(message, file=sys.stderr)
continue
ground_truths = list(map(lambda x: x["text"], qa["answers"]))
prediction = predictions[qa["id"]]
precision, recall = compute_precision_recall(prediction, ground_truths, qa["id"])

precisions.append(precision)
recalls.append(recall)

if precision == 0 and recall == 0:
f1 += 0
else:
f1 += 2 * (precision * recall) / (precision + recall)

exact_match += metric_max_over_ground_truths(exact_match_score, prediction, ground_truths)

precisions = [x for _, x in sorted(zip(recalls, precisions))]
recalls.sort()

f1 = 100.0 * f1 / total
exact_match = 100.0 * exact_match / total
aupr = get_aupr(precisions, recalls)

prec_at_90_recall = get_prec_at_recall(precisions, recalls, recall_thresh=0.9)
prec_at_80_recall = get_prec_at_recall(precisions, recalls, recall_thresh=0.8)

return {
"exact_match": exact_match,
"f1": f1,
"aupr": aupr,
"prec_at_80_recall": prec_at_80_recall,
"prec_at_90_recall": prec_at_90_recall,
}


if __name__ == "__main__":
parser = argparse.ArgumentParser(description="Evaluation for CUAD")
parser.add_argument("dataset_file", help="Dataset file")
parser.add_argument("prediction_file", help="Prediction File")
args = parser.parse_args()
with open(args.dataset_file) as dataset_file:
dataset_json = json.load(dataset_file)
dataset = dataset_json["data"]
with open(args.prediction_file) as prediction_file:
predictions = json.load(prediction_file)
print(json.dumps(evaluate(dataset, predictions)))