diff --git a/annif/backend/vw_multi.py b/annif/backend/vw_multi.py index 355a1a34e..5ab9678d3 100644 --- a/annif/backend/vw_multi.py +++ b/annif/backend/vw_multi.py @@ -23,15 +23,17 @@ class VWMultiBackend(mixins.ChunkingBackend, backend.AnnifBackend): # where allowed_values is either a type or a list of allowed values # and default_value may be None, to let VW decide by itself 'bit_precision': (int, None), + 'ngram': (int, None), 'learning_rate': (float, None), 'loss_function': (['squared', 'logistic', 'hinge'], 'logistic'), 'l1': (float, None), 'l2': (float, None), - 'passes': (int, None) + 'passes': (int, None), + 'probabilities': (bool, None) } DEFAULT_ALGORITHM = 'oaa' - SUPPORTED_ALGORITHMS = ('oaa', 'ect') + SUPPORTED_ALGORITHMS = ('oaa', 'ect', 'log_multi', 'multilabel_oaa') MODEL_FILE = 'vw-model' TRAIN_FILE = 'vw-train.txt' @@ -71,22 +73,35 @@ def _normalize_text(cls, project, text): # colon and pipe chars have special meaning in VW and must be avoided return ntext.replace(':', '').replace('|', '') - def _write_train_file(self, examples, filename): + @classmethod + def _write_train_file(cls, examples, filename): with open(filename, 'w') as trainfile: for ex in examples: print(ex, file=trainfile) + @classmethod + def _uris_to_subject_ids(cls, project, uris): + subject_ids = [] + for uri in uris: + subject_id = project.subjects.by_uri(uri) + if subject_id is not None: + subject_ids.append(subject_id) + return subject_ids + + def _format_examples(self, project, text, uris): + subject_ids = self._uris_to_subject_ids(project, uris) + if self.algorithm == 'multilabel_oaa': + yield '{} | {}'.format(','.join(map(str, subject_ids)), text) + else: + for subject_id in subject_ids: + yield '{} | {}'.format(subject_id + 1, text) + def _create_train_file(self, corpus, project): self.info('creating VW train file') examples = [] for doc in corpus.documents: text = self._normalize_text(project, doc.text) - for uri in doc.uris: - subject_id = project.subjects.by_uri(uri) - if subject_id is None: - continue - exstr = '{} | {}'.format(subject_id + 1, text) - examples.append(exstr) + examples.extend(self._format_examples(project, text, doc.uris)) random.shuffle(examples) annif.util.atomic_save(examples, self._get_datadir(), @@ -115,9 +130,6 @@ def _create_params(self, params): params.update({param: self._convert_param(param, val) for param, val in self.params.items() if param in self.VW_PARAMS}) - if self.algorithm == 'oaa': - # only the oaa algorithm supports probabilities output - params.update({'probabilities': True, 'loss_function': 'logistic'}) return params def _create_model(self, project): @@ -142,8 +154,13 @@ def _analyze_chunks(self, chunktexts, project): for chunktext in chunktexts: example = ' | {}'.format(chunktext) result = self._model.predict(example) - if isinstance(result, int): - # just a single integer - need to one-hot-encode + if self.algorithm == 'multilabel_oaa': + # result is a list of subject IDs - need to vectorize + mask = np.zeros(len(project.subjects)) + mask[result] = 1.0 + result = mask + elif isinstance(result, int): + # result is a single integer - need to one-hot-encode mask = np.zeros(len(project.subjects)) mask[result - 1] = 1.0 result = mask diff --git a/tests/test_backend_vw.py b/tests/test_backend_vw_multi.py similarity index 60% rename from tests/test_backend_vw.py rename to tests/test_backend_vw_multi.py index 5598f3122..a0594ae2b 100644 --- a/tests/test_backend_vw.py +++ b/tests/test_backend_vw_multi.py @@ -17,7 +17,7 @@ def vw_corpus(tmpdir): return annif.corpus.DocumentFile(str(tmpfile)) -def test_vw_analyze_no_model(datadir, project): +def test_vw_multi_analyze_no_model(datadir, project): vw_type = annif.backend.get_backend('vw_multi') vw = vw_type( backend_id='vw_multi', @@ -28,7 +28,7 @@ def test_vw_analyze_no_model(datadir, project): results = vw.analyze("example text", project) -def test_vw_train(datadir, document_corpus, project): +def test_vw_multi_train(datadir, document_corpus, project): vw_type = annif.backend.get_backend('vw_multi') vw = vw_type( backend_id='vw_multi', @@ -44,7 +44,7 @@ def test_vw_train(datadir, document_corpus, project): assert datadir.join('vw-model').size() > 0 -def test_vw_train_multiple_passes(datadir, document_corpus, project): +def test_vw_multi_train_multiple_passes(datadir, document_corpus, project): vw_type = annif.backend.get_backend('vw_multi') vw = vw_type( backend_id='vw_multi', @@ -60,7 +60,7 @@ def test_vw_train_multiple_passes(datadir, document_corpus, project): assert datadir.join('vw-model').size() > 0 -def test_vw_train_invalid_algorithm(datadir, document_corpus, project): +def test_vw_multi_train_invalid_algorithm(datadir, document_corpus, project): vw_type = annif.backend.get_backend('vw_multi') vw = vw_type( backend_id='vw_multi', @@ -74,7 +74,7 @@ def test_vw_train_invalid_algorithm(datadir, document_corpus, project): vw.train(document_corpus, project) -def test_vw_train_invalid_loss_function(datadir, project, vw_corpus): +def test_vw_multi_train_invalid_loss_function(datadir, project, vw_corpus): vw_type = annif.backend.get_backend('vw_multi') vw = vw_type( backend_id='vw_multi', @@ -85,7 +85,7 @@ def test_vw_train_invalid_loss_function(datadir, project, vw_corpus): vw.train(vw_corpus, project) -def test_vw_train_invalid_learning_rate(datadir, project, vw_corpus): +def test_vw_multi_train_invalid_learning_rate(datadir, project, vw_corpus): vw_type = annif.backend.get_backend('vw_multi') vw = vw_type( backend_id='vw_multi', @@ -96,11 +96,11 @@ def test_vw_train_invalid_learning_rate(datadir, project, vw_corpus): vw.train(vw_corpus, project) -def test_vw_analyze(datadir, project): +def test_vw_multi_analyze(datadir, project): vw_type = annif.backend.get_backend('vw_multi') vw = vw_type( backend_id='vw_multi', - params={'chunksize': 4}, + params={'chunksize': 4, 'probabilities': 1}, datadir=str(datadir)) results = vw.analyze("""Arkeologiaa sanotaan joskus myös @@ -116,7 +116,7 @@ def test_vw_analyze(datadir, project): assert 'arkeologia' in [result.label for result in results] -def test_vw_analyze_empty(datadir, project): +def test_vw_multi_analyze_empty(datadir, project): vw_type = annif.backend.get_backend('vw_multi') vw = vw_type( backend_id='vw_multi', @@ -128,7 +128,7 @@ def test_vw_analyze_empty(datadir, project): assert len(results) == 0 -def test_vw_analyze_multiple_passes(datadir, project): +def test_vw_multi_analyze_multiple_passes(datadir, project): vw_type = annif.backend.get_backend('vw_multi') vw = vw_type( backend_id='vw_multi', @@ -140,7 +140,7 @@ def test_vw_analyze_multiple_passes(datadir, project): assert len(results) == 0 -def test_vw_train_ect(datadir, document_corpus, project): +def test_vw_multi_train_ect(datadir, document_corpus, project): vw_type = annif.backend.get_backend('vw_multi') vw = vw_type( backend_id='vw_multi', @@ -156,7 +156,7 @@ def test_vw_train_ect(datadir, document_corpus, project): assert datadir.join('vw-model').size() > 0 -def test_vw_analyze_ect(datadir, project): +def test_vw_multi_analyze_ect(datadir, project): vw_type = annif.backend.get_backend('vw_multi') vw = vw_type( backend_id='vw_multi', @@ -172,3 +172,72 @@ def test_vw_analyze_ect(datadir, project): pohjaan.""", project) assert len(results) > 0 + + +def test_vw_multi_train_log_multi(datadir, document_corpus, project): + vw_type = annif.backend.get_backend('vw_multi') + vw = vw_type( + backend_id='vw_multi', + params={ + 'chunksize': 4, + 'learning_rate': 0.5, + 'algorithm': 'log_multi'}, + datadir=str(datadir)) + + vw.train(document_corpus, project) + assert vw._model is not None + assert datadir.join('vw-model').exists() + assert datadir.join('vw-model').size() > 0 + + +def test_vw_multi_analyze_log_multi(datadir, project): + vw_type = annif.backend.get_backend('vw_multi') + vw = vw_type( + backend_id='vw_multi', + params={'chunksize': 1, + 'algorithm': 'log_multi'}, + datadir=str(datadir)) + + results = vw.analyze("""Arkeologiaa sanotaan joskus myös + muinaistutkimukseksi tai muinaistieteeksi. Se on humanistinen tiede + tai oikeammin joukko tieteitä, jotka tutkivat ihmisen menneisyyttä. + Tutkimusta tehdään analysoimalla muinaisjäännöksiä eli niitä jälkiä, + joita ihmisten toiminta on jättänyt maaperään tai vesistöjen + pohjaan.""", project) + + assert len(results) > 0 + + +def test_vw_multi_train_multilabel_oaa(datadir, document_corpus, project): + vw_type = annif.backend.get_backend('vw_multi') + vw = vw_type( + backend_id='vw_multi', + params={ + 'chunksize': 4, + 'learning_rate': 0.5, + 'algorithm': 'multilabel_oaa'}, + datadir=str(datadir)) + + vw.train(document_corpus, project) + assert vw._model is not None + assert datadir.join('vw-model').exists() + assert datadir.join('vw-model').size() > 0 + + +def test_vw_multi_analyze_multilabel_oaa(datadir, project): + vw_type = annif.backend.get_backend('vw_multi') + vw = vw_type( + backend_id='vw_multi', + params={'chunksize': 1, + 'algorithm': 'multilabel_oaa'}, + datadir=str(datadir)) + + results = vw.analyze("""Arkeologiaa sanotaan joskus myös + muinaistutkimukseksi tai muinaistieteeksi. Se on humanistinen tiede + tai oikeammin joukko tieteitä, jotka tutkivat ihmisen menneisyyttä. + Tutkimusta tehdään analysoimalla muinaisjäännöksiä eli niitä jälkiä, + joita ihmisten toiminta on jättänyt maaperään tai vesistöjen + pohjaan.""", project) + + # weak assertion, but often multilabel_oaa produces zero hits + assert len(results) >= 0