Skip to content

Commit

Permalink
Turn WeightedSuggestion to WeightedSuggestionsBatch
Browse files Browse the repository at this point in the history
  • Loading branch information
juhoinkinen committed Mar 7, 2023
1 parent 42e1afc commit f280342
Show file tree
Hide file tree
Showing 4 changed files with 28 additions and 27 deletions.
24 changes: 10 additions & 14 deletions annif/backend/ensemble.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,22 +42,18 @@ def _suggest_with_sources(self, texts, sources):
self._normalize_hits(hits, source_project) for hits in hit_sets
]
hit_sets_from_sources.append(
[
annif.suggestion.WeightedSuggestion(
hits=norm_hits, weight=weight, subjects=source_project.subjects
)
for norm_hits in norm_hit_sets
]
annif.suggestion.WeightedSuggestionsBatch(
hit_sets=norm_hit_sets,
weight=weight,
subjects=source_project.subjects,
)
)
return hit_sets_from_sources

def _merge_hit_sets_from_sources(self, hit_sets_from_sources, params):
"""Hook for merging hits from sources. Can be overridden by
"""Hook for merging hit sets from sources. Can be overridden by
subclasses."""
return [
annif.util.merge_hits(hits, len(self.project.subjects))
for hits in hit_sets_from_sources
]
return annif.util.merge_hits(hit_sets_from_sources, len(self.project.subjects))

def _suggest_batch(self, texts, params):
sources = annif.util.parse_sources(params["sources"])
Expand Down Expand Up @@ -133,16 +129,16 @@ def _objective(self, trial):
weighted_hits = []
for project_id, hits in srchits.items():
weighted_hits.append(
annif.suggestion.WeightedSuggestion(
hits=hits,
annif.suggestion.WeightedSuggestionsBatch(
hit_sets=[hits],
weight=weights[project_id],
subjects=self._backend.project.subjects,
)
)
batch.evaluate(
annif.util.merge_hits(
weighted_hits, len(self._backend.project.subjects)
),
)[0],
goldsubj,
)
results = batch.results(metrics=[self._metric])
Expand Down
4 changes: 2 additions & 2 deletions annif/backend/nn_ensemble.py
Original file line number Diff line number Diff line change
Expand Up @@ -137,9 +137,9 @@ def _merge_hit_sets_from_sources(self, hit_sets_from_sources, params):
np.sqrt(hits.as_vector(len(subjects)))
* weight
* len(hit_sets_from_sources)
for hits, weight, subjects in hits_from_sources
for hits in proj_hit_set
]
for hits_from_sources in hit_sets_from_sources
for proj_hit_set, weight, subjects in hit_sets_from_sources
],
dtype=np.float32,
).transpose(1, 2, 0)
Expand Down
4 changes: 2 additions & 2 deletions annif/suggestion.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,8 +7,8 @@
import numpy as np

SubjectSuggestion = collections.namedtuple("SubjectSuggestion", "subject_id score")
WeightedSuggestion = collections.namedtuple(
"WeightedSuggestion", "hits weight subjects"
WeightedSuggestionsBatch = collections.namedtuple(
"WeightedSuggestionsBatch", "hit_sets weight subjects"
)


Expand Down
23 changes: 14 additions & 9 deletions annif/util.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,15 +54,20 @@ def cleanup_uri(uri):
return uri


def merge_hits(weighted_hits, size):
"""Merge hits from multiple sources. Input is a sequence of WeightedSuggestion
objects. The size parameter determines the length of the subject vector.
Returns an SuggestionResult object."""

weights = [whit.weight for whit in weighted_hits]
scores = [whit.hits.as_vector(size) for whit in weighted_hits]
result = np.average(scores, axis=0, weights=weights)
return VectorSuggestionResult(result)
def merge_hits(weighted_hits_batches, size):
"""Merge hit sets from multiple sources. Input is a sequence of
WeightedSuggestionsBatch objects. The size parameter determines the length of the
subject vector. Returns a list of SuggestionResult objects."""

weights = [batch.weight for batch in weighted_hits_batches]
score_vectors = np.array(
[
[whits.as_vector(size) for whits in batch.hit_sets]
for batch in weighted_hits_batches
]
)
results = np.average(score_vectors, axis=0, weights=weights)
return [VectorSuggestionResult(res) for res in results]


def parse_sources(sourcedef):
Expand Down

0 comments on commit f280342

Please sign in to comment.