Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

make background refresh optional #1655

Closed
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions reinforcement_learning/include/constants.h
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@ namespace reinforcement_learning { namespace name {
const char *const MODEL_BLOB_URI = "model.blob.uri";
const char *const MODEL_REFRESH_INTERVAL_MS = "model.refreshintervalms";
const char *const MODEL_IMPLEMENTATION = "model.implementation"; // VW vs other ML
const char *const MODEL_BACKGROUND_REFRESH = "model.backgroundrefresh";
rajan-chari marked this conversation as resolved.
Show resolved Hide resolved
const char *const VW_CMDLINE = "vw.commandline";
const char *const INITIAL_EPSILON = "initial_exploration.epsilon";
const char *const INTERACTION_EH_HOST = "interaction.eventhub.host";
Expand Down Expand Up @@ -40,5 +41,6 @@ namespace reinforcement_learning { namespace value {
const char *const INTERACTION_EH_SENDER = "INTERACTION_EH_SENDER";
const char *const NULL_TRACE_LOGGER = "NULL_TRACE_LOGGER";
const char *const CONSOLE_TRACE_LOGGER = "CONSOLE_TRACE_LOGGER";
const bool MODEL_BACKGROUND_REFRESH = true;
}}

2 changes: 1 addition & 1 deletion reinforcement_learning/include/live_model.h
Original file line number Diff line number Diff line change
Expand Up @@ -208,7 +208,7 @@ namespace reinforcement_learning {
/**
* @brief Default move constructor for live model object.
*/
live_model(live_model&&) = default;
live_model(live_model&&);
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Now that move constructor is defined, move assignment should be as well


/**
* @brief Default move assignment operator swaps implementation.
Expand Down
5 changes: 5 additions & 0 deletions reinforcement_learning/rlclientlib/live_model.cc
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,11 @@ namespace reinforcement_learning
new live_model_impl(config, fn, err_context, trace_factory, t_factory, m_factory, sender_factory));
}

live_model::live_model(live_model&& model) {
_pimpl = std::move(model._pimpl);
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: destructor is trivial now but that might change. would a swap be better for these variables?

_initialized = model._initialized;
}

live_model::~live_model() = default;

int live_model::init(api_status* status) {
Expand Down
26 changes: 20 additions & 6 deletions reinforcement_learning/rlclientlib/live_model_impl.cc
Original file line number Diff line number Diff line change
Expand Up @@ -116,12 +116,15 @@ namespace reinforcement_learning {
_trace_factory(trace_factory),
_t_factory{t_factory},
_m_factory{m_factory},
_sender_factory{sender_factory},
_bg_model_proc(config.get_int(name::MODEL_REFRESH_INTERVAL_MS, 60 * 1000), _watchdog, "Model downloader", &_error_cb) {
_sender_factory{sender_factory} {
// If there is no user supplied error callback, supply a default one that does nothing but report unhandled background errors.
if (fn == nullptr) {
_error_cb.set(&default_error_callback, &_watchdog);
}

if (_configuration.get_bool(name::MODEL_BACKGROUND_REFRESH, value::MODEL_BACKGROUND_REFRESH)) {
_bg_model_proc.reset(new utility::periodic_background_proc<model_management::model_downloader>(config.get_int(name::MODEL_REFRESH_INTERVAL_MS, 60 * 1000), _watchdog, "Model downloader", &_error_cb));
}
}

int live_model_impl::init_trace(api_status* status) {
Expand Down Expand Up @@ -240,10 +243,21 @@ namespace reinforcement_learning {
m::i_data_transport* ptransport;
RETURN_IF_FAIL(_t_factory->create(&ptransport, tranport_impl, _configuration, status));
// This class manages lifetime of transport
this->_transport.reset(ptransport);
// Initialize background process and start downloading models
this->_model_download.reset(new m::model_downloader(ptransport, &_data_cb, _trace_logger.get()));
return _bg_model_proc.init(_model_download.get(), status);
_transport.reset(ptransport);

if (_bg_model_proc) {
// Initialize background process and start downloading models
_model_download.reset(new m::model_downloader(ptransport, &_data_cb, _trace_logger.get()));
return _bg_model_proc->init(_model_download.get(), status);
}
else {
Copy link
Member

@ataymano ataymano Oct 26, 2018

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It looks like we cannot update model later. Maybe it is better to add update_model method in live_model interface?

// update the model synchronously
model_management::model_data md;
RETURN_IF_FAIL(_transport->get_data(md, status));
RETURN_IF_FAIL(_model->update(md, status));
}

return error_code::success;
}

//helper: check if at least one of the arguments is null or empty
Expand Down
2 changes: 1 addition & 1 deletion reinforcement_learning/rlclientlib/live_model_impl.h
Original file line number Diff line number Diff line change
Expand Up @@ -83,7 +83,7 @@ namespace reinforcement_learning
std::unique_ptr<model_management::model_downloader> _model_download{nullptr};
std::unique_ptr<i_trace> _trace_logger{nullptr};

utility::periodic_background_proc<model_management::model_downloader> _bg_model_proc;
std::unique_ptr<utility::periodic_background_proc<model_management::model_downloader>> _bg_model_proc;
uint64_t _seed_shift;
};

Expand Down
89 changes: 68 additions & 21 deletions reinforcement_learning/unit_test/live_model_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,8 @@ namespace cfg = reinforcement_learning::utility::config;

using namespace fakeit;

const auto JSON_CFG = R"(
namespace {
const auto JSON_CFG = R"(
{
"ApplicationID": "rnc-123456-a",
"EventHubInteractionConnectionString": "Endpoint=sb://localhost:8080/;SharedAccessKeyName=RMSAKey;SharedAccessKey=<ASharedAccessKey>=;EntityPath=interaction",
Expand All @@ -40,17 +41,42 @@ const auto JSON_CFG = R"(
"InitialExplorationEpsilon": 1.0
}
)";
const auto JSON_CONTEXT = R"({"_multi":[{},{}]})";
const auto JSON_CONTEXT = R"({"_multi":[{},{}]})";

BOOST_AUTO_TEST_CASE(live_model_ranking_request) {
auto mock_sender = get_mock_sender();
auto mock_data_transport = get_mock_data_transport();
auto mock_model = get_mock_model();
r::live_model create_mock_live_model(
const u::configuration& config,
r::data_transport_factory_t* data_transport_factory = nullptr,
r::model_factory_t* model_factory = nullptr,
r::sender_factory_t* sender_factory = nullptr) {

auto sender_factory = get_mock_sender_factory(mock_sender.get(), mock_sender.get());
auto data_transport_factory = get_mock_data_transport_factory(mock_data_transport.get());
auto model_factory = get_mock_model_factory(mock_model.get());
static auto mock_sender = get_mock_sender();
rajan-chari marked this conversation as resolved.
Show resolved Hide resolved
static auto mock_data_transport = get_mock_data_transport();
static auto mock_model = get_mock_model();

static auto default_sender_factory = get_mock_sender_factory(mock_sender.get(), mock_sender.get());
static auto default_data_transport_factory = get_mock_data_transport_factory(mock_data_transport.get());
static auto default_model_factory = get_mock_model_factory(mock_model.get());

if (!data_transport_factory) {
data_transport_factory = default_data_transport_factory.get();
}

if (!model_factory) {
model_factory = default_model_factory.get();
}

if (!sender_factory) {
sender_factory = default_sender_factory.get();
}


r::live_model model(config, nullptr, nullptr, &r::trace_logger_factory, data_transport_factory, model_factory, sender_factory);
return model;
}
}


BOOST_AUTO_TEST_CASE(live_model_ranking_request) {
//create a simple ds configuration
u::configuration config;
cfg::create_from_json(JSON_CFG, config);
Expand All @@ -59,7 +85,7 @@ BOOST_AUTO_TEST_CASE(live_model_ranking_request) {
r::api_status status;

//create the ds live_model, and initialize it with the config
r::live_model ds(config, nullptr, nullptr, &r::trace_logger_factory, data_transport_factory.get(), model_factory.get(), sender_factory.get());
auto ds = create_mock_live_model(config);
BOOST_CHECK_EQUAL(ds.init(&status), err::success);

const auto event_id = "event_id";
Expand Down Expand Up @@ -91,21 +117,13 @@ BOOST_AUTO_TEST_CASE(live_model_ranking_request) {
}

BOOST_AUTO_TEST_CASE(live_model_outcome) {
auto mock_sender = get_mock_sender();
auto mock_data_transport = get_mock_data_transport();
auto mock_model = get_mock_model();

auto sender_factory = get_mock_sender_factory(mock_sender.get(), mock_sender.get());
auto data_transport_factory = get_mock_data_transport_factory(mock_data_transport.get());
auto model_factory = get_mock_model_factory(mock_model.get());

//create a simple ds configuration
u::configuration config;
cfg::create_from_json(JSON_CFG, config);
config.set(r::name::EH_TEST, "true");

//create a ds live_model, and initialize with configuration
r::live_model ds(config, nullptr, nullptr, &r::trace_logger_factory, data_transport_factory.get(), model_factory.get(), sender_factory.get());
r::live_model ds = create_mock_live_model(config);

//check api_status content when errors are returned
r::api_status status;
Expand Down Expand Up @@ -213,7 +231,7 @@ BOOST_AUTO_TEST_CASE(live_model_mocks) {
cfg::create_from_json(JSON_CFG, config);
config.set(r::name::EH_TEST, "true");
{
r::live_model model(config, nullptr, nullptr, &r::trace_logger_factory, data_transport_factory.get(), model_factory.get(), sender_factory.get());
r::live_model model = create_mock_live_model(config, data_transport_factory.get(), model_factory.get(), sender_factory.get());

r::api_status status;
BOOST_CHECK_EQUAL(model.init(&status), err::success);
Expand All @@ -229,6 +247,35 @@ BOOST_AUTO_TEST_CASE(live_model_mocks) {
BOOST_CHECK_EQUAL(recorded.size(), 2);
}

BOOST_AUTO_TEST_CASE(live_model_no_background_refresh) {
u::configuration config;
cfg::create_from_json(JSON_CFG, config);

config.set(r::name::EH_TEST, "true");
config.set(r::name::MODEL_BACKGROUND_REFRESH, "false");

r::live_model model = create_mock_live_model(config);

r::api_status status;
BOOST_CHECK_EQUAL(model.init(&status), err::success);
}

BOOST_AUTO_TEST_CASE(live_model_no_background_refresh_failure) {
u::configuration config;
cfg::create_from_json(JSON_CFG, config);

config.set(r::name::EH_TEST, "true");
config.set(r::name::MODEL_BACKGROUND_REFRESH, "false");

auto mock_data_transport = get_mock_failing_data_transport();
auto data_transport_factory = get_mock_data_transport_factory(mock_data_transport.get());

r::live_model model = create_mock_live_model(config, data_transport_factory.get());

r::api_status status;
BOOST_CHECK_NE(model.init(&status), err::success);
}

BOOST_AUTO_TEST_CASE(live_model_logger_receive_data) {
std::vector<std::string> recorded_observations;
auto mock_observation_sender = get_mock_sender(recorded_observations);
Expand Down Expand Up @@ -262,7 +309,7 @@ BOOST_AUTO_TEST_CASE(live_model_logger_receive_data) {
std::string expected_interactions;
std::string expected_observations;
{
r::live_model model(config, nullptr, nullptr, &r::trace_logger_factory, data_transport_factory.get(), model_factory.get(), logger_factory.get());
r::live_model model = create_mock_live_model(config, data_transport_factory.get(), model_factory.get(), logger_factory.get());

r::api_status status;
BOOST_CHECK_EQUAL(model.init(&status), err::success);
Expand Down
10 changes: 10 additions & 0 deletions reinforcement_learning/unit_test/mock_util.cc
Original file line number Diff line number Diff line change
Expand Up @@ -48,6 +48,16 @@ std::unique_ptr<fakeit::Mock<m::i_data_transport>> get_mock_data_transport() {
return mock;
}

std::unique_ptr<fakeit::Mock<m::i_data_transport>> get_mock_failing_data_transport() {
auto mock = std::unique_ptr<fakeit::Mock<m::i_data_transport>>(
new fakeit::Mock<m::i_data_transport>());

When(Method((*mock), get_data)).AlwaysReturn(r::error_code::exception_during_http_req);
Fake(Dtor((*mock)));

return mock;
}

std::unique_ptr<fakeit::Mock<m::i_model>> get_mock_model() {
auto mock = std::unique_ptr<fakeit::Mock<m::i_model>>(
new fakeit::Mock<m::i_model>());
Expand Down
3 changes: 2 additions & 1 deletion reinforcement_learning/unit_test/mock_util.h
Original file line number Diff line number Diff line change
Expand Up @@ -11,9 +11,10 @@ std::unique_ptr<fakeit::Mock<reinforcement_learning::i_sender>> get_mock_sender(
std::unique_ptr<fakeit::Mock<reinforcement_learning::i_sender>> get_mock_sender(std::vector<std::string>& recorded_messages);

std::unique_ptr<fakeit::Mock<reinforcement_learning::model_management::i_data_transport>> get_mock_data_transport();
std::unique_ptr<fakeit::Mock<reinforcement_learning::model_management::i_data_transport>> get_mock_failing_data_transport();
std::unique_ptr<fakeit::Mock<reinforcement_learning::model_management::i_model>> get_mock_model();

std::unique_ptr<reinforcement_learning::sender_factory_t> get_mock_sender_factory(fakeit::Mock<reinforcement_learning::i_sender>* mock_observation_sender,
fakeit::Mock<reinforcement_learning::i_sender>* mock_interaction_sender);
std::unique_ptr<reinforcement_learning::data_transport_factory_t> get_mock_data_transport_factory(fakeit::Mock<reinforcement_learning::model_management::i_data_transport>* mock_data_transport);
std::unique_ptr<reinforcement_learning::model_factory_t> get_mock_model_factory(fakeit::Mock<reinforcement_learning::model_management::i_model>* mock_model);
std::unique_ptr<reinforcement_learning::model_factory_t> get_mock_model_factory(fakeit::Mock<reinforcement_learning::model_management::i_model>* mock_model);