-
Notifications
You must be signed in to change notification settings - Fork 28.4k
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Joint work with @hirakendu, @etrain, @atalwalkar and @harsha2010. Key features: + Supports binary classification and regression + Supports gini, entropy and variance for information gain calculation + Supports both continuous and categorical features The algorithm has gone through several development iterations over the last few months leading to a highly optimized implementation. Optimizations include: 1. Level-wise training to reduce passes over the entire dataset. 2. Bin-wise split calculation to reduce computation overhead. 3. Aggregation over partitions before combining to reduce communication overhead. Author: Manish Amde <manish9ue@gmail.com> Author: manishamde <manish9ue@gmail.com> Author: Xiangrui Meng <meng@databricks.com> Closes #79 from manishamde/tree and squashes the following commits: 1e8c704 [Manish Amde] remove numBins field in the Strategy class 7d54b4f [manishamde] Merge pull request #4 from mengxr/dtree f536ae9 [Xiangrui Meng] another pass on code style e1dd86f [Manish Amde] implementing code style suggestions 62dc723 [Manish Amde] updating javadoc and converting helper methods to package private to allow unit testing 201702f [Manish Amde] making some more methods private f963ef5 [Manish Amde] making methods private c487e6a [manishamde] Merge pull request #1 from mengxr/dtree 24500c5 [Xiangrui Meng] minor style updates 4576b64 [Manish Amde] documentation and for to while loop conversion ff363a7 [Manish Amde] binary search for bins and while loop for categorical feature bins 632818f [Manish Amde] removing threshold for classification predict method 2116360 [Manish Amde] removing dummy bin calculation for categorical variables 6068356 [Manish Amde] ensuring num bins is always greater than max number of categories 62c2562 [Manish Amde] fixing comment indentation ad1fc21 [Manish Amde] incorporated mengxr's code style suggestions d1ef4f6 [Manish Amde] more documentation 794ff4d [Manish Amde] minor improvements to docs and style eb8fcbe [Manish Amde] minor code style updates cd2c2b4 [Manish Amde] fixing code style based on feedback 63e786b [Manish Amde] added multiple train methods for java compatability d3023b3 [Manish Amde] adding more docs for nested methods 84f85d6 [Manish Amde] code documentation 9372779 [Manish Amde] code style: max line lenght <= 100 dd0c0d7 [Manish Amde] minor: some docs 0dd7659 [manishamde] basic doc 5841c28 [Manish Amde] unit tests for categorical features f067d68 [Manish Amde] minor cleanup c0e522b [Manish Amde] updated predict and split threshold logic b09dc98 [Manish Amde] minor refactoring 6b7de78 [Manish Amde] minor refactoring and tests d504eb1 [Manish Amde] more tests for categorical features dbb7ac1 [Manish Amde] categorical feature support 6df35b9 [Manish Amde] regression predict logic 53108ed [Manish Amde] fixing index for highest bin e23c2e5 [Manish Amde] added regression support c8f6d60 [Manish Amde] adding enum for feature type b0e3e76 [Manish Amde] adding enum for feature type 154aa77 [Manish Amde] enums for configurations 733d6dd [Manish Amde] fixed tests 02c595c [Manish Amde] added command line parsing 98ec8d5 [Manish Amde] tree building and prediction logic b0eb866 [Manish Amde] added logic to handle leaf nodes 80e8c66 [Manish Amde] working version of multi-level split calculation 4798aae [Manish Amde] added gain stats class dad0afc [Manish Amde] decison stump functionality working 03f534c [Manish Amde] some more tests 0012a77 [Manish Amde] basic stump working 8bca1e2 [Manish Amde] additional code for creating intermediate RDD 92cedce [Manish Amde] basic building blocks for intermediate RDD calculation. untested. cd53eae [Manish Amde] skeletal framework
- Loading branch information
1 parent
45df912
commit 8b3045c
Showing
17 changed files
with
2,188 additions
and
0 deletions.
There are no files selected for viewing
1,150 changes: 1,150 additions & 0 deletions
1,150
mllib/src/main/scala/org/apache/spark/mllib/tree/DecisionTree.scala
Large diffs are not rendered by default.
Oops, something went wrong.
17 changes: 17 additions & 0 deletions
17
mllib/src/main/scala/org/apache/spark/mllib/tree/README.md
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,17 @@ | ||
This package contains the default implementation of the decision tree algorithm. | ||
|
||
The decision tree algorithm supports: | ||
+ Binary classification | ||
+ Regression | ||
+ Information loss calculation with entropy and gini for classification and variance for regression | ||
+ Both continuous and categorical features | ||
|
||
# Tree improvements | ||
+ Node model pruning | ||
+ Printing to dot files | ||
|
||
# Future Ensemble Extensions | ||
|
||
+ Random forests | ||
+ Boosting | ||
+ Extremely randomized trees |
26 changes: 26 additions & 0 deletions
26
mllib/src/main/scala/org/apache/spark/mllib/tree/configuration/Algo.scala
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,26 @@ | ||
/* | ||
* Licensed to the Apache Software Foundation (ASF) under one or more | ||
* contributor license agreements. See the NOTICE file distributed with | ||
* this work for additional information regarding copyright ownership. | ||
* The ASF licenses this file to You under the Apache License, Version 2.0 | ||
* (the "License"); you may not use this file except in compliance with | ||
* the License. You may obtain a copy of the License at | ||
* | ||
* http://www.apache.org/licenses/LICENSE-2.0 | ||
* | ||
* Unless required by applicable law or agreed to in writing, software | ||
* distributed under the License is distributed on an "AS IS" BASIS, | ||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
* See the License for the specific language governing permissions and | ||
* limitations under the License. | ||
*/ | ||
|
||
package org.apache.spark.mllib.tree.configuration | ||
|
||
/** | ||
* Enum to select the algorithm for the decision tree | ||
*/ | ||
object Algo extends Enumeration { | ||
type Algo = Value | ||
val Classification, Regression = Value | ||
} |
26 changes: 26 additions & 0 deletions
26
mllib/src/main/scala/org/apache/spark/mllib/tree/configuration/FeatureType.scala
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,26 @@ | ||
/* | ||
* Licensed to the Apache Software Foundation (ASF) under one or more | ||
* contributor license agreements. See the NOTICE file distributed with | ||
* this work for additional information regarding copyright ownership. | ||
* The ASF licenses this file to You under the Apache License, Version 2.0 | ||
* (the "License"); you may not use this file except in compliance with | ||
* the License. You may obtain a copy of the License at | ||
* | ||
* http://www.apache.org/licenses/LICENSE-2.0 | ||
* | ||
* Unless required by applicable law or agreed to in writing, software | ||
* distributed under the License is distributed on an "AS IS" BASIS, | ||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
* See the License for the specific language governing permissions and | ||
* limitations under the License. | ||
*/ | ||
|
||
package org.apache.spark.mllib.tree.configuration | ||
|
||
/** | ||
* Enum to describe whether a feature is "continuous" or "categorical" | ||
*/ | ||
object FeatureType extends Enumeration { | ||
type FeatureType = Value | ||
val Continuous, Categorical = Value | ||
} |
26 changes: 26 additions & 0 deletions
26
mllib/src/main/scala/org/apache/spark/mllib/tree/configuration/QuantileStrategy.scala
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,26 @@ | ||
/* | ||
* Licensed to the Apache Software Foundation (ASF) under one or more | ||
* contributor license agreements. See the NOTICE file distributed with | ||
* this work for additional information regarding copyright ownership. | ||
* The ASF licenses this file to You under the Apache License, Version 2.0 | ||
* (the "License"); you may not use this file except in compliance with | ||
* the License. You may obtain a copy of the License at | ||
* | ||
* http://www.apache.org/licenses/LICENSE-2.0 | ||
* | ||
* Unless required by applicable law or agreed to in writing, software | ||
* distributed under the License is distributed on an "AS IS" BASIS, | ||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
* See the License for the specific language governing permissions and | ||
* limitations under the License. | ||
*/ | ||
|
||
package org.apache.spark.mllib.tree.configuration | ||
|
||
/** | ||
* Enum for selecting the quantile calculation strategy | ||
*/ | ||
object QuantileStrategy extends Enumeration { | ||
type QuantileStrategy = Value | ||
val Sort, MinMax, ApproxHist = Value | ||
} |
43 changes: 43 additions & 0 deletions
43
mllib/src/main/scala/org/apache/spark/mllib/tree/configuration/Strategy.scala
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,43 @@ | ||
/* | ||
* Licensed to the Apache Software Foundation (ASF) under one or more | ||
* contributor license agreements. See the NOTICE file distributed with | ||
* this work for additional information regarding copyright ownership. | ||
* The ASF licenses this file to You under the Apache License, Version 2.0 | ||
* (the "License"); you may not use this file except in compliance with | ||
* the License. You may obtain a copy of the License at | ||
* | ||
* http://www.apache.org/licenses/LICENSE-2.0 | ||
* | ||
* Unless required by applicable law or agreed to in writing, software | ||
* distributed under the License is distributed on an "AS IS" BASIS, | ||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
* See the License for the specific language governing permissions and | ||
* limitations under the License. | ||
*/ | ||
|
||
package org.apache.spark.mllib.tree.configuration | ||
|
||
import org.apache.spark.mllib.tree.impurity.Impurity | ||
import org.apache.spark.mllib.tree.configuration.Algo._ | ||
import org.apache.spark.mllib.tree.configuration.QuantileStrategy._ | ||
|
||
/** | ||
* Stores all the configuration options for tree construction | ||
* @param algo classification or regression | ||
* @param impurity criterion used for information gain calculation | ||
* @param maxDepth maximum depth of the tree | ||
* @param maxBins maximum number of bins used for splitting features | ||
* @param quantileCalculationStrategy algorithm for calculating quantiles | ||
* @param categoricalFeaturesInfo A map storing information about the categorical variables and the | ||
* number of discrete values they take. For example, an entry (n -> | ||
* k) implies the feature n is categorical with k categories 0, | ||
* 1, 2, ... , k-1. It's important to note that features are | ||
* zero-indexed. | ||
*/ | ||
class Strategy ( | ||
val algo: Algo, | ||
val impurity: Impurity, | ||
val maxDepth: Int, | ||
val maxBins: Int = 100, | ||
val quantileCalculationStrategy: QuantileStrategy = Sort, | ||
val categoricalFeaturesInfo: Map[Int,Int] = Map[Int,Int]()) extends Serializable |
47 changes: 47 additions & 0 deletions
47
mllib/src/main/scala/org/apache/spark/mllib/tree/impurity/Entropy.scala
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,47 @@ | ||
/* | ||
* Licensed to the Apache Software Foundation (ASF) under one or more | ||
* contributor license agreements. See the NOTICE file distributed with | ||
* this work for additional information regarding copyright ownership. | ||
* The ASF licenses this file to You under the Apache License, Version 2.0 | ||
* (the "License"); you may not use this file except in compliance with | ||
* the License. You may obtain a copy of the License at | ||
* | ||
* http://www.apache.org/licenses/LICENSE-2.0 | ||
* | ||
* Unless required by applicable law or agreed to in writing, software | ||
* distributed under the License is distributed on an "AS IS" BASIS, | ||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
* See the License for the specific language governing permissions and | ||
* limitations under the License. | ||
*/ | ||
|
||
package org.apache.spark.mllib.tree.impurity | ||
|
||
/** | ||
* Class for calculating [[http://en.wikipedia.org/wiki/Binary_entropy_function entropy]] during | ||
* binary classification. | ||
*/ | ||
object Entropy extends Impurity { | ||
|
||
def log2(x: Double) = scala.math.log(x) / scala.math.log(2) | ||
|
||
/** | ||
* entropy calculation | ||
* @param c0 count of instances with label 0 | ||
* @param c1 count of instances with label 1 | ||
* @return entropy value | ||
*/ | ||
def calculate(c0: Double, c1: Double): Double = { | ||
if (c0 == 0 || c1 == 0) { | ||
0 | ||
} else { | ||
val total = c0 + c1 | ||
val f0 = c0 / total | ||
val f1 = c1 / total | ||
-(f0 * log2(f0)) - (f1 * log2(f1)) | ||
} | ||
} | ||
|
||
def calculate(count: Double, sum: Double, sumSquares: Double): Double = | ||
throw new UnsupportedOperationException("Entropy.calculate") | ||
} |
46 changes: 46 additions & 0 deletions
46
mllib/src/main/scala/org/apache/spark/mllib/tree/impurity/Gini.scala
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,46 @@ | ||
/* | ||
* Licensed to the Apache Software Foundation (ASF) under one or more | ||
* contributor license agreements. See the NOTICE file distributed with | ||
* this work for additional information regarding copyright ownership. | ||
* The ASF licenses this file to You under the Apache License, Version 2.0 | ||
* (the "License"); you may not use this file except in compliance with | ||
* the License. You may obtain a copy of the License at | ||
* | ||
* http://www.apache.org/licenses/LICENSE-2.0 | ||
* | ||
* Unless required by applicable law or agreed to in writing, software | ||
* distributed under the License is distributed on an "AS IS" BASIS, | ||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
* See the License for the specific language governing permissions and | ||
* limitations under the License. | ||
*/ | ||
|
||
package org.apache.spark.mllib.tree.impurity | ||
|
||
/** | ||
* Class for calculating the | ||
* [[http://en.wikipedia.org/wiki/Decision_tree_learning#Gini_impurity Gini impurity]] | ||
* during binary classification. | ||
*/ | ||
object Gini extends Impurity { | ||
|
||
/** | ||
* Gini coefficient calculation | ||
* @param c0 count of instances with label 0 | ||
* @param c1 count of instances with label 1 | ||
* @return Gini coefficient value | ||
*/ | ||
override def calculate(c0: Double, c1: Double): Double = { | ||
if (c0 == 0 || c1 == 0) { | ||
0 | ||
} else { | ||
val total = c0 + c1 | ||
val f0 = c0 / total | ||
val f1 = c1 / total | ||
1 - f0 * f0 - f1 * f1 | ||
} | ||
} | ||
|
||
def calculate(count: Double, sum: Double, sumSquares: Double): Double = | ||
throw new UnsupportedOperationException("Gini.calculate") | ||
} |
42 changes: 42 additions & 0 deletions
42
mllib/src/main/scala/org/apache/spark/mllib/tree/impurity/Impurity.scala
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,42 @@ | ||
/* | ||
* Licensed to the Apache Software Foundation (ASF) under one or more | ||
* contributor license agreements. See the NOTICE file distributed with | ||
* this work for additional information regarding copyright ownership. | ||
* The ASF licenses this file to You under the Apache License, Version 2.0 | ||
* (the "License"); you may not use this file except in compliance with | ||
* the License. You may obtain a copy of the License at | ||
* | ||
* http://www.apache.org/licenses/LICENSE-2.0 | ||
* | ||
* Unless required by applicable law or agreed to in writing, software | ||
* distributed under the License is distributed on an "AS IS" BASIS, | ||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
* See the License for the specific language governing permissions and | ||
* limitations under the License. | ||
*/ | ||
|
||
package org.apache.spark.mllib.tree.impurity | ||
|
||
/** | ||
* Trait for calculating information gain. | ||
*/ | ||
trait Impurity extends Serializable { | ||
|
||
/** | ||
* information calculation for binary classification | ||
* @param c0 count of instances with label 0 | ||
* @param c1 count of instances with label 1 | ||
* @return information value | ||
*/ | ||
def calculate(c0 : Double, c1 : Double): Double | ||
|
||
/** | ||
* information calculation for regression | ||
* @param count number of instances | ||
* @param sum sum of labels | ||
* @param sumSquares summation of squares of the labels | ||
* @return information value | ||
*/ | ||
def calculate(count: Double, sum: Double, sumSquares: Double): Double | ||
|
||
} |
37 changes: 37 additions & 0 deletions
37
mllib/src/main/scala/org/apache/spark/mllib/tree/impurity/Variance.scala
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,37 @@ | ||
/* | ||
* Licensed to the Apache Software Foundation (ASF) under one or more | ||
* contributor license agreements. See the NOTICE file distributed with | ||
* this work for additional information regarding copyright ownership. | ||
* The ASF licenses this file to You under the Apache License, Version 2.0 | ||
* (the "License"); you may not use this file except in compliance with | ||
* the License. You may obtain a copy of the License at | ||
* | ||
* http://www.apache.org/licenses/LICENSE-2.0 | ||
* | ||
* Unless required by applicable law or agreed to in writing, software | ||
* distributed under the License is distributed on an "AS IS" BASIS, | ||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
* See the License for the specific language governing permissions and | ||
* limitations under the License. | ||
*/ | ||
|
||
package org.apache.spark.mllib.tree.impurity | ||
|
||
/** | ||
* Class for calculating variance during regression | ||
*/ | ||
object Variance extends Impurity { | ||
override def calculate(c0: Double, c1: Double): Double = | ||
throw new UnsupportedOperationException("Variance.calculate") | ||
|
||
/** | ||
* variance calculation | ||
* @param count number of instances | ||
* @param sum sum of labels | ||
* @param sumSquares summation of squares of the labels | ||
*/ | ||
override def calculate(count: Double, sum: Double, sumSquares: Double): Double = { | ||
val squaredLoss = sumSquares - (sum * sum) / count | ||
squaredLoss / count | ||
} | ||
} |
33 changes: 33 additions & 0 deletions
33
mllib/src/main/scala/org/apache/spark/mllib/tree/model/Bin.scala
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,33 @@ | ||
/* | ||
* Licensed to the Apache Software Foundation (ASF) under one or more | ||
* contributor license agreements. See the NOTICE file distributed with | ||
* this work for additional information regarding copyright ownership. | ||
* The ASF licenses this file to You under the Apache License, Version 2.0 | ||
* (the "License"); you may not use this file except in compliance with | ||
* the License. You may obtain a copy of the License at | ||
* | ||
* http://www.apache.org/licenses/LICENSE-2.0 | ||
* | ||
* Unless required by applicable law or agreed to in writing, software | ||
* distributed under the License is distributed on an "AS IS" BASIS, | ||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
* See the License for the specific language governing permissions and | ||
* limitations under the License. | ||
*/ | ||
|
||
package org.apache.spark.mllib.tree.model | ||
|
||
import org.apache.spark.mllib.tree.configuration.FeatureType._ | ||
|
||
/** | ||
* Used for "binning" the features bins for faster best split calculation. For a continuous | ||
* feature, a bin is determined by a low and a high "split". For a categorical feature, | ||
* the a bin is determined using a single label value (category). | ||
* @param lowSplit signifying the lower threshold for the continuous feature to be | ||
* accepted in the bin | ||
* @param highSplit signifying the upper threshold for the continuous feature to be | ||
* accepted in the bin | ||
* @param featureType type of feature -- categorical or continuous | ||
* @param category categorical label value accepted in the bin | ||
*/ | ||
case class Bin(lowSplit: Split, highSplit: Split, featureType: FeatureType, category: Double) |
Oops, something went wrong.