Skip to content

Commit

Permalink
Segment Replication - Block snapshot creation if the target primary s…
Browse files Browse the repository at this point in the history
…hard has not completed failover. (#9629)

* Segment Replication - Block snapshot creation if the target primary shard has not completed failover.

Signed-off-by: Marc Handalian <handalm@amazon.com>

* Add unit test and remove non deterministic IT.

Signed-off-by: Marc Handalian <handalm@amazon.com>

* spotless

Signed-off-by: Marc Handalian <handalm@amazon.com>

---------

Signed-off-by: Marc Handalian <handalm@amazon.com>
  • Loading branch information
mch2 authored and pull[bot] committed Apr 17, 2024
1 parent 62f1d55 commit 1200000
Show file tree
Hide file tree
Showing 2 changed files with 112 additions and 0 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -379,6 +379,12 @@ private void snapshot(
if (indexShard.routingEntry().primary() == false) {
throw new IndexShardSnapshotFailedException(shardId, "snapshot should be performed only on primary");
}
if (indexShard.indexSettings().isSegRepEnabled() && indexShard.isPrimaryMode() == false) {
throw new IndexShardSnapshotFailedException(
shardId,
"snapshot triggered on a new primary following failover and cannot proceed until promotion is complete"
);
}
if (indexShard.routingEntry().relocating()) {
// do not snapshot when in the process of relocation of primaries so we won't get conflicts
throw new IndexShardSnapshotFailedException(shardId, "cannot snapshot while relocating");
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,9 @@
import org.opensearch.action.admin.indices.forcemerge.ForceMergeRequest;
import org.opensearch.action.index.IndexRequest;
import org.opensearch.action.support.PlainActionFuture;
import org.opensearch.cluster.ClusterChangedEvent;
import org.opensearch.cluster.ClusterState;
import org.opensearch.cluster.SnapshotsInProgress;
import org.opensearch.cluster.metadata.IndexMetadata;
import org.opensearch.cluster.routing.IndexShardRoutingTable;
import org.opensearch.cluster.routing.ShardRouting;
Expand All @@ -28,7 +31,9 @@
import org.opensearch.common.unit.TimeValue;
import org.opensearch.common.util.CancellableThreads;
import org.opensearch.core.action.ActionListener;
import org.opensearch.core.index.shard.ShardId;
import org.opensearch.core.xcontent.MediaTypeRegistry;
import org.opensearch.index.IndexService;
import org.opensearch.index.IndexSettings;
import org.opensearch.index.SegmentReplicationShardStats;
import org.opensearch.index.engine.DocIdSeqNoAndSource;
Expand All @@ -40,10 +45,12 @@
import org.opensearch.index.mapper.MapperService;
import org.opensearch.index.replication.OpenSearchIndexLevelReplicationTestCase;
import org.opensearch.index.replication.TestReplicationSource;
import org.opensearch.index.snapshots.IndexShardSnapshotStatus;
import org.opensearch.index.store.Store;
import org.opensearch.index.store.StoreFileMetadata;
import org.opensearch.index.translog.SnapshotMatchers;
import org.opensearch.index.translog.Translog;
import org.opensearch.indices.IndicesService;
import org.opensearch.indices.recovery.RecoverySettings;
import org.opensearch.indices.recovery.RecoveryTarget;
import org.opensearch.indices.replication.CheckpointInfoResponse;
Expand All @@ -60,6 +67,12 @@
import org.opensearch.indices.replication.common.ReplicationListener;
import org.opensearch.indices.replication.common.ReplicationState;
import org.opensearch.indices.replication.common.ReplicationType;
import org.opensearch.repositories.IndexId;
import org.opensearch.snapshots.Snapshot;
import org.opensearch.snapshots.SnapshotId;
import org.opensearch.snapshots.SnapshotInfoTests;
import org.opensearch.snapshots.SnapshotShardsService;
import org.opensearch.test.VersionUtils;
import org.opensearch.threadpool.ThreadPool;
import org.opensearch.transport.TransportService;
import org.junit.Assert;
Expand All @@ -69,6 +82,7 @@
import java.util.ArrayList;
import java.util.Arrays;
import java.util.Collections;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
import java.util.Set;
Expand All @@ -83,6 +97,7 @@
import static org.hamcrest.Matchers.containsString;
import static org.hamcrest.Matchers.hasToString;
import static org.hamcrest.Matchers.instanceOf;
import static org.mockito.ArgumentMatchers.anyInt;
import static org.mockito.Mockito.any;
import static org.mockito.Mockito.doAnswer;
import static org.mockito.Mockito.doThrow;
Expand Down Expand Up @@ -859,6 +874,97 @@ public void testSegmentReplicationStats() throws Exception {
}
}

public void testSnapshotWhileFailoverIncomplete() throws Exception {
final NRTReplicationEngineFactory engineFactory = new NRTReplicationEngineFactory();
final NRTReplicationEngineFactory spy = spy(engineFactory);
try (ReplicationGroup shards = createGroup(1, settings, indexMapping, spy, createTempDir())) {
final IndexShard primaryShard = shards.getPrimary();
final IndexShard replicaShard = shards.getReplicas().get(0);
shards.startAll();
shards.indexDocs(10);
shards.refresh("test");
replicateSegments(primaryShard, shards.getReplicas());
shards.assertAllEqual(10);

final SnapshotShardsService shardsService = getSnapshotShardsService(replicaShard);
final Snapshot snapshot = new Snapshot(randomAlphaOfLength(10), new SnapshotId(randomAlphaOfLength(5), randomAlphaOfLength(5)));

final ClusterState initState = addSnapshotIndex(clusterService.state(), snapshot, replicaShard, SnapshotsInProgress.State.INIT);
shardsService.clusterChanged(new ClusterChangedEvent("test", initState, clusterService.state()));

CountDownLatch latch = new CountDownLatch(1);
doAnswer(ans -> {
final Engine engineOrNull = replicaShard.getEngineOrNull();
assertNotNull(engineOrNull);
assertTrue(engineOrNull instanceof ReadOnlyEngine);
shards.assertAllEqual(10);
shardsService.clusterChanged(
new ClusterChangedEvent(
"test",
addSnapshotIndex(clusterService.state(), snapshot, replicaShard, SnapshotsInProgress.State.STARTED),
initState
)
);
latch.countDown();
return ans.callRealMethod();
}).when(spy).newReadWriteEngine(any());
shards.promoteReplicaToPrimary(replicaShard).get();
latch.await();
assertBusy(() -> {
final IndexShardSnapshotStatus.Copy copy = shardsService.currentSnapshotShards(snapshot).get(replicaShard.shardId).asCopy();
final IndexShardSnapshotStatus.Stage stage = copy.getStage();
assertEquals(IndexShardSnapshotStatus.Stage.FAILURE, stage);
assertNotNull(copy.getFailure());
assertTrue(
copy.getFailure()
.contains("snapshot triggered on a new primary following failover and cannot proceed until promotion is complete")
);
});
}
}

private SnapshotShardsService getSnapshotShardsService(IndexShard replicaShard) {
final TransportService transportService = mock(TransportService.class);
when(transportService.getThreadPool()).thenReturn(threadPool);
final IndicesService indicesService = mock(IndicesService.class);
final IndexService indexService = mock(IndexService.class);
when(indicesService.indexServiceSafe(any())).thenReturn(indexService);
when(indexService.getShardOrNull(anyInt())).thenReturn(replicaShard);
return new SnapshotShardsService(settings, clusterService, createRepositoriesService(), transportService, indicesService);
}

private ClusterState addSnapshotIndex(
ClusterState state,
Snapshot snapshot,
IndexShard shard,
SnapshotsInProgress.State snapshotState
) {
final Map<ShardId, SnapshotsInProgress.ShardSnapshotStatus> shardsBuilder = new HashMap<>();
ShardRouting shardRouting = shard.shardRouting;
shardsBuilder.put(
shardRouting.shardId(),
new SnapshotsInProgress.ShardSnapshotStatus(state.getNodes().getLocalNode().getId(), "1")
);
final SnapshotsInProgress.Entry entry = new SnapshotsInProgress.Entry(
snapshot,
randomBoolean(),
false,
snapshotState,
Collections.singletonList(new IndexId(index.getName(), index.getUUID())),
Collections.emptyList(),
randomNonNegativeLong(),
randomLong(),
shardsBuilder,
null,
SnapshotInfoTests.randomUserMetadata(),
VersionUtils.randomVersion(random()),
false
);
return ClusterState.builder(state)
.putCustom(SnapshotsInProgress.TYPE, SnapshotsInProgress.of(Collections.singletonList(entry)))
.build();
}

private void assertReplicaCaughtUp(IndexShard primaryShard) {
Set<SegmentReplicationShardStats> initialStats = primaryShard.getReplicationStats();
assertEquals(initialStats.size(), 1);
Expand Down

0 comments on commit 1200000

Please sign in to comment.