From d57e8f0d5abf8cace8b868beb115ad5c52104735 Mon Sep 17 00:00:00 2001
From: Andriy Redko <andriy.redko@aiven.io>
Date: Wed, 22 Jun 2022 18:47:05 -0400
Subject: [PATCH] [BUG] opensearch crashes on closed client connection before
 search reply (#3626) (#3645) (#3655)

* [BUG] opensearch crashes on closed client connection before search reply

Signed-off-by: Andriy Redko <andriy.redko@aiven.io>

* Addressing code review comments

Signed-off-by: Andriy Redko <andriy.redko@aiven.io>
(cherry picked from commit 3dba46ecd14c5b1eb18500357400d6cefafc2836)

Co-authored-by: Andriy Redko <andriy.redko@aiven.io>
Signed-off-by: Andriy Redko <andriy.redko@aiven.io>

Co-authored-by: opensearch-trigger-bot[bot] <98922864+opensearch-trigger-bot[bot]@users.noreply.github.com>
---
 .../search/AbstractSearchAsyncAction.java     |  12 +-
 .../AbstractSearchAsyncActionTests.java       | 155 +++++++++++++++++-
 2 files changed, 162 insertions(+), 5 deletions(-)

diff --git a/server/src/main/java/org/opensearch/action/search/AbstractSearchAsyncAction.java b/server/src/main/java/org/opensearch/action/search/AbstractSearchAsyncAction.java
index 190904145b091..21b5f3244471b 100644
--- a/server/src/main/java/org/opensearch/action/search/AbstractSearchAsyncAction.java
+++ b/server/src/main/java/org/opensearch/action/search/AbstractSearchAsyncAction.java
@@ -454,7 +454,11 @@ private void onShardFailure(final int shardIndex, @Nullable SearchShardTarget sh
         }
         final int totalOps = this.totalOps.incrementAndGet();
         if (totalOps == expectedTotalOps) {
-            onPhaseDone();
+            try {
+                onPhaseDone();
+            } catch (final Exception ex) {
+                onPhaseFailure(this, "The phase has failed", ex);
+            }
         } else if (totalOps > expectedTotalOps) {
             throw new AssertionError(
                 "unexpected higher total ops [" + totalOps + "] compared to expected [" + expectedTotalOps + "]",
@@ -559,7 +563,11 @@ private void successfulShardExecution(SearchShardIterator shardsIt) {
         }
         final int xTotalOps = totalOps.addAndGet(remainingOpsOnIterator);
         if (xTotalOps == expectedTotalOps) {
-            onPhaseDone();
+            try {
+                onPhaseDone();
+            } catch (final Exception ex) {
+                onPhaseFailure(this, "The phase has failed", ex);
+            }
         } else if (xTotalOps > expectedTotalOps) {
             throw new AssertionError(
                 "unexpected higher total ops [" + xTotalOps + "] compared to expected [" + expectedTotalOps + "]",
diff --git a/server/src/test/java/org/opensearch/action/search/AbstractSearchAsyncActionTests.java b/server/src/test/java/org/opensearch/action/search/AbstractSearchAsyncActionTests.java
index f4b45b9c36f96..2a990f8e3b65a 100644
--- a/server/src/test/java/org/opensearch/action/search/AbstractSearchAsyncActionTests.java
+++ b/server/src/test/java/org/opensearch/action/search/AbstractSearchAsyncActionTests.java
@@ -32,6 +32,8 @@
 
 package org.opensearch.action.search;
 
+import org.junit.After;
+import org.junit.Before;
 import org.opensearch.action.ActionListener;
 import org.opensearch.action.OriginalIndices;
 import org.opensearch.action.support.IndicesOptions;
@@ -43,25 +45,34 @@
 import org.opensearch.index.Index;
 import org.opensearch.index.query.MatchAllQueryBuilder;
 import org.opensearch.index.shard.ShardId;
+import org.opensearch.index.shard.ShardNotFoundException;
 import org.opensearch.search.SearchPhaseResult;
 import org.opensearch.search.SearchShardTarget;
 import org.opensearch.search.internal.AliasFilter;
 import org.opensearch.search.internal.InternalSearchResponse;
 import org.opensearch.search.internal.ShardSearchContextId;
 import org.opensearch.search.internal.ShardSearchRequest;
+import org.opensearch.search.query.QuerySearchResult;
 import org.opensearch.test.OpenSearchTestCase;
 import org.opensearch.transport.Transport;
 
 import java.util.ArrayList;
+import java.util.Arrays;
 import java.util.Collections;
 import java.util.HashSet;
 import java.util.List;
 import java.util.Set;
+import java.util.UUID;
 import java.util.concurrent.CopyOnWriteArraySet;
+import java.util.concurrent.CountDownLatch;
+import java.util.concurrent.ExecutorService;
+import java.util.concurrent.Executors;
 import java.util.concurrent.TimeUnit;
+import java.util.concurrent.atomic.AtomicBoolean;
 import java.util.concurrent.atomic.AtomicLong;
 import java.util.concurrent.atomic.AtomicReference;
 import java.util.function.BiFunction;
+import java.util.stream.IntStream;
 
 import static org.hamcrest.Matchers.equalTo;
 import static org.hamcrest.Matchers.greaterThanOrEqualTo;
@@ -71,6 +82,22 @@ public class AbstractSearchAsyncActionTests extends OpenSearchTestCase {
 
     private final List<Tuple<String, String>> resolvedNodes = new ArrayList<>();
     private final Set<ShardSearchContextId> releasedContexts = new CopyOnWriteArraySet<>();
+    private ExecutorService executor;
+
+    @Before
+    @Override
+    public void setUp() throws Exception {
+        super.setUp();
+        executor = Executors.newFixedThreadPool(1);
+    }
+
+    @After
+    @Override
+    public void tearDown() throws Exception {
+        super.tearDown();
+        executor.shutdown();
+        assertTrue(executor.awaitTermination(1, TimeUnit.SECONDS));
+    }
 
     private AbstractSearchAsyncAction<SearchPhaseResult> createAction(
         SearchRequest request,
@@ -78,6 +105,26 @@ private AbstractSearchAsyncAction<SearchPhaseResult> createAction(
         ActionListener<SearchResponse> listener,
         final boolean controlled,
         final AtomicLong expected
+    ) {
+        return createAction(
+            request,
+            results,
+            listener,
+            controlled,
+            false,
+            expected,
+            new SearchShardIterator(null, null, Collections.emptyList(), null)
+        );
+    }
+
+    private AbstractSearchAsyncAction<SearchPhaseResult> createAction(
+        SearchRequest request,
+        ArraySearchPhaseResults<SearchPhaseResult> results,
+        ActionListener<SearchResponse> listener,
+        final boolean controlled,
+        final boolean failExecutePhaseOnShard,
+        final AtomicLong expected,
+        final SearchShardIterator... shards
     ) {
         final Runnable runnable;
         final TransportSearchAction.SearchTimeProvider timeProvider;
@@ -105,10 +152,10 @@ private AbstractSearchAsyncAction<SearchPhaseResult> createAction(
             Collections.singletonMap("foo", new AliasFilter(new MatchAllQueryBuilder())),
             Collections.singletonMap("foo", 2.0f),
             Collections.singletonMap("name", Sets.newHashSet("bar", "baz")),
-            null,
+            executor,
             request,
             listener,
-            new GroupShardsIterator<>(Collections.singletonList(new SearchShardIterator(null, null, Collections.emptyList(), null))),
+            new GroupShardsIterator<>(Arrays.asList(shards)),
             timeProvider,
             ClusterState.EMPTY_STATE,
             null,
@@ -126,7 +173,13 @@ protected void executePhaseOnShard(
                 final SearchShardIterator shardIt,
                 final SearchShardTarget shard,
                 final SearchActionListener<SearchPhaseResult> listener
-            ) {}
+            ) {
+                if (failExecutePhaseOnShard) {
+                    listener.onFailure(new ShardNotFoundException(shardIt.shardId()));
+                } else {
+                    listener.onResponse(new QuerySearchResult());
+                }
+            }
 
             @Override
             long buildTookInMillis() {
@@ -328,6 +381,102 @@ private static ArraySearchPhaseResults<SearchPhaseResult> phaseResults(
         return phaseResults;
     }
 
+    public void testOnShardFailurePhaseDoneFailure() throws InterruptedException {
+        final Index index = new Index("test", UUID.randomUUID().toString());
+        final CountDownLatch latch = new CountDownLatch(1);
+        final AtomicBoolean fail = new AtomicBoolean(true);
+
+        final SearchShardIterator[] shards = IntStream.range(0, 5 + randomInt(10))
+            .mapToObj(i -> new SearchShardIterator(null, new ShardId(index, i), Arrays.asList("n1", "n2", "n3"), null, null, null))
+            .toArray(SearchShardIterator[]::new);
+
+        SearchRequest searchRequest = new SearchRequest().allowPartialSearchResults(true);
+        searchRequest.setMaxConcurrentShardRequests(1);
+
+        final ArraySearchPhaseResults<SearchPhaseResult> queryResult = new ArraySearchPhaseResults<>(shards.length);
+        AbstractSearchAsyncAction<SearchPhaseResult> action = createAction(
+            searchRequest,
+            queryResult,
+            new ActionListener<SearchResponse>() {
+                @Override
+                public void onResponse(SearchResponse response) {
+
+                }
+
+                @Override
+                public void onFailure(Exception e) {
+                    if (fail.compareAndSet(true, false)) {
+                        try {
+                            throw new RuntimeException("Simulated exception");
+                        } finally {
+                            executor.submit(() -> latch.countDown());
+                        }
+                    }
+                }
+            },
+            false,
+            true,
+            new AtomicLong(),
+            shards
+        );
+        action.run();
+        assertTrue(latch.await(1, TimeUnit.SECONDS));
+
+        InternalSearchResponse internalSearchResponse = InternalSearchResponse.empty();
+        SearchResponse searchResponse = action.buildSearchResponse(internalSearchResponse, action.buildShardFailures(), null, null);
+        assertSame(searchResponse.getAggregations(), internalSearchResponse.aggregations());
+        assertSame(searchResponse.getSuggest(), internalSearchResponse.suggest());
+        assertSame(searchResponse.getProfileResults(), internalSearchResponse.profile());
+        assertSame(searchResponse.getHits(), internalSearchResponse.hits());
+        assertThat(searchResponse.getSuccessfulShards(), equalTo(0));
+    }
+
+    public void testOnShardSuccessPhaseDoneFailure() throws InterruptedException {
+        final Index index = new Index("test", UUID.randomUUID().toString());
+        final CountDownLatch latch = new CountDownLatch(1);
+        final AtomicBoolean fail = new AtomicBoolean(true);
+
+        final SearchShardIterator[] shards = IntStream.range(0, 5 + randomInt(10))
+            .mapToObj(i -> new SearchShardIterator(null, new ShardId(index, i), Arrays.asList("n1", "n2", "n3"), null, null, null))
+            .toArray(SearchShardIterator[]::new);
+
+        SearchRequest searchRequest = new SearchRequest().allowPartialSearchResults(true);
+        searchRequest.setMaxConcurrentShardRequests(1);
+
+        final ArraySearchPhaseResults<SearchPhaseResult> queryResult = new ArraySearchPhaseResults<>(shards.length);
+        AbstractSearchAsyncAction<SearchPhaseResult> action = createAction(
+            searchRequest,
+            queryResult,
+            new ActionListener<SearchResponse>() {
+                @Override
+                public void onResponse(SearchResponse response) {
+                    if (fail.compareAndSet(true, false)) {
+                        throw new RuntimeException("Simulated exception");
+                    }
+                }
+
+                @Override
+                public void onFailure(Exception e) {
+                    executor.submit(() -> latch.countDown());
+                }
+            },
+            false,
+            false,
+            new AtomicLong(),
+            shards
+        );
+        action.run();
+        assertTrue(latch.await(1, TimeUnit.SECONDS));
+
+        InternalSearchResponse internalSearchResponse = InternalSearchResponse.empty();
+        SearchResponse searchResponse = action.buildSearchResponse(internalSearchResponse, action.buildShardFailures(), null, null);
+        assertSame(searchResponse.getAggregations(), internalSearchResponse.aggregations());
+        assertSame(searchResponse.getSuggest(), internalSearchResponse.suggest());
+        assertSame(searchResponse.getProfileResults(), internalSearchResponse.profile());
+        assertSame(searchResponse.getHits(), internalSearchResponse.hits());
+        assertThat(searchResponse.getSuccessfulShards(), equalTo(shards.length));
+    }
+
     private static final class PhaseResult extends SearchPhaseResult {
         PhaseResult(ShardSearchContextId contextId) {
             this.contextId = contextId;