Skip to content

Commit

Permalink
Merge pull request #288 from NatLibFi/issue235-more-features
Browse files Browse the repository at this point in the history
Add more features to the model in vw_ensemble
  • Loading branch information
osma authored Jul 2, 2019
2 parents 8521652 + d531b74 commit 2ea1130
Show file tree
Hide file tree
Showing 2 changed files with 35 additions and 11 deletions.
41 changes: 32 additions & 9 deletions annif/backend/vw_ensemble.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,10 @@ class VWEnsembleBackend(
# will make it more careful so that it will require more training data.
DEFAULT_DISCOUNT_RATE = 0.01

# score threshold for "zero features": scores lower than this will be
# considered zero and marked with a zero feature given to VW
ZERO_THRESHOLD = 0.001

def _load_subject_freq(self):
path = os.path.join(self.datadir, self.FREQ_FILE)
if not os.path.exists(path):
Expand Down Expand Up @@ -94,17 +98,30 @@ def _source_project_ids(self):
sources = annif.util.parse_sources(self.params['sources'])
return [project_id for project_id, _ in sources]

def _format_example(self, subject_id, scores, true=None):
@staticmethod
def _format_value(true):
if true is None:
val = ''
return ''
elif true:
val = 1
return 1
else:
val = -1
ex = "{} |{}".format(val, subject_id)
for proj_idx, proj in enumerate(self._source_project_ids):
ex += " {}:{:.6f}".format(proj, scores[proj_idx])
return ex
return -1

def _format_example(self, subject_id, scores, true=None):
features = " ".join(["{}:{:.6f}".format(proj, scores[proj_idx])
for proj_idx, proj
in enumerate(self._source_project_ids)])
zero_features = " ".join(["zero^{}".format(proj)
for proj_idx, proj
in enumerate(self._source_project_ids)
if scores[proj_idx] < self.ZERO_THRESHOLD])
return "{} |raw {} {} |{} {} {}".format(
self._format_value(true),
features,
zero_features,
subject_id,
features,
zero_features)

def _doc_score_vector(self, doc, source_projects):
score_vectors = []
Expand All @@ -119,7 +136,8 @@ def _doc_to_example(self, doc, project, source_projects):
true = subjects.as_vector(project.subjects)
score_vector = self._doc_score_vector(doc, source_projects)
for subj_id in range(len(true)):
if true[subj_id] or score_vector[:, subj_id].sum() > 0.0:
if true[subj_id] \
or score_vector[:, subj_id].sum() >= self.ZERO_THRESHOLD:
ex = (subj_id, self._format_example(
subj_id,
score_vector[:, subj_id],
Expand All @@ -136,6 +154,11 @@ def _create_examples(self, corpus, project):
random.shuffle(examples)
return examples

def _create_model(self, project):
# add interactions between raw (descriptor-invariant) features to
# the mix
super()._create_model(project, {'q': 'rr'})

@staticmethod
def _write_freq_file(subject_freq, filename):
with open(filename, 'w') as freqfile:
Expand Down
5 changes: 3 additions & 2 deletions tests/test_backend_vw_ensemble.py
Original file line number Diff line number Diff line change
Expand Up @@ -126,7 +126,7 @@ def test_vw_ensemble_format_example(datadir):
datadir=str(datadir))

ex = vw_ensemble._format_example(0, [0.5])
assert ex == ' |0 dummy-en:0.500000'
assert ex == ' |raw dummy-en:0.500000 |0 dummy-en:0.500000 '


def test_vw_ensemble_format_example_avoid_sci_notation(datadir):
Expand All @@ -137,4 +137,5 @@ def test_vw_ensemble_format_example_avoid_sci_notation(datadir):
datadir=str(datadir))

ex = vw_ensemble._format_example(0, [7.24e-05])
assert ex == ' |0 dummy-en:0.000072'
assert ex == ' |raw dummy-en:0.000072 zero^dummy-en' + \
' |0 dummy-en:0.000072 zero^dummy-en'

0 comments on commit 2ea1130

Please sign in to comment.