Skip to content

Commit

Permalink
Merge pull request #37 from SoftwareBuildingBlocks/icml_push
Browse files Browse the repository at this point in the history
Rename prob_dist_new, pdf_new, pmf_to_pdf_new
  • Loading branch information
mmajzoubi authored May 6, 2020
2 parents f37ed5c + 7cfd827 commit 8c4f372
Show file tree
Hide file tree
Showing 19 changed files with 168 additions and 483 deletions.
2 changes: 1 addition & 1 deletion test/unit_test/CMakeLists.txt
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
add_executable(vw-unit-test.out
main.cc options_test.cc options_boost_po_test.cc cb_explore_adf_test.cc explore_test.cc
stable_unique_tests.cc test_common.h object_pool_test.cc ccb_parser_test.cc json_parser_test.cc
dsjson_parser_test.cc ccb_test.cc offset_tree_tests.cc random_test.cc cats_tree_tests.cc pmf_to_pdf_test.cc
dsjson_parser_test.cc ccb_test.cc offset_tree_tests.cc random_test.cc cats_tree_tests.cc pmf_to_pdf_test.cc
)

# Add the include directories from vw target for testing
Expand Down
20 changes: 7 additions & 13 deletions test/unit_test/pmf_to_pdf_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@ struct reduction_test_harness
reduction_test_harness() : _curr_idx(0) {}

void set_predict_response(const vector<pair<uint32_t, float>>& predictions) { _predictions = predictions; }

void test_predict(single_learner& base, example& ec)
{
ec.pred.a_s.clear();
Expand All @@ -39,11 +39,11 @@ struct reduction_test_harness
void test_learn(single_learner& base, example& ec)
{
cout << "ec.l.cb.costs after:" << endl;
for (uint32_t i = 0; i < ec.l.cb.costs.size(); i++)
for (uint32_t i = 0; i < ec.l.cb.costs.size(); i++)
{
cout << "(" << ec.l.cb.costs[i].action << " , " << ec.l.cb.costs[i].cost << " , " << ec.l.cb.costs[i].probability
<< " , " << ec.l.cb.costs[i].partial_prediction << "), " << endl;
}
}
}

static void predict(reduction_test_harness& test_reduction, single_learner& base, example& ec)
Expand Down Expand Up @@ -81,7 +81,7 @@ BOOST_AUTO_TEST_CASE(pmf_to_pdf_basic)
const auto test_harness = VW::pmf_to_pdf::get_test_harness_reduction(prediction_scores);

example ec;

auto data = scoped_calloc_or_throw<VW::pmf_to_pdf::reduction>();
data->num_actions = k;
data->bandwidth = h;
Expand All @@ -97,7 +97,7 @@ BOOST_AUTO_TEST_CASE(pmf_to_pdf_basic)
cout << "ec.pred.p_d (PDF): " << endl;
for (uint32_t i = 0; i < ec.pred.prob_dist.size(); i++)
{
cout << "(" << ec.pred.prob_dist[i].left << " , " << ec.pred.prob_dist[i].right <<
cout << "(" << ec.pred.prob_dist[i].left << " , " << ec.pred.prob_dist[i].right <<
": " << ec.pred.prob_dist[i].pdf_value << ")" << endl;
sum += ec.pred.prob_dist[i].pdf_value * (ec.pred.prob_dist[i].right - ec.pred.prob_dist[i].left);
}
Expand All @@ -106,18 +106,12 @@ BOOST_AUTO_TEST_CASE(pmf_to_pdf_basic)
ec.l.cb_cont = VW::cb_continuous::continuous_label();
ec.l.cb_cont.costs = v_init<VW::cb_continuous::continuous_label_elm>();
ec.l.cb_cont.costs.clear();
ec.l.cb_cont.costs.push_back({1010.17f, .5f, .05f, 0.f}); // action, cost, prob, partial
ec.l.cb_cont.costs.push_back({1010.17f, .5f, .05f}); // action, cost, prob

cout << "ec.l.cb_cont.costs after:" << endl;
cout << "(" << ec.l.cb_cont.costs[0].action << " , " << ec.l.cb_cont.costs[0].cost << " , " << ec.l.cb_cont.costs[0].probability
<< " , " << ec.l.cb_cont.costs[0].partial_prediction << "), " << endl;
cout << "(" << ec.l.cb_cont.costs[0].action << " , " << ec.l.cb_cont.costs[0].cost << " , " << ec.l.cb_cont.costs[0].probability << "), " << endl;

learn(*data, *as_singleline(test_harness), ec);

float chosen_action = 1080;
cout << "pdf value of " << chosen_action << " is = " << VW::actions_pdf::get_pdf_value(ec.pred.prob_dist, chosen_action)
<< std::endl;
cout << "here" << endl;
}

namespace VW { namespace pmf_to_pdf {
Expand Down
4 changes: 2 additions & 2 deletions test/unit_test/unit_test.vcxproj
Original file line number Diff line number Diff line change
Expand Up @@ -19,11 +19,11 @@
</ProjectConfiguration>
</ItemGroup>
<ItemGroup>
<ClCompile Include="cats_tree_tests.cc" />
<ClCompile Include="cb_explore_adf_test.cc" />
<ClCompile Include="ccb_test.cc" />
<ClCompile Include="ccb_parser_test.cc" />
<ClCompile Include="object_pool_test.cc" />
<ClCompile Include="offset_tree_cont_tests.cc" />
<ClCompile Include="offset_tree_tests.cc" />
<ClCompile Include="options_boost_po_test.cc" />
<ClCompile Include="options_test.cc" />
Expand All @@ -48,7 +48,7 @@
</ProjectReference>
</ItemGroup>
<ItemGroup>
<ClInclude Include="offset_tree_cont_tests.h" />
<ClInclude Include="cats_tree_tests.h" />
<ClInclude Include="test_common.h" />
</ItemGroup>
<PropertyGroup Label="Globals">
Expand Down
8 changes: 4 additions & 4 deletions test/unit_test/unit_test.vcxproj.filters
Original file line number Diff line number Diff line change
Expand Up @@ -39,9 +39,6 @@
<ClCompile Include="offset_tree_tests.cc">
<Filter>Source Files</Filter>
</ClCompile>
<ClCompile Include="offset_tree_cont_tests.cc">
<Filter>Source Files</Filter>
</ClCompile>
<ClCompile Include="object_pool_test.cc">
<Filter>Source Files</Filter>
</ClCompile>
Expand All @@ -63,6 +60,9 @@
<ClCompile Include="pmf_to_pdf_test.cc">
<Filter>Source Files</Filter>
</ClCompile>
<ClCompile Include="cats_tree_tests.cc">
<Filter>Source Files</Filter>
</ClCompile>
</ItemGroup>
<ItemGroup>
<None Include="packages.config" />
Expand All @@ -71,7 +71,7 @@
<ClInclude Include="test_common.h">
<Filter>Header Files</Filter>
</ClInclude>
<ClInclude Include="offset_tree_cont_tests.h">
<ClInclude Include="cats_tree_tests.h">
<Filter>Header Files</Filter>
</ClInclude>
</ItemGroup>
Expand Down
4 changes: 2 additions & 2 deletions vowpalwabbit/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,7 @@ set(vw_all_headers
classweight.h parse_regressor.h kernel_svm.h confidence.h label_dictionary.h
config.h.in primitives.h lda_core.h print.h vw_versions.h offset_tree.h shared_feature_merger.h conditional_contextual_bandit.h
ccb_label.h pmf_to_pdf.h api_status.h cb_label_parser.h errors_data.h
err_constants.h prob_dist_cont.h cb_continuous.h cats.h cats_pdf.h cb_explore_pdf.h get_pmf.h sample_pdf.h cats_tree.h
err_constants.h prob_dist_cont.h cb_continuous_label.h cats.h cats_pdf.h cb_explore_pdf.h get_pmf.h sample_pdf.h cats_tree.h
)

set(vw_all_sources
Expand All @@ -52,7 +52,7 @@ set(vw_all_sources
comp_io.cc interactions.cc vw_validate.cc audit_regressor.cc gen_cs_example.cc cb_explore.cc
action_score.cc cb_explore_adf.cc OjaNewton.cc baseline.cc classweight.cc
offset_tree.cc vw_exception.cc no_label.cc shared_feature_merger.cc conditional_contextual_bandit.cc
cb_sample.cc ccb_label.cc version.cc pmf_to_pdf.cc api_status.cc prob_dist_cont.cc cb_continuous.cc
cb_sample.cc ccb_label.cc version.cc pmf_to_pdf.cc api_status.cc prob_dist_cont.cc cb_continuous_label.cc
cats.cc cats_pdf.cc cb_explore_pdf.cc get_pmf.cc sample_pdf.cc cats_tree.cc
)

Expand Down
2 changes: 1 addition & 1 deletion vowpalwabbit/cats.cc
Original file line number Diff line number Diff line change
Expand Up @@ -55,7 +55,7 @@ int cats::learn(example& ec, api_status* status = nullptr)
{
assert(!ec.test_only);
predict(ec, status);
VW_DBG(ec) << "cats::learn(), " << cont_label_to_string(ec) << features_to_string(ec) << endl;
VW_DBG(ec) << "cats::learn(), " << to_string(ec.l.cb_cont) << features_to_string(ec) << endl;
_base->learn(ec);
return error_code::success;
}
Expand Down
71 changes: 5 additions & 66 deletions vowpalwabbit/cats_pdf.cc
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
#include "parse_args.h"
#include "err_constants.h"
#include "api_status.h"
#include "cb_label_parser.h"
#include "cb_continuous_label.h"
#include "debug_log.h"

// Aliases
Expand Down Expand Up @@ -56,7 +56,7 @@ namespace cats_pdf
{
assert(!ec.test_only);
predict(ec, status);
VW_DBG(ec) << "cats_pdf::learn(), " << cont_label_to_string(ec) << features_to_string(ec) << endl;
VW_DBG(ec) << "cats_pdf::learn(), " << to_string(ec.l.cb_cont) << features_to_string(ec) << endl;
_base->learn(ec);
return error_code::success;
}
Expand Down Expand Up @@ -143,67 +143,6 @@ namespace cats_pdf
// END: functions to output progress
////////////////////////////////////////////////////

////////////////////////////////////////////////////
// Begin: parse a,c,p,x file format
namespace lbl_parser
{
void parse_label(parser* p, shared_data*, void* v, v_array<substring>& words)
{
auto ld = static_cast<continuous_label*>(v);
ld->costs.clear();
for (auto word : words)
{
continuous_label_elm f{0.f, FLT_MAX, 0.f, 0.f};
tokenize(':', word, p->parse_name);

if (p->parse_name.empty() || p->parse_name.size() > 3)
THROW("malformed cost specification: " << p->parse_name);

f.action = float_of_substring(p->parse_name[0]);

if (p->parse_name.size() > 1)
f.cost = float_of_substring(p->parse_name[1]);

if (nanpattern(f.cost))
THROW("error NaN cost (" << p->parse_name[1] << " for action: " << p->parse_name[0]);

f.probability = .0;
if (p->parse_name.size() > 2)
f.probability = float_of_substring(p->parse_name[2]);

if (nanpattern(f.probability))
THROW("error NaN probability (" << p->parse_name[2] << " for action: " << p->parse_name[0]);

if (f.probability > 1.0)
{
std::cerr << "invalid probability > 1 specified for an action, resetting to 1." << endl;
f.probability = 1.0;
}
if (f.probability < 0.0)
{
std::cerr << "invalid probability < 0 specified for an action, resetting to 0." << endl;
f.probability = .0;
}

ld->costs.push_back(f);
}
}

label_parser cont_tbd_label_parser = {
CB::default_label<continuous_label>,
parse_label,
CB::cache_label<continuous_label, continuous_label_elm>,
CB::read_cached_label<continuous_label, continuous_label_elm>,
CB::delete_label<continuous_label>,
CB::weight,
CB::copy_label<continuous_label>,
CB::is_test_label<continuous_label>,
sizeof(continuous_label)};
}

// End: parse a,c,p,x file format
////////////////////////////////////////////////////

// Setup reduction in stack
LEARNER::base_learner* setup(config::options_i& options, vw& all)
{
Expand Down Expand Up @@ -241,11 +180,11 @@ namespace cats_pdf
predict_or_learn<false>, 1, prediction_type::action_pdf_value);

l.set_finish_example(finish_example);
all.p->lp = lbl_parser::cont_tbd_label_parser;
all.p->lp = cb_continuous::the_label_parser;
all.delete_prediction = nullptr;

return make_base(l);
}
}
}
}
}
} // namespace VW
Loading

0 comments on commit 8c4f372

Please sign in to comment.