Skip to content

Commit

Permalink
Merge pull request #253 from NatLibFi/issue230-vw-multi-passes
Browse files Browse the repository at this point in the history
Move passes-related settings outside _create_params
  • Loading branch information
osma authored Jan 30, 2019
2 parents 94b1ca3 + 7038111 commit b7c532f
Show file tree
Hide file tree
Showing 2 changed files with 19 additions and 16 deletions.
10 changes: 7 additions & 3 deletions annif/backend/vw_multi.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,6 +48,10 @@ def initialize(self):
backend_id=self.backend_id)
self.debug('loading VW model from {}'.format(path))
params = self._create_params({'i': path, 'quiet': True})
if 'passes' in params:
# don't confuse the model with passes
del params['passes']
self.debug("model parameters: {}".format(params))
self._model = pyvw.vw(**params)
self.debug('loaded model {}'.format(str(self._model)))

Expand Down Expand Up @@ -111,9 +115,6 @@ def _create_params(self, params):
params.update({param: self._convert_param(param, val)
for param, val in self.params.items()
if param in self.VW_PARAMS})
if params.get('passes', 1) > 1:
# need a cache file when there are multiple passes
params.update({'cache': True, 'kill_cache': True})
if self.algorithm == 'oaa':
# only the oaa algorithm supports probabilities output
params.update({'probabilities': True, 'loss_function': 'logistic'})
Expand All @@ -124,6 +125,9 @@ def _create_model(self, project):
trainpath = os.path.join(self._get_datadir(), self.TRAIN_FILE)
params = self._create_params(
{'data': trainpath, self.algorithm: len(project.subjects)})
if params.get('passes', 1) > 1:
# need a cache file when there are multiple passes
params.update({'cache': True, 'kill_cache': True})
self.debug("model parameters: {}".format(params))
self._model = pyvw.vw(**params)
modelpath = os.path.join(self._get_datadir(), self.MODEL_FILE)
Expand Down
25 changes: 12 additions & 13 deletions tests/test_backend_vw.py
Original file line number Diff line number Diff line change
Expand Up @@ -96,19 +96,6 @@ def test_vw_train_invalid_learning_rate(datadir, project, vw_corpus):
vw.train(vw_corpus, project)


def test_vw_train_unknown_subject(datadir, project, vw_corpus):
vw_type = annif.backend.get_backend('vw_multi')
vw = vw_type(
backend_id='vw_multi',
params={'chunksize': 4},
datadir=str(datadir))

vw.train(vw_corpus, project)
assert vw._model is not None
assert datadir.join('vw-model').exists()
assert datadir.join('vw-model').size() > 0


def test_vw_analyze(datadir, project):
vw_type = annif.backend.get_backend('vw_multi')
vw = vw_type(
Expand Down Expand Up @@ -141,6 +128,18 @@ def test_vw_analyze_empty(datadir, project):
assert len(results) == 0


def test_vw_analyze_multiple_passes(datadir, project):
vw_type = annif.backend.get_backend('vw_multi')
vw = vw_type(
backend_id='vw_multi',
params={'chunksize': 4, 'passes': 2},
datadir=str(datadir))

results = vw.analyze("...", project)

assert len(results) == 0


def test_vw_train_ect(datadir, document_corpus, project):
vw_type = annif.backend.get_backend('vw_multi')
vw = vw_type(
Expand Down

0 comments on commit b7c532f

Please sign in to comment.