Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

refactor: Migrate raw function pointers to std::function #4461

Merged
merged 80 commits into from
Feb 7, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
80 commits
Select commit Hold shift + click to select a range
b02042d
remove x permission
byronxu99 Jan 11, 2023
52457b7
migrate some function pointers to std::function
byronxu99 Jan 11, 2023
7022800
no template for learner class, use std::function, bind data pointer i…
byronxu99 Jan 17, 2023
f6cd651
move learner class function implementations to .cc file
byronxu99 Jan 17, 2023
d00a557
fix failing tests
byronxu99 Jan 17, 2023
778a647
remove BaseLearnerT from templates
byronxu99 Jan 17, 2023
cd4320d
change polymorphic_ex to use more const
byronxu99 Jan 18, 2023
337f37e
change base_learner to learner
byronxu99 Jan 18, 2023
871af99
clang format
byronxu99 Jan 18, 2023
62119bc
Merge branch 'master' of https://github.com/VowpalWabbit/vowpal_wabbi…
byronxu99 Jan 18, 2023
e683afe
clang format
byronxu99 Jan 18, 2023
54e1d21
update comments in learner.h
byronxu99 Jan 18, 2023
0ca9b80
make variable names in learner.h more consistent
byronxu99 Jan 18, 2023
f309576
fix .net core binding to json parser
byronxu99 Jan 18, 2023
de793dc
fix cli bindings
byronxu99 Jan 18, 2023
42032b1
format and add comments
byronxu99 Jan 18, 2023
d31c1d0
avoid binding nullptr to base learner reference
byronxu99 Jan 18, 2023
a17d183
avoid UB with no data reduction
byronxu99 Jan 18, 2023
3d048ad
assert function pointers are not null when building learner, fix errors
byronxu99 Jan 18, 2023
ba46574
formatting
byronxu99 Jan 18, 2023
d09cfe3
add typecasting wrapper around merge functions + remove no_sanitize_u…
byronxu99 Jan 18, 2023
dbad8fc
use std::function for set_minmax
byronxu99 Jan 19, 2023
95dbfcf
add ExampleT type check in learner builder
byronxu99 Jan 19, 2023
1408481
formatting
byronxu99 Jan 19, 2023
40ba14c
migrate print_by_ref
byronxu99 Jan 19, 2023
0a1a075
fix error with get_loss_function and set_minmax function
byronxu99 Jan 19, 2023
cf126fe
higher performance templated polymorphic_ex class
byronxu99 Jan 20, 2023
fbefced
add fake error handling learner when going past last learner in stack
byronxu99 Jan 20, 2023
4213131
clarify base vs previous learner, add comments
byronxu99 Jan 20, 2023
e669e54
update comments
byronxu99 Jan 20, 2023
418fe4b
fix failing test due to change in error message string
byronxu99 Jan 20, 2023
2fa55d4
Merge branch 'master' of https://github.com/VowpalWabbit/vowpal_wabbi…
byronxu99 Jan 20, 2023
725c054
add benchmark for std::function
byronxu99 Jan 23, 2023
5e0e039
use nullptr for set_minmax instead of noop function
byronxu99 Jan 24, 2023
2e1f77a
use shared_ptr for shared_data
byronxu99 Jan 24, 2023
1ee2648
Merge branch 'master' of https://github.com/VowpalWabbit/vowpal_wabbi…
byronxu99 Jan 24, 2023
b376410
fix typo
byronxu99 Jan 24, 2023
0449c64
make all_reduce a unique_ptr and fix error with shared_data
byronxu99 Jan 24, 2023
5b9a01e
use shared_ptr for learner instead of raw pointer
byronxu99 Jan 24, 2023
3867e30
Merge branch 'master' of https://github.com/VowpalWabbit/vowpal_wabbi…
byronxu99 Jan 24, 2023
ea1f02c
fix java and python bindings
byronxu99 Jan 24, 2023
83747a2
add pointer benchmark
byronxu99 Jan 25, 2023
cc3396d
add pointer benchmark results
byronxu99 Jan 25, 2023
0282aaf
formatting
byronxu99 Jan 25, 2023
5951eef
get pointer before std function benchmark
byronxu99 Jan 25, 2023
5dec49c
Merge branch 'master' of https://github.com/VowpalWabbit/vowpal_wabbi…
byronxu99 Jan 25, 2023
d03bbef
Merge branch 'master' of https://github.com/VowpalWabbit/vowpal_wabbi…
byronxu99 Jan 26, 2023
28c5ef8
use function pointer for csharp binding
byronxu99 Jan 26, 2023
ae52cd9
use lambda instead of std::bind
byronxu99 Jan 26, 2023
0f05a61
remove previous learner from function signatures, bind at creation in…
byronxu99 Jan 26, 2023
08afc01
don't run std library benchmarks by default
byronxu99 Jan 26, 2023
a1d0da9
Merge branch 'master' into std_function
byronxu99 Jan 27, 2023
c2e9e36
Merge branch 'master' of https://github.com/VowpalWabbit/vowpal_wabbi…
byronxu99 Jan 30, 2023
25ea8ea
Merge branch 'std_function' of github.com:byronxu99/vowpal_wabbit int…
byronxu99 Jan 30, 2023
0685080
use function pointer when std::function is unnecessary
byronxu99 Jan 30, 2023
7633f32
add comments, rename stuff, formatting
byronxu99 Jan 30, 2023
c3f036e
clarify use of shared_ptr in reduction stack setup
byronxu99 Jan 30, 2023
2c7b1cf
Merge branch 'master' of https://github.com/VowpalWabbit/vowpal_wabbi…
byronxu99 Jan 30, 2023
5c793f7
always use shared_ptr when taking over ownership
byronxu99 Jan 31, 2023
7f21989
use shared_ptr variant of require functions
byronxu99 Jan 31, 2023
686bba2
Merge branch 'master' of https://github.com/VowpalWabbit/vowpal_wabbi…
byronxu99 Jan 31, 2023
cf622f7
Merge branch 'master' of https://github.com/VowpalWabbit/vowpal_wabbi…
byronxu99 Jan 31, 2023
86e554d
use macro to reduce code duplication in learner builders
byronxu99 Jan 31, 2023
e3bae1c
allocate placeholder char for no-data reduction
byronxu99 Jan 31, 2023
9fceb32
Merge branch 'master' of https://github.com/VowpalWabbit/vowpal_wabbi…
byronxu99 Jan 31, 2023
e24ba37
Merge branch 'master' of https://github.com/VowpalWabbit/vowpal_wabbi…
byronxu99 Feb 1, 2023
aa02e3d
Merge branch 'master' of https://github.com/VowpalWabbit/vowpal_wabbi…
byronxu99 Feb 1, 2023
d4b5a5b
Merge branch 'master' of https://github.com/VowpalWabbit/vowpal_wabbi…
byronxu99 Feb 1, 2023
f8e8bfc
Merge branch 'master' of https://github.com/VowpalWabbit/vowpal_wabbi…
byronxu99 Feb 2, 2023
6bfbe87
Merge branch 'master' of https://github.com/VowpalWabbit/vowpal_wabbi…
byronxu99 Feb 2, 2023
778b07c
Merge branch 'master' of https://github.com/VowpalWabbit/vowpal_wabbi…
byronxu99 Feb 2, 2023
8861fba
Merge branch 'master' of https://github.com/VowpalWabbit/vowpal_wabbi…
byronxu99 Feb 3, 2023
5b90557
don't use the term previous learner
byronxu99 Feb 6, 2023
7ce17ca
rename foundation to bottom
byronxu99 Feb 7, 2023
05be7e6
Merge branch 'master' of https://github.com/VowpalWabbit/vowpal_wabbi…
byronxu99 Feb 7, 2023
92f0616
clarify naming of reduction and learner
byronxu99 Feb 7, 2023
4a3bca2
remove debug log for base_learner
byronxu99 Feb 7, 2023
16e62df
fix python bindings
byronxu99 Feb 7, 2023
dcea4cb
make naming more consistent
byronxu99 Feb 7, 2023
5006ed6
change reduction to learner in test files
byronxu99 Feb 7, 2023
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
The table of contents is too big for display.
Diff view
Diff view
  •  
  •  
  •  
34 changes: 18 additions & 16 deletions cs/cli/vowpalwabbit.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,8 @@
#include "vw/json_parser/parse_example_json.h"
#include "vw/core/shared_data.h"

#include <functional>

using namespace System;
using namespace System::Collections::Generic;
using namespace System::Text;
Expand All @@ -34,12 +36,12 @@ VowpalWabbit::VowpalWabbit(VowpalWabbitSettings^ settings)
auto total = settings->ParallelOptions->MaxDegreeOfParallelism;

if (settings->Root == nullptr)
{ m_vw->all_reduce = new all_reduce_threads(total, settings->Node);
{ m_vw->all_reduce.reset(new all_reduce_threads(total, settings->Node));
}
else
{ auto parent_all_reduce = (all_reduce_threads*)settings->Root->m_vw->all_reduce;
{ auto parent_all_reduce = (all_reduce_threads*)settings->Root->m_vw->all_reduce.get();

m_vw->all_reduce = new all_reduce_threads(parent_all_reduce, total, settings->Node);
m_vw->all_reduce.reset(new all_reduce_threads(parent_all_reduce, total, settings->Node));
}
}

Expand Down Expand Up @@ -172,7 +174,7 @@ void VowpalWabbit::Learn(List<VowpalWabbitExample^>^ examples)
m_vw->learn(ex_coll);

// as this is not a ring-based example it is not freed
as_multiline(m_vw->l)->finish_example(*m_vw, ex_coll);
require_multiline(m_vw->l)->finish_example(*m_vw, ex_coll);
}
CATCHRETHROW
finally{ }
Expand All @@ -189,10 +191,10 @@ void VowpalWabbit::Predict(List<VowpalWabbitExample^>^ examples)
ex_coll.push_back(pex);
}

as_multiline(m_vw->l)->predict(ex_coll);
require_multiline(m_vw->l)->predict(ex_coll);

// as this is not a ring-based example it is not freed
as_multiline(m_vw->l)->finish_example(*m_vw, ex_coll);
require_multiline(m_vw->l)->finish_example(*m_vw, ex_coll);
}
CATCHRETHROW
finally{ }
Expand All @@ -210,7 +212,7 @@ void VowpalWabbit::Learn(VowpalWabbitExample^ ex)
{ m_vw->learn(*ex->m_example);

// as this is not a ring-based example it is not free'd
as_singleline(m_vw->l)->finish_example(*m_vw, *ex->m_example);
require_singleline(m_vw->l)->finish_example(*m_vw, *ex->m_example);
}
CATCHRETHROW
}
Expand All @@ -231,7 +233,7 @@ generic<typename T> T VowpalWabbit::Learn(VowpalWabbitExample^ ex, IVowpalWabbit
auto prediction = predictionFactory->Create(m_vw, ex->m_example);

// as this is not a ring-based example it is not free'd
as_singleline(m_vw->l)->finish_example(*m_vw, *ex->m_example);
require_singleline(m_vw->l)->finish_example(*m_vw, *ex->m_example);

return prediction;
}
Expand All @@ -246,10 +248,10 @@ void VowpalWabbit::Predict(VowpalWabbitExample^ ex)
#endif

try
{ as_singleline(m_vw->l)->predict(*ex->m_example);
{ require_singleline(m_vw->l)->predict(*ex->m_example);

// as this is not a ring-based example it is not free'd
as_singleline(m_vw->l)->finish_example(*m_vw, *ex->m_example);
require_singleline(m_vw->l)->finish_example(*m_vw, *ex->m_example);
}
CATCHRETHROW
}
Expand All @@ -262,12 +264,12 @@ generic<typename T> T VowpalWabbit::Predict(VowpalWabbitExample^ ex, IVowpalWabb
#endif

try
{ as_singleline(m_vw->l)->predict(*ex->m_example);
{ require_singleline(m_vw->l)->predict(*ex->m_example);

auto prediction = predictionFactory->Create(m_vw, ex->m_example);

// as this is not a ring-based example it is not free'd
as_singleline(m_vw->l)->finish_example(*m_vw, *ex->m_example);
require_singleline(m_vw->l)->finish_example(*m_vw, *ex->m_example);

return prediction;
}
Expand Down Expand Up @@ -320,9 +322,9 @@ List<VowpalWabbitExample^>^ VowpalWabbit::ParseDecisionServiceJson(cli::array<By
VW::parsers::json::decision_service_interaction interaction;

if (m_vw->audit)
VW::parsers::json::read_line_decision_service_json<true>(*m_vw, examples, reinterpret_cast<char*>(data), length, copyJson, get_example_from_pool, &state, &interaction);
VW::parsers::json::read_line_decision_service_json<true>(*m_vw, examples, reinterpret_cast<char*>(data), length, copyJson, std::bind(get_example_from_pool, &state), &interaction);
else
VW::parsers::json::read_line_decision_service_json<false>(*m_vw, examples, reinterpret_cast<char*>(data), length, copyJson, get_example_from_pool, &state, &interaction);
VW::parsers::json::read_line_decision_service_json<false>(*m_vw, examples, reinterpret_cast<char*>(data), length, copyJson, std::bind(get_example_from_pool, &state), &interaction);

// finalize example
VW::setup_examples(*m_vw, examples);
Expand Down Expand Up @@ -384,9 +386,9 @@ List<VowpalWabbitExample^>^ VowpalWabbit::ParseDecisionServiceJson(cli::array<By
interior_ptr<ParseJsonState^> state_ptr = &state;

if (m_vw->audit)
VW::parsers::json::read_line_json<true>(*m_vw, examples, reinterpret_cast<char*>(valueHandle.AddrOfPinnedObject().ToPointer()), (size_t)bytes->Length, get_example_from_pool, &state);
VW::parsers::json::read_line_json<true>(*m_vw, examples, reinterpret_cast<char*>(valueHandle.AddrOfPinnedObject().ToPointer()), (size_t)bytes->Length, std::bind(get_example_from_pool, &state));
else
VW::parsers::json::read_line_json<false>(*m_vw, examples, reinterpret_cast<char*>(valueHandle.AddrOfPinnedObject().ToPointer()), (size_t)bytes->Length, get_example_from_pool, &state);
VW::parsers::json::read_line_json<false>(*m_vw, examples, reinterpret_cast<char*>(valueHandle.AddrOfPinnedObject().ToPointer()), (size_t)bytes->Length, std::bind(get_example_from_pool, &state));

// finalize example
VW::setup_examples(*m_vw, examples);
Expand Down
18 changes: 9 additions & 9 deletions cs/vw.net.native/vw.net.workspace.cc
Original file line number Diff line number Diff line change
Expand Up @@ -206,15 +206,15 @@ API size_t WorkspaceHashFeature(
API void WorkspaceSetUpAllReduceThreadsRoot(vw_net_native::workspace_context* workspace, size_t total, size_t node)
{
workspace->vw->selected_all_reduce_type = VW::all_reduce_type::THREAD;
workspace->vw->all_reduce = new VW::all_reduce_threads(total, node);
workspace->vw->all_reduce.reset(new VW::all_reduce_threads(total, node));
}

API void WorkspaceSetUpAllReduceThreadsNode(vw_net_native::workspace_context* workspace, size_t total, size_t node,
vw_net_native::workspace_context* root_workspace)
{
workspace->vw->selected_all_reduce_type = VW::all_reduce_type::THREAD;
workspace->vw->all_reduce =
new VW::all_reduce_threads((VW::all_reduce_threads*)root_workspace->vw->all_reduce, total, node);
workspace->vw->all_reduce.reset(
new VW::all_reduce_threads((VW::all_reduce_threads*)root_workspace->vw->all_reduce.get(), total, node));
}

API vw_net_native::ERROR_CODE WorkspaceRunMultiPass(
Expand Down Expand Up @@ -269,11 +269,11 @@ API vw_net_native::ERROR_CODE WorkspacePredict(vw_net_native::workspace_context*
{
try
{
as_singleline(workspace->vw->l)->predict(*ex);
require_singleline(workspace->vw->l)->predict(*ex);

if (create_prediction != nullptr) { create_prediction(); }

as_singleline(workspace->vw->l)->finish_example(*workspace->vw, *ex);
require_singleline(workspace->vw->l)->finish_example(*workspace->vw, *ex);

return VW::experimental::error_code::success;
}
Expand All @@ -289,7 +289,7 @@ API vw_net_native::ERROR_CODE WorkspaceLearn(vw_net_native::workspace_context* w

if (create_prediction != nullptr) { create_prediction(); }

as_singleline(workspace->vw->l)->finish_example(*workspace->vw, *ex);
require_singleline(workspace->vw->l)->finish_example(*workspace->vw, *ex);

return VW::experimental::error_code::success;
}
Expand All @@ -301,11 +301,11 @@ API vw_net_native::ERROR_CODE WorkspacePredictMulti(vw_net_native::workspace_con
{
try
{
as_multiline(workspace->vw->l)->predict(*ex_coll);
require_multiline(workspace->vw->l)->predict(*ex_coll);

if (create_prediction != nullptr) { create_prediction(); }

as_multiline(workspace->vw->l)->finish_example(*workspace->vw, *ex_coll);
require_multiline(workspace->vw->l)->finish_example(*workspace->vw, *ex_coll);

return VW::experimental::error_code::success;
}
Expand All @@ -321,7 +321,7 @@ API vw_net_native::ERROR_CODE WorkspaceLearnMulti(vw_net_native::workspace_conte

if (create_prediction != nullptr) { create_prediction(); }

as_multiline(workspace->vw->l)->finish_example(*workspace->vw, *ex_coll);
require_multiline(workspace->vw->l)->finish_example(*workspace->vw, *ex_coll);

return VW::experimental::error_code::success;
}
Expand Down
14 changes: 8 additions & 6 deletions cs/vw.net.native/vw.net.workspace_parse_json.cc
Original file line number Diff line number Diff line change
@@ -1,6 +1,8 @@
#include "vw.net.workspace.h"
#include "vw/json_parser/parse_example_json.h"

#include <functional>

API vw_net_native::ERROR_CODE WorkspaceParseJson(vw_net_native::workspace_context* workspace, char* json, size_t length,
vw_net_native::example_pool_get_example_fn get_example, void* example_pool_context,
VW::experimental::api_status* status)
Expand All @@ -15,12 +17,12 @@ API vw_net_native::ERROR_CODE WorkspaceParseJson(vw_net_native::workspace_contex
if (workspace->vw->audit)
{
VW::parsers::json::read_line_json<true>(
*workspace->vw, examples, json, length, get_example, example_pool_context);
*workspace->vw, examples, json, length, std::bind(get_example, example_pool_context));
}
else
{
VW::parsers::json::read_line_json<false>(
*workspace->vw, examples, json, length, get_example, example_pool_context);
*workspace->vw, examples, json, length, std::bind(get_example, example_pool_context));
}

VW::setup_examples(*workspace->vw, examples);
Expand Down Expand Up @@ -48,13 +50,13 @@ API vw_net_native::ERROR_CODE WorkspaceParseDecisionServiceJson(vw_net_native::w
{
if (workspace->vw->audit)
{
VW::parsers::json::read_line_decision_service_json<true>(
*workspace->vw, examples, actual_json, length, copy_json, get_example, example_pool_context, interaction);
VW::parsers::json::read_line_decision_service_json<true>(*workspace->vw, examples, actual_json, length, copy_json,
std::bind(get_example, example_pool_context), interaction);
}
else
{
VW::parsers::json::read_line_decision_service_json<false>(
*workspace->vw, examples, actual_json, length, copy_json, get_example, example_pool_context, interaction);
VW::parsers::json::read_line_decision_service_json<false>(*workspace->vw, examples, actual_json, length,
copy_json, std::bind(get_example, example_pool_context), interaction);
}

VW::setup_examples(*workspace->vw, examples);
Expand Down
4 changes: 2 additions & 2 deletions java/src/main/c++/jni_spark_vw.cc
Original file line number Diff line number Diff line change
Expand Up @@ -843,7 +843,7 @@ JNIEXPORT void JNICALL Java_org_vowpalwabbit_spark_VowpalWabbitExample_learn(JNI
all->learn(*ex);

// as this is not a ring-based example it is not free'd
VW::LEARNER::as_singleline(all->l)->finish_example(*all, *ex);
all->l->finish_example(*all, *ex);
}
catch (...)
{
Expand All @@ -862,7 +862,7 @@ JNIEXPORT jobject JNICALL Java_org_vowpalwabbit_spark_VowpalWabbitExample_predic
all->predict(*ex);

// as this is not a ring-based example it is not free'd
VW::LEARNER::as_singleline(all->l)->finish_example(*all, *ex);
all->l->finish_example(*all, *ex);

return getJavaPrediction(env, all, ex);
}
Expand Down
44 changes: 24 additions & 20 deletions python/pylibvw.cc
Original file line number Diff line number Diff line change
Expand Up @@ -81,15 +81,15 @@ class OptionManager : VW::config::typed_option_visitor
// see pyvw.py class VWOption
py::object m_py_opt_class;
VW::config::options_i& m_opt;
std::vector<std::string>& m_enabled_reductions;
std::vector<std::string>& m_enabled_learners;
std::string default_group_name;

py::object* m_visitor_output_var;

public:
OptionManager(VW::config::options_i& options, std::vector<std::string>& enabled_reductions, py::object py_class)
OptionManager(VW::config::options_i& options, std::vector<std::string>& enabled_learners, py::object py_class)
: m_opt(options)
, m_enabled_reductions(enabled_reductions)
, m_enabled_learners(enabled_learners)
, m_option_group_dic(options.get_collection_of_options())
, m_py_opt_class(py_class)
{
Expand Down Expand Up @@ -197,7 +197,7 @@ class OptionManager : VW::config::typed_option_visitor
while (it != m_option_group_dic.end())
{
auto reduction_enabled =
std::find(m_enabled_reductions.begin(), m_enabled_reductions.end(), it->first) != m_enabled_reductions.end();
std::find(m_enabled_learners.begin(), m_enabled_learners.end(), it->first) != m_enabled_learners.end();

if (((it->first).compare(default_group_name) != 0) && enabled_only && !reduction_enabled)
{
Expand Down Expand Up @@ -347,7 +347,7 @@ py::dict get_learner_metrics(vw_ptr all)

if (all->global_metrics.are_metrics_enabled())
{
auto metrics = all->global_metrics.collect_metrics(all->l);
auto metrics = all->global_metrics.collect_metrics(all->l.get());

python_dict_writer writer(dictionary);
metrics.visit(writer);
Expand All @@ -370,9 +370,9 @@ search_ptr get_search_ptr(vw_ptr all)

py::object get_options(vw_ptr all, py::object py_class, bool enabled_only)
{
std::vector<std::string> enabled_reductions;
if (all->l) all->l->get_enabled_reductions(enabled_reductions);
auto opt_manager = OptionManager(*all->options, enabled_reductions, py_class);
std::vector<std::string> enabled_learners;
if (all->l) all->l->get_enabled_learners(enabled_learners);
auto opt_manager = OptionManager(*all->options, enabled_learners, py_class);
return opt_manager.get_vw_option_pyobjects(enabled_only);
}

Expand All @@ -391,14 +391,14 @@ std::string get_arguments(vw_ptr all)
return serializer.str();
}

py::list get_enabled_reductions(vw_ptr all)
py::list get_enabled_learners(vw_ptr all)
{
py::list py_enabled_reductions;
std::vector<std::string> enabled_reductions;
if (all->l) all->l->get_enabled_reductions(enabled_reductions);
for (auto ex : enabled_reductions) { py_enabled_reductions.append(ex); }
py::list py_enabled_learners;
std::vector<std::string> enabled_learners;
if (all->l) all->l->get_enabled_learners(enabled_learners);
for (auto ex : enabled_learners) { py_enabled_learners.append(ex); }

return py_enabled_reductions;
return py_enabled_learners;
}

predictor_ptr get_predictor(search_ptr _sch, ptag my_tag)
Expand Down Expand Up @@ -529,25 +529,25 @@ multi_ex unwrap_example_list(py::list& ec)
return ex_coll;
}

void my_finish_example(vw_ptr all, example_ptr ec) { as_singleline(all->l)->finish_example(*all, *ec); }
void my_finish_example(vw_ptr all, example_ptr ec) { all->l->finish_example(*all, *ec); }

void my_finish_multi_ex(vw_ptr& all, py::list& ec)
{
auto ex_col = unwrap_example_list(ec);
as_multiline(all->l)->finish_example(*all, ex_col);
all->l->finish_example(*all, ex_col);
}

void my_learn(vw_ptr all, example_ptr ec)
{
if (ec->test_only) { as_singleline(all->l)->predict(*ec); }
if (ec->test_only) { all->l->predict(*ec); }
else { all->learn(*ec.get()); }
}

std::string my_json_weights(vw_ptr all) { return all->dump_weights_to_json_experimental(); }

float my_predict(vw_ptr all, example_ptr ec)
{
as_singleline(all->l)->predict(*ec);
all->l->predict(*ec);
return ec->partial_prediction;
}

Expand All @@ -560,7 +560,7 @@ void predict_or_learn(vw_ptr& all, py::list& ec)
if (learn)
all->learn(ex_coll);
else
as_multiline(all->l)->predict(ex_coll);
all->l->predict(ex_coll);
}

py::list my_parse(vw_ptr& all, char* str)
Expand Down Expand Up @@ -1420,7 +1420,11 @@ BOOST_PYTHON_MODULE(pylibvw)
.def("audit_example", &my_audit_example, "print example audit information")
.def("get_id", &get_model_id, "return the model id")
.def("get_arguments", &get_arguments, "return the arguments after resolving all dependencies")
.def("get_enabled_reductions", &get_enabled_reductions, "return the list of names of the enabled reductions")

// this returns all learners, not just reduction learners, but the API was originally called
// get_enabled_reductions
.def("get_enabled_learners", &get_enabled_learners, "return the list of names of the enabled learners")
.def("get_enabled_reductions", &get_enabled_learners, "return the list of names of the enabled learners")

.def("learn_multi", &my_learn_multi_ex, "given a list pyvw examples, learn (and predict) on those examples")
.def("predict_multi", &my_predict_multi_ex, "given a list of pyvw examples, predict on that example")
Expand Down
4 changes: 4 additions & 0 deletions test/benchmarks/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,10 @@ if (NOT BUILD_ONLY_STANDALONE_BENCHMARKS)
set(all_sources ${all_sources}
input_format_benchmarks.cc
benchmark_funcs.cc

# These are just for benchmarking specific standard library operations
#benchmark_std_function.cc
#benchmark_pointers.cc
)
endif()

Expand Down
Loading