Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[7.x] [ML][Inference] Adding classification_weights to ensemble models (#50874) #50994

Merged
merged 1 commit into from
Jan 14, 2020
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@
import org.elasticsearch.common.xcontent.XContentParser;

import java.io.IOException;
import java.util.Arrays;
import java.util.Collections;
import java.util.List;
import java.util.Objects;
Expand All @@ -41,6 +42,7 @@ public class Ensemble implements TrainedModel {
public static final ParseField AGGREGATE_OUTPUT = new ParseField("aggregate_output");
public static final ParseField TARGET_TYPE = new ParseField("target_type");
public static final ParseField CLASSIFICATION_LABELS = new ParseField("classification_labels");
public static final ParseField CLASSIFICATION_WEIGHTS = new ParseField("classification_weights");

private static final ObjectParser<Builder, Void> PARSER = new ObjectParser<>(
NAME,
Expand All @@ -60,6 +62,7 @@ public class Ensemble implements TrainedModel {
AGGREGATE_OUTPUT);
PARSER.declareString(Ensemble.Builder::setTargetType, TARGET_TYPE);
PARSER.declareStringArray(Ensemble.Builder::setClassificationLabels, CLASSIFICATION_LABELS);
PARSER.declareDoubleArray(Ensemble.Builder::setClassificationWeights, CLASSIFICATION_WEIGHTS);
}

public static Ensemble fromXContent(XContentParser parser) {
Expand All @@ -71,17 +74,20 @@ public static Ensemble fromXContent(XContentParser parser) {
private final OutputAggregator outputAggregator;
private final TargetType targetType;
private final List<String> classificationLabels;
private final double[] classificationWeights;

Ensemble(List<String> featureNames,
List<TrainedModel> models,
@Nullable OutputAggregator outputAggregator,
TargetType targetType,
@Nullable List<String> classificationLabels) {
@Nullable List<String> classificationLabels,
@Nullable double[] classificationWeights) {
this.featureNames = featureNames;
this.models = models;
this.outputAggregator = outputAggregator;
this.targetType = targetType;
this.classificationLabels = classificationLabels;
this.classificationWeights = classificationWeights;
}

@Override
Expand Down Expand Up @@ -116,6 +122,9 @@ public XContentBuilder toXContent(XContentBuilder builder, ToXContent.Params par
if (classificationLabels != null) {
builder.field(CLASSIFICATION_LABELS.getPreferredName(), classificationLabels);
}
if (classificationWeights != null) {
builder.field(CLASSIFICATION_WEIGHTS.getPreferredName(), classificationWeights);
}
builder.endObject();
return builder;
}
Expand All @@ -129,12 +138,18 @@ public boolean equals(Object o) {
&& Objects.equals(models, that.models)
&& Objects.equals(targetType, that.targetType)
&& Objects.equals(classificationLabels, that.classificationLabels)
&& Arrays.equals(classificationWeights, that.classificationWeights)
&& Objects.equals(outputAggregator, that.outputAggregator);
}

@Override
public int hashCode() {
return Objects.hash(featureNames, models, outputAggregator, classificationLabels, targetType);
return Objects.hash(featureNames,
models,
outputAggregator,
classificationLabels,
targetType,
Arrays.hashCode(classificationWeights));
}

public static Builder builder() {
Expand All @@ -147,6 +162,7 @@ public static class Builder {
private OutputAggregator outputAggregator;
private TargetType targetType;
private List<String> classificationLabels;
private double[] classificationWeights;

public Builder setFeatureNames(List<String> featureNames) {
this.featureNames = featureNames;
Expand All @@ -173,6 +189,11 @@ public Builder setClassificationLabels(List<String> classificationLabels) {
return this;
}

public Builder setClassificationWeights(List<Double> classificationWeights) {
this.classificationWeights = classificationWeights.stream().mapToDouble(Double::doubleValue).toArray();
return this;
}

private void setOutputAggregatorFromParser(List<OutputAggregator> outputAggregators) {
this.setOutputAggregator(outputAggregators.get(0));
}
Expand All @@ -182,7 +203,7 @@ private void setTargetType(String targetType) {
}

public Ensemble build() {
return new Ensemble(featureNames, trainedModels, outputAggregator, targetType, classificationLabels);
return new Ensemble(featureNames, trainedModels, outputAggregator, targetType, classificationLabels, classificationWeights);
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -80,11 +80,19 @@ public static Ensemble createRandom(TargetType targetType) {
if (randomBoolean() && targetType.equals(TargetType.CLASSIFICATION)) {
categoryLabels = Arrays.asList(generateRandomStringArray(randomIntBetween(1, 10), randomIntBetween(1, 10), false, false));
}
double[] thresholds = randomBoolean() && targetType == TargetType.CLASSIFICATION ?
Stream.generate(ESTestCase::randomDouble)
.limit(categoryLabels == null ? randomIntBetween(1, 10) : categoryLabels.size())
.mapToDouble(Double::valueOf)
.toArray() :
null;

return new Ensemble(featureNames,
models,
outputAggregator,
targetType,
categoryLabels);
categoryLabels,
thresholds);
}

@Override
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -112,18 +112,26 @@ public static class TopClassEntry implements Writeable {

public final ParseField CLASS_NAME = new ParseField("class_name");
public final ParseField CLASS_PROBABILITY = new ParseField("class_probability");
public final ParseField CLASS_SCORE = new ParseField("class_score");

private final String classification;
private final double probability;
private final double score;

public TopClassEntry(String classification, Double probability) {
public TopClassEntry(String classification, double probability) {
this(classification, probability, probability);
}

public TopClassEntry(String classification, double probability, double score) {
this.classification = ExceptionsHelper.requireNonNull(classification, CLASS_NAME);
this.probability = ExceptionsHelper.requireNonNull(probability, CLASS_PROBABILITY);
this.probability = probability;
this.score = score;
}

public TopClassEntry(StreamInput in) throws IOException {
this.classification = in.readString();
this.probability = in.readDouble();
this.score = in.readDouble();
}

public String getClassification() {
Expand All @@ -134,31 +142,36 @@ public double getProbability() {
return probability;
}

public double getScore() {
return score;
}

public Map<String, Object> asValueMap() {
Map<String, Object> map = new HashMap<>(2);
Map<String, Object> map = new HashMap<>(3, 1.0f);
map.put(CLASS_NAME.getPreferredName(), classification);
map.put(CLASS_PROBABILITY.getPreferredName(), probability);
map.put(CLASS_SCORE.getPreferredName(), score);
return map;
}

@Override
public void writeTo(StreamOutput out) throws IOException {
out.writeString(classification);
out.writeDouble(probability);
out.writeDouble(score);
}

@Override
public boolean equals(Object object) {
if (object == this) { return true; }
if (object == null || getClass() != object.getClass()) { return false; }
TopClassEntry that = (TopClassEntry) object;
return Objects.equals(classification, that.classification) &&
Objects.equals(probability, that.probability);
return Objects.equals(classification, that.classification) && probability == that.probability && score == that.score;
}

@Override
public int hashCode() {
return Objects.hash(classification, probability);
return Objects.hash(classification, probability, score);
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
package org.elasticsearch.xpack.core.ml.inference.trainedmodel;

import org.elasticsearch.common.Nullable;
import org.elasticsearch.common.collect.Tuple;
import org.elasticsearch.xpack.core.ml.inference.results.ClassificationInferenceResults;
import org.elasticsearch.xpack.core.ml.utils.ExceptionsHelper;

Expand All @@ -20,25 +21,38 @@ public final class InferenceHelpers {

private InferenceHelpers() { }

public static List<ClassificationInferenceResults.TopClassEntry> topClasses(List<Double> probabilities,
List<String> classificationLabels,
int numToInclude) {
if (numToInclude == 0) {
return Collections.emptyList();
}
int[] sortedIndices = IntStream.range(0, probabilities.size())
.boxed()
.sorted(Comparator.comparing(probabilities::get).reversed())
.mapToInt(i -> i)
.toArray();
/**
* @return Tuple of the highest scored index and the top classes
*/
public static Tuple<Integer, List<ClassificationInferenceResults.TopClassEntry>> topClasses(List<Double> probabilities,
List<String> classificationLabels,
@Nullable double[] classificationWeights,
int numToInclude) {

if (classificationLabels != null && probabilities.size() != classificationLabels.size()) {
throw ExceptionsHelper
.serverError(
"model returned classification probabilities of size [{}] which is not equal to classification labels size [{}]",
null,
probabilities.size(),
classificationLabels);
classificationLabels.size());
}

List<Double> scores = classificationWeights == null ?
probabilities :
IntStream.range(0, probabilities.size())
.mapToDouble(i -> probabilities.get(i) * classificationWeights[i])
.boxed()
.collect(Collectors.toList());

int[] sortedIndices = IntStream.range(0, probabilities.size())
.boxed()
.sorted(Comparator.comparing(scores::get).reversed())
.mapToInt(i -> i)
.toArray();

if (numToInclude == 0) {
return Tuple.tuple(sortedIndices[0], Collections.emptyList());
}

List<String> labels = classificationLabels == null ?
Expand All @@ -50,26 +64,24 @@ public static List<ClassificationInferenceResults.TopClassEntry> topClasses(List
List<ClassificationInferenceResults.TopClassEntry> topClassEntries = new ArrayList<>(count);
for(int i = 0; i < count; i++) {
int idx = sortedIndices[i];
topClassEntries.add(new ClassificationInferenceResults.TopClassEntry(labels.get(idx), probabilities.get(idx)));
topClassEntries.add(new ClassificationInferenceResults.TopClassEntry(labels.get(idx), probabilities.get(idx), scores.get(idx)));
}

return topClassEntries;
return Tuple.tuple(sortedIndices[0], topClassEntries);
}

public static String classificationLabel(double inferenceValue, @Nullable List<String> classificationLabels) {
assert inferenceValue == Math.rint(inferenceValue);
public static String classificationLabel(Integer inferenceValue, @Nullable List<String> classificationLabels) {
if (classificationLabels == null) {
return String.valueOf(inferenceValue);
}
int label = Double.valueOf(inferenceValue).intValue();
if (label < 0 || label >= classificationLabels.size()) {
if (inferenceValue < 0 || inferenceValue >= classificationLabels.size()) {
throw ExceptionsHelper.serverError(
"model returned classification value of [{}] which is not a valid index in classification labels [{}]",
null,
label,
inferenceValue,
classificationLabels);
}
return classificationLabels.get(label);
return classificationLabels.get(inferenceValue);
}

public static Double toDouble(Object value) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -6,21 +6,14 @@
package org.elasticsearch.xpack.core.ml.inference.trainedmodel;

import org.apache.lucene.util.Accountable;
import org.elasticsearch.common.Nullable;
import org.elasticsearch.common.io.stream.NamedWriteable;
import org.elasticsearch.xpack.core.ml.inference.results.InferenceResults;
import org.elasticsearch.xpack.core.ml.utils.NamedXContentObject;

import java.util.List;
import java.util.Map;

public interface TrainedModel extends NamedXContentObject, NamedWriteable, Accountable {

/**
* @return List of featureNames expected by the model. In the order that they are expected
*/
List<String> getFeatureNames();

/**
* Infer against the provided fields
*
Expand All @@ -36,12 +29,6 @@ public interface TrainedModel extends NamedXContentObject, NamedWriteable, Accou
*/
TargetType targetType();

/**
* @return Ordinal encoded list of classification labels.
*/
@Nullable
List<String> classificationLabels();

/**
* Runs validations against the model.
*
Expand Down
Loading