Skip to content

Commit

Permalink
Emit predicted category using an appropriate JSON type.
Browse files Browse the repository at this point in the history
  • Loading branch information
przemekwitek committed Dec 5, 2019
1 parent 617e5b9 commit 1eb3b44
Show file tree
Hide file tree
Showing 4 changed files with 69 additions and 17 deletions.
2 changes: 2 additions & 0 deletions docs/CHANGELOG.asciidoc
Original file line number Diff line number Diff line change
Expand Up @@ -52,6 +52,8 @@ tree which is trained for both regression and classification. (See {ml-pull}811[
(See {ml-pull}818[#818].)
* Reduce memory usage of {ml} native processes on Windows. (See {ml-pull}844[#844].)
* Reduce runtime of classification and regression. (See {ml-pull}863[#863].)
* Emit `prediction_field_name` in ml results using the type of a `dependent_variable`.
(See {ml-pull}877[#877].)

=== Bug Fixes
* Fixes potential memory corruption when determining seasonality. (See {ml-pull}852[#852].)
Expand Down
5 changes: 5 additions & 0 deletions include/api/CDataFrameTrainBoostedTreeClassifierRunner.h
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,10 @@ class API_EXPORT CDataFrameTrainBoostedTreeClassifierRunner final
const TRowRef& row,
core::CRapidJsonConcurrentLineWriter& writer) const;

//! Write the predicted category value as string, int or bool.
void writePredictedCategoryValue(const std::string& categoryValue,
core::CRapidJsonConcurrentLineWriter& writer) const;

//! \return A serialisable definition of the trained classification model.
TInferenceModelDefinitionUPtr
inferenceModelDefinition(const TStrVec& fieldNames,
Expand All @@ -55,6 +59,7 @@ class API_EXPORT CDataFrameTrainBoostedTreeClassifierRunner final

private:
std::size_t m_NumTopClasses;
std::string m_DependentVariableType;
};

//! \brief Makes a core::CDataFrame boosted tree classification runner.
Expand Down
22 changes: 20 additions & 2 deletions lib/api/CDataFrameTrainBoostedTreeClassifierRunner.cc
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,7 @@ using TSizeVec = std::vector<std::size_t>;

// Configuration
const std::string NUM_TOP_CLASSES{"num_top_classes"};
const std::string DEPENDENT_VARIABLE_TYPE{"dependent_variable_type"};
const std::string BALANCED_CLASS_LOSS{"balanced_class_loss"};

// Output
Expand All @@ -47,6 +48,8 @@ CDataFrameTrainBoostedTreeClassifierRunner::parameterReader() {
static const CDataFrameAnalysisConfigReader PARAMETER_READER{[] {
auto theReader = CDataFrameTrainBoostedTreeRunner::parameterReader();
theReader.addParameter(NUM_TOP_CLASSES, CDataFrameAnalysisConfigReader::E_OptionalParameter);
theReader.addParameter(DEPENDENT_VARIABLE_TYPE,
CDataFrameAnalysisConfigReader::E_OptionalParameter);
theReader.addParameter(BALANCED_CLASS_LOSS,
CDataFrameAnalysisConfigReader::E_OptionalParameter);
return theReader;
Expand All @@ -60,6 +63,8 @@ CDataFrameTrainBoostedTreeClassifierRunner::CDataFrameTrainBoostedTreeClassifier
: CDataFrameTrainBoostedTreeRunner{spec, parameters} {

m_NumTopClasses = parameters[NUM_TOP_CLASSES].fallback(std::size_t{0});
m_DependentVariableType =
parameters[DEPENDENT_VARIABLE_TYPE].fallback(std::string("string"));
this->boostedTreeFactory().balanceClassTrainingLoss(
parameters[BALANCED_CLASS_LOSS].fallback(true));

Expand Down Expand Up @@ -119,7 +124,7 @@ void CDataFrameTrainBoostedTreeClassifierRunner::writeOneRow(

writer.StartObject();
writer.Key(this->predictionFieldName());
writer.String(categoryValues[predictedCategoryId]);
writePredictedCategoryValue(categoryValues[predictedCategoryId], writer);
writer.Key(PREDICTION_PROBABILITY_FIELD_NAME);
writer.Double(probabilityOfCategory[predictedCategoryId]);
writer.Key(IS_TRAINING_FIELD_NAME);
Expand All @@ -135,7 +140,7 @@ void CDataFrameTrainBoostedTreeClassifierRunner::writeOneRow(
for (std::size_t i = 0; i < std::min(categoryIds.size(), m_NumTopClasses); ++i) {
writer.StartObject();
writer.Key(CLASS_NAME_FIELD_NAME);
writer.String(categoryValues[categoryIds[i]]);
writePredictedCategoryValue(categoryValues[categoryIds[i]], writer);
writer.Key(CLASS_PROBABILITY_FIELD_NAME);
writer.Double(probabilityOfCategory[i]);
writer.EndObject();
Expand All @@ -158,6 +163,19 @@ void CDataFrameTrainBoostedTreeClassifierRunner::writeOneRow(
columnHoldingPrediction, row, writer);
}

void CDataFrameTrainBoostedTreeClassifierRunner::writePredictedCategoryValue(
const std::string& categoryValue,
core::CRapidJsonConcurrentLineWriter& writer) const {

if (m_DependentVariableType == "int") {
writer.Int(std::stoi(categoryValue));
} else if (m_DependentVariableType == "bool") {
writer.Bool(std::stoi(categoryValue) == 1);
} else {
writer.String(categoryValue);
}
}

CDataFrameTrainBoostedTreeClassifierRunner::TLossFunctionUPtr
CDataFrameTrainBoostedTreeClassifierRunner::chooseLossFunction(const core::CDataFrame& frame,
std::size_t dependentVariableColumn) const {
Expand Down
57 changes: 42 additions & 15 deletions lib/api/unittest/CDataFrameTrainBoostedTreeClassifierRunnerTest.cc
Original file line number Diff line number Diff line change
Expand Up @@ -45,15 +45,20 @@ BOOST_AUTO_TEST_CASE(testPredictionFieldNameClash) {
BOOST_TEST_REQUIRE(errors[0] == "Input error: prediction_field_name must not be equal to any of [is_training, prediction_probability, top_classes].");
}

BOOST_AUTO_TEST_CASE(testWriteOneRow) {
template<typename T>
void testWriteOneRow(const std::string& dependentVariableField,
const std::string& dependentVariableType,
T (rapidjson::Value::*extract)() const,
const std::vector<T>& expectedPredictions) {
// Prepare input data frame
const TStrVec columnNames{"x1", "x2", "x3", "x4", "x5", "x5_prediction"};
const TStrVec categoricalColumns{"x1", "x2", "x5"};
const std::string predictionField = dependentVariableField + "_prediction";
const TStrVec columnNames{"x1", "x2", "x3", "x4", "x5", predictionField};
const TStrVec categoricalColumns{"x1", "x2", "x3", "x4", "x5"};
const TStrVecVec rows{{"a", "b", "1.0", "1.0", "cat", "-1.0"},
{"a", "b", "2.0", "2.0", "cat", "-0.5"},
{"a", "b", "5.0", "5.0", "dog", "-0.1"},
{"c", "d", "5.0", "5.0", "dog", "1.0"},
{"e", "f", "5.0", "5.0", "dog", "1.5"}};
{"a", "b", "1.0", "1.0", "cat", "-0.5"},
{"a", "b", "5.0", "0.0", "dog", "-0.1"},
{"c", "d", "5.0", "0.0", "dog", "1.0"},
{"e", "f", "5.0", "0.0", "dog", "1.5"}};
std::unique_ptr<core::CDataFrame> frame =
core::makeMainStorageDataFrame(columnNames.size()).first;
frame->columnNames(columnNames);
Expand All @@ -67,10 +72,13 @@ BOOST_AUTO_TEST_CASE(testWriteOneRow) {

// Create classification analysis runner object
const auto spec{test::CDataFrameAnalysisSpecificationFactory::predictionSpec(
"classification", "x5", rows.size(), columnNames.size(), 13000000, 0, 0,
categoricalColumns)};
"classification", dependentVariableField, rows.size(),
columnNames.size(), 13000000, 0, 0, categoricalColumns)};
rapidjson::Document jsonParameters;
jsonParameters.Parse("{\"dependent_variable\": \"x5\"}");
jsonParameters.Parse("{"
" \"dependent_variable\": \"" + dependentVariableField + "\","
" \"dependent_variable_type\": \"" + dependentVariableType + "\""
"}");
const auto parameters{
api::CDataFrameTrainBoostedTreeClassifierRunner::parameterReader().read(jsonParameters)};
api::CDataFrameTrainBoostedTreeClassifierRunner runner(*spec, parameters);
Expand All @@ -83,10 +91,10 @@ BOOST_AUTO_TEST_CASE(testWriteOneRow) {

frame->readRows(1, [&](TRowItr beginRows, TRowItr endRows) {
const auto columnHoldingDependentVariable{
std::find(columnNames.begin(), columnNames.end(), "x5") -
std::find(columnNames.begin(), columnNames.end(), dependentVariableField) -
columnNames.begin()};
const auto columnHoldingPrediction{
std::find(columnNames.begin(), columnNames.end(), "x5_prediction") -
std::find(columnNames.begin(), columnNames.end(), predictionField) -
columnNames.begin()};
for (auto row = beginRows; row != endRows; ++row) {
runner.writeOneRow(*frame, columnHoldingDependentVariable,
Expand All @@ -95,17 +103,17 @@ BOOST_AUTO_TEST_CASE(testWriteOneRow) {
});
}
// Verify results
const TStrVec expectedPredictions{"cat", "cat", "cat", "dog", "dog"};
rapidjson::Document arrayDoc;
arrayDoc.Parse<rapidjson::kParseDefaultFlags>(output.str().c_str());
BOOST_TEST_REQUIRE(arrayDoc.IsArray());
BOOST_TEST_REQUIRE(arrayDoc.Size() == rows.size());
BOOST_TEST_REQUIRE(arrayDoc.Size() == expectedPredictions.size());
for (std::size_t i = 0; i < arrayDoc.Size(); ++i) {
BOOST_TEST_CONTEXT("Result for row " << i) {
const rapidjson::Value& object = arrayDoc[rapidjson::SizeType(i)];
BOOST_TEST_REQUIRE(object.IsObject());
BOOST_TEST_REQUIRE(object.HasMember("x5_prediction"));
BOOST_TEST_REQUIRE(object["x5_prediction"].GetString() ==
BOOST_TEST_REQUIRE(object.HasMember(predictionField));
BOOST_TEST_REQUIRE((object[predictionField].*extract)() ==
expectedPredictions[i]);
BOOST_TEST_REQUIRE(object.HasMember("prediction_probability"));
BOOST_TEST_REQUIRE(object["prediction_probability"].GetDouble() > 0.5);
Expand All @@ -115,4 +123,23 @@ BOOST_AUTO_TEST_CASE(testWriteOneRow) {
}
}

BOOST_AUTO_TEST_CASE(testWriteOneRow_DependentVariableIsInt) {
testWriteOneRow("x3", "int", &rapidjson::Value::GetInt, {1, 1, 1, 5, 5});
}

BOOST_AUTO_TEST_CASE(testWriteOneRow_DependentVariableIsBool) {
testWriteOneRow("x4", "bool", &rapidjson::Value::GetBool,
{true, true, true, false, false});
}

BOOST_AUTO_TEST_CASE(testWriteOneRow_DependentVariableIsString) {
testWriteOneRow("x5", "string", &rapidjson::Value::GetString,
{"cat", "cat", "cat", "dog", "dog"});
}

BOOST_AUTO_TEST_CASE(testWriteOneRow_DependentVariableTypeMissing) {
testWriteOneRow("x5", "", &rapidjson::Value::GetString,
{"cat", "cat", "cat", "dog", "dog"});
}

BOOST_AUTO_TEST_SUITE_END()

0 comments on commit 1eb3b44

Please sign in to comment.