diff --git a/annif/backend/ensemble.py b/annif/backend/ensemble.py index 0122f15f4..c57159112 100644 --- a/annif/backend/ensemble.py +++ b/annif/backend/ensemble.py @@ -42,22 +42,18 @@ def _suggest_with_sources(self, texts, sources): self._normalize_hits(hits, source_project) for hits in hit_sets ] hit_sets_from_sources.append( - [ - annif.suggestion.WeightedSuggestion( - hits=norm_hits, weight=weight, subjects=source_project.subjects - ) - for norm_hits in norm_hit_sets - ] + annif.suggestion.WeightedSuggestionsBatch( + hit_sets=norm_hit_sets, + weight=weight, + subjects=source_project.subjects, + ) ) return hit_sets_from_sources def _merge_hit_sets_from_sources(self, hit_sets_from_sources, params): - """Hook for merging hits from sources. Can be overridden by + """Hook for merging hit sets from sources. Can be overridden by subclasses.""" - return [ - annif.util.merge_hits(hits, len(self.project.subjects)) - for hits in hit_sets_from_sources - ] + return annif.util.merge_hits(hit_sets_from_sources, len(self.project.subjects)) def _suggest_batch(self, texts, params): sources = annif.util.parse_sources(params["sources"]) @@ -133,8 +129,8 @@ def _objective(self, trial): weighted_hits = [] for project_id, hits in srchits.items(): weighted_hits.append( - annif.suggestion.WeightedSuggestion( - hits=hits, + annif.suggestion.WeightedSuggestionsBatch( + hit_sets=[hits], weight=weights[project_id], subjects=self._backend.project.subjects, ) @@ -142,7 +138,7 @@ def _objective(self, trial): batch.evaluate( annif.util.merge_hits( weighted_hits, len(self._backend.project.subjects) - ), + )[0], goldsubj, ) results = batch.results(metrics=[self._metric]) diff --git a/annif/backend/nn_ensemble.py b/annif/backend/nn_ensemble.py index 7e37728e1..8951094ff 100644 --- a/annif/backend/nn_ensemble.py +++ b/annif/backend/nn_ensemble.py @@ -137,9 +137,9 @@ def _merge_hit_sets_from_sources(self, hit_sets_from_sources, params): np.sqrt(hits.as_vector(len(subjects))) * weight * len(hit_sets_from_sources) - for hits, weight, subjects in hits_from_sources + for hits in proj_hit_set ] - for hits_from_sources in hit_sets_from_sources + for proj_hit_set, weight, subjects in hit_sets_from_sources ], dtype=np.float32, ).transpose(1, 2, 0) diff --git a/annif/suggestion.py b/annif/suggestion.py index fdf5aa94d..947e37451 100644 --- a/annif/suggestion.py +++ b/annif/suggestion.py @@ -7,8 +7,8 @@ import numpy as np SubjectSuggestion = collections.namedtuple("SubjectSuggestion", "subject_id score") -WeightedSuggestion = collections.namedtuple( - "WeightedSuggestion", "hits weight subjects" +WeightedSuggestionsBatch = collections.namedtuple( + "WeightedSuggestionsBatch", "hit_sets weight subjects" ) diff --git a/annif/util.py b/annif/util.py index d0bf9842c..9d3cf236c 100644 --- a/annif/util.py +++ b/annif/util.py @@ -54,15 +54,20 @@ def cleanup_uri(uri): return uri -def merge_hits(weighted_hits, size): - """Merge hits from multiple sources. Input is a sequence of WeightedSuggestion - objects. The size parameter determines the length of the subject vector. - Returns an SuggestionResult object.""" - - weights = [whit.weight for whit in weighted_hits] - scores = [whit.hits.as_vector(size) for whit in weighted_hits] - result = np.average(scores, axis=0, weights=weights) - return VectorSuggestionResult(result) +def merge_hits(weighted_hits_batches, size): + """Merge hit sets from multiple sources. Input is a sequence of + WeightedSuggestionsBatch objects. The size parameter determines the length of the + subject vector. Returns a list of SuggestionResult objects.""" + + weights = [batch.weight for batch in weighted_hits_batches] + score_vectors = np.array( + [ + [whits.as_vector(size) for whits in batch.hit_sets] + for batch in weighted_hits_batches + ] + ) + results = np.average(score_vectors, axis=0, weights=weights) + return [VectorSuggestionResult(res) for res in results] def parse_sources(sourcedef):