Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[ML] adds multi-class feature importance support #53803

Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -35,21 +35,21 @@ public ClassificationInferenceResults(double value,
String classificationLabel,
List<TopClassEntry> topClasses,
InferenceConfig config) {
this(value, classificationLabel, topClasses, Collections.emptyMap(), (ClassificationConfig)config);
this(value, classificationLabel, topClasses, Collections.emptyList(), (ClassificationConfig)config);
}

public ClassificationInferenceResults(double value,
String classificationLabel,
List<TopClassEntry> topClasses,
Map<String, Double> featureImportance,
List<FeatureImportance> featureImportance,
InferenceConfig config) {
this(value, classificationLabel, topClasses, featureImportance, (ClassificationConfig)config);
}

private ClassificationInferenceResults(double value,
String classificationLabel,
List<TopClassEntry> topClasses,
Map<String, Double> featureImportance,
List<FeatureImportance> featureImportance,
ClassificationConfig classificationConfig) {
super(value,
SingleValueInferenceResults.takeTopFeatureImportances(featureImportance,
Expand Down Expand Up @@ -118,7 +118,10 @@ public void writeResult(IngestDocument document, String parentResultField) {
topClasses.stream().map(TopClassEntry::asValueMap).collect(Collectors.toList()));
}
if (getFeatureImportance().size() > 0) {
document.setFieldValue(parentResultField + ".feature_importance", getFeatureImportance());
document.setFieldValue(parentResultField + ".feature_importance", getFeatureImportance()
.stream()
.map(FeatureImportance::toMap)
.collect(Collectors.toList()));
}
}

Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,97 @@
/*
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
* or more contributor license agreements. Licensed under the Elastic License;
* you may not use this file except in compliance with the Elastic License.
*/
package org.elasticsearch.xpack.core.ml.inference.results;

import org.elasticsearch.common.io.stream.StreamInput;
import org.elasticsearch.common.io.stream.StreamOutput;
import org.elasticsearch.common.io.stream.Writeable;

import java.io.IOException;
import java.util.Collections;
import java.util.LinkedHashMap;
import java.util.Map;
import java.util.Objects;

public class FeatureImportance implements Writeable {

private final Map<String, Double> classImportance;
private final double importance;
private final String featureName;
private static final String IMPORTANCE = "importance";
private static final String FEATURE_NAME = "feature_name";

public static FeatureImportance forRegression(String featureName, double importance) {
return new FeatureImportance(featureName, importance, null);
}

public static FeatureImportance forClassification(String featureName, Map<String, Double> classImportance) {
return new FeatureImportance(featureName, classImportance.values().stream().mapToDouble(Math::abs).sum(), classImportance);
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I am not 100% convinced this should be abs.

We don't write the feature importance value on the native side by looking at the norm of the vector.

Do we want to make this the norm too? Or do we thing abs is good enough?

@tveasey @valeriy42

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can you please provide more context. What are you calculating here?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@valeriy42 @tveasey this is calculating the "overall importance" of all the classes combined for a given feature. This is so we can measure "most important feature" independent of the classes.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

norm would make it an L2 norm, abs makes it an L1 norm. Either way is suitable. I think, abs is better, since norm over-treats larger importances and ignores smaller once.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

+1 abs

}

private FeatureImportance(String featureName, double importance, Map<String, Double> classImportance) {
this.featureName = Objects.requireNonNull(featureName);
this.importance = importance;
this.classImportance = classImportance == null ? null : Collections.unmodifiableMap(classImportance);
}

public FeatureImportance(StreamInput in) throws IOException {
this.featureName = in.readString();
this.importance = in.readDouble();
if (in.readBoolean()) {
this.classImportance = in.readMap(StreamInput::readString, StreamInput::readDouble);
} else {
this.classImportance = null;
}
}

public Map<String, Double> getClassImportance() {
return classImportance;
}

public double getImportance() {
return importance;
}

public String getFeatureName() {
return featureName;
}

@Override
public void writeTo(StreamOutput out) throws IOException {
out.writeString(this.featureName);
out.writeDouble(this.importance);
out.writeBoolean(this.classImportance != null);
if (this.classImportance != null) {
out.writeMap(this.classImportance, StreamOutput::writeString, StreamOutput::writeDouble);
}
}

public Map<String, Object> toMap() {
Map<String, Object> map = new LinkedHashMap<>();
map.put(FEATURE_NAME, featureName);
map.put(IMPORTANCE, importance);
if (classImportance != null) {
classImportance.forEach(map::put);
}
return map;
}

@Override
public boolean equals(Object object) {
if (object == this) { return true; }
if (object == null || getClass() != object.getClass()) { return false; }
FeatureImportance that = (FeatureImportance) object;
return Objects.equals(featureName, that.featureName)
&& Objects.equals(importance, that.importance)
&& Objects.equals(classImportance, that.classImportance);
}

@Override
public int hashCode() {
return Objects.hash(featureName, importance, classImportance);
}

}
Original file line number Diff line number Diff line change
Expand Up @@ -18,9 +18,9 @@ public class RawInferenceResults implements InferenceResults {
public static final String NAME = "raw";

private final double[] value;
private final Map<String, Double> featureImportance;
private final Map<String, double[]> featureImportance;

public RawInferenceResults(double[] value, Map<String, Double> featureImportance) {
public RawInferenceResults(double[] value, Map<String, double[]> featureImportance) {
this.value = value;
this.featureImportance = featureImportance;
}
Expand All @@ -29,7 +29,7 @@ public double[] getValue() {
return value;
}

public Map<String, Double> getFeatureImportance() {
public Map<String, double[]> getFeatureImportance() {
return featureImportance;
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -14,8 +14,9 @@

import java.io.IOException;
import java.util.Collections;
import java.util.Map;
import java.util.List;
import java.util.Objects;
import java.util.stream.Collectors;

public class RegressionInferenceResults extends SingleValueInferenceResults {

Expand All @@ -24,14 +25,14 @@ public class RegressionInferenceResults extends SingleValueInferenceResults {
private final String resultsField;

public RegressionInferenceResults(double value, InferenceConfig config) {
this(value, (RegressionConfig) config, Collections.emptyMap());
this(value, (RegressionConfig) config, Collections.emptyList());
}

public RegressionInferenceResults(double value, InferenceConfig config, Map<String, Double> featureImportance) {
public RegressionInferenceResults(double value, InferenceConfig config, List<FeatureImportance> featureImportance) {
this(value, (RegressionConfig)config, featureImportance);
}

private RegressionInferenceResults(double value, RegressionConfig regressionConfig, Map<String, Double> featureImportance) {
private RegressionInferenceResults(double value, RegressionConfig regressionConfig, List<FeatureImportance> featureImportance) {
super(value,
SingleValueInferenceResults.takeTopFeatureImportances(featureImportance,
regressionConfig.getNumTopFeatureImportanceValues()));
Expand Down Expand Up @@ -70,7 +71,10 @@ public void writeResult(IngestDocument document, String parentResultField) {
ExceptionsHelper.requireNonNull(parentResultField, "parentResultField");
document.setFieldValue(parentResultField + "." + this.resultsField, value());
if (getFeatureImportance().size() > 0) {
document.setFieldValue(parentResultField + ".feature_importance", getFeatureImportance());
document.setFieldValue(parentResultField + ".feature_importance", getFeatureImportance()
.stream()
.map(FeatureImportance::toMap)
.collect(Collectors.toList()));
}
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -8,45 +8,46 @@
import org.elasticsearch.Version;
import org.elasticsearch.common.io.stream.StreamInput;
import org.elasticsearch.common.io.stream.StreamOutput;
import org.elasticsearch.xpack.core.ml.utils.ExceptionsHelper;

import java.io.IOException;
import java.util.Collections;
import java.util.LinkedHashMap;
import java.util.Map;
import java.util.List;
import java.util.stream.Collectors;

public abstract class SingleValueInferenceResults implements InferenceResults {

private final double value;
private final Map<String, Double> featureImportance;
private final List<FeatureImportance> featureImportance;

static Map<String, Double> takeTopFeatureImportances(Map<String, Double> unsortedFeatureImportances, int numTopFeatures) {
return unsortedFeatureImportances.entrySet()
.stream()
.sorted((l, r)-> Double.compare(Math.abs(r.getValue()), Math.abs(l.getValue())))
static List<FeatureImportance> takeTopFeatureImportances(List<FeatureImportance> unsortedFeatureImportances, int numTopFeatures) {
if (unsortedFeatureImportances == null || unsortedFeatureImportances.isEmpty()) {
return unsortedFeatureImportances;
}
return unsortedFeatureImportances.stream()
.sorted((l, r)-> Double.compare(Math.abs(r.getImportance()), Math.abs(l.getImportance())))
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is the abs necessary when the score is a norm? If the score can be -ve why is it wrong to use the -ve value?

Copy link
Member Author

@benwtrent benwtrent Mar 23, 2020

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@davidkyle

Score is not absolutely the norm. Additionally, we want to have the MOST influential values, regardless of direction. We could have feature importances like this:

{
A: -1.2,
B: -0.2,
C: 0.5
}

If we want the top two influential features, we want A and C.

The getImportance is only the norm when it comes to multi-class. This is not the case for (logistic) regression.

.limit(numTopFeatures)
.collect(LinkedHashMap::new, (h, e) -> h.put(e.getKey(), e.getValue()) , LinkedHashMap::putAll);
.collect(Collectors.toList());
}

SingleValueInferenceResults(StreamInput in) throws IOException {
value = in.readDouble();
if (in.getVersion().onOrAfter(Version.V_7_7_0)) {
this.featureImportance = in.readMap(StreamInput::readString, StreamInput::readDouble);
this.featureImportance = in.readList(FeatureImportance::new);
} else {
this.featureImportance = Collections.emptyMap();
this.featureImportance = Collections.emptyList();
}
}

SingleValueInferenceResults(double value, Map<String, Double> featureImportance) {
SingleValueInferenceResults(double value, List<FeatureImportance> featureImportance) {
this.value = value;
this.featureImportance = ExceptionsHelper.requireNonNull(featureImportance, "featureImportance");
this.featureImportance = featureImportance == null ? Collections.emptyList() : featureImportance;
}

public Double value() {
return value;
}

public Map<String, Double> getFeatureImportance() {
public List<FeatureImportance> getFeatureImportance() {
return featureImportance;
}

Expand All @@ -58,7 +59,7 @@ public String valueAsString() {
public void writeTo(StreamOutput out) throws IOException {
out.writeDouble(value);
if (out.getVersion().onOrAfter(Version.V_7_7_0)) {
out.writeMap(this.featureImportance, StreamOutput::writeString, StreamOutput::writeDouble);
out.writeList(this.featureImportance);
}
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -8,12 +8,14 @@
import org.elasticsearch.common.Nullable;
import org.elasticsearch.common.collect.Tuple;
import org.elasticsearch.xpack.core.ml.inference.results.ClassificationInferenceResults;
import org.elasticsearch.xpack.core.ml.inference.results.FeatureImportance;
import org.elasticsearch.xpack.core.ml.utils.ExceptionsHelper;

import java.util.ArrayList;
import java.util.Collections;
import java.util.Comparator;
import java.util.HashMap;
import java.util.LinkedHashMap;
import java.util.List;
import java.util.Map;
import java.util.stream.Collectors;
Expand Down Expand Up @@ -100,18 +102,46 @@ public static Double toDouble(Object value) {
return null;
}

public static Map<String, Double> decodeFeatureImportances(Map<String, String> processedFeatureToOriginalFeatureMap,
Map<String, Double> featureImportances) {
public static Map<String, double[]> decodeFeatureImportances(Map<String, String> processedFeatureToOriginalFeatureMap,
Map<String, double[]> featureImportances) {
if (processedFeatureToOriginalFeatureMap == null || processedFeatureToOriginalFeatureMap.isEmpty()) {
return featureImportances;
}

Map<String, Double> originalFeatureImportance = new HashMap<>();
Map<String, double[]> originalFeatureImportance = new HashMap<>();
featureImportances.forEach((feature, importance) -> {
String featureName = processedFeatureToOriginalFeatureMap.getOrDefault(feature, feature);
originalFeatureImportance.compute(featureName, (f, v1) -> v1 == null ? importance : v1 + importance);
originalFeatureImportance.compute(featureName, (f, v1) -> v1 == null ? importance : sumDoubleArrays(importance, v1));
});

return originalFeatureImportance;
}

public static List<FeatureImportance> transformFeatureImportance(Map<String, double[]> featureImportance,
@Nullable List<String> classificationLabels) {
List<FeatureImportance> importances = new ArrayList<>(featureImportance.size());
featureImportance.forEach((k, v) -> {
// This indicates regression, or logistic regression
// If the length > 1, we assume multi-class classification.
if (v.length == 1) {
importances.add(FeatureImportance.forRegression(k, v[0]));
} else {
Map<String, Double> classImportance = new LinkedHashMap<>(v.length, 1.0f);
// If the classificationLabels exist, their length must match leaf_value length
assert classificationLabels == null || classificationLabels.size() == v.length;
for (int i = 0; i < v.length; i++) {
classImportance.put(classificationLabels == null ? String.valueOf(i) : classificationLabels.get(i), v[i]);
}
importances.add(FeatureImportance.forClassification(k, classImportance));
}
});
return importances;
}

public static double[] sumDoubleArrays(double[] sumTo, double[] inc) {
assert sumTo != null && inc != null && sumTo.length == inc.length;
for (int i = 0; i < inc.length; i++) {
sumTo[i] += inc[i];
}
return sumTo;
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -60,9 +60,9 @@ public interface TrainedModel extends NamedXContentObject, NamedWriteable, Accou
* NOTE: Must be thread safe
* @param fields The fields inferring against
* @param featureDecoder A Map translating processed feature names to their original feature names
* @return A {@code Map<String, Double>} mapping each featureName to its importance
* @return A {@code Map<String, double[]>} mapping each featureName to its importance
*/
Map<String, Double> featureImportance(Map<String, Object> fields, Map<String, String> featureDecoder);
Map<String, double[]> featureImportance(Map<String, Object> fields, Map<String, String> featureDecoder);

default Version getMinimalCompatibilityVersion() {
return Version.V_7_6_0;
Expand Down
Loading