Skip to content

Commit

Permalink
[SPARK-20043][ML] DecisionTreeModel: ImpurityCalculator builder fails…
Browse files Browse the repository at this point in the history
… for uppercase impurity type Gini

Fix bug: DecisionTreeModel can't recongnize Impurity "Gini" when loading

TODO:
+ [x] add unit test
+ [x] fix the bug

Author: 颜发才(Yan Facai) <facai.yan@gmail.com>

Closes #17407 from facaiy/BUG/decision_tree_loader_failer_with_Gini_impurity.

(cherry picked from commit 7d432af)
Signed-off-by: Joseph K. Bradley <joseph@databricks.com>
  • Loading branch information
facaiy authored and jkbradley committed Mar 28, 2017
1 parent 4964dbe commit 3095480
Show file tree
Hide file tree
Showing 2 changed files with 15 additions and 1 deletion.
Original file line number Diff line number Diff line change
Expand Up @@ -184,7 +184,7 @@ private[spark] object ImpurityCalculator {
* the given stats.
*/
def getCalculator(impurity: String, stats: Array[Double]): ImpurityCalculator = {
impurity match {
impurity.toLowerCase match {
case "gini" => new GiniCalculator(stats)
case "entropy" => new EntropyCalculator(stats)
case "variance" => new VarianceCalculator(stats)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -383,6 +383,20 @@ class DecisionTreeClassifierSuite
testEstimatorAndModelReadWrite(dt, continuousData, allParamSettings ++ Map("maxDepth" -> 0),
checkModelData)
}

test("SPARK-20043: " +
"ImpurityCalculator builder fails for uppercase impurity type Gini in model read/write") {
val rdd = TreeTests.getTreeReadWriteData(sc)
val data: DataFrame =
TreeTests.setMetadata(rdd, Map.empty[Int, Int], numClasses = 2)

val dt = new DecisionTreeClassifier()
.setImpurity("Gini")
.setMaxDepth(2)
val model = dt.fit(data)

testDefaultReadWrite(model)
}
}

private[ml] object DecisionTreeClassifierSuite extends SparkFunSuite {
Expand Down

0 comments on commit 3095480

Please sign in to comment.