-
Notifications
You must be signed in to change notification settings - Fork 834
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
1 parent
016883d
commit ee412f5
Showing
4 changed files
with
364 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,2 @@ | ||
//> DependsOn: core | ||
//> DependsOn: utils |
208 changes: 208 additions & 0 deletions
208
src/clean-missing-data/src/main/scala/CleanMissingData.scala
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,208 @@ | ||
// Copyright (C) Microsoft Corporation. All rights reserved. | ||
// Licensed under the MIT License. See LICENSE in project root for information. | ||
|
||
package com.microsoft.ml.spark | ||
|
||
import org.apache.hadoop.fs.Path | ||
import org.apache.spark.annotation.DeveloperApi | ||
import org.apache.spark.ml.param._ | ||
import org.apache.spark.ml.util._ | ||
import org.apache.spark.ml._ | ||
import org.apache.spark.sql._ | ||
import org.apache.spark.sql.types.StructType | ||
|
||
import scala.collection.mutable.ListBuffer | ||
|
||
object CleanMissingData extends DefaultParamsReadable[CleanMissingData] { | ||
val meanOpt = "Mean" | ||
val medianOpt = "Median" | ||
val customOpt = "Custom" | ||
val modes = Array(meanOpt, medianOpt, customOpt) | ||
|
||
def validateAndTransformSchema(schema: StructType, | ||
inputCols: Array[String], | ||
outputCols: Array[String]): StructType = { | ||
inputCols.zip(outputCols).foldLeft(schema)((oldSchema, io) => { | ||
if (oldSchema.fieldNames.contains(io._2)) { | ||
val index = oldSchema.fieldIndex(io._2) | ||
val fields = oldSchema.fields | ||
fields(index) = oldSchema.fields(oldSchema.fieldIndex(io._1)) | ||
StructType(fields) | ||
} else { | ||
oldSchema.add(oldSchema.fields(oldSchema.fieldIndex(io._1))) | ||
} | ||
}) | ||
} | ||
} | ||
|
||
/** | ||
* Removes missing values from input dataset. | ||
* The following modes are supported: | ||
* Mean - replaces missings with mean of fit column | ||
* Median - replaces missings with approximate median of fit column | ||
* Custom - replaces missings with custom value specified by user | ||
*/ | ||
class CleanMissingData(override val uid: String) extends Estimator[CleanMissingDataModel] | ||
with HasInputCols with HasOutputCols with MMLParams { | ||
|
||
def this() = this(Identifiable.randomUID("CleanMissingData")) | ||
|
||
val cleaningMode = StringParam(this, "cleaningMode", "cleaning mode", CleanMissingData.meanOpt) | ||
def setCleaningMode(value: String): this.type = set(cleaningMode, value) | ||
def getCleaningMode: String = $(cleaningMode) | ||
|
||
val customValue = DoubleParam(this, "customValue", "custom value for replacement") | ||
def setCustomValue(value: Double): this.type = set(customValue, value) | ||
def getCustomValue: Double = $(customValue) | ||
|
||
/** | ||
* Fits the dataset, prepares the transformation function. | ||
* | ||
* @param dataset The input dataset. | ||
* @return The model for removing missings. | ||
*/ | ||
override def fit(dataset: Dataset[_]): CleanMissingDataModel = { | ||
val replacementValues = getReplacementValues(dataset, getInputCols, getOutputCols, getCleaningMode) | ||
new CleanMissingDataModel(uid, replacementValues, getInputCols, getOutputCols) | ||
} | ||
|
||
override def copy(extra: ParamMap): Estimator[CleanMissingDataModel] = defaultCopy(extra) | ||
|
||
@DeveloperApi | ||
override def transformSchema(schema: StructType): StructType = | ||
CleanMissingData.validateAndTransformSchema(schema, getInputCols, getOutputCols) | ||
|
||
private def getReplacementValues(dataset: Dataset[_], | ||
colsToClean: Array[String], | ||
outputCols: Array[String], | ||
mode: String): Map[String, Any] = { | ||
import org.apache.spark.sql.functions._ | ||
val columns = colsToClean.map(col => dataset(col)) | ||
val metrics = getCleaningMode match { | ||
case CleanMissingData.meanOpt => { | ||
val row = dataset.select(columns.map(column => avg(column)): _*).collect()(0) | ||
rowToValues(row) | ||
} | ||
case CleanMissingData.medianOpt => { | ||
val row = dataset.select(columns.map(column => callUDF("percentile_approx", column, lit(0.5))): _*).collect()(0) | ||
rowToValues(row) | ||
} | ||
case CleanMissingData.customOpt => { | ||
colsToClean.map(col => getCustomValue) | ||
} | ||
} | ||
outputCols.zip(metrics).toMap | ||
} | ||
|
||
private def rowToValues(row: Row): Array[Double] = { | ||
val avgs = ListBuffer[Double]() | ||
for (i <- 0 until row.size) { | ||
avgs += row.getDouble(i) | ||
} | ||
avgs.toArray | ||
} | ||
} | ||
|
||
/** | ||
* Model produced by [[CleanMissingData]]. | ||
*/ | ||
class CleanMissingDataModel(val uid: String, | ||
val replacementValues: Map[String, Any], | ||
val inputCols: Array[String], | ||
val outputCols: Array[String]) | ||
extends Model[CleanMissingDataModel] with MLWritable { | ||
|
||
override def write: MLWriter = new CleanMissingDataModel.CleanMissingDataModelWriter(uid, | ||
replacementValues, | ||
inputCols, | ||
outputCols) | ||
|
||
override def copy(extra: ParamMap): CleanMissingDataModel = | ||
new CleanMissingDataModel(uid, replacementValues, inputCols, outputCols) | ||
|
||
override def transform(dataset: Dataset[_]): DataFrame = { | ||
val datasetCols = dataset.columns.map(name => dataset(name)).toList | ||
val datasetInputCols = inputCols.zip(outputCols) | ||
.flatMap(io => | ||
if (io._1 == io._2) { | ||
None | ||
} else { | ||
Some(dataset(io._1).as(io._2)) | ||
}).toList | ||
val addedCols = dataset.select((datasetCols ::: datasetInputCols):_*) | ||
addedCols.na.fill(replacementValues) | ||
} | ||
|
||
@DeveloperApi | ||
override def transformSchema(schema: StructType): StructType = | ||
CleanMissingData.validateAndTransformSchema(schema, inputCols, outputCols) | ||
} | ||
|
||
object CleanMissingDataModel extends MLReadable[CleanMissingDataModel] { | ||
|
||
private val replacementValuesPart = "replacementValues" | ||
private val inputColsPart = "inputCols" | ||
private val outputColsPart = "outputCols" | ||
private val dataPart = "data" | ||
|
||
override def read: MLReader[CleanMissingDataModel] = new CleanMissingDataModelReader | ||
|
||
override def load(path: String): CleanMissingDataModel = super.load(path) | ||
|
||
/** [[MLWriter]] instance for [[CleanMissingDataModel]] */ | ||
private[CleanMissingDataModel] | ||
class CleanMissingDataModelWriter(val uid: String, | ||
val replacementValues: Map[String, Any], | ||
val inputCols: Array[String], | ||
val outputCols: Array[String]) | ||
extends MLWriter { | ||
private case class Data(uid: String) | ||
|
||
override protected def saveImpl(path: String): Unit = { | ||
val overwrite = this.shouldOverwrite | ||
val qualPath = PipelineUtilities.makeQualifiedPath(sc, path) | ||
// Required in order to allow this to be part of an ML pipeline | ||
PipelineUtilities.saveMetadata(uid, | ||
CleanMissingDataModel.getClass.getName.replace("$", ""), | ||
new Path(path, "metadata").toString, | ||
sc, | ||
overwrite) | ||
|
||
// save the replacement values | ||
ObjectUtilities.writeObject(replacementValues, qualPath, replacementValuesPart, sc, overwrite) | ||
|
||
// save the input cols and output cols | ||
ObjectUtilities.writeObject(inputCols, qualPath, inputColsPart, sc, overwrite) | ||
ObjectUtilities.writeObject(outputCols, qualPath, outputColsPart, sc, overwrite) | ||
|
||
// save model data | ||
val data = Data(uid) | ||
val dataPath = new Path(qualPath, dataPart).toString | ||
val saveMode = | ||
if (overwrite) SaveMode.Overwrite | ||
else SaveMode.ErrorIfExists | ||
sparkSession.createDataFrame(Seq(data)).repartition(1).write.mode(saveMode).parquet(dataPath) | ||
} | ||
} | ||
|
||
private class CleanMissingDataModelReader | ||
extends MLReader[CleanMissingDataModel] { | ||
|
||
override def load(path: String): CleanMissingDataModel = { | ||
val qualPath = PipelineUtilities.makeQualifiedPath(sc, path) | ||
// load the uid | ||
val dataPath = new Path(qualPath, dataPart).toString | ||
val data = sparkSession.read.format("parquet").load(dataPath) | ||
val Row(uid: String) = data.select("uid").head() | ||
|
||
// get the replacement values | ||
val replacementValues = ObjectUtilities.loadObject[Map[String, Any]](qualPath, replacementValuesPart, sc) | ||
// get the input and output cols | ||
val inputCols = ObjectUtilities.loadObject[Array[String]](qualPath, inputColsPart, sc) | ||
val outputCols = ObjectUtilities.loadObject[Array[String]](qualPath, outputColsPart, sc) | ||
|
||
new CleanMissingDataModel(uid, replacementValues, inputCols, outputCols) | ||
} | ||
} | ||
|
||
} |
142 changes: 142 additions & 0 deletions
142
src/clean-missing-data/src/test/scala/VerifyCleanMissingData.scala
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,142 @@ | ||
// Copyright (C) Microsoft Corporation. All rights reserved. | ||
// Licensed under the MIT License. See LICENSE in project root for information. | ||
|
||
package com.microsoft.ml.spark | ||
|
||
import org.apache.spark.ml.Estimator | ||
import org.apache.spark.sql.DataFrame | ||
import org.apache.spark.sql.types.StructType | ||
import java.lang.{Double => JDouble, Integer => JInt} | ||
|
||
import org.scalactic.TolerantNumerics | ||
|
||
/** | ||
* Tests to validate the functionality of Clean Missing Data estimator. | ||
*/ | ||
class VerifyCleanMissingData extends EstimatorFuzzingTest { | ||
|
||
val tolerance = 0.01 | ||
implicit val doubleEq = TolerantNumerics.tolerantDoubleEquality(tolerance) | ||
val tolEq = TolerantNumerics.tolerantDoubleEquality(tolerance) | ||
|
||
import session.implicits._ | ||
def createMockDataset: DataFrame = { | ||
Seq[(JInt, JInt, JDouble, JDouble, JInt)]( | ||
(0, 2, 0.50, 0.60, 0), | ||
(1, 3, 0.40, null, null), | ||
(0, 4, 0.78, 0.99, 2), | ||
(1, 5, 0.12, 0.34, 3), | ||
(0, 1, 0.50, 0.60, 0), | ||
(null, null, null, null, null), | ||
(0, 3, 0.78, 0.99, 2), | ||
(1, 4, 0.12, 0.34, 3), | ||
(0, null, 0.50, 0.60, 0), | ||
(1, 2, 0.40, 0.50, null), | ||
(0, 3, null, 0.99, 2), | ||
(1, 4, 0.12, 0.34, 3)) | ||
.toDF("col1", "col2", "col3", "col4", "col5") | ||
} | ||
|
||
test("Test for cleaning missing data with mean") { | ||
val dataset = createMockDataset | ||
val cmd = new CleanMissingData() | ||
.setInputCols(dataset.columns) | ||
.setOutputCols(dataset.columns) | ||
.setCleaningMode(CleanMissingData.meanOpt) | ||
val cmdModel = cmd.fit(dataset) | ||
val result = cmdModel.transform(dataset) | ||
// Calculate mean of column values | ||
val numCols = dataset.columns.length | ||
val meanValues = Array.ofDim[Double](numCols) | ||
val counts = Array.ofDim[Double](numCols) | ||
val collected = dataset.collect() | ||
collected.foreach(row => { | ||
for (i <- 0 until numCols) { | ||
val rawValue = row.get(i) | ||
val rowValue = | ||
if (rawValue == null) 0 | ||
else if (i == 2 || i == 3) { | ||
counts(i) += 1 | ||
row.get(i).asInstanceOf[JDouble].doubleValue() | ||
} else { | ||
counts(i) += 1 | ||
row.get(i).asInstanceOf[JInt].doubleValue() | ||
} | ||
meanValues(i) += rowValue | ||
} | ||
}) | ||
for (i <- 0 until numCols) { | ||
meanValues(i) /= counts(i) | ||
if (i != 2 && i != 3) { | ||
meanValues(i) = meanValues(i).toInt.toDouble | ||
} | ||
} | ||
verifyReplacementValues(dataset, result, meanValues) | ||
} | ||
|
||
test("Test for cleaning missing data with median") { | ||
val dataset = createMockDataset | ||
val cmd = new CleanMissingData() | ||
.setInputCols(dataset.columns) | ||
.setOutputCols(dataset.columns) | ||
.setCleaningMode(CleanMissingData.medianOpt) | ||
val cmdModel = cmd.fit(dataset) | ||
val result = cmdModel.transform(dataset) | ||
val medianValues = Array[Double](0, 3, 0.4, 0.6, 2) | ||
verifyReplacementValues(dataset, result, medianValues) | ||
} | ||
|
||
test("Test for cleaning missing data with custom value") { | ||
val dataset = createMockDataset | ||
val customValue = -1.5 | ||
val cmd = new CleanMissingData() | ||
.setInputCols(dataset.columns) | ||
.setOutputCols(dataset.columns) | ||
.setCleaningMode(CleanMissingData.customOpt) | ||
.setCustomValue(customValue) | ||
val cmdModel = cmd.fit(dataset) | ||
val result = cmdModel.transform(dataset) | ||
val replacesValues = Array.fill[Double](dataset.columns.length)(customValue) | ||
val numCols = replacesValues.length | ||
for (i <- 0 until numCols) { | ||
if (i != 2 && i != 3) { | ||
replacesValues(i) = replacesValues(i).toInt.toDouble | ||
} | ||
} | ||
verifyReplacementValues(dataset, result, replacesValues) | ||
} | ||
|
||
private def verifyReplacementValues(expected: DataFrame, result: DataFrame, expectedValues: Array[Double]) = { | ||
val collectedExp = expected.collect() | ||
val collectedResult = result.collect() | ||
val numRows = result.count().toInt | ||
val numCols = result.columns.length | ||
for (j <- 0 until numRows) { | ||
for (i <- 0 until numCols) { | ||
val row = collectedExp(j) | ||
val (rowValue, actualValue) = | ||
if (i == 2 || i == 3) { | ||
(row.get(i).asInstanceOf[JDouble], collectedResult(j)(i).asInstanceOf[Double]) | ||
} else { | ||
(row.get(i).asInstanceOf[JInt], collectedResult(j)(i).asInstanceOf[Int].toDouble) | ||
} | ||
if (rowValue == null) { | ||
val expectedValue = expectedValues(i) | ||
assert(tolEq.areEquivalent(expectedValue, actualValue), | ||
s"Values do not match, expected: $expectedValue, result: $actualValue") | ||
} | ||
} | ||
} | ||
} | ||
|
||
override def createFitDataset: DataFrame = { | ||
createMockDataset | ||
} | ||
|
||
override def schemaForDataset: StructType = ??? | ||
|
||
override def getEstimator(): Estimator[_] = { | ||
val dataset = createFitDataset | ||
new CleanMissingData().setInputCols(dataset.columns).setOutputCols(dataset.columns) | ||
} | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters