Skip to content

Commit

Permalink
add schema transformation layer
Browse files Browse the repository at this point in the history
  • Loading branch information
mengxr committed Nov 10, 2014
1 parent 6736e87 commit 73a000b
Show file tree
Hide file tree
Showing 6 changed files with 134 additions and 15 deletions.
23 changes: 18 additions & 5 deletions mllib/src/main/scala/org/apache/spark/ml/Pipeline.scala
Original file line number Diff line number Diff line change
Expand Up @@ -20,12 +20,18 @@ package org.apache.spark.ml
import scala.collection.mutable.ListBuffer

import org.apache.spark.ml.param.{Param, ParamMap}
import org.apache.spark.sql.SchemaRDD
import org.apache.spark.sql.{StructType, SchemaRDD}

/**
* A stage in a pipeline, either an Estimator or an Transformer.
*/
abstract class PipelineStage
abstract class PipelineStage {

/**
* Derives the output schema from the input schema and parameters.
*/
def transform(schema: StructType, paramMap: ParamMap): StructType
}

/**
* A simple pipeline, which acts as an estimator.
Expand Down Expand Up @@ -70,6 +76,11 @@ class Pipeline extends Estimator[PipelineModel] {

new PipelineModel(this, map, transformers.toArray)
}

override def transform(schema: StructType, paramMap: ParamMap): StructType = {
val map = this.paramMap ++ paramMap
map(stages).foldLeft(schema)((cur, stage) => stage.transform(cur, paramMap))
}
}

/**
Expand Down Expand Up @@ -99,8 +110,10 @@ class PipelineModel(
}

override def transform(dataset: SchemaRDD, paramMap: ParamMap): SchemaRDD = {
transformers.foldLeft(dataset) { (dataset, transformer) =>
transformer.transform(dataset, paramMap)
}
transformers.foldLeft(dataset)((cur, transformer) => transformer.transform(cur, paramMap))
}

override def transform(schema: StructType, paramMap: ParamMap): StructType = {
transformers.foldLeft(schema)((cur, transformer) => transformer.transform(cur, paramMap))
}
}
23 changes: 22 additions & 1 deletion mllib/src/main/scala/org/apache/spark/ml/Transformer.scala
Original file line number Diff line number Diff line change
Expand Up @@ -17,11 +17,14 @@

package org.apache.spark.ml

import org.apache.spark.sql.catalyst.ScalaReflection
import org.apache.spark.sql.catalyst.types.{StructField, StructType}

import scala.annotation.varargs
import scala.reflect.runtime.universe.TypeTag

import org.apache.spark.ml.param._
import org.apache.spark.sql.SchemaRDD
import org.apache.spark.sql.{DataType, SchemaRDD}
import org.apache.spark.sql.api.java.JavaSchemaRDD
import org.apache.spark.sql.catalyst.analysis.Star
import org.apache.spark.sql.catalyst.dsl._
Expand Down Expand Up @@ -81,6 +84,24 @@ abstract class UnaryTransformer[IN, OUT: TypeTag, SELF <: UnaryTransformer[IN, O
*/
protected def createTransformFunc(paramMap: ParamMap): IN => OUT

/**
* Validates the input type. Throw an exception if it is invalid.
*/
protected def validateInputType(inputType: DataType): Unit = {}

override def transform(schema: StructType, paramMap: ParamMap): StructType = {
val map = this.paramMap ++ paramMap
val inputType = schema(map(inputCol)).dataType
validateInputType(inputType)
if (schema.fieldNames.contains(map(outputCol))) {
throw new IllegalArgumentException(s"Output column ${map(outputCol)} already exists.")
}
val output = ScalaReflection.schemaFor[OUT]
val outputFields = schema.fields :+
StructField(map(outputCol), output.dataType, output.nullable)
StructType(outputFields)
}

override def transform(dataset: SchemaRDD, paramMap: ParamMap): SchemaRDD = {
import dataset.sqlContext._
val map = this.paramMap ++ paramMap
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -20,12 +20,13 @@ package org.apache.spark.ml.classification
import org.apache.spark.ml._
import org.apache.spark.ml.param._
import org.apache.spark.mllib.classification.LogisticRegressionWithLBFGS
import org.apache.spark.mllib.linalg.{BLAS, Vector}
import org.apache.spark.mllib.linalg.{VectorUDT, BLAS, Vector}
import org.apache.spark.mllib.regression.LabeledPoint
import org.apache.spark.sql.catalyst.types._
import org.apache.spark.sql.SchemaRDD
import org.apache.spark.sql.catalyst.analysis.Star
import org.apache.spark.sql.catalyst.dsl._
import org.apache.spark.sql.catalyst.expressions.Row
import org.apache.spark.sql.catalyst.expressions.{Cast, Row}
import org.apache.spark.storage.StorageLevel

/**
Expand Down Expand Up @@ -54,7 +55,7 @@ class LogisticRegression extends Estimator[LogisticRegressionModel] with Logisti
override def fit(dataset: SchemaRDD, paramMap: ParamMap): LogisticRegressionModel = {
import dataset.sqlContext._
val map = this.paramMap ++ paramMap
val instances = dataset.select(map(labelCol).attr, map(featuresCol).attr)
val instances = dataset.select(Cast(map(labelCol).attr, DoubleType), map(featuresCol).attr)
.map { case Row(label: Double, features: Vector) =>
LabeledPoint(label, features)
}.persist(StorageLevel.MEMORY_AND_DISK)
Expand All @@ -68,6 +69,33 @@ class LogisticRegression extends Estimator[LogisticRegressionModel] with Logisti
Params.copyValues(this, lrm)
lrm
}

override def transform(schema: StructType, paramMap: ParamMap): StructType = {
val map = this.paramMap ++ paramMap
val featuresType = schema(map(featuresCol)).dataType
// TODO: Support casting Array[Double] and Array[Float] to Vector.
if (!featuresType.isInstanceOf[VectorUDT]) {
throw new IllegalArgumentException(
s"Features column ${map(featuresCol)} must be a vector column but got $featuresType.")
}
val validLabelTypes = Set[DataType](FloatType, DoubleType, IntegerType, BooleanType, LongType)
val labelType = schema(map(labelCol)).dataType
if (!validLabelTypes.contains(labelType)) {
throw new IllegalArgumentException(
s"Cannot convert label column ${map(labelCol)} of type $labelType to a double column.")
}
val fieldNames = schema.fieldNames
if (fieldNames.contains(map(scoreCol))) {
throw new IllegalArgumentException(s"Score column ${map(scoreCol)} already exists.")
}
if (fieldNames.contains(map(predictionCol))) {
throw new IllegalArgumentException(s"Prediction column ${map(predictionCol)} already exists.")
}
val outputFields = schema.fields ++ Seq(
StructField(map(scoreCol), DoubleType, false),
StructField(map(predictionCol), DoubleType, false))
StructType(outputFields)
}
}

/**
Expand All @@ -83,6 +111,27 @@ class LogisticRegressionModel private[ml] (
def setScoreCol(value: String): this.type = { set(scoreCol, value); this }
def setPredictionCol(value: String): this.type = { set(predictionCol, value); this }

override def transform(schema: StructType, paramMap: ParamMap): StructType = {
val map = this.paramMap ++ paramMap
val featuresType = schema(map(featuresCol)).dataType
// TODO: Support casting Array[Double] and Array[Float] to Vector.
if (!featuresType.isInstanceOf[VectorUDT]) {
throw new IllegalArgumentException(
s"Features column ${map(featuresCol)} must be a vector column but got $featuresType.")
}
val fieldNames = schema.fieldNames
if (fieldNames.contains(map(scoreCol))) {
throw new IllegalArgumentException(s"Score column ${map(scoreCol)} already exists.")
}
if (fieldNames.contains(map(predictionCol))) {
throw new IllegalArgumentException(s"Prediction column ${map(predictionCol)} already exists.")
}
val outputFields = schema.fields ++ Seq(
StructField(map(scoreCol), DoubleType, false),
StructField(map(predictionCol), DoubleType, false))
StructType(outputFields)
}

override def transform(dataset: SchemaRDD, paramMap: ParamMap): SchemaRDD = {
import dataset.sqlContext._
val map = this.paramMap ++ paramMap
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -20,8 +20,9 @@ package org.apache.spark.ml.feature
import org.apache.spark.ml._
import org.apache.spark.ml.param._
import org.apache.spark.mllib.feature
import org.apache.spark.mllib.linalg.Vector
import org.apache.spark.sql.SchemaRDD
import org.apache.spark.mllib.linalg.{VectorUDT, Vector}
import org.apache.spark.sql.catalyst.types.StructField
import org.apache.spark.sql.{StructType, SchemaRDD}
import org.apache.spark.sql.catalyst.analysis.Star
import org.apache.spark.sql.catalyst.dsl._
import org.apache.spark.sql.catalyst.expressions.Row
Expand Down Expand Up @@ -52,6 +53,17 @@ class StandardScaler extends Estimator[StandardScalerModel] with StandardScalerP
Params.copyValues(this, model)
model
}

override def transform(schema: StructType, paramMap: ParamMap): StructType = {
val map = this.paramMap ++ paramMap
val inputType = schema(map(inputCol)).dataType
require(inputType.isInstanceOf[VectorUDT],
s"Input column ${map(inputCol)} must be a vector column")
require(!schema.fieldNames.contains(map(outputCol)),
s"Output column ${map(outputCol)} already exists.")
val outputFields = schema.fields :+ StructField(map(outputCol), new VectorUDT, false)
StructType(outputFields)
}
}

/**
Expand All @@ -73,4 +85,15 @@ class StandardScalerModel private[ml] (
}
dataset.select(Star(None), scale.call(map(inputCol).attr) as map(outputCol))
}

override def transform(schema: StructType, paramMap: ParamMap): StructType = {
val map = this.paramMap ++ paramMap
val inputType = schema(map(inputCol)).dataType
require(inputType.isInstanceOf[VectorUDT],
s"Input column ${map(inputCol)} must be a vector column")
require(!schema.fieldNames.contains(map(outputCol)),
s"Output column ${map(outputCol)} already exists.")
val outputFields = schema.fields :+ StructField(map(outputCol), new VectorUDT, false)
StructType(outputFields)
}
}
9 changes: 6 additions & 3 deletions mllib/src/main/scala/org/apache/spark/ml/param/params.scala
Original file line number Diff line number Diff line change
Expand Up @@ -129,6 +129,12 @@ trait Params extends Identifiable with Serializable {
*/
def explainParams(): String = params.mkString("\n")

/** Checks whether a param is explicitly set. */
def isSet(param: Param[_]): Boolean = {
require(param.parent.eq(this))
paramMap.contains(param)
}

/** Gets a param by its name. */
private[ml] def getParam(paramName: String): Param[Any] = {
val m = this.getClass.getMethod(paramName)
Expand All @@ -142,9 +148,6 @@ trait Params extends Identifiable with Serializable {
*/
protected val paramMap: ParamMap = ParamMap.empty

/** Checks whether a param is explicitly set. */
protected def isSet(param: Param[_]): Boolean = paramMap.contains(param)

/**
* Sets a parameter in the own parameter map.
*/
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@ import org.apache.spark.Logging
import org.apache.spark.ml._
import org.apache.spark.ml.param.{IntParam, Param, ParamMap, Params}
import org.apache.spark.mllib.util.MLUtils
import org.apache.spark.sql.SchemaRDD
import org.apache.spark.sql.{StructType, SchemaRDD}

/**
* Params for [[CrossValidator]] and [[CrossValidatorModel]].
Expand Down Expand Up @@ -99,6 +99,11 @@ class CrossValidator extends Estimator[CrossValidatorModel] with CrossValidatorP
Params.copyValues(this, cvModel)
cvModel
}

override def transform(schema: StructType, paramMap: ParamMap): StructType = {
val map = this.paramMap ++ paramMap
map(estimator).transform(schema, paramMap)
}
}

/**
Expand All @@ -109,7 +114,12 @@ class CrossValidatorModel private[ml] (
override val fittingParamMap: ParamMap,
bestModel: Model,
metric: Double) extends Model with CrossValidatorParams {

override def transform(dataset: SchemaRDD, paramMap: ParamMap): SchemaRDD = {
bestModel.transform(dataset, paramMap)
}

override def transform(schema: StructType, paramMap: ParamMap): StructType = {
bestModel.transform(schema, paramMap)
}
}

0 comments on commit 73a000b

Please sign in to comment.