Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

First implementation of VW regular backend #249

Merged
merged 24 commits into from
Jan 29, 2019
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
24 commits
Select commit Hold shift + click to select a range
db79408
Initial implementation of Vowpal Wabbit backend
osma Jan 14, 2019
239a378
add Travis package dependencies needed by VW install
osma Jan 14, 2019
8a1ef9b
add python-dev travis dependency, hoping it will fix VW build
osma Jan 14, 2019
78ccb4e
Write examples to file before training in vw backend; increase weight…
osma Jan 14, 2019
9ebe881
use apt addon to install packages; switch libboost-python default ver…
osma Jan 15, 2019
2522c37
only switch libboost-python versions for python 3.5 under travis
osma Jan 15, 2019
91701fa
add missing sudo commands
osma Jan 15, 2019
d5a51fa
add missing install section to .travis.yml
osma Jan 15, 2019
b9fa078
clean up garbage left after conflict merging
osma Jan 15, 2019
2e6359b
install deb packages using apt addon (even though they're unnecessary…
osma Jan 15, 2019
cf59bb3
rename load_corpus to train, to match changes on master branch
osma Jan 15, 2019
a506e55
Initial unit tests for vw backend
osma Jan 15, 2019
cc3fb3d
Implement splitting to chunks in vw backend
osma Jan 15, 2019
ab11514
refactor: extract chunking functionality into mixin class
osma Jan 15, 2019
2ccc243
rename vw backend to vw_multi to better reflect its purpose and make …
osma Jan 17, 2019
6b7c2d3
Initial support for VW parameters set in config file
osma Jan 21, 2019
497eac7
avoid passing pipe characters to VW
osma Jan 23, 2019
89d4a52
ChunkingBackend: handle special case when there is no text to analyze
osma Jan 29, 2019
412fc9d
Merge branch 'master' into vw-backend
osma Jan 29, 2019
a331419
Add missing import
osma Jan 29, 2019
01485ce
Remove unused method
osma Jan 29, 2019
3997eb2
Add "pragma: no cover" annotation for abstract method
osma Jan 29, 2019
72b8594
Additional tests for vw_multi backend
osma Jan 29, 2019
92e737e
Further tests for vw_multi backend
osma Jan 29, 2019
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
11 changes: 10 additions & 1 deletion .travis.yml
Original file line number Diff line number Diff line change
Expand Up @@ -9,18 +9,27 @@ addons:
packages:
- libvoikko1
- voikko-fi
- libboost-program-options-dev
- libboost-python-dev
- zlib1g-dev
cache: pip
before_install:
- export BOTO_CONFIG=/dev/null
install:
- pip install pipenv
- pip install --upgrade pytest
- pipenv install --dev --skip-lock
- travis_wait 30 python -m nltk.downloader punkt
# For Python 3.5, also install optional dependencies that were not specified in Pipfile
# For other Python versions we will only run the tests that depend on pure Python modules
# - fastText dependencies
- if [[ $TRAVIS_PYTHON_VERSION == '3.5' ]]; then pip install fasttextmirror; fi
# - voikko dependencies
- if [[ $TRAVIS_PYTHON_VERSION == '3.5' ]]; then pip install voikko; fi
- travis_wait 30 python -m nltk.downloader punkt
# - vw dependencies
- if [[ $TRAVIS_PYTHON_VERSION == '3.5' ]]; then sudo ln -sf /usr/lib/x86_64-linux-gnu/libboost_python-py35.a /usr/lib/x86_64-linux-gnu/libboost_python.a; fi
- if [[ $TRAVIS_PYTHON_VERSION == '3.5' ]]; then sudo ln -sf /usr/lib/x86_64-linux-gnu/libboost_python-py35.so /usr/lib/x86_64-linux-gnu/libboost_python.so; fi
- if [[ $TRAVIS_PYTHON_VERSION == '3.5' ]]; then pip install vowpalwabbit; fi
script:
- pytest --cov=./
after_success:
Expand Down
7 changes: 7 additions & 0 deletions annif/backend/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,3 +36,10 @@ def get_backend(backend_id):
register_backend(fasttext.FastTextBackend)
except ImportError:
annif.logger.debug("fastText not available, not enabling fasttext backend")

try:
from . import vw_multi
register_backend(vw_multi.VWMultiBackend)
except ImportError:
annif.logger.debug(
"vowpalwabbit not available, not enabling vw_multi backend")
20 changes: 2 additions & 18 deletions annif/backend/fasttext.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,9 +7,10 @@
from annif.exception import NotInitializedException
import fastText
from . import backend
from . import mixins


class FastTextBackend(backend.AnnifBackend):
class FastTextBackend(mixins.ChunkingBackend, backend.AnnifBackend):
"""fastText backend for Annif"""

name = "fasttext"
Expand Down Expand Up @@ -125,20 +126,3 @@ def _analyze_chunks(self, chunktexts, project):
label=subject[1],
score=score / len(chunktexts)))
return ListAnalysisResult(results, project.subjects)

def _analyze(self, text, project, params):
self.initialize()
self.debug('Analyzing text "{}..." (len={})'.format(
text[:20], len(text)))
sentences = project.analyzer.tokenize_sentences(text)
self.debug('Found {} sentences'.format(len(sentences)))
chunksize = int(params['chunksize'])
chunktexts = []
for i in range(0, len(sentences), chunksize):
chunktext = ' '.join(sentences[i:i + chunksize])
normalized = self._normalize_text(project, chunktext)
if normalized != '':
chunktexts.append(normalized)
self.debug('Split sentences into {} chunks'.format(len(chunktexts)))

return self._analyze_chunks(chunktexts, project)
34 changes: 34 additions & 0 deletions annif/backend/mixins.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,34 @@
"""Annif backend mixins that can be used to implement features"""


import abc
from annif.hit import ListAnalysisResult


class ChunkingBackend(metaclass=abc.ABCMeta):
"""Annif backend mixin that implements chunking of input"""

@abc.abstractmethod
def _analyze_chunks(self, chunktexts, project):
"""Analyze the chunked text; should be implemented by the subclass
inheriting this mixin"""

pass # pragma: no cover

def _analyze(self, text, project, params):
self.initialize()
self.debug('Analyzing text "{}..." (len={})'.format(
text[:20], len(text)))
sentences = project.analyzer.tokenize_sentences(text)
self.debug('Found {} sentences'.format(len(sentences)))
chunksize = int(params['chunksize'])
chunktexts = []
for i in range(0, len(sentences), chunksize):
chunktext = ' '.join(sentences[i:i + chunksize])
normalized = self._normalize_text(project, chunktext)
if normalized != '':
chunktexts.append(normalized)
self.debug('Split sentences into {} chunks'.format(len(chunktexts)))
if len(chunktexts) == 0: # nothing to analyze, empty result
return ListAnalysisResult(hits=[], subject_index=project.subjects)
return self._analyze_chunks(chunktexts, project)
126 changes: 126 additions & 0 deletions annif/backend/vw_multi.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,126 @@
"""Annif backend using the Vorpal Wabbit multiclass and multilabel
classifiers"""

import random
import os.path
import annif.util
from vowpalwabbit import pyvw
import numpy as np
from annif.hit import AnalysisHit, VectorAnalysisResult
from annif.exception import ConfigurationException, NotInitializedException
from . import backend
from . import mixins


class VWMultiBackend(mixins.ChunkingBackend, backend.AnnifBackend):
"""Vorpal Wabbit multiclass/multilabel backend for Annif"""

name = "vw_multi"
needs_subject_index = True

VW_PARAMS = {
# each param specifier is a pair (allowed_values, default_value)
# where allowed_values is either a type or a list of allowed values
# and default_value may be None, to let VW decide by itself
'bit_precision': (int, None),
'learning_rate': (float, None),
'loss_function': (['squared', 'logistic', 'hinge'], 'logistic'),
'l1': (float, None),
'l2': (float, None),
'passes': (int, None)
}

MODEL_FILE = 'vw-model'
TRAIN_FILE = 'vw-train.txt'

# defaults for uninitialized instances
_model = None

def initialize(self):
if self._model is None:
path = os.path.join(self._get_datadir(), self.MODEL_FILE)
self.debug('loading VW model from {}'.format(path))
if os.path.exists(path):
self._model = pyvw.vw(
i=path,
quiet=True,
loss_function='logistic',
probabilities=True)
self.debug('loaded model {}'.format(str(self._model)))
else:
raise NotInitializedException(
'model {} not found'.format(path),
backend_id=self.backend_id)

@classmethod
def _normalize_text(cls, project, text):
ntext = ' '.join(project.analyzer.tokenize_words(text))
# colon and pipe chars have special meaning in VW and must be avoided
return ntext.replace(':', '').replace('|', '')

def _write_train_file(self, examples, filename):
with open(filename, 'w') as trainfile:
for ex in examples:
print(ex, file=trainfile)

def _create_train_file(self, corpus, project):
self.info('creating VW train file')
examples = []
for doc in corpus.documents:
text = self._normalize_text(project, doc.text)
for uri in doc.uris:
subject_id = project.subjects.by_uri(uri)
if subject_id is None:
continue
exstr = '{} | {}'.format(subject_id + 1, text)
examples.append(exstr)
random.shuffle(examples)
annif.util.atomic_save(examples,
self._get_datadir(),
self.TRAIN_FILE,
method=self._write_train_file)

def _convert_param(self, param, val):
pspec, _ = self.VW_PARAMS[param]
if isinstance(pspec, list):
if val in pspec:
return val
raise ConfigurationException(
"{} is not a valid value for {} (allowed: {})".format(
val, param, ', '.join(pspec)), backend_id=self.backend_id)
try:
return pspec(val)
except ValueError:
raise ConfigurationException(
"The {} value {} cannot be converted to {}".format(
param, val, pspec), backend_id=self.backend_id)

def _create_model(self, project):
self.info('creating VW model')
trainpath = os.path.join(self._get_datadir(), self.TRAIN_FILE)
params = {param: defaultval
for param, (_, defaultval) in self.VW_PARAMS.items()
if defaultval is not None}
params.update({param: self._convert_param(param, val)
for param, val in self.params.items()
if param in self.VW_PARAMS})
self.debug("model parameters: {}".format(params))
self._model = pyvw.vw(
oaa=len(project.subjects),
probabilities=True,
data=trainpath,
**params)
modelpath = os.path.join(self._get_datadir(), self.MODEL_FILE)
self._model.save(modelpath)

def train(self, corpus, project):
self._create_train_file(corpus, project)
self._create_model(project)

def _analyze_chunks(self, chunktexts, project):
results = []
for chunktext in chunktexts:
example = ' | {}'.format(chunktext)
results.append(np.array(self._model.predict(example)))
return VectorAnalysisResult(
np.array(results).mean(axis=0), project.subjects)
1 change: 1 addition & 0 deletions setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@ def read(fname):
extras_require={
'fasttext': ['fasttextmirror'],
'voikko': ['voikko'],
'vw': ['vowpalwabbit'],
},
entry_points={
'console_scripts': ['annif=annif.cli:cli']},
Expand Down
111 changes: 111 additions & 0 deletions tests/test_backend_vw.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,111 @@
"""Unit tests for the fastText backend in Annif"""

import pytest
import annif.backend
import annif.corpus
from annif.exception import ConfigurationException, NotInitializedException

pytest.importorskip("annif.backend.vw_multi")


@pytest.fixture(scope='function')
def vw_corpus(tmpdir):
"""return a small document corpus for testing VW training"""
tmpfile = tmpdir.join('document.tsv')
tmpfile.write("nonexistent\thttp://example.com/nonexistent\n" +
"arkeologia\thttp://www.yso.fi/onto/yso/p1265")
return annif.corpus.DocumentFile(str(tmpfile))


def test_vw_analyze_no_model(datadir, project):
vw_type = annif.backend.get_backend('vw_multi')
vw = vw_type(
backend_id='vw_multi',
params={'chunksize': 4},
datadir=str(datadir))

with pytest.raises(NotInitializedException):
results = vw.analyze("example text", project)


def test_vw_train(datadir, document_corpus, project):
vw_type = annif.backend.get_backend('vw_multi')
vw = vw_type(
backend_id='vw_multi',
params={
'chunksize': 4,
'learning_rate': 0.5,
'loss_function': 'hinge'},
datadir=str(datadir))

vw.train(document_corpus, project)
assert vw._model is not None
assert datadir.join('vw-model').exists()
assert datadir.join('vw-model').size() > 0


def test_vw_train_invalid_loss_function(datadir, project, vw_corpus):
vw_type = annif.backend.get_backend('vw_multi')
vw = vw_type(
backend_id='vw_multi',
params={'chunksize': 4, 'loss_function': 'invalid'},
datadir=str(datadir))

with pytest.raises(ConfigurationException):
vw.train(vw_corpus, project)


def test_vw_train_invalid_learning_rate(datadir, project, vw_corpus):
vw_type = annif.backend.get_backend('vw_multi')
vw = vw_type(
backend_id='vw_multi',
params={'chunksize': 4, 'learning_rate': 'high'},
datadir=str(datadir))

with pytest.raises(ConfigurationException):
vw.train(vw_corpus, project)


def test_vw_train_unknown_subject(datadir, project, vw_corpus):
vw_type = annif.backend.get_backend('vw_multi')
vw = vw_type(
backend_id='vw_multi',
params={'chunksize': 4},
datadir=str(datadir))

vw.train(vw_corpus, project)
assert vw._model is not None
assert datadir.join('vw-model').exists()
assert datadir.join('vw-model').size() > 0


def test_vw_analyze(datadir, project):
vw_type = annif.backend.get_backend('vw_multi')
vw = vw_type(
backend_id='vw_multi',
params={'chunksize': 4},
datadir=str(datadir))

results = vw.analyze("""Arkeologiaa sanotaan joskus myös
muinaistutkimukseksi tai muinaistieteeksi. Se on humanistinen tiede
tai oikeammin joukko tieteitä, jotka tutkivat ihmisen menneisyyttä.
Tutkimusta tehdään analysoimalla muinaisjäännöksiä eli niitä jälkiä,
joita ihmisten toiminta on jättänyt maaperään tai vesistöjen
pohjaan.""", project)

assert len(results) > 0
assert 'http://www.yso.fi/onto/yso/p1265' in [
result.uri for result in results]
assert 'arkeologia' in [result.label for result in results]


def test_vw_analyze_empty(datadir, project):
vw_type = annif.backend.get_backend('vw_multi')
vw = vw_type(
backend_id='vw_multi',
params={'chunksize': 4},
datadir=str(datadir))

results = vw.analyze("...", project)

assert len(results) == 0