-
Notifications
You must be signed in to change notification settings - Fork 41
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Merge pull request #249 from NatLibFi/vw-backend
First implementation of VW regular backend
- Loading branch information
Showing
7 changed files
with
291 additions
and
19 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,34 @@ | ||
"""Annif backend mixins that can be used to implement features""" | ||
|
||
|
||
import abc | ||
from annif.hit import ListAnalysisResult | ||
|
||
|
||
class ChunkingBackend(metaclass=abc.ABCMeta): | ||
"""Annif backend mixin that implements chunking of input""" | ||
|
||
@abc.abstractmethod | ||
def _analyze_chunks(self, chunktexts, project): | ||
"""Analyze the chunked text; should be implemented by the subclass | ||
inheriting this mixin""" | ||
|
||
pass # pragma: no cover | ||
|
||
def _analyze(self, text, project, params): | ||
self.initialize() | ||
self.debug('Analyzing text "{}..." (len={})'.format( | ||
text[:20], len(text))) | ||
sentences = project.analyzer.tokenize_sentences(text) | ||
self.debug('Found {} sentences'.format(len(sentences))) | ||
chunksize = int(params['chunksize']) | ||
chunktexts = [] | ||
for i in range(0, len(sentences), chunksize): | ||
chunktext = ' '.join(sentences[i:i + chunksize]) | ||
normalized = self._normalize_text(project, chunktext) | ||
if normalized != '': | ||
chunktexts.append(normalized) | ||
self.debug('Split sentences into {} chunks'.format(len(chunktexts))) | ||
if len(chunktexts) == 0: # nothing to analyze, empty result | ||
return ListAnalysisResult(hits=[], subject_index=project.subjects) | ||
return self._analyze_chunks(chunktexts, project) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,126 @@ | ||
"""Annif backend using the Vorpal Wabbit multiclass and multilabel | ||
classifiers""" | ||
|
||
import random | ||
import os.path | ||
import annif.util | ||
from vowpalwabbit import pyvw | ||
import numpy as np | ||
from annif.hit import AnalysisHit, VectorAnalysisResult | ||
from annif.exception import ConfigurationException, NotInitializedException | ||
from . import backend | ||
from . import mixins | ||
|
||
|
||
class VWMultiBackend(mixins.ChunkingBackend, backend.AnnifBackend): | ||
"""Vorpal Wabbit multiclass/multilabel backend for Annif""" | ||
|
||
name = "vw_multi" | ||
needs_subject_index = True | ||
|
||
VW_PARAMS = { | ||
# each param specifier is a pair (allowed_values, default_value) | ||
# where allowed_values is either a type or a list of allowed values | ||
# and default_value may be None, to let VW decide by itself | ||
'bit_precision': (int, None), | ||
'learning_rate': (float, None), | ||
'loss_function': (['squared', 'logistic', 'hinge'], 'logistic'), | ||
'l1': (float, None), | ||
'l2': (float, None), | ||
'passes': (int, None) | ||
} | ||
|
||
MODEL_FILE = 'vw-model' | ||
TRAIN_FILE = 'vw-train.txt' | ||
|
||
# defaults for uninitialized instances | ||
_model = None | ||
|
||
def initialize(self): | ||
if self._model is None: | ||
path = os.path.join(self._get_datadir(), self.MODEL_FILE) | ||
self.debug('loading VW model from {}'.format(path)) | ||
if os.path.exists(path): | ||
self._model = pyvw.vw( | ||
i=path, | ||
quiet=True, | ||
loss_function='logistic', | ||
probabilities=True) | ||
self.debug('loaded model {}'.format(str(self._model))) | ||
else: | ||
raise NotInitializedException( | ||
'model {} not found'.format(path), | ||
backend_id=self.backend_id) | ||
|
||
@classmethod | ||
def _normalize_text(cls, project, text): | ||
ntext = ' '.join(project.analyzer.tokenize_words(text)) | ||
# colon and pipe chars have special meaning in VW and must be avoided | ||
return ntext.replace(':', '').replace('|', '') | ||
|
||
def _write_train_file(self, examples, filename): | ||
with open(filename, 'w') as trainfile: | ||
for ex in examples: | ||
print(ex, file=trainfile) | ||
|
||
def _create_train_file(self, corpus, project): | ||
self.info('creating VW train file') | ||
examples = [] | ||
for doc in corpus.documents: | ||
text = self._normalize_text(project, doc.text) | ||
for uri in doc.uris: | ||
subject_id = project.subjects.by_uri(uri) | ||
if subject_id is None: | ||
continue | ||
exstr = '{} | {}'.format(subject_id + 1, text) | ||
examples.append(exstr) | ||
random.shuffle(examples) | ||
annif.util.atomic_save(examples, | ||
self._get_datadir(), | ||
self.TRAIN_FILE, | ||
method=self._write_train_file) | ||
|
||
def _convert_param(self, param, val): | ||
pspec, _ = self.VW_PARAMS[param] | ||
if isinstance(pspec, list): | ||
if val in pspec: | ||
return val | ||
raise ConfigurationException( | ||
"{} is not a valid value for {} (allowed: {})".format( | ||
val, param, ', '.join(pspec)), backend_id=self.backend_id) | ||
try: | ||
return pspec(val) | ||
except ValueError: | ||
raise ConfigurationException( | ||
"The {} value {} cannot be converted to {}".format( | ||
param, val, pspec), backend_id=self.backend_id) | ||
|
||
def _create_model(self, project): | ||
self.info('creating VW model') | ||
trainpath = os.path.join(self._get_datadir(), self.TRAIN_FILE) | ||
params = {param: defaultval | ||
for param, (_, defaultval) in self.VW_PARAMS.items() | ||
if defaultval is not None} | ||
params.update({param: self._convert_param(param, val) | ||
for param, val in self.params.items() | ||
if param in self.VW_PARAMS}) | ||
self.debug("model parameters: {}".format(params)) | ||
self._model = pyvw.vw( | ||
oaa=len(project.subjects), | ||
probabilities=True, | ||
data=trainpath, | ||
**params) | ||
modelpath = os.path.join(self._get_datadir(), self.MODEL_FILE) | ||
self._model.save(modelpath) | ||
|
||
def train(self, corpus, project): | ||
self._create_train_file(corpus, project) | ||
self._create_model(project) | ||
|
||
def _analyze_chunks(self, chunktexts, project): | ||
results = [] | ||
for chunktext in chunktexts: | ||
example = ' | {}'.format(chunktext) | ||
results.append(np.array(self._model.predict(example))) | ||
return VectorAnalysisResult( | ||
np.array(results).mean(axis=0), project.subjects) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,111 @@ | ||
"""Unit tests for the fastText backend in Annif""" | ||
|
||
import pytest | ||
import annif.backend | ||
import annif.corpus | ||
from annif.exception import ConfigurationException, NotInitializedException | ||
|
||
pytest.importorskip("annif.backend.vw_multi") | ||
|
||
|
||
@pytest.fixture(scope='function') | ||
def vw_corpus(tmpdir): | ||
"""return a small document corpus for testing VW training""" | ||
tmpfile = tmpdir.join('document.tsv') | ||
tmpfile.write("nonexistent\thttp://example.com/nonexistent\n" + | ||
"arkeologia\thttp://www.yso.fi/onto/yso/p1265") | ||
return annif.corpus.DocumentFile(str(tmpfile)) | ||
|
||
|
||
def test_vw_analyze_no_model(datadir, project): | ||
vw_type = annif.backend.get_backend('vw_multi') | ||
vw = vw_type( | ||
backend_id='vw_multi', | ||
params={'chunksize': 4}, | ||
datadir=str(datadir)) | ||
|
||
with pytest.raises(NotInitializedException): | ||
results = vw.analyze("example text", project) | ||
|
||
|
||
def test_vw_train(datadir, document_corpus, project): | ||
vw_type = annif.backend.get_backend('vw_multi') | ||
vw = vw_type( | ||
backend_id='vw_multi', | ||
params={ | ||
'chunksize': 4, | ||
'learning_rate': 0.5, | ||
'loss_function': 'hinge'}, | ||
datadir=str(datadir)) | ||
|
||
vw.train(document_corpus, project) | ||
assert vw._model is not None | ||
assert datadir.join('vw-model').exists() | ||
assert datadir.join('vw-model').size() > 0 | ||
|
||
|
||
def test_vw_train_invalid_loss_function(datadir, project, vw_corpus): | ||
vw_type = annif.backend.get_backend('vw_multi') | ||
vw = vw_type( | ||
backend_id='vw_multi', | ||
params={'chunksize': 4, 'loss_function': 'invalid'}, | ||
datadir=str(datadir)) | ||
|
||
with pytest.raises(ConfigurationException): | ||
vw.train(vw_corpus, project) | ||
|
||
|
||
def test_vw_train_invalid_learning_rate(datadir, project, vw_corpus): | ||
vw_type = annif.backend.get_backend('vw_multi') | ||
vw = vw_type( | ||
backend_id='vw_multi', | ||
params={'chunksize': 4, 'learning_rate': 'high'}, | ||
datadir=str(datadir)) | ||
|
||
with pytest.raises(ConfigurationException): | ||
vw.train(vw_corpus, project) | ||
|
||
|
||
def test_vw_train_unknown_subject(datadir, project, vw_corpus): | ||
vw_type = annif.backend.get_backend('vw_multi') | ||
vw = vw_type( | ||
backend_id='vw_multi', | ||
params={'chunksize': 4}, | ||
datadir=str(datadir)) | ||
|
||
vw.train(vw_corpus, project) | ||
assert vw._model is not None | ||
assert datadir.join('vw-model').exists() | ||
assert datadir.join('vw-model').size() > 0 | ||
|
||
|
||
def test_vw_analyze(datadir, project): | ||
vw_type = annif.backend.get_backend('vw_multi') | ||
vw = vw_type( | ||
backend_id='vw_multi', | ||
params={'chunksize': 4}, | ||
datadir=str(datadir)) | ||
|
||
results = vw.analyze("""Arkeologiaa sanotaan joskus myös | ||
muinaistutkimukseksi tai muinaistieteeksi. Se on humanistinen tiede | ||
tai oikeammin joukko tieteitä, jotka tutkivat ihmisen menneisyyttä. | ||
Tutkimusta tehdään analysoimalla muinaisjäännöksiä eli niitä jälkiä, | ||
joita ihmisten toiminta on jättänyt maaperään tai vesistöjen | ||
pohjaan.""", project) | ||
|
||
assert len(results) > 0 | ||
assert 'http://www.yso.fi/onto/yso/p1265' in [ | ||
result.uri for result in results] | ||
assert 'arkeologia' in [result.label for result in results] | ||
|
||
|
||
def test_vw_analyze_empty(datadir, project): | ||
vw_type = annif.backend.get_backend('vw_multi') | ||
vw = vw_type( | ||
backend_id='vw_multi', | ||
params={'chunksize': 4}, | ||
datadir=str(datadir)) | ||
|
||
results = vw.analyze("...", project) | ||
|
||
assert len(results) == 0 |