Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[7.5][ML] Deduplicate multi-fields for data frame analytics (#48799) #48807

Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@
import java.util.Collections;
import java.util.HashSet;
import java.util.Iterator;
import java.util.LinkedHashMap;
import java.util.List;
import java.util.Map;
import java.util.Objects;
Expand Down Expand Up @@ -238,7 +239,9 @@ private ExtractedFields detectExtractedFields(Set<String> fields) {
// We sort the fields to ensure the checksum for each document is deterministic
Collections.sort(sortedFields);
ExtractedFields extractedFields = ExtractedFields.build(sortedFields, Collections.emptySet(), fieldCapabilitiesResponse);
if (extractedFields.getDocValueFields().size() > docValueFieldsLimit) {
boolean preferSource = extractedFields.getDocValueFields().size() > docValueFieldsLimit;
extractedFields = deduplicateMultiFields(extractedFields, preferSource);
if (preferSource) {
extractedFields = fetchFromSourceIfSupported(extractedFields);
if (extractedFields.getDocValueFields().size() > docValueFieldsLimit) {
throw ExceptionsHelper.badRequestException("[{}] fields must be retrieved from doc_values but the limit is [{}]; " +
Expand All @@ -250,9 +253,59 @@ private ExtractedFields detectExtractedFields(Set<String> fields) {
return extractedFields;
}

private ExtractedFields deduplicateMultiFields(ExtractedFields extractedFields, boolean preferSource) {
Set<String> requiredFields = config.getAnalysis().getRequiredFields().stream().map(RequiredField::getName)
.collect(Collectors.toSet());
Map<String, ExtractedField> nameOrParentToField = new LinkedHashMap<>();
for (ExtractedField currentField : extractedFields.getAllFields()) {
String nameOrParent = currentField.isMultiField() ? currentField.getParentField() : currentField.getName();
ExtractedField existingField = nameOrParentToField.putIfAbsent(nameOrParent, currentField);
if (existingField != null) {
ExtractedField parent = currentField.isMultiField() ? existingField : currentField;
ExtractedField multiField = currentField.isMultiField() ? currentField : existingField;
nameOrParentToField.put(nameOrParent, chooseMultiFieldOrParent(preferSource, requiredFields, parent, multiField));
}
}
return new ExtractedFields(new ArrayList<>(nameOrParentToField.values()));
}

private ExtractedField chooseMultiFieldOrParent(boolean preferSource, Set<String> requiredFields,
ExtractedField parent, ExtractedField multiField) {
// Check requirements first
if (requiredFields.contains(parent.getName())) {
return parent;
}
if (requiredFields.contains(multiField.getName())) {
return multiField;
}

// If both are multi-fields it means there are several. In this case parent is the previous multi-field
// we selected. We'll just keep that.
if (parent.isMultiField() && multiField.isMultiField()) {
return parent;
}

// If we prefer source only the parent may support it. If it does we pick it immediately.
if (preferSource && parent.supportsFromSource()) {
return parent;
}

// If any of the two is a doc_value field let's prefer it as it'd support aggregations.
// We check the parent first as it'd be a shorter field name.
if (parent.getMethod() == ExtractedField.Method.DOC_VALUE) {
return parent;
}
if (multiField.getMethod() == ExtractedField.Method.DOC_VALUE) {
return multiField;
}

// None is aggregatable. Let's pick the parent for its shorter name.
return parent;
}

private ExtractedFields fetchFromSourceIfSupported(ExtractedFields extractedFields) {
List<ExtractedField> adjusted = new ArrayList<>(extractedFields.getAllFields().size());
for (ExtractedField field : extractedFields.getDocValueFields()) {
for (ExtractedField field : extractedFields.getAllFields()) {
adjusted.add(field.supportsFromSource() ? field.newFromSource() : field);
}
return new ExtractedFields(adjusted);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -534,6 +534,151 @@ public void testDetect_GivenBooleanField_BooleanMappedAsString() {
assertThat(booleanField.value(hit), arrayContaining("false", "true", "false"));
}

public void testDetect_GivenMultiFields() {
FieldCapabilitiesResponse fieldCapabilities = new MockFieldCapsResponseBuilder()
.addAggregatableField("a_float", "float")
.addNonAggregatableField("text_without_keyword", "text")
.addNonAggregatableField("text_1", "text")
.addAggregatableField("text_1.keyword", "keyword")
.addNonAggregatableField("text_2", "text")
.addAggregatableField("text_2.keyword", "keyword")
.addAggregatableField("keyword_1", "keyword")
.addNonAggregatableField("keyword_1.text", "text")
.build();

ExtractedFieldsDetector extractedFieldsDetector = new ExtractedFieldsDetector(
SOURCE_INDEX, buildRegressionConfig("a_float"), RESULTS_FIELD, true, 100, fieldCapabilities);
ExtractedFields extractedFields = extractedFieldsDetector.detect();

assertThat(extractedFields.getAllFields().size(), equalTo(5));
List<String> extractedFieldNames = extractedFields.getAllFields().stream().map(ExtractedField::getName)
.collect(Collectors.toList());
assertThat(extractedFieldNames, contains("a_float", "keyword_1", "text_1.keyword", "text_2.keyword", "text_without_keyword"));
}

public void testDetect_GivenMultiFieldAndParentIsRequired() {
FieldCapabilitiesResponse fieldCapabilities = new MockFieldCapsResponseBuilder()
.addAggregatableField("field_1", "keyword")
.addAggregatableField("field_1.keyword", "keyword")
.addAggregatableField("field_2", "float")
.build();

ExtractedFieldsDetector extractedFieldsDetector = new ExtractedFieldsDetector(
SOURCE_INDEX, buildClassificationConfig("field_1"), RESULTS_FIELD, true, 100, fieldCapabilities);
ExtractedFields extractedFields = extractedFieldsDetector.detect();

assertThat(extractedFields.getAllFields().size(), equalTo(2));
List<String> extractedFieldNames = extractedFields.getAllFields().stream().map(ExtractedField::getName)
.collect(Collectors.toList());
assertThat(extractedFieldNames, contains("field_1", "field_2"));
}

public void testDetect_GivenMultiFieldAndMultiFieldIsRequired() {
FieldCapabilitiesResponse fieldCapabilities = new MockFieldCapsResponseBuilder()
.addAggregatableField("field_1", "keyword")
.addAggregatableField("field_1.keyword", "keyword")
.addAggregatableField("field_2", "float")
.build();

ExtractedFieldsDetector extractedFieldsDetector = new ExtractedFieldsDetector(
SOURCE_INDEX, buildClassificationConfig("field_1.keyword"), RESULTS_FIELD, true, 100, fieldCapabilities);
ExtractedFields extractedFields = extractedFieldsDetector.detect();

assertThat(extractedFields.getAllFields().size(), equalTo(2));
List<String> extractedFieldNames = extractedFields.getAllFields().stream().map(ExtractedField::getName)
.collect(Collectors.toList());
assertThat(extractedFieldNames, contains("field_1.keyword", "field_2"));
}

public void testDetect_GivenSeveralMultiFields_ShouldPickFirstSorted() {
FieldCapabilitiesResponse fieldCapabilities = new MockFieldCapsResponseBuilder()
.addNonAggregatableField("field_1", "text")
.addAggregatableField("field_1.keyword_3", "keyword")
.addAggregatableField("field_1.keyword_2", "keyword")
.addAggregatableField("field_1.keyword_1", "keyword")
.addAggregatableField("field_2", "float")
.build();

ExtractedFieldsDetector extractedFieldsDetector = new ExtractedFieldsDetector(
SOURCE_INDEX, buildRegressionConfig("field_2"), RESULTS_FIELD, true, 100, fieldCapabilities);
ExtractedFields extractedFields = extractedFieldsDetector.detect();

assertThat(extractedFields.getAllFields().size(), equalTo(2));
List<String> extractedFieldNames = extractedFields.getAllFields().stream().map(ExtractedField::getName)
.collect(Collectors.toList());
assertThat(extractedFieldNames, contains("field_1.keyword_1", "field_2"));
}

public void testDetect_GivenMultiFields_OverDocValueLimit() {
FieldCapabilitiesResponse fieldCapabilities = new MockFieldCapsResponseBuilder()
.addNonAggregatableField("field_1", "text")
.addAggregatableField("field_1.keyword_1", "keyword")
.addAggregatableField("field_2", "float")
.build();

ExtractedFieldsDetector extractedFieldsDetector = new ExtractedFieldsDetector(
SOURCE_INDEX, buildRegressionConfig("field_2"), RESULTS_FIELD, true, 0, fieldCapabilities);
ExtractedFields extractedFields = extractedFieldsDetector.detect();

assertThat(extractedFields.getAllFields().size(), equalTo(2));
List<String> extractedFieldNames = extractedFields.getAllFields().stream().map(ExtractedField::getName)
.collect(Collectors.toList());
assertThat(extractedFieldNames, contains("field_1", "field_2"));
}

public void testDetect_GivenParentAndMultiFieldBothAggregatable() {
FieldCapabilitiesResponse fieldCapabilities = new MockFieldCapsResponseBuilder()
.addAggregatableField("field_1", "keyword")
.addAggregatableField("field_1.keyword", "keyword")
.addAggregatableField("field_2.keyword", "float")
.addAggregatableField("field_2.double", "double")
.build();

ExtractedFieldsDetector extractedFieldsDetector = new ExtractedFieldsDetector(
SOURCE_INDEX, buildRegressionConfig("field_2.double"), RESULTS_FIELD, true, 100, fieldCapabilities);
ExtractedFields extractedFields = extractedFieldsDetector.detect();

assertThat(extractedFields.getAllFields().size(), equalTo(2));
List<String> extractedFieldNames = extractedFields.getAllFields().stream().map(ExtractedField::getName)
.collect(Collectors.toList());
assertThat(extractedFieldNames, contains("field_1", "field_2.double"));
}

public void testDetect_GivenParentAndMultiFieldNoneAggregatable() {
FieldCapabilitiesResponse fieldCapabilities = new MockFieldCapsResponseBuilder()
.addNonAggregatableField("field_1", "text")
.addNonAggregatableField("field_1.text", "text")
.addAggregatableField("field_2", "float")
.build();

ExtractedFieldsDetector extractedFieldsDetector = new ExtractedFieldsDetector(
SOURCE_INDEX, buildRegressionConfig("field_2"), RESULTS_FIELD, true, 100, fieldCapabilities);
ExtractedFields extractedFields = extractedFieldsDetector.detect();

assertThat(extractedFields.getAllFields().size(), equalTo(2));
List<String> extractedFieldNames = extractedFields.getAllFields().stream().map(ExtractedField::getName)
.collect(Collectors.toList());
assertThat(extractedFieldNames, contains("field_1", "field_2"));
}

public void testDetect_GivenMultiFields_AndExplicitlyIncludedFields() {
FieldCapabilitiesResponse fieldCapabilities = new MockFieldCapsResponseBuilder()
.addNonAggregatableField("field_1", "text")
.addAggregatableField("field_1.keyword", "keyword")
.addAggregatableField("field_2", "float")
.build();
FetchSourceContext analyzedFields = new FetchSourceContext(true, new String[] { "field_1", "field_2" }, new String[0]);

ExtractedFieldsDetector extractedFieldsDetector = new ExtractedFieldsDetector(
SOURCE_INDEX, buildRegressionConfig("field_2", analyzedFields), RESULTS_FIELD, false, 100, fieldCapabilities);
ExtractedFields extractedFields = extractedFieldsDetector.detect();

assertThat(extractedFields.getAllFields().size(), equalTo(2));
List<String> extractedFieldNames = extractedFields.getAllFields().stream().map(ExtractedField::getName)
.collect(Collectors.toList());
assertThat(extractedFieldNames, contains("field_1", "field_2"));
}

private static DataFrameAnalyticsConfig buildOutlierDetectionConfig() {
return buildOutlierDetectionConfig(null);
}
Expand Down Expand Up @@ -576,9 +721,17 @@ private static class MockFieldCapsResponseBuilder {
private final Map<String, Map<String, FieldCapabilities>> fieldCaps = new HashMap<>();

private MockFieldCapsResponseBuilder addAggregatableField(String field, String... types) {
return addField(field, true, types);
}

private MockFieldCapsResponseBuilder addNonAggregatableField(String field, String... types) {
return addField(field, false, types);
}

private MockFieldCapsResponseBuilder addField(String field, boolean isAggregatable, String... types) {
Map<String, FieldCapabilities> caps = new HashMap<>();
for (String type : types) {
caps.put(type, new FieldCapabilities(field, type, true, true));
caps.put(type, new FieldCapabilities(field, type, true, isAggregatable));
}
fieldCaps.put(field, caps);
return this;
Expand Down