diff --git a/src/clean-missing-data/build.sbt b/src/clean-missing-data/build.sbt new file mode 100644 index 00000000000..47a5d9cbee2 --- /dev/null +++ b/src/clean-missing-data/build.sbt @@ -0,0 +1,2 @@ +//> DependsOn: core +//> DependsOn: utils diff --git a/src/clean-missing-data/src/main/scala/CleanMissingData.scala b/src/clean-missing-data/src/main/scala/CleanMissingData.scala new file mode 100644 index 00000000000..13f3c277689 --- /dev/null +++ b/src/clean-missing-data/src/main/scala/CleanMissingData.scala @@ -0,0 +1,208 @@ +// Copyright (C) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. See LICENSE in project root for information. + +package com.microsoft.ml.spark + +import org.apache.hadoop.fs.Path +import org.apache.spark.annotation.DeveloperApi +import org.apache.spark.ml.param._ +import org.apache.spark.ml.util._ +import org.apache.spark.ml._ +import org.apache.spark.sql._ +import org.apache.spark.sql.types.StructType + +import scala.collection.mutable.ListBuffer + +object CleanMissingData extends DefaultParamsReadable[CleanMissingData] { + val meanOpt = "Mean" + val medianOpt = "Median" + val customOpt = "Custom" + val modes = Array(meanOpt, medianOpt, customOpt) + + def validateAndTransformSchema(schema: StructType, + inputCols: Array[String], + outputCols: Array[String]): StructType = { + inputCols.zip(outputCols).foldLeft(schema)((oldSchema, io) => { + if (oldSchema.fieldNames.contains(io._2)) { + val index = oldSchema.fieldIndex(io._2) + val fields = oldSchema.fields + fields(index) = oldSchema.fields(oldSchema.fieldIndex(io._1)) + StructType(fields) + } else { + oldSchema.add(oldSchema.fields(oldSchema.fieldIndex(io._1))) + } + }) + } +} + +/** + * Removes missing values from input dataset. + * The following modes are supported: + * Mean - replaces missings with mean of fit column + * Median - replaces missings with approximate median of fit column + * Custom - replaces missings with custom value specified by user + */ +class CleanMissingData(override val uid: String) extends Estimator[CleanMissingDataModel] + with HasInputCols with HasOutputCols with MMLParams { + + def this() = this(Identifiable.randomUID("CleanMissingData")) + + val cleaningMode = StringParam(this, "cleaningMode", "cleaning mode", CleanMissingData.meanOpt) + def setCleaningMode(value: String): this.type = set(cleaningMode, value) + def getCleaningMode: String = $(cleaningMode) + + val customValue = DoubleParam(this, "customValue", "custom value for replacement") + def setCustomValue(value: Double): this.type = set(customValue, value) + def getCustomValue: Double = $(customValue) + + /** + * Fits the dataset, prepares the transformation function. + * + * @param dataset The input dataset. + * @return The model for removing missings. + */ + override def fit(dataset: Dataset[_]): CleanMissingDataModel = { + val replacementValues = getReplacementValues(dataset, getInputCols, getOutputCols, getCleaningMode) + new CleanMissingDataModel(uid, replacementValues, getInputCols, getOutputCols) + } + + override def copy(extra: ParamMap): Estimator[CleanMissingDataModel] = defaultCopy(extra) + + @DeveloperApi + override def transformSchema(schema: StructType): StructType = + CleanMissingData.validateAndTransformSchema(schema, getInputCols, getOutputCols) + + private def getReplacementValues(dataset: Dataset[_], + colsToClean: Array[String], + outputCols: Array[String], + mode: String): Map[String, Any] = { + import org.apache.spark.sql.functions._ + val columns = colsToClean.map(col => dataset(col)) + val metrics = getCleaningMode match { + case CleanMissingData.meanOpt => { + val row = dataset.select(columns.map(column => avg(column)): _*).collect()(0) + rowToValues(row) + } + case CleanMissingData.medianOpt => { + val row = dataset.select(columns.map(column => callUDF("percentile_approx", column, lit(0.5))): _*).collect()(0) + rowToValues(row) + } + case CleanMissingData.customOpt => { + colsToClean.map(col => getCustomValue) + } + } + outputCols.zip(metrics).toMap + } + + private def rowToValues(row: Row): Array[Double] = { + val avgs = ListBuffer[Double]() + for (i <- 0 until row.size) { + avgs += row.getDouble(i) + } + avgs.toArray + } +} + +/** + * Model produced by [[CleanMissingData]]. + */ +class CleanMissingDataModel(val uid: String, + val replacementValues: Map[String, Any], + val inputCols: Array[String], + val outputCols: Array[String]) + extends Model[CleanMissingDataModel] with MLWritable { + + override def write: MLWriter = new CleanMissingDataModel.CleanMissingDataModelWriter(uid, + replacementValues, + inputCols, + outputCols) + + override def copy(extra: ParamMap): CleanMissingDataModel = + new CleanMissingDataModel(uid, replacementValues, inputCols, outputCols) + + override def transform(dataset: Dataset[_]): DataFrame = { + val datasetCols = dataset.columns.map(name => dataset(name)).toList + val datasetInputCols = inputCols.zip(outputCols) + .flatMap(io => + if (io._1 == io._2) { + None + } else { + Some(dataset(io._1).as(io._2)) + }).toList + val addedCols = dataset.select((datasetCols ::: datasetInputCols):_*) + addedCols.na.fill(replacementValues) + } + + @DeveloperApi + override def transformSchema(schema: StructType): StructType = + CleanMissingData.validateAndTransformSchema(schema, inputCols, outputCols) +} + +object CleanMissingDataModel extends MLReadable[CleanMissingDataModel] { + + private val replacementValuesPart = "replacementValues" + private val inputColsPart = "inputCols" + private val outputColsPart = "outputCols" + private val dataPart = "data" + + override def read: MLReader[CleanMissingDataModel] = new CleanMissingDataModelReader + + override def load(path: String): CleanMissingDataModel = super.load(path) + + /** [[MLWriter]] instance for [[CleanMissingDataModel]] */ + private[CleanMissingDataModel] + class CleanMissingDataModelWriter(val uid: String, + val replacementValues: Map[String, Any], + val inputCols: Array[String], + val outputCols: Array[String]) + extends MLWriter { + private case class Data(uid: String) + + override protected def saveImpl(path: String): Unit = { + val overwrite = this.shouldOverwrite + val qualPath = PipelineUtilities.makeQualifiedPath(sc, path) + // Required in order to allow this to be part of an ML pipeline + PipelineUtilities.saveMetadata(uid, + CleanMissingDataModel.getClass.getName.replace("$", ""), + new Path(path, "metadata").toString, + sc, + overwrite) + + // save the replacement values + ObjectUtilities.writeObject(replacementValues, qualPath, replacementValuesPart, sc, overwrite) + + // save the input cols and output cols + ObjectUtilities.writeObject(inputCols, qualPath, inputColsPart, sc, overwrite) + ObjectUtilities.writeObject(outputCols, qualPath, outputColsPart, sc, overwrite) + + // save model data + val data = Data(uid) + val dataPath = new Path(qualPath, dataPart).toString + val saveMode = + if (overwrite) SaveMode.Overwrite + else SaveMode.ErrorIfExists + sparkSession.createDataFrame(Seq(data)).repartition(1).write.mode(saveMode).parquet(dataPath) + } + } + + private class CleanMissingDataModelReader + extends MLReader[CleanMissingDataModel] { + + override def load(path: String): CleanMissingDataModel = { + val qualPath = PipelineUtilities.makeQualifiedPath(sc, path) + // load the uid + val dataPath = new Path(qualPath, dataPart).toString + val data = sparkSession.read.format("parquet").load(dataPath) + val Row(uid: String) = data.select("uid").head() + + // get the replacement values + val replacementValues = ObjectUtilities.loadObject[Map[String, Any]](qualPath, replacementValuesPart, sc) + // get the input and output cols + val inputCols = ObjectUtilities.loadObject[Array[String]](qualPath, inputColsPart, sc) + val outputCols = ObjectUtilities.loadObject[Array[String]](qualPath, outputColsPart, sc) + + new CleanMissingDataModel(uid, replacementValues, inputCols, outputCols) + } + } + +} diff --git a/src/clean-missing-data/src/test/scala/VerifyCleanMissingData.scala b/src/clean-missing-data/src/test/scala/VerifyCleanMissingData.scala new file mode 100644 index 00000000000..261fcd37b95 --- /dev/null +++ b/src/clean-missing-data/src/test/scala/VerifyCleanMissingData.scala @@ -0,0 +1,142 @@ +// Copyright (C) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. See LICENSE in project root for information. + +package com.microsoft.ml.spark + +import org.apache.spark.ml.Estimator +import org.apache.spark.sql.DataFrame +import org.apache.spark.sql.types.StructType +import java.lang.{Double => JDouble, Integer => JInt} + +import org.scalactic.TolerantNumerics + +/** + * Tests to validate the functionality of Clean Missing Data estimator. + */ +class VerifyCleanMissingData extends EstimatorFuzzingTest { + + val tolerance = 0.01 + implicit val doubleEq = TolerantNumerics.tolerantDoubleEquality(tolerance) + val tolEq = TolerantNumerics.tolerantDoubleEquality(tolerance) + + import session.implicits._ + def createMockDataset: DataFrame = { + Seq[(JInt, JInt, JDouble, JDouble, JInt)]( + (0, 2, 0.50, 0.60, 0), + (1, 3, 0.40, null, null), + (0, 4, 0.78, 0.99, 2), + (1, 5, 0.12, 0.34, 3), + (0, 1, 0.50, 0.60, 0), + (null, null, null, null, null), + (0, 3, 0.78, 0.99, 2), + (1, 4, 0.12, 0.34, 3), + (0, null, 0.50, 0.60, 0), + (1, 2, 0.40, 0.50, null), + (0, 3, null, 0.99, 2), + (1, 4, 0.12, 0.34, 3)) + .toDF("col1", "col2", "col3", "col4", "col5") + } + + test("Test for cleaning missing data with mean") { + val dataset = createMockDataset + val cmd = new CleanMissingData() + .setInputCols(dataset.columns) + .setOutputCols(dataset.columns) + .setCleaningMode(CleanMissingData.meanOpt) + val cmdModel = cmd.fit(dataset) + val result = cmdModel.transform(dataset) + // Calculate mean of column values + val numCols = dataset.columns.length + val meanValues = Array.ofDim[Double](numCols) + val counts = Array.ofDim[Double](numCols) + val collected = dataset.collect() + collected.foreach(row => { + for (i <- 0 until numCols) { + val rawValue = row.get(i) + val rowValue = + if (rawValue == null) 0 + else if (i == 2 || i == 3) { + counts(i) += 1 + row.get(i).asInstanceOf[JDouble].doubleValue() + } else { + counts(i) += 1 + row.get(i).asInstanceOf[JInt].doubleValue() + } + meanValues(i) += rowValue + } + }) + for (i <- 0 until numCols) { + meanValues(i) /= counts(i) + if (i != 2 && i != 3) { + meanValues(i) = meanValues(i).toInt.toDouble + } + } + verifyReplacementValues(dataset, result, meanValues) + } + + test("Test for cleaning missing data with median") { + val dataset = createMockDataset + val cmd = new CleanMissingData() + .setInputCols(dataset.columns) + .setOutputCols(dataset.columns) + .setCleaningMode(CleanMissingData.medianOpt) + val cmdModel = cmd.fit(dataset) + val result = cmdModel.transform(dataset) + val medianValues = Array[Double](0, 3, 0.4, 0.6, 2) + verifyReplacementValues(dataset, result, medianValues) + } + + test("Test for cleaning missing data with custom value") { + val dataset = createMockDataset + val customValue = -1.5 + val cmd = new CleanMissingData() + .setInputCols(dataset.columns) + .setOutputCols(dataset.columns) + .setCleaningMode(CleanMissingData.customOpt) + .setCustomValue(customValue) + val cmdModel = cmd.fit(dataset) + val result = cmdModel.transform(dataset) + val replacesValues = Array.fill[Double](dataset.columns.length)(customValue) + val numCols = replacesValues.length + for (i <- 0 until numCols) { + if (i != 2 && i != 3) { + replacesValues(i) = replacesValues(i).toInt.toDouble + } + } + verifyReplacementValues(dataset, result, replacesValues) + } + + private def verifyReplacementValues(expected: DataFrame, result: DataFrame, expectedValues: Array[Double]) = { + val collectedExp = expected.collect() + val collectedResult = result.collect() + val numRows = result.count().toInt + val numCols = result.columns.length + for (j <- 0 until numRows) { + for (i <- 0 until numCols) { + val row = collectedExp(j) + val (rowValue, actualValue) = + if (i == 2 || i == 3) { + (row.get(i).asInstanceOf[JDouble], collectedResult(j)(i).asInstanceOf[Double]) + } else { + (row.get(i).asInstanceOf[JInt], collectedResult(j)(i).asInstanceOf[Int].toDouble) + } + if (rowValue == null) { + val expectedValue = expectedValues(i) + assert(tolEq.areEquivalent(expectedValue, actualValue), + s"Values do not match, expected: $expectedValue, result: $actualValue") + } + } + } + } + + override def createFitDataset: DataFrame = { + createMockDataset + } + + override def schemaForDataset: StructType = ??? + + override def getEstimator(): Estimator[_] = { + val dataset = createFitDataset + new CleanMissingData().setInputCols(dataset.columns).setOutputCols(dataset.columns) + } +} diff --git a/src/core/contracts/src/main/scala/Params.scala b/src/core/contracts/src/main/scala/Params.scala index f0ab6e8473e..698f5ec17fe 100644 --- a/src/core/contracts/src/main/scala/Params.scala +++ b/src/core/contracts/src/main/scala/Params.scala @@ -121,6 +121,18 @@ trait HasOutputCol extends Wrappable { def getOutputCol: String = $(outputCol) } +trait HasInputCols extends Wrappable { + val inputCols = new StringArrayParam(this, "inputCols", "The names of the input columns") + def setInputCols(value: Array[String]): this.type = set(inputCols, value) + def getInputCols: Array[String] = $(inputCols) +} + +trait HasOutputCols extends Wrappable { + val outputCols = new StringArrayParam(this, "outputCols", "The names of the output columns") + def setOutputCols(value: Array[String]): this.type = set(outputCols, value) + def getOutputCols: Array[String] = $(outputCols) +} + trait HasLabelCol extends Wrappable { val labelCol = StringParam(this, "labelCol", "The name of the label column") def setLabelCol(value: String): this.type = set(labelCol, value)