diff --git a/server/src/main/java/org/opensearch/snapshots/SnapshotShardsService.java b/server/src/main/java/org/opensearch/snapshots/SnapshotShardsService.java index dc35ff7c127dc..af2f925f89726 100644 --- a/server/src/main/java/org/opensearch/snapshots/SnapshotShardsService.java +++ b/server/src/main/java/org/opensearch/snapshots/SnapshotShardsService.java @@ -379,6 +379,12 @@ private void snapshot( if (indexShard.routingEntry().primary() == false) { throw new IndexShardSnapshotFailedException(shardId, "snapshot should be performed only on primary"); } + if (indexShard.indexSettings().isSegRepEnabled() && indexShard.isPrimaryMode() == false) { + throw new IndexShardSnapshotFailedException( + shardId, + "snapshot triggered on a new primary following failover and cannot proceed until promotion is complete" + ); + } if (indexShard.routingEntry().relocating()) { // do not snapshot when in the process of relocation of primaries so we won't get conflicts throw new IndexShardSnapshotFailedException(shardId, "cannot snapshot while relocating"); diff --git a/server/src/test/java/org/opensearch/index/shard/SegmentReplicationIndexShardTests.java b/server/src/test/java/org/opensearch/index/shard/SegmentReplicationIndexShardTests.java index b7972810dddb9..d5ad48e80400c 100644 --- a/server/src/test/java/org/opensearch/index/shard/SegmentReplicationIndexShardTests.java +++ b/server/src/test/java/org/opensearch/index/shard/SegmentReplicationIndexShardTests.java @@ -16,6 +16,9 @@ import org.opensearch.action.admin.indices.forcemerge.ForceMergeRequest; import org.opensearch.action.index.IndexRequest; import org.opensearch.action.support.PlainActionFuture; +import org.opensearch.cluster.ClusterChangedEvent; +import org.opensearch.cluster.ClusterState; +import org.opensearch.cluster.SnapshotsInProgress; import org.opensearch.cluster.metadata.IndexMetadata; import org.opensearch.cluster.routing.IndexShardRoutingTable; import org.opensearch.cluster.routing.ShardRouting; @@ -28,7 +31,9 @@ import org.opensearch.common.unit.TimeValue; import org.opensearch.common.util.CancellableThreads; import org.opensearch.core.action.ActionListener; +import org.opensearch.core.index.shard.ShardId; import org.opensearch.core.xcontent.MediaTypeRegistry; +import org.opensearch.index.IndexService; import org.opensearch.index.IndexSettings; import org.opensearch.index.SegmentReplicationShardStats; import org.opensearch.index.engine.DocIdSeqNoAndSource; @@ -40,10 +45,12 @@ import org.opensearch.index.mapper.MapperService; import org.opensearch.index.replication.OpenSearchIndexLevelReplicationTestCase; import org.opensearch.index.replication.TestReplicationSource; +import org.opensearch.index.snapshots.IndexShardSnapshotStatus; import org.opensearch.index.store.Store; import org.opensearch.index.store.StoreFileMetadata; import org.opensearch.index.translog.SnapshotMatchers; import org.opensearch.index.translog.Translog; +import org.opensearch.indices.IndicesService; import org.opensearch.indices.recovery.RecoverySettings; import org.opensearch.indices.recovery.RecoveryTarget; import org.opensearch.indices.replication.CheckpointInfoResponse; @@ -60,6 +67,12 @@ import org.opensearch.indices.replication.common.ReplicationListener; import org.opensearch.indices.replication.common.ReplicationState; import org.opensearch.indices.replication.common.ReplicationType; +import org.opensearch.repositories.IndexId; +import org.opensearch.snapshots.Snapshot; +import org.opensearch.snapshots.SnapshotId; +import org.opensearch.snapshots.SnapshotInfoTests; +import org.opensearch.snapshots.SnapshotShardsService; +import org.opensearch.test.VersionUtils; import org.opensearch.threadpool.ThreadPool; import org.opensearch.transport.TransportService; import org.junit.Assert; @@ -69,6 +82,7 @@ import java.util.ArrayList; import java.util.Arrays; import java.util.Collections; +import java.util.HashMap; import java.util.List; import java.util.Map; import java.util.Set; @@ -83,6 +97,7 @@ import static org.hamcrest.Matchers.containsString; import static org.hamcrest.Matchers.hasToString; import static org.hamcrest.Matchers.instanceOf; +import static org.mockito.ArgumentMatchers.anyInt; import static org.mockito.Mockito.any; import static org.mockito.Mockito.doAnswer; import static org.mockito.Mockito.doThrow; @@ -859,6 +874,97 @@ public void testSegmentReplicationStats() throws Exception { } } + public void testSnapshotWhileFailoverIncomplete() throws Exception { + final NRTReplicationEngineFactory engineFactory = new NRTReplicationEngineFactory(); + final NRTReplicationEngineFactory spy = spy(engineFactory); + try (ReplicationGroup shards = createGroup(1, settings, indexMapping, spy, createTempDir())) { + final IndexShard primaryShard = shards.getPrimary(); + final IndexShard replicaShard = shards.getReplicas().get(0); + shards.startAll(); + shards.indexDocs(10); + shards.refresh("test"); + replicateSegments(primaryShard, shards.getReplicas()); + shards.assertAllEqual(10); + + final SnapshotShardsService shardsService = getSnapshotShardsService(replicaShard); + final Snapshot snapshot = new Snapshot(randomAlphaOfLength(10), new SnapshotId(randomAlphaOfLength(5), randomAlphaOfLength(5))); + + final ClusterState initState = addSnapshotIndex(clusterService.state(), snapshot, replicaShard, SnapshotsInProgress.State.INIT); + shardsService.clusterChanged(new ClusterChangedEvent("test", initState, clusterService.state())); + + CountDownLatch latch = new CountDownLatch(1); + doAnswer(ans -> { + final Engine engineOrNull = replicaShard.getEngineOrNull(); + assertNotNull(engineOrNull); + assertTrue(engineOrNull instanceof ReadOnlyEngine); + shards.assertAllEqual(10); + shardsService.clusterChanged( + new ClusterChangedEvent( + "test", + addSnapshotIndex(clusterService.state(), snapshot, replicaShard, SnapshotsInProgress.State.STARTED), + initState + ) + ); + latch.countDown(); + return ans.callRealMethod(); + }).when(spy).newReadWriteEngine(any()); + shards.promoteReplicaToPrimary(replicaShard).get(); + latch.await(); + assertBusy(() -> { + final IndexShardSnapshotStatus.Copy copy = shardsService.currentSnapshotShards(snapshot).get(replicaShard.shardId).asCopy(); + final IndexShardSnapshotStatus.Stage stage = copy.getStage(); + assertEquals(IndexShardSnapshotStatus.Stage.FAILURE, stage); + assertNotNull(copy.getFailure()); + assertTrue( + copy.getFailure() + .contains("snapshot triggered on a new primary following failover and cannot proceed until promotion is complete") + ); + }); + } + } + + private SnapshotShardsService getSnapshotShardsService(IndexShard replicaShard) { + final TransportService transportService = mock(TransportService.class); + when(transportService.getThreadPool()).thenReturn(threadPool); + final IndicesService indicesService = mock(IndicesService.class); + final IndexService indexService = mock(IndexService.class); + when(indicesService.indexServiceSafe(any())).thenReturn(indexService); + when(indexService.getShardOrNull(anyInt())).thenReturn(replicaShard); + return new SnapshotShardsService(settings, clusterService, createRepositoriesService(), transportService, indicesService); + } + + private ClusterState addSnapshotIndex( + ClusterState state, + Snapshot snapshot, + IndexShard shard, + SnapshotsInProgress.State snapshotState + ) { + final Map shardsBuilder = new HashMap<>(); + ShardRouting shardRouting = shard.shardRouting; + shardsBuilder.put( + shardRouting.shardId(), + new SnapshotsInProgress.ShardSnapshotStatus(state.getNodes().getLocalNode().getId(), "1") + ); + final SnapshotsInProgress.Entry entry = new SnapshotsInProgress.Entry( + snapshot, + randomBoolean(), + false, + snapshotState, + Collections.singletonList(new IndexId(index.getName(), index.getUUID())), + Collections.emptyList(), + randomNonNegativeLong(), + randomLong(), + shardsBuilder, + null, + SnapshotInfoTests.randomUserMetadata(), + VersionUtils.randomVersion(random()), + false + ); + return ClusterState.builder(state) + .putCustom(SnapshotsInProgress.TYPE, SnapshotsInProgress.of(Collections.singletonList(entry))) + .build(); + } + private void assertReplicaCaughtUp(IndexShard primaryShard) { Set initialStats = primaryShard.getReplicationStats(); assertEquals(initialStats.size(), 1);