From 781b2551512a4cc656d8d6ed3d605bc244066653 Mon Sep 17 00:00:00 2001 From: Mahmood Ali Date: Sun, 28 Apr 2019 17:25:27 -0400 Subject: [PATCH] server: server forwarding logic for nomad exec endpoint --- nomad/client_alloc_endpoint.go | 130 ++++++++++++++++++++ nomad/client_alloc_endpoint_test.go | 184 ++++++++++++++++++++++++++++ nomad/client_fs_endpoint.go | 80 ++++++------ nomad/server.go | 1 + 4 files changed, 355 insertions(+), 40 deletions(-) diff --git a/nomad/client_alloc_endpoint.go b/nomad/client_alloc_endpoint.go index 44713a9b75f8..b863695f6cb5 100644 --- a/nomad/client_alloc_endpoint.go +++ b/nomad/client_alloc_endpoint.go @@ -2,11 +2,16 @@ package nomad import ( "errors" + "fmt" + "io" + "net" "time" metrics "github.com/armon/go-metrics" log "github.com/hashicorp/go-hclog" cstructs "github.com/hashicorp/nomad/client/structs" + "github.com/hashicorp/nomad/helper" + "github.com/ugorji/go/codec" "github.com/hashicorp/nomad/acl" "github.com/hashicorp/nomad/nomad/structs" @@ -19,6 +24,10 @@ type ClientAllocations struct { logger log.Logger } +func (a *ClientAllocations) register() { + a.srv.streamingRpcs.Register("Allocations.Exec", a.exec) +} + // GarbageCollectAll is used to garbage collect all allocations on a client. func (a *ClientAllocations) GarbageCollectAll(args *structs.NodeSpecificRequest, reply *structs.GenericResponse) error { // We only allow stale reads since the only potentially stale information is @@ -287,3 +296,124 @@ func (a *ClientAllocations) Stats(args *cstructs.AllocStatsRequest, reply *cstru // Make the RPC return NodeRpc(state.Session, "Allocations.Stats", args, reply) } + +func (a *ClientAllocations) exec(conn io.ReadWriteCloser) { + defer conn.Close() + defer metrics.MeasureSince([]string{"nomad", "alloc", "exec"}, time.Now()) + + // Decode the arguments + var args cstructs.AllocExecRequest + decoder := codec.NewDecoder(conn, structs.MsgpackHandle) + encoder := codec.NewEncoder(conn, structs.MsgpackHandle) + + if err := decoder.Decode(&args); err != nil { + handleStreamResultError(err, helper.Int64ToPtr(500), encoder) + return + } + + // Check if we need to forward to a different region + if r := args.RequestRegion(); r != a.srv.Region() { + forwardRegionStreamingRpc(a.srv, conn, encoder, &args, "Allocations.Exec", + args.AllocID, &args.QueryOptions) + return + } + + // Check node read permissions + if aclObj, err := a.srv.ResolveToken(args.AuthToken); err != nil { + handleStreamResultError(err, nil, encoder) + return + } else if aclObj != nil { + // client ultimately checks if AllocNodeExec is required + exec := aclObj.AllowNsOp(args.QueryOptions.Namespace, acl.NamespaceCapabilityAllocExec) + if !exec { + handleStreamResultError(structs.ErrPermissionDenied, nil, encoder) + return + } + } + + // Verify the arguments. + if args.AllocID == "" { + handleStreamResultError(errors.New("missing AllocID"), helper.Int64ToPtr(400), encoder) + return + } + + // Retrieve the allocation + snap, err := a.srv.State().Snapshot() + if err != nil { + handleStreamResultError(err, nil, encoder) + return + } + + alloc, err := snap.AllocByID(nil, args.AllocID) + if err != nil { + handleStreamResultError(err, nil, encoder) + return + } + if alloc == nil { + handleStreamResultError(structs.NewErrUnknownAllocation(args.AllocID), helper.Int64ToPtr(404), encoder) + return + } + nodeID := alloc.NodeID + + // Make sure Node is valid and new enough to support RPC + node, err := snap.NodeByID(nil, nodeID) + if err != nil { + handleStreamResultError(err, helper.Int64ToPtr(500), encoder) + return + } + + if node == nil { + err := fmt.Errorf("Unknown node %q", nodeID) + handleStreamResultError(err, helper.Int64ToPtr(400), encoder) + return + } + + if err := nodeSupportsRpc(node); err != nil { + handleStreamResultError(err, helper.Int64ToPtr(400), encoder) + return + } + + // Get the connection to the client either by forwarding to another server + // or creating a direct stream + var clientConn net.Conn + state, ok := a.srv.getNodeConn(nodeID) + if !ok { + // Determine the Server that has a connection to the node. + srv, err := a.srv.serverWithNodeConn(nodeID, a.srv.Region()) + if err != nil { + var code *int64 + if structs.IsErrNoNodeConn(err) { + code = helper.Int64ToPtr(404) + } + handleStreamResultError(err, code, encoder) + return + } + + // Get a connection to the server + conn, err := a.srv.streamingRpc(srv, "Allocations.Exec") + if err != nil { + handleStreamResultError(err, nil, encoder) + return + } + + clientConn = conn + } else { + stream, err := NodeStreamingRpc(state.Session, "Allocations.Exec") + if err != nil { + handleStreamResultError(err, nil, encoder) + return + } + clientConn = stream + } + defer clientConn.Close() + + // Send the request. + outEncoder := codec.NewEncoder(clientConn, structs.MsgpackHandle) + if err := outEncoder.Encode(args); err != nil { + handleStreamResultError(err, nil, encoder) + return + } + + structs.Bridge(conn, clientConn) + return +} diff --git a/nomad/client_alloc_endpoint_test.go b/nomad/client_alloc_endpoint_test.go index 3e97740cf6a1..6ed7575a79ca 100644 --- a/nomad/client_alloc_endpoint_test.go +++ b/nomad/client_alloc_endpoint_test.go @@ -1,8 +1,13 @@ package nomad import ( + "encoding/json" "fmt" + "io" + "net" + "strings" "testing" + "time" msgpackrpc "github.com/hashicorp/net-rpc-msgpackrpc" "github.com/hashicorp/nomad/acl" @@ -12,9 +17,12 @@ import ( "github.com/hashicorp/nomad/helper/uuid" "github.com/hashicorp/nomad/nomad/mock" "github.com/hashicorp/nomad/nomad/structs" + nstructs "github.com/hashicorp/nomad/nomad/structs" + "github.com/hashicorp/nomad/plugins/drivers" "github.com/hashicorp/nomad/testutil" "github.com/kr/pretty" "github.com/stretchr/testify/require" + "github.com/ugorji/go/codec" ) func TestClientAllocations_GarbageCollectAll_Local(t *testing.T) { @@ -1040,3 +1048,179 @@ func TestClientAllocations_Restart_ACL(t *testing.T) { }) } } + +// TestAlloc_ExecStreaming asserts that exec task requests are forwarded +// to appropriate server or remote regions +func TestAlloc_ExecStreaming(t *testing.T) { + t.Skip("try skipping") + t.Parallel() + + ////// Nomad clusters topology - not specific to test + localServer := TestServer(t, nil) + defer localServer.Shutdown() + + remoteServer := TestServer(t, func(c *Config) { + c.DevDisableBootstrap = true + }) + defer remoteServer.Shutdown() + + remoteRegionServer := TestServer(t, func(c *Config) { + c.Region = "two" + }) + defer remoteRegionServer.Shutdown() + + TestJoin(t, localServer, remoteServer) + TestJoin(t, localServer, remoteRegionServer) + testutil.WaitForLeader(t, localServer.RPC) + testutil.WaitForLeader(t, remoteServer.RPC) + testutil.WaitForLeader(t, remoteRegionServer.RPC) + + c, cleanup := client.TestClient(t, func(c *config.Config) { + c.Servers = []string{localServer.config.RPCAddr.String()} + }) + defer cleanup() + + // Wait for the client to connect + testutil.WaitForResult(func() (bool, error) { + nodes := remoteServer.connectedNodes() + return len(nodes) == 1, nil + }, func(err error) { + require.NoError(t, err, "failed to have a client") + }) + + // Force remove the connection locally in case it exists + remoteServer.nodeConnsLock.Lock() + delete(remoteServer.nodeConns, c.NodeID()) + remoteServer.nodeConnsLock.Unlock() + + ///// Start task + a := mock.BatchAlloc() + a.NodeID = c.NodeID() + a.Job.Type = structs.JobTypeBatch + a.Job.TaskGroups[0].Count = 1 + a.Job.TaskGroups[0].Tasks[0].Config = map[string]interface{}{ + "run_for": "20s", + "exec_command": map[string]interface{}{ + "run_for": "1ms", + "stdout_string": "expected output", + "exit_code": 3, + }, + } + + // Upsert the allocation + localState := localServer.State() + require.Nil(t, localState.UpsertJob(999, a.Job)) + require.Nil(t, localState.UpsertAllocs(1003, []*structs.Allocation{a})) + remoteState := remoteServer.State() + require.Nil(t, remoteState.UpsertJob(999, a.Job)) + require.Nil(t, remoteState.UpsertAllocs(1003, []*structs.Allocation{a})) + + // Wait for the client to run the allocation + testutil.WaitForResult(func() (bool, error) { + alloc, err := localState.AllocByID(nil, a.ID) + if err != nil { + return false, err + } + if alloc == nil { + return false, fmt.Errorf("unknown alloc") + } + if alloc.ClientStatus != structs.AllocClientStatusRunning { + return false, fmt.Errorf("alloc client status: %v", alloc.ClientStatus) + } + + return true, nil + }, func(err error) { + require.NoError(t, err, "task didn't start yet") + }) + + ///////// Actually run query now + cases := []struct { + name string + rpc func(string) (structs.StreamingRpcHandler, error) + }{ + {"client", c.StreamingRpcHandler}, + {"local_server", localServer.StreamingRpcHandler}, + {"remote_server", remoteServer.StreamingRpcHandler}, + {"remote_region", remoteRegionServer.StreamingRpcHandler}, + } + + for _, tc := range cases { + t.Run(tc.name, func(t *testing.T) { + + // Make the request + req := &cstructs.AllocExecRequest{ + AllocID: a.ID, + Task: a.Job.TaskGroups[0].Tasks[0].Name, + Tty: true, + Cmd: []string{"placeholder command"}, + QueryOptions: nstructs.QueryOptions{Region: "global"}, + } + + // Get the handler + handler, err := tc.rpc("Allocations.Exec") + require.Nil(t, err) + + // Create a pipe + p1, p2 := net.Pipe() + defer p1.Close() + defer p2.Close() + + errCh := make(chan error) + frames := make(chan *drivers.ExecTaskStreamingResponseMsg) + + // Start the handler + go handler(p2) + go decodeFrames(t, p1, frames, errCh) + + // Send the request + encoder := codec.NewEncoder(p1, nstructs.MsgpackHandle) + require.Nil(t, encoder.Encode(req)) + + timeout := time.After(3 * time.Second) + + OUTER: + for { + select { + case <-timeout: + require.FailNow(t, "timed out before getting exit code") + case err := <-errCh: + require.NoError(t, err) + case f := <-frames: + if f.Exited && f.Result != nil { + code := int(f.Result.ExitCode) + require.Equal(t, 3, code) + break OUTER + } + } + } + }) + } +} + +func decodeFrames(t *testing.T, p1 net.Conn, frames chan<- *drivers.ExecTaskStreamingResponseMsg, errCh chan<- error) { + // Start the decoder + decoder := codec.NewDecoder(p1, nstructs.MsgpackHandle) + + for { + var msg cstructs.StreamErrWrapper + if err := decoder.Decode(&msg); err != nil { + if err == io.EOF || strings.Contains(err.Error(), "closed") { + return + } + t.Logf("received error decoding: %#v", err) + + errCh <- fmt.Errorf("error decoding: %v", err) + return + } + + if msg.Error != nil { + errCh <- msg.Error + continue + } + + var frame drivers.ExecTaskStreamingResponseMsg + json.Unmarshal(msg.Payload, &frame) + t.Logf("received message: %#v", msg) + frames <- &frame + } +} diff --git a/nomad/client_fs_endpoint.go b/nomad/client_fs_endpoint.go index 120927a545d0..ef152e44a82c 100644 --- a/nomad/client_fs_endpoint.go +++ b/nomad/client_fs_endpoint.go @@ -33,7 +33,7 @@ func (f *FileSystem) register() { // handleStreamResultError is a helper for sending an error with a potential // error code. The transmission of the error is ignored if the error has been // generated by the closing of the underlying transport. -func (f *FileSystem) handleStreamResultError(err error, code *int64, encoder *codec.Encoder) { +func handleStreamResultError(err error, code *int64, encoder *codec.Encoder) { // Nothing to do as the conn is closed if err == io.EOF || strings.Contains(err.Error(), "closed") { return @@ -48,7 +48,7 @@ func (f *FileSystem) handleStreamResultError(err error, code *int64, encoder *co // forwardRegionStreamingRpc is used to make a streaming RPC to a different // region. It looks up the allocation in the remote region to determine what // remote server can route the request. -func (f *FileSystem) forwardRegionStreamingRpc(conn io.ReadWriteCloser, +func forwardRegionStreamingRpc(fsrv *Server, conn io.ReadWriteCloser, encoder *codec.Encoder, args interface{}, method, allocID string, qo *structs.QueryOptions) { // Request the allocation from the target region allocReq := &structs.AllocSpecificRequest{ @@ -56,31 +56,31 @@ func (f *FileSystem) forwardRegionStreamingRpc(conn io.ReadWriteCloser, QueryOptions: *qo, } var allocResp structs.SingleAllocResponse - if err := f.srv.forwardRegion(qo.RequestRegion(), "Alloc.GetAlloc", allocReq, &allocResp); err != nil { - f.handleStreamResultError(err, nil, encoder) + if err := fsrv.forwardRegion(qo.RequestRegion(), "Alloc.GetAlloc", allocReq, &allocResp); err != nil { + handleStreamResultError(err, nil, encoder) return } if allocResp.Alloc == nil { - f.handleStreamResultError(structs.NewErrUnknownAllocation(allocID), helper.Int64ToPtr(404), encoder) + handleStreamResultError(structs.NewErrUnknownAllocation(allocID), helper.Int64ToPtr(404), encoder) return } // Determine the Server that has a connection to the node. - srv, err := f.srv.serverWithNodeConn(allocResp.Alloc.NodeID, qo.RequestRegion()) + srv, err := fsrv.serverWithNodeConn(allocResp.Alloc.NodeID, qo.RequestRegion()) if err != nil { var code *int64 if structs.IsErrNoNodeConn(err) { code = helper.Int64ToPtr(404) } - f.handleStreamResultError(err, code, encoder) + handleStreamResultError(err, code, encoder) return } // Get a connection to the server - srvConn, err := f.srv.streamingRpc(srv, method) + srvConn, err := fsrv.streamingRpc(srv, method) if err != nil { - f.handleStreamResultError(err, nil, encoder) + handleStreamResultError(err, nil, encoder) return } defer srvConn.Close() @@ -88,7 +88,7 @@ func (f *FileSystem) forwardRegionStreamingRpc(conn io.ReadWriteCloser, // Send the request. outEncoder := codec.NewEncoder(srvConn, structs.MsgpackHandle) if err := outEncoder.Encode(args); err != nil { - f.handleStreamResultError(err, nil, encoder) + handleStreamResultError(err, nil, encoder) return } @@ -217,46 +217,46 @@ func (f *FileSystem) stream(conn io.ReadWriteCloser) { encoder := codec.NewEncoder(conn, structs.MsgpackHandle) if err := decoder.Decode(&args); err != nil { - f.handleStreamResultError(err, helper.Int64ToPtr(500), encoder) + handleStreamResultError(err, helper.Int64ToPtr(500), encoder) return } // Check if we need to forward to a different region if r := args.RequestRegion(); r != f.srv.Region() { - f.forwardRegionStreamingRpc(conn, encoder, &args, "FileSystem.Stream", + forwardRegionStreamingRpc(f.srv, conn, encoder, &args, "FileSystem.Stream", args.AllocID, &args.QueryOptions) return } // Check node read permissions if aclObj, err := f.srv.ResolveToken(args.AuthToken); err != nil { - f.handleStreamResultError(err, nil, encoder) + handleStreamResultError(err, nil, encoder) return } else if aclObj != nil && !aclObj.AllowNsOp(args.Namespace, acl.NamespaceCapabilityReadFS) { - f.handleStreamResultError(structs.ErrPermissionDenied, nil, encoder) + handleStreamResultError(structs.ErrPermissionDenied, nil, encoder) return } // Verify the arguments. if args.AllocID == "" { - f.handleStreamResultError(errors.New("missing AllocID"), helper.Int64ToPtr(400), encoder) + handleStreamResultError(errors.New("missing AllocID"), helper.Int64ToPtr(400), encoder) return } // Retrieve the allocation snap, err := f.srv.State().Snapshot() if err != nil { - f.handleStreamResultError(err, nil, encoder) + handleStreamResultError(err, nil, encoder) return } alloc, err := snap.AllocByID(nil, args.AllocID) if err != nil { - f.handleStreamResultError(err, nil, encoder) + handleStreamResultError(err, nil, encoder) return } if alloc == nil { - f.handleStreamResultError(structs.NewErrUnknownAllocation(args.AllocID), helper.Int64ToPtr(404), encoder) + handleStreamResultError(structs.NewErrUnknownAllocation(args.AllocID), helper.Int64ToPtr(404), encoder) return } nodeID := alloc.NodeID @@ -264,18 +264,18 @@ func (f *FileSystem) stream(conn io.ReadWriteCloser) { // Make sure Node is valid and new enough to support RPC node, err := snap.NodeByID(nil, nodeID) if err != nil { - f.handleStreamResultError(err, helper.Int64ToPtr(500), encoder) + handleStreamResultError(err, helper.Int64ToPtr(500), encoder) return } if node == nil { err := fmt.Errorf("Unknown node %q", nodeID) - f.handleStreamResultError(err, helper.Int64ToPtr(400), encoder) + handleStreamResultError(err, helper.Int64ToPtr(400), encoder) return } if err := nodeSupportsRpc(node); err != nil { - f.handleStreamResultError(err, helper.Int64ToPtr(400), encoder) + handleStreamResultError(err, helper.Int64ToPtr(400), encoder) return } @@ -291,14 +291,14 @@ func (f *FileSystem) stream(conn io.ReadWriteCloser) { if structs.IsErrNoNodeConn(err) { code = helper.Int64ToPtr(404) } - f.handleStreamResultError(err, code, encoder) + handleStreamResultError(err, code, encoder) return } // Get a connection to the server conn, err := f.srv.streamingRpc(srv, "FileSystem.Stream") if err != nil { - f.handleStreamResultError(err, nil, encoder) + handleStreamResultError(err, nil, encoder) return } @@ -306,7 +306,7 @@ func (f *FileSystem) stream(conn io.ReadWriteCloser) { } else { stream, err := NodeStreamingRpc(state.Session, "FileSystem.Stream") if err != nil { - f.handleStreamResultError(err, nil, encoder) + handleStreamResultError(err, nil, encoder) return } clientConn = stream @@ -316,7 +316,7 @@ func (f *FileSystem) stream(conn io.ReadWriteCloser) { // Send the request. outEncoder := codec.NewEncoder(clientConn, structs.MsgpackHandle) if err := outEncoder.Encode(args); err != nil { - f.handleStreamResultError(err, nil, encoder) + handleStreamResultError(err, nil, encoder) return } @@ -335,50 +335,50 @@ func (f *FileSystem) logs(conn io.ReadWriteCloser) { encoder := codec.NewEncoder(conn, structs.MsgpackHandle) if err := decoder.Decode(&args); err != nil { - f.handleStreamResultError(err, helper.Int64ToPtr(500), encoder) + handleStreamResultError(err, helper.Int64ToPtr(500), encoder) return } // Check if we need to forward to a different region if r := args.RequestRegion(); r != f.srv.Region() { - f.forwardRegionStreamingRpc(conn, encoder, &args, "FileSystem.Logs", + forwardRegionStreamingRpc(f.srv, conn, encoder, &args, "FileSystem.Logs", args.AllocID, &args.QueryOptions) return } // Check node read permissions if aclObj, err := f.srv.ResolveToken(args.AuthToken); err != nil { - f.handleStreamResultError(err, nil, encoder) + handleStreamResultError(err, nil, encoder) return } else if aclObj != nil { readfs := aclObj.AllowNsOp(args.QueryOptions.Namespace, acl.NamespaceCapabilityReadFS) logs := aclObj.AllowNsOp(args.QueryOptions.Namespace, acl.NamespaceCapabilityReadLogs) if !readfs && !logs { - f.handleStreamResultError(structs.ErrPermissionDenied, nil, encoder) + handleStreamResultError(structs.ErrPermissionDenied, nil, encoder) return } } // Verify the arguments. if args.AllocID == "" { - f.handleStreamResultError(errors.New("missing AllocID"), helper.Int64ToPtr(400), encoder) + handleStreamResultError(errors.New("missing AllocID"), helper.Int64ToPtr(400), encoder) return } // Retrieve the allocation snap, err := f.srv.State().Snapshot() if err != nil { - f.handleStreamResultError(err, nil, encoder) + handleStreamResultError(err, nil, encoder) return } alloc, err := snap.AllocByID(nil, args.AllocID) if err != nil { - f.handleStreamResultError(err, nil, encoder) + handleStreamResultError(err, nil, encoder) return } if alloc == nil { - f.handleStreamResultError(structs.NewErrUnknownAllocation(args.AllocID), helper.Int64ToPtr(404), encoder) + handleStreamResultError(structs.NewErrUnknownAllocation(args.AllocID), helper.Int64ToPtr(404), encoder) return } nodeID := alloc.NodeID @@ -386,18 +386,18 @@ func (f *FileSystem) logs(conn io.ReadWriteCloser) { // Make sure Node is valid and new enough to support RPC node, err := snap.NodeByID(nil, nodeID) if err != nil { - f.handleStreamResultError(err, helper.Int64ToPtr(500), encoder) + handleStreamResultError(err, helper.Int64ToPtr(500), encoder) return } if node == nil { err := fmt.Errorf("Unknown node %q", nodeID) - f.handleStreamResultError(err, helper.Int64ToPtr(400), encoder) + handleStreamResultError(err, helper.Int64ToPtr(400), encoder) return } if err := nodeSupportsRpc(node); err != nil { - f.handleStreamResultError(err, helper.Int64ToPtr(400), encoder) + handleStreamResultError(err, helper.Int64ToPtr(400), encoder) return } @@ -413,14 +413,14 @@ func (f *FileSystem) logs(conn io.ReadWriteCloser) { if structs.IsErrNoNodeConn(err) { code = helper.Int64ToPtr(404) } - f.handleStreamResultError(err, code, encoder) + handleStreamResultError(err, code, encoder) return } // Get a connection to the server conn, err := f.srv.streamingRpc(srv, "FileSystem.Logs") if err != nil { - f.handleStreamResultError(err, nil, encoder) + handleStreamResultError(err, nil, encoder) return } @@ -428,7 +428,7 @@ func (f *FileSystem) logs(conn io.ReadWriteCloser) { } else { stream, err := NodeStreamingRpc(state.Session, "FileSystem.Logs") if err != nil { - f.handleStreamResultError(err, nil, encoder) + handleStreamResultError(err, nil, encoder) return } clientConn = stream @@ -438,7 +438,7 @@ func (f *FileSystem) logs(conn io.ReadWriteCloser) { // Send the request. outEncoder := codec.NewEncoder(clientConn, structs.MsgpackHandle) if err := outEncoder.Encode(args); err != nil { - f.handleStreamResultError(err, nil, encoder) + handleStreamResultError(err, nil, encoder) return } diff --git a/nomad/server.go b/nomad/server.go index c35124c510cb..21dd5178b294 100644 --- a/nomad/server.go +++ b/nomad/server.go @@ -1027,6 +1027,7 @@ func (s *Server) setupRpcServer(server *rpc.Server, ctx *RPCContext) { // Client endpoints s.staticEndpoints.ClientStats = &ClientStats{srv: s, logger: s.logger.Named("client_stats")} s.staticEndpoints.ClientAllocations = &ClientAllocations{srv: s, logger: s.logger.Named("client_allocs")} + s.staticEndpoints.ClientAllocations.register() // Streaming endpoints s.staticEndpoints.FileSystem = &FileSystem{srv: s, logger: s.logger.Named("client_fs")}