Skip to content

Commit

Permalink
[ML] adds multi-class feature importance support (#53803)
Browse files Browse the repository at this point in the history
Adds multi-class feature importance calculation. 

Feature importance objects are now mapped as follows
(logistic) Regression:
```
{
   "feature_name": "feature_0",
   "importance": -1.3
}
```
Multi-class [class names are `foo`, `bar`, `baz`]
```
{ 
   “feature_name”: “feature_0”, 
   “importance”: 2.0, // sum(abs()) of class importances
   “foo”: 1.0, 
   “bar”: 0.5, 
   “baz”: -0.5 
},
```

For users to get the full benefit of aggregating and searching for feature importance, they should update their index mapping as follows (before turning this option on in their pipelines)
```
 "ml.inference.feature_importance": {
          "type": "nested",
          "dynamic": true,
          "properties": {
            "feature_name": {
              "type": "keyword"
            },
            "importance": {
              "type": "double"
            }
          }
        }
```
The mapping field name is as follows
`ml.<inference.target_field>.<inference.tag>.feature_importance`
if `inference.tag` is not provided in the processor definition, it is not part of the field path.
`inference.target_field` is defaulted to `ml.inference`.
//cc @lcawl ^ Where should we document this?

If this makes it in for 7.7, there shouldn't be any feature_importance at inference BWC worries as 7.7 is the first version to have it.
  • Loading branch information
benwtrent authored Mar 23, 2020
1 parent ecdbd37 commit 756a297
Show file tree
Hide file tree
Showing 18 changed files with 410 additions and 138 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -35,21 +35,21 @@ public ClassificationInferenceResults(double value,
String classificationLabel,
List<TopClassEntry> topClasses,
InferenceConfig config) {
this(value, classificationLabel, topClasses, Collections.emptyMap(), (ClassificationConfig)config);
this(value, classificationLabel, topClasses, Collections.emptyList(), (ClassificationConfig)config);
}

public ClassificationInferenceResults(double value,
String classificationLabel,
List<TopClassEntry> topClasses,
Map<String, Double> featureImportance,
List<FeatureImportance> featureImportance,
InferenceConfig config) {
this(value, classificationLabel, topClasses, featureImportance, (ClassificationConfig)config);
}

private ClassificationInferenceResults(double value,
String classificationLabel,
List<TopClassEntry> topClasses,
Map<String, Double> featureImportance,
List<FeatureImportance> featureImportance,
ClassificationConfig classificationConfig) {
super(value,
SingleValueInferenceResults.takeTopFeatureImportances(featureImportance,
Expand Down Expand Up @@ -118,7 +118,10 @@ public void writeResult(IngestDocument document, String parentResultField) {
topClasses.stream().map(TopClassEntry::asValueMap).collect(Collectors.toList()));
}
if (getFeatureImportance().size() > 0) {
document.setFieldValue(parentResultField + ".feature_importance", getFeatureImportance());
document.setFieldValue(parentResultField + ".feature_importance", getFeatureImportance()
.stream()
.map(FeatureImportance::toMap)
.collect(Collectors.toList()));
}
}

Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,97 @@
/*
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
* or more contributor license agreements. Licensed under the Elastic License;
* you may not use this file except in compliance with the Elastic License.
*/
package org.elasticsearch.xpack.core.ml.inference.results;

import org.elasticsearch.common.io.stream.StreamInput;
import org.elasticsearch.common.io.stream.StreamOutput;
import org.elasticsearch.common.io.stream.Writeable;

import java.io.IOException;
import java.util.Collections;
import java.util.LinkedHashMap;
import java.util.Map;
import java.util.Objects;

public class FeatureImportance implements Writeable {

private final Map<String, Double> classImportance;
private final double importance;
private final String featureName;
private static final String IMPORTANCE = "importance";
private static final String FEATURE_NAME = "feature_name";

public static FeatureImportance forRegression(String featureName, double importance) {
return new FeatureImportance(featureName, importance, null);
}

public static FeatureImportance forClassification(String featureName, Map<String, Double> classImportance) {
return new FeatureImportance(featureName, classImportance.values().stream().mapToDouble(Math::abs).sum(), classImportance);
}

private FeatureImportance(String featureName, double importance, Map<String, Double> classImportance) {
this.featureName = Objects.requireNonNull(featureName);
this.importance = importance;
this.classImportance = classImportance == null ? null : Collections.unmodifiableMap(classImportance);
}

public FeatureImportance(StreamInput in) throws IOException {
this.featureName = in.readString();
this.importance = in.readDouble();
if (in.readBoolean()) {
this.classImportance = in.readMap(StreamInput::readString, StreamInput::readDouble);
} else {
this.classImportance = null;
}
}

public Map<String, Double> getClassImportance() {
return classImportance;
}

public double getImportance() {
return importance;
}

public String getFeatureName() {
return featureName;
}

@Override
public void writeTo(StreamOutput out) throws IOException {
out.writeString(this.featureName);
out.writeDouble(this.importance);
out.writeBoolean(this.classImportance != null);
if (this.classImportance != null) {
out.writeMap(this.classImportance, StreamOutput::writeString, StreamOutput::writeDouble);
}
}

public Map<String, Object> toMap() {
Map<String, Object> map = new LinkedHashMap<>();
map.put(FEATURE_NAME, featureName);
map.put(IMPORTANCE, importance);
if (classImportance != null) {
classImportance.forEach(map::put);
}
return map;
}

@Override
public boolean equals(Object object) {
if (object == this) { return true; }
if (object == null || getClass() != object.getClass()) { return false; }
FeatureImportance that = (FeatureImportance) object;
return Objects.equals(featureName, that.featureName)
&& Objects.equals(importance, that.importance)
&& Objects.equals(classImportance, that.classImportance);
}

@Override
public int hashCode() {
return Objects.hash(featureName, importance, classImportance);
}

}
Original file line number Diff line number Diff line change
Expand Up @@ -18,9 +18,9 @@ public class RawInferenceResults implements InferenceResults {
public static final String NAME = "raw";

private final double[] value;
private final Map<String, Double> featureImportance;
private final Map<String, double[]> featureImportance;

public RawInferenceResults(double[] value, Map<String, Double> featureImportance) {
public RawInferenceResults(double[] value, Map<String, double[]> featureImportance) {
this.value = value;
this.featureImportance = featureImportance;
}
Expand All @@ -29,7 +29,7 @@ public double[] getValue() {
return value;
}

public Map<String, Double> getFeatureImportance() {
public Map<String, double[]> getFeatureImportance() {
return featureImportance;
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -14,8 +14,9 @@

import java.io.IOException;
import java.util.Collections;
import java.util.Map;
import java.util.List;
import java.util.Objects;
import java.util.stream.Collectors;

public class RegressionInferenceResults extends SingleValueInferenceResults {

Expand All @@ -24,14 +25,14 @@ public class RegressionInferenceResults extends SingleValueInferenceResults {
private final String resultsField;

public RegressionInferenceResults(double value, InferenceConfig config) {
this(value, (RegressionConfig) config, Collections.emptyMap());
this(value, (RegressionConfig) config, Collections.emptyList());
}

public RegressionInferenceResults(double value, InferenceConfig config, Map<String, Double> featureImportance) {
public RegressionInferenceResults(double value, InferenceConfig config, List<FeatureImportance> featureImportance) {
this(value, (RegressionConfig)config, featureImportance);
}

private RegressionInferenceResults(double value, RegressionConfig regressionConfig, Map<String, Double> featureImportance) {
private RegressionInferenceResults(double value, RegressionConfig regressionConfig, List<FeatureImportance> featureImportance) {
super(value,
SingleValueInferenceResults.takeTopFeatureImportances(featureImportance,
regressionConfig.getNumTopFeatureImportanceValues()));
Expand Down Expand Up @@ -70,7 +71,10 @@ public void writeResult(IngestDocument document, String parentResultField) {
ExceptionsHelper.requireNonNull(parentResultField, "parentResultField");
document.setFieldValue(parentResultField + "." + this.resultsField, value());
if (getFeatureImportance().size() > 0) {
document.setFieldValue(parentResultField + ".feature_importance", getFeatureImportance());
document.setFieldValue(parentResultField + ".feature_importance", getFeatureImportance()
.stream()
.map(FeatureImportance::toMap)
.collect(Collectors.toList()));
}
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -8,45 +8,46 @@
import org.elasticsearch.Version;
import org.elasticsearch.common.io.stream.StreamInput;
import org.elasticsearch.common.io.stream.StreamOutput;
import org.elasticsearch.xpack.core.ml.utils.ExceptionsHelper;

import java.io.IOException;
import java.util.Collections;
import java.util.LinkedHashMap;
import java.util.Map;
import java.util.List;
import java.util.stream.Collectors;

public abstract class SingleValueInferenceResults implements InferenceResults {

private final double value;
private final Map<String, Double> featureImportance;
private final List<FeatureImportance> featureImportance;

static Map<String, Double> takeTopFeatureImportances(Map<String, Double> unsortedFeatureImportances, int numTopFeatures) {
return unsortedFeatureImportances.entrySet()
.stream()
.sorted((l, r)-> Double.compare(Math.abs(r.getValue()), Math.abs(l.getValue())))
static List<FeatureImportance> takeTopFeatureImportances(List<FeatureImportance> unsortedFeatureImportances, int numTopFeatures) {
if (unsortedFeatureImportances == null || unsortedFeatureImportances.isEmpty()) {
return unsortedFeatureImportances;
}
return unsortedFeatureImportances.stream()
.sorted((l, r)-> Double.compare(Math.abs(r.getImportance()), Math.abs(l.getImportance())))
.limit(numTopFeatures)
.collect(LinkedHashMap::new, (h, e) -> h.put(e.getKey(), e.getValue()) , LinkedHashMap::putAll);
.collect(Collectors.toList());
}

SingleValueInferenceResults(StreamInput in) throws IOException {
value = in.readDouble();
if (in.getVersion().onOrAfter(Version.V_7_7_0)) {
this.featureImportance = in.readMap(StreamInput::readString, StreamInput::readDouble);
this.featureImportance = in.readList(FeatureImportance::new);
} else {
this.featureImportance = Collections.emptyMap();
this.featureImportance = Collections.emptyList();
}
}

SingleValueInferenceResults(double value, Map<String, Double> featureImportance) {
SingleValueInferenceResults(double value, List<FeatureImportance> featureImportance) {
this.value = value;
this.featureImportance = ExceptionsHelper.requireNonNull(featureImportance, "featureImportance");
this.featureImportance = featureImportance == null ? Collections.emptyList() : featureImportance;
}

public Double value() {
return value;
}

public Map<String, Double> getFeatureImportance() {
public List<FeatureImportance> getFeatureImportance() {
return featureImportance;
}

Expand All @@ -58,7 +59,7 @@ public String valueAsString() {
public void writeTo(StreamOutput out) throws IOException {
out.writeDouble(value);
if (out.getVersion().onOrAfter(Version.V_7_7_0)) {
out.writeMap(this.featureImportance, StreamOutput::writeString, StreamOutput::writeDouble);
out.writeList(this.featureImportance);
}
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -8,12 +8,14 @@
import org.elasticsearch.common.Nullable;
import org.elasticsearch.common.collect.Tuple;
import org.elasticsearch.xpack.core.ml.inference.results.ClassificationInferenceResults;
import org.elasticsearch.xpack.core.ml.inference.results.FeatureImportance;
import org.elasticsearch.xpack.core.ml.utils.ExceptionsHelper;

import java.util.ArrayList;
import java.util.Collections;
import java.util.Comparator;
import java.util.HashMap;
import java.util.LinkedHashMap;
import java.util.List;
import java.util.Map;
import java.util.stream.Collectors;
Expand Down Expand Up @@ -100,18 +102,46 @@ public static Double toDouble(Object value) {
return null;
}

public static Map<String, Double> decodeFeatureImportances(Map<String, String> processedFeatureToOriginalFeatureMap,
Map<String, Double> featureImportances) {
public static Map<String, double[]> decodeFeatureImportances(Map<String, String> processedFeatureToOriginalFeatureMap,
Map<String, double[]> featureImportances) {
if (processedFeatureToOriginalFeatureMap == null || processedFeatureToOriginalFeatureMap.isEmpty()) {
return featureImportances;
}

Map<String, Double> originalFeatureImportance = new HashMap<>();
Map<String, double[]> originalFeatureImportance = new HashMap<>();
featureImportances.forEach((feature, importance) -> {
String featureName = processedFeatureToOriginalFeatureMap.getOrDefault(feature, feature);
originalFeatureImportance.compute(featureName, (f, v1) -> v1 == null ? importance : v1 + importance);
originalFeatureImportance.compute(featureName, (f, v1) -> v1 == null ? importance : sumDoubleArrays(importance, v1));
});

return originalFeatureImportance;
}

public static List<FeatureImportance> transformFeatureImportance(Map<String, double[]> featureImportance,
@Nullable List<String> classificationLabels) {
List<FeatureImportance> importances = new ArrayList<>(featureImportance.size());
featureImportance.forEach((k, v) -> {
// This indicates regression, or logistic regression
// If the length > 1, we assume multi-class classification.
if (v.length == 1) {
importances.add(FeatureImportance.forRegression(k, v[0]));
} else {
Map<String, Double> classImportance = new LinkedHashMap<>(v.length, 1.0f);
// If the classificationLabels exist, their length must match leaf_value length
assert classificationLabels == null || classificationLabels.size() == v.length;
for (int i = 0; i < v.length; i++) {
classImportance.put(classificationLabels == null ? String.valueOf(i) : classificationLabels.get(i), v[i]);
}
importances.add(FeatureImportance.forClassification(k, classImportance));
}
});
return importances;
}

public static double[] sumDoubleArrays(double[] sumTo, double[] inc) {
assert sumTo != null && inc != null && sumTo.length == inc.length;
for (int i = 0; i < inc.length; i++) {
sumTo[i] += inc[i];
}
return sumTo;
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -60,9 +60,9 @@ public interface TrainedModel extends NamedXContentObject, NamedWriteable, Accou
* NOTE: Must be thread safe
* @param fields The fields inferring against
* @param featureDecoder A Map translating processed feature names to their original feature names
* @return A {@code Map<String, Double>} mapping each featureName to its importance
* @return A {@code Map<String, double[]>} mapping each featureName to its importance
*/
Map<String, Double> featureImportance(Map<String, Object> fields, Map<String, String> featureDecoder);
Map<String, double[]> featureImportance(Map<String, Object> fields, Map<String, String> featureDecoder);

default Version getMinimalCompatibilityVersion() {
return Version.V_7_6_0;
Expand Down
Loading

0 comments on commit 756a297

Please sign in to comment.