diff --git a/include/LightGBM/utils/yamc/alternate_shared_mutex.hpp b/include/LightGBM/utils/yamc/alternate_shared_mutex.hpp new file mode 100644 index 000000000000..45442f61770d --- /dev/null +++ b/include/LightGBM/utils/yamc/alternate_shared_mutex.hpp @@ -0,0 +1,212 @@ +/* + * alternate_shared_mutex.hpp + * + * MIT License + * + * Copyright (c) 2017 yohhoy + * + * Permission is hereby granted, free of charge, to any person obtaining a copy + * of this software and associated documentation files (the "Software"), to deal + * in the Software without restriction, including without limitation the rights + * to use, copy, modify, merge, publish, distribute, sublicense, and/or sell + * copies of the Software, and to permit persons to whom the Software is + * furnished to do so, subject to the following conditions: + * + * The above copyright notice and this permission notice shall be included in + * all copies or substantial portions of the Software. + * + * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR + * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, + * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE + * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER + * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, + * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE + * SOFTWARE. + */ +#ifndef YAMC_ALTERNATE_SHARED_MUTEX_HPP_ +#define YAMC_ALTERNATE_SHARED_MUTEX_HPP_ + +#include +#include +#include +#include + +#include "yamc_rwlock_sched.hpp" + +namespace yamc { + +/* + * alternate implementation of shared mutex variants + * + * - yamc::alternate::shared_mutex + * - yamc::alternate::shared_timed_mutex + * - yamc::alternate::basic_shared_mutex + * - yamc::alternate::basic_shared_timed_mutex + */ +namespace alternate { + +namespace detail { + +template +class shared_mutex_base { + protected: + typename RwLockPolicy::state state_; + std::condition_variable cv_; + std::mutex mtx_; + + void lock() { + std::unique_lock lk(mtx_); + RwLockPolicy::before_wait_wlock(state_); + while (RwLockPolicy::wait_wlock(state_)) { + cv_.wait(lk); + } + RwLockPolicy::after_wait_wlock(state_); + RwLockPolicy::acquire_wlock(&state_); + } + + bool try_lock() { + std::lock_guard lk(mtx_); + if (RwLockPolicy::wait_wlock(state_)) return false; + RwLockPolicy::acquire_wlock(state_); + return true; + } + + void unlock() { + std::lock_guard lk(mtx_); + RwLockPolicy::release_wlock(&state_); + cv_.notify_all(); + } + + void lock_shared() { + std::unique_lock lk(mtx_); + while (RwLockPolicy::wait_rlock(state_)) { + cv_.wait(lk); + } + RwLockPolicy::acquire_rlock(&state_); + } + + bool try_lock_shared() { + std::lock_guard lk(mtx_); + if (RwLockPolicy::wait_rlock(state_)) return false; + RwLockPolicy::acquire_rlock(state_); + return true; + } + + void unlock_shared() { + std::lock_guard lk(mtx_); + if (RwLockPolicy::release_rlock(&state_)) { + cv_.notify_all(); + } + } +}; + +} // namespace detail + +template +class basic_shared_mutex : private detail::shared_mutex_base { + using base = detail::shared_mutex_base; + + public: + basic_shared_mutex() = default; + ~basic_shared_mutex() = default; + + basic_shared_mutex(const basic_shared_mutex&) = delete; + basic_shared_mutex& operator=(const basic_shared_mutex&) = delete; + + using base::lock; + using base::try_lock; + using base::unlock; + + using base::lock_shared; + using base::try_lock_shared; + using base::unlock_shared; +}; + +using shared_mutex = basic_shared_mutex; + +template +class basic_shared_timed_mutex + : private detail::shared_mutex_base { + using base = detail::shared_mutex_base; + + using base::cv_; + using base::mtx_; + using base::state_; + + template + bool do_try_lockwait(const std::chrono::time_point& tp) { + std::unique_lock lk(mtx_); + RwLockPolicy::before_wait_wlock(state_); + while (RwLockPolicy::wait_wlock(state_)) { + if (cv_.wait_until(lk, tp) == std::cv_status::timeout) { + if (!RwLockPolicy::wait_wlock(state_)) // re-check predicate + break; + RwLockPolicy::after_wait_wlock(state_); + return false; + } + } + RwLockPolicy::after_wait_wlock(state_); + RwLockPolicy::acquire_wlock(state_); + return true; + } + + template + bool do_try_lock_sharedwait( + const std::chrono::time_point& tp) { + std::unique_lock lk(mtx_); + while (RwLockPolicy::wait_rlock(state_)) { + if (cv_.wait_until(lk, tp) == std::cv_status::timeout) { + if (!RwLockPolicy::wait_rlock(state_)) // re-check predicate + break; + return false; + } + } + RwLockPolicy::acquire_rlock(state_); + return true; + } + + public: + basic_shared_timed_mutex() = default; + ~basic_shared_timed_mutex() = default; + + basic_shared_timed_mutex(const basic_shared_timed_mutex&) = delete; + basic_shared_timed_mutex& operator=(const basic_shared_timed_mutex&) = delete; + + using base::lock; + using base::try_lock; + using base::unlock; + + template + bool try_lock_for(const std::chrono::duration& duration) { + const auto tp = std::chrono::steady_clock::now() + duration; + return do_try_lockwait(tp); + } + + template + bool try_lock_until(const std::chrono::time_point& tp) { + return do_try_lockwait(tp); + } + + using base::lock_shared; + using base::try_lock_shared; + using base::unlock_shared; + + template + bool try_lock_shared_for(const std::chrono::duration& duration) { + const auto tp = std::chrono::steady_clock::now() + duration; + return do_try_lock_sharedwait(tp); + } + + template + bool try_lock_shared_until( + const std::chrono::time_point& tp) { + return do_try_lock_sharedwait(tp); + } +}; + +using shared_timed_mutex = basic_shared_timed_mutex; + +} // namespace alternate +} // namespace yamc + +#endif diff --git a/include/LightGBM/utils/yamc/yamc_rwlock_sched.hpp b/include/LightGBM/utils/yamc/yamc_rwlock_sched.hpp new file mode 100644 index 000000000000..b7d36950ce36 --- /dev/null +++ b/include/LightGBM/utils/yamc/yamc_rwlock_sched.hpp @@ -0,0 +1,149 @@ +/* + * yamc_rwlock_sched.hpp + * + * MIT License + * + * Copyright (c) 2017 yohhoy + * + * Permission is hereby granted, free of charge, to any person obtaining a copy + * of this software and associated documentation files (the "Software"), to deal + * in the Software without restriction, including without limitation the rights + * to use, copy, modify, merge, publish, distribute, sublicense, and/or sell + * copies of the Software, and to permit persons to whom the Software is + * furnished to do so, subject to the following conditions: + * + * The above copyright notice and this permission notice shall be included in + * all copies or substantial portions of the Software. + * + * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR + * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, + * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE + * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER + * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, + * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE + * SOFTWARE. + */ +#ifndef YAMC_RWLOCK_SCHED_HPP_ +#define YAMC_RWLOCK_SCHED_HPP_ + +#include +#include + +/// default shared_mutex rwlock policy +#ifndef YAMC_RWLOCK_SCHED_DEFAULT +#define YAMC_RWLOCK_SCHED_DEFAULT yamc::rwlock::ReaderPrefer +#endif + +namespace yamc { + +/* + * readers-writer locking policy for basic_shared_(timed)_mutex + * + * - yamc::rwlock::ReaderPrefer + * - yamc::rwlock::WriterPrefer + */ +namespace rwlock { + +/// Reader prefer scheduling +/// +/// NOTE: +// This policy might introduce "Writer Starvation" if readers continuously +// hold shared lock. PThreads rwlock implementation in Linux use this +// scheduling policy as default. (see also PTHREAD_RWLOCK_PREFER_READER_NP) +// +struct ReaderPrefer { + static const std::size_t writer_mask = ~(~std::size_t(0u) >> 1); // MSB 1bit + static const std::size_t reader_mask = ~std::size_t(0u) >> 1; + + struct state { + std::size_t rwcount = 0; + }; + + static void before_wait_wlock(const state&) {} + static void after_wait_wlock(const state&) {} + + static bool wait_wlock(const state& s) { return (s.rwcount != 0); } + + static void acquire_wlock(state* s) { + assert(!(s->rwcount & writer_mask)); + s->rwcount |= writer_mask; + } + + static void release_wlock(state* s) { + assert(s->rwcount & writer_mask); + s->rwcount &= ~writer_mask; + } + + static bool wait_rlock(const state& s) { return (s.rwcount & writer_mask) != 0; } + + static void acquire_rlock(state* s) { + assert((s->rwcount & reader_mask) < reader_mask); + ++(s->rwcount); + } + + static bool release_rlock(state* s) { + assert(0 < (s->rwcount & reader_mask)); + return (--(s->rwcount) == 0); + } +}; + +/// Writer prefer scheduling +/// +/// NOTE: +/// If there are waiting writer, new readers are blocked until all shared lock +/// are released, +// and the writer thread can get exclusive lock in preference to blocked +// reader threads. This policy might introduce "Reader Starvation" if writers +// continuously request exclusive lock. +/// (see also PTHREAD_RWLOCK_PREFER_WRITER_NONRECURSIVE_NP) +/// +struct WriterPrefer { + static const std::size_t locked = ~(~std::size_t(0u) >> 1); // MSB 1bit + static const std::size_t wait_mask = ~std::size_t(0u) >> 1; + + struct state { + std::size_t nwriter = 0; + std::size_t nreader = 0; + }; + + static void before_wait_wlock(state* s) { + assert((s->nwriter & wait_mask) < wait_mask); + ++(s->nwriter); + } + + static bool wait_wlock(const state& s) { + return ((s.nwriter & locked) || 0 < s.nreader); + } + + static void after_wait_wlock(state* s) { + assert(0 < (s->nwriter & wait_mask)); + --(s->nwriter); + } + + static void acquire_wlock(state* s) { + assert(!(s->nwriter & locked)); + s->nwriter |= locked; + } + + static void release_wlock(state* s) { + assert(s->nwriter & locked); + s->nwriter &= ~locked; + } + + static bool wait_rlock(const state& s) { return (s.nwriter != 0); } + + static void acquire_rlock(state* s) { + assert(!(s->nwriter & locked)); + ++(s->nreader); + } + + static bool release_rlock(state* s) { + assert(0 < s->nreader); + return (--(s->nreader) == 0); + } +}; + +} // namespace rwlock +} // namespace yamc + +#endif diff --git a/include/LightGBM/utils/yamc/yamc_shared_lock.hpp b/include/LightGBM/utils/yamc/yamc_shared_lock.hpp new file mode 100644 index 000000000000..658c0c357d78 --- /dev/null +++ b/include/LightGBM/utils/yamc/yamc_shared_lock.hpp @@ -0,0 +1,197 @@ +/* + * yamc_shared_lock.hpp + * + * MIT License + * + * Copyright (c) 2017 yohhoy + * + * Permission is hereby granted, free of charge, to any person obtaining a copy + * of this software and associated documentation files (the "Software"), to deal + * in the Software without restriction, including without limitation the rights + * to use, copy, modify, merge, publish, distribute, sublicense, and/or sell + * copies of the Software, and to permit persons to whom the Software is + * furnished to do so, subject to the following conditions: + * + * The above copyright notice and this permission notice shall be included in + * all copies or substantial portions of the Software. + * + * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR + * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, + * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE + * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER + * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, + * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE + * SOFTWARE. + */ +#ifndef YAMC_SHARED_LOCK_HPP_ +#define YAMC_SHARED_LOCK_HPP_ + +#include +#include +#include +#include +#include // std::swap + +/* + * std::shared_lock in C++14 Standard Library + * + * - yamc::shared_lock + */ +namespace yamc { + +template +class shared_lock { + void locking_precondition(const char* emsg) { + if (pm_ == nullptr) { + throw std::system_error( + std::make_error_code(std::errc::operation_not_permitted), emsg); + } + if (owns_) { + throw std::system_error( + std::make_error_code(std::errc::resource_deadlock_would_occur), emsg); + } + } + + public: + using mutex_type = Mutex; + + shared_lock() noexcept = default; + + explicit shared_lock(mutex_type* m) { + m->lock_shared(); + pm_ = m; + owns_ = true; + } + + shared_lock(const mutex_type& m, std::defer_lock_t) noexcept { + pm_ = &m; + owns_ = false; + } + + shared_lock(const mutex_type& m, std::try_to_lock_t) { + pm_ = &m; + owns_ = m.try_lock_shared(); + } + + shared_lock(const mutex_type& m, std::adopt_lock_t) { + pm_ = &m; + owns_ = true; + } + + template + shared_lock(const mutex_type& m, + const std::chrono::time_point& abs_time) { + pm_ = &m; + owns_ = m.try_lock_shared_until(abs_time); + } + + template + shared_lock(const mutex_type& m, + const std::chrono::duration& rel_time) { + pm_ = &m; + owns_ = m.try_lock_shared_for(rel_time); + } + + ~shared_lock() { + if (owns_) { + assert(pm_ != nullptr); + pm_->unlock_shared(); + } + } + + shared_lock(const shared_lock&) = delete; + shared_lock& operator=(const shared_lock&) = delete; + + shared_lock(shared_lock&& rhs) noexcept { + if (pm_ && owns_) { + pm_->unlock_shared(); + } + pm_ = rhs.pm_; + owns_ = rhs.owns_; + rhs.pm_ = nullptr; + rhs.owns_ = false; + } + + shared_lock& operator=(shared_lock&& rhs) noexcept { + if (pm_ && owns_) { + pm_->unlock_shared(); + } + pm_ = rhs.pm_; + owns_ = rhs.owns_; + rhs.pm_ = nullptr; + rhs.owns_ = false; + return *this; + } + + void lock() { + locking_precondition("shared_lock::lock"); + pm_->lock_shared(); + owns_ = true; + } + + bool try_lock() { + locking_precondition("shared_lock::try_lock"); + return (owns_ = pm_->try_lock_shared()); + } + + template + bool try_lock_for(const std::chrono::duration& rel_time) { + locking_precondition("shared_lock::try_lock_for"); + return (owns_ = pm_->try_lock_shared_for(rel_time)); + } + + template + bool try_lock_until( + const std::chrono::time_point& abs_time) { + locking_precondition("shared_lock::try_lock_until"); + return (owns_ = pm_->try_lock_shared_until(abs_time)); + } + + void unlock() { + assert(pm_ != nullptr); + if (!owns_) { + throw std::system_error( + std::make_error_code(std::errc::operation_not_permitted), + "shared_lock::unlock"); + } + pm_->unlock_shared(); + owns_ = false; + } + + void swap(shared_lock& sl) noexcept { + std::swap(pm_, sl.pm_); + std::swap(owns_, sl.owns_); + } + + mutex_type* release() noexcept { + mutex_type* result = pm_; + pm_ = nullptr; + owns_ = false; + return result; + } + + bool owns_lock() const noexcept { return owns_; } + + explicit operator bool() const noexcept { return owns_; } + + mutex_type* mutex() const noexcept { return pm_; } + + private: + mutex_type* pm_ = nullptr; + bool owns_ = false; +}; + +} // namespace yamc + +namespace std { + +/// std::swap() specialization for yamc::shared_lock type +template +void swap(yamc::shared_lock& lhs, + yamc::shared_lock& rhs) noexcept { + lhs.swap(rhs); +} + +} // namespace std + +#endif diff --git a/src/c_api.cpp b/src/c_api.cpp index 091f3cd49a14..4c5e9170b642 100644 --- a/src/c_api.cpp +++ b/src/c_api.cpp @@ -27,6 +27,8 @@ #include #include "application/predictor.hpp" +#include +#include namespace LightGBM { @@ -46,6 +48,12 @@ catch(std::string& ex) { return LGBM_APIHandleException(ex); } \ catch(...) { return LGBM_APIHandleException("unknown exception"); } \ return 0; +#define UNIQUE_LOCK(mtx) \ +std::unique_lock lock(mtx); + +#define SHARED_LOCK(mtx) \ +yamc::shared_lock lock(&mtx); + const int PREDICTOR_TYPES = 4; // Single row predictor to abstract away caching logic @@ -133,7 +141,7 @@ class Booster { } void MergeFrom(const Booster* other) { - std::lock_guard lock(mutex_); + UNIQUE_LOCK(mutex_) boosting_->MergeFrom(other->boosting_.get()); } @@ -166,7 +174,7 @@ class Booster { void ResetTrainingData(const Dataset* train_data) { if (train_data != train_data_) { - std::lock_guard lock(mutex_); + UNIQUE_LOCK(mutex_) train_data_ = train_data; CreateObjectiveAndMetrics(); // reset the boosting @@ -284,7 +292,7 @@ class Booster { } void ResetConfig(const char* parameters) { - std::lock_guard lock(mutex_); + UNIQUE_LOCK(mutex_) auto param = Config::Str2Map(parameters); if (param.count("num_class")) { Log::Fatal("Cannot change num_class during training"); @@ -322,7 +330,7 @@ class Booster { } void AddValidData(const Dataset* valid_data) { - std::lock_guard lock(mutex_); + UNIQUE_LOCK(mutex_) valid_metrics_.emplace_back(); for (auto metric_type : config_.metric) { auto metric = std::unique_ptr(Metric::CreateMetric(metric_type, config_)); @@ -336,12 +344,12 @@ class Booster { } bool TrainOneIter() { - std::lock_guard lock(mutex_); + UNIQUE_LOCK(mutex_) return boosting_->TrainOneIter(nullptr, nullptr); } void Refit(const int32_t* leaf_preds, int32_t nrow, int32_t ncol) { - std::lock_guard lock(mutex_); + UNIQUE_LOCK(mutex_) std::vector> v_leaf_preds(nrow, std::vector(ncol, 0)); for (int i = 0; i < nrow; ++i) { for (int j = 0; j < ncol; ++j) { @@ -352,37 +360,42 @@ class Booster { } bool TrainOneIter(const score_t* gradients, const score_t* hessians) { - std::lock_guard lock(mutex_); + UNIQUE_LOCK(mutex_) return boosting_->TrainOneIter(gradients, hessians); } void RollbackOneIter() { - std::lock_guard lock(mutex_); + UNIQUE_LOCK(mutex_) boosting_->RollbackOneIter(); } - void PredictSingleRow(int num_iteration, int predict_type, int ncol, + void SetSingleRowPredictor(int num_iteration, int predict_type, const Config& config) { + UNIQUE_LOCK(mutex_) + if (single_row_predictor_[predict_type].get() == nullptr || + !single_row_predictor_[predict_type]->IsPredictorEqual(config, num_iteration, boosting_.get())) { + single_row_predictor_[predict_type].reset(new SingleRowPredictor(predict_type, boosting_.get(), + config, num_iteration)); + } + } + + void PredictSingleRow(int predict_type, int ncol, std::function>(int row_idx)> get_row_fun, const Config& config, - double* out_result, int64_t* out_len) { + double* out_result, int64_t* out_len) const { if (!config.predict_disable_shape_check && ncol != boosting_->MaxFeatureIdx() + 1) { Log::Fatal("The number of features in data (%d) is not the same as it was in training data (%d).\n"\ "You can set ``predict_disable_shape_check=true`` to discard this error, but please be aware what you are doing.", ncol, boosting_->MaxFeatureIdx() + 1); } - std::lock_guard lock(mutex_); - if (single_row_predictor_[predict_type].get() == nullptr || - !single_row_predictor_[predict_type]->IsPredictorEqual(config, num_iteration, boosting_.get())) { - single_row_predictor_[predict_type].reset(new SingleRowPredictor(predict_type, boosting_.get(), - config, num_iteration)); - } + SHARED_LOCK(mutex_) + const auto& single_row_predictor = single_row_predictor_[predict_type]; auto one_row = get_row_fun(0); auto pred_wrt_ptr = out_result; - single_row_predictor_[predict_type]->predict_function(one_row, pred_wrt_ptr); + single_row_predictor->predict_function(one_row, pred_wrt_ptr); - *out_len = single_row_predictor_[predict_type]->num_pred_in_one_row; + *out_len = single_row_predictor->num_pred_in_one_row; } - Predictor CreatePredictor(int num_iteration, int predict_type, int ncol, const Config& config) { + Predictor CreatePredictor(int num_iteration, int predict_type, int ncol, const Config& config) const { if (!config.predict_disable_shape_check && ncol != boosting_->MaxFeatureIdx() + 1) { Log::Fatal("The number of features in data (%d) is not the same as it was in training data (%d).\n" \ "You can set ``predict_disable_shape_check=true`` to discard this error, but please be aware what you are doing.", ncol, boosting_->MaxFeatureIdx() + 1); @@ -408,8 +421,8 @@ class Booster { void Predict(int num_iteration, int predict_type, int nrow, int ncol, std::function>(int row_idx)> get_row_fun, const Config& config, - double* out_result, int64_t* out_len) { - std::lock_guard lock(mutex_); + double* out_result, int64_t* out_len) const { + SHARED_LOCK(mutex_); auto predictor = CreatePredictor(num_iteration, predict_type, ncol, config); bool is_predict_leaf = false; bool predict_contrib = false; @@ -438,7 +451,7 @@ class Booster { const Config& config, int64_t* out_elements_size, std::vector>>* agg_ptr, int32_t** out_indices, void** out_data, int data_type, - bool* is_data_float32_ptr, int num_matrices) { + bool* is_data_float32_ptr, int num_matrices) const { auto predictor = CreatePredictor(num_iteration, predict_type, ncol, config); auto pred_sparse_fun = predictor.GetPredictSparseFunction(); std::vector>>& agg = *agg_ptr; @@ -479,8 +492,8 @@ class Booster { std::function>(int64_t row_idx)> get_row_fun, const Config& config, int64_t* out_len, void** out_indptr, int indptr_type, - int32_t** out_indices, void** out_data, int data_type) { - std::lock_guard lock(mutex_); + int32_t** out_indices, void** out_data, int data_type) const { + SHARED_LOCK(mutex_); // Get the number of trees per iteration (for multiclass scenario we output multiple sparse matrices) int num_matrices = boosting_->NumModelPerIteration(); bool is_indptr_int32 = false; @@ -563,8 +576,8 @@ class Booster { std::function>(int64_t row_idx)> get_row_fun, const Config& config, int64_t* out_len, void** out_col_ptr, int col_ptr_type, - int32_t** out_indices, void** out_data, int data_type) { - std::lock_guard lock(mutex_); + int32_t** out_indices, void** out_data, int data_type) const { + SHARED_LOCK(mutex_); // Get the number of trees per iteration (for multiclass scenario we output multiple sparse matrices) int num_matrices = boosting_->NumModelPerIteration(); auto predictor = CreatePredictor(num_iteration, predict_type, ncol, config); @@ -665,8 +678,8 @@ class Booster { void Predict(int num_iteration, int predict_type, const char* data_filename, int data_has_header, const Config& config, - const char* result_filename) { - std::lock_guard lock(mutex_); + const char* result_filename) const { + SHARED_LOCK(mutex_) bool is_predict_leaf = false; bool is_raw_score = false; bool predict_contrib = false; @@ -685,11 +698,11 @@ class Booster { predictor.Predict(data_filename, result_filename, bool_data_has_header, config.predict_disable_shape_check); } - void GetPredictAt(int data_idx, double* out_result, int64_t* out_len) { + void GetPredictAt(int data_idx, double* out_result, int64_t* out_len) const { boosting_->GetPredictAt(data_idx, out_result, out_len); } - void SaveModelToFile(int start_iteration, int num_iteration, int feature_importance_type, const char* filename) { + void SaveModelToFile(int start_iteration, int num_iteration, int feature_importance_type, const char* filename) const { boosting_->SaveModelToFile(start_iteration, num_iteration, feature_importance_type, filename); } @@ -699,46 +712,48 @@ class Booster { } std::string SaveModelToString(int start_iteration, int num_iteration, - int feature_importance_type) { + int feature_importance_type) const { return boosting_->SaveModelToString(start_iteration, num_iteration, feature_importance_type); } std::string DumpModel(int start_iteration, int num_iteration, - int feature_importance_type) { + int feature_importance_type) const { return boosting_->DumpModel(start_iteration, num_iteration, feature_importance_type); } - std::vector FeatureImportance(int num_iteration, int importance_type) { + std::vector FeatureImportance(int num_iteration, int importance_type) const { return boosting_->FeatureImportance(num_iteration, importance_type); } double UpperBoundValue() const { - std::lock_guard lock(mutex_); + SHARED_LOCK(mutex_) return boosting_->GetUpperBoundValue(); } double LowerBoundValue() const { - std::lock_guard lock(mutex_); + SHARED_LOCK(mutex_) return boosting_->GetLowerBoundValue(); } double GetLeafValue(int tree_idx, int leaf_idx) const { + SHARED_LOCK(mutex_) return dynamic_cast(boosting_.get())->GetLeafValue(tree_idx, leaf_idx); } void SetLeafValue(int tree_idx, int leaf_idx, double val) { - std::lock_guard lock(mutex_); + UNIQUE_LOCK(mutex_) dynamic_cast(boosting_.get())->SetLeafValue(tree_idx, leaf_idx, val); } void ShuffleModels(int start_iter, int end_iter) { - std::lock_guard lock(mutex_); + UNIQUE_LOCK(mutex_) boosting_->ShuffleModels(start_iter, end_iter); } int GetEvalCounts() const { + SHARED_LOCK(mutex_) int ret = 0; for (const auto& metric : train_metric_) { ret += static_cast(metric->GetName().size()); @@ -747,6 +762,7 @@ class Booster { } int GetEvalNames(char** out_strs, const int len, const size_t buffer_len, size_t *out_buffer_len) const { + SHARED_LOCK(mutex_) *out_buffer_len = 0; int idx = 0; for (const auto& metric : train_metric_) { @@ -763,6 +779,7 @@ class Booster { } int GetFeatureNames(char** out_strs, const int len, const size_t buffer_len, size_t *out_buffer_len) const { + SHARED_LOCK(mutex_) *out_buffer_len = 0; int idx = 0; for (const auto& name : boosting_->FeatureNames()) { @@ -792,7 +809,7 @@ class Booster { /*! \brief Training objective function */ std::unique_ptr objective_fun_; /*! \brief mutex for threading safe call */ - mutable std::mutex mutex_; + mutable yamc::alternate::shared_mutex mutex_; }; } // namespace LightGBM @@ -1916,7 +1933,8 @@ int LGBM_BoosterPredictForCSRSingleRow(BoosterHandle handle, } Booster* ref_booster = reinterpret_cast(handle); auto get_row_fun = RowFunctionFromCSR(indptr, indptr_type, indices, data, data_type, nindptr, nelem); - ref_booster->PredictSingleRow(num_iteration, predict_type, static_cast(num_col), get_row_fun, config, out_result, out_len); + ref_booster->SetSingleRowPredictor(num_iteration, predict_type, config); + ref_booster->PredictSingleRow(predict_type, static_cast(num_col), get_row_fun, config, out_result, out_len); API_END(); } @@ -1960,7 +1978,7 @@ int LGBM_BoosterPredictForCSRSingleRowFast(FastConfigHandle fastConfig_handle, API_BEGIN(); FastConfig *fastConfig = reinterpret_cast(fastConfig_handle); auto get_row_fun = RowFunctionFromCSR(indptr, indptr_type, indices, data, fastConfig->data_type, nindptr, nelem); - fastConfig->booster->PredictSingleRow(num_iteration, predict_type, fastConfig->ncol, + fastConfig->booster->PredictSingleRow(predict_type, fastConfig->ncol, get_row_fun, fastConfig->config, out_result, out_len); API_END(); } @@ -2058,7 +2076,8 @@ int LGBM_BoosterPredictForMatSingleRow(BoosterHandle handle, } Booster* ref_booster = reinterpret_cast(handle); auto get_row_fun = RowPairFunctionFromDenseMatric(data, 1, ncol, data_type, is_row_major); - ref_booster->PredictSingleRow(num_iteration, predict_type, ncol, get_row_fun, config, out_result, out_len); + ref_booster->SetSingleRowPredictor(num_iteration, predict_type, config); + ref_booster->PredictSingleRow(predict_type, ncol, get_row_fun, config, out_result, out_len); API_END(); } @@ -2092,7 +2111,7 @@ int LGBM_BoosterPredictForMatSingleRowFast(FastConfigHandle fastConfig_handle, FastConfig *fastConfig = reinterpret_cast(fastConfig_handle); // Single row in row-major format: auto get_row_fun = RowPairFunctionFromDenseMatric(data, 1, fastConfig->ncol, fastConfig->data_type, 1); - fastConfig->booster->PredictSingleRow(num_iteration, predict_type, fastConfig->ncol, + fastConfig->booster->PredictSingleRow(predict_type, fastConfig->ncol, get_row_fun, fastConfig->config, out_result, out_len); API_END(); diff --git a/windows/LightGBM.vcxproj b/windows/LightGBM.vcxproj index 0f814a341977..9dd319527229 100644 --- a/windows/LightGBM.vcxproj +++ b/windows/LightGBM.vcxproj @@ -244,6 +244,7 @@ + @@ -255,6 +256,8 @@ + + diff --git a/windows/LightGBM.vcxproj.filters b/windows/LightGBM.vcxproj.filters index f122c865afee..231798919efe 100644 --- a/windows/LightGBM.vcxproj.filters +++ b/windows/LightGBM.vcxproj.filters @@ -210,6 +210,15 @@ src\treelearner + + include\LightGBM\utils\yamc + + + include\LightGBM\utils\yamc + + + include\LightGBM\utils\yamc +