From 11089b23cea2b07f258fc1a1a1f30c547deb15bd Mon Sep 17 00:00:00 2001 From: Chelsea Holland Komlo Date: Thu, 7 Dec 2017 12:07:00 -0500 Subject: [PATCH] reload raft transport layer fix up linting --- client/client.go | 11 +- client/client_test.go | 6 +- client/config/config.go | 2 +- command/agent/command.go | 8 +- nomad/server.go | 79 ++++---- nomad/server_test.go | 173 +++++++++++++++++- .../hashicorp/raft/net_transport.go | 33 +++- 7 files changed, 253 insertions(+), 59 deletions(-) diff --git a/client/client.go b/client/client.go index 0b73c63bbd5a..c58ff7fa5f29 100644 --- a/client/client.go +++ b/client/client.go @@ -365,9 +365,9 @@ func (c *Client) init() error { return nil } -// ReloadTLSConnections allows a client to reload RPC connections if the -// client's TLS configuration changes from plaintext to TLS -func (c *Client) ReloadTLSConnections(newConfig *nconfig.TLSConfig) error { +// reloadTLSConnections allows a client to reload its TLS configuration on the +// fly +func (c *Client) reloadTLSConnections(newConfig *nconfig.TLSConfig) error { var tlsWrap tlsutil.RegionWrapper if newConfig != nil && newConfig.EnableRPC { tw, err := c.config.NewTLSConfiguration(newConfig).OutgoingTLSWrapper() @@ -378,6 +378,8 @@ func (c *Client) ReloadTLSConnections(newConfig *nconfig.TLSConfig) error { tlsWrap = tw } + // Keep the client configuration up to date as we use configuration values to + // decide on what type of connections to accept c.configLock.Lock() c.config.TLSConfig = newConfig c.configLock.Unlock() @@ -387,8 +389,7 @@ func (c *Client) ReloadTLSConnections(newConfig *nconfig.TLSConfig) error { return nil } -// Reload allows a client to reload RPC connections if the -// client's TLS configuration changes +// Reload allows a client to reload its configuration on the fly func (c *Client) Reload(newConfig *config.Config) error { return c.reloadTLSConnections(newConfig.TLSConfig) } diff --git a/client/client_test.go b/client/client_test.go index ae2f2cdc4fac..d4cba7db5dd2 100644 --- a/client/client_test.go +++ b/client/client_test.go @@ -1007,7 +1007,7 @@ func TestClient_ReloadTLS_UpgradePlaintextToTLS(t *testing.T) { assert := assert.New(t) s1, addr := testServer(t, func(c *nomad.Config) { - c.Region = "dc1" + c.Region = "foo" }) defer s1.Shutdown() testutil.WaitForLeader(t, s1.RPC) @@ -1059,7 +1059,7 @@ func TestClient_ReloadTLS_DowngradeTLSToPlaintext(t *testing.T) { assert := assert.New(t) s1, addr := testServer(t, func(c *nomad.Config) { - c.Region = "dc1" + c.Region = "foo" }) defer s1.Shutdown() testutil.WaitForLeader(t, s1.RPC) @@ -1090,7 +1090,7 @@ func TestClient_ReloadTLS_DowngradeTLSToPlaintext(t *testing.T) { req := structs.NodeSpecificRequest{ NodeID: c1.Node().ID, - QueryOptions: structs.QueryOptions{Region: "dc1"}, + QueryOptions: structs.QueryOptions{Region: "foo"}, } var out structs.SingleNodeResponse testutil.AssertUntil(100*time.Millisecond, diff --git a/client/config/config.go b/client/config/config.go index 89820edbef39..e2a271646ee0 100644 --- a/client/config/config.go +++ b/client/config/config.go @@ -354,7 +354,7 @@ func (c *Config) TLSConfiguration() *tlsutil.Config { } // NewTLSConfiguration returns a TLSUtil Config for a new TLS config object -// This allows a TLSConfig object to be created without first explicitely +// This allows a TLSConfig object to be created without first explicitly // setting it func (c *Config) NewTLSConfiguration(tlsConfig *config.TLSConfig) *tlsutil.Config { return &tlsutil.Config{ diff --git a/command/agent/command.go b/command/agent/command.go index 8844e78edefd..5a7c67831e70 100644 --- a/command/agent/command.go +++ b/command/agent/command.go @@ -12,7 +12,6 @@ import ( "sort" "strconv" "strings" - "sync" "syscall" "time" @@ -47,7 +46,6 @@ type Command struct { args []string agent *Agent httpServer *HTTPServer - httpServerLock sync.Mutex logFilter *logutils.LevelFilter logOutput io.Writer retryJoinErrCh chan struct{} @@ -602,8 +600,6 @@ WAIT: func (c *Command) reloadHTTPServer(newConfig *Config) error { c.agent.logger.Println("[INFO] agent: Reloading HTTP server with new TLS configuration") - c.httpServerLock.Lock() - defer c.httpServerLock.Unlock() c.httpServer.Shutdown() @@ -640,6 +636,7 @@ func (c *Command) handleReload() { shouldReload := c.agent.ShouldReload(newConf) if shouldReload { + c.agent.logger.Printf("[DEBUG] agent: starting reload of agent config") err := c.agent.Reload(newConf) if err != nil { c.agent.logger.Printf("[ERR] agent: failed to reload the config: %v", err) @@ -649,8 +646,10 @@ func (c *Command) handleReload() { if s := c.agent.Server(); s != nil { sconf, err := convertServerConfig(newConf, c.logOutput) + c.agent.logger.Printf("[DEBUG] agent: starting reload of server config") if err != nil { c.agent.logger.Printf("[ERR] agent: failed to convert server config: %v", err) + return } else { if err := s.Reload(sconf); err != nil { c.agent.logger.Printf("[ERR] agent: reloading server config failed: %v", err) @@ -661,6 +660,7 @@ func (c *Command) handleReload() { if s := c.agent.Client(); s != nil { clientConfig, err := c.agent.clientConfig() + c.agent.logger.Printf("[DEBUG] agent: starting reload of client config") if err != nil { c.agent.logger.Printf("[ERR] agent: reloading client config failed: %v", err) return diff --git a/nomad/server.go b/nomad/server.go index 4ea115694fdc..14c3ed4a5569 100644 --- a/nomad/server.go +++ b/nomad/server.go @@ -85,8 +85,9 @@ const ( // schedulers, and notification bus for agents. type Server struct { config *Config - configLock sync.RWMutex - logger *log.Logger + configLock sync.Mutex + + logger *log.Logger // Connection pool to other Nomad servers connPool *ConnPool @@ -96,32 +97,28 @@ type Server struct { // The raft instance is used among Nomad nodes within the // region to protect operations that require strong consistency - leaderCh <-chan bool - raft *raft.Raft - raftLayer *RaftLayer - raftLayerLock sync.Mutex + leaderCh <-chan bool + raft *raft.Raft + raftLayer *RaftLayer raftStore *raftboltdb.BoltStore raftInmem *raft.InmemStore - raftTransport *raft.NetworkTransport - raftTransportLock sync.Mutex + raftTransport *raft.NetworkTransport // fsm is the state machine used with Raft fsm *nomadFSM // rpcListener is used to listen for incoming connections - rpcListener net.Listener - rpcListenerLock sync.Mutex - listenerCh chan struct{} + rpcListener net.Listener + listenerCh chan struct{} rpcServer *rpc.Server rpcAdvertise net.Addr // rpcTLS is the TLS config for incoming TLS requests - rpcTLS *tls.Config - rpcCancel context.CancelFunc - rpcTLSLock sync.Mutex + rpcTLS *tls.Config + rpcCancel context.CancelFunc // peers is used to track the known Nomad servers. This is // used for region forwarding and clustering. @@ -329,6 +326,7 @@ func NewServer(config *Config, consulCatalog consul.CatalogAPI, logger *log.Logg // Start ingesting events for Serf go s.serfEventHandler() + // start the RPC listener for the server s.startRPCListener() // Emit metrics for the eval broker @@ -353,10 +351,8 @@ func NewServer(config *Config, consulCatalog consul.CatalogAPI, logger *log.Logg return s, nil } -// Start the RPC listeners +// startRPCListener starts the server's the RPC listener func (s *Server) startRPCListener() { - s.rpcListenerLock.Lock() - defer s.rpcListenerLock.Unlock() ctx, cancel := context.WithCancel(context.Background()) s.rpcCancel = cancel go func() { @@ -365,10 +361,8 @@ func (s *Server) startRPCListener() { }() } +// createRPCListener creates the server's RPC listener func (s *Server) createRPCListener() error { - s.rpcListenerLock.Lock() - defer s.rpcListenerLock.Unlock() - s.listenerCh = make(chan struct{}) list, err := net.ListenTCP("tcp", s.config.RPCAddr) if err != nil || list == nil { @@ -380,6 +374,8 @@ func (s *Server) createRPCListener() error { return nil } +// getTLSConf gets the server's TLS configuration based on the config supplied +// by the operator func getTLSConf(enableRPC bool, tlsConf *tlsutil.Config) (*tls.Config, tlsutil.RegionWrapper, error) { var tlsWrap tlsutil.RegionWrapper var incomingTLS *tls.Config @@ -399,11 +395,13 @@ func getTLSConf(enableRPC bool, tlsConf *tlsutil.Config) (*tls.Config, tlsutil.R return incomingTLS, tlsWrap, nil } -// ReloadTLSConnections updates a server's TLS configuration and reloads RPC -// connections. This will handle both TLS upgrades and downgrades. +// reloadTLSConnections updates a server's TLS configuration and reloads RPC +// connections. func (s *Server) reloadTLSConnections(newTLSConfig *config.TLSConfig) error { s.logger.Printf("[INFO] nomad: reloading server connections due to configuration changes") + // the server config must be in sync with the latest config changes, due to + // testing for TLS configuration settings in rpc.go tlsConf := s.config.newTLSConfig(newTLSConfig) incomingTLS, tlsWrap, err := getTLSConf(newTLSConfig.EnableRPC, tlsConf) if err != nil { @@ -411,6 +409,13 @@ func (s *Server) reloadTLSConnections(newTLSConfig *config.TLSConfig) error { return err } + // Keeping configuration in sync is important for other places that require + // access to config information, such as rpc.go, where we decide on what kind + // of network connections to accept depending on the server configuration + s.configLock.Lock() + s.config.TLSConfig = newTLSConfig + s.configLock.Unlock() + if s.rpcCancel == nil { s.logger.Printf("[ERR] nomad: No TLS Context to reset") return fmt.Errorf("Unable to reset tls context") @@ -422,38 +427,27 @@ func (s *Server) reloadTLSConnections(newTLSConfig *config.TLSConfig) error { s.config.TLSConfig = newTLSConfig s.configLock.Unlock() - s.rpcTLSLock.Lock() s.rpcTLS = incomingTLS - s.rpcTLSLock.Unlock() - - s.raftTransportLock.Lock() - defer s.raftTransportLock.Unlock() - s.raftTransport.Close() - s.connPool.ReloadTLS(tlsWrap) // reinitialize our rpc listener - s.rpcListenerLock.Lock() s.rpcListener.Close() <-s.listenerCh - s.rpcListenerLock.Unlock() + s.raftTransport.Pause() + s.raftLayer.Close() err = s.createRPCListener() - if err != nil { - return err - } - s.startRPCListener() - s.raftLayerLock.Lock() - s.raftLayer.Close() + // CLose existing streams wrapper := tlsutil.RegionSpecificWrapper(s.config.Region, tlsWrap) s.raftLayer = NewRaftLayer(s.rpcAdvertise, wrapper) - s.raftLayerLock.Unlock() - // re-initialize the network transport with a re-initialized stream layer - trans := raft.NewNetworkTransport(s.raftLayer, 3, s.config.RaftTimeout, - s.config.LogOutput) - s.raftTransport = trans + s.startRPCListener() + + time.Sleep(3 * time.Second) + if err != nil { + return err + } s.logger.Printf("[DEBUG] nomad: finished reloading server connections") return nil @@ -621,6 +615,7 @@ func (s *Server) Reload(newConfig *Config) error { if !newConfig.TLSConfig.Equals(s.config.TLSConfig) { if err := s.reloadTLSConnections(newConfig.TLSConfig); err != nil { + s.logger.Printf("[DEBUG] nomad: reloading server TLS configuration") multierror.Append(&mErr, err) } } diff --git a/nomad/server_test.go b/nomad/server_test.go index 89e6810d7410..e7d063b97a40 100644 --- a/nomad/server_test.go +++ b/nomad/server_test.go @@ -13,6 +13,7 @@ import ( "time" "github.com/hashicorp/consul/lib/freeport" + memdb "github.com/hashicorp/go-memdb" msgpackrpc "github.com/hashicorp/net-rpc-msgpackrpc" "github.com/hashicorp/nomad/command/agent/consul" "github.com/hashicorp/nomad/helper/uuid" @@ -330,7 +331,7 @@ func TestServer_Reload_TLSConnections_PlaintextToTLS(t *testing.T) { // Tests that the server will successfully reload its network connections, // downgrading from TLS to plaintext if the server's TLS configuration changes. -func TestServer_reload_TLSConnections_TLSToPlaintext(t *testing.T) { +func TestServer_Reload_TLSConnections_TLSToPlaintext_RPC(t *testing.T) { t.Parallel() assert := assert.New(t) @@ -375,3 +376,173 @@ func TestServer_reload_TLSConnections_TLSToPlaintext(t *testing.T) { err = msgpackrpc.CallWithCodec(codec, "Node.Register", req, &resp) assert.Nil(err) } + +// Test that Raft connections are reloaded as expected when a Nomad server is +// upgraded from plaintext to TLS +func TestServer_Reload_TLSConnections_Raft(t *testing.T) { + assert := assert.New(t) + t.Parallel() + const ( + cafile = "../../helper/tlsutil/testdata/ca.pem" + foocert = "../../helper/tlsutil/testdata/nomad-foo.pem" + fookey = "../../helper/tlsutil/testdata/nomad-foo-key.pem" + barcert = "../dev/tls_cluster/certs/nomad.pem" + barkey = "../dev/tls_cluster/certs/nomad-key.pem" + ) + dir := tmpDir(t) + defer os.RemoveAll(dir) + s1 := testServer(t, func(c *Config) { + c.BootstrapExpect = 2 + c.DevMode = false + c.DevDisableBootstrap = true + c.DataDir = path.Join(dir, "node1") + c.NodeName = "node1" + c.Region = "regionFoo" + }) + defer s1.Shutdown() + + s2 := testServer(t, func(c *Config) { + c.BootstrapExpect = 2 + c.DevMode = false + c.DevDisableBootstrap = true + c.DataDir = path.Join(dir, "node2") + c.NodeName = "node2" + c.Region = "regionFoo" + }) + defer s2.Shutdown() + + testJoin(t, s1, s2) + + testutil.WaitForResult(func() (bool, error) { + peers, _ := s1.numPeers() + return peers == 2, nil + }, func(err error) { + t.Fatalf("should have 2 peers") + }) + + // the server should be connected to the rest of the cluster + testutil.WaitForLeader(t, s2.RPC) + + { + // assert that a job register request will succeed + codec := rpcClient(t, s2) + job := mock.Job() + req := &structs.JobRegisterRequest{ + Job: job, + WriteRequest: structs.WriteRequest{ + Region: "regionFoo", + Namespace: job.Namespace, + }, + } + + // Fetch the response + var resp structs.JobRegisterResponse + err := msgpackrpc.CallWithCodec(codec, "Job.Register", req, &resp) + assert.Nil(err) + + // Check for the job in the FSM of each server in the cluster + { + state := s2.fsm.State() + ws := memdb.NewWatchSet() + out, err := state.JobByID(ws, job.Namespace, job.ID) + assert.Nil(err) + assert.NotNil(out) + assert.Equal(out.CreateIndex, resp.JobModifyIndex) + } + { + state := s1.fsm.State() + ws := memdb.NewWatchSet() + out, err := state.JobByID(ws, job.Namespace, job.ID) + assert.Nil(err) + assert.NotNil(out) // TODO Occasionally is flaky + assert.Equal(out.CreateIndex, resp.JobModifyIndex) + } + } + + newTLSConfig := &config.TLSConfig{ + EnableHTTP: true, + VerifyHTTPSClient: true, + CAFile: cafile, + CertFile: foocert, + KeyFile: fookey, + } + + err := s1.reloadTLSConnections(newTLSConfig) + assert.Nil(err) + + { + // assert that a job register request will fail between servers that + // should not be able to communicate over Raft + codec := rpcClient(t, s2) + job := mock.Job() + req := &structs.JobRegisterRequest{ + Job: job, + WriteRequest: structs.WriteRequest{ + Region: "global", + Namespace: job.Namespace, + }, + } + + var resp structs.JobRegisterResponse + err := msgpackrpc.CallWithCodec(codec, "Job.Register", req, &resp) + assert.NotNil(err) + + // Check that the job was not persisted + state := s2.fsm.State() + ws := memdb.NewWatchSet() + out, _ := state.JobByID(ws, job.Namespace, job.ID) + assert.Nil(out) + } + + secondNewTLSConfig := &config.TLSConfig{ + EnableHTTP: true, + VerifyHTTPSClient: true, + CAFile: cafile, + CertFile: barcert, + KeyFile: barkey, + } + + // Now, transition the other server to TLS, which should restore their + // ability to communicate. + err = s2.reloadTLSConnections(secondNewTLSConfig) + assert.Nil(err) + + // the server should be connected to the rest of the cluster + testutil.WaitForLeader(t, s2.RPC) + + { + // assert that a job register request will succeed + codec := rpcClient(t, s2) + job := mock.Job() + req := &structs.JobRegisterRequest{ + Job: job, + WriteRequest: structs.WriteRequest{ + Region: "regionFoo", + Namespace: job.Namespace, + }, + } + + // Fetch the response + var resp structs.JobRegisterResponse + err := msgpackrpc.CallWithCodec(codec, "Job.Register", req, &resp) + assert.Nil(err) + + // Check for the job in the FSM of each server in the cluster + { + state := s2.fsm.State() + ws := memdb.NewWatchSet() + out, err := state.JobByID(ws, job.Namespace, job.ID) + assert.Nil(err) + assert.NotNil(out) + assert.Equal(out.CreateIndex, resp.JobModifyIndex) + } + { + state := s1.fsm.State() + ws := memdb.NewWatchSet() + out, err := state.JobByID(ws, job.Namespace, job.ID) + assert.Nil(err) + assert.NotNil(out) + assert.Equal(out.CreateIndex, resp.JobModifyIndex) + } + } +} diff --git a/vendor/github.com/hashicorp/raft/net_transport.go b/vendor/github.com/hashicorp/raft/net_transport.go index 7c55ac5371ff..cbfabfc31606 100644 --- a/vendor/github.com/hashicorp/raft/net_transport.go +++ b/vendor/github.com/hashicorp/raft/net_transport.go @@ -2,6 +2,7 @@ package raft import ( "bufio" + "context" "errors" "fmt" "io" @@ -72,7 +73,8 @@ type NetworkTransport struct { shutdownCh chan struct{} shutdownLock sync.Mutex - stream StreamLayer + stream StreamLayer + streamCancel context.CancelFunc timeout time.Duration TimeoutScale int @@ -141,6 +143,7 @@ func NewNetworkTransportWithLogger( if logger == nil { logger = log.New(os.Stderr, "", log.LstdFlags) } + trans := &NetworkTransport{ connPool: make(map[ServerAddress][]*netConn), consumeCh: make(chan RPC), @@ -151,10 +154,26 @@ func NewNetworkTransportWithLogger( timeout: timeout, TimeoutScale: DefaultTimeoutScale, } - go trans.listen() + + ctx, cancel := context.WithCancel(context.Background()) + trans.streamCancel = cancel + go trans.listen(ctx) return trans } +// Pause closes the current stream for a NetworkTransport instance +func (n *NetworkTransport) Pause() { + n.streamCancel() + n.stream.Close() +} + +// Pause creates a new stream for a NetworkTransport instance +func (n *NetworkTransport) Reload(s StreamLayer) { + ctx, cancel := context.WithCancel(context.Background()) + n.streamCancel = cancel + go n.listen(ctx) +} + // SetHeartbeatHandler is used to setup a heartbeat handler // as a fast-pass. This is to avoid head-of-line blocking from // disk IO. @@ -356,14 +375,22 @@ func (n *NetworkTransport) DecodePeer(buf []byte) ServerAddress { } // listen is used to handling incoming connections. -func (n *NetworkTransport) listen() { +func (n *NetworkTransport) listen(ctx context.Context) { for { + select { + case <-ctx.Done(): + n.logger.Println("[INFO] raft-net: stream layer is closed") + return + default: + } + // Accept incoming connections conn, err := n.stream.Accept() if err != nil { if n.IsShutdown() { return } + // TODO Getting an error here n.logger.Printf("[ERR] raft-net: Failed to accept connection: %v", err) continue }