diff --git a/muted-tests.yml b/muted-tests.yml index 47b1bf907d797..9dd522ab0e617 100644 --- a/muted-tests.yml +++ b/muted-tests.yml @@ -121,9 +121,6 @@ tests: - class: org.elasticsearch.xpack.esql.action.ManyShardsIT method: testConcurrentQueries issue: https://github.com/elastic/elasticsearch/issues/112424 -- class: org.elasticsearch.xpack.inference.external.http.RequestBasedTaskRunnerTests - method: testLoopOneAtATime - issue: https://github.com/elastic/elasticsearch/issues/112471 - class: org.elasticsearch.ingest.geoip.IngestGeoIpClientYamlTestSuiteIT issue: https://github.com/elastic/elasticsearch/issues/111497 - class: org.elasticsearch.smoketest.SmokeTestIngestWithAllDepsClientYamlTestSuiteIT diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/http/RequestBasedTaskRunner.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/http/RequestBasedTaskRunner.java index 85aac661e6091..4aaf48213f99b 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/http/RequestBasedTaskRunner.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/http/RequestBasedTaskRunner.java @@ -40,7 +40,7 @@ class RequestBasedTaskRunner { * Else, offload to a new thread so we do not block another threadpool's thread. */ public void requestNextRun() { - if (loopCount.getAndIncrement() == 0) { + if (isRunning.get() && loopCount.getAndIncrement() == 0) { var currentThreadPool = EsExecutors.executorName(Thread.currentThread().getName()); if (executorServiceName.equalsIgnoreCase(currentThreadPool)) { run(); diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/external/http/RequestBasedTaskRunnerTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/external/http/RequestBasedTaskRunnerTests.java index d24bdbe444f52..7e93e68aae3e7 100644 --- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/external/http/RequestBasedTaskRunnerTests.java +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/external/http/RequestBasedTaskRunnerTests.java @@ -7,21 +7,23 @@ package org.elasticsearch.xpack.inference.external.http; +import org.elasticsearch.common.util.concurrent.EsExecutors; import org.elasticsearch.test.ESTestCase; import org.elasticsearch.threadpool.ThreadPool; -import org.junit.After; import org.junit.Before; -import java.util.concurrent.CountDownLatch; import java.util.concurrent.atomic.AtomicInteger; -import java.util.concurrent.locks.ReentrantLock; +import java.util.concurrent.atomic.AtomicReference; import static org.elasticsearch.xpack.inference.InferencePlugin.UTILITY_THREAD_POOL_NAME; -import static org.elasticsearch.xpack.inference.Utils.inferenceUtilityPool; import static org.hamcrest.Matchers.equalTo; -import static org.mockito.Mockito.spy; +import static org.mockito.ArgumentMatchers.eq; +import static org.mockito.Mockito.mock; import static org.mockito.Mockito.times; import static org.mockito.Mockito.verify; +import static org.mockito.Mockito.verifyNoInteractions; +import static org.mockito.Mockito.verifyNoMoreInteractions; +import static org.mockito.Mockito.when; public class RequestBasedTaskRunnerTests extends ESTestCase { private ThreadPool threadPool; @@ -29,128 +31,74 @@ public class RequestBasedTaskRunnerTests extends ESTestCase { @Before public void setUp() throws Exception { super.setUp(); - threadPool = spy(createThreadPool(inferenceUtilityPool())); + threadPool = mock(); + when(threadPool.executor(UTILITY_THREAD_POOL_NAME)).thenReturn(EsExecutors.DIRECT_EXECUTOR_SERVICE); } - @After - public void tearDown() throws Exception { - terminate(threadPool); - super.tearDown(); - } + public void testRequestWhileLoopingWillRerunCommand() { + var expectedTimesRerun = randomInt(5); + AtomicInteger counter = new AtomicInteger(0); - public void testLoopOneAtATime() throws Exception { - // count the number of times the runnable is called - var counter = new AtomicInteger(0); - - // block the runnable and wait for the test thread to take an action - var lock = new ReentrantLock(); - var condition = lock.newCondition(); - Runnable block = () -> { - try { - try { - lock.lock(); - condition.await(); - } finally { - lock.unlock(); - } - } catch (InterruptedException e) { - fail(e, "did not unblock the thread in time, likely during threadpool terminate"); - } - }; - Runnable unblock = () -> { - try { - lock.lock(); - condition.signalAll(); - } finally { - lock.unlock(); + var requestNextRun = new AtomicReference(); + Runnable command = () -> { + if (counter.getAndIncrement() < expectedTimesRerun) { + requestNextRun.get().run(); } }; - - var runner = new RequestBasedTaskRunner(() -> { - counter.incrementAndGet(); - block.run(); - }, threadPool, UTILITY_THREAD_POOL_NAME); - - // given we have not called requestNextRun, then no thread should have started - assertThat(counter.get(), equalTo(0)); - verify(threadPool, times(0)).executor(UTILITY_THREAD_POOL_NAME); - + var runner = new RequestBasedTaskRunner(command, threadPool, UTILITY_THREAD_POOL_NAME); + requestNextRun.set(runner::requestNextRun); runner.requestNextRun(); - // given that we have called requestNextRun, then 1 thread should run once - assertBusy(() -> { - verify(threadPool, times(1)).executor(UTILITY_THREAD_POOL_NAME); - assertThat(counter.get(), equalTo(1)); - }); - - // given that we have called requestNextRun while a thread was running, and the thread was blocked - runner.requestNextRun(); - // then 1 thread should run once - verify(threadPool, times(1)).executor(UTILITY_THREAD_POOL_NAME); - assertThat(counter.get(), equalTo(1)); + verify(threadPool, times(1)).executor(eq(UTILITY_THREAD_POOL_NAME)); + verifyNoMoreInteractions(threadPool); + assertThat(counter.get(), equalTo(expectedTimesRerun + 1)); + } - // given the thread is unblocked - unblock.run(); - // then 1 thread should run twice - verify(threadPool, times(1)).executor(UTILITY_THREAD_POOL_NAME); - assertBusy(() -> assertThat(counter.get(), equalTo(2))); + public void testRequestWhileNotLoopingWillQueueCommand() { + AtomicInteger counter = new AtomicInteger(0); - // given the thread is unblocked again, but there were only two calls to requestNextRun - unblock.run(); - // then 1 thread should run twice - verify(threadPool, times(1)).executor(UTILITY_THREAD_POOL_NAME); - assertBusy(() -> assertThat(counter.get(), equalTo(2))); + var runner = new RequestBasedTaskRunner(counter::incrementAndGet, threadPool, UTILITY_THREAD_POOL_NAME); - // given no thread is running, when we call requestNextRun - runner.requestNextRun(); - // then a second thread should start for the third run - assertBusy(() -> { - verify(threadPool, times(2)).executor(UTILITY_THREAD_POOL_NAME); - assertThat(counter.get(), equalTo(3)); - }); - - // given the thread is unblocked, then it should exit and rejoin the threadpool - unblock.run(); - assertTrue("Test thread should unblock after all runs complete", terminate(threadPool)); - - // final check - we ran three times on two threads - verify(threadPool, times(2)).executor(UTILITY_THREAD_POOL_NAME); - assertThat(counter.get(), equalTo(3)); + for (int i = 1; i < randomInt(10); i++) { + runner.requestNextRun(); + verify(threadPool, times(i)).executor(eq(UTILITY_THREAD_POOL_NAME)); + assertThat(counter.get(), equalTo(i)); + } + ; } - public void testCancel() throws Exception { - // count the number of times the runnable is called - var counter = new AtomicInteger(0); - var latch = new CountDownLatch(1); - var runner = new RequestBasedTaskRunner(() -> { - counter.incrementAndGet(); - try { - latch.await(); - } catch (InterruptedException e) { - fail(e, "did not unblock the thread in time, likely during threadpool terminate"); - } - }, threadPool, UTILITY_THREAD_POOL_NAME); + public void testCancelBeforeRunning() { + AtomicInteger counter = new AtomicInteger(0); - // given that we have called requestNextRun, then 1 thread should run once + var runner = new RequestBasedTaskRunner(counter::incrementAndGet, threadPool, UTILITY_THREAD_POOL_NAME); + runner.cancel(); runner.requestNextRun(); - assertBusy(() -> { - verify(threadPool, times(1)).executor(UTILITY_THREAD_POOL_NAME); - assertThat(counter.get(), equalTo(1)); - }); - // given that a thread is running, three more calls will be queued - runner.requestNextRun(); - runner.requestNextRun(); + verifyNoInteractions(threadPool); + assertThat(counter.get(), equalTo(0)); + } + + public void testCancelWhileRunning() { + var expectedTimesRerun = randomInt(5); + AtomicInteger counter = new AtomicInteger(0); + + var runnerRef = new AtomicReference(); + Runnable command = () -> { + if (counter.getAndIncrement() < expectedTimesRerun) { + runnerRef.get().requestNextRun(); + } + runnerRef.get().cancel(); + }; + var runner = new RequestBasedTaskRunner(command, threadPool, UTILITY_THREAD_POOL_NAME); + runnerRef.set(runner); runner.requestNextRun(); - // when we cancel the thread, then the thread should immediately exit and rejoin - runner.cancel(); - latch.countDown(); - assertTrue("Test thread should unblock after all runs complete", terminate(threadPool)); + verify(threadPool, times(1)).executor(eq(UTILITY_THREAD_POOL_NAME)); + verifyNoMoreInteractions(threadPool); + assertThat(counter.get(), equalTo(1)); - // given that we called cancel, when we call requestNextRun then no thread should start runner.requestNextRun(); - verify(threadPool, times(1)).executor(UTILITY_THREAD_POOL_NAME); + verifyNoMoreInteractions(threadPool); assertThat(counter.get(), equalTo(1)); }