-
Notifications
You must be signed in to change notification settings - Fork 3.4k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
metrics: add BLEU #2535
Merged
Merged
metrics: add BLEU #2535
Changes from all commits
Commits
Show all changes
29 commits
Select commit
Hold shift + click to select a range
5cc619c
metrics: added bleu score and test bleu
c186a4a
metrics: fixed type hints in bleu
3fc1f08
bleu score moved to metrics/functional/nlp.py
22267c6
refactor with torch.Tensor
25bd5f0
Update test_sequence.py
Borda 15af2df
refactor as Borda requests and nltk==3.2
4fd8e30
locked nltk==3.3
339a1ca
nltk>=3.3, parametrized smooth argument for test
1bfca67
fix bleu_score example
d6b426c
added class BLEUScore metrics and test
cd21b81
added class BLEUScore metrics and test
2aba4dc
update CHANGELOG
fc79638
refactor with torchtext
5c42ecf
torchtext changed to optional import
f76576e
fix E501 line too long
9107284
add else: in optional import
efd09c6
remove pragma: no-cover
cd03664
constants changed to CAPITALS
92e7dd4
remove class in tests
420bb1f
List -> Sequence, conda -> pip, cast with tensor
8af8f90
add torchtext in test.txt
3c74bb2
remove torchtext from test.txt
49ce1bb
bump torchtext to 0.5.0
5647ab8
bump torchtext to 0.5.0
f9d598a
Apply suggestions from code review
Borda 47477a0
ignore bleu score in doctest, renamed to nlp.py
bf254e7
back to implementation with torch
1f85f2a
remove --ignore in CI test, proper reference format
7a2f80b
apply justus comment
File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -30,6 +30,7 @@ dependencies: | |
- twine==1.13.0 | ||
- pillow<7.0.0 | ||
- scikit-image | ||
- nltk>=3.3 | ||
|
||
# Optional | ||
- scipy>=0.13.3 | ||
|
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -25,5 +25,6 @@ | |
mse, | ||
psnr, | ||
rmse, | ||
rmsle | ||
rmsle, | ||
) | ||
from pytorch_lightning.metrics.functional.nlp import bleu_score |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,92 @@ | ||
# referenced from | ||
# Library Name: torchtext | ||
# Authors: torchtext authors and @sluks | ||
# Date: 2020-07-18 | ||
# Link: https://pytorch.org/text/_modules/torchtext/data/metrics.html#bleu_score | ||
from typing import Sequence, List | ||
from collections import Counter | ||
|
||
import torch | ||
|
||
|
||
def _count_ngram(ngram_input_list: List[str], n_gram: int) -> Counter: | ||
"""Counting how many times each word appears in a given text with ngram | ||
|
||
Args: | ||
ngram_input_list: A list of translated text or reference texts | ||
n_gram: gram value ranged 1 to 4 | ||
|
||
Return: | ||
ngram_counter: a collections.Counter object of ngram | ||
""" | ||
|
||
ngram_counter = Counter() | ||
|
||
for i in range(1, n_gram + 1): | ||
for j in range(len(ngram_input_list) - i + 1): | ||
ngram_key = tuple(ngram_input_list[j : i + j]) | ||
ngram_counter[ngram_key] += 1 | ||
|
||
return ngram_counter | ||
|
||
|
||
def bleu_score( | ||
translate_corpus: Sequence[str], reference_corpus: Sequence[str], n_gram: int = 4, smooth: bool = False | ||
) -> torch.Tensor: | ||
"""Calculate BLEU score of machine translated text with one or more references. | ||
|
||
Args: | ||
translate_corpus: An iterable of machine translated corpus | ||
reference_corpus: An iterable of iterables of reference corpus | ||
n_gram: Gram value ranged from 1 to 4 (Default 4) | ||
smooth: Whether or not to apply smoothing – Lin et al. 2004 | ||
|
||
Return: | ||
A Tensor with BLEU Score | ||
|
||
Example: | ||
|
||
>>> translate_corpus = ['the cat is on the mat'.split()] | ||
>>> reference_corpus = [['there is a cat on the mat'.split(), 'a cat is on the mat'.split()]] | ||
>>> bleu_score(translate_corpus, reference_corpus) | ||
tensor(0.7598) | ||
""" | ||
|
||
assert len(translate_corpus) == len(reference_corpus) | ||
numerator = torch.zeros(n_gram) | ||
denominator = torch.zeros(n_gram) | ||
precision_scores = torch.zeros(n_gram) | ||
c = 0.0 | ||
r = 0.0 | ||
for (translation, references) in zip(translate_corpus, reference_corpus): | ||
c += len(translation) | ||
ref_len_list = [len(ref) for ref in references] | ||
ref_len_diff = [abs(len(translation) - x) for x in ref_len_list] | ||
r += ref_len_list[ref_len_diff.index(min(ref_len_diff))] | ||
translation_counter = _count_ngram(translation, n_gram) | ||
reference_counter = Counter() | ||
for ref in references: | ||
reference_counter |= _count_ngram(ref, n_gram) | ||
|
||
ngram_counter_clip = translation_counter & reference_counter | ||
for counter_clip in ngram_counter_clip: | ||
numerator[len(counter_clip) - 1] += ngram_counter_clip[counter_clip] | ||
|
||
for counter in translation_counter: | ||
denominator[len(counter) - 1] += translation_counter[counter] | ||
|
||
trans_len = torch.tensor(c) | ||
ref_len = torch.tensor(r) | ||
if min(numerator) == 0.0: | ||
return torch.tensor(0.0) | ||
|
||
if smooth: | ||
precision_scores = torch.add(numerator, torch.ones(n_gram)) / torch.add(denominator, torch.ones(n_gram)) | ||
else: | ||
precision_scores = numerator / denominator | ||
log_precision_scores = torch.tensor([1.0 / n_gram] * n_gram) * torch.log(precision_scores) | ||
geometric_mean = torch.exp(torch.sum(log_precision_scores)) | ||
brevity_penalty = torch.tensor(1.0) if c > r else torch.exp(1 - (ref_len / trans_len)) | ||
bleu = brevity_penalty * geometric_mean | ||
|
||
return bleu |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,46 @@ | ||
import torch | ||
|
||
from pytorch_lightning.metrics.functional.nlp import bleu_score | ||
from pytorch_lightning.metrics.metric import Metric | ||
|
||
|
||
class BLEUScore(Metric): | ||
""" | ||
Calculate BLEU score of machine translated text with one or more references. | ||
|
||
Example: | ||
|
||
>>> translate_corpus = ['the cat is on the mat'.split()] | ||
>>> reference_corpus = [['there is a cat on the mat'.split(), 'a cat is on the mat'.split()]] | ||
>>> metric = BLEUScore() | ||
>>> metric(translate_corpus, reference_corpus) | ||
tensor(0.7598) | ||
""" | ||
|
||
def __init__(self, n_gram: int = 4, smooth: bool = False): | ||
""" | ||
Args: | ||
n_gram: Gram value ranged from 1 to 4 (Default 4) | ||
smooth: Whether or not to apply smoothing – Lin et al. 2004 | ||
""" | ||
super().__init__(name="bleu") | ||
self.n_gram = n_gram | ||
self.smooth = smooth | ||
|
||
def forward(self, translate_corpus: list, reference_corpus: list) -> torch.Tensor: | ||
""" | ||
Actual metric computation | ||
|
||
Args: | ||
translate_corpus: An iterable of machine translated corpus | ||
reference_corpus: An iterable of iterables of reference corpus | ||
|
||
Return: | ||
torch.Tensor: BLEU Score | ||
""" | ||
return bleu_score( | ||
translate_corpus=translate_corpus, | ||
reference_corpus=reference_corpus, | ||
n_gram=self.n_gram, | ||
smooth=self.smooth, | ||
).to(self.device, self.dtype) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -11,4 +11,4 @@ horovod>=0.19.1 | |
omegaconf>=2.0.0 | ||
# scipy>=0.13.3 | ||
scikit-learn>=0.20.0 | ||
torchtext>=0.3.1 | ||
torchtext>=0.3.1 |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -12,3 +12,4 @@ black==19.10b0 | |
pre-commit>=1.0 | ||
|
||
cloudpickle>=1.2 | ||
nltk>=3.3 |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,66 @@ | ||
import pytest | ||
import torch | ||
from nltk.translate.bleu_score import SmoothingFunction, corpus_bleu, sentence_bleu | ||
|
||
from pytorch_lightning.metrics.functional.nlp import bleu_score | ||
|
||
# example taken from | ||
# https://www.nltk.org/api/nltk.translate.html?highlight=bleu%20score#nltk.translate.bleu_score.sentence_bleu | ||
HYPOTHESIS1 = tuple( | ||
"It is a guide to action which ensures that the military always obeys the commands of the party".split() | ||
) | ||
REFERENCE1 = tuple("It is a guide to action that ensures that the military will forever heed Party commands".split()) | ||
REFERENCE2 = tuple( | ||
"It is a guiding principle which makes the military forces always being under the command of the Party".split() | ||
) | ||
REFERENCE3 = tuple("It is the practical guide for the army always to heed the directions of the party".split()) | ||
|
||
|
||
# example taken from | ||
# https://www.nltk.org/api/nltk.translate.html?highlight=bleu%20score#nltk.translate.bleu_score.corpus_bleu | ||
HYP1 = "It is a guide to action which ensures that the military always obeys the commands of the party".split() | ||
HYP2 = "he read the book because he was interested in world history".split() | ||
|
||
REF1A = "It is a guide to action that ensures that the military will forever heed Party commands".split() | ||
REF1B = "It is a guiding principle which makes the military force always being under the command of the Party".split() | ||
REF1C = "It is the practical guide for the army always to heed the directions of the party".split() | ||
REF2A = "he was interested in world history because he read the book".split() | ||
|
||
LIST_OF_REFERENCES = [[REF1A, REF1B, REF1C], [REF2A]] | ||
HYPOTHESES = [HYP1, HYP2] | ||
|
||
# https://www.nltk.org/api/nltk.translate.html?highlight=bleu%20score#nltk.translate.bleu_score.SmoothingFunction | ||
smooth_func = SmoothingFunction().method2 | ||
|
||
|
||
@pytest.mark.parametrize( | ||
["weights", "n_gram", "smooth_func", "smooth"], | ||
[ | ||
pytest.param([1], 1, None, False), | ||
pytest.param([0.5, 0.5], 2, smooth_func, True), | ||
pytest.param([0.333333, 0.333333, 0.333333], 3, None, False), | ||
pytest.param([0.25, 0.25, 0.25, 0.25], 4, smooth_func, True), | ||
], | ||
) | ||
def test_bleu_score(weights, n_gram, smooth_func, smooth): | ||
nltk_output = sentence_bleu( | ||
[REFERENCE1, REFERENCE2, REFERENCE3], HYPOTHESIS1, weights=weights, smoothing_function=smooth_func | ||
) | ||
pl_output = bleu_score([HYPOTHESIS1], [[REFERENCE1, REFERENCE2, REFERENCE3]], n_gram=n_gram, smooth=smooth) | ||
assert torch.allclose(pl_output, torch.tensor(nltk_output)) | ||
|
||
nltk_output = corpus_bleu(LIST_OF_REFERENCES, HYPOTHESES, weights=weights, smoothing_function=smooth_func) | ||
pl_output = bleu_score(HYPOTHESES, LIST_OF_REFERENCES, n_gram=n_gram, smooth=smooth) | ||
assert torch.allclose(pl_output, torch.tensor(nltk_output)) | ||
|
||
|
||
def test_bleu_empty(): | ||
hyp = [[]] | ||
ref = [[[]]] | ||
assert bleu_score(hyp, ref) == torch.tensor(0.0) | ||
|
||
|
||
def test_no_4_gram(): | ||
hyps = [["My", "full", "pytorch-lightning"]] | ||
refs = [[["My", "full", "pytorch-lightning", "test"], ["Completely", "Different"]]] | ||
assert bleu_score(hyps, refs) == torch.tensor(0.0) |
Oops, something went wrong.
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
it seems to be something missing here...
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Same in the latest docs.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
it has been added in #2209 @SkafteNicki mind check what you had in mind? :]