From 34dc5a40b5657cb77c3ae9d02a3c1662c3a09454 Mon Sep 17 00:00:00 2001 From: Osma Suominen Date: Fri, 28 Jun 2019 11:33:28 +0300 Subject: [PATCH] Refactor: split _merge_hits_from_sources in vw_ensemble --- annif/backend/vw_ensemble.py | 20 ++++++++++++-------- 1 file changed, 12 insertions(+), 8 deletions(-) diff --git a/annif/backend/vw_ensemble.py b/annif/backend/vw_ensemble.py index e25b43964..13f188041 100644 --- a/annif/backend/vw_ensemble.py +++ b/annif/backend/vw_ensemble.py @@ -67,21 +67,25 @@ def initialize(self): self._load_subject_freq() super().initialize() + def _calculate_scores(self, subj_id, subj_score_vector): + ex = self._format_example(subj_id, subj_score_vector) + raw_score = subj_score_vector.mean() + pred_score = (self._model.predict(ex) + 1.0) / 2.0 + return raw_score, pred_score + def _merge_hits_from_sources(self, hits_from_sources, project, params): score_vector = np.array([hits.vector for hits, _ in hits_from_sources]) + discount_rate = self.params.get('discount_rate', + self.DEFAULT_DISCOUNT_RATE) result = np.zeros(score_vector.shape[1]) for subj_id in range(score_vector.shape[1]): - if score_vector[:, subj_id].sum() > 0.0: - ex = self._format_example( - subj_id, - score_vector[:, subj_id]) - discount_rate = self.params.get( - 'discount_rate', self.DEFAULT_DISCOUNT_RATE) + subj_score_vector = score_vector[:, subj_id] + if subj_score_vector.sum() > 0.0: + raw_score, pred_score = self._calculate_scores( + subj_id, subj_score_vector) raw_weight = 1.0 / \ ((discount_rate * self._subject_freq[subj_id]) + 1) - raw_score = score_vector[:, subj_id].mean() - pred_score = (self._model.predict(ex) + 1.0) / 2.0 result[subj_id] = (raw_weight * raw_score) + \ (1.0 - raw_weight) * pred_score return VectorSuggestionResult(result, project.subjects)