Skip to content

Commit

Permalink
Option to spy
Browse files Browse the repository at this point in the history
  • Loading branch information
albertzaharovits committed Jun 7, 2024
1 parent 90f1b0b commit 3e0a635
Show file tree
Hide file tree
Showing 4 changed files with 29 additions and 24 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -61,6 +61,7 @@
import static java.util.Collections.emptySet;
import static org.elasticsearch.test.ClusterServiceUtils.createClusterService;
import static org.elasticsearch.test.ClusterServiceUtils.setState;
import static org.elasticsearch.test.transport.MockTransportService.createTaskManager;

/**
* The test case for unit testing task manager and related transport actions
Expand Down Expand Up @@ -176,12 +177,7 @@ public TestNode(String name, ThreadPool threadPool, Settings settings) {
discoveryNode.set(DiscoveryNodeUtils.create(name, address.publishAddress(), emptyMap(), emptySet()));
return discoveryNode.get();
};
TaskManager taskManager;
if (MockTaskManager.USE_MOCK_TASK_MANAGER_SETTING.get(settings)) {
taskManager = new MockTaskManager(settings, threadPool, emptySet());
} else {
taskManager = new TaskManager(settings, threadPool, emptySet());
}
TaskManager taskManager = createTaskManager(settings, threadPool, emptySet(), Tracer.NOOP);
transportService = new TransportService(
settings,
new Netty4Transport(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -34,13 +34,14 @@
import java.util.Collections;
import java.util.List;
import java.util.concurrent.CopyOnWriteArrayList;
import java.util.concurrent.CountDownLatch;
import java.util.concurrent.ExecutionException;
import java.util.concurrent.ExecutorService;
import java.util.concurrent.Executors;
import java.util.concurrent.TimeUnit;
import java.util.concurrent.atomic.AtomicBoolean;
import java.util.concurrent.atomic.AtomicLong;

import static org.elasticsearch.test.tasks.MockTaskManager.USE_SPY_TASK_MANAGER_SETTING;
import static org.hamcrest.Matchers.equalTo;
import static org.hamcrest.Matchers.instanceOf;
import static org.mockito.ArgumentMatchers.anyString;
Expand Down Expand Up @@ -74,7 +75,7 @@ private MockTransportService startTransport(String id, List<DiscoveryNode> known

public void testRemoteTaskCancellationOnFailedResponse() throws Exception {
Settings.Builder remoteTransportSettingsBuilder = Settings.builder();
remoteTransportSettingsBuilder.put("tests.mock.taskmanager.enabled", true);
remoteTransportSettingsBuilder.put(USE_SPY_TASK_MANAGER_SETTING.getKey(), true);
try (
MockTransportService remoteTransport = RemoteClusterConnectionTests.startTransport(
"seed_node",
Expand Down Expand Up @@ -122,22 +123,22 @@ public void testRemoteTaskCancellationOnFailedResponse() throws Exception {
randomBoolean()
);

AtomicBoolean cancelChildReceived = new AtomicBoolean(false);
CountDownLatch cancelChildReceived = new CountDownLatch(1);
remoteTransport.addRequestHandlingBehavior(
TaskCancellationService.CANCEL_CHILD_ACTION_NAME,
(handler, request, channel, task) -> {
handler.messageReceived(request, channel, task);
cancelChildReceived.set(true);
cancelChildReceived.countDown();
}
);
AtomicLong searchShardsRequestId = new AtomicLong(-1);
AtomicBoolean cancelChildSent = new AtomicBoolean(false);
CountDownLatch cancelChildSent = new CountDownLatch(1);
localService.addSendBehavior(remoteTransport, (connection, requestId, action, request, options) -> {
connection.sendRequest(requestId, action, request, options);
if (action.equals("indices:admin/search/search_shards")) {
searchShardsRequestId.set(requestId);
} else if (action.equals(TaskCancellationService.CANCEL_CHILD_ACTION_NAME)) {
cancelChildSent.set(true);
cancelChildSent.countDown();
}
});

Expand All @@ -148,15 +149,9 @@ public void testRemoteTaskCancellationOnFailedResponse() throws Exception {
assertThat(e.getCause(), instanceOf(RemoteTransportException.class));

// assert remote task is cancelled
assertBusy(cancelChildSent::get);
assertBusy(cancelChildReceived::get);
assertBusy(
() -> verify(remoteTransport.getTaskManager()).cancelChildLocal(
eq(parentTaskId),
eq(searchShardsRequestId.get()),
anyString()
)
);
safeAwait(cancelChildSent);
safeAwait(cancelChildReceived);
verify(remoteTransport.getTaskManager()).cancelChildLocal(eq(parentTaskId), eq(searchShardsRequestId.get()), anyString());
}
}
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,12 @@ public class MockTaskManager extends TaskManager {
Property.NodeScope
);

public static final Setting<Boolean> USE_SPY_TASK_MANAGER_SETTING = Setting.boolSetting(
"tests.spy.taskmanager.enabled",
false,
Property.NodeScope
);

private final Collection<MockTaskManagerListener> listeners = new CopyOnWriteArrayList<>();

public MockTaskManager(Settings settings, ThreadPool threadPool, Set<String> taskHeaders) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -103,7 +103,7 @@ public class MockTransportService extends TransportService {
public static class TestPlugin extends Plugin {
@Override
public List<Setting<?>> getSettings() {
return List.of(MockTaskManager.USE_MOCK_TASK_MANAGER_SETTING);
return List.of(MockTaskManager.USE_MOCK_TASK_MANAGER_SETTING, MockTaskManager.USE_SPY_TASK_MANAGER_SETTING);
}
}

Expand Down Expand Up @@ -311,9 +311,17 @@ private static TransportAddress[] extractTransportAddresses(TransportService tra
return transportAddresses.toArray(new TransportAddress[transportAddresses.size()]);
}

private static TaskManager createTaskManager(Settings settings, ThreadPool threadPool, Set<String> taskHeaders, Tracer tracer) {
public static TaskManager createTaskManager(Settings settings, ThreadPool threadPool, Set<String> taskHeaders, Tracer tracer) {
if (MockTaskManager.USE_SPY_TASK_MANAGER_SETTING.get(settings)) {
return spy(createMockTaskManager(settings, threadPool, taskHeaders, tracer));
} else {
return createMockTaskManager(settings, threadPool, taskHeaders, tracer);
}
}

private static TaskManager createMockTaskManager(Settings settings, ThreadPool threadPool, Set<String> taskHeaders, Tracer tracer) {
if (MockTaskManager.USE_MOCK_TASK_MANAGER_SETTING.get(settings)) {
return spy(new MockTaskManager(settings, threadPool, taskHeaders));
return new MockTaskManager(settings, threadPool, taskHeaders);
} else {
return new TaskManager(settings, threadPool, taskHeaders, tracer);
}
Expand Down

0 comments on commit 3e0a635

Please sign in to comment.