diff --git a/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/action/TransportDeleteExpiredDataAction.java b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/action/TransportDeleteExpiredDataAction.java index bc4bb6e16e53c..1ff0cf2208853 100644 --- a/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/action/TransportDeleteExpiredDataAction.java +++ b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/action/TransportDeleteExpiredDataAction.java @@ -12,6 +12,7 @@ import org.elasticsearch.action.support.HandledTransportAction; import org.elasticsearch.action.support.ThreadedActionListener; import org.elasticsearch.client.Client; +import org.elasticsearch.client.OriginSettingClient; import org.elasticsearch.cluster.service.ClusterService; import org.elasticsearch.common.inject.Inject; import org.elasticsearch.tasks.Task; @@ -46,7 +47,7 @@ public class TransportDeleteExpiredDataAction extends HandledTransportAction { - BatchedBucketsIterator(Client client, String jobId) { + BatchedBucketsIterator(OriginSettingClient client, String jobId) { super(client, jobId, Bucket.RESULT_TYPE_VALUE); } diff --git a/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/job/persistence/BatchedInfluencersIterator.java b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/job/persistence/BatchedInfluencersIterator.java index fe8bd3aaa3af7..35a88ed0f3e14 100644 --- a/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/job/persistence/BatchedInfluencersIterator.java +++ b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/job/persistence/BatchedInfluencersIterator.java @@ -6,7 +6,7 @@ package org.elasticsearch.xpack.ml.job.persistence; import org.elasticsearch.ElasticsearchParseException; -import org.elasticsearch.client.Client; +import org.elasticsearch.client.OriginSettingClient; import org.elasticsearch.common.bytes.BytesReference; import org.elasticsearch.common.xcontent.LoggingDeprecationHandler; import org.elasticsearch.common.xcontent.NamedXContentRegistry; @@ -21,7 +21,7 @@ import java.io.InputStream; class BatchedInfluencersIterator extends BatchedResultsIterator { - BatchedInfluencersIterator(Client client, String jobId) { + BatchedInfluencersIterator(OriginSettingClient client, String jobId) { super(client, jobId, Influencer.RESULT_TYPE_VALUE); } diff --git a/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/job/persistence/BatchedJobsIterator.java b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/job/persistence/BatchedJobsIterator.java index 1b72c1901d9bb..f933769c9454f 100644 --- a/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/job/persistence/BatchedJobsIterator.java +++ b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/job/persistence/BatchedJobsIterator.java @@ -6,7 +6,7 @@ package org.elasticsearch.xpack.ml.job.persistence; import org.elasticsearch.ElasticsearchParseException; -import org.elasticsearch.client.Client; +import org.elasticsearch.client.OriginSettingClient; import org.elasticsearch.common.xcontent.LoggingDeprecationHandler; import org.elasticsearch.common.xcontent.NamedXContentRegistry; import org.elasticsearch.common.xcontent.XContentFactory; @@ -23,7 +23,7 @@ public class BatchedJobsIterator extends BatchedDocumentsIterator { - public BatchedJobsIterator(Client client, String index) { + public BatchedJobsIterator(OriginSettingClient client, String index) { super(client, index); } diff --git a/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/job/persistence/BatchedRecordsIterator.java b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/job/persistence/BatchedRecordsIterator.java index 22c107f771ba5..989dd61c72d8b 100644 --- a/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/job/persistence/BatchedRecordsIterator.java +++ b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/job/persistence/BatchedRecordsIterator.java @@ -6,7 +6,7 @@ package org.elasticsearch.xpack.ml.job.persistence; import org.elasticsearch.ElasticsearchParseException; -import org.elasticsearch.client.Client; +import org.elasticsearch.client.OriginSettingClient; import org.elasticsearch.common.bytes.BytesReference; import org.elasticsearch.common.xcontent.LoggingDeprecationHandler; import org.elasticsearch.common.xcontent.NamedXContentRegistry; @@ -22,7 +22,7 @@ class BatchedRecordsIterator extends BatchedResultsIterator { - BatchedRecordsIterator(Client client, String jobId) { + BatchedRecordsIterator(OriginSettingClient client, String jobId) { super(client, jobId, AnomalyRecord.RESULT_TYPE_VALUE); } diff --git a/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/job/persistence/BatchedResultsIterator.java b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/job/persistence/BatchedResultsIterator.java index 1c0fdbe08c9c6..61ca1dcc2c8af 100644 --- a/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/job/persistence/BatchedResultsIterator.java +++ b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/job/persistence/BatchedResultsIterator.java @@ -5,7 +5,7 @@ */ package org.elasticsearch.xpack.ml.job.persistence; -import org.elasticsearch.client.Client; +import org.elasticsearch.client.OriginSettingClient; import org.elasticsearch.index.query.QueryBuilder; import org.elasticsearch.index.query.TermsQueryBuilder; import org.elasticsearch.xpack.core.ml.job.persistence.AnomalyDetectorsIndex; @@ -16,7 +16,7 @@ public abstract class BatchedResultsIterator extends BatchedDocumentsIterator private final ResultsFilterBuilder filterBuilder; - public BatchedResultsIterator(Client client, String jobId, String resultType) { + public BatchedResultsIterator(OriginSettingClient client, String jobId, String resultType) { super(client, AnomalyDetectorsIndex.jobResultsAliasedName(jobId)); this.filterBuilder = new ResultsFilterBuilder(new TermsQueryBuilder(Result.RESULT_TYPE.getPreferredName(), resultType)); } diff --git a/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/job/persistence/BatchedStateDocIdsIterator.java b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/job/persistence/BatchedStateDocIdsIterator.java index 65e8b75671151..4c147f3431b28 100644 --- a/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/job/persistence/BatchedStateDocIdsIterator.java +++ b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/job/persistence/BatchedStateDocIdsIterator.java @@ -5,7 +5,7 @@ */ package org.elasticsearch.xpack.ml.job.persistence; -import org.elasticsearch.client.Client; +import org.elasticsearch.client.OriginSettingClient; import org.elasticsearch.index.query.QueryBuilder; import org.elasticsearch.index.query.QueryBuilders; import org.elasticsearch.search.SearchHit; @@ -16,7 +16,7 @@ */ public class BatchedStateDocIdsIterator extends BatchedDocumentsIterator { - public BatchedStateDocIdsIterator(Client client, String index) { + public BatchedStateDocIdsIterator(OriginSettingClient client, String index) { super(client, index); } diff --git a/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/job/persistence/JobResultsProvider.java b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/job/persistence/JobResultsProvider.java index b9359d2b97cd6..38e3e037ab024 100644 --- a/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/job/persistence/JobResultsProvider.java +++ b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/job/persistence/JobResultsProvider.java @@ -37,6 +37,7 @@ import org.elasticsearch.action.update.UpdateRequest; import org.elasticsearch.action.update.UpdateResponse; import org.elasticsearch.client.Client; +import org.elasticsearch.client.OriginSettingClient; import org.elasticsearch.cluster.ClusterState; import org.elasticsearch.cluster.block.ClusterBlock; import org.elasticsearch.cluster.block.ClusterBlockException; @@ -130,7 +131,6 @@ import java.util.stream.Collectors; import static org.elasticsearch.xpack.core.ClientHelper.ML_ORIGIN; -import static org.elasticsearch.xpack.core.ClientHelper.clientWithOrigin; import static org.elasticsearch.xpack.core.ClientHelper.executeAsyncWithOrigin; public class JobResultsProvider { @@ -715,7 +715,7 @@ private void expandBuckets(String jobId, BucketsQueryBuilder query, QueryPage newBatchedBucketsIterator(String jobId) { - return new BatchedBucketsIterator(clientWithOrigin(client, ML_ORIGIN), jobId); + return new BatchedBucketsIterator(new OriginSettingClient(client, ML_ORIGIN), jobId); } /** @@ -727,7 +727,7 @@ public BatchedResultsIterator newBatchedBucketsIterator(String jobId) { * @return a record {@link BatchedResultsIterator} */ public BatchedResultsIterator newBatchedRecordsIterator(String jobId) { - return new BatchedRecordsIterator(clientWithOrigin(client, ML_ORIGIN), jobId); + return new BatchedRecordsIterator(new OriginSettingClient(client, ML_ORIGIN), jobId); } /** @@ -924,7 +924,7 @@ public void influencers(String jobId, InfluencersQuery query, Consumer newBatchedInfluencersIterator(String jobId) { - return new BatchedInfluencersIterator(clientWithOrigin(client, ML_ORIGIN), jobId); + return new BatchedInfluencersIterator(new OriginSettingClient(client, ML_ORIGIN), jobId); } /** diff --git a/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/job/retention/AbstractExpiredJobDataRemover.java b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/job/retention/AbstractExpiredJobDataRemover.java index 2650f3018d951..c6e3fe9dbf670 100644 --- a/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/job/retention/AbstractExpiredJobDataRemover.java +++ b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/job/retention/AbstractExpiredJobDataRemover.java @@ -6,7 +6,7 @@ package org.elasticsearch.xpack.ml.job.retention; import org.elasticsearch.action.ActionListener; -import org.elasticsearch.client.Client; +import org.elasticsearch.client.OriginSettingClient; import org.elasticsearch.common.unit.TimeValue; import org.elasticsearch.index.query.BoolQueryBuilder; import org.elasticsearch.index.query.QueryBuilders; @@ -34,9 +34,9 @@ */ abstract class AbstractExpiredJobDataRemover implements MlDataRemover { - private final Client client; + private final OriginSettingClient client; - AbstractExpiredJobDataRemover(Client client) { + AbstractExpiredJobDataRemover(OriginSettingClient client) { this.client = client; } diff --git a/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/job/retention/ExpiredForecastsRemover.java b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/job/retention/ExpiredForecastsRemover.java index a80b00aaa0792..40611438fda59 100644 --- a/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/job/retention/ExpiredForecastsRemover.java +++ b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/job/retention/ExpiredForecastsRemover.java @@ -13,7 +13,7 @@ import org.elasticsearch.action.search.SearchRequest; import org.elasticsearch.action.search.SearchResponse; import org.elasticsearch.action.support.ThreadedActionListener; -import org.elasticsearch.client.Client; +import org.elasticsearch.client.OriginSettingClient; import org.elasticsearch.common.xcontent.LoggingDeprecationHandler; import org.elasticsearch.common.xcontent.NamedXContentRegistry; import org.elasticsearch.common.xcontent.XContentFactory; @@ -62,11 +62,11 @@ public class ExpiredForecastsRemover implements MlDataRemover { private static final int MAX_FORECASTS = 10000; private static final String RESULTS_INDEX_PATTERN = AnomalyDetectorsIndex.jobResultsIndexPrefix() + "*"; - private final Client client; + private final OriginSettingClient client; private final ThreadPool threadPool; private final long cutoffEpochMs; - public ExpiredForecastsRemover(Client client, ThreadPool threadPool) { + public ExpiredForecastsRemover(OriginSettingClient client, ThreadPool threadPool) { this.client = Objects.requireNonNull(client); this.threadPool = Objects.requireNonNull(threadPool); this.cutoffEpochMs = Instant.now(Clock.systemDefaultZone()).toEpochMilli(); diff --git a/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/job/retention/ExpiredModelSnapshotsRemover.java b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/job/retention/ExpiredModelSnapshotsRemover.java index 1153407d5125e..221f9d9debf87 100644 --- a/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/job/retention/ExpiredModelSnapshotsRemover.java +++ b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/job/retention/ExpiredModelSnapshotsRemover.java @@ -14,7 +14,7 @@ import org.elasticsearch.action.search.SearchResponse; import org.elasticsearch.action.support.ThreadedActionListener; import org.elasticsearch.action.support.master.AcknowledgedResponse; -import org.elasticsearch.client.Client; +import org.elasticsearch.client.OriginSettingClient; import org.elasticsearch.index.query.QueryBuilder; import org.elasticsearch.index.query.QueryBuilders; import org.elasticsearch.search.SearchHit; @@ -55,10 +55,10 @@ public class ExpiredModelSnapshotsRemover extends AbstractExpiredJobDataRemover */ private static final int MODEL_SNAPSHOT_SEARCH_SIZE = 10000; - private final Client client; + private final OriginSettingClient client; private final ThreadPool threadPool; - public ExpiredModelSnapshotsRemover(Client client, ThreadPool threadPool) { + public ExpiredModelSnapshotsRemover(OriginSettingClient client, ThreadPool threadPool) { super(client); this.client = Objects.requireNonNull(client); this.threadPool = Objects.requireNonNull(threadPool); diff --git a/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/job/retention/ExpiredResultsRemover.java b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/job/retention/ExpiredResultsRemover.java index 6a17382db0e8c..fff2c23ab75a6 100644 --- a/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/job/retention/ExpiredResultsRemover.java +++ b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/job/retention/ExpiredResultsRemover.java @@ -9,7 +9,7 @@ import org.apache.logging.log4j.Logger; import org.elasticsearch.ElasticsearchException; import org.elasticsearch.action.ActionListener; -import org.elasticsearch.client.Client; +import org.elasticsearch.client.OriginSettingClient; import org.elasticsearch.index.query.QueryBuilder; import org.elasticsearch.index.query.QueryBuilders; import org.elasticsearch.index.reindex.AbstractBulkByScrollRequest; @@ -46,10 +46,10 @@ public class ExpiredResultsRemover extends AbstractExpiredJobDataRemover { private static final Logger LOGGER = LogManager.getLogger(ExpiredResultsRemover.class); - private final Client client; + private final OriginSettingClient client; private final AnomalyDetectionAuditor auditor; - public ExpiredResultsRemover(Client client, AnomalyDetectionAuditor auditor) { + public ExpiredResultsRemover(OriginSettingClient client, AnomalyDetectionAuditor auditor) { super(client); this.client = Objects.requireNonNull(client); this.auditor = Objects.requireNonNull(auditor); diff --git a/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/job/retention/UnusedStateRemover.java b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/job/retention/UnusedStateRemover.java index 8a1d30382489f..cf1a9aaae4ede 100644 --- a/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/job/retention/UnusedStateRemover.java +++ b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/job/retention/UnusedStateRemover.java @@ -9,7 +9,7 @@ import org.apache.logging.log4j.Logger; import org.elasticsearch.action.ActionListener; import org.elasticsearch.action.support.IndicesOptions; -import org.elasticsearch.client.Client; +import org.elasticsearch.client.OriginSettingClient; import org.elasticsearch.cluster.service.ClusterService; import org.elasticsearch.common.Strings; import org.elasticsearch.index.query.QueryBuilders; @@ -49,10 +49,10 @@ public class UnusedStateRemover implements MlDataRemover { private static final Logger LOGGER = LogManager.getLogger(UnusedStateRemover.class); - private final Client client; + private final OriginSettingClient client; private final ClusterService clusterService; - public UnusedStateRemover(Client client, ClusterService clusterService) { + public UnusedStateRemover(OriginSettingClient client, ClusterService clusterService) { this.client = Objects.requireNonNull(client); this.clusterService = Objects.requireNonNull(clusterService); } diff --git a/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/utils/persistence/BatchedDocumentsIterator.java b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/utils/persistence/BatchedDocumentsIterator.java index 119dcbdb42822..9b8c1345af1ea 100644 --- a/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/utils/persistence/BatchedDocumentsIterator.java +++ b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/utils/persistence/BatchedDocumentsIterator.java @@ -10,7 +10,7 @@ import org.elasticsearch.action.search.SearchRequest; import org.elasticsearch.action.search.SearchResponse; import org.elasticsearch.action.search.SearchScrollRequest; -import org.elasticsearch.client.Client; +import org.elasticsearch.client.OriginSettingClient; import org.elasticsearch.index.query.QueryBuilder; import org.elasticsearch.search.SearchHit; import org.elasticsearch.search.builder.SearchSourceBuilder; @@ -34,14 +34,14 @@ public abstract class BatchedDocumentsIterator { private static final String CONTEXT_ALIVE_DURATION = "5m"; private static final int BATCH_SIZE = 10000; - private final Client client; + private final OriginSettingClient client; private final String index; private volatile long count; private volatile long totalHits; private volatile String scrollId; private volatile boolean isScrollInitialised; - protected BatchedDocumentsIterator(Client client, String index) { + protected BatchedDocumentsIterator(OriginSettingClient client, String index) { this.client = Objects.requireNonNull(client); this.index = Objects.requireNonNull(index); this.totalHits = 0; diff --git a/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/utils/persistence/DocIdBatchedDocumentIterator.java b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/utils/persistence/DocIdBatchedDocumentIterator.java index 55b2cee2ff16d..3dcee716f11af 100644 --- a/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/utils/persistence/DocIdBatchedDocumentIterator.java +++ b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/utils/persistence/DocIdBatchedDocumentIterator.java @@ -5,7 +5,7 @@ */ package org.elasticsearch.xpack.ml.utils.persistence; -import org.elasticsearch.client.Client; +import org.elasticsearch.client.OriginSettingClient; import org.elasticsearch.index.query.QueryBuilder; import org.elasticsearch.search.SearchHit; @@ -18,7 +18,7 @@ public class DocIdBatchedDocumentIterator extends BatchedDocumentsIterator extends BatchedResultsIterator { private Boolean requireIncludeInterim; public MockBatchedDocumentsIterator(List>> batches, String resultType) { - super(mock(Client.class), "foo", resultType); + super(MockOriginSettingClient.mockOriginSettingClient(mock(Client.class), ClientHelper.ML_ORIGIN), "foo", resultType); this.batches = batches; index = 0; wasTimeRangeCalled = false; diff --git a/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/job/process/normalizer/ScoresUpdaterTests.java b/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/job/process/normalizer/ScoresUpdaterTests.java index 9836cf93718e5..410c15e52c093 100644 --- a/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/job/process/normalizer/ScoresUpdaterTests.java +++ b/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/job/process/normalizer/ScoresUpdaterTests.java @@ -59,12 +59,12 @@ public class ScoresUpdaterTests extends ESTestCase { private Job job; private ScoresUpdater scoresUpdater; - private Bucket generateBucket(Date timestamp) throws IOException { + private Bucket generateBucket(Date timestamp) { return new Bucket(JOB_ID, timestamp, DEFAULT_BUCKET_SPAN); } @Before - public void setUpMocks() throws IOException { + public void setUpMocks() { MockitoAnnotations.initMocks(this); Job.Builder jobBuilder = new Job.Builder(JOB_ID); diff --git a/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/job/retention/AbstractExpiredJobDataRemoverTests.java b/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/job/retention/AbstractExpiredJobDataRemoverTests.java index c5a24fc9e0609..eb29ba06b17ca 100644 --- a/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/job/retention/AbstractExpiredJobDataRemoverTests.java +++ b/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/job/retention/AbstractExpiredJobDataRemoverTests.java @@ -6,10 +6,11 @@ package org.elasticsearch.xpack.ml.job.retention; import org.apache.lucene.search.TotalHits; -import org.elasticsearch.action.ActionFuture; import org.elasticsearch.action.ActionListener; +import org.elasticsearch.action.search.SearchAction; import org.elasticsearch.action.search.SearchResponse; import org.elasticsearch.client.Client; +import org.elasticsearch.client.OriginSettingClient; import org.elasticsearch.common.bytes.BytesReference; import org.elasticsearch.common.xcontent.ToXContent; import org.elasticsearch.common.xcontent.XContentBuilder; @@ -17,8 +18,10 @@ import org.elasticsearch.search.SearchHit; import org.elasticsearch.search.SearchHits; import org.elasticsearch.test.ESTestCase; +import org.elasticsearch.xpack.core.ClientHelper; import org.elasticsearch.xpack.core.ml.job.config.Job; import org.elasticsearch.xpack.core.ml.job.config.JobTests; +import org.elasticsearch.xpack.ml.test.MockOriginSettingClient; import org.junit.Before; import java.io.IOException; @@ -32,6 +35,7 @@ import static org.hamcrest.Matchers.is; import static org.mockito.Matchers.any; +import static org.mockito.Matchers.eq; import static org.mockito.Mockito.doAnswer; import static org.mockito.Mockito.mock; import static org.mockito.Mockito.when; @@ -44,7 +48,7 @@ private class ConcreteExpiredJobDataRemover extends AbstractExpiredJobDataRemove private int getRetentionDaysCallCount = 0; - ConcreteExpiredJobDataRemover(Client client) { + ConcreteExpiredJobDataRemover(OriginSettingClient client) { super(client); } @@ -61,17 +65,30 @@ protected void removeDataBefore(Job job, long cutoffEpochMs, ActionListener toXContents) throws IOException { return createSearchResponse(toXContents, toXContents.size()); } + @SuppressWarnings("unchecked") + static void givenJobs(Client client, List jobs) throws IOException { + SearchResponse response = AbstractExpiredJobDataRemoverTests.createSearchResponse(jobs); + + doAnswer(invocationOnMock -> { + ActionListener listener = (ActionListener) invocationOnMock.getArguments()[2]; + listener.onResponse(response); + return null; + }).when(client).execute(eq(SearchAction.INSTANCE), any(), any()); + } + private static SearchResponse createSearchResponse(List toXContents, int totalHits) throws IOException { SearchHit[] hitsArray = new SearchHit[toXContents.size()]; for (int i = 0; i < toXContents.size(); i++) { @@ -88,14 +105,10 @@ private static SearchResponse createSearchResponse(List to public void testRemoveGivenNoJobs() throws IOException { SearchResponse response = createSearchResponse(Collections.emptyList()); - - @SuppressWarnings("unchecked") - ActionFuture future = mock(ActionFuture.class); - when(future.actionGet()).thenReturn(response); - when(client.search(any())).thenReturn(future); + mockSearchResponse(response); TestListener listener = new TestListener(); - ConcreteExpiredJobDataRemover remover = new ConcreteExpiredJobDataRemover(client); + ConcreteExpiredJobDataRemover remover = new ConcreteExpiredJobDataRemover(originSettingClient); remover.remove(listener, () -> false); listener.waitToCompletion(); @@ -103,6 +116,7 @@ public void testRemoveGivenNoJobs() throws IOException { assertEquals(0, remover.getRetentionDaysCallCount); } + @SuppressWarnings("unchecked") public void testRemoveGivenMultipleBatches() throws IOException { // This is testing AbstractExpiredJobDataRemover.WrappedBatchedJobsIterator int totalHits = 7; @@ -126,13 +140,14 @@ public void testRemoveGivenMultipleBatches() throws IOException { AtomicInteger searchCount = new AtomicInteger(0); - @SuppressWarnings("unchecked") - ActionFuture future = mock(ActionFuture.class); - doAnswer(invocationOnMock -> responses.get(searchCount.getAndIncrement())).when(future).actionGet(); - when(client.search(any())).thenReturn(future); + doAnswer(invocationOnMock -> { + ActionListener listener = (ActionListener) invocationOnMock.getArguments()[2]; + listener.onResponse(responses.get(searchCount.getAndIncrement())); + return null; + }).when(client).execute(eq(SearchAction.INSTANCE), any(), any()); TestListener listener = new TestListener(); - ConcreteExpiredJobDataRemover remover = new ConcreteExpiredJobDataRemover(client); + ConcreteExpiredJobDataRemover remover = new ConcreteExpiredJobDataRemover(originSettingClient); remover.remove(listener, () -> false); listener.waitToCompletion(); @@ -153,13 +168,10 @@ public void testRemoveGivenTimeOut() throws IOException { final int timeoutAfter = randomIntBetween(0, totalHits - 1); AtomicInteger attemptsLeft = new AtomicInteger(timeoutAfter); - @SuppressWarnings("unchecked") - ActionFuture future = mock(ActionFuture.class); - when(future.actionGet()).thenReturn(response); - when(client.search(any())).thenReturn(future); + mockSearchResponse(response); TestListener listener = new TestListener(); - ConcreteExpiredJobDataRemover remover = new ConcreteExpiredJobDataRemover(client); + ConcreteExpiredJobDataRemover remover = new ConcreteExpiredJobDataRemover(originSettingClient); remover.remove(listener, () -> (attemptsLeft.getAndDecrement() <= 0)); listener.waitToCompletion(); @@ -167,6 +179,15 @@ public void testRemoveGivenTimeOut() throws IOException { assertEquals(timeoutAfter, remover.getRetentionDaysCallCount); } + @SuppressWarnings("unchecked") + private void mockSearchResponse(SearchResponse searchResponse) { + doAnswer(invocationOnMock -> { + ActionListener listener = (ActionListener) invocationOnMock.getArguments()[2]; + listener.onResponse(searchResponse); + return null; + }).when(client).execute(eq(SearchAction.INSTANCE), any(), any()); + } + static class TestListener implements ActionListener { boolean success; diff --git a/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/job/retention/ExpiredModelSnapshotsRemoverTests.java b/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/job/retention/ExpiredModelSnapshotsRemoverTests.java index 56c2333cae016..6e332bf148d17 100644 --- a/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/job/retention/ExpiredModelSnapshotsRemoverTests.java +++ b/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/job/retention/ExpiredModelSnapshotsRemoverTests.java @@ -5,24 +5,25 @@ */ package org.elasticsearch.xpack.ml.job.retention; -import org.elasticsearch.action.ActionFuture; import org.elasticsearch.action.ActionListener; import org.elasticsearch.action.search.SearchAction; import org.elasticsearch.action.search.SearchRequest; import org.elasticsearch.action.search.SearchResponse; import org.elasticsearch.action.support.master.AcknowledgedResponse; import org.elasticsearch.client.Client; +import org.elasticsearch.client.OriginSettingClient; import org.elasticsearch.common.settings.Settings; -import org.elasticsearch.mock.orig.Mockito; import org.elasticsearch.test.ESTestCase; import org.elasticsearch.threadpool.FixedExecutorBuilder; import org.elasticsearch.threadpool.ThreadPool; +import org.elasticsearch.xpack.core.ClientHelper; import org.elasticsearch.xpack.core.ml.action.DeleteModelSnapshotAction; import org.elasticsearch.xpack.core.ml.job.config.Job; import org.elasticsearch.xpack.core.ml.job.config.JobTests; import org.elasticsearch.xpack.core.ml.job.persistence.AnomalyDetectorsIndex; import org.elasticsearch.xpack.core.ml.job.process.autodetect.state.ModelSnapshot; import org.elasticsearch.xpack.ml.MachineLearning; +import org.elasticsearch.xpack.ml.test.MockOriginSettingClient; import org.junit.After; import org.junit.Before; import org.mockito.invocation.InvocationOnMock; @@ -33,21 +34,23 @@ import java.util.Arrays; import java.util.Collections; import java.util.List; +import java.util.concurrent.atomic.AtomicBoolean; import java.util.concurrent.atomic.AtomicInteger; import static org.elasticsearch.xpack.ml.job.retention.AbstractExpiredJobDataRemoverTests.TestListener; import static org.hamcrest.Matchers.equalTo; import static org.hamcrest.Matchers.is; import static org.mockito.Matchers.any; +import static org.mockito.Matchers.eq; import static org.mockito.Matchers.same; import static org.mockito.Mockito.doAnswer; import static org.mockito.Mockito.mock; import static org.mockito.Mockito.verify; -import static org.mockito.Mockito.when; public class ExpiredModelSnapshotsRemoverTests extends ESTestCase { private Client client; + private OriginSettingClient originSettingClient; private ThreadPool threadPool; private List capturedSearchRequests; private List capturedDeleteModelSnapshotRequests; @@ -59,7 +62,10 @@ public void setUpTests() { capturedSearchRequests = new ArrayList<>(); capturedDeleteModelSnapshotRequests = new ArrayList<>(); searchResponsesPerCall = new ArrayList<>(); + client = mock(Client.class); + originSettingClient = MockOriginSettingClient.mockOriginSettingClient(client, ClientHelper.ML_ORIGIN); + listener = new TestListener(); // Init thread pool @@ -76,8 +82,7 @@ public void shutdownThreadPool() { } public void testRemove_GivenJobsWithoutRetentionPolicy() throws IOException { - givenClientRequestsSucceed(); - givenJobs(Arrays.asList( + givenClientRequestsSucceed(Arrays.asList( JobTests.buildJobBuilder("foo").build(), JobTests.buildJobBuilder("bar").build() )); @@ -86,25 +91,22 @@ public void testRemove_GivenJobsWithoutRetentionPolicy() throws IOException { listener.waitToCompletion(); assertThat(listener.success, is(true)); - verify(client).search(any()); - Mockito.verifyNoMoreInteractions(client); + verify(client).execute(eq(SearchAction.INSTANCE), any(), any()); } public void testRemove_GivenJobWithoutActiveSnapshot() throws IOException { - givenClientRequestsSucceed(); - givenJobs(Collections.singletonList(JobTests.buildJobBuilder("foo").setModelSnapshotRetentionDays(7L).build())); + givenClientRequestsSucceed(Collections.singletonList(JobTests.buildJobBuilder("foo").setModelSnapshotRetentionDays(7L).build())); createExpiredModelSnapshotsRemover().remove(listener, () -> false); listener.waitToCompletion(); assertThat(listener.success, is(true)); - verify(client).search(any()); - Mockito.verifyNoMoreInteractions(client); + verify(client).execute(eq(SearchAction.INSTANCE), any(), any()); } public void testRemove_GivenJobsWithMixedRetentionPolicies() throws IOException { - givenClientRequestsSucceed(); - givenJobs(Arrays.asList( + givenClientRequestsSucceed( + Arrays.asList( JobTests.buildJobBuilder("none").build(), JobTests.buildJobBuilder("snapshots-1").setModelSnapshotRetentionDays(7L).setModelSnapshotId("active").build(), JobTests.buildJobBuilder("snapshots-2").setModelSnapshotRetentionDays(17L).setModelSnapshotId("active").build() @@ -140,8 +142,8 @@ public void testRemove_GivenJobsWithMixedRetentionPolicies() throws IOException } public void testRemove_GivenTimeout() throws IOException { - givenClientRequestsSucceed(); - givenJobs(Arrays.asList( + givenClientRequestsSucceed( + Arrays.asList( JobTests.buildJobBuilder("snapshots-1").setModelSnapshotRetentionDays(7L).setModelSnapshotId("active").build(), JobTests.buildJobBuilder("snapshots-2").setModelSnapshotRetentionDays(17L).setModelSnapshotId("active").build() )); @@ -162,8 +164,8 @@ public void testRemove_GivenTimeout() throws IOException { } public void testRemove_GivenClientSearchRequestsFail() throws IOException { - givenClientSearchRequestsFail(); - givenJobs(Arrays.asList( + givenClientSearchRequestsFail( + Arrays.asList( JobTests.buildJobBuilder("none").build(), JobTests.buildJobBuilder("snapshots-1").setModelSnapshotRetentionDays(7L).setModelSnapshotId("active").build(), JobTests.buildJobBuilder("snapshots-2").setModelSnapshotRetentionDays(17L).setModelSnapshotId("active").build() @@ -188,8 +190,8 @@ public void testRemove_GivenClientSearchRequestsFail() throws IOException { } public void testRemove_GivenClientDeleteSnapshotRequestsFail() throws IOException { - givenClientDeleteModelSnapshotRequestsFail(); - givenJobs(Arrays.asList( + givenClientDeleteModelSnapshotRequestsFail( + Arrays.asList( JobTests.buildJobBuilder("none").build(), JobTests.buildJobBuilder("snapshots-1").setModelSnapshotRetentionDays(7L).setModelSnapshotId("active").build(), JobTests.buildJobBuilder("snapshots-2").setModelSnapshotRetentionDays(17L).setModelSnapshotId("active").build() @@ -216,59 +218,47 @@ public void testRemove_GivenClientDeleteSnapshotRequestsFail() throws IOExceptio assertThat(deleteSnapshotRequest.getSnapshotId(), equalTo("snapshots-1_1")); } - @SuppressWarnings("unchecked") - private void givenJobs(List jobs) throws IOException { - SearchResponse response = AbstractExpiredJobDataRemoverTests.createSearchResponse(jobs); - - ActionFuture future = mock(ActionFuture.class); - when(future.actionGet()).thenReturn(response); - when(client.search(any())).thenReturn(future); - } - private ExpiredModelSnapshotsRemover createExpiredModelSnapshotsRemover() { - return new ExpiredModelSnapshotsRemover(client, threadPool); + return new ExpiredModelSnapshotsRemover(originSettingClient, threadPool); } private static ModelSnapshot createModelSnapshot(String jobId, String snapshotId) { return new ModelSnapshot.Builder(jobId).setSnapshotId(snapshotId).build(); } -// private static SearchResponse createSearchResponse(List modelSnapshots) throws IOException { -// SearchHit[] hitsArray = new SearchHit[modelSnapshots.size()]; -// for (int i = 0; i < modelSnapshots.size(); i++) { -// hitsArray[i] = new SearchHit(randomInt()); -// XContentBuilder jsonBuilder = JsonXContent.contentBuilder(); -// modelSnapshots.get(i).toXContent(jsonBuilder, ToXContent.EMPTY_PARAMS); -// hitsArray[i].sourceRef(BytesReference.bytes(jsonBuilder)); -// } -// SearchHits hits = new SearchHits(hitsArray, new TotalHits(hitsArray.length, TotalHits.Relation.EQUAL_TO), 1.0f); -// SearchResponse searchResponse = mock(SearchResponse.class); -// when(searchResponse.getHits()).thenReturn(hits); -// return searchResponse; -// } - - private void givenClientRequestsSucceed() { - givenClientRequests(true, true); + private void givenClientRequestsSucceed(List jobs) throws IOException { + givenClientRequests(jobs, true, true); } - private void givenClientSearchRequestsFail() { - givenClientRequests(false, true); + private void givenClientSearchRequestsFail(List jobs) throws IOException { + givenClientRequests(jobs, false, true); } - private void givenClientDeleteModelSnapshotRequestsFail() { - givenClientRequests(true, false); + private void givenClientDeleteModelSnapshotRequestsFail(List jobs) throws IOException { + givenClientRequests(jobs, true, false); } @SuppressWarnings("unchecked") - private void givenClientRequests(boolean shouldSearchRequestsSucceed, boolean shouldDeleteSnapshotRequestsSucceed) { + private void givenClientRequests(List jobs, + boolean shouldSearchRequestsSucceed, boolean shouldDeleteSnapshotRequestsSucceed) throws IOException { + SearchResponse response = AbstractExpiredJobDataRemoverTests.createSearchResponse(jobs); + doAnswer(new Answer() { int callCount = 0; + AtomicBoolean isJobQuery = new AtomicBoolean(true); @Override public Void answer(InvocationOnMock invocationOnMock) { + ActionListener listener = (ActionListener) invocationOnMock.getArguments()[2]; + + if (isJobQuery.get()) { + listener.onResponse(response); + isJobQuery.set(false); + return null; + } + SearchRequest searchRequest = (SearchRequest) invocationOnMock.getArguments()[1]; capturedSearchRequests.add(searchRequest); - ActionListener listener = (ActionListener) invocationOnMock.getArguments()[2]; if (shouldSearchRequestsSucceed) { listener.onResponse(searchResponsesPerCall.get(callCount++)); } else { @@ -277,6 +267,7 @@ public Void answer(InvocationOnMock invocationOnMock) { return null; } }).when(client).execute(same(SearchAction.INSTANCE), any(), any()); + doAnswer(new Answer() { @Override public Void answer(InvocationOnMock invocationOnMock) { diff --git a/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/job/retention/ExpiredResultsRemoverTests.java b/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/job/retention/ExpiredResultsRemoverTests.java index f5acae02b4f87..b4c5a051fb8c1 100644 --- a/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/job/retention/ExpiredResultsRemoverTests.java +++ b/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/job/retention/ExpiredResultsRemoverTests.java @@ -5,22 +5,19 @@ */ package org.elasticsearch.xpack.ml.job.retention; -import org.elasticsearch.action.ActionFuture; import org.elasticsearch.action.ActionListener; -import org.elasticsearch.action.search.SearchResponse; +import org.elasticsearch.action.search.SearchAction; import org.elasticsearch.client.Client; -import org.elasticsearch.common.settings.Settings; -import org.elasticsearch.common.util.concurrent.ThreadContext; +import org.elasticsearch.client.OriginSettingClient; import org.elasticsearch.index.reindex.BulkByScrollResponse; import org.elasticsearch.index.reindex.DeleteByQueryAction; import org.elasticsearch.index.reindex.DeleteByQueryRequest; -import org.elasticsearch.mock.orig.Mockito; import org.elasticsearch.test.ESTestCase; -import org.elasticsearch.threadpool.ThreadPool; -import org.elasticsearch.xpack.core.ml.job.config.Job; +import org.elasticsearch.xpack.core.ClientHelper; import org.elasticsearch.xpack.core.ml.job.config.JobTests; import org.elasticsearch.xpack.core.ml.job.persistence.AnomalyDetectorsIndex; import org.elasticsearch.xpack.ml.notifications.AnomalyDetectionAuditor; +import org.elasticsearch.xpack.ml.test.MockOriginSettingClient; import org.junit.Before; import org.mockito.invocation.InvocationOnMock; import org.mockito.stubbing.Answer; @@ -34,6 +31,7 @@ import static org.hamcrest.Matchers.equalTo; import static org.mockito.Matchers.any; +import static org.mockito.Matchers.eq; import static org.mockito.Matchers.same; import static org.mockito.Mockito.doAnswer; import static org.mockito.Mockito.mock; @@ -43,6 +41,7 @@ public class ExpiredResultsRemoverTests extends ESTestCase { private Client client; + private OriginSettingClient originSettingClient; private List capturedDeleteByQueryRequests; private ActionListener listener; @@ -50,37 +49,26 @@ public class ExpiredResultsRemoverTests extends ESTestCase { @SuppressWarnings("unchecked") public void setUpTests() { capturedDeleteByQueryRequests = new ArrayList<>(); - client = mock(Client.class); - ThreadPool threadPool = mock(ThreadPool.class); - when(client.threadPool()).thenReturn(threadPool); - when(threadPool.getThreadContext()).thenReturn(new ThreadContext(Settings.EMPTY)); - doAnswer(new Answer() { - @Override - public Void answer(InvocationOnMock invocationOnMock) throws Throwable { - capturedDeleteByQueryRequests.add((DeleteByQueryRequest) invocationOnMock.getArguments()[1]); - ActionListener listener = - (ActionListener) invocationOnMock.getArguments()[2]; - listener.onResponse(null); - return null; - } - }).when(client).execute(same(DeleteByQueryAction.INSTANCE), any(), any()); + + client = org.mockito.Mockito.mock(Client.class); + originSettingClient = MockOriginSettingClient.mockOriginSettingClient(client, ClientHelper.ML_ORIGIN); listener = mock(ActionListener.class); } public void testRemove_GivenNoJobs() throws IOException { givenClientRequestsSucceed(); - givenJobs(Collections.emptyList()); + AbstractExpiredJobDataRemoverTests.givenJobs(client, Collections.emptyList()); createExpiredResultsRemover().remove(listener, () -> false); + verify(client).execute(eq(SearchAction.INSTANCE), any(), any()); verify(listener).onResponse(true); - verify(client).search(any()); - Mockito.verifyNoMoreInteractions(client); } public void testRemove_GivenJobsWithoutRetentionPolicy() throws IOException { givenClientRequestsSucceed(); - givenJobs(Arrays.asList( + AbstractExpiredJobDataRemoverTests.givenJobs(client, + Arrays.asList( JobTests.buildJobBuilder("foo").build(), JobTests.buildJobBuilder("bar").build() )); @@ -88,13 +76,13 @@ public void testRemove_GivenJobsWithoutRetentionPolicy() throws IOException { createExpiredResultsRemover().remove(listener, () -> false); verify(listener).onResponse(true); - verify(client).search(any()); - Mockito.verifyNoMoreInteractions(client); + verify(client).execute(eq(SearchAction.INSTANCE), any(), any()); } public void testRemove_GivenJobsWithAndWithoutRetentionPolicy() throws Exception { givenClientRequestsSucceed(); - givenJobs(Arrays.asList( + AbstractExpiredJobDataRemoverTests.givenJobs(client, + Arrays.asList( JobTests.buildJobBuilder("none").build(), JobTests.buildJobBuilder("results-1").setResultsRetentionDays(10L).build(), JobTests.buildJobBuilder("results-2").setResultsRetentionDays(20L).build() @@ -112,7 +100,8 @@ public void testRemove_GivenJobsWithAndWithoutRetentionPolicy() throws Exception public void testRemove_GivenTimeout() throws Exception { givenClientRequestsSucceed(); - givenJobs(Arrays.asList( + AbstractExpiredJobDataRemoverTests.givenJobs(client, + Arrays.asList( JobTests.buildJobBuilder("results-1").setResultsRetentionDays(10L).build(), JobTests.buildJobBuilder("results-2").setResultsRetentionDays(20L).build() )); @@ -128,7 +117,8 @@ public void testRemove_GivenTimeout() throws Exception { public void testRemove_GivenClientRequestsFailed() throws IOException { givenClientRequestsFailed(); - givenJobs(Arrays.asList( + AbstractExpiredJobDataRemoverTests.givenJobs(client, + Arrays.asList( JobTests.buildJobBuilder("none").build(), JobTests.buildJobBuilder("results-1").setResultsRetentionDays(10L).build(), JobTests.buildJobBuilder("results-2").setResultsRetentionDays(20L).build() @@ -154,7 +144,7 @@ private void givenClientRequestsFailed() { private void givenClientRequests(boolean shouldSucceed) { doAnswer(new Answer() { @Override - public Void answer(InvocationOnMock invocationOnMock) throws Throwable { + public Void answer(InvocationOnMock invocationOnMock) { capturedDeleteByQueryRequests.add((DeleteByQueryRequest) invocationOnMock.getArguments()[1]); ActionListener listener = (ActionListener) invocationOnMock.getArguments()[2]; @@ -170,16 +160,7 @@ public Void answer(InvocationOnMock invocationOnMock) throws Throwable { }).when(client).execute(same(DeleteByQueryAction.INSTANCE), any(), any()); } - @SuppressWarnings("unchecked") - private void givenJobs(List jobs) throws IOException { - SearchResponse response = AbstractExpiredJobDataRemoverTests.createSearchResponse(jobs); - - ActionFuture future = mock(ActionFuture.class); - when(future.actionGet()).thenReturn(response); - when(client.search(any())).thenReturn(future); - } - private ExpiredResultsRemover createExpiredResultsRemover() { - return new ExpiredResultsRemover(client, mock(AnomalyDetectionAuditor.class)); + return new ExpiredResultsRemover(originSettingClient, mock(AnomalyDetectionAuditor.class)); } } diff --git a/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/test/MockOriginSettingClient.java b/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/test/MockOriginSettingClient.java new file mode 100644 index 0000000000000..b47245a40b93a --- /dev/null +++ b/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/test/MockOriginSettingClient.java @@ -0,0 +1,48 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License; + * you may not use this file except in compliance with the Elastic License. + */ + +package org.elasticsearch.xpack.ml.test; + + +import org.elasticsearch.client.Client; +import org.elasticsearch.client.OriginSettingClient; +import org.elasticsearch.common.settings.Settings; +import org.elasticsearch.common.util.concurrent.ThreadContext; +import org.elasticsearch.threadpool.ThreadPool; +import org.mockito.Mockito; + +import static org.mockito.Mockito.mock; +import static org.mockito.Mockito.when; + +/** + * OriginSettingClient is a final class that cannot be mocked by mockito. + * The solution is to wrap a non-mocked OriginSettingClient around a + * mocked Client. All the mocking should take place on the client parameter. + */ +public class MockOriginSettingClient { + + /** + * Create a OriginSettingClient on a mocked client. + * + * @param client The mocked client + * @param origin Whatever + * @return A OriginSettingClient using a mocked client + */ + public static OriginSettingClient mockOriginSettingClient(Client client, String origin) { + + if (Mockito.mockingDetails(client).isMock() == false) { + throw new AssertionError("client should be a mock"); + } + ThreadContext tc = new ThreadContext(Settings.EMPTY); + + ThreadPool tp = mock(ThreadPool.class); + when(tp.getThreadContext()).thenReturn(tc); + + when(client.threadPool()).thenReturn(tp); + + return new OriginSettingClient(client, origin); + } +} diff --git a/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/utils/persistence/BatchedDocumentsIteratorTests.java b/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/utils/persistence/BatchedDocumentsIteratorTests.java index 381ff0612abe2..8373a75bfa117 100644 --- a/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/utils/persistence/BatchedDocumentsIteratorTests.java +++ b/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/utils/persistence/BatchedDocumentsIteratorTests.java @@ -6,12 +6,16 @@ package org.elasticsearch.xpack.ml.utils.persistence; import org.apache.lucene.search.TotalHits; -import org.elasticsearch.action.ActionFuture; -import org.elasticsearch.action.search.ClearScrollRequestBuilder; +import org.elasticsearch.action.ActionListener; +import org.elasticsearch.action.search.ClearScrollAction; +import org.elasticsearch.action.search.ClearScrollResponse; +import org.elasticsearch.action.search.SearchAction; import org.elasticsearch.action.search.SearchRequest; import org.elasticsearch.action.search.SearchResponse; +import org.elasticsearch.action.search.SearchScrollAction; import org.elasticsearch.action.search.SearchScrollRequest; import org.elasticsearch.client.Client; +import org.elasticsearch.client.OriginSettingClient; import org.elasticsearch.common.unit.TimeValue; import org.elasticsearch.index.query.QueryBuilder; import org.elasticsearch.index.query.QueryBuilders; @@ -19,6 +23,8 @@ import org.elasticsearch.search.SearchHits; import org.elasticsearch.search.internal.SearchContext; import org.elasticsearch.test.ESTestCase; +import org.elasticsearch.xpack.core.ClientHelper; +import org.elasticsearch.xpack.ml.test.MockOriginSettingClient; import org.elasticsearch.xpack.ml.test.SearchHitBuilder; import org.junit.Before; import org.mockito.ArgumentCaptor; @@ -30,9 +36,13 @@ import java.util.Deque; import java.util.List; import java.util.NoSuchElementException; +import java.util.concurrent.atomic.AtomicInteger; import static org.hamcrest.Matchers.equalTo; import static org.hamcrest.Matchers.is; +import static org.mockito.Matchers.any; +import static org.mockito.Matchers.eq; +import static org.mockito.Mockito.doAnswer; import static org.mockito.Mockito.mock; import static org.mockito.Mockito.when; @@ -42,6 +52,7 @@ public class BatchedDocumentsIteratorTests extends ESTestCase { private static final String SCROLL_ID = "someScrollId"; private Client client; + private OriginSettingClient originSettingClient; private boolean wasScrollCleared; private TestIterator testIterator; @@ -52,8 +63,9 @@ public class BatchedDocumentsIteratorTests extends ESTestCase { @Before public void setUpMocks() { client = Mockito.mock(Client.class); + originSettingClient = MockOriginSettingClient.mockOriginSettingClient(client, ClientHelper.ML_ORIGIN); wasScrollCleared = false; - testIterator = new TestIterator(client, INDEX_NAME); + testIterator = new TestIterator(originSettingClient, INDEX_NAME); givenClearScrollRequest(); } @@ -122,14 +134,14 @@ private String createJsonDoc(String value) { return "{\"foo\":\"" + value + "\"}"; } + @SuppressWarnings("unchecked") private void givenClearScrollRequest() { - ClearScrollRequestBuilder requestBuilder = mock(ClearScrollRequestBuilder.class); - when(client.prepareClearScroll()).thenReturn(requestBuilder); - when(requestBuilder.setScrollIds(Collections.singletonList(SCROLL_ID))).thenReturn(requestBuilder); - when(requestBuilder.get()).thenAnswer((invocation) -> { + doAnswer(invocationOnMock -> { + ActionListener listener = (ActionListener) invocationOnMock.getArguments()[2]; wasScrollCleared = true; + listener.onResponse(mock(ClearScrollResponse.class)); return null; - }); + }).when(client).execute(eq(ClearScrollAction.INSTANCE), any(), any()); } private void assertSearchRequest() { @@ -156,6 +168,8 @@ private class ScrollResponsesMocker { private long totalHits = 0; private List responses = new ArrayList<>(); + private AtomicInteger responseIndex = new AtomicInteger(0); + ScrollResponsesMocker addBatch(String... hits) { totalHits += hits.length; batches.add(hits); @@ -173,33 +187,23 @@ void finishMock() { givenNextResponse(batches.get(i)); } if (responses.size() > 0) { - ActionFuture first = wrapResponse(responses.get(0)); - if (responses.size() > 1) { - List> rest = new ArrayList<>(); - for (int i = 1; i < responses.size(); ++i) { - rest.add(wrapResponse(responses.get(i))); - } - - when(client.searchScroll(searchScrollRequestCaptor.capture())).thenReturn( - first, rest.toArray(new ActionFuture[rest.size() - 1])); - } else { - when(client.searchScroll(searchScrollRequestCaptor.capture())).thenReturn(first); - } + doAnswer(invocationOnMock -> { + ActionListener listener = (ActionListener) invocationOnMock.getArguments()[2]; + listener.onResponse(responses.get(responseIndex.getAndIncrement())); + return null; + }).when(client).execute(eq(SearchScrollAction.INSTANCE), searchScrollRequestCaptor.capture(), any()); } } + @SuppressWarnings("unchecked") private void givenInitialResponse(String... hits) { SearchResponse searchResponse = createSearchResponseWithHits(hits); - ActionFuture future = wrapResponse(searchResponse); - when(future.actionGet()).thenReturn(searchResponse); - when(client.search(searchRequestCaptor.capture())).thenReturn(future); - } - @SuppressWarnings("unchecked") - private ActionFuture wrapResponse(SearchResponse searchResponse) { - ActionFuture future = mock(ActionFuture.class); - when(future.actionGet()).thenReturn(searchResponse); - return future; + doAnswer(invocationOnMock -> { + ActionListener listener = (ActionListener) invocationOnMock.getArguments()[2]; + listener.onResponse(searchResponse); + return null; + }).when(client).execute(eq(SearchAction.INSTANCE), searchRequestCaptor.capture(), any()); } private void givenNextResponse(String... hits) { @@ -224,7 +228,7 @@ private SearchHits createHits(String... values) { } private static class TestIterator extends BatchedDocumentsIterator { - TestIterator(Client client, String jobId) { + TestIterator(OriginSettingClient client, String jobId) { super(client, jobId); }