-
Notifications
You must be signed in to change notification settings - Fork 41
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Add nn_ensemble backend #331
Merged
Merged
Changes from 27 commits
Commits
Show all changes
31 commits
Select commit
Hold shift + click to select a range
0cf82c5
First rough version of neural network ensemble backend
osma 9b71ba7
Merge branch 'master' into issue330-nn-ensemble-backend
osma f0180a5
clean up imports
osma 2a422b1
Install Keras and TensorFlow under Travis CI
osma 069dab6
Add basic unit tests for nn_ensemble backend
osma fbf3419
Add more nn_ensemble unit tests + fix bug in suggest method
osma 2691716
remove print statements
osma 036027f
Merge branch 'master' into issue330-nn-ensemble-backend
osma 10b0b00
Merge branch 'master' into issue330-nn-ensemble-backend
osma 6ea609f
Specify explicit Keras and tensorflow versions. Use TF 1.15 RC to avo…
osma 4d3a83d
fix setup.py syntax
osma 1d08613
upgrade pip under Travis CI before installing anything (tensorflow ne…
osma c88c218
make Keras and tensorflow core dependencies, not optional; pin numpy …
osma 1b94abe
upgrade pip on scrutinizer (tensorflow needs pip 19.*)
osma a789f46
nn_ensemble is now a core backend, remove conditional imports
osma 94992b4
Merge branch 'master' into issue330-nn-ensemble-backend
osma bcb779b
Make hyperparameters configurable in nn_ensemble (with defaults)
osma 9b7148b
fix syntax, pep8 and tests (doh)
osma 3d6aff2
Merge branch 'master' into issue330-nn-ensemble-backend
osma 5b90d25
Turn nn_ensemble into an optional feature again. I've had some trouble
osma f12929a
Adjust Pipfile, setup.py and .travis.yml to make nn feature optional
osma 40c8af3
fix syntax (oops)
osma 7c38754
Upgrade to TensorFlow 2.0
osma db8f893
Merge branch 'master' into issue330-nn-ensemble-backend
osma 50e8202
more elegant handling of file name prefixes in annif.util.atomic_save
osma 3c081f8
Refactor: Split learn method in nn_ensemble backend
osma 8331a16
Avoid testing nn features on Python 3.6, to increase overall test cov…
osma bd6649b
set explicit dtype=float32 for numpy arrays to avoid wasting memory
osma b9093e7
Merge branch 'master' into issue330-nn-ensemble-backend
osma f0014df
Up the default to 100 nodes since it may produce better results
osma ce49174
Install tensorflow in Docker image
juhoinkinen File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,126 @@ | ||
"""Neural network based ensemble backend that combines results from multiple | ||
projects.""" | ||
|
||
|
||
import os.path | ||
import numpy as np | ||
from tensorflow.keras.layers import Input, Dense, Add, Flatten, Lambda, Dropout | ||
from tensorflow.keras.models import Model, load_model | ||
import tensorflow.keras.backend as K | ||
import annif.corpus | ||
import annif.project | ||
import annif.util | ||
from annif.exception import NotInitializedException | ||
from annif.suggestion import VectorSuggestionResult | ||
from . import ensemble | ||
|
||
|
||
class NNEnsembleBackend(ensemble.EnsembleBackend): | ||
"""Neural network ensemble backend that combines results from multiple | ||
projects""" | ||
|
||
name = "nn_ensemble" | ||
|
||
MODEL_FILE = "nn-model.h5" | ||
|
||
DEFAULT_PARAMS = { | ||
'nodes': 60, | ||
'dropout_rate': 0.2, | ||
'optimizer': 'adam', | ||
'epochs': 10, | ||
} | ||
|
||
# defaults for uninitialized instances | ||
_model = None | ||
|
||
def default_params(self): | ||
params = {} | ||
params.update(super().default_params()) | ||
params.update(self.DEFAULT_PARAMS) | ||
return params | ||
|
||
def initialize(self): | ||
if self._model is not None: | ||
return # already initialized | ||
model_filename = os.path.join(self.datadir, self.MODEL_FILE) | ||
if not os.path.exists(model_filename): | ||
raise NotInitializedException( | ||
'model file {} not found'.format(model_filename), | ||
backend_id=self.backend_id) | ||
self.debug('loading Keras model from {}'.format(model_filename)) | ||
self._model = load_model(model_filename) | ||
|
||
def _merge_hits_from_sources(self, hits_from_sources, project, params): | ||
score_vector = np.array([hits.vector * weight | ||
for hits, weight in hits_from_sources]) | ||
results = self._model.predict( | ||
np.expand_dims(score_vector.transpose(), 0)) | ||
return VectorSuggestionResult(results[0], project.subjects) | ||
|
||
def _create_model(self, sources, project): | ||
self.info("creating NN ensemble model") | ||
|
||
inputs = Input(shape=(len(project.subjects), len(sources))) | ||
|
||
flat_input = Flatten()(inputs) | ||
drop_input = Dropout( | ||
rate=float( | ||
self.params['dropout_rate']))(flat_input) | ||
hidden = Dense(int(self.params['nodes']), | ||
activation="relu")(drop_input) | ||
drop_hidden = Dropout(rate=float(self.params['dropout_rate']))(hidden) | ||
delta = Dense(len(project.subjects), | ||
kernel_initializer='zeros', | ||
bias_initializer='zeros')(drop_hidden) | ||
|
||
mean = Lambda(lambda x: K.mean(x, axis=2))(inputs) | ||
|
||
predictions = Add()([mean, delta]) | ||
|
||
self._model = Model(inputs=inputs, outputs=predictions) | ||
self._model.compile(optimizer=self.params['optimizer'], | ||
loss='binary_crossentropy', | ||
metrics=['top_k_categorical_accuracy']) | ||
|
||
summary = [] | ||
self._model.summary(print_fn=summary.append) | ||
self.debug("Created model: \n" + "\n".join(summary)) | ||
|
||
def train(self, corpus, project): | ||
sources = annif.util.parse_sources(self.params['sources']) | ||
self._create_model(sources, project) | ||
self.learn(corpus, project) | ||
|
||
def _corpus_to_vectors(self, corpus, project): | ||
# pass corpus through all source projects | ||
sources = [(annif.project.get_project(project_id), weight) | ||
for project_id, weight | ||
in annif.util.parse_sources(self.params['sources'])] | ||
|
||
score_vectors = [] | ||
true_vectors = [] | ||
for doc in corpus.documents: | ||
doc_scores = [] | ||
for source_project, weight in sources: | ||
hits = source_project.suggest(doc.text) | ||
doc_scores.append(hits.vector * weight) | ||
score_vectors.append(np.array(doc_scores).transpose()) | ||
subjects = annif.corpus.SubjectSet((doc.uris, doc.labels)) | ||
true_vectors.append(subjects.as_vector(project.subjects)) | ||
# collect the results into a single vector, considering weights | ||
scores = np.array(score_vectors) | ||
# collect the gold standard values into another vector | ||
true = np.array(true_vectors) | ||
return (scores, true) | ||
|
||
def learn(self, corpus, project): | ||
scores, true = self._corpus_to_vectors(corpus, project) | ||
|
||
# fit the model | ||
self._model.fit(scores, true, batch_size=32, verbose=True, | ||
epochs=int(self.params['epochs'])) | ||
|
||
annif.util.atomic_save( | ||
self._model, | ||
self.datadir, | ||
self.MODEL_FILE) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,89 @@ | ||
"""Unit tests for the nn_ensemble backend in Annif""" | ||
|
||
import time | ||
import pytest | ||
import annif.backend | ||
import annif.corpus | ||
import annif.project | ||
from annif.exception import NotInitializedException | ||
|
||
pytest.importorskip("annif.backend.nn_ensemble") | ||
|
||
|
||
def test_nn_ensemble_suggest_no_model(datadir, project): | ||
nn_ensemble_type = annif.backend.get_backend('nn_ensemble') | ||
nn_ensemble = nn_ensemble_type( | ||
backend_id='nn_ensemble', | ||
config_params={'sources': 'dummy-en'}, | ||
datadir=str(datadir)) | ||
|
||
with pytest.raises(NotInitializedException): | ||
results = nn_ensemble.suggest("example text", project) | ||
|
||
|
||
def test_nn_ensemble_train_and_learn(app, datadir, tmpdir): | ||
nn_ensemble_type = annif.backend.get_backend("nn_ensemble") | ||
nn_ensemble = nn_ensemble_type( | ||
backend_id='nn_ensemble', | ||
config_params={'sources': 'dummy-en'}, | ||
datadir=str(datadir)) | ||
|
||
tmpfile = tmpdir.join('document.tsv') | ||
tmpfile.write("dummy\thttp://example.org/dummy\n" + | ||
"another\thttp://example.org/dummy\n" + | ||
"none\thttp://example.org/none") | ||
document_corpus = annif.corpus.DocumentFile(str(tmpfile)) | ||
project = annif.project.get_project('dummy-en') | ||
|
||
with app.app_context(): | ||
nn_ensemble.train(document_corpus, project) | ||
assert datadir.join('nn-model.h5').exists() | ||
assert datadir.join('nn-model.h5').size() > 0 | ||
|
||
# test online learning | ||
modelfile = datadir.join('nn-model.h5') | ||
|
||
old_size = modelfile.size() | ||
old_mtime = modelfile.mtime() | ||
|
||
time.sleep(0.1) # make sure the timestamp has a chance to increase | ||
|
||
nn_ensemble.learn(document_corpus, project) | ||
|
||
assert modelfile.size() != old_size or modelfile.mtime() != old_mtime | ||
|
||
|
||
def test_nn_ensemble_initialize(app, datadir): | ||
nn_ensemble_type = annif.backend.get_backend("nn_ensemble") | ||
nn_ensemble = nn_ensemble_type( | ||
backend_id='nn_ensemble', | ||
config_params={'sources': 'dummy-en'}, | ||
datadir=str(datadir)) | ||
|
||
assert nn_ensemble._model is None | ||
with app.app_context(): | ||
nn_ensemble.initialize() | ||
assert nn_ensemble._model is not None | ||
# initialize a second time - this shouldn't do anything | ||
with app.app_context(): | ||
nn_ensemble.initialize() | ||
|
||
|
||
def test_nn_ensemble_suggest(app, datadir): | ||
nn_ensemble_type = annif.backend.get_backend("nn_ensemble") | ||
nn_ensemble = nn_ensemble_type( | ||
backend_id='nn_ensemble', | ||
config_params={'sources': 'dummy-en'}, | ||
datadir=str(datadir)) | ||
|
||
project = annif.project.get_project('dummy-en') | ||
|
||
results = nn_ensemble.suggest("""Arkeologiaa sanotaan joskus myös | ||
muinaistutkimukseksi tai muinaistieteeksi. Se on humanistinen tiede | ||
tai oikeammin joukko tieteitä, jotka tutkivat ihmisen menneisyyttä. | ||
Tutkimusta tehdään analysoimalla muinaisjäännöksiä eli niitä jälkiä, | ||
joita ihmisten toiminta on jättänyt maaperään tai vesistöjen | ||
pohjaan.""", project) | ||
|
||
assert nn_ensemble._model is not None | ||
assert len(results) > 0 |
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Any comments @juhoinkinen or @mvsjober to the Keras model defined here? Would you do something differently?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Looks good