diff --git a/docs/Parameters.rst b/docs/Parameters.rst
index 9632e020deb6..1ea8ec4b9d45 100644
--- a/docs/Parameters.rst
+++ b/docs/Parameters.rst
@@ -820,6 +820,12 @@ Dataset Parameters
- **Note**: can be used only in CLI version; for language-specific packages you can use the correspondent function
+- ``precise_float_parser`` :raw-html:`🔗︎`, default = ``false``, type = bool
+
+ - Use precise floating point number parsing for text parser (e.g. CSV, TSV, LibSVM input).
+
+ - **Note**: setting this to ``true`` may lead to much slower text parsing.
+
Predict Parameters
~~~~~~~~~~~~~~~~~~
diff --git a/include/LightGBM/config.h b/include/LightGBM/config.h
index 73696cdb88f8..a7334cbe521c 100644
--- a/include/LightGBM/config.h
+++ b/include/LightGBM/config.h
@@ -712,6 +712,11 @@ struct Config {
// desc = **Note**: can be used only in CLI version; for language-specific packages you can use the correspondent function
bool save_binary = false;
+ // [no-save]
+ // desc = Use precise floating point number parsing for text parser (e.g. CSV, TSV, LibSVM input).
+ // desc = **Note**: setting this to ``true`` may lead to much slower text parsing.
+ bool precise_float_parser = false;
+
#pragma endregion
#pragma region Predict Parameters
diff --git a/include/LightGBM/dataset.h b/include/LightGBM/dataset.h
index 61989e221bcc..23120e682f27 100644
--- a/include/LightGBM/dataset.h
+++ b/include/LightGBM/dataset.h
@@ -252,6 +252,8 @@ class Metadata {
/*! \brief Interface for Parser */
class Parser {
public:
+ typedef const char* (*AtofFunc)(const char* p, double* out);
+
/*! \brief virtual destructor */
virtual ~Parser() {}
@@ -271,9 +273,10 @@ class Parser {
* \param filename One Filename of data
* \param num_features Pass num_features of this data file if you know, <=0 means don't know
* \param label_idx index of label column
+ * \param precise_float_parser using precise floating point number parsing if true
* \return Object of parser
*/
- static Parser* CreateParser(const char* filename, bool header, int num_features, int label_idx);
+ static Parser* CreateParser(const char* filename, bool header, int num_features, int label_idx, bool precise_float_parser);
};
/*! \brief The main class of data set,
diff --git a/src/application/application.cpp b/src/application/application.cpp
index e82cfcada98f..d9a4d7544ebc 100644
--- a/src/application/application.cpp
+++ b/src/application/application.cpp
@@ -221,7 +221,8 @@ void Application::Predict() {
if (config_.task == TaskType::KRefitTree) {
// create predictor
Predictor predictor(boosting_.get(), 0, -1, false, true, false, false, 1, 1);
- predictor.Predict(config_.data.c_str(), config_.output_result.c_str(), config_.header, config_.predict_disable_shape_check);
+ predictor.Predict(config_.data.c_str(), config_.output_result.c_str(), config_.header, config_.predict_disable_shape_check,
+ config_.precise_float_parser);
TextReader result_reader(config_.output_result.c_str(), false);
result_reader.ReadAllLines();
std::vector> pred_leaf(result_reader.Lines().size());
@@ -251,7 +252,8 @@ void Application::Predict() {
config_.pred_early_stop, config_.pred_early_stop_freq,
config_.pred_early_stop_margin);
predictor.Predict(config_.data.c_str(),
- config_.output_result.c_str(), config_.header, config_.predict_disable_shape_check);
+ config_.output_result.c_str(), config_.header, config_.predict_disable_shape_check,
+ config_.precise_float_parser);
Log::Info("Finished prediction");
}
}
diff --git a/src/application/predictor.hpp b/src/application/predictor.hpp
index 7c2241b36959..dff23add2df5 100644
--- a/src/application/predictor.hpp
+++ b/src/application/predictor.hpp
@@ -160,13 +160,14 @@ class Predictor {
* \param data_filename Filename of data
* \param result_filename Filename of output result
*/
- void Predict(const char* data_filename, const char* result_filename, bool header, bool disable_shape_check) {
+ void Predict(const char* data_filename, const char* result_filename, bool header, bool disable_shape_check, bool precise_float_parser) {
auto writer = VirtualFileWriter::Make(result_filename);
if (!writer->Init()) {
Log::Fatal("Prediction results file %s cannot be created", result_filename);
}
auto label_idx = header ? -1 : boosting_->LabelIdx();
- auto parser = std::unique_ptr(Parser::CreateParser(data_filename, header, boosting_->MaxFeatureIdx() + 1, label_idx));
+ auto parser = std::unique_ptr(Parser::CreateParser(data_filename, header, boosting_->MaxFeatureIdx() + 1, label_idx,
+ precise_float_parser));
if (parser == nullptr) {
Log::Fatal("Could not recognize the data format of data file %s", data_filename);
diff --git a/src/c_api.cpp b/src/c_api.cpp
index b0f828f5072f..3d20d92da70d 100644
--- a/src/c_api.cpp
+++ b/src/c_api.cpp
@@ -709,7 +709,8 @@ class Booster {
Predictor predictor(boosting_.get(), start_iteration, num_iteration, is_raw_score, is_predict_leaf, predict_contrib,
config.pred_early_stop, config.pred_early_stop_freq, config.pred_early_stop_margin);
bool bool_data_has_header = data_has_header > 0 ? true : false;
- predictor.Predict(data_filename, result_filename, bool_data_has_header, config.predict_disable_shape_check);
+ predictor.Predict(data_filename, result_filename, bool_data_has_header, config.predict_disable_shape_check,
+ config.precise_float_parser);
}
void GetPredictAt(int data_idx, double* out_result, int64_t* out_len) const {
diff --git a/src/io/config_auto.cpp b/src/io/config_auto.cpp
index 06c53e84268a..913c3242a3eb 100644
--- a/src/io/config_auto.cpp
+++ b/src/io/config_auto.cpp
@@ -259,6 +259,7 @@ const std::unordered_set& Config::parameter_set() {
"categorical_feature",
"forcedbins_filename",
"save_binary",
+ "precise_float_parser",
"start_iteration_predict",
"num_iteration_predict",
"predict_raw_score",
@@ -525,6 +526,8 @@ void Config::GetMembersFromString(const std::unordered_map(Parser::CreateParser(filename, config_.header, 0, label_idx_));
+ auto parser = std::unique_ptr(Parser::CreateParser(filename, config_.header, 0, label_idx_,
+ config_.precise_float_parser));
if (parser == nullptr) {
Log::Fatal("Could not recognize data format of %s", filename);
}
@@ -267,7 +268,8 @@ Dataset* DatasetLoader::LoadFromFileAlignWithOtherDataset(const char* filename,
}
auto bin_filename = CheckCanLoadFromBin(filename);
if (bin_filename.size() == 0) {
- auto parser = std::unique_ptr(Parser::CreateParser(filename, config_.header, 0, label_idx_));
+ auto parser = std::unique_ptr(Parser::CreateParser(filename, config_.header, 0, label_idx_,
+ config_.precise_float_parser));
if (parser == nullptr) {
Log::Fatal("Could not recognize data format of %s", filename);
}
diff --git a/src/io/parser.cpp b/src/io/parser.cpp
index f16c643d1d6a..550c4e13d5c0 100644
--- a/src/io/parser.cpp
+++ b/src/io/parser.cpp
@@ -6,9 +6,6 @@
#include
#include
-#include
-#include
-#include
#include
namespace LightGBM {
@@ -232,7 +229,7 @@ DataType GetDataType(const char* filename, bool header,
return type;
}
-Parser* Parser::CreateParser(const char* filename, bool header, int num_features, int label_idx) {
+Parser* Parser::CreateParser(const char* filename, bool header, int num_features, int label_idx, bool precise_float_parser) {
const int n_read_line = 32;
auto lines = ReadKLineFromFile(filename, header, n_read_line);
int num_col = 0;
@@ -242,15 +239,16 @@ Parser* Parser::CreateParser(const char* filename, bool header, int num_features
}
std::unique_ptr ret;
int output_label_index = -1;
+ AtofFunc atof = precise_float_parser ? Common::AtofPrecise : Common::Atof;
if (type == DataType::LIBSVM) {
output_label_index = GetLabelIdxForLibsvm(lines[0], num_features, label_idx);
- ret.reset(new LibSVMParser(output_label_index, num_col));
+ ret.reset(new LibSVMParser(output_label_index, num_col, atof));
} else if (type == DataType::TSV) {
output_label_index = GetLabelIdxForTSV(lines[0], num_features, label_idx);
- ret.reset(new TSVParser(output_label_index, num_col));
+ ret.reset(new TSVParser(output_label_index, num_col, atof));
} else if (type == DataType::CSV) {
output_label_index = GetLabelIdxForCSV(lines[0], num_features, label_idx);
- ret.reset(new CSVParser(output_label_index, num_col));
+ ret.reset(new CSVParser(output_label_index, num_col, atof));
}
if (output_label_index < 0 && label_idx >= 0) {
diff --git a/src/io/parser.hpp b/src/io/parser.hpp
index dad71ed31d0d..531bfad2f2d3 100644
--- a/src/io/parser.hpp
+++ b/src/io/parser.hpp
@@ -15,20 +15,10 @@
namespace LightGBM {
-#ifdef USE_PRECISE_TEXT_PARSER
-static const char* TextParserAtof(const char* p, double* out) {
- return Common::AtofPrecise(p, out);
-}
-#else
-static const char* TextParserAtof(const char* p, double* out) {
- return Common::Atof(p, out);
-}
-#endif
-
class CSVParser: public Parser {
public:
- explicit CSVParser(int label_idx, int total_columns)
- :label_idx_(label_idx), total_columns_(total_columns) {
+ explicit CSVParser(int label_idx, int total_columns, AtofFunc atof)
+ :label_idx_(label_idx), total_columns_(total_columns), atof_(atof) {
}
inline void ParseOneLine(const char* str,
std::vector>* out_features, double* out_label) const override {
@@ -37,7 +27,7 @@ class CSVParser: public Parser {
int offset = 0;
*out_label = 0.0f;
while (*str != '\0') {
- str = TextParserAtof(str, &val);
+ str = atof_(str, &val);
if (idx == label_idx_) {
*out_label = val;
offset = -1;
@@ -60,12 +50,13 @@ class CSVParser: public Parser {
private:
int label_idx_ = 0;
int total_columns_ = -1;
+ AtofFunc atof_;
};
class TSVParser: public Parser {
public:
- explicit TSVParser(int label_idx, int total_columns)
- :label_idx_(label_idx), total_columns_(total_columns) {
+ explicit TSVParser(int label_idx, int total_columns, AtofFunc atof)
+ :label_idx_(label_idx), total_columns_(total_columns), atof_(atof) {
}
inline void ParseOneLine(const char* str,
std::vector>* out_features, double* out_label) const override {
@@ -73,7 +64,7 @@ class TSVParser: public Parser {
double val = 0.0f;
int offset = 0;
while (*str != '\0') {
- str = TextParserAtof(str, &val);
+ str = atof_(str, &val);
if (idx == label_idx_) {
*out_label = val;
offset = -1;
@@ -96,12 +87,13 @@ class TSVParser: public Parser {
private:
int label_idx_ = 0;
int total_columns_ = -1;
+ AtofFunc atof_;
};
class LibSVMParser: public Parser {
public:
- explicit LibSVMParser(int label_idx, int total_columns)
- :label_idx_(label_idx), total_columns_(total_columns) {
+ explicit LibSVMParser(int label_idx, int total_columns, AtofFunc atof)
+ :label_idx_(label_idx), total_columns_(total_columns), atof_(atof) {
if (label_idx > 0) {
Log::Fatal("Label should be the first column in a LibSVM file");
}
@@ -111,7 +103,7 @@ class LibSVMParser: public Parser {
int idx = 0;
double val = 0.0f;
if (label_idx_ == 0) {
- str = TextParserAtof(str, &val);
+ str = atof_(str, &val);
*out_label = val;
str = Common::SkipSpaceAndTab(str);
}
@@ -136,6 +128,7 @@ class LibSVMParser: public Parser {
private:
int label_idx_ = 0;
int total_columns_ = -1;
+ AtofFunc atof_;
};
} // namespace LightGBM