Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Enhancements to vw_multi backend #254

Merged
merged 5 commits into from
Feb 1, 2019
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
45 changes: 31 additions & 14 deletions annif/backend/vw_multi.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,15 +23,17 @@ class VWMultiBackend(mixins.ChunkingBackend, backend.AnnifBackend):
# where allowed_values is either a type or a list of allowed values
# and default_value may be None, to let VW decide by itself
'bit_precision': (int, None),
'ngram': (int, None),
'learning_rate': (float, None),
'loss_function': (['squared', 'logistic', 'hinge'], 'logistic'),
'l1': (float, None),
'l2': (float, None),
'passes': (int, None)
'passes': (int, None),
'probabilities': (bool, None)
}

DEFAULT_ALGORITHM = 'oaa'
SUPPORTED_ALGORITHMS = ('oaa', 'ect')
SUPPORTED_ALGORITHMS = ('oaa', 'ect', 'log_multi', 'multilabel_oaa')

MODEL_FILE = 'vw-model'
TRAIN_FILE = 'vw-train.txt'
Expand Down Expand Up @@ -71,22 +73,35 @@ def _normalize_text(cls, project, text):
# colon and pipe chars have special meaning in VW and must be avoided
return ntext.replace(':', '').replace('|', '')

def _write_train_file(self, examples, filename):
@classmethod
def _write_train_file(cls, examples, filename):
with open(filename, 'w') as trainfile:
for ex in examples:
print(ex, file=trainfile)

@classmethod
def _uris_to_subject_ids(cls, project, uris):
subject_ids = []
for uri in uris:
subject_id = project.subjects.by_uri(uri)
if subject_id is not None:
subject_ids.append(subject_id)
return subject_ids

def _format_examples(self, project, text, uris):
subject_ids = self._uris_to_subject_ids(project, uris)
if self.algorithm == 'multilabel_oaa':
yield '{} | {}'.format(','.join(map(str, subject_ids)), text)
else:
for subject_id in subject_ids:
yield '{} | {}'.format(subject_id + 1, text)

def _create_train_file(self, corpus, project):
self.info('creating VW train file')
examples = []
for doc in corpus.documents:
text = self._normalize_text(project, doc.text)
for uri in doc.uris:
subject_id = project.subjects.by_uri(uri)
if subject_id is None:
continue
exstr = '{} | {}'.format(subject_id + 1, text)
examples.append(exstr)
examples.extend(self._format_examples(project, text, doc.uris))
random.shuffle(examples)
annif.util.atomic_save(examples,
self._get_datadir(),
Expand Down Expand Up @@ -115,9 +130,6 @@ def _create_params(self, params):
params.update({param: self._convert_param(param, val)
for param, val in self.params.items()
if param in self.VW_PARAMS})
if self.algorithm == 'oaa':
# only the oaa algorithm supports probabilities output
params.update({'probabilities': True, 'loss_function': 'logistic'})
return params

def _create_model(self, project):
Expand All @@ -142,8 +154,13 @@ def _analyze_chunks(self, chunktexts, project):
for chunktext in chunktexts:
example = ' | {}'.format(chunktext)
result = self._model.predict(example)
if isinstance(result, int):
# just a single integer - need to one-hot-encode
if self.algorithm == 'multilabel_oaa':
# result is a list of subject IDs - need to vectorize
mask = np.zeros(len(project.subjects))
mask[result] = 1.0
result = mask
elif isinstance(result, int):
# result is a single integer - need to one-hot-encode
mask = np.zeros(len(project.subjects))
mask[result - 1] = 1.0
result = mask
Expand Down
93 changes: 81 additions & 12 deletions tests/test_backend_vw.py → tests/test_backend_vw_multi.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@ def vw_corpus(tmpdir):
return annif.corpus.DocumentFile(str(tmpfile))


def test_vw_analyze_no_model(datadir, project):
def test_vw_multi_analyze_no_model(datadir, project):
vw_type = annif.backend.get_backend('vw_multi')
vw = vw_type(
backend_id='vw_multi',
Expand All @@ -28,7 +28,7 @@ def test_vw_analyze_no_model(datadir, project):
results = vw.analyze("example text", project)


def test_vw_train(datadir, document_corpus, project):
def test_vw_multi_train(datadir, document_corpus, project):
vw_type = annif.backend.get_backend('vw_multi')
vw = vw_type(
backend_id='vw_multi',
Expand All @@ -44,7 +44,7 @@ def test_vw_train(datadir, document_corpus, project):
assert datadir.join('vw-model').size() > 0


def test_vw_train_multiple_passes(datadir, document_corpus, project):
def test_vw_multi_train_multiple_passes(datadir, document_corpus, project):
vw_type = annif.backend.get_backend('vw_multi')
vw = vw_type(
backend_id='vw_multi',
Expand All @@ -60,7 +60,7 @@ def test_vw_train_multiple_passes(datadir, document_corpus, project):
assert datadir.join('vw-model').size() > 0


def test_vw_train_invalid_algorithm(datadir, document_corpus, project):
def test_vw_multi_train_invalid_algorithm(datadir, document_corpus, project):
vw_type = annif.backend.get_backend('vw_multi')
vw = vw_type(
backend_id='vw_multi',
Expand All @@ -74,7 +74,7 @@ def test_vw_train_invalid_algorithm(datadir, document_corpus, project):
vw.train(document_corpus, project)


def test_vw_train_invalid_loss_function(datadir, project, vw_corpus):
def test_vw_multi_train_invalid_loss_function(datadir, project, vw_corpus):
vw_type = annif.backend.get_backend('vw_multi')
vw = vw_type(
backend_id='vw_multi',
Expand All @@ -85,7 +85,7 @@ def test_vw_train_invalid_loss_function(datadir, project, vw_corpus):
vw.train(vw_corpus, project)


def test_vw_train_invalid_learning_rate(datadir, project, vw_corpus):
def test_vw_multi_train_invalid_learning_rate(datadir, project, vw_corpus):
vw_type = annif.backend.get_backend('vw_multi')
vw = vw_type(
backend_id='vw_multi',
Expand All @@ -96,11 +96,11 @@ def test_vw_train_invalid_learning_rate(datadir, project, vw_corpus):
vw.train(vw_corpus, project)


def test_vw_analyze(datadir, project):
def test_vw_multi_analyze(datadir, project):
vw_type = annif.backend.get_backend('vw_multi')
vw = vw_type(
backend_id='vw_multi',
params={'chunksize': 4},
params={'chunksize': 4, 'probabilities': 1},
datadir=str(datadir))

results = vw.analyze("""Arkeologiaa sanotaan joskus myös
Expand All @@ -116,7 +116,7 @@ def test_vw_analyze(datadir, project):
assert 'arkeologia' in [result.label for result in results]


def test_vw_analyze_empty(datadir, project):
def test_vw_multi_analyze_empty(datadir, project):
vw_type = annif.backend.get_backend('vw_multi')
vw = vw_type(
backend_id='vw_multi',
Expand All @@ -128,7 +128,7 @@ def test_vw_analyze_empty(datadir, project):
assert len(results) == 0


def test_vw_analyze_multiple_passes(datadir, project):
def test_vw_multi_analyze_multiple_passes(datadir, project):
vw_type = annif.backend.get_backend('vw_multi')
vw = vw_type(
backend_id='vw_multi',
Expand All @@ -140,7 +140,7 @@ def test_vw_analyze_multiple_passes(datadir, project):
assert len(results) == 0


def test_vw_train_ect(datadir, document_corpus, project):
def test_vw_multi_train_ect(datadir, document_corpus, project):
vw_type = annif.backend.get_backend('vw_multi')
vw = vw_type(
backend_id='vw_multi',
Expand All @@ -156,7 +156,7 @@ def test_vw_train_ect(datadir, document_corpus, project):
assert datadir.join('vw-model').size() > 0


def test_vw_analyze_ect(datadir, project):
def test_vw_multi_analyze_ect(datadir, project):
vw_type = annif.backend.get_backend('vw_multi')
vw = vw_type(
backend_id='vw_multi',
Expand All @@ -172,3 +172,72 @@ def test_vw_analyze_ect(datadir, project):
pohjaan.""", project)

assert len(results) > 0


def test_vw_multi_train_log_multi(datadir, document_corpus, project):
vw_type = annif.backend.get_backend('vw_multi')
vw = vw_type(
backend_id='vw_multi',
params={
'chunksize': 4,
'learning_rate': 0.5,
'algorithm': 'log_multi'},
datadir=str(datadir))

vw.train(document_corpus, project)
assert vw._model is not None
assert datadir.join('vw-model').exists()
assert datadir.join('vw-model').size() > 0


def test_vw_multi_analyze_log_multi(datadir, project):
vw_type = annif.backend.get_backend('vw_multi')
vw = vw_type(
backend_id='vw_multi',
params={'chunksize': 1,
'algorithm': 'log_multi'},
datadir=str(datadir))

results = vw.analyze("""Arkeologiaa sanotaan joskus myös
muinaistutkimukseksi tai muinaistieteeksi. Se on humanistinen tiede
tai oikeammin joukko tieteitä, jotka tutkivat ihmisen menneisyyttä.
Tutkimusta tehdään analysoimalla muinaisjäännöksiä eli niitä jälkiä,
joita ihmisten toiminta on jättänyt maaperään tai vesistöjen
pohjaan.""", project)

assert len(results) > 0


def test_vw_multi_train_multilabel_oaa(datadir, document_corpus, project):
vw_type = annif.backend.get_backend('vw_multi')
vw = vw_type(
backend_id='vw_multi',
params={
'chunksize': 4,
'learning_rate': 0.5,
'algorithm': 'multilabel_oaa'},
datadir=str(datadir))

vw.train(document_corpus, project)
assert vw._model is not None
assert datadir.join('vw-model').exists()
assert datadir.join('vw-model').size() > 0


def test_vw_multi_analyze_multilabel_oaa(datadir, project):
vw_type = annif.backend.get_backend('vw_multi')
vw = vw_type(
backend_id='vw_multi',
params={'chunksize': 1,
'algorithm': 'multilabel_oaa'},
datadir=str(datadir))

results = vw.analyze("""Arkeologiaa sanotaan joskus myös
muinaistutkimukseksi tai muinaistieteeksi. Se on humanistinen tiede
tai oikeammin joukko tieteitä, jotka tutkivat ihmisen menneisyyttä.
Tutkimusta tehdään analysoimalla muinaisjäännöksiä eli niitä jälkiä,
joita ihmisten toiminta on jättänyt maaperään tai vesistöjen
pohjaan.""", project)

# weak assertion, but often multilabel_oaa produces zero hits
assert len(results) >= 0