Skip to content

Commit

Permalink
refactor: use model utils instead of macro in recall tree (#4248)
Browse files Browse the repository at this point in the history
  • Loading branch information
jackgerrits authored Nov 10, 2022
1 parent 89da871 commit 417012d
Showing 1 changed file with 62 additions and 49 deletions.
111 changes: 62 additions & 49 deletions vowpalwabbit/core/src/reductions/recall_tree.cc
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
#include "vw/config/options.h"
#include "vw/core/global_data.h"
#include "vw/core/learner.h"
#include "vw/core/model_utils.h"
#include "vw/core/numeric_casts.h"
#include "vw/core/parser.h"
#include "vw/core/rand48.h"
Expand Down Expand Up @@ -151,20 +152,20 @@ node_pred* find_or_create(recall_tree& b, uint32_t cn, VW::example& ec)
return ls;
}

void compute_recall_lbest(recall_tree& b, node* n)
void compute_recall_lbest(const recall_tree& b, node& n)
{
if (n->n <= 0) { return; }
if (n.n <= 0) { return; }

double mass_at_k = 0;

for (node_pred* ls = n->preds.begin(); ls != n->preds.end() && ls < n->preds.begin() + b.max_candidates; ++ls)
for (node_pred* ls = n.preds.begin(); ls != n.preds.end() && ls < n.preds.begin() + b.max_candidates; ++ls)
{ mass_at_k += ls->label_count; }

float f = static_cast<float>(mass_at_k) / static_cast<float>(n->n);
float stdf = std::sqrt(f * (1.f - f) / static_cast<float>(n->n));
float diamf = 15.f / (std::sqrt(18.f) * static_cast<float>(n->n));
float f = static_cast<float>(mass_at_k) / static_cast<float>(n.n);
float stdf = std::sqrt(f * (1.f - f) / static_cast<float>(n.n));
float diamf = 15.f / (std::sqrt(18.f) * static_cast<float>(n.n));

n->recall_lbest = std::max(0.f, f - std::sqrt(b.bern_hyper) * stdf - b.bern_hyper * diamf);
n.recall_lbest = std::max(0.f, f - std::sqrt(b.bern_hyper) * stdf - b.bern_hyper * diamf);
}

double plogp(double c, double n) { return (c == 0) ? 0 : (c / n) * log(c / n); }
Expand Down Expand Up @@ -214,7 +215,7 @@ void insert_example_at_node(recall_tree& b, uint32_t cn, VW::example& ec)

b.nodes[cn].n += ec.weight;

compute_recall_lbest(b, &b.nodes[cn]);
compute_recall_lbest(b, b.nodes[cn]);
}

// TODO: handle if features already in this namespace
Expand Down Expand Up @@ -441,54 +442,66 @@ void save_load_tree(recall_tree& b, io_buf& model_file, bool read, bool text)
{
if (model_file.num_files() > 0)
{
std::stringstream msg;

WRITEIT(b.k, "k");
WRITEIT(b.node_only, "node_only");
WRITEITVAR(b.nodes.size(), "nodes", n_nodes);

if (read)
{
VW::model_utils::read_model_field(model_file, b.k);
VW::model_utils::read_model_field(model_file, b.node_only);
size_t n_nodes = 0;
VW::model_utils::read_model_field(model_file, n_nodes);
b.nodes.clear();
for (uint32_t j = 0; j < n_nodes; ++j) { b.nodes.push_back(node()); }
}

WRITEIT(b.max_candidates, "max_candidates");
WRITEIT(b.max_depth, "max_depth");

for (uint32_t j = 0; j < n_nodes; ++j)
{
node* cn = &b.nodes[j];

WRITEIT(cn->parent, "parent");
WRITEIT(cn->recall_lbest, "recall_lbest");
WRITEIT(cn->internal, "internal");
WRITEIT(cn->depth, "depth");
WRITEIT(cn->base_router, "base_router");
WRITEIT(cn->left, "left");
WRITEIT(cn->right, "right");
WRITEIT(cn->n, "n");
WRITEIT(cn->entropy, "entropy");
WRITEIT(cn->passes, "passes");

WRITEITVAR(cn->preds.size(), "n_preds", n_preds);

if (read)
for (uint32_t j = 0; j < n_nodes; ++j) { b.nodes.emplace_back(); }
VW::model_utils::read_model_field(model_file, b.max_candidates);
VW::model_utils::read_model_field(model_file, b.max_depth);
for (auto& cn : b.nodes)
{
cn->preds.clear();

for (uint32_t k = 0; k < n_preds; ++k) { cn->preds.push_back(node_pred(0)); }
VW::model_utils::read_model_field(model_file, cn.parent);
VW::model_utils::read_model_field(model_file, cn.recall_lbest);
VW::model_utils::read_model_field(model_file, cn.internal);
VW::model_utils::read_model_field(model_file, cn.depth);
VW::model_utils::read_model_field(model_file, cn.base_router);
VW::model_utils::read_model_field(model_file, cn.left);
VW::model_utils::read_model_field(model_file, cn.right);
VW::model_utils::read_model_field(model_file, cn.n);
VW::model_utils::read_model_field(model_file, cn.entropy);
VW::model_utils::read_model_field(model_file, cn.passes);
size_t n_preds = 0;
VW::model_utils::read_model_field(model_file, n_preds);
cn.preds.clear();
for (uint32_t k = 0; k < n_preds; ++k) { cn.preds.push_back(node_pred(0)); }
for (auto& pred : cn.preds)
{
VW::model_utils::read_model_field(model_file, pred.label);
VW::model_utils::read_model_field(model_file, pred.label_count);
}
compute_recall_lbest(b, cn);
}

for (uint32_t k = 0; k < n_preds; ++k)
}
else
{
VW::model_utils::write_model_field(model_file, b.k, "k", text);
VW::model_utils::write_model_field(model_file, b.node_only, "node_only", text);
VW::model_utils::write_model_field(model_file, b.nodes.size(), "nodes", text);
VW::model_utils::write_model_field(model_file, b.max_candidates, "max_candidates", text);
VW::model_utils::write_model_field(model_file, b.max_depth, "max_depth", text);
for (const auto& cn : b.nodes)
{
node_pred* pred = &cn->preds[k];

WRITEIT(pred->label, "label");
WRITEIT(pred->label_count, "label_count");
VW::model_utils::write_model_field(model_file, cn.parent, "parent", text);
VW::model_utils::write_model_field(model_file, cn.recall_lbest, "recall_lbest", text);
VW::model_utils::write_model_field(model_file, cn.internal, "internal", text);
VW::model_utils::write_model_field(model_file, cn.depth, "depth", text);
VW::model_utils::write_model_field(model_file, cn.base_router, "base_router", text);
VW::model_utils::write_model_field(model_file, cn.left, "left", text);
VW::model_utils::write_model_field(model_file, cn.right, "right", text);
VW::model_utils::write_model_field(model_file, cn.n, "n", text);
VW::model_utils::write_model_field(model_file, cn.entropy, "entropy", text);
VW::model_utils::write_model_field(model_file, cn.passes, "passes", text);
VW::model_utils::write_model_field(model_file, cn.preds.size(), "n_preds", text);
for (const auto& pred : cn.preds)
{
VW::model_utils::write_model_field(model_file, pred.label, "label", text);
VW::model_utils::write_model_field(model_file, pred.label_count, "label_count", text);
}
}

if (read) { compute_recall_lbest(b, cn); }
}
}
}
Expand Down

0 comments on commit 417012d

Please sign in to comment.