Skip to content

Commit

Permalink
Merge pull request #254 from NatLibFi/issue230-vw-multi-enhancements
Browse files Browse the repository at this point in the history
Enhancements to vw_multi backend
  • Loading branch information
osma authored Feb 1, 2019
2 parents edcef39 + a0194e8 commit c08663f
Show file tree
Hide file tree
Showing 2 changed files with 112 additions and 26 deletions.
45 changes: 31 additions & 14 deletions annif/backend/vw_multi.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,15 +23,17 @@ class VWMultiBackend(mixins.ChunkingBackend, backend.AnnifBackend):
# where allowed_values is either a type or a list of allowed values
# and default_value may be None, to let VW decide by itself
'bit_precision': (int, None),
'ngram': (int, None),
'learning_rate': (float, None),
'loss_function': (['squared', 'logistic', 'hinge'], 'logistic'),
'l1': (float, None),
'l2': (float, None),
'passes': (int, None)
'passes': (int, None),
'probabilities': (bool, None)
}

DEFAULT_ALGORITHM = 'oaa'
SUPPORTED_ALGORITHMS = ('oaa', 'ect')
SUPPORTED_ALGORITHMS = ('oaa', 'ect', 'log_multi', 'multilabel_oaa')

MODEL_FILE = 'vw-model'
TRAIN_FILE = 'vw-train.txt'
Expand Down Expand Up @@ -71,22 +73,35 @@ def _normalize_text(cls, project, text):
# colon and pipe chars have special meaning in VW and must be avoided
return ntext.replace(':', '').replace('|', '')

def _write_train_file(self, examples, filename):
@classmethod
def _write_train_file(cls, examples, filename):
with open(filename, 'w') as trainfile:
for ex in examples:
print(ex, file=trainfile)

@classmethod
def _uris_to_subject_ids(cls, project, uris):
subject_ids = []
for uri in uris:
subject_id = project.subjects.by_uri(uri)
if subject_id is not None:
subject_ids.append(subject_id)
return subject_ids

def _format_examples(self, project, text, uris):
subject_ids = self._uris_to_subject_ids(project, uris)
if self.algorithm == 'multilabel_oaa':
yield '{} | {}'.format(','.join(map(str, subject_ids)), text)
else:
for subject_id in subject_ids:
yield '{} | {}'.format(subject_id + 1, text)

def _create_train_file(self, corpus, project):
self.info('creating VW train file')
examples = []
for doc in corpus.documents:
text = self._normalize_text(project, doc.text)
for uri in doc.uris:
subject_id = project.subjects.by_uri(uri)
if subject_id is None:
continue
exstr = '{} | {}'.format(subject_id + 1, text)
examples.append(exstr)
examples.extend(self._format_examples(project, text, doc.uris))
random.shuffle(examples)
annif.util.atomic_save(examples,
self._get_datadir(),
Expand Down Expand Up @@ -115,9 +130,6 @@ def _create_params(self, params):
params.update({param: self._convert_param(param, val)
for param, val in self.params.items()
if param in self.VW_PARAMS})
if self.algorithm == 'oaa':
# only the oaa algorithm supports probabilities output
params.update({'probabilities': True, 'loss_function': 'logistic'})
return params

def _create_model(self, project):
Expand All @@ -142,8 +154,13 @@ def _analyze_chunks(self, chunktexts, project):
for chunktext in chunktexts:
example = ' | {}'.format(chunktext)
result = self._model.predict(example)
if isinstance(result, int):
# just a single integer - need to one-hot-encode
if self.algorithm == 'multilabel_oaa':
# result is a list of subject IDs - need to vectorize
mask = np.zeros(len(project.subjects))
mask[result] = 1.0
result = mask
elif isinstance(result, int):
# result is a single integer - need to one-hot-encode
mask = np.zeros(len(project.subjects))
mask[result - 1] = 1.0
result = mask
Expand Down
93 changes: 81 additions & 12 deletions tests/test_backend_vw.py → tests/test_backend_vw_multi.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@ def vw_corpus(tmpdir):
return annif.corpus.DocumentFile(str(tmpfile))


def test_vw_analyze_no_model(datadir, project):
def test_vw_multi_analyze_no_model(datadir, project):
vw_type = annif.backend.get_backend('vw_multi')
vw = vw_type(
backend_id='vw_multi',
Expand All @@ -28,7 +28,7 @@ def test_vw_analyze_no_model(datadir, project):
results = vw.analyze("example text", project)


def test_vw_train(datadir, document_corpus, project):
def test_vw_multi_train(datadir, document_corpus, project):
vw_type = annif.backend.get_backend('vw_multi')
vw = vw_type(
backend_id='vw_multi',
Expand All @@ -44,7 +44,7 @@ def test_vw_train(datadir, document_corpus, project):
assert datadir.join('vw-model').size() > 0


def test_vw_train_multiple_passes(datadir, document_corpus, project):
def test_vw_multi_train_multiple_passes(datadir, document_corpus, project):
vw_type = annif.backend.get_backend('vw_multi')
vw = vw_type(
backend_id='vw_multi',
Expand All @@ -60,7 +60,7 @@ def test_vw_train_multiple_passes(datadir, document_corpus, project):
assert datadir.join('vw-model').size() > 0


def test_vw_train_invalid_algorithm(datadir, document_corpus, project):
def test_vw_multi_train_invalid_algorithm(datadir, document_corpus, project):
vw_type = annif.backend.get_backend('vw_multi')
vw = vw_type(
backend_id='vw_multi',
Expand All @@ -74,7 +74,7 @@ def test_vw_train_invalid_algorithm(datadir, document_corpus, project):
vw.train(document_corpus, project)


def test_vw_train_invalid_loss_function(datadir, project, vw_corpus):
def test_vw_multi_train_invalid_loss_function(datadir, project, vw_corpus):
vw_type = annif.backend.get_backend('vw_multi')
vw = vw_type(
backend_id='vw_multi',
Expand All @@ -85,7 +85,7 @@ def test_vw_train_invalid_loss_function(datadir, project, vw_corpus):
vw.train(vw_corpus, project)


def test_vw_train_invalid_learning_rate(datadir, project, vw_corpus):
def test_vw_multi_train_invalid_learning_rate(datadir, project, vw_corpus):
vw_type = annif.backend.get_backend('vw_multi')
vw = vw_type(
backend_id='vw_multi',
Expand All @@ -96,11 +96,11 @@ def test_vw_train_invalid_learning_rate(datadir, project, vw_corpus):
vw.train(vw_corpus, project)


def test_vw_analyze(datadir, project):
def test_vw_multi_analyze(datadir, project):
vw_type = annif.backend.get_backend('vw_multi')
vw = vw_type(
backend_id='vw_multi',
params={'chunksize': 4},
params={'chunksize': 4, 'probabilities': 1},
datadir=str(datadir))

results = vw.analyze("""Arkeologiaa sanotaan joskus myös
Expand All @@ -116,7 +116,7 @@ def test_vw_analyze(datadir, project):
assert 'arkeologia' in [result.label for result in results]


def test_vw_analyze_empty(datadir, project):
def test_vw_multi_analyze_empty(datadir, project):
vw_type = annif.backend.get_backend('vw_multi')
vw = vw_type(
backend_id='vw_multi',
Expand All @@ -128,7 +128,7 @@ def test_vw_analyze_empty(datadir, project):
assert len(results) == 0


def test_vw_analyze_multiple_passes(datadir, project):
def test_vw_multi_analyze_multiple_passes(datadir, project):
vw_type = annif.backend.get_backend('vw_multi')
vw = vw_type(
backend_id='vw_multi',
Expand All @@ -140,7 +140,7 @@ def test_vw_analyze_multiple_passes(datadir, project):
assert len(results) == 0


def test_vw_train_ect(datadir, document_corpus, project):
def test_vw_multi_train_ect(datadir, document_corpus, project):
vw_type = annif.backend.get_backend('vw_multi')
vw = vw_type(
backend_id='vw_multi',
Expand All @@ -156,7 +156,7 @@ def test_vw_train_ect(datadir, document_corpus, project):
assert datadir.join('vw-model').size() > 0


def test_vw_analyze_ect(datadir, project):
def test_vw_multi_analyze_ect(datadir, project):
vw_type = annif.backend.get_backend('vw_multi')
vw = vw_type(
backend_id='vw_multi',
Expand All @@ -172,3 +172,72 @@ def test_vw_analyze_ect(datadir, project):
pohjaan.""", project)

assert len(results) > 0


def test_vw_multi_train_log_multi(datadir, document_corpus, project):
vw_type = annif.backend.get_backend('vw_multi')
vw = vw_type(
backend_id='vw_multi',
params={
'chunksize': 4,
'learning_rate': 0.5,
'algorithm': 'log_multi'},
datadir=str(datadir))

vw.train(document_corpus, project)
assert vw._model is not None
assert datadir.join('vw-model').exists()
assert datadir.join('vw-model').size() > 0


def test_vw_multi_analyze_log_multi(datadir, project):
vw_type = annif.backend.get_backend('vw_multi')
vw = vw_type(
backend_id='vw_multi',
params={'chunksize': 1,
'algorithm': 'log_multi'},
datadir=str(datadir))

results = vw.analyze("""Arkeologiaa sanotaan joskus myös
muinaistutkimukseksi tai muinaistieteeksi. Se on humanistinen tiede
tai oikeammin joukko tieteitä, jotka tutkivat ihmisen menneisyyttä.
Tutkimusta tehdään analysoimalla muinaisjäännöksiä eli niitä jälkiä,
joita ihmisten toiminta on jättänyt maaperään tai vesistöjen
pohjaan.""", project)

assert len(results) > 0


def test_vw_multi_train_multilabel_oaa(datadir, document_corpus, project):
vw_type = annif.backend.get_backend('vw_multi')
vw = vw_type(
backend_id='vw_multi',
params={
'chunksize': 4,
'learning_rate': 0.5,
'algorithm': 'multilabel_oaa'},
datadir=str(datadir))

vw.train(document_corpus, project)
assert vw._model is not None
assert datadir.join('vw-model').exists()
assert datadir.join('vw-model').size() > 0


def test_vw_multi_analyze_multilabel_oaa(datadir, project):
vw_type = annif.backend.get_backend('vw_multi')
vw = vw_type(
backend_id='vw_multi',
params={'chunksize': 1,
'algorithm': 'multilabel_oaa'},
datadir=str(datadir))

results = vw.analyze("""Arkeologiaa sanotaan joskus myös
muinaistutkimukseksi tai muinaistieteeksi. Se on humanistinen tiede
tai oikeammin joukko tieteitä, jotka tutkivat ihmisen menneisyyttä.
Tutkimusta tehdään analysoimalla muinaisjäännöksiä eli niitä jälkiä,
joita ihmisten toiminta on jättänyt maaperään tai vesistöjen
pohjaan.""", project)

# weak assertion, but often multilabel_oaa produces zero hits
assert len(results) >= 0

0 comments on commit c08663f

Please sign in to comment.