diff --git a/nmtpytorch/vocabulary.py b/nmtpytorch/vocabulary.py index 889097c2..cf04b03e 100644 --- a/nmtpytorch/vocabulary.py +++ b/nmtpytorch/vocabulary.py @@ -56,33 +56,28 @@ def __init__(self, fname, short_list=0): self.counts = None self._allmap = None self.n_tokens = None - # By default, we start with all special tokens - for tok in self.TOKENS: - setattr(self, f'has_{tok[1:-1]}', True) # Load file - data = json.load(open(self.vocab)) + with open(self.vocab) as f: + data = json.load(f) + if self.short_list > 0: # Get a slice of most frequent `short_list` items data = dict(list(data.items())[:self.short_list]) - # Detect vocabulary: values can be int or a string of type "id count" - elem = next(iter(data.values())) - if isinstance(elem, str): - self._map = {k: int(v.split()[0]) for k, v in data.items()} - self.counts = {k: int(v.split()[1]) for k, v in data.items()} - total_count = sum(self.counts.values()) - self.freqs = {k: v / total_count for k, v in self.counts.items()} - elif isinstance(elem, int): - self._map = data - else: - raise RuntimeError('Unknown vocabulary format.') + self._map = {k: int(v.split()[0]) for k, v in data.items()} + self.counts = {k: int(v.split()[1]) for k, v in data.items()} + + total_count = sum(self.counts.values()) + self.freqs = {k: v / total_count for k, v in self.counts.items()} # Sanity check for placeholder tokens for tok, idx in self.TOKENS.items(): if self._map.get(tok, -1) != idx: logger.info(f'{tok} not found in {self.vocab.name!r}') setattr(self, f'has_{tok[1:-1]}', False) + else: + setattr(self, f'has_{tok[1:-1]}', True) # Set # of tokens self.n_tokens = len(self._map) @@ -98,12 +93,6 @@ def __init__(self, fname, short_list=0): assert len(self._allmap) == (len(self._map) + len(self._imap)), \ "Merged vocabulary size is not equal to sum of both." - def __getitem__(self, key): - return self._allmap[key] - - def __len__(self): - return len(self._map) - def sent_to_idxs(self, line, explicit_bos=False, explicit_eos=True): """Convert from list of strings to list of token indices.""" tidxs = [] @@ -115,12 +104,14 @@ def sent_to_idxs(self, line, explicit_bos=False, explicit_eos=True): for tok in line.split(): tidxs.append(self._map.get(tok, self.TOKENS[""])) else: - # Silently remove unknown tokens from the words + # Remove unknown tokens from the words for tok in line.split(): try: tidxs.append(self._map[tok]) - except KeyError as ke: - pass + except KeyError as _: + # make this verbose and repetitive as this should be + # used cautiously only for some specific models + logger.info('No token, removing word from sentence') if explicit_eos and self.has_eos: tidxs.append(self.TOKENS[""]) @@ -171,5 +162,11 @@ def list_of_idxs_to_sents(self, lidxs): results.append(" ".join(result)) return results + def __getitem__(self, key): + return self._allmap[key] + + def __len__(self): + return len(self._map) + def __repr__(self): return f"Vocabulary of {self.n_tokens} items ({self.vocab.name!r})"