Skip to content

Commit

Permalink
feat: not self consistent speed up (#4652)
Browse files Browse the repository at this point in the history
* feat: Improved runtime performance of EMT when using the not_self_consistent_rank flag.

* feat: Improved runtime performance of EMT when using the not_self_consistent_rank flag.

* Fixed code formatting.

---------

Co-authored-by: Alexey Taymanov <41013086+ataymano@users.noreply.github.com>
  • Loading branch information
mrucker and ataymano authored Jan 23, 2024
1 parent b4612f8 commit f8091b6
Showing 1 changed file with 23 additions and 28 deletions.
51 changes: 23 additions & 28 deletions vowpalwabbit/core/src/reductions/eigen_memory_tree.cc
Original file line number Diff line number Diff line change
Expand Up @@ -489,15 +489,7 @@ void tree_bound(emt_tree& b, emt_example* ec)
}
}

void scorer_features(const emt_feats& f1, VW::features& out)
{
for (auto p : f1)
{
if (p.second != 0) { out.push_back(p.second, p.first); }
}
}

void scorer_features(const emt_feats& f1, const emt_feats& f2, VW::features& out)
void scorer_features_sub(const emt_feats& f1, const emt_feats& f2, VW::features& out)
{
auto iter1 = f1.begin();
auto iter2 = f2.begin();
Expand Down Expand Up @@ -535,15 +527,31 @@ void scorer_features(const emt_feats& f1, const emt_feats& f2, VW::features& out
}
}

void scorer_features_mul(const emt_feats& f1, const emt_feats& f2, VW::features& out)
{
auto iter1 = f1.begin();
auto iter2 = f2.begin();

while (iter1 != f1.end() && iter2 != f2.end())
{
if (iter1->first < iter2->first) { iter1++; }
else if (iter2->first < iter1->first) { iter2++; }
else
{
out.push_back(iter1->second * iter2->second, iter1->first);
iter1++;
iter2++;
}
}
}

void scorer_example(emt_tree& b, const emt_example& ex1, const emt_example& ex2)
{
VW::example& out = *b.ex;

static constexpr VW::namespace_index X_NS = 'x';
static constexpr VW::namespace_index Z_NS = 'z';

out.feature_space[X_NS].clear();
out.feature_space[Z_NS].clear();

if (b.scorer_type == emt_scorer_type::SELF_CONSISTENT_RANK)
{
Expand All @@ -552,7 +560,7 @@ void scorer_example(emt_tree& b, const emt_example& ex1, const emt_example& ex2)

out.interactions->clear();

scorer_features(ex1.full, ex2.full, out.feature_space[X_NS]);
scorer_features_sub(ex1.full, ex2.full, out.feature_space[X_NS]);

out.total_sum_feat_sq = out.feature_space[X_NS].sum_feat_sq;
out.num_features = out.feature_space[X_NS].size();
Expand All @@ -565,26 +573,13 @@ void scorer_example(emt_tree& b, const emt_example& ex1, const emt_example& ex2)
{
out.indices.clear();
out.indices.push_back(X_NS);
out.indices.push_back(Z_NS);

out.interactions->clear();
out.interactions->push_back({X_NS, Z_NS});

b.all->feature_tweaks_config.ignore_some_linear = true;
b.all->feature_tweaks_config.ignore_linear[X_NS] = true;
b.all->feature_tweaks_config.ignore_linear[Z_NS] = true;

scorer_features(ex1.full, out.feature_space[X_NS]);
scorer_features(ex2.full, out.feature_space[Z_NS]);
scorer_features_mul(ex1.full, ex2.full, out.feature_space[X_NS]);

// when we receive ex1 and ex2 their features are indexed on top of eachother. In order
// to make sure VW recognizes the features from the two examples as separate features
// we apply a map of multiplying by 2 and then offseting by 1 on the second example.
for (auto& j : out.feature_space[X_NS].indices) { j = j * 2; }
for (auto& j : out.feature_space[Z_NS].indices) { j = j * 2 + 1; }

out.total_sum_feat_sq = out.feature_space[X_NS].sum_feat_sq + out.feature_space[Z_NS].sum_feat_sq;
out.num_features = out.feature_space[X_NS].size() + out.feature_space[Z_NS].size();
out.total_sum_feat_sq = out.feature_space[X_NS].sum_feat_sq;
out.num_features = out.feature_space[X_NS].size();

auto initial = emt_initial(b.initial_type, ex1.full, ex2.full);
out.ex_reduction_features.get<VW::simple_label_reduction_features>().initial = initial;
Expand Down

0 comments on commit f8091b6

Please sign in to comment.