diff --git a/annif/backend/vw_multi.py b/annif/backend/vw_multi.py index 86d7aa33c..355a1a34e 100644 --- a/annif/backend/vw_multi.py +++ b/annif/backend/vw_multi.py @@ -48,6 +48,10 @@ def initialize(self): backend_id=self.backend_id) self.debug('loading VW model from {}'.format(path)) params = self._create_params({'i': path, 'quiet': True}) + if 'passes' in params: + # don't confuse the model with passes + del params['passes'] + self.debug("model parameters: {}".format(params)) self._model = pyvw.vw(**params) self.debug('loaded model {}'.format(str(self._model))) @@ -111,9 +115,6 @@ def _create_params(self, params): params.update({param: self._convert_param(param, val) for param, val in self.params.items() if param in self.VW_PARAMS}) - if params.get('passes', 1) > 1: - # need a cache file when there are multiple passes - params.update({'cache': True, 'kill_cache': True}) if self.algorithm == 'oaa': # only the oaa algorithm supports probabilities output params.update({'probabilities': True, 'loss_function': 'logistic'}) @@ -124,6 +125,9 @@ def _create_model(self, project): trainpath = os.path.join(self._get_datadir(), self.TRAIN_FILE) params = self._create_params( {'data': trainpath, self.algorithm: len(project.subjects)}) + if params.get('passes', 1) > 1: + # need a cache file when there are multiple passes + params.update({'cache': True, 'kill_cache': True}) self.debug("model parameters: {}".format(params)) self._model = pyvw.vw(**params) modelpath = os.path.join(self._get_datadir(), self.MODEL_FILE) diff --git a/tests/test_backend_vw.py b/tests/test_backend_vw.py index 30cd55bff..5598f3122 100644 --- a/tests/test_backend_vw.py +++ b/tests/test_backend_vw.py @@ -96,19 +96,6 @@ def test_vw_train_invalid_learning_rate(datadir, project, vw_corpus): vw.train(vw_corpus, project) -def test_vw_train_unknown_subject(datadir, project, vw_corpus): - vw_type = annif.backend.get_backend('vw_multi') - vw = vw_type( - backend_id='vw_multi', - params={'chunksize': 4}, - datadir=str(datadir)) - - vw.train(vw_corpus, project) - assert vw._model is not None - assert datadir.join('vw-model').exists() - assert datadir.join('vw-model').size() > 0 - - def test_vw_analyze(datadir, project): vw_type = annif.backend.get_backend('vw_multi') vw = vw_type( @@ -141,6 +128,18 @@ def test_vw_analyze_empty(datadir, project): assert len(results) == 0 +def test_vw_analyze_multiple_passes(datadir, project): + vw_type = annif.backend.get_backend('vw_multi') + vw = vw_type( + backend_id='vw_multi', + params={'chunksize': 4, 'passes': 2}, + datadir=str(datadir)) + + results = vw.analyze("...", project) + + assert len(results) == 0 + + def test_vw_train_ect(datadir, document_corpus, project): vw_type = annif.backend.get_backend('vw_multi') vw = vw_type(