From c0750abc6a0914ba38a73c518be755529c21c82d Mon Sep 17 00:00:00 2001 From: Alex Dadgar Date: Wed, 3 Jan 2018 14:59:52 -0800 Subject: [PATCH 1/5] Helper to populate RPC server endpoints --- nomad/endpoints_oss.go | 4 +- nomad/heartbeat.go | 2 +- nomad/node_endpoint_test.go | 4 +- nomad/server.go | 89 ++++++++++++++++++++++--------------- 4 files changed, 59 insertions(+), 40 deletions(-) diff --git a/nomad/endpoints_oss.go b/nomad/endpoints_oss.go index 3d59b57ead0f..006b05552c56 100644 --- a/nomad/endpoints_oss.go +++ b/nomad/endpoints_oss.go @@ -2,6 +2,8 @@ package nomad +import "net/rpc" + // EnterpriseEndpoints holds the set of enterprise only endpoints to register type EnterpriseEndpoints struct{} @@ -12,4 +14,4 @@ func NewEnterpriseEndpoints(s *Server) *EnterpriseEndpoints { } // Register is a no-op in oss. -func (e *EnterpriseEndpoints) Register(s *Server) {} +func (e *EnterpriseEndpoints) Register(s *rpc.Server) {} diff --git a/nomad/heartbeat.go b/nomad/heartbeat.go index 89bc86010152..54e885337cbc 100644 --- a/nomad/heartbeat.go +++ b/nomad/heartbeat.go @@ -100,7 +100,7 @@ func (s *Server) invalidateHeartbeat(id string) { }, } var resp structs.NodeUpdateResponse - if err := s.endpoints.Node.UpdateStatus(&req, &resp); err != nil { + if err := s.staticEndpoints.Node.UpdateStatus(&req, &resp); err != nil { s.logger.Printf("[ERR] nomad.heartbeat: update status failed: %v", err) } } diff --git a/nomad/node_endpoint_test.go b/nomad/node_endpoint_test.go index 909e2a63715c..c53ea9f97d64 100644 --- a/nomad/node_endpoint_test.go +++ b/nomad/node_endpoint_test.go @@ -1746,7 +1746,7 @@ func TestClientEndpoint_BatchUpdate(t *testing.T) { // Call to do the batch update bf := NewBatchFuture() - endpoint := s1.endpoints.Node + endpoint := s1.staticEndpoints.Node endpoint.batchUpdate(bf, []*structs.Allocation{clientAlloc}) if err := bf.Wait(); err != nil { t.Fatalf("err: %v", err) @@ -1864,7 +1864,7 @@ func TestClientEndpoint_CreateNodeEvals(t *testing.T) { } // Create some evaluations - ids, index, err := s1.endpoints.Node.createNodeEvals(alloc.NodeID, 1) + ids, index, err := s1.staticEndpoints.Node.createNodeEvals(alloc.NodeID, 1) if err != nil { t.Fatalf("err: %v", err) } diff --git a/nomad/server.go b/nomad/server.go index 7648c74360f3..7b298b217288 100644 --- a/nomad/server.go +++ b/nomad/server.go @@ -88,9 +88,6 @@ type Server struct { // Connection pool to other Nomad servers connPool *ConnPool - // Endpoints holds our RPC endpoints - endpoints endpoints - // The raft instance is used among Nomad nodes within the // region to protect operations that require strong consistency leaderCh <-chan bool @@ -104,13 +101,21 @@ type Server struct { fsm *nomadFSM // rpcListener is used to listen for incoming connections - rpcListener net.Listener - rpcServer *rpc.Server + rpcListener net.Listener + + // rpcServer is the static RPC server that is used by the local agent. + rpcServer *rpc.Server + + // rpcAdvertise is the advertised address for the RPC listener. rpcAdvertise net.Addr // rpcTLS is the TLS config for incoming TLS requests rpcTLS *tls.Config + // staticEndpoints is the set of static endpoints that can be reused across + // all RPC connections + staticEndpoints endpoints + // peers is used to track the known Nomad servers. This is // used for region forwarding and clustering. peers map[string][]*serverParts @@ -739,37 +744,8 @@ func (s *Server) setupVaultClient() error { // setupRPC is used to setup the RPC listener func (s *Server) setupRPC(tlsWrap tlsutil.RegionWrapper) error { - // Create endpoints - s.endpoints.ACL = &ACL{s} - s.endpoints.Alloc = &Alloc{s} - s.endpoints.Eval = &Eval{s} - s.endpoints.Job = &Job{s} - s.endpoints.Node = &Node{srv: s} - s.endpoints.Deployment = &Deployment{srv: s} - s.endpoints.Operator = &Operator{s} - s.endpoints.Periodic = &Periodic{s} - s.endpoints.Plan = &Plan{s} - s.endpoints.Region = &Region{s} - s.endpoints.Status = &Status{s} - s.endpoints.System = &System{s} - s.endpoints.Search = &Search{s} - s.endpoints.Enterprise = NewEnterpriseEndpoints(s) - - // Register the handlers - s.rpcServer.Register(s.endpoints.ACL) - s.rpcServer.Register(s.endpoints.Alloc) - s.rpcServer.Register(s.endpoints.Eval) - s.rpcServer.Register(s.endpoints.Job) - s.rpcServer.Register(s.endpoints.Node) - s.rpcServer.Register(s.endpoints.Deployment) - s.rpcServer.Register(s.endpoints.Operator) - s.rpcServer.Register(s.endpoints.Periodic) - s.rpcServer.Register(s.endpoints.Plan) - s.rpcServer.Register(s.endpoints.Region) - s.rpcServer.Register(s.endpoints.Status) - s.rpcServer.Register(s.endpoints.System) - s.rpcServer.Register(s.endpoints.Search) - s.endpoints.Enterprise.Register(s) + // Populate the static RPC server + s.setupRpcServer(s.rpcServer) list, err := net.ListenTCP("tcp", s.config.RPCAddr) if err != nil { @@ -799,6 +775,47 @@ func (s *Server) setupRPC(tlsWrap tlsutil.RegionWrapper) error { return nil } +// setupRpcServer is used to populate an RPC server with endpoints +func (s *Server) setupRpcServer(server *rpc.Server) { + // Add the static endpoints to the RPC server. + if s.staticEndpoints.Status == nil { + // Initialize the list just once + s.staticEndpoints.ACL = &ACL{s} + s.staticEndpoints.Alloc = &Alloc{s} + s.staticEndpoints.Eval = &Eval{s} + s.staticEndpoints.Job = &Job{s} + s.staticEndpoints.Node = &Node{srv: s} + s.staticEndpoints.Deployment = &Deployment{srv: s} + s.staticEndpoints.Operator = &Operator{s} + s.staticEndpoints.Periodic = &Periodic{s} + s.staticEndpoints.Plan = &Plan{s} + s.staticEndpoints.Region = &Region{s} + s.staticEndpoints.Status = &Status{s} + s.staticEndpoints.System = &System{s} + s.staticEndpoints.Search = &Search{s} + s.staticEndpoints.Enterprise = NewEnterpriseEndpoints(s) + } + + // Register the static handlers + server.Register(s.staticEndpoints.ACL) + server.Register(s.staticEndpoints.Alloc) + server.Register(s.staticEndpoints.Eval) + server.Register(s.staticEndpoints.Job) + server.Register(s.staticEndpoints.Node) + server.Register(s.staticEndpoints.Deployment) + server.Register(s.staticEndpoints.Operator) + server.Register(s.staticEndpoints.Periodic) + server.Register(s.staticEndpoints.Plan) + server.Register(s.staticEndpoints.Region) + server.Register(s.staticEndpoints.Status) + server.Register(s.staticEndpoints.System) + server.Register(s.staticEndpoints.Search) + s.staticEndpoints.Enterprise.Register(server) + + // Create new dynamic endpoints and add them to the RPC server. + // TODO +} + // setupRaft is used to setup and initialize Raft func (s *Server) setupRaft() error { // If we have an unclean exit then attempt to close the Raft store. From e68420fb9586f59afd1e8e115cd8b29f8fbcaf42 Mon Sep 17 00:00:00 2001 From: Alex Dadgar Date: Wed, 3 Jan 2018 16:00:55 -0800 Subject: [PATCH 2/5] Dynamic RPC servers with context --- nomad/rpc.go | 52 +++++++++++++++++++++++++++++++++++++++---------- nomad/server.go | 4 ++-- 2 files changed, 44 insertions(+), 12 deletions(-) diff --git a/nomad/rpc.go b/nomad/rpc.go index 828ee0c94c0a..49dea7a9fd18 100644 --- a/nomad/rpc.go +++ b/nomad/rpc.go @@ -55,6 +55,21 @@ const ( enqueueLimit = 30 * time.Second ) +// RPCContext provides metadata about the RPC connection. +type RPCContext struct { + // Session exposes the multiplexed connection session. + Session *yamux.Session + + // TLS marks whether the RPC is over a TLS based connection + TLS bool + + // TLSRole is the certificate role making the TLS connection. + TLSRole string + + // TLSRegion is the region on the certificate making theTLS connection + TLSRegion string +} + // NewClientCodec returns a new rpc.ClientCodec to be used to make RPC calls to // the Nomad Server. func NewClientCodec(conn io.ReadWriteCloser) rpc.ClientCodec { @@ -80,14 +95,14 @@ func (s *Server) listen() { continue } - go s.handleConn(conn, false) + go s.handleConn(conn, &RPCContext{}) metrics.IncrCounter([]string{"nomad", "rpc", "accept_conn"}, 1) } } // handleConn is used to determine if this is a Raft or // Nomad type RPC connection and invoke the correct handler -func (s *Server) handleConn(conn net.Conn, isTLS bool) { +func (s *Server) handleConn(conn net.Conn, ctx *RPCContext) { // Read a single byte buf := make([]byte, 1) if _, err := conn.Read(buf); err != nil { @@ -99,7 +114,7 @@ func (s *Server) handleConn(conn net.Conn, isTLS bool) { } // Enforce TLS if EnableRPC is set - if s.config.TLSConfig.EnableRPC && !isTLS && RPCType(buf[0]) != rpcTLS { + if s.config.TLSConfig.EnableRPC && !ctx.TLS && RPCType(buf[0]) != rpcTLS { if !s.config.TLSConfig.RPCUpgradeMode { s.logger.Printf("[WARN] nomad.rpc: Non-TLS connection attempted from %s with RequireTLS set", conn.RemoteAddr().String()) conn.Close() @@ -110,14 +125,17 @@ func (s *Server) handleConn(conn net.Conn, isTLS bool) { // Switch on the byte switch RPCType(buf[0]) { case rpcNomad: - s.handleNomadConn(conn) + // Create an RPC Server and handle the request + server := rpc.NewServer() + s.setupRpcServer(server, ctx) + s.handleNomadConn(conn, server) case rpcRaft: metrics.IncrCounter([]string{"nomad", "rpc", "raft_handoff"}, 1) s.raftLayer.Handoff(conn) case rpcMultiplex: - s.handleMultiplex(conn) + s.handleMultiplex(conn, ctx) case rpcTLS: if s.rpcTLS == nil { @@ -126,7 +144,13 @@ func (s *Server) handleConn(conn net.Conn, isTLS bool) { return } conn = tls.Server(conn, s.rpcTLS) - s.handleConn(conn, true) + + // Update the connection context with the fact that the connection is + // using TLS + // TODO pull out more TLS information into the context + ctx.TLS = true + + s.handleConn(conn, ctx) default: s.logger.Printf("[ERR] nomad.rpc: unrecognized RPC byte: %v", buf[0]) @@ -137,11 +161,19 @@ func (s *Server) handleConn(conn net.Conn, isTLS bool) { // handleMultiplex is used to multiplex a single incoming connection // using the Yamux multiplexer -func (s *Server) handleMultiplex(conn net.Conn) { +func (s *Server) handleMultiplex(conn net.Conn, ctx *RPCContext) { defer conn.Close() conf := yamux.DefaultConfig() conf.LogOutput = s.config.LogOutput server, _ := yamux.Server(conn, conf) + + // Update the context to store the yamux session + ctx.Session = server + + // Create the RPC server for this connection + rpcServer := rpc.NewServer() + s.setupRpcServer(rpcServer, ctx) + for { sub, err := server.Accept() if err != nil { @@ -150,12 +182,12 @@ func (s *Server) handleMultiplex(conn net.Conn) { } return } - go s.handleNomadConn(sub) + go s.handleNomadConn(sub, rpcServer) } } // handleNomadConn is used to service a single Nomad RPC connection -func (s *Server) handleNomadConn(conn net.Conn) { +func (s *Server) handleNomadConn(conn net.Conn, server *rpc.Server) { defer conn.Close() rpcCodec := NewServerCodec(conn) for { @@ -165,7 +197,7 @@ func (s *Server) handleNomadConn(conn net.Conn) { default: } - if err := s.rpcServer.ServeRequest(rpcCodec); err != nil { + if err := server.ServeRequest(rpcCodec); err != nil { if err != io.EOF && !strings.Contains(err.Error(), "closed") { s.logger.Printf("[ERR] nomad.rpc: RPC error: %v (%v)", err, conn) metrics.IncrCounter([]string{"nomad", "rpc", "request_error"}, 1) diff --git a/nomad/server.go b/nomad/server.go index 7b298b217288..32214cdd3e59 100644 --- a/nomad/server.go +++ b/nomad/server.go @@ -745,7 +745,7 @@ func (s *Server) setupVaultClient() error { // setupRPC is used to setup the RPC listener func (s *Server) setupRPC(tlsWrap tlsutil.RegionWrapper) error { // Populate the static RPC server - s.setupRpcServer(s.rpcServer) + s.setupRpcServer(s.rpcServer, nil) list, err := net.ListenTCP("tcp", s.config.RPCAddr) if err != nil { @@ -776,7 +776,7 @@ func (s *Server) setupRPC(tlsWrap tlsutil.RegionWrapper) error { } // setupRpcServer is used to populate an RPC server with endpoints -func (s *Server) setupRpcServer(server *rpc.Server) { +func (s *Server) setupRpcServer(server *rpc.Server, ctx *RPCContext) { // Add the static endpoints to the RPC server. if s.staticEndpoints.Status == nil { // Initialize the list just once From dd7b71e063530c12201a1b1e9f8784c4c60e3e8b Mon Sep 17 00:00:00 2001 From: Alex Dadgar Date: Thu, 4 Jan 2018 16:33:07 -0800 Subject: [PATCH 3/5] Improve TLS cluster testing --- nomad/rpc.go | 26 +++++++++- nomad/server_test.go | 115 ++++++++++++++++++++++++++++++++----------- 2 files changed, 112 insertions(+), 29 deletions(-) diff --git a/nomad/rpc.go b/nomad/rpc.go index 49dea7a9fd18..ef46e0accf32 100644 --- a/nomad/rpc.go +++ b/nomad/rpc.go @@ -145,11 +145,35 @@ func (s *Server) handleConn(conn net.Conn, ctx *RPCContext) { } conn = tls.Server(conn, s.rpcTLS) + // Force a handshake so we can get information about the TLS connection + // state. + tlsConn, ok := conn.(*tls.Conn) + if !ok { + s.logger.Printf("[ERR] nomad.rpc: expected TLS connection but got %T", conn) + conn.Close() + return + } + + if err := tlsConn.Handshake(); err != nil { + s.logger.Printf("[WARN] nomad.rpc: failed TLS handshake from connection from %v: %v", tlsConn.RemoteAddr(), err) + conn.Close() + return + } + // Update the connection context with the fact that the connection is // using TLS - // TODO pull out more TLS information into the context ctx.TLS = true + // Parse the region and role from the TLS certificate + state := tlsConn.ConnectionState() + parts := strings.SplitN(state.ServerName, ".", 3) + if len(parts) != 3 || (parts[0] != "server" && parts[0] != "client") || parts[2] != "nomad" { + s.logger.Printf("[WARN] nomad.rpc: invalid server name %q on verified TLS connection", state.ServerName) + } else { + ctx.TLSRole = parts[0] + ctx.TLSRegion = parts[1] + } + s.handleConn(conn, ctx) default: diff --git a/nomad/server_test.go b/nomad/server_test.go index 04175a2900ac..cea9a0c3bdf8 100644 --- a/nomad/server_test.go +++ b/nomad/server_test.go @@ -144,7 +144,7 @@ func TestServer_RPC(t *testing.T) { } } -func TestServer_RPC_MixedTLS(t *testing.T) { +func TestServer_RPC_TLS(t *testing.T) { t.Parallel() const ( cafile = "../helper/tlsutil/testdata/ca.pem" @@ -154,6 +154,7 @@ func TestServer_RPC_MixedTLS(t *testing.T) { dir := tmpDir(t) defer os.RemoveAll(dir) s1 := testServer(t, func(c *Config) { + c.Region = "regionFoo" c.BootstrapExpect = 3 c.DevMode = false c.DevDisableBootstrap = true @@ -170,53 +171,111 @@ func TestServer_RPC_MixedTLS(t *testing.T) { defer s1.Shutdown() s2 := testServer(t, func(c *Config) { + c.Region = "regionFoo" c.BootstrapExpect = 3 c.DevMode = false c.DevDisableBootstrap = true c.DataDir = path.Join(dir, "node2") + c.TLSConfig = &config.TLSConfig{ + EnableHTTP: true, + EnableRPC: true, + VerifyServerHostname: true, + CAFile: cafile, + CertFile: foocert, + KeyFile: fookey, + } }) defer s2.Shutdown() s3 := testServer(t, func(c *Config) { + c.Region = "regionFoo" c.BootstrapExpect = 3 c.DevMode = false c.DevDisableBootstrap = true c.DataDir = path.Join(dir, "node3") + c.TLSConfig = &config.TLSConfig{ + EnableHTTP: true, + EnableRPC: true, + VerifyServerHostname: true, + CAFile: cafile, + CertFile: foocert, + KeyFile: fookey, + } }) defer s3.Shutdown() testJoin(t, s1, s2, s3) + testutil.WaitForLeader(t, s1.RPC) - l1, l2, l3, shutdown := make(chan error, 1), make(chan error, 1), make(chan error, 1), make(chan struct{}, 1) + // Part of a server joining is making an RPC request, so just by testing + // that there is a leader we verify that the RPCs are working over TLS. +} - wait := func(done chan error, rpc func(string, interface{}, interface{}) error) { - for { - select { - case <-shutdown: - return - default: - } +func TestServer_RPC_MixedTLS(t *testing.T) { + t.Parallel() + const ( + cafile = "../helper/tlsutil/testdata/ca.pem" + foocert = "../helper/tlsutil/testdata/nomad-foo.pem" + fookey = "../helper/tlsutil/testdata/nomad-foo-key.pem" + ) + dir := tmpDir(t) + defer os.RemoveAll(dir) + s1 := testServer(t, func(c *Config) { + c.Region = "regionFoo" + c.BootstrapExpect = 3 + c.DevMode = false + c.DevDisableBootstrap = true + c.DataDir = path.Join(dir, "node1") + c.TLSConfig = &config.TLSConfig{ + EnableHTTP: true, + EnableRPC: true, + VerifyServerHostname: true, + CAFile: cafile, + CertFile: foocert, + KeyFile: fookey, + } + }) + defer s1.Shutdown() - args := &structs.GenericRequest{} - var leader string - err := rpc("Status.Leader", args, &leader) - if err != nil || leader != "" { - done <- err - } + s2 := testServer(t, func(c *Config) { + c.Region = "regionFoo" + c.BootstrapExpect = 3 + c.DevMode = false + c.DevDisableBootstrap = true + c.DataDir = path.Join(dir, "node2") + c.TLSConfig = &config.TLSConfig{ + EnableHTTP: true, + EnableRPC: true, + VerifyServerHostname: true, + CAFile: cafile, + CertFile: foocert, + KeyFile: fookey, + } + }) + defer s2.Shutdown() + s3 := testServer(t, func(c *Config) { + c.Region = "regionFoo" + c.BootstrapExpect = 3 + c.DevMode = false + c.DevDisableBootstrap = true + c.DataDir = path.Join(dir, "node3") + }) + defer s3.Shutdown() + + testJoin(t, s1, s2, s3) + + // Ensure that we do not form a quorum + start := time.Now() + for { + if time.Now().After(start.Add(2 * time.Second)) { + break } - } - go wait(l1, s1.RPC) - go wait(l2, s2.RPC) - go wait(l3, s3.RPC) - - select { - case <-time.After(5 * time.Second): - case err := <-l1: - t.Fatalf("Server 1 has leader or error: %v", err) - case err := <-l2: - t.Fatalf("Server 2 has leader or error: %v", err) - case err := <-l3: - t.Fatalf("Server 3 has leader or error: %v", err) + args := &structs.GenericRequest{} + var leader string + err := s1.RPC("Status.Leader", args, &leader) + if err == nil || leader != "" { + t.Fatalf("Got leader or no error: %q %v", leader, err) + } } } From 7f4d929276505df9b635e743768e7189872663af Mon Sep 17 00:00:00 2001 From: Alex Dadgar Date: Fri, 5 Jan 2018 13:50:04 -0800 Subject: [PATCH 4/5] Track client connections --- nomad/node_endpoint.go | 24 +++++++++++++ nomad/node_endpoint_test.go | 70 ++++++++++++++++++++++++++++++------- nomad/rpc.go | 22 ++++++++++-- nomad/server.go | 58 ++++++++++++++++++++++++++++-- 4 files changed, 155 insertions(+), 19 deletions(-) diff --git a/nomad/node_endpoint.go b/nomad/node_endpoint.go index 7f4265fb972d..faa2b973da1d 100644 --- a/nomad/node_endpoint.go +++ b/nomad/node_endpoint.go @@ -36,6 +36,9 @@ const ( type Node struct { srv *Server + // ctx provides context regarding the underlying connection + ctx *RPCContext + // updates holds pending client status updates for allocations updates []*structs.Allocation @@ -114,6 +117,13 @@ func (n *Node) Register(args *structs.NodeRegisterRequest, reply *structs.NodeUp } } + // We have a valid node connection, so add the mapping to cache the + // connection and allow the server to send RPCs to the client. + if n.ctx != nil && n.ctx.NodeID == "" { + n.ctx.NodeID = args.Node.ID + n.srv.addNodeConn(n.ctx) + } + // Commit this update via Raft _, index, err := n.srv.raftApply(structs.NodeRegisterRequestType, args) if err != nil { @@ -305,6 +315,13 @@ func (n *Node) UpdateStatus(args *structs.NodeUpdateStatusRequest, reply *struct return fmt.Errorf("node not found") } + // We have a valid node connection, so add the mapping to cache the + // connection and allow the server to send RPCs to the client. + if n.ctx != nil && n.ctx.NodeID == "" { + n.ctx.NodeID = args.NodeID + n.srv.addNodeConn(n.ctx) + } + // XXX: Could use the SecretID here but have to update the heartbeat system // to track SecretIDs. @@ -724,6 +741,13 @@ func (n *Node) GetClientAllocs(args *structs.NodeSpecificRequest, return fmt.Errorf("node secret ID does not match") } + // We have a valid node connection, so add the mapping to cache the + // connection and allow the server to send RPCs to the client. + if n.ctx != nil && n.ctx.NodeID == "" { + n.ctx.NodeID = args.NodeID + n.srv.addNodeConn(n.ctx) + } + var err error allocs, err = state.AllocsByNode(ws, args.NodeID) if err != nil { diff --git a/nomad/node_endpoint_test.go b/nomad/node_endpoint_test.go index c53ea9f97d64..d3121a77f9da 100644 --- a/nomad/node_endpoint_test.go +++ b/nomad/node_endpoint_test.go @@ -16,15 +16,20 @@ import ( "github.com/hashicorp/nomad/testutil" vapi "github.com/hashicorp/vault/api" "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" ) func TestClientEndpoint_Register(t *testing.T) { t.Parallel() + require := require.New(t) s1 := testServer(t, nil) defer s1.Shutdown() codec := rpcClient(t, s1) testutil.WaitForLeader(t, s1.RPC) + // Check that we have no client connections + require.Empty(s1.connectedNodes()) + // Create the register request node := mock.Node() req := &structs.NodeRegisterRequest{ @@ -41,6 +46,11 @@ func TestClientEndpoint_Register(t *testing.T) { t.Fatalf("bad index: %d", resp.Index) } + // Check that we have the client connections + nodes := s1.connectedNodes() + require.Len(nodes, 1) + require.Equal(node.ID, nodes[0]) + // Check for the node in the FSM state := s1.fsm.State() ws := memdb.NewWatchSet() @@ -57,6 +67,15 @@ func TestClientEndpoint_Register(t *testing.T) { if out.ComputedClass == "" { t.Fatal("ComputedClass not set") } + + // Close the connection and check that we remove the client connections + require.Nil(codec.Close()) + testutil.WaitForResult(func() (bool, error) { + nodes := s1.connectedNodes() + return len(nodes) == 0, nil + }, func(err error) { + t.Fatalf("should have no clients") + }) } func TestClientEndpoint_Register_SecretMismatch(t *testing.T) { @@ -260,11 +279,15 @@ func TestClientEndpoint_Deregister_Vault(t *testing.T) { func TestClientEndpoint_UpdateStatus(t *testing.T) { t.Parallel() + require := require.New(t) s1 := testServer(t, nil) defer s1.Shutdown() codec := rpcClient(t, s1) testutil.WaitForLeader(t, s1.RPC) + // Check that we have no client connections + require.Empty(s1.connectedNodes()) + // Create the register request node := mock.Node() reg := &structs.NodeRegisterRequest{ @@ -304,6 +327,11 @@ func TestClientEndpoint_UpdateStatus(t *testing.T) { t.Fatalf("bad: %#v", ttl) } + // Check that we have the client connections + nodes := s1.connectedNodes() + require.Len(nodes, 1) + require.Equal(node.ID, nodes[0]) + // Check for the node in the FSM state := s1.fsm.State() ws := memdb.NewWatchSet() @@ -317,6 +345,15 @@ func TestClientEndpoint_UpdateStatus(t *testing.T) { if out.ModifyIndex != resp2.Index { t.Fatalf("index mis-match") } + + // Close the connection and check that we remove the client connections + require.Nil(codec.Close()) + testutil.WaitForResult(func() (bool, error) { + nodes := s1.connectedNodes() + return len(nodes) == 0, nil + }, func(err error) { + t.Fatalf("should have no clients") + }) } func TestClientEndpoint_UpdateStatus_Vault(t *testing.T) { @@ -1230,30 +1267,23 @@ func TestClientEndpoint_GetAllocs_ACL_Basic(t *testing.T) { func TestClientEndpoint_GetClientAllocs(t *testing.T) { t.Parallel() + require := require.New(t) s1 := testServer(t, nil) defer s1.Shutdown() codec := rpcClient(t, s1) testutil.WaitForLeader(t, s1.RPC) + // Check that we have no client connections + require.Empty(s1.connectedNodes()) + // Create the register request node := mock.Node() - reg := &structs.NodeRegisterRequest{ - Node: node, - WriteRequest: structs.WriteRequest{Region: "global"}, - } - - // Fetch the response - var resp structs.GenericResponse - if err := msgpackrpc.CallWithCodec(codec, "Node.Register", reg, &resp); err != nil { - t.Fatalf("err: %v", err) - } - node.CreateIndex = resp.Index - node.ModifyIndex = resp.Index + state := s1.fsm.State() + require.Nil(state.UpsertNode(98, node)) // Inject fake evaluations alloc := mock.Alloc() alloc.NodeID = node.ID - state := s1.fsm.State() state.UpsertJobSummary(99, mock.JobSummary(alloc.JobID)) err := state.UpsertAllocs(100, []*structs.Allocation{alloc}) if err != nil { @@ -1278,6 +1308,11 @@ func TestClientEndpoint_GetClientAllocs(t *testing.T) { t.Fatalf("bad: %#v", resp2.Allocs) } + // Check that we have the client connections + nodes := s1.connectedNodes() + require.Len(nodes, 1) + require.Equal(node.ID, nodes[0]) + // Lookup node with bad SecretID get.SecretID = "foobarbaz" var resp3 structs.NodeClientAllocsResponse @@ -1298,6 +1333,15 @@ func TestClientEndpoint_GetClientAllocs(t *testing.T) { if len(resp4.Allocs) != 0 { t.Fatalf("unexpected node %#v", resp3.Allocs) } + + // Close the connection and check that we remove the client connections + require.Nil(codec.Close()) + testutil.WaitForResult(func() (bool, error) { + nodes := s1.connectedNodes() + return len(nodes) == 0, nil + }, func(err error) { + t.Fatalf("should have no clients") + }) } func TestClientEndpoint_GetClientAllocs_Blocking(t *testing.T) { diff --git a/nomad/rpc.go b/nomad/rpc.go index ef46e0accf32..3dd1c2c46372 100644 --- a/nomad/rpc.go +++ b/nomad/rpc.go @@ -57,6 +57,9 @@ const ( // RPCContext provides metadata about the RPC connection. type RPCContext struct { + // Conn exposes the raw connection. + Conn net.Conn + // Session exposes the multiplexed connection session. Session *yamux.Session @@ -66,8 +69,11 @@ type RPCContext struct { // TLSRole is the certificate role making the TLS connection. TLSRole string - // TLSRegion is the region on the certificate making theTLS connection + // TLSRegion is the region on the certificate making the TLS connection TLSRegion string + + // NodeID marks the NodeID that initiated the connection. + NodeID string } // NewClientCodec returns a new rpc.ClientCodec to be used to make RPC calls to @@ -95,7 +101,7 @@ func (s *Server) listen() { continue } - go s.handleConn(conn, &RPCContext{}) + go s.handleConn(conn, &RPCContext{Conn: conn}) metrics.IncrCounter([]string{"nomad", "rpc", "accept_conn"}, 1) } } @@ -130,6 +136,10 @@ func (s *Server) handleConn(conn net.Conn, ctx *RPCContext) { s.setupRpcServer(server, ctx) s.handleNomadConn(conn, server) + // Remove any potential mapping between a NodeID to this connection and + // close the underlying connection. + s.removeNodeConn(ctx) + case rpcRaft: metrics.IncrCounter([]string{"nomad", "rpc", "raft_handoff"}, 1) s.raftLayer.Handoff(conn) @@ -186,7 +196,13 @@ func (s *Server) handleConn(conn net.Conn, ctx *RPCContext) { // handleMultiplex is used to multiplex a single incoming connection // using the Yamux multiplexer func (s *Server) handleMultiplex(conn net.Conn, ctx *RPCContext) { - defer conn.Close() + defer func() { + // Remove any potential mapping between a NodeID to this connection and + // close the underlying connection. + s.removeNodeConn(ctx) + conn.Close() + }() + conf := yamux.DefaultConfig() conf.LogOutput = s.config.LogOutput server, _ := yamux.Server(conn, conf) diff --git a/nomad/server.go b/nomad/server.go index 32214cdd3e59..a826e7601c66 100644 --- a/nomad/server.go +++ b/nomad/server.go @@ -29,6 +29,7 @@ import ( "github.com/hashicorp/raft" raftboltdb "github.com/hashicorp/raft-boltdb" "github.com/hashicorp/serf/serf" + "github.com/hashicorp/yamux" ) const ( @@ -116,6 +117,11 @@ type Server struct { // all RPC connections staticEndpoints endpoints + // nodeConns is the set of multiplexed node connections we have keyed by + // NodeID + nodeConns map[string]*yamux.Session + nodeConnsLock sync.RWMutex + // peers is used to track the known Nomad servers. This is // used for region forwarding and clustering. peers map[string][]*serverParts @@ -261,6 +267,7 @@ func NewServer(config *Config, consulCatalog consul.CatalogAPI, logger *log.Logg connPool: NewPool(config.LogOutput, serverRPCCache, serverMaxStreams, tlsWrap), logger: logger, rpcServer: rpc.NewServer(), + nodeConns: make(map[string]*yamux.Session), peers: make(map[string][]*serverParts), localPeers: make(map[raft.ServerAddress]*serverParts), reconcileCh: make(chan serf.Member, 32), @@ -784,7 +791,7 @@ func (s *Server) setupRpcServer(server *rpc.Server, ctx *RPCContext) { s.staticEndpoints.Alloc = &Alloc{s} s.staticEndpoints.Eval = &Eval{s} s.staticEndpoints.Job = &Job{s} - s.staticEndpoints.Node = &Node{srv: s} + s.staticEndpoints.Node = &Node{srv: s} // Add but don't register s.staticEndpoints.Deployment = &Deployment{srv: s} s.staticEndpoints.Operator = &Operator{s} s.staticEndpoints.Periodic = &Periodic{s} @@ -801,7 +808,6 @@ func (s *Server) setupRpcServer(server *rpc.Server, ctx *RPCContext) { server.Register(s.staticEndpoints.Alloc) server.Register(s.staticEndpoints.Eval) server.Register(s.staticEndpoints.Job) - server.Register(s.staticEndpoints.Node) server.Register(s.staticEndpoints.Deployment) server.Register(s.staticEndpoints.Operator) server.Register(s.staticEndpoints.Periodic) @@ -813,7 +819,10 @@ func (s *Server) setupRpcServer(server *rpc.Server, ctx *RPCContext) { s.staticEndpoints.Enterprise.Register(server) // Create new dynamic endpoints and add them to the RPC server. - // TODO + node := &Node{srv: s, ctx: ctx} + + // Register the dynamic endpoints + server.Register(node) } // setupRaft is used to setup and initialize Raft @@ -1172,6 +1181,49 @@ func (s *Server) RPC(method string, args interface{}, reply interface{}) error { return codec.err } +// getNodeConn returns the connection to the given node and whether it exists. +func (s *Server) getNodeConn(nodeID string) (*yamux.Session, bool) { + s.nodeConnsLock.RLock() + defer s.nodeConnsLock.RUnlock() + session, ok := s.nodeConns[nodeID] + return session, ok +} + +// connectedNodes returns the set of nodes we have a connection with. +func (s *Server) connectedNodes() []string { + s.nodeConnsLock.RLock() + defer s.nodeConnsLock.RUnlock() + nodes := make([]string, 0, len(s.nodeConns)) + for nodeID := range s.nodeConns { + nodes = append(nodes, nodeID) + } + return nodes +} + +// addNodeConn adds the mapping between a node and its session. +func (s *Server) addNodeConn(ctx *RPCContext) { + // Hotpath the no-op + if ctx == nil || ctx.NodeID == "" { + return + } + + s.nodeConnsLock.Lock() + defer s.nodeConnsLock.Unlock() + s.nodeConns[ctx.NodeID] = ctx.Session +} + +// removeNodeConn removes the mapping between a node and its session. +func (s *Server) removeNodeConn(ctx *RPCContext) { + // Hotpath the no-op + if ctx == nil || ctx.NodeID == "" { + return + } + + s.nodeConnsLock.Lock() + defer s.nodeConnsLock.Unlock() + delete(s.nodeConns, ctx.NodeID) +} + // Stats is used to return statistics for debugging and insight // for various sub-systems func (s *Server) Stats() map[string]map[string]string { From a495c83a68d0b3b2b70964c2222eccf0dbcdf81d Mon Sep 17 00:00:00 2001 From: Alex Dadgar Date: Wed, 10 Jan 2018 16:11:36 -0800 Subject: [PATCH 5/5] Store the whole verified certificate chain --- nomad/rpc.go | 19 ++++++------------- 1 file changed, 6 insertions(+), 13 deletions(-) diff --git a/nomad/rpc.go b/nomad/rpc.go index 3dd1c2c46372..f765e288ca42 100644 --- a/nomad/rpc.go +++ b/nomad/rpc.go @@ -3,6 +3,7 @@ package nomad import ( "context" "crypto/tls" + "crypto/x509" "fmt" "io" "math/rand" @@ -66,11 +67,9 @@ type RPCContext struct { // TLS marks whether the RPC is over a TLS based connection TLS bool - // TLSRole is the certificate role making the TLS connection. - TLSRole string - - // TLSRegion is the region on the certificate making the TLS connection - TLSRegion string + // VerifiedChains is is the Verified certificates presented by the incoming + // connection. + VerifiedChains [][]*x509.Certificate // NodeID marks the NodeID that initiated the connection. NodeID string @@ -174,15 +173,9 @@ func (s *Server) handleConn(conn net.Conn, ctx *RPCContext) { // using TLS ctx.TLS = true - // Parse the region and role from the TLS certificate + // Store the verified chains so they can be inspected later. state := tlsConn.ConnectionState() - parts := strings.SplitN(state.ServerName, ".", 3) - if len(parts) != 3 || (parts[0] != "server" && parts[0] != "client") || parts[2] != "nomad" { - s.logger.Printf("[WARN] nomad.rpc: invalid server name %q on verified TLS connection", state.ServerName) - } else { - ctx.TLSRole = parts[0] - ctx.TLSRegion = parts[1] - } + ctx.VerifiedChains = state.VerifiedChains s.handleConn(conn, ctx)