-
Notifications
You must be signed in to change notification settings - Fork 1
/
decisionTree.hpp
33 lines (25 loc) · 1.06 KB
/
decisionTree.hpp
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
#pragma once
//----------------------------------------------------------------------------
//----------------------------------------------------------------------------
void trainDecisionTree(const cv::Ptr<cv::ml::TrainData> &dataset)
{
auto decision_tree = cv::ml::DTrees::create();
decision_tree->setMaxCategories(2);
decision_tree->setMaxDepth(20);
decision_tree->setMinSampleCount(1);
decision_tree->setTruncatePrunedTree(true);
decision_tree->setUse1SERule(true);
decision_tree->setUseSurrogates(false);
decision_tree->setCVFolds(1);
decision_tree->train(dataset);
decision_tree->save("DecisionTree.xml");
}
//----------------------------------------------------------------------------
//----------------------------------------------------------------------------
float testDecisionTree(const cv::Ptr<cv::ml::TrainData> &dataset)
{
auto decision_tree = cv::ml::DTrees::load("DecisionTree.xml");
std::vector<int32_t> predictions;
auto error = decision_tree->calcError(dataset, true, predictions);
return error;
}