diff --git a/.changelog/12362.txt b/.changelog/12362.txt new file mode 100644 index 000000000000..7a8dcca5e585 --- /dev/null +++ b/.changelog/12362.txt @@ -0,0 +1,3 @@ +```release-note:improvement +server: store and check previous Raft protocol version to prevent downgrades +``` diff --git a/nomad/server.go b/nomad/server.go index 5e3d2eb51ac2..30647a72f3de 100644 --- a/nomad/server.go +++ b/nomad/server.go @@ -1284,6 +1284,16 @@ func (s *Server) setupRaft() error { return err } + // Check Raft version and update the version file. + raftVersionFilePath := filepath.Join(path, "version") + raftVersionFileContent := strconv.Itoa(int(s.config.RaftConfig.ProtocolVersion)) + if err := s.checkRaftVersionFile(raftVersionFilePath); err != nil { + return err + } + if err := ioutil.WriteFile(raftVersionFilePath, []byte(raftVersionFileContent), 0644); err != nil { + return fmt.Errorf("failed to write Raft version file: %v", err) + } + // Create the BoltDB backend, with NoFreelistSync option store, raftErr := raftboltdb.New(raftboltdb.Options{ Path: filepath.Join(path, "raft.db"), @@ -1399,6 +1409,42 @@ func (s *Server) setupRaft() error { return nil } +// checkRaftVersionFile reads the Raft version file and returns an error if +// the Raft version is incompatible with the current version configured. +// Provide best-effort check if the file cannot be read. +func (s *Server) checkRaftVersionFile(path string) error { + raftVersion := s.config.RaftConfig.ProtocolVersion + baseWarning := "use the 'nomad operator raft list-peers' command to make sure the Raft protocol versions are consistent" + + _, err := os.Stat(path) + if err != nil { + if os.IsNotExist(err) { + return nil + } + + s.logger.Warn(fmt.Sprintf("unable to read Raft version file, %s", baseWarning), "error", err) + return nil + } + + v, err := ioutil.ReadFile(path) + if err != nil { + s.logger.Warn(fmt.Sprintf("unable to read Raft version file, %s", baseWarning), "error", err) + return nil + } + + previousVersion, err := strconv.Atoi(strings.TrimSpace(string(v))) + if err != nil { + s.logger.Warn(fmt.Sprintf("invalid Raft protocol version in Raft version file, %s", baseWarning), "error", err) + return nil + } + + if raft.ProtocolVersion(previousVersion) > raftVersion { + return fmt.Errorf("downgrading Raft is not supported, current version is %d, previous version was %d", raftVersion, previousVersion) + } + + return nil +} + // setupSerf is used to setup and initialize a Serf func (s *Server) setupSerf(conf *serf.Config, ch chan serf.Event, path string) (*serf.Serf, error) { conf.Init() diff --git a/nomad/server_test.go b/nomad/server_test.go index db1b1091e22d..858ac0dd715f 100644 --- a/nomad/server_test.go +++ b/nomad/server_test.go @@ -645,3 +645,27 @@ func TestServer_ReloadSchedulers_InvalidSchedulers(t *testing.T) { currentWC = s.GetSchedulerWorkerConfig() require.Equal(t, origWC, currentWC) } + +func TestServer_PreventRaftDowngrade(t *testing.T) { + ci.Parallel(t) + + dir := t.TempDir() + _, cleanupv3 := TestServer(t, func(c *Config) { + c.DevMode = false + c.DataDir = dir + c.RaftConfig.ProtocolVersion = 3 + }) + cleanupv3() + + _, cleanupv2, err := TestServerErr(t, func(c *Config) { + c.DevMode = false + c.DataDir = dir + c.RaftConfig.ProtocolVersion = 2 + }) + if cleanupv2 != nil { + defer cleanupv2() + } + + // Downgrading Raft should prevent the server from starting. + require.Error(t, err) +} diff --git a/nomad/testing.go b/nomad/testing.go index 9fbe2ca02e41..3dd3d5ba2c71 100644 --- a/nomad/testing.go +++ b/nomad/testing.go @@ -17,6 +17,7 @@ import ( "github.com/hashicorp/nomad/nomad/structs" "github.com/hashicorp/nomad/version" "github.com/pkg/errors" + "github.com/stretchr/testify/require" ) var ( @@ -39,6 +40,12 @@ func TestACLServer(t *testing.T, cb func(*Config)) (*Server, *structs.ACLToken, } func TestServer(t *testing.T, cb func(*Config)) (*Server, func()) { + s, c, err := TestServerErr(t, cb) + require.NoError(t, err, "failed to start test server") + return s, c +} + +func TestServerErr(t *testing.T, cb func(*Config)) (*Server, func(), error) { // Setup the default settings config := DefaultConfig() @@ -137,10 +144,10 @@ func TestServer(t *testing.T, cb func(*Config)) (*Server, func()) { case <-time.After(1 * time.Minute): t.Fatal("timed out while shutting down server") } - } + }, nil } else if i == 0 { freeport.Return(ports) - t.Fatalf("err: %v", err) + return nil, nil, err } else { if server != nil { _ = server.Shutdown() @@ -151,7 +158,7 @@ func TestServer(t *testing.T, cb func(*Config)) (*Server, func()) { } } - return nil, nil + return nil, nil, nil } func TestJoin(t *testing.T, servers ...*Server) {