Skip to content

Commit

Permalink
Python API for ChiSqSelector
Browse files Browse the repository at this point in the history
  • Loading branch information
yanboliang committed May 6, 2015
1 parent 32cdc81 commit cdaac99
Show file tree
Hide file tree
Showing 2 changed files with 67 additions and 2 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -494,6 +494,16 @@ private[python] class PythonMLLibAPI extends Serializable {
new StandardScaler(withMean, withStd).fit(data.rdd)
}

/**
* Java stub for ChiSqSelector.fit(). This stub returns a
* handle to the Java object instead of the content of the Java object.
* Extra care needs to be taken in the Python code to ensure it gets freed on
* exit; see the Py4J documentation.
*/
def fitChiSqSelector(numTopFeatures: Int, data: JavaRDD[LabeledPoint]): ChiSqSelectorModel = {
new ChiSqSelector(numTopFeatures).fit(data.rdd)
}

/**
* Java stub for IDF.fit(). This stub returns a
* handle to the Java object instead of the content of the Java object.
Expand Down
59 changes: 57 additions & 2 deletions python/pyspark/mllib/feature.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,10 +33,12 @@
from pyspark import SparkContext
from pyspark.rdd import RDD, ignore_unicode_prefix
from pyspark.mllib.common import callMLlibFunc, JavaModelWrapper
from pyspark.mllib.linalg import Vectors, _convert_to_vector
from pyspark.mllib.linalg import Vectors, DenseVector, SparseVector, _convert_to_vector
from pyspark.mllib.regression import LabeledPoint

__all__ = ['Normalizer', 'StandardScalerModel', 'StandardScaler',
'HashingTF', 'IDFModel', 'IDF', 'Word2Vec', 'Word2VecModel']
'HashingTF', 'IDFModel', 'IDF', 'Word2Vec', 'Word2VecModel',
'ChiSqSelector', 'ChiSqSelectorModel']


class VectorTransformer(object):
Expand Down Expand Up @@ -199,6 +201,59 @@ def fit(self, dataset):
return StandardScalerModel(jmodel)


class ChiSqSelectorModel(JavaVectorTransformer):
"""
.. note:: Experimental
Represents a Chi Squared selector model.
"""
def transform(self, vector):
"""
Applies transformation on a vector.
:param vector: Vector or RDD of Vector to be transformed.
:return: transformed vector.
"""
return JavaVectorTransformer.transform(self, vector)


class ChiSqSelector(object):
"""
.. note:: Experimental
Creates a ChiSquared feature selector.
>>> data = [
... LabeledPoint(0.0, SparseVector(3, {0: 8.0, 1: 7.0})),
... LabeledPoint(1.0, SparseVector(3, {1: 9.0, 2: 6.0})),
... LabeledPoint(1.0, [0.0, 9.0, 8.0]),
... LabeledPoint(2.0, [8.0, 9.0, 5.0])
... ]
>>> model = ChiSqSelector(1).fit(sc.parallelize(data))
>>> model.transform(SparseVector(3, {1: 9.0, 2: 6.0}))
SparseVector(1, {0: 6.0})
>>> model.transform(DenseVector([8.0, 9.0, 5.0]))
DenseVector([5.0])
"""
def __init__(self, numTopFeatures):
"""
:param numTopFeatures: number of features that selector will select.
"""
self.numTopFeatures = int(numTopFeatures)

def fit(self, data):
"""
Returns a ChiSquared feature selector.
:param data: an `RDD[LabeledPoint]` containing the labeled dataset
with categorical features. Real-valued features will be
treated as categorical for each distinct value.
Apply feature discretizer before using this function.
"""
jmodel = callMLlibFunc("fitChiSqSelector", self.numTopFeatures, data)
return ChiSqSelectorModel(jmodel)


class HashingTF(object):
"""
.. note:: Experimental
Expand Down

0 comments on commit cdaac99

Please sign in to comment.