Skip to content
This repository has been archived by the owner on Jan 5, 2023. It is now read-only.

Commit

Permalink
vocabulary: remove old checks, add documentation
Browse files Browse the repository at this point in the history
  • Loading branch information
ozancaglayan committed Feb 20, 2020
1 parent fd45bab commit f4f9901
Showing 1 changed file with 21 additions and 24 deletions.
45 changes: 21 additions & 24 deletions nmtpytorch/vocabulary.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,33 +56,28 @@ def __init__(self, fname, short_list=0):
self.counts = None
self._allmap = None
self.n_tokens = None
# By default, we start with all special tokens
for tok in self.TOKENS:
setattr(self, f'has_{tok[1:-1]}', True)

# Load file
data = json.load(open(self.vocab))
with open(self.vocab) as f:
data = json.load(f)

if self.short_list > 0:
# Get a slice of most frequent `short_list` items
data = dict(list(data.items())[:self.short_list])

# Detect vocabulary: values can be int or a string of type "id count"
elem = next(iter(data.values()))
if isinstance(elem, str):
self._map = {k: int(v.split()[0]) for k, v in data.items()}
self.counts = {k: int(v.split()[1]) for k, v in data.items()}
total_count = sum(self.counts.values())
self.freqs = {k: v / total_count for k, v in self.counts.items()}
elif isinstance(elem, int):
self._map = data
else:
raise RuntimeError('Unknown vocabulary format.')
self._map = {k: int(v.split()[0]) for k, v in data.items()}
self.counts = {k: int(v.split()[1]) for k, v in data.items()}

total_count = sum(self.counts.values())
self.freqs = {k: v / total_count for k, v in self.counts.items()}

# Sanity check for placeholder tokens
for tok, idx in self.TOKENS.items():
if self._map.get(tok, -1) != idx:
logger.info(f'{tok} not found in {self.vocab.name!r}')
setattr(self, f'has_{tok[1:-1]}', False)
else:
setattr(self, f'has_{tok[1:-1]}', True)

# Set # of tokens
self.n_tokens = len(self._map)
Expand All @@ -98,12 +93,6 @@ def __init__(self, fname, short_list=0):
assert len(self._allmap) == (len(self._map) + len(self._imap)), \
"Merged vocabulary size is not equal to sum of both."

def __getitem__(self, key):
return self._allmap[key]

def __len__(self):
return len(self._map)

def sent_to_idxs(self, line, explicit_bos=False, explicit_eos=True):
"""Convert from list of strings to list of token indices."""
tidxs = []
Expand All @@ -115,12 +104,14 @@ def sent_to_idxs(self, line, explicit_bos=False, explicit_eos=True):
for tok in line.split():
tidxs.append(self._map.get(tok, self.TOKENS["<unk>"]))
else:
# Silently remove unknown tokens from the words
# Remove unknown tokens from the words
for tok in line.split():
try:
tidxs.append(self._map[tok])
except KeyError as ke:
pass
except KeyError as _:
# make this verbose and repetitive as this should be
# used cautiously only for some specific models
logger.info('No <unk> token, removing word from sentence')

if explicit_eos and self.has_eos:
tidxs.append(self.TOKENS["<eos>"])
Expand Down Expand Up @@ -171,5 +162,11 @@ def list_of_idxs_to_sents(self, lidxs):
results.append(" ".join(result))
return results

def __getitem__(self, key):
return self._allmap[key]

def __len__(self):
return len(self._map)

def __repr__(self):
return f"Vocabulary of {self.n_tokens} items ({self.vocab.name!r})"

0 comments on commit f4f9901

Please sign in to comment.