diff --git a/annif/backend/__init__.py b/annif/backend/__init__.py index 1b5b21651..4a5fd82e2 100644 --- a/annif/backend/__init__.py +++ b/annif/backend/__init__.py @@ -38,6 +38,8 @@ def get_backend(backend_id): try: from . import vw_multi register_backend(vw_multi.VWMultiBackend) + from . import vw_ensemble + register_backend(vw_ensemble.VWEnsembleBackend) except ImportError: - annif.logger.debug( - "vowpalwabbit not available, not enabling vw_multi backend") + annif.logger.debug("vowpalwabbit not available, not enabling " + + "vw_multi & vw_ensemble backends") diff --git a/annif/backend/ensemble.py b/annif/backend/ensemble.py index 90d2712e7..90959979a 100644 --- a/annif/backend/ensemble.py +++ b/annif/backend/ensemble.py @@ -30,10 +30,16 @@ def _suggest_with_sources(self, text, sources): hits=norm_hits, weight=weight)) return hits_from_sources + def _merge_hits_from_sources(self, hits_from_sources, project, params): + """Hook for merging hits from sources. Can be overridden by + subclasses.""" + return annif.util.merge_hits(hits_from_sources, project.subjects) + def _suggest(self, text, project, params): sources = annif.util.parse_sources(params['sources']) hits_from_sources = self._suggest_with_sources(text, sources) - merged_hits = annif.util.merge_hits( - hits_from_sources, project.subjects) + merged_hits = self._merge_hits_from_sources(hits_from_sources, + project, + params) self.debug('{} hits after merging'.format(len(merged_hits))) return merged_hits diff --git a/annif/backend/mixins.py b/annif/backend/mixins.py index 976ed028c..04774a58c 100644 --- a/annif/backend/mixins.py +++ b/annif/backend/mixins.py @@ -16,7 +16,6 @@ def _suggest_chunks(self, chunktexts, project): pass # pragma: no cover def _suggest(self, text, project, params): - self.initialize() self.debug('Suggesting subjects for text "{}..." (len={})'.format( text[:20], len(text))) sentences = project.analyzer.tokenize_sentences(text) diff --git a/annif/backend/vw_ensemble.py b/annif/backend/vw_ensemble.py new file mode 100644 index 000000000..c43242aac --- /dev/null +++ b/annif/backend/vw_ensemble.py @@ -0,0 +1,170 @@ +"""Annif backend using the Vowpal Wabbit multiclass and multilabel +classifiers""" + +import random +import os.path +import annif.util +from vowpalwabbit import pyvw +import numpy as np +from annif.suggestion import VectorSuggestionResult +from annif.exception import ConfigurationException, NotInitializedException +from . import backend +from . import ensemble + + +class VWEnsembleBackend( + ensemble.EnsembleBackend, + backend.AnnifLearningBackend): + """Vowpal Wabbit ensemble backend that combines results from multiple + projects and learns how well those projects/backends recognize + particular subjects.""" + + name = "vw_ensemble" + + VW_PARAMS = { + # each param specifier is a pair (allowed_values, default_value) + # where allowed_values is either a type or a list of allowed values + # and default_value may be None, to let VW decide by itself + 'bit_precision': (int, None), + 'learning_rate': (float, None), + 'loss_function': (['squared', 'logistic', 'hinge'], 'squared'), + 'l1': (float, None), + 'l2': (float, None), + 'passes': (int, None) + } + + MODEL_FILE = 'vw-model' + TRAIN_FILE = 'vw-train.txt' + + # defaults for uninitialized instances + _model = None + + def initialize(self): + if self._model is None: + path = os.path.join(self.datadir, self.MODEL_FILE) + if not os.path.exists(path): + raise NotInitializedException( + 'model {} not found'.format(path), + backend_id=self.backend_id) + self.debug('loading VW model from {}'.format(path)) + params = self._create_params({'i': path, 'quiet': True}) + if 'passes' in params: + # don't confuse the model with passes + del params['passes'] + self.debug("model parameters: {}".format(params)) + self._model = pyvw.vw(**params) + self.debug('loaded model {}'.format(str(self._model))) + + @staticmethod + def _write_train_file(examples, filename): + with open(filename, 'w', encoding='utf-8') as trainfile: + for ex in examples: + print(ex, file=trainfile) + + def _merge_hits_from_sources(self, hits_from_sources, project, params): + score_vector = np.array([hits.vector + for hits, _ in hits_from_sources]) + result = np.zeros(score_vector.shape[1]) + for subj_id in range(score_vector.shape[1]): + if score_vector[:, subj_id].sum() > 0.0: + ex = self._format_example( + subj_id, + score_vector[:, subj_id]) + score = (self._model.predict(ex) + 1.0) / 2.0 + result[subj_id] = score + return VectorSuggestionResult(result, project.subjects) + + def _format_example(self, subject_id, scores, true=None): + if true is None: + val = '' + elif true: + val = 1 + else: + val = -1 + ex = "{} |{}".format(val, subject_id) + for proj_idx, proj in enumerate(self.source_project_ids): + ex += " {}:{}".format(proj, scores[proj_idx]) + return ex + + @property + def source_project_ids(self): + sources = annif.util.parse_sources(self.params['sources']) + return [project_id for project_id, _ in sources] + + def _create_examples(self, corpus, project): + source_projects = [annif.project.get_project(project_id) + for project_id in self.source_project_ids] + examples = [] + for doc in corpus.documents: + subjects = annif.corpus.SubjectSet((doc.uris, doc.labels)) + true = subjects.as_vector(project.subjects) + score_vectors = [] + for source_project in source_projects: + hits = source_project.suggest(doc.text) + score_vectors.append(hits.vector) + score_vector = np.array(score_vectors) + for subj_id in range(len(true)): + if true[subj_id] or score_vector[:, subj_id].sum() > 0.0: + ex = self._format_example( + subj_id, + score_vector[:, subj_id], + true[subj_id]) + examples.append(ex) + random.shuffle(examples) + return examples + + def _create_train_file(self, corpus, project): + self.info('creating VW train file') + examples = self._create_examples(corpus, project) + annif.util.atomic_save(examples, + self.datadir, + self.TRAIN_FILE, + method=self._write_train_file) + + def _convert_param(self, param, val): + pspec, _ = self.VW_PARAMS[param] + if isinstance(pspec, list): + if val in pspec: + return val + raise ConfigurationException( + "{} is not a valid value for {} (allowed: {})".format( + val, param, ', '.join(pspec)), backend_id=self.backend_id) + try: + return pspec(val) + except ValueError: + raise ConfigurationException( + "The {} value {} cannot be converted to {}".format( + param, val, pspec), backend_id=self.backend_id) + + def _create_params(self, params): + params.update({param: defaultval + for param, (_, defaultval) in self.VW_PARAMS.items() + if defaultval is not None}) + params.update({param: self._convert_param(param, val) + for param, val in self.params.items() + if param in self.VW_PARAMS}) + return params + + def _create_model(self, project): + trainpath = os.path.join(self.datadir, self.TRAIN_FILE) + params = self._create_params( + {'data': trainpath, 'q': '::'}) + if params.get('passes', 1) > 1: + # need a cache file when there are multiple passes + params.update({'cache': True, 'kill_cache': True}) + self.debug("model parameters: {}".format(params)) + self._model = pyvw.vw(**params) + modelpath = os.path.join(self.datadir, self.MODEL_FILE) + self._model.save(modelpath) + + def train(self, corpus, project): + self.info("creating VW ensemble model") + self._create_train_file(corpus, project) + self._create_model(project) + + def learn(self, corpus, project): + self.initialize() + for example in self._create_examples(corpus, project): + self._model.learn(example) + modelpath = os.path.join(self.datadir, self.MODEL_FILE) + self._model.save(modelpath) diff --git a/annif/backend/vw_multi.py b/annif/backend/vw_multi.py index 287704d02..e4006128f 100644 --- a/annif/backend/vw_multi.py +++ b/annif/backend/vw_multi.py @@ -189,6 +189,7 @@ def train(self, corpus, project): self._create_model(project) def learn(self, corpus, project): + self.initialize() for example in self._create_examples(corpus, project): self._model.learn(example) modelpath = os.path.join(self.datadir, self.MODEL_FILE) diff --git a/tests/conftest.py b/tests/conftest.py index d10c14d1e..c965d3fac 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -73,6 +73,18 @@ def document_corpus(subject_index): return doc_corpus +@pytest.fixture(scope='module') +def fulltext_corpus(subject_index): + docdir = os.path.join( + os.path.dirname(__file__), + 'corpora', + 'archaeology', + 'fulltext') + ft_corpus = annif.corpus.DocumentDirectory(docdir) + ft_corpus.set_subject_index(subject_index) + return ft_corpus + + @pytest.fixture(scope='module') def project(document_corpus): proj = unittest.mock.Mock() diff --git a/tests/test_backend_vw_ensemble.py b/tests/test_backend_vw_ensemble.py new file mode 100644 index 000000000..b9308bffb --- /dev/null +++ b/tests/test_backend_vw_ensemble.py @@ -0,0 +1,56 @@ +"""Unit tests for the vw_ensemble backend in Annif""" + +import pytest +import annif.backend +import annif.corpus + +pytest.importorskip("annif.backend.vw_ensemble") + + +def test_vw_ensemble_train(app, datadir, tmpdir, fulltext_corpus, project): + vw_ensemble_type = annif.backend.get_backend("vw_ensemble") + vw_ensemble = vw_ensemble_type( + backend_id='vw_ensemble', + params={'sources': 'tfidf-fi'}, + datadir=str(datadir)) + + with app.app_context(): + vw_ensemble.train(fulltext_corpus, project) + assert datadir.join('vw-train.txt').exists() + assert datadir.join('vw-train.txt').size() > 0 + assert datadir.join('vw-model').exists() + assert datadir.join('vw-model').size() > 0 + + +def test_vw_ensemble_initialize(app, datadir): + vw_ensemble_type = annif.backend.get_backend("vw_ensemble") + vw_ensemble = vw_ensemble_type( + backend_id='vw_ensemble', + params={'sources': 'tfidf-fi'}, + datadir=str(datadir)) + + assert vw_ensemble._model is None + with app.app_context(): + vw_ensemble.initialize() + assert vw_ensemble._model is not None + # initialize a second time - this shouldn't do anything + with app.app_context(): + vw_ensemble.initialize() + + +def test_vw_ensemble_suggest(app, datadir, project): + vw_ensemble_type = annif.backend.get_backend("vw_ensemble") + vw_ensemble = vw_ensemble_type( + backend_id='vw_ensemble', + params={'sources': 'tfidf-fi'}, + datadir=str(datadir)) + + results = vw_ensemble.suggest("""Arkeologiaa sanotaan joskus myös + muinaistutkimukseksi tai muinaistieteeksi. Se on humanistinen tiede + tai oikeammin joukko tieteitä, jotka tutkivat ihmisen menneisyyttä. + Tutkimusta tehdään analysoimalla muinaisjäännöksiä eli niitä jälkiä, + joita ihmisten toiminta on jättänyt maaperään tai vesistöjen + pohjaan.""", project) + + assert vw_ensemble._model is not None + assert len(results) > 0