Skip to content

Commit

Permalink
Merge pull request #677 from NatLibFi/batching-in-nn-ensemble-suggest…
Browse files Browse the repository at this point in the history
…ions

 Batch suggest in ensemble backends
  • Loading branch information
juhoinkinen authored Mar 7, 2023
2 parents 3e8f42f + f280342 commit eb437a8
Show file tree
Hide file tree
Showing 4 changed files with 65 additions and 48 deletions.
65 changes: 37 additions & 28 deletions annif/backend/ensemble.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,35 +33,32 @@ def _normalize_hits(self, hits, source_project):
by subclasses."""
return hits

def _suggest_with_sources(self, text, sources):
hits_from_sources = []
def _suggest_with_sources(self, texts, sources):
hit_sets_from_sources = []
for project_id, weight in sources:
source_project = self.project.registry.get_project(project_id)
hits = source_project.suggest([text])[0]
self.debug(
"Got {} hits from project {}, weight {}".format(
len(hits), source_project.project_id, weight
hit_sets = source_project.suggest(texts)
norm_hit_sets = [
self._normalize_hits(hits, source_project) for hits in hit_sets
]
hit_sets_from_sources.append(
annif.suggestion.WeightedSuggestionsBatch(
hit_sets=norm_hit_sets,
weight=weight,
subjects=source_project.subjects,
)
)
norm_hits = self._normalize_hits(hits, source_project)
hits_from_sources.append(
annif.suggestion.WeightedSuggestion(
hits=norm_hits, weight=weight, subjects=source_project.subjects
)
)
return hits_from_sources
return hit_sets_from_sources

def _merge_hits_from_sources(self, hits_from_sources, params):
"""Hook for merging hits from sources. Can be overridden by
def _merge_hit_sets_from_sources(self, hit_sets_from_sources, params):
"""Hook for merging hit sets from sources. Can be overridden by
subclasses."""
return annif.util.merge_hits(hits_from_sources, len(self.project.subjects))
return annif.util.merge_hits(hit_sets_from_sources, len(self.project.subjects))

def _suggest(self, text, params):
def _suggest_batch(self, texts, params):
sources = annif.util.parse_sources(params["sources"])
hits_from_sources = self._suggest_with_sources(text, sources)
merged_hits = self._merge_hits_from_sources(hits_from_sources, params)
self.debug("{} hits after merging".format(len(merged_hits)))
return merged_hits
hit_sets_from_sources = self._suggest_with_sources(texts, sources)
return self._merge_hit_sets_from_sources(hit_sets_from_sources, params)


class EnsembleOptimizer(hyperopt.HyperparameterOptimizer):
Expand Down Expand Up @@ -95,11 +92,23 @@ def _prepare(self, n_jobs=1):
jobs, pool_class = annif.parallel.get_pool(n_jobs)

with pool_class(jobs) as pool:
for hits, subject_set in pool.imap_unordered(
psmap.suggest, self._corpus.documents
for hit_sets, subject_sets in pool.imap_unordered(
psmap.suggest_batch, self._corpus.doc_batches
):
self._gold_subjects.append(subject_set)
self._source_hits.append(hits)
self._gold_subjects.extend(subject_sets)
self._source_hits.extend(self._hit_sets_to_list(hit_sets))

def _hit_sets_to_list(self, hit_sets):
"""Convert a dict of lists of hits to a list of dicts of hits, e.g.
{"proj-1": [p-1-doc-1-hits, p-1-doc-2-hits]
"proj-2": [p-2-doc-1-hits, p-2-doc-2-hits]}
to
[{"proj-1": p-1-doc-1-hits, "proj-2": p-2-doc-1-hits},
{"proj-1": p-1-doc-2-hits, "proj-2": p-2-doc-2-hits}]
"""
return [
dict(zip(hit_sets.keys(), doc_hits)) for doc_hits in zip(*hit_sets.values())
]

def _normalize(self, hps):
total = sum(hps.values())
Expand All @@ -120,16 +129,16 @@ def _objective(self, trial):
weighted_hits = []
for project_id, hits in srchits.items():
weighted_hits.append(
annif.suggestion.WeightedSuggestion(
hits=hits,
annif.suggestion.WeightedSuggestionsBatch(
hit_sets=[hits],
weight=weights[project_id],
subjects=self._backend.project.subjects,
)
)
batch.evaluate(
annif.util.merge_hits(
weighted_hits, len(self._backend.project.subjects)
),
)[0],
goldsubj,
)
results = batch.results(metrics=[self._metric])
Expand Down
21 changes: 12 additions & 9 deletions annif/backend/nn_ensemble.py
Original file line number Diff line number Diff line change
Expand Up @@ -130,18 +130,21 @@ def initialize(self, parallel=False):
model_filename, custom_objects={"MeanLayer": MeanLayer}
)

def _merge_hits_from_sources(self, hits_from_sources, params):
score_vector = np.array(
def _merge_hit_sets_from_sources(self, hit_sets_from_sources, params):
score_vectors = np.array(
[
np.sqrt(hits.as_vector(len(subjects))) * weight * len(hits_from_sources)
for hits, weight, subjects in hits_from_sources
[
np.sqrt(hits.as_vector(len(subjects)))
* weight
* len(hit_sets_from_sources)
for hits in proj_hit_set
]
for proj_hit_set, weight, subjects in hit_sets_from_sources
],
dtype=np.float32,
)
results = self._model.predict(
np.expand_dims(score_vector.transpose(), 0), verbose=0
)
return VectorSuggestionResult(results[0])
).transpose(1, 2, 0)
results = self._model(score_vectors).numpy()
return [VectorSuggestionResult(res) for res in results]

def _create_model(self, sources):
self.info("creating NN ensemble model")
Expand Down
4 changes: 2 additions & 2 deletions annif/suggestion.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,8 +7,8 @@
import numpy as np

SubjectSuggestion = collections.namedtuple("SubjectSuggestion", "subject_id score")
WeightedSuggestion = collections.namedtuple(
"WeightedSuggestion", "hits weight subjects"
WeightedSuggestionsBatch = collections.namedtuple(
"WeightedSuggestionsBatch", "hit_sets weight subjects"
)


Expand Down
23 changes: 14 additions & 9 deletions annif/util.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,15 +54,20 @@ def cleanup_uri(uri):
return uri


def merge_hits(weighted_hits, size):
"""Merge hits from multiple sources. Input is a sequence of WeightedSuggestion
objects. The size parameter determines the length of the subject vector.
Returns an SuggestionResult object."""

weights = [whit.weight for whit in weighted_hits]
scores = [whit.hits.as_vector(size) for whit in weighted_hits]
result = np.average(scores, axis=0, weights=weights)
return VectorSuggestionResult(result)
def merge_hits(weighted_hits_batches, size):
"""Merge hit sets from multiple sources. Input is a sequence of
WeightedSuggestionsBatch objects. The size parameter determines the length of the
subject vector. Returns a list of SuggestionResult objects."""

weights = [batch.weight for batch in weighted_hits_batches]
score_vectors = np.array(
[
[whits.as_vector(size) for whits in batch.hit_sets]
for batch in weighted_hits_batches
]
)
results = np.average(score_vectors, axis=0, weights=weights)
return [VectorSuggestionResult(res) for res in results]


def parse_sources(sourcedef):
Expand Down

0 comments on commit eb437a8

Please sign in to comment.