Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

metrics: add BLEU #2535

Merged
merged 29 commits into from
Jul 22, 2020
Merged
Show file tree
Hide file tree
Changes from 28 commits
Commits
Show all changes
29 commits
Select commit Hold shift + click to select a range
5cc619c
metrics: added bleu score and test bleu
Jun 25, 2020
c186a4a
metrics: fixed type hints in bleu
Jun 25, 2020
3fc1f08
bleu score moved to metrics/functional/nlp.py
Jul 7, 2020
22267c6
refactor with torch.Tensor
Jul 7, 2020
25bd5f0
Update test_sequence.py
Borda Jul 7, 2020
15af2df
refactor as Borda requests and nltk==3.2
Jul 8, 2020
4fd8e30
locked nltk==3.3
Jul 8, 2020
339a1ca
nltk>=3.3, parametrized smooth argument for test
Jul 8, 2020
1bfca67
fix bleu_score example
Jul 8, 2020
d6b426c
added class BLEUScore metrics and test
Jul 11, 2020
cd21b81
added class BLEUScore metrics and test
Jul 11, 2020
2aba4dc
update CHANGELOG
Jul 11, 2020
fc79638
refactor with torchtext
Jul 12, 2020
5c42ecf
torchtext changed to optional import
Jul 15, 2020
f76576e
fix E501 line too long
Jul 15, 2020
9107284
add else: in optional import
Jul 15, 2020
efd09c6
remove pragma: no-cover
Jul 15, 2020
cd03664
constants changed to CAPITALS
Jul 15, 2020
92e7dd4
remove class in tests
Jul 15, 2020
420bb1f
List -> Sequence, conda -> pip, cast with tensor
Jul 15, 2020
8af8f90
add torchtext in test.txt
Jul 16, 2020
3c74bb2
remove torchtext from test.txt
Jul 16, 2020
49ce1bb
bump torchtext to 0.5.0
Jul 16, 2020
5647ab8
bump torchtext to 0.5.0
Jul 16, 2020
f9d598a
Apply suggestions from code review
Borda Jul 16, 2020
47477a0
ignore bleu score in doctest, renamed to nlp.py
Jul 17, 2020
bf254e7
back to implementation with torch
Jul 17, 2020
1f85f2a
remove --ignore in CI test, proper reference format
Jul 18, 2020
7a2f80b
apply justus comment
Jul 20, 2020
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).

### Added

- Added BLEU metrics ([#2535](https://github.com/PyTorchLightning/pytorch-lightning/pull/2535))

### Changed

Expand Down
34 changes: 23 additions & 11 deletions docs/source/metrics.rst
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@ Example::

.. warning::
The metrics package is still in development! If we're missing a metric or you find a mistake, please send a PR!
to a few metrics. Please feel free to create an issue/PR if you have a proposed metric or have found a bug.
to a few metrics. Please feel free to create an issue/PR if you have a proposed metric or have found a bug.
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

it seems to be something missing here...

Copy link
Contributor Author

@ydcjeff ydcjeff Jul 16, 2020

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The metrics package is still in development! If we’re missing a metric or you find a mistake, please send a PR! to a few metrics. Please feel free to create an issue/PR if you have a proposed metric or have found a bug.

Same in the latest docs.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

it has been added in #2209 @SkafteNicki mind check what you had in mind? :]


----------------

Expand Down Expand Up @@ -73,7 +73,7 @@ Here's an example showing how to implement a NumpyMetric
class RMSE(NumpyMetric):
def forward(self, x, y):
return np.sqrt(np.mean(np.power(x-y, 2.0)))


.. autoclass:: pytorch_lightning.metrics.metric.NumpyMetric
:noindex:
Expand Down Expand Up @@ -138,6 +138,12 @@ AUROC
.. autoclass:: pytorch_lightning.metrics.classification.AUROC
:noindex:

BLEUScore
^^^^^^^^^

.. autoclass:: pytorch_lightning.metrics.nlp.BLEUScore
:noindex:

ConfusionMatrix
^^^^^^^^^^^^^^^

Expand Down Expand Up @@ -283,6 +289,12 @@ average_precision (F)
.. autofunction:: pytorch_lightning.metrics.functional.average_precision
:noindex:

bleu_score (F)
^^^^^^^^^^^^^^

.. autofunction:: pytorch_lightning.metrics.functional.bleu_score
:noindex:

confusion_matrix (F)
^^^^^^^^^^^^^^^^^^^^

Expand Down Expand Up @@ -418,22 +430,22 @@ to_onehot (F)

Sklearn interface
-----------------
Lightning supports `sklearns metrics module <https://scikit-learn.org/stable/modules/classes.html#module-sklearn.metrics>`_
as a backend for calculating metrics. Sklearns metrics are well tested and robust,

Lightning supports `sklearns metrics module <https://scikit-learn.org/stable/modules/classes.html#module-sklearn.metrics>`_
as a backend for calculating metrics. Sklearns metrics are well tested and robust,
but requires conversion between pytorch and numpy thus may slow down your computations.

To use the sklearn backend of metrics simply import as

.. code-block:: python

import pytorch_lightning.metrics.sklearns import plm
metric = plm.Accuracy(normalize=True)
val = metric(pred, target)
Each converted sklearn metric comes has the same interface as its
original counterpart (e.g. accuracy takes the additional `normalize` keyword).
Like the native Lightning metrics, these converted sklearn metrics also come

Each converted sklearn metric comes has the same interface as its
original counterpart (e.g. accuracy takes the additional `normalize` keyword).
Like the native Lightning metrics, these converted sklearn metrics also come
with built-in distributed (ddp) support.

SklearnMetric (sk)
Expand All @@ -460,7 +472,7 @@ AveragePrecision (sk)
.. autofunction:: pytorch_lightning.metrics.sklearns.AveragePrecision
:noindex:


ConfusionMatrix (sk)
^^^^^^^^^^^^^^^^^^^^

Expand Down
1 change: 1 addition & 0 deletions environment.yml
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@ dependencies:
- twine==1.13.0
- pillow<7.0.0
- scikit-image
- nltk>=3.3

# Optional
- scipy>=0.13.3
Expand Down
48 changes: 25 additions & 23 deletions pytorch_lightning/metrics/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
MSE,
PSNR,
RMSE,
RMSLE
RMSLE,
)
from pytorch_lightning.metrics.classification import (
Accuracy,
Expand All @@ -28,30 +28,32 @@
PrecisionRecallCurve,
SklearnMetric,
)
from pytorch_lightning.metrics.nlp import BLEUScore

__classification_metrics = [
'AUC',
'AUROC',
'Accuracy',
'AveragePrecision',
'ConfusionMatrix',
'DiceCoefficient',
'F1',
'FBeta',
'MulticlassPrecisionRecall',
'MulticlassROC',
'Precision',
'PrecisionRecall',
'PrecisionRecallCurve',
'ROC',
'Recall',
'IoU',
"AUC",
"AUROC",
"Accuracy",
"AveragePrecision",
"ConfusionMatrix",
"DiceCoefficient",
"F1",
"FBeta",
"MulticlassPrecisionRecall",
"MulticlassROC",
"Precision",
"PrecisionRecall",
"PrecisionRecallCurve",
"ROC",
"Recall",
"IoU",
]
__regression_metrics = [
'MAE',
'MSE',
'PSNR',
'RMSE',
'RMSLE'
"MAE",
"MSE",
"PSNR",
"RMSE",
"RMSLE",
]
__all__ = __regression_metrics + __classification_metrics + ['SklearnMetric']
__sequence_metrics = ["BLEUScore"]
__all__ = __regression_metrics + __classification_metrics + ["SklearnMetric"] + __sequence_metrics
3 changes: 2 additions & 1 deletion pytorch_lightning/metrics/functional/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,5 +25,6 @@
mse,
psnr,
rmse,
rmsle
rmsle,
)
from pytorch_lightning.metrics.functional.nlp import bleu_score
92 changes: 92 additions & 0 deletions pytorch_lightning/metrics/functional/nlp.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,92 @@
# referenced from
# Library Name: torchtext
# Authors: torchtext authors and @sluks
# Date: 2020-07-18
# Link: https://pytorch.org/text/_modules/torchtext/data/metrics.html#bleu_score
from typing import Sequence
Borda marked this conversation as resolved.
Show resolved Hide resolved
from collections import Counter

import torch


def _count_ngram(ngram_input_list: list, n_gram: int) -> Counter:
Borda marked this conversation as resolved.
Show resolved Hide resolved
"""Counting how many times each word appears in a given text with ngram

Args:
ngram_input_list: A list of translated text or reference texts
n_gram: gram value ranged 1 to 4

Return:
ngram_counter: a collections.Counter object of ngram
"""

ngram_counter = Counter()

for i in range(1, n_gram + 1):
for j in range(len(ngram_input_list) - i + 1):
ngram_key = tuple(ngram_input_list[j : i + j])
ngram_counter[ngram_key] += 1

return ngram_counter


def bleu_score(
translate_corpus: Sequence, reference_corpus: Sequence, n_gram: int = 4, smooth: bool = False
Borda marked this conversation as resolved.
Show resolved Hide resolved
) -> torch.Tensor:
"""Calculate BLEU score of machine translated text with one or more references.

Args:
translate_corpus: An iterable of machine translated corpus
reference_corpus: An iterable of iterables of reference corpus
n_gram: Gram value ranged from 1 to 4 (Default 4)
smooth: Whether or not to apply smoothing – Lin et al. 2004

Return:
A Tensor with BLEU Score

Example:

>>> translate_corpus = ['the cat is on the mat'.split()]
>>> reference_corpus = [['there is a cat on the mat'.split(), 'a cat is on the mat'.split()]]
>>> bleu_score(translate_corpus, reference_corpus)
tensor(0.7598)
"""

assert len(translate_corpus) == len(reference_corpus)
numerator = torch.zeros(n_gram)
denominator = torch.zeros(n_gram)
precision_scores = torch.zeros(n_gram)
c = 0.0
r = 0.0
for (translation, references) in zip(translate_corpus, reference_corpus):
c += len(translation)
ref_len_list = [len(ref) for ref in references]
ref_len_diff = [abs(len(translation) - x) for x in ref_len_list]
r += ref_len_list[ref_len_diff.index(min(ref_len_diff))]
translation_counter = _count_ngram(translation, n_gram)
reference_counter = Counter()
for ref in references:
reference_counter |= _count_ngram(ref, n_gram)

ngram_counter_clip = translation_counter & reference_counter
for counter_clip in ngram_counter_clip:
numerator[len(counter_clip) - 1] += ngram_counter_clip[counter_clip]

for counter in translation_counter:
denominator[len(counter) - 1] += translation_counter[counter]

trans_len = torch.tensor(c)
ref_len = torch.tensor(r)
if min(numerator) == 0.0:
return torch.tensor(0.0)

if smooth:
precision_scores = torch.add(numerator, torch.ones(n_gram)) / torch.add(denominator, torch.ones(n_gram))
else:
precision_scores = numerator / denominator
log_precision_scores = torch.tensor([1.0 / n_gram] * n_gram) * torch.log(precision_scores)
geometric_mean = torch.exp(torch.sum(log_precision_scores))
brevity_penalty = torch.tensor(1.0) if c > r else torch.exp(1 - (ref_len / trans_len))
bleu = brevity_penalty * geometric_mean

return bleu
46 changes: 46 additions & 0 deletions pytorch_lightning/metrics/nlp.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,46 @@
import torch

from pytorch_lightning.metrics.functional.nlp import bleu_score
from pytorch_lightning.metrics.metric import Metric


class BLEUScore(Metric):
"""
Calculate BLEU score of machine translated text with one or more references.

Example:

>>> translate_corpus = ['the cat is on the mat'.split()]
>>> reference_corpus = [['there is a cat on the mat'.split(), 'a cat is on the mat'.split()]]
>>> metric = BLEUScore()
>>> metric(translate_corpus, reference_corpus)
tensor(0.7598)
"""

def __init__(self, n_gram: int = 4, smooth: bool = False):
"""
Args:
n_gram: Gram value ranged from 1 to 4 (Default 4)
smooth: Whether or not to apply smoothing – Lin et al. 2004
"""
super().__init__(name="bleu")
self.n_gram = n_gram
self.smooth = smooth

def forward(self, translate_corpus: list, reference_corpus: list) -> torch.Tensor:
"""
Actual metric computation

Args:
translate_corpus: An iterable of machine translated corpus
reference_corpus: An iterable of iterables of reference corpus

Return:
torch.Tensor: BLEU Score
"""
return bleu_score(
translate_corpus=translate_corpus,
reference_corpus=reference_corpus,
n_gram=self.n_gram,
smooth=self.smooth,
)
Borda marked this conversation as resolved.
Show resolved Hide resolved
2 changes: 1 addition & 1 deletion requirements/extra.txt
Original file line number Diff line number Diff line change
Expand Up @@ -11,4 +11,4 @@ horovod>=0.19.1
omegaconf>=2.0.0
# scipy>=0.13.3
scikit-learn>=0.20.0
torchtext>=0.3.1
torchtext>=0.3.1
1 change: 1 addition & 0 deletions requirements/test.txt
Original file line number Diff line number Diff line change
Expand Up @@ -12,3 +12,4 @@ black==19.10b0
pre-commit>=1.0

cloudpickle>=1.2
nltk>=3.3
66 changes: 66 additions & 0 deletions tests/metrics/functional/test_nlp.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,66 @@
import pytest
import torch
from nltk.translate.bleu_score import SmoothingFunction, corpus_bleu, sentence_bleu

from pytorch_lightning.metrics.functional.nlp import bleu_score

# example taken from
# https://www.nltk.org/api/nltk.translate.html?highlight=bleu%20score#nltk.translate.bleu_score.sentence_bleu
HYPOTHESIS1 = tuple(
"It is a guide to action which ensures that the military always obeys the commands of the party".split()
)
REFERENCE1 = tuple("It is a guide to action that ensures that the military will forever heed Party commands".split())
REFERENCE2 = tuple(
"It is a guiding principle which makes the military forces always being under the command of the Party".split()
)
REFERENCE3 = tuple("It is the practical guide for the army always to heed the directions of the party".split())


# example taken from
# https://www.nltk.org/api/nltk.translate.html?highlight=bleu%20score#nltk.translate.bleu_score.corpus_bleu
HYP1 = "It is a guide to action which ensures that the military always obeys the commands of the party".split()
HYP2 = "he read the book because he was interested in world history".split()

REF1A = "It is a guide to action that ensures that the military will forever heed Party commands".split()
REF1B = "It is a guiding principle which makes the military force always being under the command of the Party".split()
REF1C = "It is the practical guide for the army always to heed the directions of the party".split()
REF2A = "he was interested in world history because he read the book".split()

LIST_OF_REFERENCES = [[REF1A, REF1B, REF1C], [REF2A]]
HYPOTHESES = [HYP1, HYP2]

# https://www.nltk.org/api/nltk.translate.html?highlight=bleu%20score#nltk.translate.bleu_score.SmoothingFunction
smooth_func = SmoothingFunction().method2


@pytest.mark.parametrize(
["weights", "n_gram", "smooth_func", "smooth"],
[
pytest.param([1], 1, None, False),
pytest.param([0.5, 0.5], 2, smooth_func, True),
pytest.param([0.333333, 0.333333, 0.333333], 3, None, False),
pytest.param([0.25, 0.25, 0.25, 0.25], 4, smooth_func, True),
],
)
def test_bleu_score(weights, n_gram, smooth_func, smooth):
nltk_output = sentence_bleu(
[REFERENCE1, REFERENCE2, REFERENCE3], HYPOTHESIS1, weights=weights, smoothing_function=smooth_func
)
pl_output = bleu_score([HYPOTHESIS1], [[REFERENCE1, REFERENCE2, REFERENCE3]], n_gram=n_gram, smooth=smooth)
assert torch.allclose(pl_output, torch.tensor(nltk_output))

nltk_output = corpus_bleu(LIST_OF_REFERENCES, HYPOTHESES, weights=weights, smoothing_function=smooth_func)
pl_output = bleu_score(HYPOTHESES, LIST_OF_REFERENCES, n_gram=n_gram, smooth=smooth)
assert torch.allclose(pl_output, torch.tensor(nltk_output))


def test_bleu_empty():
hyp = [[]]
ref = [[[]]]
assert bleu_score(hyp, ref) == torch.tensor(0.0)


def test_no_4_gram():
hyps = [["My", "full", "pytorch-lightning"]]
refs = [[["My", "full", "pytorch-lightning", "test"], ["Completely", "Different"]]]
assert bleu_score(hyps, refs) == torch.tensor(0.0)
Loading