From 29ffa0268340acf81ee39785b06a934176e1c001 Mon Sep 17 00:00:00 2001 From: Luiz Aoqui Date: Fri, 4 Feb 2022 20:35:20 -0500 Subject: [PATCH] fix mTLS certificate check on agent to agent RPCs (#11998) PR #11956 implemented a new mTLS RPC check to validate the role of the certificate used in the request, but further testing revealed two flaws: 1. client-only endpoints did not accept server certificates so the request would fail when forwarded from one server to another. 2. the certificate was being checked after the request was forwarded, so the check would happen over the server certificate, not the actual source. This commit checks for the desired mTLS level, where the client level accepts both, a server or a client certificate. It also validates the cercertificate before the request is forwarded. --- .semgrep/rpc_endpoint.yml | 17 +------- nomad/alloc_endpoint.go | 12 +++--- nomad/deployment_endpoint.go | 12 +++--- nomad/eval_endpoint.go | 83 +++++++++++++++++++++--------------- nomad/node_endpoint.go | 22 +++++----- nomad/plan_endpoint.go | 11 ++--- nomad/rpc.go | 12 +++--- nomad/rpc_test.go | 55 ++++++++++++++++-------- nomad/util.go | 42 +++++++++++++++++- 9 files changed, 165 insertions(+), 101 deletions(-) diff --git a/.semgrep/rpc_endpoint.yml b/.semgrep/rpc_endpoint.yml index 2277a6b19843..9f22f67a2dca 100644 --- a/.semgrep/rpc_endpoint.yml +++ b/.semgrep/rpc_endpoint.yml @@ -30,26 +30,11 @@ rules: # Pattern used by endpoints called exclusively between agents # (server -> server or client -> server) - pattern-not-inside: | - if done, err := $A.$B.forward($METHOD, ...); done { - return err - } - ... - ... := validateLocalClientTLSCertificate(...) + ... := validateTLSCertificateLevel(...) ... - - pattern-not-inside: | if done, err := $A.$B.forward($METHOD, ...); done { return err } - ... - ... := validateLocalServerTLSCertificate(...) - ... - - pattern-not-inside: | - if done, err := $A.$B.forward($METHOD, ...); done { - return err - } - ... - ... := validateTLSCertificate(...) - ... # Pattern used by some Node endpoints. - pattern-not-inside: | if done, err := $A.$B.forward($METHOD, ...); done { diff --git a/nomad/alloc_endpoint.go b/nomad/alloc_endpoint.go index 6c8231c6b1f0..92abee62f4a5 100644 --- a/nomad/alloc_endpoint.go +++ b/nomad/alloc_endpoint.go @@ -222,15 +222,17 @@ func (a *Alloc) GetAlloc(args *structs.AllocSpecificRequest, // GetAllocs is used to lookup a set of allocations func (a *Alloc) GetAllocs(args *structs.AllocsGetRequest, reply *structs.AllocsGetResponse) error { - if done, err := a.srv.forward("Alloc.GetAllocs", args, args, reply); done { + + // Ensure the connection was initiated by a client if TLS is used. + err := validateTLSCertificateLevel(a.srv, a.ctx, tlsCertificateLevelClient) + if err != nil { return err } - defer metrics.MeasureSince([]string{"nomad", "alloc", "get_allocs"}, time.Now()) - // Ensure the connection was initiated by a client if TLS is used. - if err := validateLocalClientTLSCertificate(a.srv, a.ctx); err != nil { - return fmt.Errorf("invalid client connection in region %s: %v", a.srv.Region(), err) + if done, err := a.srv.forward("Alloc.GetAllocs", args, args, reply); done { + return err } + defer metrics.MeasureSince([]string{"nomad", "alloc", "get_allocs"}, time.Now()) allocs := make([]*structs.Allocation, len(args.AllocIDs)) diff --git a/nomad/deployment_endpoint.go b/nomad/deployment_endpoint.go index 0bc073768561..2c18de98d540 100644 --- a/nomad/deployment_endpoint.go +++ b/nomad/deployment_endpoint.go @@ -504,15 +504,17 @@ func (d *Deployment) Allocations(args *structs.DeploymentSpecificRequest, reply // Reap is used to cleanup terminal deployments func (d *Deployment) Reap(args *structs.DeploymentDeleteRequest, reply *structs.GenericResponse) error { - if done, err := d.srv.forward("Deployment.Reap", args, args, reply); done { + + // Ensure the connection was initiated by another server if TLS is used. + err := validateTLSCertificateLevel(d.srv, d.ctx, tlsCertificateLevelServer) + if err != nil { return err } - defer metrics.MeasureSince([]string{"nomad", "deployment", "reap"}, time.Now()) - // Ensure the connection was initiated by another server if TLS is used. - if err := validateLocalServerTLSCertificate(d.srv, d.ctx); err != nil { - return fmt.Errorf("invalid server connection in region %s: %v", d.srv.Region(), err) + if done, err := d.srv.forward("Deployment.Reap", args, args, reply); done { + return err } + defer metrics.MeasureSince([]string{"nomad", "deployment", "reap"}, time.Now()) // Update via Raft _, index, err := d.srv.raftApply(structs.DeploymentDeleteRequestType, args) diff --git a/nomad/eval_endpoint.go b/nomad/eval_endpoint.go index 18b83c45de00..8a48e27c1dee 100644 --- a/nomad/eval_endpoint.go +++ b/nomad/eval_endpoint.go @@ -85,15 +85,17 @@ func (e *Eval) GetEval(args *structs.EvalSpecificRequest, // Dequeue is used to dequeue a pending evaluation func (e *Eval) Dequeue(args *structs.EvalDequeueRequest, reply *structs.EvalDequeueResponse) error { - if done, err := e.srv.forward("Eval.Dequeue", args, args, reply); done { + + // Ensure the connection was initiated by another server if TLS is used. + err := validateTLSCertificateLevel(e.srv, e.ctx, tlsCertificateLevelServer) + if err != nil { return err } - defer metrics.MeasureSince([]string{"nomad", "eval", "dequeue"}, time.Now()) - // Ensure the connection was initiated by another server if TLS is used. - if err := validateLocalServerTLSCertificate(e.srv, e.ctx); err != nil { - return fmt.Errorf("invalid server connection in region %s: %v", e.srv.Region(), err) + if done, err := e.srv.forward("Eval.Dequeue", args, args, reply); done { + return err } + defer metrics.MeasureSince([]string{"nomad", "eval", "dequeue"}, time.Now()) // Ensure there is at least one scheduler if len(args.Schedulers) == 0 { @@ -175,15 +177,17 @@ func (e *Eval) getWaitIndex(namespace, job string, evalModifyIndex uint64) (uint // Ack is used to acknowledge completion of a dequeued evaluation func (e *Eval) Ack(args *structs.EvalAckRequest, reply *structs.GenericResponse) error { - if done, err := e.srv.forward("Eval.Ack", args, args, reply); done { + + // Ensure the connection was initiated by another server if TLS is used. + err := validateTLSCertificateLevel(e.srv, e.ctx, tlsCertificateLevelServer) + if err != nil { return err } - defer metrics.MeasureSince([]string{"nomad", "eval", "ack"}, time.Now()) - // Ensure the connection was initiated by another server if TLS is used. - if err := validateLocalServerTLSCertificate(e.srv, e.ctx); err != nil { - return fmt.Errorf("invalid server connection in region %s: %v", e.srv.Region(), err) + if done, err := e.srv.forward("Eval.Ack", args, args, reply); done { + return err } + defer metrics.MeasureSince([]string{"nomad", "eval", "ack"}, time.Now()) // Ack the EvalID if err := e.srv.evalBroker.Ack(args.EvalID, args.Token); err != nil { @@ -195,15 +199,17 @@ func (e *Eval) Ack(args *structs.EvalAckRequest, // Nack is used to negative acknowledge completion of a dequeued evaluation. func (e *Eval) Nack(args *structs.EvalAckRequest, reply *structs.GenericResponse) error { - if done, err := e.srv.forward("Eval.Nack", args, args, reply); done { + + // Ensure the connection was initiated by another server if TLS is used. + err := validateTLSCertificateLevel(e.srv, e.ctx, tlsCertificateLevelServer) + if err != nil { return err } - defer metrics.MeasureSince([]string{"nomad", "eval", "nack"}, time.Now()) - // Ensure the connection was initiated by another server if TLS is used. - if err := validateLocalServerTLSCertificate(e.srv, e.ctx); err != nil { - return fmt.Errorf("invalid server connection in region %s: %v", e.srv.Region(), err) + if done, err := e.srv.forward("Eval.Nack", args, args, reply); done { + return err } + defer metrics.MeasureSince([]string{"nomad", "eval", "nack"}, time.Now()) // Nack the EvalID if err := e.srv.evalBroker.Nack(args.EvalID, args.Token); err != nil { @@ -215,15 +221,17 @@ func (e *Eval) Nack(args *structs.EvalAckRequest, // Update is used to perform an update of an Eval if it is outstanding. func (e *Eval) Update(args *structs.EvalUpdateRequest, reply *structs.GenericResponse) error { - if done, err := e.srv.forward("Eval.Update", args, args, reply); done { + + // Ensure the connection was initiated by another server if TLS is used. + err := validateTLSCertificateLevel(e.srv, e.ctx, tlsCertificateLevelServer) + if err != nil { return err } - defer metrics.MeasureSince([]string{"nomad", "eval", "update"}, time.Now()) - // Ensure the connection was initiated by another server if TLS is used. - if err := validateLocalServerTLSCertificate(e.srv, e.ctx); err != nil { - return fmt.Errorf("invalid server connection in region %s: %v", e.srv.Region(), err) + if done, err := e.srv.forward("Eval.Update", args, args, reply); done { + return err } + defer metrics.MeasureSince([]string{"nomad", "eval", "update"}, time.Now()) // Ensure there is only a single update with token if len(args.Evals) != 1 { @@ -250,15 +258,17 @@ func (e *Eval) Update(args *structs.EvalUpdateRequest, // Create is used to make a new evaluation func (e *Eval) Create(args *structs.EvalUpdateRequest, reply *structs.GenericResponse) error { - if done, err := e.srv.forward("Eval.Create", args, args, reply); done { + + // Ensure the connection was initiated by another server if TLS is used. + err := validateTLSCertificateLevel(e.srv, e.ctx, tlsCertificateLevelServer) + if err != nil { return err } - defer metrics.MeasureSince([]string{"nomad", "eval", "create"}, time.Now()) - // Ensure the connection was initiated by another server if TLS is used. - if err := validateLocalServerTLSCertificate(e.srv, e.ctx); err != nil { - return fmt.Errorf("invalid server connection in region %s: %v", e.srv.Region(), err) + if done, err := e.srv.forward("Eval.Create", args, args, reply); done { + return err } + defer metrics.MeasureSince([]string{"nomad", "eval", "create"}, time.Now()) // Ensure there is only a single update with token if len(args.Evals) != 1 { @@ -300,15 +310,16 @@ func (e *Eval) Create(args *structs.EvalUpdateRequest, // Reblock is used to reinsert an existing blocked evaluation into the blocked // evaluation tracker. func (e *Eval) Reblock(args *structs.EvalUpdateRequest, reply *structs.GenericResponse) error { - if done, err := e.srv.forward("Eval.Reblock", args, args, reply); done { + // Ensure the connection was initiated by another server if TLS is used. + err := validateTLSCertificateLevel(e.srv, e.ctx, tlsCertificateLevelServer) + if err != nil { return err } - defer metrics.MeasureSince([]string{"nomad", "eval", "reblock"}, time.Now()) - // Ensure the connection was initiated by another server if TLS is used. - if err := validateLocalServerTLSCertificate(e.srv, e.ctx); err != nil { - return fmt.Errorf("invalid server connection in region %s: %v", e.srv.Region(), err) + if done, err := e.srv.forward("Eval.Reblock", args, args, reply); done { + return err } + defer metrics.MeasureSince([]string{"nomad", "eval", "reblock"}, time.Now()) // Ensure there is only a single update with token if len(args.Evals) != 1 { @@ -347,15 +358,17 @@ func (e *Eval) Reblock(args *structs.EvalUpdateRequest, reply *structs.GenericRe // Reap is used to cleanup dead evaluations and allocations func (e *Eval) Reap(args *structs.EvalDeleteRequest, reply *structs.GenericResponse) error { - if done, err := e.srv.forward("Eval.Reap", args, args, reply); done { + + // Ensure the connection was initiated by another server if TLS is used. + err := validateTLSCertificateLevel(e.srv, e.ctx, tlsCertificateLevelServer) + if err != nil { return err } - defer metrics.MeasureSince([]string{"nomad", "eval", "reap"}, time.Now()) - // Ensure the connection was initiated by another server if TLS is used. - if err := validateLocalServerTLSCertificate(e.srv, e.ctx); err != nil { - return fmt.Errorf("invalid server connection in region %s: %v", e.srv.Region(), err) + if done, err := e.srv.forward("Eval.Reap", args, args, reply); done { + return err } + defer metrics.MeasureSince([]string{"nomad", "eval", "reap"}, time.Now()) // Update via Raft _, index, err := e.srv.raftApply(structs.EvalDeleteRequestType, args) diff --git a/nomad/node_endpoint.go b/nomad/node_endpoint.go index 8ed43b2dc057..d5a1725b4852 100644 --- a/nomad/node_endpoint.go +++ b/nomad/node_endpoint.go @@ -1098,15 +1098,16 @@ func (n *Node) GetClientAllocs(args *structs.NodeSpecificRequest, // UpdateAlloc is used to update the client status of an allocation func (n *Node) UpdateAlloc(args *structs.AllocUpdateRequest, reply *structs.GenericResponse) error { - if done, err := n.srv.forward("Node.UpdateAlloc", args, args, reply); done { + // Ensure the connection was initiated by another client if TLS is used. + err := validateTLSCertificateLevel(n.srv, n.ctx, tlsCertificateLevelClient) + if err != nil { return err } - defer metrics.MeasureSince([]string{"nomad", "client", "update_alloc"}, time.Now()) - // Ensure the connection was initiated by a client if TLS is used. - if err := validateLocalClientTLSCertificate(n.srv, n.ctx); err != nil { - return fmt.Errorf("invalid client connection in region %s: %v", n.srv.Region(), err) + if done, err := n.srv.forward("Node.UpdateAlloc", args, args, reply); done { + return err } + defer metrics.MeasureSince([]string{"nomad", "client", "update_alloc"}, time.Now()) // Ensure at least a single alloc if len(args.Alloc) == 0 { @@ -1920,15 +1921,16 @@ func taskUsesConnect(task *structs.Task) bool { } func (n *Node) EmitEvents(args *structs.EmitNodeEventsRequest, reply *structs.EmitNodeEventsResponse) error { - if done, err := n.srv.forward("Node.EmitEvents", args, args, reply); done { + // Ensure the connection was initiated by another client if TLS is used. + err := validateTLSCertificateLevel(n.srv, n.ctx, tlsCertificateLevelClient) + if err != nil { return err } - defer metrics.MeasureSince([]string{"nomad", "client", "emit_events"}, time.Now()) - // Ensure the connection was initiated by a client if TLS is used. - if err := validateLocalClientTLSCertificate(n.srv, n.ctx); err != nil { - return fmt.Errorf("invalid client connection in region %s: %v", n.srv.Region(), err) + if done, err := n.srv.forward("Node.EmitEvents", args, args, reply); done { + return err } + defer metrics.MeasureSince([]string{"nomad", "client", "emit_events"}, time.Now()) if len(args.NodeEvents) == 0 { return fmt.Errorf("no node events given") diff --git a/nomad/plan_endpoint.go b/nomad/plan_endpoint.go index a6cd8dbefa67..4979270e439e 100644 --- a/nomad/plan_endpoint.go +++ b/nomad/plan_endpoint.go @@ -21,15 +21,16 @@ type Plan struct { // Submit is used to submit a plan to the leader func (p *Plan) Submit(args *structs.PlanRequest, reply *structs.PlanResponse) error { - if done, err := p.srv.forward("Plan.Submit", args, args, reply); done { + // Ensure the connection was initiated by another server if TLS is used. + err := validateTLSCertificateLevel(p.srv, p.ctx, tlsCertificateLevelServer) + if err != nil { return err } - defer metrics.MeasureSince([]string{"nomad", "plan", "submit"}, time.Now()) - // Ensure the connection was initiated by another server if TLS is used. - if err := validateLocalServerTLSCertificate(p.srv, p.ctx); err != nil { - return fmt.Errorf("invalid server connection in region %s: %v", p.srv.Region(), err) + if done, err := p.srv.forward("Plan.Submit", args, args, reply); done { + return err } + defer metrics.MeasureSince([]string{"nomad", "plan", "submit"}, time.Now()) if args.Plan == nil { return fmt.Errorf("cannot submit nil plan") diff --git a/nomad/rpc.go b/nomad/rpc.go index 37446d53ee77..96db49b641bd 100644 --- a/nomad/rpc.go +++ b/nomad/rpc.go @@ -127,16 +127,16 @@ func (ctx *RPCContext) ValidateCertificateForName(name string) error { if cert == nil { return errors.New("missing certificate information") } - for _, dnsName := range cert.DNSNames { - if dnsName == name { + + validNames := []string{cert.Subject.CommonName} + validNames = append(validNames, cert.DNSNames...) + for _, valid := range validNames { + if name == valid { return nil } } - if cert.Subject.CommonName == name { - return nil - } - return fmt.Errorf("certificate not valid for %q", name) + return fmt.Errorf("invalid certificate, %s not in %s", name, strings.Join(validNames, ",")) } // listen is used to listen for incoming RPC connections diff --git a/nomad/rpc_test.go b/nomad/rpc_test.go index 07f2d9492ee4..bd738f2793a3 100644 --- a/nomad/rpc_test.go +++ b/nomad/rpc_test.go @@ -1081,7 +1081,7 @@ func TestRPC_TLS_Enforcement_Raft(t *testing.T) { } t.Run("Raft RPC: verify_hostname=true", func(t *testing.T) { - err := tlsHelper.raftRPC(t, tlsHelper.mtlsServer, cfg) + err := tlsHelper.raftRPC(t, tlsHelper.mtlsServer1, cfg) // the expected error depends on location of failure. // We expect "bad certificate" if connection fails during handshake, @@ -1186,7 +1186,7 @@ func TestRPC_TLS_Enforcement_RPC(t *testing.T) { name: "local server/clients only rpc", cn: "server.global.nomad", rpcs: localClientsOnlyRPCs, - canRPC: false, + canRPC: true, }, // Local client. { @@ -1274,18 +1274,22 @@ func TestRPC_TLS_Enforcement_RPC(t *testing.T) { } for method, arg := range tc.rpcs { - t.Run(fmt.Sprintf("nomad RPC: rpc=%s verify_hostname=true", method), func(t *testing.T) { - err := tlsHelper.nomadRPC(t, tlsHelper.mtlsServer, cfg, method, arg) - - if tc.canRPC { - if err != nil { - require.NotContains(t, err, "certificate") + for _, srv := range []*Server{tlsHelper.mtlsServer1, tlsHelper.mtlsServer2} { + name := fmt.Sprintf("nomad RPC: rpc=%s verify_hostname=true leader=%v", method, srv.IsLeader()) + t.Run(name, func(t *testing.T) { + err := tlsHelper.nomadRPC(t, srv, cfg, method, arg) + + if tc.canRPC { + if err != nil { + require.NotContains(t, err, "certificate") + } + } else { + require.Error(t, err) + require.Contains(t, err.Error(), "certificate") } - } else { - require.Error(t, err) - require.Contains(t, err.Error(), "certificate") - } - }) + }) + } + t.Run(fmt.Sprintf("nomad RPC: rpc=%s verify_hostname=false", method), func(t *testing.T) { err := tlsHelper.nomadRPC(t, tlsHelper.nonVerifyServer, cfg, method, arg) if err != nil { @@ -1301,8 +1305,10 @@ type tlsTestHelper struct { dir string nodeID int - mtlsServer *Server - mtlsServerCleanup func() + mtlsServer1 *Server + mtlsServer1Cleanup func() + mtlsServer2 *Server + mtlsServer2Cleanup func() nonVerifyServer *Server nonVerifyServerCleanup func() @@ -1329,7 +1335,18 @@ func newTLSTestHelper(t *testing.T) tlsTestHelper { // Generate servers and their certificate. h.serverCert = h.newCert(t, "server.global.nomad") - h.mtlsServer, h.mtlsServerCleanup = TestServer(t, func(c *Config) { + h.mtlsServer1, h.mtlsServer1Cleanup = TestServer(t, func(c *Config) { + c.BootstrapExpect = 2 + c.TLSConfig = &config.TLSConfig{ + EnableRPC: true, + VerifyServerHostname: true, + CAFile: filepath.Join(h.dir, "ca.pem"), + CertFile: h.serverCert + ".pem", + KeyFile: h.serverCert + ".key", + } + }) + h.mtlsServer2, h.mtlsServer2Cleanup = TestServer(t, func(c *Config) { + c.BootstrapExpect = 2 c.TLSConfig = &config.TLSConfig{ EnableRPC: true, VerifyServerHostname: true, @@ -1338,6 +1355,9 @@ func newTLSTestHelper(t *testing.T) tlsTestHelper { KeyFile: h.serverCert + ".key", } }) + TestJoin(t, h.mtlsServer1, h.mtlsServer2) + testutil.WaitForLeader(t, h.mtlsServer1.RPC) + testutil.WaitForLeader(t, h.mtlsServer2.RPC) h.nonVerifyServer, h.nonVerifyServerCleanup = TestServer(t, func(c *Config) { c.TLSConfig = &config.TLSConfig{ @@ -1353,7 +1373,8 @@ func newTLSTestHelper(t *testing.T) tlsTestHelper { } func (h tlsTestHelper) cleanup() { - h.mtlsServerCleanup() + h.mtlsServer1Cleanup() + h.mtlsServer2Cleanup() h.nonVerifyServerCleanup() os.RemoveAll(h.dir) } diff --git a/nomad/util.go b/nomad/util.go index daa6999f807a..210a202d9590 100644 --- a/nomad/util.go +++ b/nomad/util.go @@ -302,18 +302,56 @@ func getAlloc(state AllocGetter, allocID string) (*structs.Allocation, error) { return alloc, nil } +// tlsCertificateLevel represents a role level for mTLS certificates. +type tlsCertificateLevel int8 + +const ( + tlsCertificateLevelServer tlsCertificateLevel = iota + tlsCertificateLevelClient +) + +// validateTLSCertificateLevel checks if the provided RPC connection was +// initiated with a certificate that matches the given TLS role level. +// +// - tlsCertificateLevelServer requires a server certificate. +// - tlsCertificateLevelServer requires a client or server certificate. +func validateTLSCertificateLevel(srv *Server, ctx *RPCContext, lvl tlsCertificateLevel) error { + switch lvl { + case tlsCertificateLevelClient: + err := validateLocalClientTLSCertificate(srv, ctx) + if err != nil { + return validateLocalServerTLSCertificate(srv, ctx) + } + return nil + case tlsCertificateLevelServer: + return validateLocalServerTLSCertificate(srv, ctx) + } + + return fmt.Errorf("invalid TLS certificate level %v", lvl) +} + // validateLocalClientTLSCertificate checks if the provided RPC connection was // initiated by a client in the same region as the target server. func validateLocalClientTLSCertificate(srv *Server, ctx *RPCContext) error { expected := fmt.Sprintf("client.%s.nomad", srv.Region()) - return validateTLSCertificate(srv, ctx, expected) + + err := validateTLSCertificate(srv, ctx, expected) + if err != nil { + return fmt.Errorf("invalid client connection in region %s: %v", srv.Region(), err) + } + return nil } // validateLocalServerTLSCertificate checks if the provided RPC connection was // initiated by a server in the same region as the target server. func validateLocalServerTLSCertificate(srv *Server, ctx *RPCContext) error { expected := fmt.Sprintf("server.%s.nomad", srv.Region()) - return validateTLSCertificate(srv, ctx, expected) + + err := validateTLSCertificate(srv, ctx, expected) + if err != nil { + return fmt.Errorf("invalid server connection in region %s: %v", srv.Region(), err) + } + return nil } // validateTLSCertificate checks if the RPC connection mTLS certificates are