diff --git a/annif/backend/fasttext.py b/annif/backend/fasttext.py index bf3abe719..5c8b6627c 100644 --- a/annif/backend/fasttext.py +++ b/annif/backend/fasttext.py @@ -107,9 +107,15 @@ def train(self, corpus, project): self._create_train_file(corpus, project) self._create_model() + def _predict_chunks(self, chunktexts, project, limit): + return self._model.predict(list( + filter(None, [self._normalize_text(project, chunktext) + for chunktext in chunktexts])), limit) + def _analyze_chunks(self, chunktexts, project): limit = int(self.params['limit']) - chunklabels, chunkscores = self._model.predict(chunktexts, limit) + chunklabels, chunkscores = self._predict_chunks( + chunktexts, project, limit) label_scores = collections.defaultdict(float) for labels, scores in zip(chunklabels, chunkscores): for label, score in zip(labels, scores): diff --git a/annif/backend/mixins.py b/annif/backend/mixins.py index f109ac9b5..f89aa5f1c 100644 --- a/annif/backend/mixins.py +++ b/annif/backend/mixins.py @@ -24,10 +24,7 @@ def _analyze(self, text, project, params): chunksize = int(params['chunksize']) chunktexts = [] for i in range(0, len(sentences), chunksize): - chunktext = ' '.join(sentences[i:i + chunksize]) - normalized = self._normalize_text(project, chunktext) - if normalized != '': - chunktexts.append(normalized) + chunktexts.append(' '.join(sentences[i:i + chunksize])) self.debug('Split sentences into {} chunks'.format(len(chunktexts))) if len(chunktexts) == 0: # nothing to analyze, empty result return ListAnalysisResult(hits=[], subject_index=project.subjects) diff --git a/annif/backend/vw_multi.py b/annif/backend/vw_multi.py index ed97b32e3..1f6d47961 100644 --- a/annif/backend/vw_multi.py +++ b/annif/backend/vw_multi.py @@ -6,7 +6,7 @@ import annif.util from vowpalwabbit import pyvw import numpy as np -from annif.hit import VectorAnalysisResult +from annif.hit import ListAnalysisResult, VectorAnalysisResult from annif.exception import ConfigurationException, NotInitializedException from . import backend from . import mixins @@ -23,7 +23,7 @@ class VWMultiBackend(mixins.ChunkingBackend, backend.AnnifBackend): # where allowed_values is either a type or a list of allowed values # and default_value may be None, to let VW decide by itself 'bit_precision': (int, None), - 'ngram': (int, None), + 'ngram': (lambda x: '_{}'.format(int(x)), None), 'learning_rate': (float, None), 'loss_function': (['squared', 'logistic', 'hinge'], 'logistic'), 'l1': (float, None), @@ -35,6 +35,8 @@ class VWMultiBackend(mixins.ChunkingBackend, backend.AnnifBackend): DEFAULT_ALGORITHM = 'oaa' SUPPORTED_ALGORITHMS = ('oaa', 'ect', 'log_multi', 'multilabel_oaa') + DEFAULT_INPUTS = '_text_' + MODEL_FILE = 'vw-model' TRAIN_FILE = 'vw-train.txt' @@ -67,11 +69,20 @@ def algorithm(self): backend_id=self.backend_id) return algorithm + @property + def inputs(self): + inputs = self.params.get('inputs', self.DEFAULT_INPUTS) + return inputs.split(',') + + @staticmethod + def _cleanup_text(text): + # colon and pipe chars have special meaning in VW and must be avoided + return text.replace(':', '').replace('|', '') + @staticmethod def _normalize_text(project, text): ntext = ' '.join(project.analyzer.tokenize_words(text)) - # colon and pipe chars have special meaning in VW and must be avoided - return ntext.replace(':', '').replace('|', '') + return VWMultiBackend._cleanup_text(ntext) @staticmethod def _write_train_file(examples, filename): @@ -91,16 +102,40 @@ def _uris_to_subject_ids(project, uris): def _format_examples(self, project, text, uris): subject_ids = self._uris_to_subject_ids(project, uris) if self.algorithm == 'multilabel_oaa': - yield '{} | {}'.format(','.join(map(str, subject_ids)), text) + yield '{} {}'.format(','.join(map(str, subject_ids)), text) else: for subject_id in subject_ids: - yield '{} | {}'.format(subject_id + 1, text) + yield '{} {}'.format(subject_id + 1, text) + + def _get_input(self, input, project, text): + if input == '_text_': + return self._normalize_text(project, text) + else: + proj = annif.project.get_project(input) + result = proj.analyze(text) + features = [ + '{}:{}'.format(self._cleanup_text(hit.uri), hit.score) + for hit in result.hits] + return ' '.join(features) + + def _inputs_to_exampletext(self, project, text): + namespaces = {} + for input in self.inputs: + inputtext = self._get_input(input, project, text) + if inputtext: + namespaces[input] = inputtext + if not namespaces: + return None + return ' '.join(['|{} {}'.format(namespace, featurestr) + for namespace, featurestr in namespaces.items()]) def _create_train_file(self, corpus, project): self.info('creating VW train file') examples = [] for doc in corpus.documents: - text = self._normalize_text(project, doc.text) + text = self._inputs_to_exampletext(project, doc.text) + if not text: + continue examples.extend(self._format_examples(project, text, doc.uris)) random.shuffle(examples) annif.util.atomic_save(examples, @@ -149,23 +184,31 @@ def train(self, corpus, project): self._create_train_file(corpus, project) self._create_model(project) + def _convert_result(self, result, project): + if self.algorithm == 'multilabel_oaa': + # result is a list of subject IDs - need to vectorize + mask = np.zeros(len(project.subjects)) + mask[result] = 1.0 + return mask + elif isinstance(result, int): + # result is a single integer - need to one-hot-encode + mask = np.zeros(len(project.subjects)) + mask[result - 1] = 1.0 + return mask + else: + # result is a list of scores (probabilities or binary 1/0) + return np.array(result) + def _analyze_chunks(self, chunktexts, project): results = [] for chunktext in chunktexts: - example = ' | {}'.format(chunktext) + exampletext = self._inputs_to_exampletext(project, chunktext) + if not exampletext: + continue + example = ' {}'.format(exampletext) result = self._model.predict(example) - if self.algorithm == 'multilabel_oaa': - # result is a list of subject IDs - need to vectorize - mask = np.zeros(len(project.subjects)) - mask[result] = 1.0 - result = mask - elif isinstance(result, int): - # result is a single integer - need to one-hot-encode - mask = np.zeros(len(project.subjects)) - mask[result - 1] = 1.0 - result = mask - else: - result = np.array(result) - results.append(result) + results.append(self._convert_result(result, project)) + if not results: # empty result + return ListAnalysisResult(hits=[], subject_index=project.subjects) return VectorAnalysisResult( np.array(results).mean(axis=0), project.subjects) diff --git a/tests/test_backend_vw_multi.py b/tests/test_backend_vw_multi.py index a0594ae2b..359fcbdf7 100644 --- a/tests/test_backend_vw_multi.py +++ b/tests/test_backend_vw_multi.py @@ -13,7 +13,8 @@ def vw_corpus(tmpdir): """return a small document corpus for testing VW training""" tmpfile = tmpdir.join('document.tsv') tmpfile.write("nonexistent\thttp://example.com/nonexistent\n" + - "arkeologia\thttp://www.yso.fi/onto/yso/p1265") + "arkeologia\thttp://www.yso.fi/onto/yso/p1265\n" + + "...\thttp://example.com/none") return annif.corpus.DocumentFile(str(tmpfile)) @@ -44,6 +45,22 @@ def test_vw_multi_train(datadir, document_corpus, project): assert datadir.join('vw-model').size() > 0 +def test_vw_multi_train_from_project(app, datadir, document_corpus, project): + vw_type = annif.backend.get_backend('vw_multi') + vw = vw_type( + backend_id='vw_multi', + params={ + 'chunksize': 4, + 'inputs': '_text_,dummy-en'}, + datadir=str(datadir)) + + with app.app_context(): + vw.train(document_corpus, project) + assert vw._model is not None + assert datadir.join('vw-model').exists() + assert datadir.join('vw-model').size() > 0 + + def test_vw_multi_train_multiple_passes(datadir, document_corpus, project): vw_type = annif.backend.get_backend('vw_multi') vw = vw_type(