diff --git a/nomad/client_rpc.go b/nomad/client_rpc.go index ca8db233619d..37ced690c4d0 100644 --- a/nomad/client_rpc.go +++ b/nomad/client_rpc.go @@ -68,9 +68,9 @@ func (s *Server) connectedNodes() map[string]time.Time { } // addNodeConn adds the mapping between a node and its session. -func (s *Server) addNodeConn(ctx *RPCContext) { +func (s *Server) addNodeConn(ctx *RPCContext, args structs.RPCInfo) { // Hotpath the no-op - if ctx == nil || ctx.NodeID == "" { + if ctx == nil || ctx.NodeID == "" || args.IsForwarded() { return } diff --git a/nomad/client_rpc_test.go b/nomad/client_rpc_test.go index a653440280e4..81dc66824002 100644 --- a/nomad/client_rpc_test.go +++ b/nomad/client_rpc_test.go @@ -27,6 +27,33 @@ func (n namedConnWrapper) LocalAddr() net.Addr { return namedAddr(n.name) } +func TestServier_addNodeConn_ignoresForwardedRequests(t *testing.T) { + t.Parallel() + + s, cleanupS1 := TestServer(t, nil) + defer cleanupS1() + testutil.WaitForLeader(t, s.RPC) + + p, _ := net.Pipe() + nodeID := uuid.Generate() + + ctx := &RPCContext{ + Conn: p, + NodeID: nodeID, + } + + require.Empty(t, s.connectedNodes()) + + q := &structs.QueryOptions{} + q.Forwarded = true + s.addNodeConn(ctx, q) + require.Empty(t, s.connectedNodes()) + + s.addNodeConn(ctx, &structs.QueryOptions{}) + require.Len(t, s.connectedNodes(), 1) + require.Contains(t, s.connectedNodes(), nodeID) +} + func TestServer_removeNodeConn_differentAddrs(t *testing.T) { t.Parallel() require := require.New(t) @@ -56,8 +83,8 @@ func TestServer_removeNodeConn_differentAddrs(t *testing.T) { NodeID: nodeID, } - s1.addNodeConn(ctx1) - s1.addNodeConn(ctx2) + s1.addNodeConn(ctx1, &structs.QueryOptions{}) + s1.addNodeConn(ctx2, &structs.QueryOptions{}) require.Len(s1.connectedNodes(), 1) require.Len(s1.nodeConns[nodeID], 2) @@ -140,7 +167,7 @@ func TestServerWithNodeConn_Path(t *testing.T) { nodeID := uuid.Generate() s2.addNodeConn(&RPCContext{ NodeID: nodeID, - }) + }, &structs.QueryOptions{}) srv, err := s1.serverWithNodeConn(nodeID, s1.Region()) require.NotNil(srv) @@ -166,7 +193,7 @@ func TestServerWithNodeConn_Path_Region(t *testing.T) { nodeID := uuid.Generate() s2.addNodeConn(&RPCContext{ NodeID: nodeID, - }) + }, &structs.QueryOptions{}) srv, err := s1.serverWithNodeConn(nodeID, s2.Region()) require.NotNil(srv) @@ -199,10 +226,10 @@ func TestServerWithNodeConn_Path_Newest(t *testing.T) { nodeID := uuid.Generate() s2.addNodeConn(&RPCContext{ NodeID: nodeID, - }) + }, &structs.QueryOptions{}) s3.addNodeConn(&RPCContext{ NodeID: nodeID, - }) + }, &structs.QueryOptions{}) srv, err := s1.serverWithNodeConn(nodeID, s1.Region()) require.NotNil(srv) @@ -235,7 +262,7 @@ func TestServerWithNodeConn_PathAndErr(t *testing.T) { nodeID := uuid.Generate() s2.addNodeConn(&RPCContext{ NodeID: nodeID, - }) + }, &structs.QueryOptions{}) // Shutdown the RPC layer for server 3 s3.rpcListener.Close() diff --git a/nomad/node_endpoint.go b/nomad/node_endpoint.go index f635abc2791b..b78a84e280d9 100644 --- a/nomad/node_endpoint.go +++ b/nomad/node_endpoint.go @@ -82,9 +82,9 @@ func (n *Node) Register(args *structs.NodeRegisterRequest, reply *structs.NodeUp // We have a valid node connection since there is no error from the // forwarded server, so add the mapping to cache the // connection and allow the server to send RPCs to the client. - if err == nil && n.ctx != nil && n.ctx.NodeID == "" && !args.IsForwarded() { + if err == nil && n.ctx != nil && n.ctx.NodeID == "" { n.ctx.NodeID = args.Node.ID - n.srv.addNodeConn(n.ctx) + n.srv.addNodeConn(n.ctx, args) } return err @@ -154,9 +154,9 @@ func (n *Node) Register(args *structs.NodeRegisterRequest, reply *structs.NodeUp // We have a valid node connection, so add the mapping to cache the // connection and allow the server to send RPCs to the client. We only cache // the connection if it is not being forwarded from another server. - if n.ctx != nil && n.ctx.NodeID == "" && !args.IsForwarded() { + if n.ctx != nil && n.ctx.NodeID == "" { n.ctx.NodeID = args.Node.ID - n.srv.addNodeConn(n.ctx) + n.srv.addNodeConn(n.ctx, args) } // Commit this update via Raft @@ -374,9 +374,9 @@ func (n *Node) UpdateStatus(args *structs.NodeUpdateStatusRequest, reply *struct // We have a valid node connection since there is no error from the // forwarded server, so add the mapping to cache the // connection and allow the server to send RPCs to the client. - if err == nil && n.ctx != nil && n.ctx.NodeID == "" && !args.IsForwarded() { + if err == nil && n.ctx != nil && n.ctx.NodeID == "" { n.ctx.NodeID = args.NodeID - n.srv.addNodeConn(n.ctx) + n.srv.addNodeConn(n.ctx, args) } return err @@ -409,9 +409,9 @@ func (n *Node) UpdateStatus(args *structs.NodeUpdateStatusRequest, reply *struct // We have a valid node connection, so add the mapping to cache the // connection and allow the server to send RPCs to the client. We only cache // the connection if it is not being forwarded from another server. - if n.ctx != nil && n.ctx.NodeID == "" && !args.IsForwarded() { + if n.ctx != nil && n.ctx.NodeID == "" { n.ctx.NodeID = args.NodeID - n.srv.addNodeConn(n.ctx) + n.srv.addNodeConn(n.ctx, args) } // XXX: Could use the SecretID here but have to update the heartbeat system @@ -925,9 +925,9 @@ func (n *Node) GetClientAllocs(args *structs.NodeSpecificRequest, // We have a valid node connection since there is no error from the // forwarded server, so add the mapping to cache the // connection and allow the server to send RPCs to the client. - if err == nil && n.ctx != nil && n.ctx.NodeID == "" && !args.IsForwarded() { + if err == nil && n.ctx != nil && n.ctx.NodeID == "" { n.ctx.NodeID = args.NodeID - n.srv.addNodeConn(n.ctx) + n.srv.addNodeConn(n.ctx, args) } return err @@ -967,9 +967,9 @@ func (n *Node) GetClientAllocs(args *structs.NodeSpecificRequest, // We have a valid node connection, so add the mapping to cache the // connection and allow the server to send RPCs to the client. We only cache // the connection if it is not being forwarded from another server. - if n.ctx != nil && n.ctx.NodeID == "" && !args.IsForwarded() { + if n.ctx != nil && n.ctx.NodeID == "" { n.ctx.NodeID = args.NodeID - n.srv.addNodeConn(n.ctx) + n.srv.addNodeConn(n.ctx, args) } var err error diff --git a/nomad/status_endpoint_test.go b/nomad/status_endpoint_test.go index 0c724dbb7d08..bf988978d729 100644 --- a/nomad/status_endpoint_test.go +++ b/nomad/status_endpoint_test.go @@ -206,7 +206,7 @@ func TestStatus_HasClientConn(t *testing.T) { // Create a connection on that node s1.addNodeConn(&RPCContext{ NodeID: arg.NodeID, - }) + }, &structs.QueryOptions{}) var out3 structs.NodeConnQueryResponse require.Nil(msgpackrpc.CallWithCodec(codec, "Status.HasNodeConn", arg, &out3)) require.True(out3.Connected)