Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

vw-multi backend: use input from other projects #256

Merged
merged 9 commits into from
Feb 5, 2019
11 changes: 10 additions & 1 deletion annif/backend/fasttext.py
Original file line number Diff line number Diff line change
Expand Up @@ -107,9 +107,18 @@ def train(self, corpus, project):
self._create_train_file(corpus, project)
self._create_model()

def _predict_chunks(self, chunktexts, project, limit):
normalized_chunks = []
for chunktext in chunktexts:
normalized = self._normalize_text(project, chunktext)
if normalized != '':
normalized_chunks.append(normalized)
return self._model.predict(normalized_chunks, limit)
osma marked this conversation as resolved.
Show resolved Hide resolved

def _analyze_chunks(self, chunktexts, project):
limit = int(self.params['limit'])
chunklabels, chunkscores = self._model.predict(chunktexts, limit)
chunklabels, chunkscores = self._predict_chunks(
chunktexts, project, limit)
label_scores = collections.defaultdict(float)
for labels, scores in zip(chunklabels, chunkscores):
for label, score in zip(labels, scores):
Expand Down
5 changes: 1 addition & 4 deletions annif/backend/mixins.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,10 +24,7 @@ def _analyze(self, text, project, params):
chunksize = int(params['chunksize'])
chunktexts = []
for i in range(0, len(sentences), chunksize):
chunktext = ' '.join(sentences[i:i + chunksize])
normalized = self._normalize_text(project, chunktext)
if normalized != '':
chunktexts.append(normalized)
chunktexts.append(' '.join(sentences[i:i + chunksize]))
self.debug('Split sentences into {} chunks'.format(len(chunktexts)))
if len(chunktexts) == 0: # nothing to analyze, empty result
return ListAnalysisResult(hits=[], subject_index=project.subjects)
Expand Down
55 changes: 47 additions & 8 deletions annif/backend/vw_multi.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
import annif.util
from vowpalwabbit import pyvw
import numpy as np
from annif.hit import VectorAnalysisResult
from annif.hit import ListAnalysisResult, VectorAnalysisResult
from annif.exception import ConfigurationException, NotInitializedException
from . import backend
from . import mixins
Expand All @@ -23,7 +23,7 @@ class VWMultiBackend(mixins.ChunkingBackend, backend.AnnifBackend):
# where allowed_values is either a type or a list of allowed values
# and default_value may be None, to let VW decide by itself
'bit_precision': (int, None),
'ngram': (int, None),
'ngram': (lambda x: '_{}'.format(int(x)), None),
'learning_rate': (float, None),
'loss_function': (['squared', 'logistic', 'hinge'], 'logistic'),
'l1': (float, None),
Expand All @@ -35,6 +35,8 @@ class VWMultiBackend(mixins.ChunkingBackend, backend.AnnifBackend):
DEFAULT_ALGORITHM = 'oaa'
SUPPORTED_ALGORITHMS = ('oaa', 'ect', 'log_multi', 'multilabel_oaa')

DEFAULT_INPUTS = '_text_'

MODEL_FILE = 'vw-model'
TRAIN_FILE = 'vw-train.txt'

Expand Down Expand Up @@ -67,11 +69,20 @@ def algorithm(self):
backend_id=self.backend_id)
return algorithm

@property
def inputs(self):
inputs = self.params.get('inputs', self.DEFAULT_INPUTS)
return inputs.split(',')

@staticmethod
def _cleanup_text(text):
# colon and pipe chars have special meaning in VW and must be avoided
return text.replace(':', '').replace('|', '')

@staticmethod
def _normalize_text(project, text):
ntext = ' '.join(project.analyzer.tokenize_words(text))
# colon and pipe chars have special meaning in VW and must be avoided
return ntext.replace(':', '').replace('|', '')
return VWMultiBackend._cleanup_text(ntext)

@staticmethod
def _write_train_file(examples, filename):
Expand All @@ -91,16 +102,39 @@ def _uris_to_subject_ids(project, uris):
def _format_examples(self, project, text, uris):
subject_ids = self._uris_to_subject_ids(project, uris)
if self.algorithm == 'multilabel_oaa':
yield '{} | {}'.format(','.join(map(str, subject_ids)), text)
yield '{} {}'.format(','.join(map(str, subject_ids)), text)
else:
for subject_id in subject_ids:
yield '{} | {}'.format(subject_id + 1, text)
yield '{} {}'.format(subject_id + 1, text)

def _inputs_to_exampletext(self, project, text):
namespaces = {}
for input in self.inputs:
if input == '_text_':
normalized = self._normalize_text(project, text)
if normalized != '':
namespaces['_text_'] = normalized
else:
proj = annif.project.get_project(input)
result = proj.analyze(text)
features = [
'{}:{}'.format(
self._cleanup_text(
hit.uri),
hit.score) for hit in result.hits]
namespaces[input] = ' '.join(features)
if not namespaces:
return None
return ' '.join(['|{} {}'.format(namespace, featurestr)
for namespace, featurestr in namespaces.items()])
osma marked this conversation as resolved.
Show resolved Hide resolved

def _create_train_file(self, corpus, project):
self.info('creating VW train file')
examples = []
for doc in corpus.documents:
text = self._normalize_text(project, doc.text)
text = self._inputs_to_exampletext(project, doc.text)
if not text:
continue
examples.extend(self._format_examples(project, text, doc.uris))
random.shuffle(examples)
annif.util.atomic_save(examples,
Expand Down Expand Up @@ -152,7 +186,10 @@ def train(self, corpus, project):
def _analyze_chunks(self, chunktexts, project):
results = []
for chunktext in chunktexts:
example = ' | {}'.format(chunktext)
exampletext = self._inputs_to_exampletext(project, chunktext)
if not exampletext:
continue
example = ' {}'.format(exampletext)
result = self._model.predict(example)
if self.algorithm == 'multilabel_oaa':
# result is a list of subject IDs - need to vectorize
Expand All @@ -167,5 +204,7 @@ def _analyze_chunks(self, chunktexts, project):
else:
result = np.array(result)
results.append(result)
if len(results) == 0: # empty result
osma marked this conversation as resolved.
Show resolved Hide resolved
return ListAnalysisResult(hits=[], subject_index=project.subjects)
return VectorAnalysisResult(
np.array(results).mean(axis=0), project.subjects)
16 changes: 16 additions & 0 deletions tests/test_backend_vw_multi.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,22 @@ def test_vw_multi_train(datadir, document_corpus, project):
assert datadir.join('vw-model').size() > 0


def test_vw_multi_train_from_project(app, datadir, document_corpus, project):
vw_type = annif.backend.get_backend('vw_multi')
vw = vw_type(
backend_id='vw_multi',
params={
'chunksize': 4,
'inputs': '_text_,dummy-en'},
datadir=str(datadir))

with app.app_context():
vw.train(document_corpus, project)
assert vw._model is not None
assert datadir.join('vw-model').exists()
assert datadir.join('vw-model').size() > 0


def test_vw_multi_train_multiple_passes(datadir, document_corpus, project):
vw_type = annif.backend.get_backend('vw_multi')
vw = vw_type(
Expand Down