From 9db490a98570b56ed7624c0a9229606d015ce79c Mon Sep 17 00:00:00 2001 From: Chen Yufei Date: Thu, 15 Apr 2021 08:39:43 +0800 Subject: [PATCH] [option] precise_float_parser: precise float number parsing for text input. --- docs/Parameters.rst | 6 ++++++ include/LightGBM/config.h | 5 +++++ include/LightGBM/dataset.h | 5 ++++- src/application/application.cpp | 6 ++++-- src/application/predictor.hpp | 5 +++-- src/c_api.cpp | 3 ++- src/io/config_auto.cpp | 3 +++ src/io/dataset_loader.cpp | 6 ++++-- src/io/parser.cpp | 12 +++++------- src/io/parser.hpp | 31 ++++++++++++------------------- 10 files changed, 48 insertions(+), 34 deletions(-) diff --git a/docs/Parameters.rst b/docs/Parameters.rst index 9632e020deb6..76c25b1c719f 100644 --- a/docs/Parameters.rst +++ b/docs/Parameters.rst @@ -820,6 +820,12 @@ Dataset Parameters - **Note**: can be used only in CLI version; for language-specific packages you can use the correspondent function +- ``precise_float_parser`` :raw-html:`🔗︎`, default = ``false``, type = bool + + - Use precise floating point number parsing for text parser (e.g. CSV, TSV, LibSVM input). + + - **Note**: setting this to `true` may lead to much slower text parsing. + Predict Parameters ~~~~~~~~~~~~~~~~~~ diff --git a/include/LightGBM/config.h b/include/LightGBM/config.h index 73696cdb88f8..02ff1d311d02 100644 --- a/include/LightGBM/config.h +++ b/include/LightGBM/config.h @@ -712,6 +712,11 @@ struct Config { // desc = **Note**: can be used only in CLI version; for language-specific packages you can use the correspondent function bool save_binary = false; + // [no-save] + // desc = Use precise floating point number parsing for text parser (e.g. CSV, TSV, LibSVM input). + // desc = **Note**: setting this to `true` may lead to much slower text parsing. + bool precise_float_parser = false; + #pragma endregion #pragma region Predict Parameters diff --git a/include/LightGBM/dataset.h b/include/LightGBM/dataset.h index 61989e221bcc..23120e682f27 100644 --- a/include/LightGBM/dataset.h +++ b/include/LightGBM/dataset.h @@ -252,6 +252,8 @@ class Metadata { /*! \brief Interface for Parser */ class Parser { public: + typedef const char* (*AtofFunc)(const char* p, double* out); + /*! \brief virtual destructor */ virtual ~Parser() {} @@ -271,9 +273,10 @@ class Parser { * \param filename One Filename of data * \param num_features Pass num_features of this data file if you know, <=0 means don't know * \param label_idx index of label column + * \param precise_float_parser using precise floating point number parsing if true * \return Object of parser */ - static Parser* CreateParser(const char* filename, bool header, int num_features, int label_idx); + static Parser* CreateParser(const char* filename, bool header, int num_features, int label_idx, bool precise_float_parser); }; /*! \brief The main class of data set, diff --git a/src/application/application.cpp b/src/application/application.cpp index e82cfcada98f..d9a4d7544ebc 100644 --- a/src/application/application.cpp +++ b/src/application/application.cpp @@ -221,7 +221,8 @@ void Application::Predict() { if (config_.task == TaskType::KRefitTree) { // create predictor Predictor predictor(boosting_.get(), 0, -1, false, true, false, false, 1, 1); - predictor.Predict(config_.data.c_str(), config_.output_result.c_str(), config_.header, config_.predict_disable_shape_check); + predictor.Predict(config_.data.c_str(), config_.output_result.c_str(), config_.header, config_.predict_disable_shape_check, + config_.precise_float_parser); TextReader result_reader(config_.output_result.c_str(), false); result_reader.ReadAllLines(); std::vector> pred_leaf(result_reader.Lines().size()); @@ -251,7 +252,8 @@ void Application::Predict() { config_.pred_early_stop, config_.pred_early_stop_freq, config_.pred_early_stop_margin); predictor.Predict(config_.data.c_str(), - config_.output_result.c_str(), config_.header, config_.predict_disable_shape_check); + config_.output_result.c_str(), config_.header, config_.predict_disable_shape_check, + config_.precise_float_parser); Log::Info("Finished prediction"); } } diff --git a/src/application/predictor.hpp b/src/application/predictor.hpp index 7c2241b36959..dff23add2df5 100644 --- a/src/application/predictor.hpp +++ b/src/application/predictor.hpp @@ -160,13 +160,14 @@ class Predictor { * \param data_filename Filename of data * \param result_filename Filename of output result */ - void Predict(const char* data_filename, const char* result_filename, bool header, bool disable_shape_check) { + void Predict(const char* data_filename, const char* result_filename, bool header, bool disable_shape_check, bool precise_float_parser) { auto writer = VirtualFileWriter::Make(result_filename); if (!writer->Init()) { Log::Fatal("Prediction results file %s cannot be created", result_filename); } auto label_idx = header ? -1 : boosting_->LabelIdx(); - auto parser = std::unique_ptr(Parser::CreateParser(data_filename, header, boosting_->MaxFeatureIdx() + 1, label_idx)); + auto parser = std::unique_ptr(Parser::CreateParser(data_filename, header, boosting_->MaxFeatureIdx() + 1, label_idx, + precise_float_parser)); if (parser == nullptr) { Log::Fatal("Could not recognize the data format of data file %s", data_filename); diff --git a/src/c_api.cpp b/src/c_api.cpp index b0f828f5072f..3d20d92da70d 100644 --- a/src/c_api.cpp +++ b/src/c_api.cpp @@ -709,7 +709,8 @@ class Booster { Predictor predictor(boosting_.get(), start_iteration, num_iteration, is_raw_score, is_predict_leaf, predict_contrib, config.pred_early_stop, config.pred_early_stop_freq, config.pred_early_stop_margin); bool bool_data_has_header = data_has_header > 0 ? true : false; - predictor.Predict(data_filename, result_filename, bool_data_has_header, config.predict_disable_shape_check); + predictor.Predict(data_filename, result_filename, bool_data_has_header, config.predict_disable_shape_check, + config.precise_float_parser); } void GetPredictAt(int data_idx, double* out_result, int64_t* out_len) const { diff --git a/src/io/config_auto.cpp b/src/io/config_auto.cpp index 06c53e84268a..913c3242a3eb 100644 --- a/src/io/config_auto.cpp +++ b/src/io/config_auto.cpp @@ -259,6 +259,7 @@ const std::unordered_set& Config::parameter_set() { "categorical_feature", "forcedbins_filename", "save_binary", + "precise_float_parser", "start_iteration_predict", "num_iteration_predict", "predict_raw_score", @@ -525,6 +526,8 @@ void Config::GetMembersFromString(const std::unordered_map(Parser::CreateParser(filename, config_.header, 0, label_idx_)); + auto parser = std::unique_ptr(Parser::CreateParser(filename, config_.header, 0, label_idx_, + config_.precise_float_parser)); if (parser == nullptr) { Log::Fatal("Could not recognize data format of %s", filename); } @@ -267,7 +268,8 @@ Dataset* DatasetLoader::LoadFromFileAlignWithOtherDataset(const char* filename, } auto bin_filename = CheckCanLoadFromBin(filename); if (bin_filename.size() == 0) { - auto parser = std::unique_ptr(Parser::CreateParser(filename, config_.header, 0, label_idx_)); + auto parser = std::unique_ptr(Parser::CreateParser(filename, config_.header, 0, label_idx_, + config_.precise_float_parser)); if (parser == nullptr) { Log::Fatal("Could not recognize data format of %s", filename); } diff --git a/src/io/parser.cpp b/src/io/parser.cpp index f16c643d1d6a..550c4e13d5c0 100644 --- a/src/io/parser.cpp +++ b/src/io/parser.cpp @@ -6,9 +6,6 @@ #include #include -#include -#include -#include #include namespace LightGBM { @@ -232,7 +229,7 @@ DataType GetDataType(const char* filename, bool header, return type; } -Parser* Parser::CreateParser(const char* filename, bool header, int num_features, int label_idx) { +Parser* Parser::CreateParser(const char* filename, bool header, int num_features, int label_idx, bool precise_float_parser) { const int n_read_line = 32; auto lines = ReadKLineFromFile(filename, header, n_read_line); int num_col = 0; @@ -242,15 +239,16 @@ Parser* Parser::CreateParser(const char* filename, bool header, int num_features } std::unique_ptr ret; int output_label_index = -1; + AtofFunc atof = precise_float_parser ? Common::AtofPrecise : Common::Atof; if (type == DataType::LIBSVM) { output_label_index = GetLabelIdxForLibsvm(lines[0], num_features, label_idx); - ret.reset(new LibSVMParser(output_label_index, num_col)); + ret.reset(new LibSVMParser(output_label_index, num_col, atof)); } else if (type == DataType::TSV) { output_label_index = GetLabelIdxForTSV(lines[0], num_features, label_idx); - ret.reset(new TSVParser(output_label_index, num_col)); + ret.reset(new TSVParser(output_label_index, num_col, atof)); } else if (type == DataType::CSV) { output_label_index = GetLabelIdxForCSV(lines[0], num_features, label_idx); - ret.reset(new CSVParser(output_label_index, num_col)); + ret.reset(new CSVParser(output_label_index, num_col, atof)); } if (output_label_index < 0 && label_idx >= 0) { diff --git a/src/io/parser.hpp b/src/io/parser.hpp index dad71ed31d0d..531bfad2f2d3 100644 --- a/src/io/parser.hpp +++ b/src/io/parser.hpp @@ -15,20 +15,10 @@ namespace LightGBM { -#ifdef USE_PRECISE_TEXT_PARSER -static const char* TextParserAtof(const char* p, double* out) { - return Common::AtofPrecise(p, out); -} -#else -static const char* TextParserAtof(const char* p, double* out) { - return Common::Atof(p, out); -} -#endif - class CSVParser: public Parser { public: - explicit CSVParser(int label_idx, int total_columns) - :label_idx_(label_idx), total_columns_(total_columns) { + explicit CSVParser(int label_idx, int total_columns, AtofFunc atof) + :label_idx_(label_idx), total_columns_(total_columns), atof_(atof) { } inline void ParseOneLine(const char* str, std::vector>* out_features, double* out_label) const override { @@ -37,7 +27,7 @@ class CSVParser: public Parser { int offset = 0; *out_label = 0.0f; while (*str != '\0') { - str = TextParserAtof(str, &val); + str = atof_(str, &val); if (idx == label_idx_) { *out_label = val; offset = -1; @@ -60,12 +50,13 @@ class CSVParser: public Parser { private: int label_idx_ = 0; int total_columns_ = -1; + AtofFunc atof_; }; class TSVParser: public Parser { public: - explicit TSVParser(int label_idx, int total_columns) - :label_idx_(label_idx), total_columns_(total_columns) { + explicit TSVParser(int label_idx, int total_columns, AtofFunc atof) + :label_idx_(label_idx), total_columns_(total_columns), atof_(atof) { } inline void ParseOneLine(const char* str, std::vector>* out_features, double* out_label) const override { @@ -73,7 +64,7 @@ class TSVParser: public Parser { double val = 0.0f; int offset = 0; while (*str != '\0') { - str = TextParserAtof(str, &val); + str = atof_(str, &val); if (idx == label_idx_) { *out_label = val; offset = -1; @@ -96,12 +87,13 @@ class TSVParser: public Parser { private: int label_idx_ = 0; int total_columns_ = -1; + AtofFunc atof_; }; class LibSVMParser: public Parser { public: - explicit LibSVMParser(int label_idx, int total_columns) - :label_idx_(label_idx), total_columns_(total_columns) { + explicit LibSVMParser(int label_idx, int total_columns, AtofFunc atof) + :label_idx_(label_idx), total_columns_(total_columns), atof_(atof) { if (label_idx > 0) { Log::Fatal("Label should be the first column in a LibSVM file"); } @@ -111,7 +103,7 @@ class LibSVMParser: public Parser { int idx = 0; double val = 0.0f; if (label_idx_ == 0) { - str = TextParserAtof(str, &val); + str = atof_(str, &val); *out_label = val; str = Common::SkipSpaceAndTab(str); } @@ -136,6 +128,7 @@ class LibSVMParser: public Parser { private: int label_idx_ = 0; int total_columns_ = -1; + AtofFunc atof_; }; } // namespace LightGBM