Skip to content

Commit

Permalink
some code clean-up
Browse files Browse the repository at this point in the history
  • Loading branch information
mengxr committed Nov 6, 2014
1 parent 2d040b3 commit 6e86d98
Show file tree
Hide file tree
Showing 2 changed files with 16 additions and 17 deletions.
3 changes: 3 additions & 0 deletions mllib/src/main/scala/org/apache/spark/ml/Model.scala
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,9 @@

package org.apache.spark.ml

/**
* A trained model.
*/
abstract class Model extends Transformer {
// def parent: Estimator
// def trainingParameters: ParamMap
Expand Down
30 changes: 13 additions & 17 deletions mllib/src/main/scala/org/apache/spark/ml/Pipeline.scala
Original file line number Diff line number Diff line change
Expand Up @@ -17,34 +17,29 @@

package org.apache.spark.ml

import org.apache.spark.ml.param.{ParamMap, Param}
import org.apache.spark.sql.SchemaRDD

import scala.collection.mutable.ListBuffer

import org.apache.spark.ml.param.{Param, ParamMap}
import org.apache.spark.sql.SchemaRDD

/**
* A stage in a pipeline, either an Estimator or an Transformer.
*/
trait PipelineStage extends Identifiable

/**
* A simple pipeline, which acts as an estimator.
*/
class Pipeline extends Estimator[PipelineModel] {

val stages: Param[Array[PipelineStage]] =
new Param(this, "stages", "stages of the pipeline")

def setStages(stages: Array[PipelineStage]): this.type = {
set(this.stages, stages)
this
}

def getStages: Array[PipelineStage] = {
get(stages)
}
val stages: Param[Array[PipelineStage]] = new Param(this, "stages", "stages of the pipeline")
def setStages(value: Array[PipelineStage]): this.type = { set(stages, value); this }
def getStages: Array[PipelineStage] = get(stages)

override def fit(dataset: SchemaRDD, paramMap: ParamMap): PipelineModel = {
val map = this.paramMap ++ paramMap
val theStages = map(stages)
// Search for last estimator.
// Search for the last estimator.
var lastIndexOfEstimator = -1
theStages.view.zipWithIndex.foreach { case (stage, index) =>
stage match {
Expand Down Expand Up @@ -75,10 +70,11 @@ class Pipeline extends Estimator[PipelineModel] {

new PipelineModel(transformers.toArray)
}

override def params: Array[Param[_]] = Array.empty
}

/**
* Represents a compiled pipeline.
*/
class PipelineModel(val transformers: Array[Transformer]) extends Model {

override def transform(dataset: SchemaRDD, paramMap: ParamMap): SchemaRDD = {
Expand Down

0 comments on commit 6e86d98

Please sign in to comment.