Skip to content

Commit

Permalink
Descale feature contribution for Linear Regression & Logistic Regress…
Browse files Browse the repository at this point in the history
…ion (#345)
  • Loading branch information
TuanNguyen27 authored and leahmcguire committed Jul 25, 2019
1 parent 3e02bf7 commit 82bb2c1
Show file tree
Hide file tree
Showing 2 changed files with 187 additions and 7 deletions.
78 changes: 74 additions & 4 deletions core/src/main/scala/com/salesforce/op/ModelInsights.scala
Original file line number Diff line number Diff line change
Expand Up @@ -484,10 +484,12 @@ case object ModelInsights {
s" to fill in model insights"
)

val labelSummary = getLabelSummary(label, checkerSummary)

ModelInsights(
label = getLabelSummary(label, checkerSummary),
label = labelSummary,
features = getFeatureInsights(vectorInput, checkerSummary, model, rawFeatures,
blacklistedFeatures, blacklistedMapKeys, rawFeatureFilterResults),
blacklistedFeatures, blacklistedMapKeys, rawFeatureFilterResults, labelSummary),
selectedModelInfo = getModelInfo(model),
trainingParams = trainingParams,
stageInfo = RawFeatureFilterConfig.toStageInfo(rawFeatureFilterResults.rawFeatureFilterConfig)
Expand Down Expand Up @@ -537,7 +539,8 @@ case object ModelInsights {
rawFeatures: Array[features.OPFeature],
blacklistedFeatures: Array[features.OPFeature],
blacklistedMapKeys: Map[String, Set[String]],
rawFeatureFilterResults: RawFeatureFilterResults = RawFeatureFilterResults()
rawFeatureFilterResults: RawFeatureFilterResults = RawFeatureFilterResults(),
label: LabelSummary
): Seq[FeatureInsights] = {
val featureInsights = (vectorInfo, summary) match {
case (Some(v), Some(s)) =>
Expand All @@ -557,6 +560,42 @@ case object ModelInsights {
case _ => None
}
val keptIndex = indexInToIndexKept.get(h.index)
val featureStd = math.sqrt(getIfExists(h.index, s.featuresStatistics.variance).getOrElse(1.0))
val sparkFtrContrib = keptIndex
.map(i => contributions.map(_.applyOrElse(i, (_: Int) => 0.0))).getOrElse(Seq.empty)
val defaultLabelStd = 1.0
val labelStd = label.distribution match {
case Some(Continuous(_, _, _, variance)) =>
if (variance == 0) {
log.warn("The standard deviation of the label is zero, " +
"so the coefficients and intercepts of the model will be zeros, training is not needed.")
defaultLabelStd
}
else math.sqrt(variance)
case Some(Discrete(domain, prob)) =>
// mean = sum (x_i * p_i)
val mean = (domain zip prob).foldLeft(0.0) {
case (weightSum, (d, p)) => weightSum + d.toDouble * p
}
// variance = sum (x_i - mu)^2 * p_i
val discreteVariance = (domain zip prob).foldLeft(0.0) {
case (sqweightSum, (d, p)) => sqweightSum + (d.toDouble - mean) * (d.toDouble - mean) * p
}
if (discreteVariance == 0) {
log.warn("The standard deviation of the label is zero, " +
"so the coefficients and intercepts of the model will be zeros, training is not needed.")
defaultLabelStd
}
else math.sqrt(discreteVariance)
case Some(_) => {
log.warn("Failing to perform weight descaling because distribution is unsupported.")
defaultLabelStd
}
case None => {
log.warn("Label does not exist, please check your data")
defaultLabelStd
}
}

h.parentFeatureOrigins ->
Insights(
Expand All @@ -579,7 +618,8 @@ case object ModelInsights {
case _ => Map.empty[String, Double]
},
contribution =
keptIndex.map(i => contributions.map(_.applyOrElse(i, (_: Int) => 0.0))).getOrElse(Seq.empty),
descaleLRContrib(model, sparkFtrContrib, featureStd, labelStd).getOrElse(sparkFtrContrib),

min = getIfExists(h.index, s.featuresStatistics.min),
max = getIfExists(h.index, s.featuresStatistics.max),
mean = getIfExists(h.index, s.featuresStatistics.mean),
Expand Down Expand Up @@ -647,6 +687,36 @@ case object ModelInsights {
}
}

private[op] def descaleLRContrib(
model: Option[Model[_]],
sparkFtrContrib: Seq[Double],
featureStd: Double,
labelStd: Double): Option[Seq[Double]] = {
val stage = model.flatMap {
case m: SparkWrapperParams[_] => m.getSparkMlStage()
case _ => None
}
stage.collect {
case m: LogisticRegressionModel =>
if (m.getStandardization && sparkFtrContrib.nonEmpty) {
// scale entire feature contribution vector
// See https://think-lab.github.io/d/205/
// § 4.5.2 Standardized Interpretations, An Introduction to Categorical Data Analysis, Alan Agresti
sparkFtrContrib.map(_ * featureStd)
}
else sparkFtrContrib
case m: LinearRegressionModel =>
if (m.getStandardization && sparkFtrContrib.nonEmpty) {
// need to also divide by labelStd for linear regression
// See https://u.demog.berkeley.edu/~andrew/teaching/standard_coeff.pdf
// See https://en.wikipedia.org/wiki/Standardized_coefficient
sparkFtrContrib.map(_ * featureStd / labelStd)
}
else sparkFtrContrib
case _ => sparkFtrContrib
}
}

private[op] def getModelContributions
(model: Option[Model[_]], featureVectorSize: Option[Int] = None): Seq[Seq[Double]] = {
val stage = model.flatMap {
Expand Down
116 changes: 113 additions & 3 deletions core/src/test/scala/com/salesforce/op/ModelInsightsTest.scala
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,7 @@
package com.salesforce.op

import com.salesforce.op.features.types._
import com.salesforce.op.features.{Feature, FeatureDistributionType}
import com.salesforce.op.features.{Feature, FeatureDistributionType, FeatureLike}
import com.salesforce.op.filters._
import com.salesforce.op.stages.impl.classification._
import com.salesforce.op.stages.impl.preparators._
Expand All @@ -40,12 +40,15 @@ import com.salesforce.op.stages.impl.selector.ModelSelectorNames.EstimatorType
import com.salesforce.op.stages.impl.selector.SelectedModel
import com.salesforce.op.stages.impl.selector.ValidationType._
import com.salesforce.op.stages.impl.tuning.{DataCutter, DataSplitter}
import com.salesforce.op.test.PassengerSparkFixtureTest
import com.salesforce.op.test.{PassengerSparkFixtureTest, TestFeatureBuilder}
import com.salesforce.op.testkit.RandomReal
import com.salesforce.op.utils.spark.{OpVectorColumnMetadata, OpVectorMetadata}
import ml.dmlc.xgboost4j.scala.spark.OpXGBoostQuietLogging
import org.apache.spark.ml.param.ParamMap
import org.apache.spark.ml.tuning.ParamGridBuilder
import org.junit.runner.RunWith
import com.salesforce.op.features.types.Real
import org.apache.spark.sql.DataFrame
import org.scalatest.FlatSpec
import org.scalatest.junit.JUnitRunner

Expand Down Expand Up @@ -95,6 +98,72 @@ class ModelInsightsTest extends FlatSpec with PassengerSparkFixtureTest with Dou
.setInput(label, features)
.getOutput()

val smallFeatureVariance = 10.0
val mediumFeatureVariance = 1.0
val bigFeatureVariance = 100.0
val smallNorm = RandomReal.normal[Real](0.0, smallFeatureVariance).limit(1000)
val mediumNorm = RandomReal.normal[Real](10, mediumFeatureVariance).limit(1000)
val bigNorm = RandomReal.normal[Real](10000.0, bigFeatureVariance).limit(1000)
val noise = RandomReal.normal[Real](0.0, 100.0).limit(1000)
// make a simple linear combination of the features (with noise), pass through sigmoid function and binarize
// to make labels for logistic reg toy data
def binarize(x: Double): Int = {
val sigmoid = 1.0 / (1.0 + math.exp(-x))
if (sigmoid > 0.5) 1 else 0
}
val logisticRegLabel = (smallNorm, mediumNorm, noise)
.zipped.map(_.toDouble.get * 10 + _.toDouble.get + _.toDouble.get).map(binarize(_)).map(RealNN(_))
// toy label for linear reg is a sum of two scaled Normals, hence we also know its standard deviation
val linearRegLabel = (smallNorm, bigNorm)
.zipped.map(_.toDouble.get * 5000 + _.toDouble.get).map(RealNN(_))
val labelStd = math.sqrt(5000 * 5000 * smallFeatureVariance + bigFeatureVariance)

def twoFeatureDF(feature1: List[Real], feature2: List[Real], label: List[RealNN]):
(Feature[RealNN], FeatureLike[OPVector], DataFrame) = {
val generatedData = feature1.zip(feature2).zip(label).map {
case ((f1, f2), label) => (f1, f2, label)
}
val (rawDF, raw1, raw2, rawLabel) = TestFeatureBuilder("feature1", "feature2", "label", generatedData)
val labelData = rawLabel.copy(isResponse = true)
val featureVector = raw1
.vectorize(fillValue = 0, fillWithMean = true, trackNulls = false, others = Array(raw2))
val checkedFeatures = labelData.sanityCheck(featureVector, removeBadFeatures = false)
return (labelData, checkedFeatures, rawDF)
}

val linRegDF = twoFeatureDF(smallNorm, bigNorm, linearRegLabel)
val logRegDF = twoFeatureDF(smallNorm, mediumNorm, logisticRegLabel)

val unstandardizedLinpred = new OpLinearRegression().setStandardization(false)
.setInput(linRegDF._1, linRegDF._2).getOutput()

val standardizedLinpred = new OpLinearRegression().setStandardization(true)
.setInput(linRegDF._1, linRegDF._2).getOutput()

val unstandardizedLogpred = new OpLogisticRegression().setStandardization(false)
.setInput(logRegDF._1, logRegDF._2).getOutput()

val standardizedLogpred = new OpLogisticRegression().setStandardization(true)
.setInput(logRegDF._1, logRegDF._2).getOutput()

def getFeatureImp(standardizedModel: FeatureLike[Prediction],
unstandardizedModel: FeatureLike[Prediction],
DF: DataFrame): Array[Double] = {
lazy val workFlow = new OpWorkflow()
.setResultFeatures(standardizedModel, unstandardizedModel).setInputDataset(DF)
lazy val model = workFlow.train()
val unstandardizedFtImp = model.modelInsights(unstandardizedModel)
.features.map(_.derivedFeatures.map(_.contribution))
val standardizedFtImp = model.modelInsights(standardizedModel)
.features.map(_.derivedFeatures.map(_.contribution))
val descaledsmallCoeff = standardizedFtImp.flatten.flatten.head
val originalsmallCoeff = unstandardizedFtImp.flatten.flatten.head
val descaledbigCoeff = standardizedFtImp.flatten.flatten.last
val orginalbigCoeff = unstandardizedFtImp.flatten.flatten.last
return Array(descaledsmallCoeff, originalsmallCoeff, descaledbigCoeff, orginalbigCoeff)
}


val params = new OpParams()

lazy val workflow = new OpWorkflow().setResultFeatures(predLin, pred).setParameters(params).setReader(dataReader)
Expand Down Expand Up @@ -508,9 +577,11 @@ class ModelInsightsTest extends FlatSpec with PassengerSparkFixtureTest with Dou
}

it should "correctly extract the FeatureInsights from the sanity checker summary and vector metadata" in {
val labelSum = ModelInsights.getLabelSummary(Option(lbl), Option(summary))

val featureInsights = ModelInsights.getFeatureInsights(
Option(meta), Option(summary), None, Array(f1, f0), Array.empty, Map.empty[String, Set[String]],
RawFeatureFilterResults()
RawFeatureFilterResults(), labelSum
)
featureInsights.size shouldBe 2

Expand Down Expand Up @@ -651,4 +722,43 @@ class ModelInsightsTest extends FlatSpec with PassengerSparkFixtureTest with Dou
f.cramersV.isEmpty shouldBe true
}
}

val tol = 0.03
it should "correctly return the descaled coefficient for linear regression, " +
"when standardization is on" in {

// Since 5000 & 1 are always returned as the coefficients of the model
// trained on unstandardized data and we can analytically calculate
// the scaled version of them by the linear regression formula, the coefficients
// of the model trained on standardized data should be within a small distance of the analytical formula.

// difference between the real coefficient and the analytical formula
val coeffs = getFeatureImp(standardizedLinpred, unstandardizedLinpred, linRegDF._3)
val descaledsmallCoeff = coeffs(0)
val originalsmallCoeff = coeffs(1)
val descaledbigCoeff = coeffs(2)
val orginalbigCoeff = coeffs(3)
val absError = math.abs(orginalbigCoeff * math.sqrt(smallFeatureVariance) / labelStd - descaledbigCoeff)
val bigCoeffSum = orginalbigCoeff * math.sqrt(smallFeatureVariance) / labelStd + descaledbigCoeff
val absError2 = math.abs(originalsmallCoeff * math.sqrt(bigFeatureVariance) / labelStd - descaledsmallCoeff)
val smallCoeffSum = originalsmallCoeff * math.sqrt(bigFeatureVariance) / labelStd + descaledsmallCoeff
absError / bigCoeffSum < tol shouldBe true
absError2 / smallCoeffSum < tol shouldBe true
}

it should "correctly return the descaled coefficient for logistic regression, " +
"when standardization is on" in {
val coeffs = getFeatureImp(standardizedLogpred, unstandardizedLogpred, logRegDF._3)
val descaledsmallCoeff = coeffs(0)
val originalsmallCoeff = coeffs(1)
val descaledbigCoeff = coeffs(2)
val orginalbigCoeff = coeffs(3)
// difference between the real coefficient and the analytical formula
val absError = math.abs(orginalbigCoeff * math.sqrt(smallFeatureVariance) - descaledbigCoeff)
val bigCoeffSum = orginalbigCoeff * math.sqrt(smallFeatureVariance) + descaledbigCoeff
val absError2 = math.abs(originalsmallCoeff * math.sqrt(mediumFeatureVariance) - descaledsmallCoeff)
val smallCoeffSum = originalsmallCoeff * math.sqrt(mediumFeatureVariance) + descaledsmallCoeff
absError / bigCoeffSum < tol shouldBe true
absError2 / smallCoeffSum < tol shouldBe true
}
}

0 comments on commit 82bb2c1

Please sign in to comment.