Skip to content

Commit

Permalink
fix predict
Browse files Browse the repository at this point in the history
  • Loading branch information
Davies Liu committed Nov 12, 2014
1 parent 84324fb commit 81172aa
Show file tree
Hide file tree
Showing 2 changed files with 26 additions and 12 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@

package org.apache.spark.mllib.tree.model

import org.apache.spark.api.java.JavaRDD
import org.apache.spark.annotation.Experimental
import org.apache.spark.mllib.tree.configuration.Algo._
import org.apache.spark.rdd.RDD
Expand Down Expand Up @@ -52,6 +53,17 @@ class DecisionTreeModel(val topNode: Node, val algo: Algo) extends Serializable
features.map(x => predict(x))
}


/**
* Predict values for the given data set using the model trained.
*
* @param features JavaRDD representing data points to be predicted
* @return JavaRDD of predictions for each of the given data points
*/
def predict(features: JavaRDD[Vector]): JavaRDD[Double] = {
predict(features.rdd)
}

/**
* Get number of nodes in tree, including leaf nodes.
*/
Expand Down
26 changes: 14 additions & 12 deletions python/pyspark/mllib/tree.py
Original file line number Diff line number Diff line change
Expand Up @@ -124,10 +124,13 @@ def trainClassifier(data, numClasses, categoricalFeaturesInfo,
Predict: 0.0
Else (feature 0 > 0.0)
Predict: 1.0
>>> model.predict(array([1.0])) > 0
True
>>> model.predict(array([0.0])) == 0
True
>>> model.predict(array([1.0]))
1.0
>>> model.predict(array([0.0]))
0.0
>>> rdd = sc.parallelize([[1.0], [0.0]])
>>> model.predict(rdd).collect()
[1.0, 0.0]
"""
return DecisionTree._train(data, "classification", numClasses, categoricalFeaturesInfo,
impurity, maxDepth, maxBins, minInstancesPerNode, minInfoGain)
Expand Down Expand Up @@ -170,14 +173,13 @@ def trainRegressor(data, categoricalFeaturesInfo,
... ]
>>>
>>> model = DecisionTree.trainRegressor(sc.parallelize(sparse_data), {})
>>> model.predict(array([0.0, 1.0])) == 1
True
>>> model.predict(array([0.0, 0.0])) == 0
True
>>> model.predict(SparseVector(2, {1: 1.0})) == 1
True
>>> model.predict(SparseVector(2, {1: 0.0})) == 0
True
>>> model.predict(SparseVector(2, {1: 1.0}))
1.0
>>> model.predict(SparseVector(2, {1: 0.0}))
0.0
>>> rdd = sc.parallelize([[0.0, 1.0], [0.0, 0.0]])
>>> model.predict(rdd).collect()
[1.0, 0.0]
"""
return DecisionTree._train(data, "regression", 0, categoricalFeaturesInfo,
impurity, maxDepth, maxBins, minInstancesPerNode, minInfoGain)
Expand Down

0 comments on commit 81172aa

Please sign in to comment.