From 05b6530e4af0d3b9f2eddcd3310768f58ce854d6 Mon Sep 17 00:00:00 2001 From: Dimitris Athanasiou Date: Sun, 24 Nov 2019 19:41:13 +0200 Subject: [PATCH] [ML] Apply source query on data frame analytics memory estimation Closes #49454 --- .../ml/integration/EstimateMemoryUsageIT.java | 64 +++++++++++++++++++ ...NativeDataFrameAnalyticsIntegTestCase.java | 6 ++ .../DataFrameDataExtractorFactory.java | 22 +++++-- 3 files changed, 85 insertions(+), 7 deletions(-) create mode 100644 x-pack/plugin/ml/qa/native-multi-node-tests/src/test/java/org/elasticsearch/xpack/ml/integration/EstimateMemoryUsageIT.java diff --git a/x-pack/plugin/ml/qa/native-multi-node-tests/src/test/java/org/elasticsearch/xpack/ml/integration/EstimateMemoryUsageIT.java b/x-pack/plugin/ml/qa/native-multi-node-tests/src/test/java/org/elasticsearch/xpack/ml/integration/EstimateMemoryUsageIT.java new file mode 100644 index 0000000000000..f006910bfe25d --- /dev/null +++ b/x-pack/plugin/ml/qa/native-multi-node-tests/src/test/java/org/elasticsearch/xpack/ml/integration/EstimateMemoryUsageIT.java @@ -0,0 +1,64 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License; + * you may not use this file except in compliance with the Elastic License. + */ +package org.elasticsearch.xpack.ml.integration; + +import org.elasticsearch.action.bulk.BulkRequestBuilder; +import org.elasticsearch.action.bulk.BulkResponse; +import org.elasticsearch.action.index.IndexRequest; +import org.elasticsearch.action.support.WriteRequest; +import org.elasticsearch.index.query.QueryBuilders; +import org.elasticsearch.xpack.core.ml.action.EstimateMemoryUsageAction; +import org.elasticsearch.xpack.core.ml.dataframe.DataFrameAnalyticsConfig; +import org.elasticsearch.xpack.core.ml.dataframe.DataFrameAnalyticsSource; +import org.elasticsearch.xpack.core.ml.dataframe.analyses.Classification; +import org.elasticsearch.xpack.core.ml.utils.QueryProvider; + +import java.io.IOException; + +import static org.hamcrest.Matchers.lessThanOrEqualTo; + +public class EstimateMemoryUsageIT extends MlNativeDataFrameAnalyticsIntegTestCase { + + public void testSourceQueryIsApplied() throws IOException { + // To test the source query is applied when we extract data, + // we set up a job where we have a query which excludes all but one document. + // We then assert the memory estimation is low enough. + + String sourceIndex = "test-source-query-is-applied"; + + client().admin().indices().prepareCreate(sourceIndex) + .addMapping("_doc", "numeric_1", "type=double", "numeric_2", "type=float", "categorical", "type=keyword") + .get(); + + BulkRequestBuilder bulkRequestBuilder = client().prepareBulk(); + bulkRequestBuilder.setRefreshPolicy(WriteRequest.RefreshPolicy.IMMEDIATE); + + for (int i = 0; i < 30; i++) { + IndexRequest indexRequest = new IndexRequest(sourceIndex); + + // We insert one odd value out of 5 for one feature + indexRequest.source("numeric_1", 1.0, "numeric_2", 2.0, "categorical", i == 0 ? "only-one" : "normal"); + bulkRequestBuilder.add(indexRequest); + } + BulkResponse bulkResponse = bulkRequestBuilder.get(); + if (bulkResponse.hasFailures()) { + fail("Failed to index data: " + bulkResponse.buildFailureMessage()); + } + + String id = "test_source_query_is_applied"; + + DataFrameAnalyticsConfig config = new DataFrameAnalyticsConfig.Builder() + .setId(id) + .setSource(new DataFrameAnalyticsSource(new String[] { sourceIndex }, + QueryProvider.fromParsedQuery(QueryBuilders.termQuery("categorical", "only-one")))) + .setAnalysis(new Classification("categorical")) + .buildForMemoryEstimation(); + + EstimateMemoryUsageAction.Response explainResponse = estimateMemoryUsage(config); + + assertThat(explainResponse.getExpectedMemoryWithoutDisk().getKb(), lessThanOrEqualTo(500L)); + } +} diff --git a/x-pack/plugin/ml/qa/native-multi-node-tests/src/test/java/org/elasticsearch/xpack/ml/integration/MlNativeDataFrameAnalyticsIntegTestCase.java b/x-pack/plugin/ml/qa/native-multi-node-tests/src/test/java/org/elasticsearch/xpack/ml/integration/MlNativeDataFrameAnalyticsIntegTestCase.java index bf375c7e5478f..13ee06e1c3a86 100644 --- a/x-pack/plugin/ml/qa/native-multi-node-tests/src/test/java/org/elasticsearch/xpack/ml/integration/MlNativeDataFrameAnalyticsIntegTestCase.java +++ b/x-pack/plugin/ml/qa/native-multi-node-tests/src/test/java/org/elasticsearch/xpack/ml/integration/MlNativeDataFrameAnalyticsIntegTestCase.java @@ -18,6 +18,7 @@ import org.elasticsearch.index.query.QueryBuilders; import org.elasticsearch.search.sort.SortOrder; import org.elasticsearch.xpack.core.ml.action.DeleteDataFrameAnalyticsAction; +import org.elasticsearch.xpack.core.ml.action.EstimateMemoryUsageAction; import org.elasticsearch.xpack.core.ml.action.EvaluateDataFrameAction; import org.elasticsearch.xpack.core.ml.action.GetDataFrameAnalyticsAction; import org.elasticsearch.xpack.core.ml.action.GetDataFrameAnalyticsStatsAction; @@ -142,6 +143,11 @@ protected GetDataFrameAnalyticsStatsAction.Response.Stats getAnalyticsStats(Stri return stats.get(0); } + protected EstimateMemoryUsageAction.Response estimateMemoryUsage(DataFrameAnalyticsConfig config) { + PutDataFrameAnalyticsAction.Request request = new PutDataFrameAnalyticsAction.Request(config); + return client().execute(EstimateMemoryUsageAction.INSTANCE, request).actionGet(); + } + protected EvaluateDataFrameAction.Response evaluateDataFrame(String index, Evaluation evaluation) { EvaluateDataFrameAction.Request request = new EvaluateDataFrameAction.Request() diff --git a/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/dataframe/extractor/DataFrameDataExtractorFactory.java b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/dataframe/extractor/DataFrameDataExtractorFactory.java index afa9e51b626de..4910334d64de8 100644 --- a/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/dataframe/extractor/DataFrameDataExtractorFactory.java +++ b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/dataframe/extractor/DataFrameDataExtractorFactory.java @@ -48,15 +48,18 @@ public class DataFrameDataExtractorFactory { private final Client client; private final String analyticsId; private final List indices; + private final QueryBuilder sourceQuery; private final ExtractedFields extractedFields; private final Map headers; private final boolean includeRowsWithMissingValues; - private DataFrameDataExtractorFactory(Client client, String analyticsId, List indices, ExtractedFields extractedFields, - Map headers, boolean includeRowsWithMissingValues) { + private DataFrameDataExtractorFactory(Client client, String analyticsId, List indices, QueryBuilder sourceQuery, + ExtractedFields extractedFields, Map headers, + boolean includeRowsWithMissingValues) { this.client = Objects.requireNonNull(client); this.analyticsId = Objects.requireNonNull(analyticsId); this.indices = Objects.requireNonNull(indices); + this.sourceQuery = Objects.requireNonNull(sourceQuery); this.extractedFields = Objects.requireNonNull(extractedFields); this.headers = headers; this.includeRowsWithMissingValues = includeRowsWithMissingValues; @@ -77,7 +80,12 @@ public DataFrameDataExtractor newExtractor(boolean includeSource) { } private QueryBuilder createQuery() { - return includeRowsWithMissingValues ? QueryBuilders.matchAllQuery() : allExtractedFieldsExistQuery(); + BoolQueryBuilder query = QueryBuilders.boolQuery(); + query.filter(sourceQuery); + if (includeRowsWithMissingValues == false) { + query.filter(allExtractedFieldsExistQuery()); + } + return query; } private QueryBuilder allExtractedFieldsExistQuery() { @@ -110,8 +118,8 @@ public static void createForSourceIndices(Client client, ActionListener.wrap( extractedFields -> listener.onResponse( new DataFrameDataExtractorFactory( - client, taskId, Arrays.asList(config.getSource().getIndex()), extractedFields, config.getHeaders(), - config.getAnalysis().supportsMissingValues())), + client, taskId, Arrays.asList(config.getSource().getIndex()), config.getSource().getParsedQuery(), extractedFields, + config.getHeaders(), config.getAnalysis().supportsMissingValues())), listener::onFailure ) ); @@ -140,8 +148,8 @@ public static void createForDestinationIndex(Client client, ActionListener.wrap( extractedFields -> listener.onResponse( new DataFrameDataExtractorFactory( - client, config.getId(), Arrays.asList(config.getDest().getIndex()), extractedFields, config.getHeaders(), - config.getAnalysis().supportsMissingValues())), + client, config.getId(), Arrays.asList(config.getDest().getIndex()), config.getSource().getParsedQuery(), + extractedFields, config.getHeaders(), config.getAnalysis().supportsMissingValues())), listener::onFailure ) );