Skip to content

Commit

Permalink
update pipeline
Browse files Browse the repository at this point in the history
  • Loading branch information
mengxr committed Nov 11, 2014
1 parent 4f9e34f commit 6e7c1c7
Show file tree
Hide file tree
Showing 2 changed files with 30 additions and 17 deletions.
37 changes: 20 additions & 17 deletions mllib/src/main/scala/org/apache/spark/ml/Pipeline.scala
Original file line number Diff line number Diff line change
Expand Up @@ -84,31 +84,31 @@ class Pipeline extends Estimator[PipelineModel] {
val map = this.paramMap ++ paramMap
val theStages = map(stages)
// Search for the last estimator.
var lastIndexOfEstimator = -1
var indexOfLastEstimator = -1
theStages.view.zipWithIndex.foreach { case (stage, index) =>
stage match {
case _: Estimator[_] =>
lastIndexOfEstimator = index
indexOfLastEstimator = index
case _ =>
}
}
var curDataset = dataset
val transformers = ListBuffer.empty[Transformer]
theStages.view.zipWithIndex.foreach { case (stage, index) =>
stage match {
case estimator: Estimator[_] =>
val transformer = estimator.fit(curDataset, paramMap)
if (index < lastIndexOfEstimator) {
curDataset = transformer.transform(curDataset, paramMap)
}
transformers += transformer
case transformer: Transformer =>
if (index < lastIndexOfEstimator) {
curDataset = transformer.transform(curDataset, paramMap)
}
transformers += transformer
case _ =>
throw new IllegalArgumentException
if (index <= indexOfLastEstimator) {
val transformer = stage match {
case estimator: Estimator[_] =>
estimator.fit(curDataset, paramMap)
case t: Transformer =>
t
case _ =>
throw new IllegalArgumentException(
s"Do not support stage $stage of type ${stage.getClass}")
}
curDataset = transformer.transform(curDataset, paramMap)
transformers += transformer
} else {
transformers += stage.asInstanceOf[Transformer]
}
}

Expand All @@ -117,7 +117,10 @@ class Pipeline extends Estimator[PipelineModel] {

override def transform(schema: StructType, paramMap: ParamMap): StructType = {
val map = this.paramMap ++ paramMap
map(stages).foldLeft(schema)((cur, stage) => stage.transform(cur, paramMap))
val theStages = map(stages)
require(theStages.toSet.size == theStages.size,
"Cannot have duplicate components in a pipeline.")
theStages.foldLeft(schema)((cur, stage) => stage.transform(cur, paramMap))
}
}

Expand Down
10 changes: 10 additions & 0 deletions mllib/src/test/scala/org/apache/spark/ml/PipelineSuite.scala
Original file line number Diff line number Diff line change
Expand Up @@ -68,4 +68,14 @@ class PipelineSuite extends FunSuite {
val output = pipelineModel.transform(dataset0)
assert(output.eq(dataset4))
}

test("pipeline with duplicate stages") {
val estimator = mock[Estimator[MyModel]]
val pipeline = new Pipeline()
.setStages(Array(estimator, estimator))
val dataset = mock[SchemaRDD]
intercept[IllegalArgumentException] {
pipeline.fit(dataset)
}
}
}

0 comments on commit 6e7c1c7

Please sign in to comment.