Skip to content
This repository has been archived by the owner on Nov 30, 2019. It is now read-only.

Commit

Permalink
update example
Browse files Browse the repository at this point in the history
  • Loading branch information
mengxr committed Jan 26, 2015
1 parent d3e8dbe commit 05e3e40
Show file tree
Hide file tree
Showing 2 changed files with 15 additions and 2 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,15 @@
from pyspark.ml.classification import LogisticRegression


"""
A simple text classification pipeline that recognizes "spark" from
input text. This is to show how to create and configure a Spark ML
pipeline in Python. Run with:
bin/spark-submit examples/src/main/python/ml/simple_text_classification_pipeline.py
"""


if __name__ == "__main__":
sc = SparkContext(appName="SimpleTextClassificationPipeline")
sqlCtx = SQLContext(sc)
Expand Down Expand Up @@ -53,5 +62,9 @@
(7L, "apache hadoop")])
.map(lambda x: Row(id=x[0], text=x[1])))

for row in model.transform(test).collect():
prediction = model.transform(test)

prediction.registerTempTable("prediction")
selected = sqlCtx.sql("SELECT id, text, prediction from prediction")
for row in selected.collect():
print row
2 changes: 1 addition & 1 deletion python/pyspark/ml/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -144,7 +144,7 @@ def getStages(self):

def fit(self, dataset, params={}):
paramMap = self._merge_params(params)
stages = paramMap(self.stages)
stages = paramMap[self.stages]
for stage in stages:
if not (isinstance(stage, Estimator) or isinstance(stage, Transformer)):
raise ValueError(
Expand Down

0 comments on commit 05e3e40

Please sign in to comment.