Skip to content

Commit

Permalink
Refactor: split _merge_hits_from_sources in vw_ensemble
Browse files Browse the repository at this point in the history
  • Loading branch information
osma committed Jun 28, 2019
1 parent 9a1345c commit 34dc5a4
Showing 1 changed file with 12 additions and 8 deletions.
20 changes: 12 additions & 8 deletions annif/backend/vw_ensemble.py
Original file line number Diff line number Diff line change
Expand Up @@ -67,21 +67,25 @@ def initialize(self):
self._load_subject_freq()
super().initialize()

def _calculate_scores(self, subj_id, subj_score_vector):
ex = self._format_example(subj_id, subj_score_vector)
raw_score = subj_score_vector.mean()
pred_score = (self._model.predict(ex) + 1.0) / 2.0
return raw_score, pred_score

def _merge_hits_from_sources(self, hits_from_sources, project, params):
score_vector = np.array([hits.vector
for hits, _ in hits_from_sources])
discount_rate = self.params.get('discount_rate',
self.DEFAULT_DISCOUNT_RATE)
result = np.zeros(score_vector.shape[1])
for subj_id in range(score_vector.shape[1]):
if score_vector[:, subj_id].sum() > 0.0:
ex = self._format_example(
subj_id,
score_vector[:, subj_id])
discount_rate = self.params.get(
'discount_rate', self.DEFAULT_DISCOUNT_RATE)
subj_score_vector = score_vector[:, subj_id]
if subj_score_vector.sum() > 0.0:
raw_score, pred_score = self._calculate_scores(
subj_id, subj_score_vector)
raw_weight = 1.0 / \
((discount_rate * self._subject_freq[subj_id]) + 1)
raw_score = score_vector[:, subj_id].mean()
pred_score = (self._model.predict(ex) + 1.0) / 2.0
result[subj_id] = (raw_weight * raw_score) + \
(1.0 - raw_weight) * pred_score
return VectorSuggestionResult(result, project.subjects)
Expand Down

0 comments on commit 34dc5a4

Please sign in to comment.