diff --git a/CHANGELOG.md b/CHANGELOG.md index a890e96e3ee..7069c082e8c 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -38,6 +38,11 @@ dataset at every epoch) and a `MultiTaskScheduler` (for ordering the instances w ## Unreleased (1.x branch) +### Fixed + +- `GumbelSampler` now sorts the beams by their true log prob. + + ## [v1.2.1](https://github.com/allenai/allennlp/releases/tag/v1.2.1) - 2020-11-10 ### Added @@ -48,7 +53,7 @@ dataset at every epoch) and a `MultiTaskScheduler` (for ordering the instances w - Added more documentation about plugins. - Added sampler class and parameter in beam search for non-deterministic search, with several implementations, including `MultinomialSampler`, `TopKSampler`, `TopPSampler`, and - `GumbelMaxSampler`. Utilizing `GumbelMaxSampler` will give [Stochastic Beam Search](https://api.semanticscholar.org/CorpusID:76662039). + `GumbelSampler`. Utilizing `GumbelSampler` will give [Stochastic Beam Search](https://api.semanticscholar.org/CorpusID:76662039). ### Changed @@ -67,6 +72,8 @@ dataset at every epoch) and a `MultiTaskScheduler` (for ordering the instances w - Fixed typo with registered name of ROUGE metric. Previously was `rogue`, fixed to `rouge`. - Fixed default masks that were erroneously created on the CPU even when a GPU is available. - Fixed pretrained embeddings for transformers that don't use end tokens. +- Fixed the transformer tokenizer cache when the tokenizers are initialized with custom kwargs. + ## [v1.2.0](https://github.com/allenai/allennlp/releases/tag/v1.2.0) - 2020-10-29 diff --git a/allennlp/common/cached_transformers.py b/allennlp/common/cached_transformers.py index 04c13e8f83c..e3e700af8a2 100644 --- a/allennlp/common/cached_transformers.py +++ b/allennlp/common/cached_transformers.py @@ -94,11 +94,13 @@ def strip_prefix(s): return transformer -_tokenizer_cache: Dict[Tuple[str, frozenset], transformers.PreTrainedTokenizer] = {} +_tokenizer_cache: Dict[Tuple[str, str], transformers.PreTrainedTokenizer] = {} def get_tokenizer(model_name: str, **kwargs) -> transformers.PreTrainedTokenizer: - cache_key = (model_name, frozenset(kwargs.items())) + from allennlp.common.util import hash_object + + cache_key = (model_name, hash_object(kwargs)) global _tokenizer_cache tokenizer = _tokenizer_cache.get(cache_key, None) diff --git a/allennlp/common/util.py b/allennlp/common/util.py index 3c6ce64ea39..7b82738a994 100644 --- a/allennlp/common/util.py +++ b/allennlp/common/util.py @@ -1,6 +1,9 @@ """ Various utilities that don't fit anywhere else. """ +import hashlib +import io +import pickle from datetime import timedelta import importlib import json @@ -679,3 +682,12 @@ def cycle_iterator_function(iterator_function: Callable[[], Iterable[T]]) -> Ite yield next(iterator) except StopIteration: iterator = iter(iterator_function()) + + +def hash_object(o: Any) -> str: + """Returns a 32-character hash code of arbitrary Python objects.""" + m = hashlib.blake2b() + with io.BytesIO() as buffer: + pickle.dump(o, buffer) + m.update(buffer.getbuffer()) + return m.hexdigest() diff --git a/allennlp/nn/beam_search.py b/allennlp/nn/beam_search.py index 55e99d0d6b1..fff07b7dac2 100644 --- a/allennlp/nn/beam_search.py +++ b/allennlp/nn/beam_search.py @@ -385,12 +385,18 @@ def sample_beams( # shape (both): (batch_size, beam_size) G_phi_S_new, selected_indices = torch.topk(G_phi_S, beam_size, dim=-1) - # shape: (batch_size * beam_size,) - G_phi_S_new = G_phi_S_new.reshape(batch_size * beam_size) - # shape: (batch_size, beam_size) selected_log_probs = log_probs.gather(1, selected_indices) + # Now sort the selected beams by their true log prob. + # shape (all): (batch_size, beam_size) + selected_log_probs, sort_indices = selected_log_probs.sort(dim=-1, descending=True) + selected_indices = selected_indices.gather(1, sort_indices) + G_phi_S_new = G_phi_S_new.gather(1, sort_indices) + + # shape: (batch_size * beam_size,) + G_phi_S_new = G_phi_S_new.reshape(batch_size * beam_size) + # shape: (batch_size * beam_size,) phi_S = selected_log_probs.reshape(batch_size * beam_size) diff --git a/tests/nn/beam_search_test.py b/tests/nn/beam_search_test.py index 9da55b37613..88614f88e9d 100644 --- a/tests/nn/beam_search_test.py +++ b/tests/nn/beam_search_test.py @@ -434,10 +434,14 @@ def test_gumbel_sampler(self): num_classes = len(log_probabilities[0]) sampler_state = sampler.init_state(log_probabilities, batch_size=2, num_classes=num_classes) - probabilities, classes, state = sampler.sample_beams(log_probabilities, 3, sampler_state) + log_probs, indices, state = sampler.sample_beams(log_probabilities, 3, sampler_state) - assert probabilities.size() == classes.size() - assert classes.size() == (2, 3) + assert log_probs.size() == indices.size() + assert indices.size() == (2, 3) + + # Make sure the probabilities are sorted. + _, sorted_indices = log_probs.sort(dim=-1, descending=True) + assert (sorted_indices == torch.arange(3).unsqueeze(0)).all() - assert all([x >= 0 and x < 4 for x in classes[0]]) - assert all([x > 1 and x <= 5 for x in classes[1]]) + assert all([x >= 0 and x < 4 for x in indices[0]]) + assert all([x > 1 and x <= 5 for x in indices[1]])