diff --git a/api.go b/api.go index 6cb8065db..a3b1f72f0 100644 --- a/api.go +++ b/api.go @@ -291,19 +291,21 @@ func RecoverCluster(conf *Config, fsm FSM, logs LogStore, stable StableStore, return fmt.Errorf("failed to list snapshots: %v", err) } for _, snapshot := range snapshots { - _, source, err := snaps.Open(snapshot.ID) - if err != nil { - // Skip this one and try the next. We will detect if we - // couldn't open any snapshots. - continue - } + if !conf.NoSnapshotRestoreOnStart { + _, source, err := snaps.Open(snapshot.ID) + if err != nil { + // Skip this one and try the next. We will detect if we + // couldn't open any snapshots. + continue + } - err = fsm.Restore(source) - // Close the source after the restore has completed - source.Close() - if err != nil { - // Same here, skip and try the next one. - continue + err = fsm.Restore(source) + // Close the source after the restore has completed + source.Close() + if err != nil { + // Same here, skip and try the next one. + continue + } } snapshotIndex = snapshot.Index @@ -545,23 +547,23 @@ func (r *Raft) restoreSnapshot() error { // Try to load in order of newest to oldest for _, snapshot := range snapshots { - _, source, err := r.snapshots.Open(snapshot.ID) - if err != nil { - r.logger.Error(fmt.Sprintf("Failed to open snapshot %v: %v", snapshot.ID, err)) - continue - } - - err = r.fsm.Restore(source) - // Close the source after the restore has completed - source.Close() - if err != nil { - r.logger.Error(fmt.Sprintf("Failed to restore snapshot %v: %v", snapshot.ID, err)) - continue - } + if !r.conf.NoSnapshotRestoreOnStart { + _, source, err := r.snapshots.Open(snapshot.ID) + if err != nil { + r.logger.Error(fmt.Sprintf("Failed to open snapshot %v: %v", snapshot.ID, err)) + continue + } - // Log success - r.logger.Info(fmt.Sprintf("Restored from snapshot %v", snapshot.ID)) + err = r.fsm.Restore(source) + // Close the source after the restore has completed + source.Close() + if err != nil { + r.logger.Error(fmt.Sprintf("Failed to restore snapshot %v: %v", snapshot.ID, err)) + continue + } + r.logger.Info(fmt.Sprintf("Restored from snapshot %v", snapshot.ID)) + } // Update the lastApplied so we don't replay old logs r.setLastApplied(snapshot.Index) diff --git a/config.go b/config.go index 66d4d0fa0..e43ba5449 100644 --- a/config.go +++ b/config.go @@ -199,6 +199,13 @@ type Config struct { // Logger is a user-provided hc-log logger. If nil, a logger writing to // LogOutput with LogLevel is used. Logger hclog.Logger + + // NoSnapshotRestoreOnStart controls if raft will restore a snapshot to the + // FSM on start. This is useful if your FSM recovers from other mechanisms + // than raft snapshotting. Snapshot metadata will still be used to initalize + // raft's configuration and index values. This is used in NewRaft and + // RestoreCluster. + NoSnapshotRestoreOnStart bool } // DefaultConfig returns a Config with usable defaults. diff --git a/raft_test.go b/raft_test.go index 11bd84087..b4994e287 100644 --- a/raft_test.go +++ b/raft_test.go @@ -1671,6 +1671,48 @@ func TestRaft_SnapshotRestore(t *testing.T) { // TODO: Need a test to process old-style entries in the Raft log when starting // up. +func TestRaft_NoRestoreOnStart(t *testing.T) { + conf := inmemConfig(t) + conf.TrailingLogs = 10 + conf.NoSnapshotRestoreOnStart = true + c := MakeCluster(1, t, conf) + + // Commit a lot of things. + leader := c.Leader() + var future Future + for i := 0; i < 100; i++ { + future = leader.Apply([]byte(fmt.Sprintf("test%d", i)), 0) + } + + // Wait for the last future to apply + if err := future.Error(); err != nil { + c.FailNowf("[ERR] err: %v", err) + } + + // Take a snapshot. + snapFuture := leader.Snapshot() + if err := snapFuture.Error(); err != nil { + c.FailNowf("[ERR] err: %v", err) + } + + // Shutdown. + shutdown := leader.Shutdown() + if err := shutdown.Error(); err != nil { + c.FailNowf("[ERR] err: %v", err) + } + + _, trans := NewInmemTransport(leader.localAddr) + newFSM := &MockFSM{} + _, err := NewRaft(&leader.conf, newFSM, leader.logs, leader.stable, leader.snapshots, trans) + if err != nil { + c.FailNowf("[ERR] err: %v", err) + } + + if len(newFSM.logs) != 0 { + c.FailNowf("[ERR] expected empty FSM, got %v", newFSM) + } +} + func TestRaft_SnapshotRestore_PeerChange(t *testing.T) { // Make the cluster. conf := inmemConfig(t)