Skip to content

Commit

Permalink
[7.x] Pass prediction_field_type to C++ analytics process (#49861) (#…
Browse files Browse the repository at this point in the history
  • Loading branch information
przemekwitek authored Dec 9, 2019
1 parent 049d854 commit 0965a10
Show file tree
Hide file tree
Showing 21 changed files with 313 additions and 127 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -67,6 +67,12 @@ public static Classification fromXContent(XContentParser parser, boolean ignoreU
.flatMap(Set::stream)
.collect(Collectors.toSet()));

/**
* Name of the parameter passed down to C++.
* This parameter is used to decide which JSON data type from {string, int, bool} to use when writing the prediction.
*/
private static final String PREDICTION_FIELD_TYPE = "prediction_field_type";

/**
* As long as we only support binary classification it makes sense to always report both classes with their probabilities.
* This way the user can see if the prediction was made with confidence they need.
Expand Down Expand Up @@ -154,17 +160,38 @@ public XContentBuilder toXContent(XContentBuilder builder, Params params) throws
}

@Override
public Map<String, Object> getParams() {
public Map<String, Object> getParams(Map<String, Set<String>> extractedFields) {
Map<String, Object> params = new HashMap<>();
params.put(DEPENDENT_VARIABLE.getPreferredName(), dependentVariable);
params.putAll(boostedTreeParams.getParams());
params.put(NUM_TOP_CLASSES.getPreferredName(), numTopClasses);
if (predictionFieldName != null) {
params.put(PREDICTION_FIELD_NAME.getPreferredName(), predictionFieldName);
}
String predictionFieldType = getPredictionFieldType(extractedFields.get(dependentVariable));
if (predictionFieldType != null) {
params.put(PREDICTION_FIELD_TYPE, predictionFieldType);
}
return params;
}

private static String getPredictionFieldType(Set<String> dependentVariableTypes) {
if (dependentVariableTypes == null) {
return null;
}
if (Types.categorical().containsAll(dependentVariableTypes)) {
return "string";
}
if (Types.bool().containsAll(dependentVariableTypes)) {
return "bool";
}
if (Types.discreteNumerical().containsAll(dependentVariableTypes)) {
// C++ process uses int64_t type, so it is safe for the dependent variable to use long numbers.
return "int";
}
return null;
}

@Override
public boolean supportsCategoricalFields() {
return true;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -16,8 +16,9 @@ public interface DataFrameAnalysis extends ToXContentObject, NamedWriteable {

/**
* @return The analysis parameters as a map
* @param extractedFields map of (name, types) for all the extracted fields
*/
Map<String, Object> getParams();
Map<String, Object> getParams(Map<String, Set<String>> extractedFields);

/**
* @return {@code true} if this analysis supports fields with categorical values (i.e. text, keyword, ip)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -192,7 +192,7 @@ public int hashCode() {
}

@Override
public Map<String, Object> getParams() {
public Map<String, Object> getParams(Map<String, Set<String>> extractedFields) {
Map<String, Object> params = new HashMap<>();
if (nNeighbors != null) {
params.put(N_NEIGHBORS.getPreferredName(), nNeighbors);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -124,7 +124,7 @@ public XContentBuilder toXContent(XContentBuilder builder, Params params) throws
}

@Override
public Map<String, Object> getParams() {
public Map<String, Object> getParams(Map<String, Set<String>> extractedFields) {
Map<String, Object> params = new HashMap<>();
params.put(DEPENDENT_VARIABLE.getPreferredName(), dependentVariable);
params.putAll(boostedTreeParams.getParams());
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -8,11 +8,20 @@
import org.elasticsearch.ElasticsearchStatusException;
import org.elasticsearch.common.io.stream.Writeable;
import org.elasticsearch.common.xcontent.XContentParser;
import org.elasticsearch.index.mapper.BooleanFieldMapper;
import org.elasticsearch.index.mapper.KeywordFieldMapper;
import org.elasticsearch.index.mapper.NumberFieldMapper;
import org.elasticsearch.test.AbstractSerializingTestCase;
import org.hamcrest.Matchers;

import java.io.IOException;
import java.util.Collections;
import java.util.HashMap;
import java.util.Map;
import java.util.Set;

import static org.hamcrest.Matchers.equalTo;
import static org.hamcrest.Matchers.hasEntry;
import static org.hamcrest.Matchers.is;
import static org.hamcrest.Matchers.not;
import static org.hamcrest.Matchers.nullValue;
Expand Down Expand Up @@ -115,6 +124,34 @@ public void testGetTrainingPercent() {
assertThat(classification.getTrainingPercent(), equalTo(100.0));
}

public void testGetParams() {
Map<String, Set<String>> extractedFields = new HashMap<>(3);
extractedFields.put("foo", Collections.singleton(BooleanFieldMapper.CONTENT_TYPE));
extractedFields.put("bar", Collections.singleton(NumberFieldMapper.NumberType.LONG.typeName()));
extractedFields.put("baz", Collections.singleton(KeywordFieldMapper.CONTENT_TYPE));
assertThat(
new Classification("foo").getParams(extractedFields),
Matchers.<Map<String, Object>>allOf(
hasEntry("dependent_variable", "foo"),
hasEntry("num_top_classes", 2),
hasEntry("prediction_field_name", "foo_prediction"),
hasEntry("prediction_field_type", "bool")));
assertThat(
new Classification("bar").getParams(extractedFields),
Matchers.<Map<String, Object>>allOf(
hasEntry("dependent_variable", "bar"),
hasEntry("num_top_classes", 2),
hasEntry("prediction_field_name", "bar_prediction"),
hasEntry("prediction_field_type", "int")));
assertThat(
new Classification("baz").getParams(extractedFields),
Matchers.<Map<String, Object>>allOf(
hasEntry("dependent_variable", "baz"),
hasEntry("num_top_classes", 2),
hasEntry("prediction_field_name", "baz_prediction"),
hasEntry("prediction_field_type", "string")));
}

public void testFieldCardinalityLimitsIsNonNull() {
assertThat(createTestInstance().getFieldCardinalityLimits(), is(not(nullValue())));
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -51,7 +51,7 @@ protected Writeable.Reader<OutlierDetection> instanceReader() {

public void testGetParams_GivenDefaults() {
OutlierDetection outlierDetection = new OutlierDetection.Builder().build();
Map<String, Object> params = outlierDetection.getParams();
Map<String, Object> params = outlierDetection.getParams(null);
assertThat(params.size(), equalTo(3));
assertThat(params.containsKey("compute_feature_influence"), is(true));
assertThat(params.get("compute_feature_influence"), is(true));
Expand All @@ -71,7 +71,7 @@ public void testGetParams_GivenExplicitValues() {
.setStandardizationEnabled(false)
.build();

Map<String, Object> params = outlierDetection.getParams();
Map<String, Object> params = outlierDetection.getParams(null);

assertThat(params.size(), equalTo(6));
assertThat(params.get(OutlierDetection.N_NEIGHBORS.getPreferredName()), equalTo(42));
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,9 @@

import java.io.IOException;

import static org.hamcrest.Matchers.allOf;
import static org.hamcrest.Matchers.equalTo;
import static org.hamcrest.Matchers.hasEntry;
import static org.hamcrest.Matchers.is;
import static org.hamcrest.Matchers.not;
import static org.hamcrest.Matchers.nullValue;
Expand Down Expand Up @@ -83,6 +85,12 @@ public void testGetTrainingPercent() {
assertThat(regression.getTrainingPercent(), equalTo(100.0));
}

public void testGetParams() {
assertThat(
new Regression("foo").getParams(null),
allOf(hasEntry("dependent_variable", "foo"), hasEntry("prediction_field_name", "foo_prediction")));
}

public void testFieldCardinalityLimitsIsNonNull() {
assertThat(createTestInstance().getFieldCardinalityLimits(), is(not(nullValue())));
}
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,21 @@
/*
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
* or more contributor license agreements. Licensed under the Elastic License;
* you may not use this file except in compliance with the Elastic License.
*/
package org.elasticsearch.xpack.core.ml.dataframe.analyses;

import org.elasticsearch.common.util.set.Sets;
import org.elasticsearch.test.ESTestCase;

import static org.hamcrest.Matchers.empty;

public class TypesTests extends ESTestCase {

public void testTypes() {
assertThat(Sets.intersection(Types.bool(), Types.categorical()), empty());
assertThat(Sets.intersection(Types.categorical(), Types.numerical()), empty());
assertThat(Sets.intersection(Types.numerical(), Types.bool()), empty());
assertThat(Sets.difference(Types.discreteNumerical(), Types.numerical()), empty());
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -28,8 +28,12 @@ public class ClassificationEvaluationIT extends MlNativeDataFrameAnalyticsIntegT

private static final String ANIMALS_DATA_INDEX = "test-evaluate-animals-index";

private static final String ACTUAL_CLASS_FIELD = "actual_class_field";
private static final String PREDICTED_CLASS_FIELD = "predicted_class_field";
private static final String ANIMAL_NAME_FIELD = "animal_name";
private static final String ANIMAL_NAME_PREDICTION_FIELD = "animal_name_prediction";
private static final String NO_LEGS_FIELD = "no_legs";
private static final String NO_LEGS_PREDICTION_FIELD = "no_legs_prediction";
private static final String IS_PREDATOR_FIELD = "predator";
private static final String IS_PREDATOR_PREDICTION_FIELD = "predator_prediction";

@Before
public void setup() {
Expand All @@ -41,9 +45,9 @@ public void cleanup() {
cleanUp();
}

public void testEvaluate_MulticlassClassification_DefaultMetrics() {
public void testEvaluate_DefaultMetrics() {
EvaluateDataFrameAction.Response evaluateDataFrameResponse =
evaluateDataFrame(ANIMALS_DATA_INDEX, new Classification(ACTUAL_CLASS_FIELD, PREDICTED_CLASS_FIELD, null));
evaluateDataFrame(ANIMALS_DATA_INDEX, new Classification(ANIMAL_NAME_FIELD, ANIMAL_NAME_PREDICTION_FIELD, null));

assertThat(evaluateDataFrameResponse.getEvaluationName(), equalTo(Classification.NAME.getPreferredName()));
assertThat(evaluateDataFrameResponse.getMetrics(), hasSize(1));
Expand All @@ -52,10 +56,10 @@ public void testEvaluate_MulticlassClassification_DefaultMetrics() {
equalTo(MulticlassConfusionMatrix.NAME.getPreferredName()));
}

public void testEvaluate_MulticlassClassification_Accuracy() {
public void testEvaluate_Accuracy_KeywordField() {
EvaluateDataFrameAction.Response evaluateDataFrameResponse =
evaluateDataFrame(
ANIMALS_DATA_INDEX, new Classification(ACTUAL_CLASS_FIELD, PREDICTED_CLASS_FIELD, Arrays.asList(new Accuracy())));
ANIMALS_DATA_INDEX, new Classification(ANIMAL_NAME_FIELD, ANIMAL_NAME_PREDICTION_FIELD, Arrays.asList(new Accuracy())));

assertThat(evaluateDataFrameResponse.getEvaluationName(), equalTo(Classification.NAME.getPreferredName()));
assertThat(evaluateDataFrameResponse.getMetrics(), hasSize(1));
Expand All @@ -74,11 +78,50 @@ public void testEvaluate_MulticlassClassification_Accuracy() {
assertThat(accuracyResult.getOverallAccuracy(), equalTo(5.0 / 75));
}

public void testEvaluate_MulticlassClassification_AccuracyAndConfusionMatrixMetricWithDefaultSize() {
public void testEvaluate_Accuracy_IntegerField() {
EvaluateDataFrameAction.Response evaluateDataFrameResponse =
evaluateDataFrame(
ANIMALS_DATA_INDEX, new Classification(NO_LEGS_FIELD, NO_LEGS_PREDICTION_FIELD, Arrays.asList(new Accuracy())));

assertThat(evaluateDataFrameResponse.getEvaluationName(), equalTo(Classification.NAME.getPreferredName()));
assertThat(evaluateDataFrameResponse.getMetrics(), hasSize(1));

Accuracy.Result accuracyResult = (Accuracy.Result) evaluateDataFrameResponse.getMetrics().get(0);
assertThat(accuracyResult.getMetricName(), equalTo(Accuracy.NAME.getPreferredName()));
assertThat(
accuracyResult.getActualClasses(),
equalTo(Arrays.asList(
new Accuracy.ActualClass("1", 15, 1.0 / 15),
new Accuracy.ActualClass("2", 15, 2.0 / 15),
new Accuracy.ActualClass("3", 15, 3.0 / 15),
new Accuracy.ActualClass("4", 15, 4.0 / 15),
new Accuracy.ActualClass("5", 15, 5.0 / 15))));
assertThat(accuracyResult.getOverallAccuracy(), equalTo(15.0 / 75));
}

public void testEvaluate_Accuracy_BooleanField() {
EvaluateDataFrameAction.Response evaluateDataFrameResponse =
evaluateDataFrame(
ANIMALS_DATA_INDEX, new Classification(IS_PREDATOR_FIELD, IS_PREDATOR_PREDICTION_FIELD, Arrays.asList(new Accuracy())));

assertThat(evaluateDataFrameResponse.getEvaluationName(), equalTo(Classification.NAME.getPreferredName()));
assertThat(evaluateDataFrameResponse.getMetrics(), hasSize(1));

Accuracy.Result accuracyResult = (Accuracy.Result) evaluateDataFrameResponse.getMetrics().get(0);
assertThat(accuracyResult.getMetricName(), equalTo(Accuracy.NAME.getPreferredName()));
assertThat(
accuracyResult.getActualClasses(),
equalTo(Arrays.asList(
new Accuracy.ActualClass("true", 45, 27.0 / 45),
new Accuracy.ActualClass("false", 30, 18.0 / 30))));
assertThat(accuracyResult.getOverallAccuracy(), equalTo(45.0 / 75));
}

public void testEvaluate_ConfusionMatrixMetricWithDefaultSize() {
EvaluateDataFrameAction.Response evaluateDataFrameResponse =
evaluateDataFrame(
ANIMALS_DATA_INDEX,
new Classification(ACTUAL_CLASS_FIELD, PREDICTED_CLASS_FIELD, Arrays.asList(new MulticlassConfusionMatrix())));
new Classification(ANIMAL_NAME_FIELD, ANIMAL_NAME_PREDICTION_FIELD, Arrays.asList(new MulticlassConfusionMatrix())));

assertThat(evaluateDataFrameResponse.getEvaluationName(), equalTo(Classification.NAME.getPreferredName()));
assertThat(evaluateDataFrameResponse.getMetrics(), hasSize(1));
Expand Down Expand Up @@ -137,11 +180,11 @@ public void testEvaluate_MulticlassClassification_AccuracyAndConfusionMatrixMetr
assertThat(confusionMatrixResult.getOtherActualClassCount(), equalTo(0L));
}

public void testEvaluate_MulticlassClassification_ConfusionMatrixMetricWithUserProvidedSize() {
public void testEvaluate_ConfusionMatrixMetricWithUserProvidedSize() {
EvaluateDataFrameAction.Response evaluateDataFrameResponse =
evaluateDataFrame(
ANIMALS_DATA_INDEX,
new Classification(ACTUAL_CLASS_FIELD, PREDICTED_CLASS_FIELD, Arrays.asList(new MulticlassConfusionMatrix(3))));
new Classification(ANIMAL_NAME_FIELD, ANIMAL_NAME_PREDICTION_FIELD, Arrays.asList(new MulticlassConfusionMatrix(3))));

assertThat(evaluateDataFrameResponse.getEvaluationName(), equalTo(Classification.NAME.getPreferredName()));
assertThat(evaluateDataFrameResponse.getMetrics(), hasSize(1));
Expand All @@ -168,20 +211,30 @@ public void testEvaluate_MulticlassClassification_ConfusionMatrixMetricWithUserP

private static void indexAnimalsData(String indexName) {
client().admin().indices().prepareCreate(indexName)
.addMapping("_doc", ACTUAL_CLASS_FIELD, "type=keyword", PREDICTED_CLASS_FIELD, "type=keyword")
.addMapping("_doc",
ANIMAL_NAME_FIELD, "type=keyword",
ANIMAL_NAME_PREDICTION_FIELD, "type=keyword",
NO_LEGS_FIELD, "type=integer",
NO_LEGS_PREDICTION_FIELD, "type=integer",
IS_PREDATOR_FIELD, "type=boolean",
IS_PREDATOR_PREDICTION_FIELD, "type=boolean")
.get();

List<String> classNames = Arrays.asList("dog", "cat", "mouse", "ant", "fox");
List<String> animalNames = Arrays.asList("dog", "cat", "mouse", "ant", "fox");
BulkRequestBuilder bulkRequestBuilder = client().prepareBulk()
.setRefreshPolicy(WriteRequest.RefreshPolicy.IMMEDIATE);
for (int i = 0; i < classNames.size(); i++) {
for (int j = 0; j < classNames.size(); j++) {
for (int i = 0; i < animalNames.size(); i++) {
for (int j = 0; j < animalNames.size(); j++) {
for (int k = 0; k < j + 1; k++) {
bulkRequestBuilder.add(
new IndexRequest(indexName)
.source(
ACTUAL_CLASS_FIELD, classNames.get(i),
PREDICTED_CLASS_FIELD, classNames.get((i + j) % classNames.size())));
ANIMAL_NAME_FIELD, animalNames.get(i),
ANIMAL_NAME_PREDICTION_FIELD, animalNames.get((i + j) % animalNames.size()),
NO_LEGS_FIELD, i + 1,
NO_LEGS_PREDICTION_FIELD, j + 1,
IS_PREDATOR_FIELD, i % 2 == 0,
IS_PREDATOR_PREDICTION_FIELD, (i + j) % 2 == 0));
}
}
}
Expand Down
Loading

0 comments on commit 0965a10

Please sign in to comment.