From 73a000ba22bbef8e248ef7d6df7872fa8f1d6d94 Mon Sep 17 00:00:00 2001 From: Xiangrui Meng Date: Sun, 9 Nov 2014 21:51:40 -0800 Subject: [PATCH] add schema transformation layer --- .../scala/org/apache/spark/ml/Pipeline.scala | 23 ++++++-- .../org/apache/spark/ml/Transformer.scala | 23 +++++++- .../classification/LogisticRegression.scala | 55 ++++++++++++++++++- .../spark/ml/feature/StandardScaler.scala | 27 ++++++++- .../org/apache/spark/ml/param/params.scala | 9 ++- .../spark/ml/tuning/CrossValidator.scala | 12 +++- 6 files changed, 134 insertions(+), 15 deletions(-) diff --git a/mllib/src/main/scala/org/apache/spark/ml/Pipeline.scala b/mllib/src/main/scala/org/apache/spark/ml/Pipeline.scala index fe511bd904e70..a96a6936b2670 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/Pipeline.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/Pipeline.scala @@ -20,12 +20,18 @@ package org.apache.spark.ml import scala.collection.mutable.ListBuffer import org.apache.spark.ml.param.{Param, ParamMap} -import org.apache.spark.sql.SchemaRDD +import org.apache.spark.sql.{StructType, SchemaRDD} /** * A stage in a pipeline, either an Estimator or an Transformer. */ -abstract class PipelineStage +abstract class PipelineStage { + + /** + * Derives the output schema from the input schema and parameters. + */ + def transform(schema: StructType, paramMap: ParamMap): StructType +} /** * A simple pipeline, which acts as an estimator. @@ -70,6 +76,11 @@ class Pipeline extends Estimator[PipelineModel] { new PipelineModel(this, map, transformers.toArray) } + + override def transform(schema: StructType, paramMap: ParamMap): StructType = { + val map = this.paramMap ++ paramMap + map(stages).foldLeft(schema)((cur, stage) => stage.transform(cur, paramMap)) + } } /** @@ -99,8 +110,10 @@ class PipelineModel( } override def transform(dataset: SchemaRDD, paramMap: ParamMap): SchemaRDD = { - transformers.foldLeft(dataset) { (dataset, transformer) => - transformer.transform(dataset, paramMap) - } + transformers.foldLeft(dataset)((cur, transformer) => transformer.transform(cur, paramMap)) + } + + override def transform(schema: StructType, paramMap: ParamMap): StructType = { + transformers.foldLeft(schema)((cur, transformer) => transformer.transform(cur, paramMap)) } } diff --git a/mllib/src/main/scala/org/apache/spark/ml/Transformer.scala b/mllib/src/main/scala/org/apache/spark/ml/Transformer.scala index eba1b9c36c7ff..0f40d1ceea852 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/Transformer.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/Transformer.scala @@ -17,11 +17,14 @@ package org.apache.spark.ml +import org.apache.spark.sql.catalyst.ScalaReflection +import org.apache.spark.sql.catalyst.types.{StructField, StructType} + import scala.annotation.varargs import scala.reflect.runtime.universe.TypeTag import org.apache.spark.ml.param._ -import org.apache.spark.sql.SchemaRDD +import org.apache.spark.sql.{DataType, SchemaRDD} import org.apache.spark.sql.api.java.JavaSchemaRDD import org.apache.spark.sql.catalyst.analysis.Star import org.apache.spark.sql.catalyst.dsl._ @@ -81,6 +84,24 @@ abstract class UnaryTransformer[IN, OUT: TypeTag, SELF <: UnaryTransformer[IN, O */ protected def createTransformFunc(paramMap: ParamMap): IN => OUT + /** + * Validates the input type. Throw an exception if it is invalid. + */ + protected def validateInputType(inputType: DataType): Unit = {} + + override def transform(schema: StructType, paramMap: ParamMap): StructType = { + val map = this.paramMap ++ paramMap + val inputType = schema(map(inputCol)).dataType + validateInputType(inputType) + if (schema.fieldNames.contains(map(outputCol))) { + throw new IllegalArgumentException(s"Output column ${map(outputCol)} already exists.") + } + val output = ScalaReflection.schemaFor[OUT] + val outputFields = schema.fields :+ + StructField(map(outputCol), output.dataType, output.nullable) + StructType(outputFields) + } + override def transform(dataset: SchemaRDD, paramMap: ParamMap): SchemaRDD = { import dataset.sqlContext._ val map = this.paramMap ++ paramMap diff --git a/mllib/src/main/scala/org/apache/spark/ml/classification/LogisticRegression.scala b/mllib/src/main/scala/org/apache/spark/ml/classification/LogisticRegression.scala index e0607dd2e0bea..5108e646e4f48 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/classification/LogisticRegression.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/classification/LogisticRegression.scala @@ -20,12 +20,13 @@ package org.apache.spark.ml.classification import org.apache.spark.ml._ import org.apache.spark.ml.param._ import org.apache.spark.mllib.classification.LogisticRegressionWithLBFGS -import org.apache.spark.mllib.linalg.{BLAS, Vector} +import org.apache.spark.mllib.linalg.{VectorUDT, BLAS, Vector} import org.apache.spark.mllib.regression.LabeledPoint +import org.apache.spark.sql.catalyst.types._ import org.apache.spark.sql.SchemaRDD import org.apache.spark.sql.catalyst.analysis.Star import org.apache.spark.sql.catalyst.dsl._ -import org.apache.spark.sql.catalyst.expressions.Row +import org.apache.spark.sql.catalyst.expressions.{Cast, Row} import org.apache.spark.storage.StorageLevel /** @@ -54,7 +55,7 @@ class LogisticRegression extends Estimator[LogisticRegressionModel] with Logisti override def fit(dataset: SchemaRDD, paramMap: ParamMap): LogisticRegressionModel = { import dataset.sqlContext._ val map = this.paramMap ++ paramMap - val instances = dataset.select(map(labelCol).attr, map(featuresCol).attr) + val instances = dataset.select(Cast(map(labelCol).attr, DoubleType), map(featuresCol).attr) .map { case Row(label: Double, features: Vector) => LabeledPoint(label, features) }.persist(StorageLevel.MEMORY_AND_DISK) @@ -68,6 +69,33 @@ class LogisticRegression extends Estimator[LogisticRegressionModel] with Logisti Params.copyValues(this, lrm) lrm } + + override def transform(schema: StructType, paramMap: ParamMap): StructType = { + val map = this.paramMap ++ paramMap + val featuresType = schema(map(featuresCol)).dataType + // TODO: Support casting Array[Double] and Array[Float] to Vector. + if (!featuresType.isInstanceOf[VectorUDT]) { + throw new IllegalArgumentException( + s"Features column ${map(featuresCol)} must be a vector column but got $featuresType.") + } + val validLabelTypes = Set[DataType](FloatType, DoubleType, IntegerType, BooleanType, LongType) + val labelType = schema(map(labelCol)).dataType + if (!validLabelTypes.contains(labelType)) { + throw new IllegalArgumentException( + s"Cannot convert label column ${map(labelCol)} of type $labelType to a double column.") + } + val fieldNames = schema.fieldNames + if (fieldNames.contains(map(scoreCol))) { + throw new IllegalArgumentException(s"Score column ${map(scoreCol)} already exists.") + } + if (fieldNames.contains(map(predictionCol))) { + throw new IllegalArgumentException(s"Prediction column ${map(predictionCol)} already exists.") + } + val outputFields = schema.fields ++ Seq( + StructField(map(scoreCol), DoubleType, false), + StructField(map(predictionCol), DoubleType, false)) + StructType(outputFields) + } } /** @@ -83,6 +111,27 @@ class LogisticRegressionModel private[ml] ( def setScoreCol(value: String): this.type = { set(scoreCol, value); this } def setPredictionCol(value: String): this.type = { set(predictionCol, value); this } + override def transform(schema: StructType, paramMap: ParamMap): StructType = { + val map = this.paramMap ++ paramMap + val featuresType = schema(map(featuresCol)).dataType + // TODO: Support casting Array[Double] and Array[Float] to Vector. + if (!featuresType.isInstanceOf[VectorUDT]) { + throw new IllegalArgumentException( + s"Features column ${map(featuresCol)} must be a vector column but got $featuresType.") + } + val fieldNames = schema.fieldNames + if (fieldNames.contains(map(scoreCol))) { + throw new IllegalArgumentException(s"Score column ${map(scoreCol)} already exists.") + } + if (fieldNames.contains(map(predictionCol))) { + throw new IllegalArgumentException(s"Prediction column ${map(predictionCol)} already exists.") + } + val outputFields = schema.fields ++ Seq( + StructField(map(scoreCol), DoubleType, false), + StructField(map(predictionCol), DoubleType, false)) + StructType(outputFields) + } + override def transform(dataset: SchemaRDD, paramMap: ParamMap): SchemaRDD = { import dataset.sqlContext._ val map = this.paramMap ++ paramMap diff --git a/mllib/src/main/scala/org/apache/spark/ml/feature/StandardScaler.scala b/mllib/src/main/scala/org/apache/spark/ml/feature/StandardScaler.scala index a5f95538b3853..efb29747f06a7 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/feature/StandardScaler.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/feature/StandardScaler.scala @@ -20,8 +20,9 @@ package org.apache.spark.ml.feature import org.apache.spark.ml._ import org.apache.spark.ml.param._ import org.apache.spark.mllib.feature -import org.apache.spark.mllib.linalg.Vector -import org.apache.spark.sql.SchemaRDD +import org.apache.spark.mllib.linalg.{VectorUDT, Vector} +import org.apache.spark.sql.catalyst.types.StructField +import org.apache.spark.sql.{StructType, SchemaRDD} import org.apache.spark.sql.catalyst.analysis.Star import org.apache.spark.sql.catalyst.dsl._ import org.apache.spark.sql.catalyst.expressions.Row @@ -52,6 +53,17 @@ class StandardScaler extends Estimator[StandardScalerModel] with StandardScalerP Params.copyValues(this, model) model } + + override def transform(schema: StructType, paramMap: ParamMap): StructType = { + val map = this.paramMap ++ paramMap + val inputType = schema(map(inputCol)).dataType + require(inputType.isInstanceOf[VectorUDT], + s"Input column ${map(inputCol)} must be a vector column") + require(!schema.fieldNames.contains(map(outputCol)), + s"Output column ${map(outputCol)} already exists.") + val outputFields = schema.fields :+ StructField(map(outputCol), new VectorUDT, false) + StructType(outputFields) + } } /** @@ -73,4 +85,15 @@ class StandardScalerModel private[ml] ( } dataset.select(Star(None), scale.call(map(inputCol).attr) as map(outputCol)) } + + override def transform(schema: StructType, paramMap: ParamMap): StructType = { + val map = this.paramMap ++ paramMap + val inputType = schema(map(inputCol)).dataType + require(inputType.isInstanceOf[VectorUDT], + s"Input column ${map(inputCol)} must be a vector column") + require(!schema.fieldNames.contains(map(outputCol)), + s"Output column ${map(outputCol)} already exists.") + val outputFields = schema.fields :+ StructField(map(outputCol), new VectorUDT, false) + StructType(outputFields) + } } diff --git a/mllib/src/main/scala/org/apache/spark/ml/param/params.scala b/mllib/src/main/scala/org/apache/spark/ml/param/params.scala index 4999729780b0a..699f6ef3c775e 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/param/params.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/param/params.scala @@ -129,6 +129,12 @@ trait Params extends Identifiable with Serializable { */ def explainParams(): String = params.mkString("\n") + /** Checks whether a param is explicitly set. */ + def isSet(param: Param[_]): Boolean = { + require(param.parent.eq(this)) + paramMap.contains(param) + } + /** Gets a param by its name. */ private[ml] def getParam(paramName: String): Param[Any] = { val m = this.getClass.getMethod(paramName) @@ -142,9 +148,6 @@ trait Params extends Identifiable with Serializable { */ protected val paramMap: ParamMap = ParamMap.empty - /** Checks whether a param is explicitly set. */ - protected def isSet(param: Param[_]): Boolean = paramMap.contains(param) - /** * Sets a parameter in the own parameter map. */ diff --git a/mllib/src/main/scala/org/apache/spark/ml/tuning/CrossValidator.scala b/mllib/src/main/scala/org/apache/spark/ml/tuning/CrossValidator.scala index f7b3f9b04caef..144684833a626 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/tuning/CrossValidator.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/tuning/CrossValidator.scala @@ -23,7 +23,7 @@ import org.apache.spark.Logging import org.apache.spark.ml._ import org.apache.spark.ml.param.{IntParam, Param, ParamMap, Params} import org.apache.spark.mllib.util.MLUtils -import org.apache.spark.sql.SchemaRDD +import org.apache.spark.sql.{StructType, SchemaRDD} /** * Params for [[CrossValidator]] and [[CrossValidatorModel]]. @@ -99,6 +99,11 @@ class CrossValidator extends Estimator[CrossValidatorModel] with CrossValidatorP Params.copyValues(this, cvModel) cvModel } + + override def transform(schema: StructType, paramMap: ParamMap): StructType = { + val map = this.paramMap ++ paramMap + map(estimator).transform(schema, paramMap) + } } /** @@ -109,7 +114,12 @@ class CrossValidatorModel private[ml] ( override val fittingParamMap: ParamMap, bestModel: Model, metric: Double) extends Model with CrossValidatorParams { + override def transform(dataset: SchemaRDD, paramMap: ParamMap): SchemaRDD = { bestModel.transform(dataset, paramMap) } + + override def transform(schema: StructType, paramMap: ParamMap): StructType = { + bestModel.transform(schema, paramMap) + } }