Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[SPARK-20043][ML] DecisionTreeModel: ImpurityCalculator builder fails for uppercase impurity type Gini #17407

Closed
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -184,7 +184,7 @@ private[spark] object ImpurityCalculator {
* the given stats.
*/
def getCalculator(impurity: String, stats: Array[Double]): ImpurityCalculator = {
impurity match {
impurity.toLowerCase match {
case "gini" => new GiniCalculator(stats)
case "entropy" => new EntropyCalculator(stats)
case "variance" => new VarianceCalculator(stats)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -385,6 +385,20 @@ class DecisionTreeClassifierSuite
testEstimatorAndModelReadWrite(dt, continuousData, allParamSettings ++ Map("maxDepth" -> 0),
allParamSettings ++ Map("maxDepth" -> 0), checkModelData)
}

test("read/write: ImpurityCalculator builder did not recognize impurity type: Gini") {
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

To be more specific, how about:
"SAPRK-20043: ImpurityCalculator builder fails for uppercase impurity type Gini in model read/write"

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

done.

val rdd = TreeTests.getTreeReadWriteData(sc)

val categoricalData: DataFrame =
TreeTests.setMetadata(rdd, Map(0 -> 2, 1 -> 3), numClasses = 2)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

To simplify this, you can write TreeTests.setMetadata(rdd, Map.empty[Int, Int], numClasses = 2).
Since this is not testing categorical features, there's no need to throw them in.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

done.


// BUG: see SPARK-20043
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'd put the JIRA number in the test title.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

removed the comment.

val dt = new DecisionTreeClassifier().setImpurity("Gini")
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

To make this faster, set maxDepth = 2 (something small)

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

set maxDepth = 2.


val model = dt.fit(categoricalData)

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The blank lines kinda stand out.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

delete some blank lines to keep compact.

testDefaultReadWrite(model, false)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

how about setting testParams=true for this case.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

done.

}
}

private[ml] object DecisionTreeClassifierSuite extends SparkFunSuite {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -178,6 +178,20 @@ class DecisionTreeRegressorSuite
TreeTests.allParamSettings ++ Map("maxDepth" -> 0),
TreeTests.allParamSettings ++ Map("maxDepth" -> 0), checkModelData)
}

test("read/write: ImpurityCalculator builder did not recognize impurity type: Variance") {
val rdd = TreeTests.getTreeReadWriteData(sc)

val continuousData: DataFrame =
TreeTests.setMetadata(rdd, Map.empty[Int, Int], numClasses = 0)

// BUG: see SPARK-20043
val dt = new DecisionTreeRegressor().setImpurity("Variance")

val model = dt.fit(continuousData)

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

the second unit test seems redundant for this case.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

removed.

testDefaultReadWrite(model, false)
}
}

private[ml] object DecisionTreeRegressorSuite extends SparkFunSuite {
Expand Down