Skip to content

Commit

Permalink
Defaulting simply by updating default params with config params
Browse files Browse the repository at this point in the history
  • Loading branch information
juhoinkinen committed Sep 25, 2019
1 parent 63a199c commit 420d3d1
Show file tree
Hide file tree
Showing 9 changed files with 17 additions and 68 deletions.
15 changes: 2 additions & 13 deletions annif/backend/backend.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,23 +20,12 @@ def __init__(self, backend_id, config_params, datadir):
backend type."""
self.backend_id = backend_id
self.datadir = datadir
self._params = config_params
self.params = self.default_params().copy()
self.params.update(config_params)

def default_params(self):
return self.DEFAULT_PARAMS

@property
def params(self):
params_to_set = {param: str(val)
for param, val in self.default_params().items()
if param not in self._params}
if params_to_set:
self.debug(
'all parameters not set, using the following defaults: {}'
.format(params_to_set))
self._params.update(params_to_set)
return self._params.copy()

def train(self, corpus, project):
"""train the model on the given document or subject corpus"""
pass # default is to do nothing, subclasses may override
Expand Down
2 changes: 1 addition & 1 deletion annif/project.py
Original file line number Diff line number Diff line change
Expand Up @@ -144,7 +144,7 @@ def backend(self):
try:
backend_class = annif.backend.get_backend(backend_id)
self._backend = backend_class(
backend_id, config_params=dict(self.config),
backend_id, config_params=self.config,
datadir=self.datadir)
except ValueError:
logger.warning(
Expand Down
9 changes: 2 additions & 7 deletions tests/test_backend.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,14 +43,9 @@ def test_learn_dummy(app, project, tmpdir):
assert result[0].score == 1.0


def test_fill_params_with_defaults(app, caplog):
logger = annif.logger
logger.propagate = True
def test_fill_params_with_defaults(app):
dummy_type = annif.backend.get_backend('dummy')
dummy = dummy_type(backend_id='dummy', config_params={},
datadir=app.config['DATADIR'])
expected_default_params = {'limit': '100'} # From AnnifBackend class
expected_msg = 'all parameters not set, using the following defaults:'
caplog.set_level(logging.DEBUG)
expected_default_params = {'limit': 100} # From AnnifBackend class
assert expected_default_params == dummy.params
assert expected_msg in caplog.records[0].message
11 changes: 2 additions & 9 deletions tests/test_backend_fasttext.py
Original file line number Diff line number Diff line change
@@ -1,18 +1,13 @@
"""Unit tests for the fastText backend in Annif"""

import logging
import pytest
import annif.backend
import annif.corpus

fasttext = pytest.importorskip("annif.backend.fasttext")


def test_fasttext_default_params(datadir, project, caplog):
logger = annif.logger
logger.propagate = True
caplog.set_level(logging.DEBUG)

def test_fasttext_default_params(datadir, project):
fasttext_type = annif.backend.get_backend("fasttext")
fasttext = fasttext_type(
backend_id='fasttext',
Expand All @@ -27,11 +22,9 @@ def test_fasttext_default_params(datadir, project, caplog):
'epoch': 5,
'loss': 'hs',
}
expected_msg = "all parameters not set, using the following defaults:"
actual_params = fasttext.params
assert expected_msg in caplog.records[0].message
for param, val in expected_default_params.items():
assert param in actual_params and actual_params[param] == str(val)
assert param in actual_params and actual_params[param] == val


def test_fasttext_train(datadir, document_corpus, project):
Expand Down
11 changes: 2 additions & 9 deletions tests/test_backend_pav.py
Original file line number Diff line number Diff line change
@@ -1,15 +1,10 @@
"""Unit tests for the PAV backend in Annif"""

import logging
import annif.backend
import annif.corpus


def test_pav_default_params(datadir, document_corpus, project, caplog):
logger = annif.logger
logger.propagate = True
caplog.set_level(logging.DEBUG)

def test_pav_default_params(datadir, document_corpus, project):
pav_type = annif.backend.get_backend("pav")
pav = pav_type(
backend_id='pav',
Expand All @@ -19,11 +14,9 @@ def test_pav_default_params(datadir, document_corpus, project, caplog):
expected_default_params = {
'min-docs': 10,
}
expected_msg = "all parameters not set, using the following defaults:"
actual_params = pav.params
assert expected_msg in caplog.records[0].message
for param, val in expected_default_params.items():
assert param in actual_params and actual_params[param] == str(val)
assert param in actual_params and actual_params[param] == val


def test_pav_train(app, datadir, tmpdir, project):
Expand Down
11 changes: 2 additions & 9 deletions tests/test_backend_tfidf.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,6 @@
import annif.backend
import annif.corpus
from sklearn.feature_extraction.text import TfidfVectorizer
import logging
import pytest
import unittest.mock

Expand All @@ -19,11 +18,7 @@ def project(document_corpus):
return proj


def test_tfidf_default_params(datadir, project, caplog):
logger = annif.logger
logger.propagate = True
caplog.set_level(logging.DEBUG)

def test_tfidf_default_params(datadir, project):
tfidf_type = annif.backend.get_backend("tfidf")
tfidf = tfidf_type(
backend_id='tfidf',
Expand All @@ -33,11 +28,9 @@ def test_tfidf_default_params(datadir, project, caplog):
expected_default_params = {
'limit': 100 # From AnnifBackend class
}
expected_msg = "all parameters not set, using the following defaults:"
actual_params = tfidf.params
assert expected_msg in caplog.records[0].message
for param, val in expected_default_params.items():
assert param in actual_params and actual_params[param] == str(val)
assert param in actual_params and actual_params[param] == val


def test_tfidf_train(datadir, document_corpus, project):
Expand Down
11 changes: 2 additions & 9 deletions tests/test_backend_vw_ensemble.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,5 @@
"""Unit tests for the vw_ensemble backend in Annif"""

import logging
import json
import time
import pytest
Expand All @@ -12,11 +11,7 @@
pytest.importorskip("annif.backend.vw_ensemble")


def test_vw_ensemble_default_params(datadir, project, caplog):
logger = annif.logger
logger.propagate = True
caplog.set_level(logging.DEBUG)

def test_vw_ensemble_default_params(datadir, project):
vw_type = annif.backend.get_backend("vw_ensemble")
vw = vw_type(
backend_id='vw_ensemble',
Expand All @@ -28,11 +23,9 @@ def test_vw_ensemble_default_params(datadir, project, caplog):
'discount_rate': 0.01,
'loss_function': 'squared',
}
expected_msg = "all parameters not set, using the following defaults:"
actual_params = vw.params
assert expected_msg in caplog.records[0].message
for param, val in expected_default_params.items():
assert param in actual_params and actual_params[param] == str(val)
assert param in actual_params and actual_params[param] == val


def test_vw_ensemble_suggest_no_model(datadir, project):
Expand Down
11 changes: 2 additions & 9 deletions tests/test_backend_vw_multi.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,5 @@
"""Unit tests for the fastText backend in Annif"""

import logging
import pytest
import annif.backend
import annif.corpus
Expand All @@ -19,11 +18,7 @@ def vw_corpus(tmpdir):
return annif.corpus.DocumentFile(str(tmpfile))


def test_vw_multi_default_params(datadir, project, caplog):
logger = annif.logger
logger.propagate = True
caplog.set_level(logging.DEBUG)

def test_vw_multi_default_params(datadir, project):
vw_type = annif.backend.get_backend("vw_multi")
vw = vw_type(
backend_id='vw_multi',
Expand All @@ -36,11 +31,9 @@ def test_vw_multi_default_params(datadir, project, caplog):
'algorithm': 'oaa',
'loss_function': 'logistic',
}
expected_msg = "all parameters not set, using the following defaults:"
actual_params = vw.params
assert expected_msg in caplog.records[0].message
for param, val in expected_default_params.items():
assert param in actual_params and actual_params[param] == str(val)
assert param in actual_params and actual_params[param] == val


def test_vw_multi_suggest_no_model(datadir, project):
Expand Down
4 changes: 2 additions & 2 deletions tests/test_project.py
Original file line number Diff line number Diff line change
Expand Up @@ -103,7 +103,7 @@ def test_get_project_default_params_tfidf(app):
}
actual_params = project.backend.params
for param, val in expected_default_params.items():
assert param in actual_params and actual_params[param] == str(val)
assert param in actual_params and actual_params[param] == val


def test_get_project_default_params_fasttext(app):
Expand All @@ -118,7 +118,7 @@ def test_get_project_default_params_fasttext(app):
'loss': 'hs'}
actual_params = project.backend.params
for param, val in expected_default_params.items():
assert param in actual_params and actual_params[param] == str(val)
assert param in actual_params and actual_params[param] == val


def test_get_project_invalid_config_file(app):
Expand Down

0 comments on commit 420d3d1

Please sign in to comment.