Skip to content

Commit

Permalink
Add Precision, Recall, F-measure, Confusion Matrix to Taggers (#2862)
Browse files Browse the repository at this point in the history
* Add Precision, Recall, F-measure, Confusion Matrix and per-tag evaluation to Taggers

And add precision, recall and f-measure to ConfusionMatrix.

Includes large doctests, and some small doctest fixes throughout the tag module

* Move evaluation of ConfusionMatrix into nltk\metrics\confusionmatrix.py

* Add self as author in significantly updated files

* Deprecate tagger evaluate(gold) in favor of accuracy(gold)

* Missed one case of Tagger evaluate still being used - fixed now

* Deprecate ChunkParser's evaluate(gold) in favor of accuracy(gold)

Co-authored-by: Steven Bird <stevenbird1@gmail.com>
  • Loading branch information
tomaarsen and stevenbird authored Dec 15, 2021
1 parent 72d9885 commit a28d256
Show file tree
Hide file tree
Showing 12 changed files with 833 additions and 25 deletions.
5 changes: 5 additions & 0 deletions nltk/chunk/api.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
##//////////////////////////////////////////////////////

from nltk.chunk.util import ChunkScore
from nltk.internals import deprecated
from nltk.parse import ParserI


Expand All @@ -34,7 +35,11 @@ def parse(self, tokens):
"""
raise NotImplementedError()

@deprecated("Use accuracy(gold) instead.")
def evaluate(self, gold):
return self.accuracy(gold)

def accuracy(self, gold):
"""
Score the accuracy of the chunker against the gold standard.
Remove the chunking the gold standard text, rechunk it using
Expand Down
137 changes: 137 additions & 0 deletions nltk/metrics/confusionmatrix.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
# Copyright (C) 2001-2021 NLTK Project
# Author: Edward Loper <edloper@gmail.com>
# Steven Bird <stevenbird1@gmail.com>
# Tom Aarsen <>
# URL: <https://www.nltk.org/>
# For license information, see LICENSE.TXT

Expand Down Expand Up @@ -201,6 +202,140 @@ def key(self):

return str

def recall(self, value):
"""Given a value in the confusion matrix, return the recall
that corresponds to this value. The recall is defined as:
- *r* = true positive / (true positive + false positive)
and can loosely be considered the ratio of how often ``value``
was predicted correctly relative to how often ``value`` was
the true result.
:param value: value used in the ConfusionMatrix
:return: the recall corresponding to ``value``.
:rtype: float
"""
# Number of times `value` was correct, and also predicted
TP = self[value, value]
# Number of times `value` was correct
TP_FN = sum(self[value, pred_value] for pred_value in self._values)
if TP_FN == 0:
return 0.0
return TP / TP_FN

def precision(self, value):
"""Given a value in the confusion matrix, return the precision
that corresponds to this value. The precision is defined as:
- *p* = true positive / (true positive + false negative)
and can loosely be considered the ratio of how often ``value``
was predicted correctly relative to the number of predictions
for ``value``.
:param value: value used in the ConfusionMatrix
:return: the precision corresponding to ``value``.
:rtype: float
"""
# Number of times `value` was correct, and also predicted
TP = self[value, value]
# Number of times `value` was predicted
TP_FP = sum(self[real_value, value] for real_value in self._values)
if TP_FP == 0:
return 0.0
return TP / TP_FP

def f_measure(self, value, alpha=0.5):
"""
Given a value used in the confusion matrix, return the f-measure
that corresponds to this value. The f-measure is the harmonic mean
of the ``precision`` and ``recall``, weighted by ``alpha``.
In particular, given the precision *p* and recall *r* defined by:
- *p* = true positive / (true positive + false negative)
- *r* = true positive / (true positive + false positive)
The f-measure is:
- *1/(alpha/p + (1-alpha)/r)*
With ``alpha = 0.5``, this reduces to:
- *2pr / (p + r)*
:param value: value used in the ConfusionMatrix
:param alpha: Ratio of the cost of false negative compared to false
positives. Defaults to 0.5, where the costs are equal.
:type alpha: float
:return: the F-measure corresponding to ``value``.
:rtype: float
"""
p = self.precision(value)
r = self.recall(value)
if p == 0.0 or r == 0.0:
return 0.0
return 1.0 / (alpha / p + (1 - alpha) / r)

def evaluate(self, alpha=0.5, truncate=None, sort_by_count=False):
"""
Tabulate the **recall**, **precision** and **f-measure**
for each value in this confusion matrix.
>>> reference = "DET NN VB DET JJ NN NN IN DET NN".split()
>>> test = "DET VB VB DET NN NN NN IN DET NN".split()
>>> cm = ConfusionMatrix(reference, test)
>>> print(cm.evaluate())
Tag | Prec. | Recall | F-measure
----+--------+--------+-----------
DET | 1.0000 | 1.0000 | 1.0000
IN | 1.0000 | 1.0000 | 1.0000
JJ | 0.0000 | 0.0000 | 0.0000
NN | 0.7500 | 0.7500 | 0.7500
VB | 0.5000 | 1.0000 | 0.6667
<BLANKLINE>
:param alpha: Ratio of the cost of false negative compared to false
positives, as used in the f-measure computation. Defaults to 0.5,
where the costs are equal.
:type alpha: float
:param truncate: If specified, then only show the specified
number of values. Any sorting (e.g., sort_by_count)
will be performed before truncation. Defaults to None
:type truncate: int, optional
:param sort_by_count: Whether to sort the outputs on frequency
in the reference label. Defaults to False.
:type sort_by_count: bool, optional
:return: A tabulated recall, precision and f-measure string
:rtype: str
"""
tags = self._values

# Apply keyword parameters
if sort_by_count:
tags = sorted(tags, key=lambda v: -sum(self._confusion[self._indices[v]]))
if truncate:
tags = tags[:truncate]

tag_column_len = max(max(len(tag) for tag in tags), 3)

# Construct the header
s = (
f"{' ' * (tag_column_len - 3)}Tag | Prec. | Recall | F-measure\n"
f"{'-' * tag_column_len}-+--------+--------+-----------\n"
)

# Construct the body
for tag in tags:
s += (
f"{tag:>{tag_column_len}} | "
f"{self.precision(tag):<6.4f} | "
f"{self.recall(tag):<6.4f} | "
f"{self.f_measure(tag, alpha=alpha):.4f}\n"
)

return s


def demo():
reference = "DET NN VB DET JJ NN NN IN DET NN".split()
Expand All @@ -211,6 +346,8 @@ def demo():
print(ConfusionMatrix(reference, test))
print(ConfusionMatrix(reference, test).pretty_format(sort_by_count=True))

print(ConfusionMatrix(reference, test).recall("VB"))


if __name__ == "__main__":
demo()
8 changes: 4 additions & 4 deletions nltk/tag/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@
An off-the-shelf tagger is available for English. It uses the Penn Treebank tagset:
>>> from nltk import pos_tag, word_tokenize
>>> pos_tag(word_tokenize("John's big idea isn't all that bad."))
>>> pos_tag(word_tokenize("John's big idea isn't all that bad.")) # doctest: +NORMALIZE_WHITESPACE
[('John', 'NNP'), ("'s", 'POS'), ('big', 'JJ'), ('idea', 'NN'), ('is', 'VBZ'),
("n't", 'RB'), ('all', 'PDT'), ('that', 'DT'), ('bad', 'JJ'), ('.', '.')]
Expand Down Expand Up @@ -57,7 +57,7 @@
We evaluate a tagger on data that was not seen during training:
>>> tagger.evaluate(brown.tagged_sents(categories='news')[500:600])
>>> tagger.accuracy(brown.tagged_sents(categories='news')[500:600])
0.7...
For more information, please consult chapter 5 of the NLTK Book.
Expand Down Expand Up @@ -144,10 +144,10 @@ def pos_tag(tokens, tagset=None, lang="eng"):
>>> from nltk.tag import pos_tag
>>> from nltk.tokenize import word_tokenize
>>> pos_tag(word_tokenize("John's big idea isn't all that bad."))
>>> pos_tag(word_tokenize("John's big idea isn't all that bad.")) # doctest: +NORMALIZE_WHITESPACE
[('John', 'NNP'), ("'s", 'POS'), ('big', 'JJ'), ('idea', 'NN'), ('is', 'VBZ'),
("n't", 'RB'), ('all', 'PDT'), ('that', 'DT'), ('bad', 'JJ'), ('.', '.')]
>>> pos_tag(word_tokenize("John's big idea isn't all that bad."), tagset='universal')
>>> pos_tag(word_tokenize("John's big idea isn't all that bad."), tagset='universal') # doctest: +NORMALIZE_WHITESPACE
[('John', 'NOUN'), ("'s", 'PRT'), ('big', 'ADJ'), ('idea', 'NOUN'), ('is', 'VERB'),
("n't", 'ADV'), ('all', 'DET'), ('that', 'DET'), ('bad', 'ADJ'), ('.', '.')]
Expand Down
Loading

0 comments on commit a28d256

Please sign in to comment.