-
Notifications
You must be signed in to change notification settings - Fork 1.9k
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
Showing
8 changed files
with
133 additions
and
55 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,54 +1,79 @@ | ||
#include "vw/core/example.h" | ||
#include "vw/core/model_utils.h" | ||
|
||
bool VW::ec_is_example_header_cb_with_observations(VW::example const& ec) | ||
namespace | ||
{ | ||
const auto& costs = ec.l.cb_with_observations.event.costs; | ||
if (costs.size() != 1) { return false; } | ||
if (costs[0].probability == -1.f) { return true; } | ||
float cb_with_observations_weight(const VW::cb_with_observations_label& ld) { return ld.event.weight; } | ||
|
||
return false; | ||
void parse_label_cb_with_observations(VW::cb_with_observations_label& /*ld*/, VW::reduction_features& /*red_features*/, | ||
VW::label_parser_reuse_mem& /*reuse_mem*/, const std::vector<VW::string_view>& /*words*/, | ||
VW::io::logger& /*logger*/) | ||
{ | ||
// TODO: implement text format parsing for cb with observations | ||
} | ||
} // namespace | ||
|
||
namespace | ||
namespace VW | ||
{ | ||
namespace model_utils | ||
{ | ||
size_t read_model_field(io_buf& io, VW::cb_with_observations_label& cb_with_obs) | ||
{ | ||
float weight_cb_with_observations(const VW::cb_with_observations_label& ld) { return ld.event.weight; } | ||
size_t bytes = 0; | ||
bytes += read_model_field(io, cb_with_obs.event); | ||
bytes += read_model_field(io, cb_with_obs.is_observation); | ||
bytes += read_model_field(io, cb_with_obs.is_definitely_bad); | ||
return bytes; | ||
} | ||
|
||
void default_label_cb_with_observations(VW::cb_with_observations_label& ld) | ||
size_t write_model_field(io_buf& io, const VW::cb_with_observations_label& cb_with_obs, | ||
const std::string& upstream_name, bool text) | ||
{ | ||
ld.event.reset_to_default(); | ||
ld.is_observation = false; | ||
ld.is_definitely_bad = false; | ||
size_t bytes = 0; | ||
bytes += VW::model_utils::write_model_field(io, cb_with_obs.event, upstream_name + "_event", text); | ||
bytes += write_model_field(io, cb_with_obs.is_observation, upstream_name + "_is_observation", text); | ||
bytes += write_model_field(io, cb_with_obs.is_definitely_bad, upstream_name + "_is_definitely_bad", text); | ||
return bytes; | ||
} | ||
} // namespace model_utils | ||
|
||
bool test_label_cb_with_observations(const VW::cb_with_observations_label& ld) { return ld.event.is_test_label(); } | ||
bool ec_is_example_header_cb_with_observations(VW::example const& ec) | ||
{ | ||
const auto& costs = ec.l.cb_with_observations.event.costs; | ||
if (costs.size() != 1) { return false; } | ||
if (costs[0].probability == -1.f) { return true; } | ||
|
||
void parse_label_cb_with_observations(VW::cb_with_observations_label& /*ld*/, VW::reduction_features& /*red_features*/, | ||
VW::label_parser_reuse_mem& /*reuse_mem*/, const std::vector<VW::string_view>& /*words*/, | ||
VW::io::logger& /*logger*/) | ||
return false; | ||
} | ||
|
||
void cb_with_observations_label::reset_to_default() | ||
{ | ||
// TODO: implement text format parsing for cb with observations | ||
event.reset_to_default(); | ||
is_observation = false; | ||
is_definitely_bad = false; | ||
} | ||
} // namespace | ||
|
||
VW::label_parser VW::cb_with_observations_global = { | ||
bool cb_with_observations_label::is_test_label() const { return event.is_test_label(); } | ||
|
||
VW::label_parser cb_with_observations_global = { | ||
// default_label | ||
[](VW::polylabel& label) { default_label_cb_with_observations(label.cb_with_observations); }, | ||
[](VW::polylabel& label) { label.cb_with_observations.reset_to_default(); }, | ||
// parse_label | ||
[](VW::polylabel& label, VW::reduction_features& red_features, VW::label_parser_reuse_mem& reuse_mem, | ||
const VW::named_labels* /*ldict*/, const std::vector<VW::string_view>& words, VW::io::logger& logger) | ||
{ parse_label_cb_with_observations(label.cb_with_observations, red_features, reuse_mem, words, logger); }, | ||
// cache_label | ||
// TODO: implement this for cb_with_observations | ||
[](const VW::polylabel& label, const VW::reduction_features& /*red_features*/, io_buf& cache, | ||
const std::string& upstream_name, bool text) | ||
{ return VW::model_utils::write_model_field(cache, label.cb, upstream_name, text); }, | ||
{ return VW::model_utils::write_model_field(cache, label.cb_with_observations, upstream_name, text); }, | ||
// read_cached_label | ||
[](VW::polylabel& label, VW::reduction_features& /*red_features*/, io_buf& cache) | ||
{ return VW::model_utils::read_model_field(cache, label.cb); }, // TODO: implement this for cb_with_observations | ||
{ return VW::model_utils::read_model_field(cache, label.cb_with_observations); }, | ||
// get_weight | ||
[](const VW::polylabel& label, const VW::reduction_features& /*red_features*/) | ||
{ return weight_cb_with_observations(label.cb_with_observations); }, | ||
{ return cb_with_observations_weight(label.cb_with_observations); }, | ||
// test_label | ||
[](const VW::polylabel& label) { return test_label_cb_with_observations(label.cb_with_observations); }, | ||
[](const VW::polylabel& label) { return label.cb_with_observations.is_test_label(); }, | ||
// Label type | ||
VW::label_type_t::CB_WITH_OBSERVATIONS}; | ||
} // namespace VW |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
44 changes: 44 additions & 0 deletions
44
vowpalwabbit/core/tests/cb_with_observations_parser_test.cc
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,44 @@ | ||
// Copyright (c) by respective owners including Yahoo!, Microsoft, and | ||
// individual contributors. All rights reserved. Released under a BSD (revised) | ||
// license as described in the file LICENSE. | ||
|
||
#include "vw/common/string_view.h" | ||
#include "vw/common/text_utils.h" | ||
#include "vw/core/cb_with_observations_label.h" | ||
#include "vw/core/memory.h" | ||
#include "vw/core/parse_primitives.h" | ||
#include "vw/core/parser.h" | ||
#include "vw/io/logger.h" | ||
|
||
#include <gmock/gmock.h> | ||
#include <gtest/gtest.h> | ||
|
||
#include <memory> | ||
#include <vector> | ||
|
||
|
||
TEST(CbWithObservations, CacheLabel) | ||
{ | ||
auto backing_vector = std::make_shared<std::vector<char>>(); | ||
VW::io_buf io_writer; | ||
io_writer.add_file(VW::io::create_vector_writer(backing_vector)); | ||
|
||
VW::cb_with_observations_label cb_with_obs_label; | ||
cb_with_obs_label.event.weight = 5.f; | ||
cb_with_obs_label.is_definitely_bad = true; | ||
cb_with_obs_label.is_observation = true; | ||
|
||
VW::model_utils::write_model_field(io_writer, cb_with_obs_label, "", false); | ||
io_writer.flush(); | ||
|
||
VW::io_buf io_reader; | ||
io_reader.add_file(VW::io::create_buffer_view(backing_vector->data(), backing_vector->size())); | ||
|
||
auto uncached_label = VW::make_unique<VW::cb_with_observations_label>(); | ||
uncached_label->reset_to_default(); | ||
VW::model_utils::read_model_field(io_reader, *uncached_label); | ||
|
||
EXPECT_FLOAT_EQ(uncached_label->event.weight, 5.f); | ||
EXPECT_EQ(uncached_label->is_definitely_bad, true); | ||
EXPECT_EQ(uncached_label->is_observation, true); | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters