From 5a2eec24fde4e247574040df4dc9172c3788dd13 Mon Sep 17 00:00:00 2001 From: Drew Bailey <2614075+drewbailey@users.noreply.github.com> Date: Thu, 12 Dec 2019 16:52:13 -0500 Subject: [PATCH] region forwarding; prevent recursive forwards for impossible requests prevent region forwarding loop, backfill tests --- api/agent_test.go | 18 +++- client/agent_endpoint.go | 4 +- client/agent_endpoint_test.go | 113 ++++++++++++++++++++++ command/agent/agent_endpoint.go | 7 +- command/agent/agent_endpoint_test.go | 62 ++++++------ command/agent/profile/pprof.go | 2 +- nomad/client_agent_endpoint.go | 29 +++++- nomad/client_agent_endpoint_test.go | 136 +++++++++++++++++++++++++++ 8 files changed, 323 insertions(+), 48 deletions(-) diff --git a/api/agent_test.go b/api/agent_test.go index a39d179c4f14..21a99eeb0473 100644 --- a/api/agent_test.go +++ b/api/agent_test.go @@ -389,9 +389,21 @@ func TestAgentCPUProfile(t *testing.T) { AuthToken: token.SecretID, } - resp, err := agent.CPUProfile("", "", 1, q) - require.NoError(t, err) - require.NotNil(t, resp) + // Valid local request + { + resp, err := agent.CPUProfile("", "", 1, q) + require.NoError(t, err) + require.NotNil(t, resp) + } + + // Invalid server request + { + resp, err := agent.CPUProfile("unknown.global", "", 1, q) + require.Error(t, err) + require.Contains(t, err.Error(), "500 (unknown nomad server unknown.global)") + require.Nil(t, resp) + } + } func TestAgentTrace(t *testing.T) { diff --git a/client/agent_endpoint.go b/client/agent_endpoint.go index d507c5487d94..238d414d91b6 100644 --- a/client/agent_endpoint.go +++ b/client/agent_endpoint.go @@ -32,9 +32,9 @@ func NewAgentEndpoint(c *Client) *Agent { func (a *Agent) Profile(args *structs.AgentPprofRequest, reply *structs.AgentPprofResponse) error { // Check ACL for agent write if aclObj, err := a.c.ResolveToken(args.AuthToken); err != nil { - return structs.NewErrRPCCoded(500, err.Error()) + return err } else if aclObj != nil && !aclObj.AllowAgentWrite() { - return structs.NewErrRPCCoded(403, structs.ErrPermissionDenied.Error()) + return structs.ErrPermissionDenied } var resp []byte diff --git a/client/agent_endpoint_test.go b/client/agent_endpoint_test.go index 6762b3e5fa04..5c5269da3e55 100644 --- a/client/agent_endpoint_test.go +++ b/client/agent_endpoint_test.go @@ -13,6 +13,7 @@ import ( "github.com/hashicorp/nomad/client/config" sframer "github.com/hashicorp/nomad/client/lib/streamframer" cstructs "github.com/hashicorp/nomad/client/structs" + "github.com/hashicorp/nomad/command/agent/profile" "github.com/hashicorp/nomad/nomad" "github.com/hashicorp/nomad/nomad/mock" "github.com/hashicorp/nomad/nomad/structs" @@ -213,3 +214,115 @@ func TestMonitor_Monitor_ACL(t *testing.T) { }) } } + +func TestAgentProfile(t *testing.T) { + t.Parallel() + require := require.New(t) + + // start server and client + s1, cleanup := nomad.TestServer(t, nil) + defer cleanup() + + testutil.WaitForLeader(t, s1.RPC) + + c, cleanupC := TestClient(t, func(c *config.Config) { + c.Servers = []string{s1.GetConfig().RPCAddr.String()} + }) + defer cleanupC() + + // Successful request + { + req := structs.AgentPprofRequest{ + ReqType: profile.CPUReq, + NodeID: c.NodeID(), + } + + reply := structs.AgentPprofResponse{} + + err := c.ClientRPC("Agent.Profile", &req, &reply) + require.NoError(err) + + require.NotNil(reply.Payload) + require.Equal(c.NodeID(), reply.AgentID) + } + + // Unknown profile request + { + req := structs.AgentPprofRequest{ + ReqType: profile.LookupReq, + Profile: "unknown", + NodeID: c.NodeID(), + } + + reply := structs.AgentPprofResponse{} + + err := c.ClientRPC("Agent.Profile", &req, &reply) + require.EqualError(err, "RPC Error:: 404,Pprof profile not found profile: unknown") + } +} + +func TestAgentProfile_ACL(t *testing.T) { + t.Parallel() + require := require.New(t) + + // start server + // start server + s, root, cleanupS := nomad.TestACLServer(t, nil) + defer cleanupS() + testutil.WaitForLeader(t, s.RPC) + + c, cleanupC := TestClient(t, func(c *config.Config) { + c.ACLEnabled = true + c.Servers = []string{s.GetConfig().RPCAddr.String()} + }) + defer cleanupC() + + policyBad := mock.AgentPolicy(acl.PolicyRead) + tokenBad := mock.CreatePolicyAndToken(t, s.State(), 1005, "invalid", policyBad) + + policyGood := mock.AgentPolicy(acl.PolicyWrite) + tokenGood := mock.CreatePolicyAndToken(t, s.State(), 1009, "valid", policyGood) + + cases := []struct { + Name string + Token string + authErr bool + }{ + { + Name: "bad token", + Token: tokenBad.SecretID, + authErr: true, + }, + { + Name: "good token", + Token: tokenGood.SecretID, + }, + { + Name: "root token", + Token: root.SecretID, + }, + } + + for _, tc := range cases { + t.Run(tc.Name, func(t *testing.T) { + req := &structs.AgentPprofRequest{ + ReqType: profile.CmdReq, + QueryOptions: structs.QueryOptions{ + Namespace: structs.DefaultNamespace, + Region: "global", + AuthToken: tc.Token, + }, + } + + reply := &structs.AgentPprofResponse{} + + err := c.ClientRPC("Agent.Profile", req, reply) + if tc.authErr { + require.EqualError(err, structs.ErrPermissionDenied.Error()) + } else { + require.NoError(err) + require.NotNil(reply.Payload) + } + }) + } +} diff --git a/command/agent/agent_endpoint.go b/command/agent/agent_endpoint.go index a154abb0d4ec..aba912149c5e 100644 --- a/command/agent/agent_endpoint.go +++ b/command/agent/agent_endpoint.go @@ -426,12 +426,7 @@ func (s *HTTPServer) agentPprof(reqType profile.ReqType, resp http.ResponseWrite } if rpcErr != nil { - code, msg, ok := structs.CodeFromRPCCodedErr(rpcErr) - if !ok { - return nil, CodedError(500, rpcErr.Error()) - } - // Return CodedError - return nil, CodedError(code, msg) + return nil, rpcErr } // Set headers from profile request diff --git a/command/agent/agent_endpoint_test.go b/command/agent/agent_endpoint_test.go index a317228a7d62..ed09e1ced7b3 100644 --- a/command/agent/agent_endpoint_test.go +++ b/command/agent/agent_endpoint_test.go @@ -489,43 +489,46 @@ func TestAgent_PprofRequest_Permissions(t *testing.T) { func TestAgent_PprofRequest(t *testing.T) { cases := []struct { - desc string - url string - addNodeID bool - addServerID bool - expectedErr string - expectedStatus int + desc string + url string + addNodeID bool + addServerID bool + expectedErr string }{ { - desc: "cmdline request", - url: "/v1/agent/pprof/cmdline", - addNodeID: true, - expectedStatus: 200, + desc: "cmdline local request", + url: "/v1/agent/pprof/cmdline", }, { - desc: "cpu profile request", - url: "/v1/agent/pprof/profile", - addNodeID: true, - expectedStatus: 200, + desc: "cmdline node request", + url: "/v1/agent/pprof/cmdline", + addNodeID: true, }, { - desc: "trace request", - url: "/v1/agent/pprof/trace", - addNodeID: true, - expectedStatus: 200, + desc: "cmdline server request", + url: "/v1/agent/pprof/cmdline", + addServerID: true, }, { - desc: "pprof lookup request", - url: "/v1/agent/pprof/goroutine", - addNodeID: true, - expectedStatus: 200, + desc: "cpu profile request", + url: "/v1/agent/pprof/profile", + addNodeID: true, }, { - desc: "unknown pprof lookup request", - url: "/v1/agent/pprof/latency", - addNodeID: true, - expectedStatus: 404, - expectedErr: "Unknown profile: latency", + desc: "trace request", + url: "/v1/agent/pprof/trace", + addNodeID: true, + }, + { + desc: "pprof lookup request", + url: "/v1/agent/pprof/goroutine", + addNodeID: true, + }, + { + desc: "unknown pprof lookup request", + url: "/v1/agent/pprof/latency", + addNodeID: true, + expectedErr: "RPC Error:: 404,Pprof profile not found profile: latency", }, } @@ -549,10 +552,7 @@ func TestAgent_PprofRequest(t *testing.T) { if tc.expectedErr != "" { require.Error(t, err) - - httpErr, ok := err.(HTTPCodedError) - require.True(t, ok) - require.Equal(t, httpErr.Code(), tc.expectedStatus) + require.EqualError(t, err, tc.expectedErr) } else { require.NoError(t, err) require.NotNil(t, resp) diff --git a/command/agent/profile/pprof.go b/command/agent/profile/pprof.go index eafb40dec232..e3d4fe0acf76 100644 --- a/command/agent/profile/pprof.go +++ b/command/agent/profile/pprof.go @@ -25,7 +25,7 @@ const ( TraceReq ReqType = "trace" LookupReq ReqType = "profile" - ErrProfileNotFoundPrefix = "Pprof profile not found" + ErrProfileNotFoundPrefix = "Pprof profile not found profile:" ) // NewErrProfileNotFound returns a new error caused by a pprof.Lookup diff --git a/nomad/client_agent_endpoint.go b/nomad/client_agent_endpoint.go index 6c256a51ea41..cb411dfa83fc 100644 --- a/nomad/client_agent_endpoint.go +++ b/nomad/client_agent_endpoint.go @@ -29,6 +29,19 @@ func (a *Agent) register() { } func (a *Agent) Profile(args *structs.AgentPprofRequest, reply *structs.AgentPprofResponse) error { + // handle when serverID does not exist for requested region + region := args.RequestRegion() + if region == "" { + return fmt.Errorf("missing target RPC") + } + + // Handle region forwarding + if region != a.srv.config.Region { + // Mark that we are forwarding + args.SetForwarded() + return a.srv.forwardRegion(region, "Agent.Profile", args, reply) + } + // Targeting a node, forward request to node if args.NodeID != "" { return a.forwardProfileClient(args, reply) @@ -36,7 +49,7 @@ func (a *Agent) Profile(args *structs.AgentPprofRequest, reply *structs.AgentPpr // Handle serverID not equal to ours if args.ServerID != "" { - serverToFwd, err := a.serverFor(args.ServerID) + serverToFwd, err := a.serverFor(args.ServerID, region) if err != nil { return err } @@ -47,9 +60,9 @@ func (a *Agent) Profile(args *structs.AgentPprofRequest, reply *structs.AgentPpr // Check ACL for agent write if aclObj, err := a.srv.ResolveToken(args.AuthToken); err != nil { - return structs.NewErrRPCCoded(500, err.Error()) + return err } else if aclObj != nil && !aclObj.AllowAgentWrite() { - return structs.NewErrRPCCoded(403, structs.ErrPermissionDenied.Error()) + return structs.ErrPermissionDenied } // Process the request on this server @@ -247,7 +260,7 @@ OUTER: } } -func (a *Agent) serverFor(serverID string) (*serverParts, error) { +func (a *Agent) serverFor(serverID, region string) (*serverParts, error) { var target *serverParts if serverID == "leader" { @@ -267,6 +280,12 @@ func (a *Agent) serverFor(serverID string) (*serverParts, error) { // with a serf member if mem.Name == serverID || mem.Tags["id"] == serverID { if ok, srv := isNomadServer(mem); ok { + if srv.Region != region { + return nil, + fmt.Errorf( + "Requested server:%s region:%s does not exist in requested region: %s", + serverID, srv.Region, region) + } target = srv } @@ -281,7 +300,7 @@ func (a *Agent) serverFor(serverID string) (*serverParts, error) { // ServerID is this current server, // No need to forward request - if target.ID == a.srv.GetConfig().NodeID { + if target.Name == a.srv.LocalMember().Name { return nil, nil } diff --git a/nomad/client_agent_endpoint_test.go b/nomad/client_agent_endpoint_test.go index f4bff3da87e9..0ddd084b49f5 100644 --- a/nomad/client_agent_endpoint_test.go +++ b/nomad/client_agent_endpoint_test.go @@ -502,6 +502,83 @@ func TestAgentProfile_RemoteClient(t *testing.T) { require.Equal(c.NodeID(), reply.AgentID) } +// Test that we prevent a forwarding loop if the requested +// serverID does not exist in the requested region +func TestAgentProfile_RemoteRegionMisMatch(t *testing.T) { + t.Parallel() + require := require.New(t) + + // start server and client + s1, cleanupS1 := TestServer(t, func(c *Config) { + c.NumSchedulers = 0 + c.Region = "foo" + }) + defer cleanupS1() + + s2, cleanup := TestServer(t, func(c *Config) { + c.NumSchedulers = 0 + c.Region = "bar" + }) + defer cleanup() + + TestJoin(t, s1, s2) + testutil.WaitForLeader(t, s1.RPC) + + req := structs.AgentPprofRequest{ + ReqType: profile.CPUReq, + ServerID: s1.serf.LocalMember().Name, + QueryOptions: structs.QueryOptions{ + Region: "bar", + }, + } + + reply := structs.AgentPprofResponse{} + + err := s1.RPC("Agent.Profile", &req, &reply) + require.Contains(err.Error(), "does not exist in requested region") + + require.NotNil(reply.Payload) + require.Equal(s1.config.NodeID, reply.AgentID) +} + +// Test that Agent.Profile can forward to a different region +func TestAgentProfile_RemoteRegion(t *testing.T) { + t.Parallel() + require := require.New(t) + + // start server and client + s1, cleanupS1 := TestServer(t, func(c *Config) { + c.NumSchedulers = 0 + c.Region = "foo" + }) + defer cleanupS1() + + s2, cleanup := TestServer(t, func(c *Config) { + c.NumSchedulers = 0 + c.Region = "bar" + }) + defer cleanup() + + TestJoin(t, s1, s2) + testutil.WaitForLeader(t, s1.RPC) + + req := structs.AgentPprofRequest{ + ReqType: profile.CPUReq, + ServerID: s2.serf.LocalMember().Name, + QueryOptions: structs.QueryOptions{ + Region: "bar", + }, + } + + reply := structs.AgentPprofResponse{} + + err := s1.RPC("Agent.Profile", &req, &reply) + require.NoError(err) + + require.NotNil(reply.Payload) + require.Equal(s2.serf.LocalMember().Name, reply.AgentID) +} + func TestAgentProfile_Server(t *testing.T) { t.Parallel() @@ -599,3 +676,62 @@ func TestAgentProfile_Server(t *testing.T) { }) } } + +func TestAgentProfile_ACL(t *testing.T) { + t.Parallel() + require := require.New(t) + + // start server + s, root, cleanupS := TestACLServer(t, nil) + defer cleanupS() + testutil.WaitForLeader(t, s.RPC) + + policyBad := mock.NamespacePolicy("other", "", []string{acl.NamespaceCapabilityReadFS}) + tokenBad := mock.CreatePolicyAndToken(t, s.State(), 1005, "invalid", policyBad) + + policyGood := mock.AgentPolicy(acl.PolicyWrite) + tokenGood := mock.CreatePolicyAndToken(t, s.State(), 1009, "valid", policyGood) + + cases := []struct { + Name string + Token string + ExpectedErr string + }{ + { + Name: "bad token", + Token: tokenBad.SecretID, + ExpectedErr: "RPC Error:: 403,Permission denied", + }, + { + Name: "good token", + Token: tokenGood.SecretID, + }, + { + Name: "root token", + Token: root.SecretID, + }, + } + + for _, tc := range cases { + t.Run(tc.Name, func(t *testing.T) { + req := &structs.AgentPprofRequest{ + ReqType: profile.CmdReq, + QueryOptions: structs.QueryOptions{ + Namespace: structs.DefaultNamespace, + Region: "global", + AuthToken: tc.Token, + }, + } + + reply := &structs.AgentPprofResponse{} + + err := s.RPC("Agent.Profile", req, reply) + if tc.ExpectedErr != "" { + require.Equal(tc.ExpectedErr, err.Error()) + } else { + require.NoError(err) + require.NotNil(reply.Payload) + } + }) + } +}