Skip to content

Commit

Permalink
Initial implementation of vw_ensemble backend. Fixes #235
Browse files Browse the repository at this point in the history
  • Loading branch information
osma committed Jun 24, 2019
1 parent 112554d commit a56be26
Show file tree
Hide file tree
Showing 7 changed files with 251 additions and 5 deletions.
6 changes: 4 additions & 2 deletions annif/backend/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,8 @@ def get_backend(backend_id):
try:
from . import vw_multi
register_backend(vw_multi.VWMultiBackend)
from . import vw_ensemble
register_backend(vw_ensemble.VWEnsembleBackend)
except ImportError:
annif.logger.debug(
"vowpalwabbit not available, not enabling vw_multi backend")
annif.logger.debug("vowpalwabbit not available, not enabling " +
"vw_multi & vw_ensemble backends")
10 changes: 8 additions & 2 deletions annif/backend/ensemble.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,10 +30,16 @@ def _suggest_with_sources(self, text, sources):
hits=norm_hits, weight=weight))
return hits_from_sources

def _merge_hits_from_sources(self, hits_from_sources, project, params):
"""Hook for merging hits from sources. Can be overridden by
subclasses."""
return annif.util.merge_hits(hits_from_sources, project.subjects)

def _suggest(self, text, project, params):
sources = annif.util.parse_sources(params['sources'])
hits_from_sources = self._suggest_with_sources(text, sources)
merged_hits = annif.util.merge_hits(
hits_from_sources, project.subjects)
merged_hits = self._merge_hits_from_sources(hits_from_sources,
project,
params)
self.debug('{} hits after merging'.format(len(merged_hits)))
return merged_hits
1 change: 0 additions & 1 deletion annif/backend/mixins.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,6 @@ def _suggest_chunks(self, chunktexts, project):
pass # pragma: no cover

def _suggest(self, text, project, params):
self.initialize()
self.debug('Suggesting subjects for text "{}..." (len={})'.format(
text[:20], len(text)))
sentences = project.analyzer.tokenize_sentences(text)
Expand Down
170 changes: 170 additions & 0 deletions annif/backend/vw_ensemble.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,170 @@
"""Annif backend using the Vowpal Wabbit multiclass and multilabel
classifiers"""

import random
import os.path
import annif.util
from vowpalwabbit import pyvw
import numpy as np
from annif.suggestion import VectorSuggestionResult
from annif.exception import ConfigurationException, NotInitializedException
from . import backend
from . import ensemble


class VWEnsembleBackend(
ensemble.EnsembleBackend,
backend.AnnifLearningBackend):
"""Vowpal Wabbit ensemble backend that combines results from multiple
projects and learns how well those projects/backends recognize
particular subjects."""

name = "vw_ensemble"

VW_PARAMS = {
# each param specifier is a pair (allowed_values, default_value)
# where allowed_values is either a type or a list of allowed values
# and default_value may be None, to let VW decide by itself
'bit_precision': (int, None),
'learning_rate': (float, None),
'loss_function': (['squared', 'logistic', 'hinge'], 'squared'),
'l1': (float, None),
'l2': (float, None),
'passes': (int, None)
}

MODEL_FILE = 'vw-model'
TRAIN_FILE = 'vw-train.txt'

# defaults for uninitialized instances
_model = None

def initialize(self):
if self._model is None:
path = os.path.join(self.datadir, self.MODEL_FILE)
if not os.path.exists(path):
raise NotInitializedException(
'model {} not found'.format(path),
backend_id=self.backend_id)
self.debug('loading VW model from {}'.format(path))
params = self._create_params({'i': path, 'quiet': True})
if 'passes' in params:
# don't confuse the model with passes
del params['passes']
self.debug("model parameters: {}".format(params))
self._model = pyvw.vw(**params)
self.debug('loaded model {}'.format(str(self._model)))

@staticmethod
def _write_train_file(examples, filename):
with open(filename, 'w', encoding='utf-8') as trainfile:
for ex in examples:
print(ex, file=trainfile)

def _merge_hits_from_sources(self, hits_from_sources, project, params):
score_vector = np.array([hits.vector
for hits, _ in hits_from_sources])
result = np.zeros(score_vector.shape[1])
for subj_id in range(score_vector.shape[1]):
if score_vector[:, subj_id].sum() > 0.0:
ex = self._format_example(
subj_id,
score_vector[:, subj_id])
score = (self._model.predict(ex) + 1.0) / 2.0
result[subj_id] = score
return VectorSuggestionResult(result, project.subjects)

def _format_example(self, subject_id, scores, true=None):
if true is None:
val = ''
elif true:
val = 1
else:
val = -1
ex = "{} |{}".format(val, subject_id)
for proj_idx, proj in enumerate(self.source_project_ids):
ex += " {}:{}".format(proj, scores[proj_idx])
return ex

@property
def source_project_ids(self):
sources = annif.util.parse_sources(self.params['sources'])
return [project_id for project_id, _ in sources]

def _create_examples(self, corpus, project):
source_projects = [annif.project.get_project(project_id)
for project_id in self.source_project_ids]
examples = []
for doc in corpus.documents:
subjects = annif.corpus.SubjectSet((doc.uris, doc.labels))
true = subjects.as_vector(project.subjects)
score_vectors = []
for source_project in source_projects:
hits = source_project.suggest(doc.text)
score_vectors.append(hits.vector)
score_vector = np.array(score_vectors)
for subj_id in range(len(true)):
if true[subj_id] or score_vector[:, subj_id].sum() > 0.0:
ex = self._format_example(
subj_id,
score_vector[:, subj_id],
true[subj_id])
examples.append(ex)
random.shuffle(examples)
return examples

def _create_train_file(self, corpus, project):
self.info('creating VW train file')
examples = self._create_examples(corpus, project)
annif.util.atomic_save(examples,
self.datadir,
self.TRAIN_FILE,
method=self._write_train_file)

def _convert_param(self, param, val):
pspec, _ = self.VW_PARAMS[param]
if isinstance(pspec, list):
if val in pspec:
return val
raise ConfigurationException(
"{} is not a valid value for {} (allowed: {})".format(
val, param, ', '.join(pspec)), backend_id=self.backend_id)
try:
return pspec(val)
except ValueError:
raise ConfigurationException(
"The {} value {} cannot be converted to {}".format(
param, val, pspec), backend_id=self.backend_id)

def _create_params(self, params):
params.update({param: defaultval
for param, (_, defaultval) in self.VW_PARAMS.items()
if defaultval is not None})
params.update({param: self._convert_param(param, val)
for param, val in self.params.items()
if param in self.VW_PARAMS})
return params

def _create_model(self, project):
trainpath = os.path.join(self.datadir, self.TRAIN_FILE)
params = self._create_params(
{'data': trainpath, 'q': '::'})
if params.get('passes', 1) > 1:
# need a cache file when there are multiple passes
params.update({'cache': True, 'kill_cache': True})
self.debug("model parameters: {}".format(params))
self._model = pyvw.vw(**params)
modelpath = os.path.join(self.datadir, self.MODEL_FILE)
self._model.save(modelpath)

def train(self, corpus, project):
self.info("creating VW ensemble model")
self._create_train_file(corpus, project)
self._create_model(project)

def learn(self, corpus, project):
self.initialize()
for example in self._create_examples(corpus, project):
self._model.learn(example)
modelpath = os.path.join(self.datadir, self.MODEL_FILE)
self._model.save(modelpath)
1 change: 1 addition & 0 deletions annif/backend/vw_multi.py
Original file line number Diff line number Diff line change
Expand Up @@ -189,6 +189,7 @@ def train(self, corpus, project):
self._create_model(project)

def learn(self, corpus, project):
self.initialize()
for example in self._create_examples(corpus, project):
self._model.learn(example)
modelpath = os.path.join(self.datadir, self.MODEL_FILE)
Expand Down
12 changes: 12 additions & 0 deletions tests/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -73,6 +73,18 @@ def document_corpus(subject_index):
return doc_corpus


@pytest.fixture(scope='module')
def fulltext_corpus(subject_index):
docdir = os.path.join(
os.path.dirname(__file__),
'corpora',
'archaeology',
'fulltext')
ft_corpus = annif.corpus.DocumentDirectory(docdir)
ft_corpus.set_subject_index(subject_index)
return ft_corpus


@pytest.fixture(scope='module')
def project(document_corpus):
proj = unittest.mock.Mock()
Expand Down
56 changes: 56 additions & 0 deletions tests/test_backend_vw_ensemble.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,56 @@
"""Unit tests for the vw_ensemble backend in Annif"""

import pytest
import annif.backend
import annif.corpus

pytest.importorskip("annif.backend.vw_ensemble")


def test_vw_ensemble_train(app, datadir, tmpdir, fulltext_corpus, project):
vw_ensemble_type = annif.backend.get_backend("vw_ensemble")
vw_ensemble = vw_ensemble_type(
backend_id='vw_ensemble',
params={'sources': 'tfidf-fi'},
datadir=str(datadir))

with app.app_context():
vw_ensemble.train(fulltext_corpus, project)
assert datadir.join('vw-train.txt').exists()
assert datadir.join('vw-train.txt').size() > 0
assert datadir.join('vw-model').exists()
assert datadir.join('vw-model').size() > 0


def test_vw_ensemble_initialize(app, datadir):
vw_ensemble_type = annif.backend.get_backend("vw_ensemble")
vw_ensemble = vw_ensemble_type(
backend_id='vw_ensemble',
params={'sources': 'tfidf-fi'},
datadir=str(datadir))

assert vw_ensemble._model is None
with app.app_context():
vw_ensemble.initialize()
assert vw_ensemble._model is not None
# initialize a second time - this shouldn't do anything
with app.app_context():
vw_ensemble.initialize()


def test_vw_ensemble_suggest(app, datadir, project):
vw_ensemble_type = annif.backend.get_backend("vw_ensemble")
vw_ensemble = vw_ensemble_type(
backend_id='vw_ensemble',
params={'sources': 'tfidf-fi'},
datadir=str(datadir))

results = vw_ensemble.suggest("""Arkeologiaa sanotaan joskus myös
muinaistutkimukseksi tai muinaistieteeksi. Se on humanistinen tiede
tai oikeammin joukko tieteitä, jotka tutkivat ihmisen menneisyyttä.
Tutkimusta tehdään analysoimalla muinaisjäännöksiä eli niitä jälkiä,
joita ihmisten toiminta on jättänyt maaperään tai vesistöjen
pohjaan.""", project)

assert vw_ensemble._model is not None
assert len(results) > 0

0 comments on commit a56be26

Please sign in to comment.