diff --git a/server/src/internalClusterTest/java/org/opensearch/search/backpressure/SearchBackpressureIT.java b/server/src/internalClusterTest/java/org/opensearch/search/backpressure/SearchBackpressureIT.java new file mode 100644 index 0000000000000..f8629e2c88b07 --- /dev/null +++ b/server/src/internalClusterTest/java/org/opensearch/search/backpressure/SearchBackpressureIT.java @@ -0,0 +1,313 @@ +/* + * SPDX-License-Identifier: Apache-2.0 + * + * The OpenSearch Contributors require contributions made to + * this file be licensed under the Apache-2.0 license or a + * compatible open source license. + */ + +package org.opensearch.search.backpressure; + +import org.hamcrest.MatcherAssert; +import org.junit.After; +import org.junit.Before; +import org.opensearch.action.ActionListener; +import org.opensearch.action.ActionRequest; +import org.opensearch.action.ActionRequestValidationException; +import org.opensearch.action.ActionResponse; +import org.opensearch.action.ActionType; +import org.opensearch.action.search.SearchShardTask; +import org.opensearch.action.support.ActionFilters; +import org.opensearch.action.support.HandledTransportAction; +import org.opensearch.common.inject.Inject; +import org.opensearch.common.io.stream.StreamInput; +import org.opensearch.common.io.stream.StreamOutput; +import org.opensearch.common.settings.Settings; +import org.opensearch.common.unit.TimeValue; +import org.opensearch.plugins.ActionPlugin; +import org.opensearch.plugins.Plugin; +import org.opensearch.search.backpressure.settings.NodeDuressSettings; +import org.opensearch.search.backpressure.settings.SearchBackpressureSettings; +import org.opensearch.search.backpressure.settings.SearchShardTaskSettings; +import org.opensearch.search.backpressure.trackers.CpuUsageTracker; +import org.opensearch.search.backpressure.trackers.ElapsedTimeTracker; +import org.opensearch.search.backpressure.trackers.HeapUsageTracker; +import org.opensearch.tasks.Task; +import org.opensearch.tasks.TaskCancelledException; +import org.opensearch.tasks.TaskId; +import org.opensearch.test.OpenSearchIntegTestCase; +import org.opensearch.threadpool.ThreadPool; +import org.opensearch.transport.TransportService; + +import java.io.IOException; +import java.util.ArrayList; +import java.util.Collection; +import java.util.Collections; +import java.util.List; +import java.util.Map; +import java.util.concurrent.CountDownLatch; +import java.util.concurrent.TimeUnit; + +import static org.hamcrest.Matchers.containsString; +import static org.hamcrest.Matchers.instanceOf; +import static org.opensearch.test.hamcrest.OpenSearchAssertions.assertAcked; + +@OpenSearchIntegTestCase.ClusterScope(scope = OpenSearchIntegTestCase.Scope.SUITE) +public class SearchBackpressureIT extends OpenSearchIntegTestCase { + + private static final TimeValue TIMEOUT = new TimeValue(10, TimeUnit.SECONDS); + + @Override + protected Collection> nodePlugins() { + final List> plugins = new ArrayList<>(super.nodePlugins()); + plugins.add(TestPlugin.class); + return plugins; + } + + @Before + public final void setupNodeSettings() { + Settings request = Settings.builder() + .put(NodeDuressSettings.SETTING_CPU_THRESHOLD.getKey(), 0.0) + .put(NodeDuressSettings.SETTING_HEAP_THRESHOLD.getKey(), 0.0) + .put(NodeDuressSettings.SETTING_NUM_SUCCESSIVE_BREACHES.getKey(), 1) + .put(SearchShardTaskSettings.SETTING_TOTAL_HEAP_PERCENT_THRESHOLD.getKey(), 0.0) + .build(); + assertAcked(client().admin().cluster().prepareUpdateSettings().setPersistentSettings(request).get()); + } + + @After + public final void cleanupNodeSettings() { + assertAcked( + client().admin() + .cluster() + .prepareUpdateSettings() + .setPersistentSettings(Settings.builder().putNull("*")) + .setTransientSettings(Settings.builder().putNull("*")) + ); + } + + public void testSearchShardTaskCancellationWithHighElapsedTime() throws InterruptedException { + Settings request = Settings.builder() + .put(SearchBackpressureSettings.SETTING_MODE.getKey(), "enforced") + .put(ElapsedTimeTracker.SETTING_ELAPSED_TIME_MILLIS_THRESHOLD.getKey(), 1000) + .build(); + assertAcked(client().admin().cluster().prepareUpdateSettings().setPersistentSettings(request).get()); + + ExceptionCatchingListener listener = new ExceptionCatchingListener(); + client().execute(TestTransportAction.ACTION, new TestRequest(RequestType.HIGH_ELAPSED_TIME), listener); + assertTrue(listener.latch.await(TIMEOUT.getSeconds(), TimeUnit.SECONDS)); + + Exception caughtException = listener.getException(); + assertNotNull("SearchShardTask should have been cancelled with TaskCancelledException", caughtException); + MatcherAssert.assertThat(caughtException, instanceOf(TaskCancelledException.class)); + MatcherAssert.assertThat(caughtException.getMessage(), containsString("elapsed time exceeded")); + } + + public void testSearchShardTaskCancellationWithHighCpu() throws InterruptedException { + Settings request = Settings.builder() + .put(SearchBackpressureSettings.SETTING_MODE.getKey(), "enforced") + .put(CpuUsageTracker.SETTING_CPU_TIME_MILLIS_THRESHOLD.getKey(), 1000) + .build(); + assertAcked(client().admin().cluster().prepareUpdateSettings().setPersistentSettings(request).get()); + + ExceptionCatchingListener listener = new ExceptionCatchingListener(); + client().execute(TestTransportAction.ACTION, new TestRequest(RequestType.HIGH_CPU), listener); + assertTrue(listener.latch.await(TIMEOUT.getSeconds(), TimeUnit.SECONDS)); + + Exception caughtException = listener.getException(); + assertNotNull("SearchShardTask should have been cancelled with TaskCancelledException", caughtException); + MatcherAssert.assertThat(caughtException, instanceOf(TaskCancelledException.class)); + MatcherAssert.assertThat(caughtException.getMessage(), containsString("cpu usage exceeded")); + } + + public void testSearchShardTaskCancellationWithHighHeapUsage() throws InterruptedException { + // Before SearchBackpressureService cancels a task based on its heap usage, we need to build up the heap moving average + // To build up the heap moving average, we need to hit the same node with multiple requests and then hit the same node with a + // request having higher heap usage + String node = randomFrom(internalCluster().getNodeNames()); + final int MOVING_AVERAGE_WINDOW_SIZE = 10; + Settings request = Settings.builder() + .put(SearchBackpressureSettings.SETTING_MODE.getKey(), "enforced") + .put(HeapUsageTracker.SETTING_HEAP_PERCENT_THRESHOLD.getKey(), 0.0) + .put(HeapUsageTracker.SETTING_HEAP_VARIANCE_THRESHOLD.getKey(), 1.0) + .put(HeapUsageTracker.SETTING_HEAP_MOVING_AVERAGE_WINDOW_SIZE.getKey(), MOVING_AVERAGE_WINDOW_SIZE) + .build(); + assertAcked(client().admin().cluster().prepareUpdateSettings().setPersistentSettings(request).get()); + + ExceptionCatchingListener listener = new ExceptionCatchingListener(); + for (int i = 0; i < MOVING_AVERAGE_WINDOW_SIZE; i++) { + client(node).execute(TestTransportAction.ACTION, new TestRequest(RequestType.HIGH_HEAP), listener); + } + + listener = new ExceptionCatchingListener(); + client(node).execute(TestTransportAction.ACTION, new TestRequest(RequestType.HIGHER_HEAP), listener); + assertTrue(listener.latch.await(TIMEOUT.getSeconds(), TimeUnit.SECONDS)); + + Exception caughtException = listener.getException(); + assertNotNull("SearchShardTask should have been cancelled with TaskCancelledException", caughtException); + MatcherAssert.assertThat(caughtException, instanceOf(TaskCancelledException.class)); + MatcherAssert.assertThat(caughtException.getMessage(), containsString("heap usage exceeded")); + } + + public void testSearchCancellationWithBackpressureDisabled() throws InterruptedException { + Settings request = Settings.builder().put(SearchBackpressureSettings.SETTING_MODE.getKey(), "monitor_only").build(); + assertAcked(client().admin().cluster().prepareUpdateSettings().setPersistentSettings(request).get()); + + ExceptionCatchingListener listener = new ExceptionCatchingListener(); + client().execute(TestTransportAction.ACTION, new TestRequest(RequestType.HIGH_ELAPSED_TIME), listener); + // waiting for the TIMEOUT * 3 time for the request to complete and the latch to countdown. + assertTrue( + "SearchShardTask should have been completed by now and countdown the latch", + listener.latch.await(TIMEOUT.getSeconds() * 3, TimeUnit.SECONDS) + ); + + Exception caughtException = listener.getException(); + assertNull("SearchShardTask shouldn't have cancelled for monitor_only mode", caughtException); + } + + private static class ExceptionCatchingListener implements ActionListener { + private final CountDownLatch latch; + private Exception exception = null; + + public ExceptionCatchingListener() { + this.latch = new CountDownLatch(1); + } + + @Override + public void onResponse(TestResponse r) { + latch.countDown(); + } + + @Override + public void onFailure(Exception e) { + this.exception = e; + latch.countDown(); + } + + private Exception getException() { + return exception; + } + } + + enum RequestType { + HIGH_CPU, + HIGH_HEAP, + HIGHER_HEAP, + HIGH_ELAPSED_TIME; + } + + public static class TestRequest extends ActionRequest { + private final RequestType type; + + public TestRequest(RequestType type) { + this.type = type; + } + + public TestRequest(StreamInput in) throws IOException { + super(in); + this.type = in.readEnum(RequestType.class); + } + + @Override + public ActionRequestValidationException validate() { + return null; + } + + @Override + public Task createTask(long id, String type, String action, TaskId parentTaskId, Map headers) { + return new SearchShardTask(id, type, action, "", parentTaskId, headers); + } + + @Override + public void writeTo(StreamOutput out) throws IOException { + super.writeTo(out); + out.writeEnum(type); + } + + public RequestType getType() { + return this.type; + } + } + + public static class TestResponse extends ActionResponse { + public TestResponse() {} + + public TestResponse(StreamInput in) {} + + @Override + public void writeTo(StreamOutput out) throws IOException {} + } + + public static class TestTransportAction extends HandledTransportAction { + public static final ActionType ACTION = new ActionType<>("internal::test_action", TestResponse::new); + private final ThreadPool threadPool; + + @Inject + public TestTransportAction(TransportService transportService, ThreadPool threadPool, ActionFilters actionFilters) { + super(ACTION.name(), transportService, actionFilters, TestRequest::new); + this.threadPool = threadPool; + } + + @Override + protected void doExecute(Task task, TestRequest request, ActionListener listener) { + threadPool.executor(ThreadPool.Names.SEARCH).execute(() -> { + try { + SearchShardTask searchShardTask = (SearchShardTask) task; + long startTime = System.nanoTime(); + + // Doing a busy-wait until task cancellation or timeout. + // We are running HIGH_HEAP requests to build up heap moving average and not expect it to get cancelled. + do { + doWork(request); + } while (request.type != RequestType.HIGH_HEAP + && searchShardTask.isCancelled() == false + && (System.nanoTime() - startTime) < TIMEOUT.getNanos()); + + if (searchShardTask.isCancelled()) { + throw new TaskCancelledException(searchShardTask.getReasonCancelled()); + } else { + listener.onResponse(new TestResponse()); + } + } catch (Exception e) { + listener.onFailure(e); + } + }); + } + + private void doWork(TestRequest request) throws InterruptedException { + switch (request.getType()) { + case HIGH_CPU: + long i = 0, j = 1, k = 1, iterations = 1000; + do { + j += i; + k *= j; + i++; + } while (i < iterations); + break; + case HIGH_HEAP: + Byte[] bytes = new Byte[100000]; + int[] ints = new int[1000]; + break; + case HIGHER_HEAP: + Byte[] more_bytes = new Byte[1000000]; + int[] more_ints = new int[10000]; + break; + case HIGH_ELAPSED_TIME: + Thread.sleep(100); + break; + } + } + } + + public static class TestPlugin extends Plugin implements ActionPlugin { + @Override + public List> getActions() { + return Collections.singletonList(new ActionHandler<>(TestTransportAction.ACTION, TestTransportAction.class)); + } + + @Override + public List> getClientActions() { + return Collections.singletonList(TestTransportAction.ACTION); + } + } +}