Skip to content

Commit

Permalink
Allow dictionaries to overwrite entries with #fairseq:overwrite comme…
Browse files Browse the repository at this point in the history
…nt (#1073)

Summary:
[This commit](dd1298e) made it so that duplicate entries in a dictionary are ignored. Unfortunately the Camembert model depends on overwriting `<unk>`, `<s>` and `</s>`.

The proposed solution here is to allow the dictionary to have entries like:
```
<unk> 999 #fairseq:overwrite
<s> 999 #fairseq:overwrite
</s> 999 #fairseq:overwrite
, 999
▁de 999
. 999
(...)
```

These will preserve the old overwriting behavior. Thus we can release a new `camembert.v0.tar.gz` with a dictionary like above and it works.
Pull Request resolved: fairinternal/fairseq-py#1073

Reviewed By: kahne

Differential Revision: D20284569

Pulled By: myleott

fbshipit-source-id: bf78fbff13c94bf8a6485cbdda62305ddc30c056
  • Loading branch information
myleott authored and facebook-github-bot committed Mar 8, 2020
1 parent 3dd221c commit 937535d
Show file tree
Hide file tree
Showing 2 changed files with 70 additions and 8 deletions.
32 changes: 24 additions & 8 deletions fairseq/data/dictionary.py
Original file line number Diff line number Diff line change
Expand Up @@ -91,9 +91,9 @@ def unk_string(self, escape=False):
else:
return self.unk_word

def add_symbol(self, word, n=1):
def add_symbol(self, word, n=1, overwrite=False):
"""Adds a word to the dictionary"""
if word in self.indices:
if word in self.indices and not overwrite:
idx = self.indices[word]
self.count[idx] = self.count[idx] + n
return idx
Expand Down Expand Up @@ -215,15 +215,31 @@ def add_from_file(self, f):

lines = f.readlines()
indices_start_line = self._load_meta(lines)

for line in lines[indices_start_line:]:
idx = line.rfind(" ")
if idx == -1:
try:
line, field = line.rstrip().rsplit(" ", 1)
if field == "#fairseq:overwrite":
overwrite = True
line, field = line.rsplit(" ", 1)
else:
overwrite = False
count = int(field)
word = line
if word in self and not overwrite:
raise RuntimeError(
"Duplicate word found when loading Dictionary: '{}'. "
"Duplicate words can overwrite earlier ones by adding the "
"#fairseq:overwrite flag at the end of the corresponding row "
"in the dictionary file. If using the Camembert model, please "
"download an updated copy of the model file."
.format(word)
)
self.add_symbol(word, n=count, overwrite=overwrite)
except ValueError:
raise ValueError(
"Incorrect dictionary format, expected '<token> <cnt>'"
"Incorrect dictionary format, expected '<token> <cnt> [flags]'"
)
word = line[:idx]
count = int(line[idx + 1 :])
self.add_symbol(word, n=count)

def _save(self, f, kv_iterator):
if isinstance(f, str):
Expand Down
46 changes: 46 additions & 0 deletions tests/test_dictionary.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.

import io
import tempfile
import unittest

Expand Down Expand Up @@ -65,6 +66,51 @@ def assertMatch(ids, ref_ids):
assertMatch(reload_ids, ref_ids2)
assertMatch(finalized_ids, reload_ids)

def test_overwrite(self):
# for example, Camembert overwrites <unk>, <s> and </s>
dict_file = io.StringIO(
"<unk> 999 #fairseq:overwrite\n"
"<s> 999 #fairseq:overwrite\n"
"</s> 999 #fairseq:overwrite\n"
", 999\n"
"▁de 999\n"
)
d = Dictionary()
d.add_from_file(dict_file)
self.assertEqual(d.index('<pad>'), 1)
self.assertEqual(d.index('foo'), 3)
self.assertEqual(d.index('<unk>'), 4)
self.assertEqual(d.index('<s>'), 5)
self.assertEqual(d.index('</s>'), 6)
self.assertEqual(d.index(','), 7)
self.assertEqual(d.index('▁de'), 8)

def test_no_overwrite(self):
# for example, Camembert overwrites <unk>, <s> and </s>
dict_file = io.StringIO(
"<unk> 999\n"
"<s> 999\n"
"</s> 999\n"
", 999\n"
"▁de 999\n"
)
d = Dictionary()
with self.assertRaisesRegex(RuntimeError, 'Duplicate'):
d.add_from_file(dict_file)

def test_space(self):
# for example, character models treat space as a symbol
dict_file = io.StringIO(
" 999\n"
"a 999\n"
"b 999\n"
)
d = Dictionary()
d.add_from_file(dict_file)
self.assertEqual(d.index(' '), 4)
self.assertEqual(d.index('a'), 5)
self.assertEqual(d.index('b'), 6)


if __name__ == '__main__':
unittest.main()

0 comments on commit 937535d

Please sign in to comment.