Skip to content

Commit

Permalink
Pass only the backend's own params to its methods
Browse files Browse the repository at this point in the history
  • Loading branch information
juhoinkinen committed Dec 17, 2019
1 parent fc967b9 commit 5b4bb73
Show file tree
Hide file tree
Showing 8 changed files with 25 additions and 20 deletions.
3 changes: 1 addition & 2 deletions annif/backend/backend.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,8 +35,7 @@ def params(self):
def _get_backend_params(self, params):
backend_params = dict(self.params)
if params is not None:
beparams = params.get(self.backend_id, {})
backend_params.update(beparams)
backend_params.update(params)
return backend_params

def _train(self, corpus, params):
Expand Down
25 changes: 17 additions & 8 deletions annif/project.py
Original file line number Diff line number Diff line change
Expand Up @@ -98,8 +98,11 @@ def initialize(self):

self.initialized = True

def _suggest_with_backend(self, text, be_params):
hits = self.backend.suggest(text, be_params)
def _suggest_with_backend(self, text, backend_params):
if backend_params is None:
backend_params = {}
beparams = backend_params.get(self.backend.backend_id, {})
hits = self.backend.suggest(text, beparams)
logger.debug(
'Got %d hits from backend %s',
len(hits), self.backend.backend_id)
Expand Down Expand Up @@ -150,27 +153,33 @@ def vocab(self):
def subjects(self):
return self.vocab.subjects

def suggest(self, text, be_params=None):
def suggest(self, text, backend_params=None):
"""Suggest subjects the given text by passing it to the backend. Returns a
list of SubjectSuggestion objects ordered by decreasing score."""
logger.debug('Suggesting subjects for text "%s..." (len=%d)',
text[:20], len(text))
hits = self._suggest_with_backend(text, be_params)
hits = self._suggest_with_backend(text, backend_params)
logger.debug('%d hits from backend', len(hits))
return hits

def train(self, corpus, be_params=None):
def train(self, corpus, backend_params=None):
"""train the project using documents from a metadata source"""
corpus.set_subject_index(self.subjects)
self.backend.train(corpus, be_params)
if backend_params is None:
backend_params = {}
beparams = backend_params.get(self.backend.backend_id, {})
self.backend.train(corpus, beparams)

def learn(self, corpus, be_params=None):
def learn(self, corpus, backend_params=None):
"""further train the project using documents from a metadata source"""
corpus.set_subject_index(self.subjects)
if backend_params is None:
backend_params = {}
beparams = backend_params.get(self.backend.backend_id, {})
if isinstance(
self.backend,
annif.backend.backend.AnnifLearningBackend):
self.backend.learn(corpus, be_params)
self.backend.learn(corpus, beparams)
else:
raise NotSupportedException("Learning not supported by backend",
project_id=self.project_id)
Expand Down
2 changes: 1 addition & 1 deletion tests/test_backend_fasttext.py
Original file line number Diff line number Diff line change
Expand Up @@ -101,7 +101,7 @@ def test_train_fasttext_params(document_corpus, project, caplog):
'epoch': 21,
'loss': 'hs'},
project=project)
params = {'fasttext': {'dim': 1, 'lr': 42.1, 'epoch': 0}}
params = {'dim': 1, 'lr': 42.1, 'epoch': 0}

with caplog.at_level(logging.DEBUG):
fasttext.train(document_corpus, params)
Expand Down
4 changes: 2 additions & 2 deletions tests/test_backend_nn_ensemble.py
Original file line number Diff line number Diff line change
Expand Up @@ -70,13 +70,13 @@ def test_nn_ensemble_train_and_learn_params(app, tmpdir, capfd):
"none\thttp://example.org/none")
document_corpus = annif.corpus.DocumentFile(str(tmpfile))

train_params = {'nn_ensemble': {'epochs': 3}}
train_params = {'epochs': 3}
with app.app_context():
nn_ensemble.train(document_corpus, train_params)
out, _ = capfd.readouterr()
assert 'Epoch 3/3' in out

learn_params = {'nn_ensemble': {'learn-epochs': 2}}
learn_params = {'learn-epochs': 2}
nn_ensemble.learn(document_corpus, learn_params)
out, _ = capfd.readouterr()
assert 'Epoch 2/2' in out
Expand Down
3 changes: 1 addition & 2 deletions tests/test_backend_omikuji.py
Original file line number Diff line number Diff line change
Expand Up @@ -95,8 +95,7 @@ def test_omikuji_train_params(datadir, document_corpus, project, capfd):
backend_id='omikuji',
config_params={},
project=project)
params = {'omikuji': {
'cluster_k': 1, 'max_depth': 2, 'collapse_every_n_layers': 42}}
params = {'cluster_k': 1, 'max_depth': 2, 'collapse_every_n_layers': 42}
omikuji.train(document_corpus, params)

out, _ = capfd.readouterr()
Expand Down
2 changes: 1 addition & 1 deletion tests/test_backend_pav.py
Original file line number Diff line number Diff line change
Expand Up @@ -102,7 +102,7 @@ def test_pav_train_params(app, tmpdir, project, caplog):
"another\thttp://example.org/dummy\n" +
"none\thttp://example.org/none")
document_corpus = annif.corpus.DocumentFile(str(tmpfile))
params = {'pav': {'min-docs': 5}}
params = {'min-docs': 5}

with app.app_context(), caplog.at_level(logging.DEBUG):
pav.train(document_corpus, params)
Expand Down
2 changes: 1 addition & 1 deletion tests/test_backend_tfidf.py
Original file line number Diff line number Diff line change
Expand Up @@ -62,7 +62,7 @@ def test_suggest_params(project):
backend_id='tfidf',
config_params={'limit': 10},
project=project)
params = {'tfidf': {'limit': 3}}
params = {'limit': 3}

results = tfidf.suggest("""Arkeologiaa sanotaan joskus myös
muinaistutkimukseksi tai muinaistieteeksi. Se on humanistinen tiede
Expand Down
4 changes: 1 addition & 3 deletions tests/test_backend_vw_multi.py
Original file line number Diff line number Diff line change
Expand Up @@ -176,9 +176,7 @@ def test_vw_multi_train_params(project, vw_corpus, caplog):
backend_id='vw_multi',
config_params={'chunksize': 4, 'learning_rate': 0.5},
project=project)
params = {'vw_multi': {
'loss_function': 'logistic',
'learning_rate': 42.1}}
params = {'loss_function': 'logistic', 'learning_rate': 42.1}

with caplog.at_level(logging.DEBUG):
vw.train(vw_corpus, params)
Expand Down

0 comments on commit 5b4bb73

Please sign in to comment.