Skip to content

Commit

Permalink
server: server forwarding logic for nomad exec endpoint
Browse files Browse the repository at this point in the history
  • Loading branch information
Mahmood Ali committed Apr 30, 2019
1 parent 4ade58e commit 781b255
Show file tree
Hide file tree
Showing 4 changed files with 355 additions and 40 deletions.
130 changes: 130 additions & 0 deletions nomad/client_alloc_endpoint.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,11 +2,16 @@ package nomad

import (
"errors"
"fmt"
"io"
"net"
"time"

metrics "github.com/armon/go-metrics"
log "github.com/hashicorp/go-hclog"
cstructs "github.com/hashicorp/nomad/client/structs"
"github.com/hashicorp/nomad/helper"
"github.com/ugorji/go/codec"

"github.com/hashicorp/nomad/acl"
"github.com/hashicorp/nomad/nomad/structs"
Expand All @@ -19,6 +24,10 @@ type ClientAllocations struct {
logger log.Logger
}

func (a *ClientAllocations) register() {
a.srv.streamingRpcs.Register("Allocations.Exec", a.exec)
}

// GarbageCollectAll is used to garbage collect all allocations on a client.
func (a *ClientAllocations) GarbageCollectAll(args *structs.NodeSpecificRequest, reply *structs.GenericResponse) error {
// We only allow stale reads since the only potentially stale information is
Expand Down Expand Up @@ -287,3 +296,124 @@ func (a *ClientAllocations) Stats(args *cstructs.AllocStatsRequest, reply *cstru
// Make the RPC
return NodeRpc(state.Session, "Allocations.Stats", args, reply)
}

func (a *ClientAllocations) exec(conn io.ReadWriteCloser) {
defer conn.Close()
defer metrics.MeasureSince([]string{"nomad", "alloc", "exec"}, time.Now())

// Decode the arguments
var args cstructs.AllocExecRequest
decoder := codec.NewDecoder(conn, structs.MsgpackHandle)
encoder := codec.NewEncoder(conn, structs.MsgpackHandle)

if err := decoder.Decode(&args); err != nil {
handleStreamResultError(err, helper.Int64ToPtr(500), encoder)
return
}

// Check if we need to forward to a different region
if r := args.RequestRegion(); r != a.srv.Region() {
forwardRegionStreamingRpc(a.srv, conn, encoder, &args, "Allocations.Exec",
args.AllocID, &args.QueryOptions)
return
}

// Check node read permissions
if aclObj, err := a.srv.ResolveToken(args.AuthToken); err != nil {
handleStreamResultError(err, nil, encoder)
return
} else if aclObj != nil {
// client ultimately checks if AllocNodeExec is required
exec := aclObj.AllowNsOp(args.QueryOptions.Namespace, acl.NamespaceCapabilityAllocExec)
if !exec {
handleStreamResultError(structs.ErrPermissionDenied, nil, encoder)
return
}
}

// Verify the arguments.
if args.AllocID == "" {
handleStreamResultError(errors.New("missing AllocID"), helper.Int64ToPtr(400), encoder)
return
}

// Retrieve the allocation
snap, err := a.srv.State().Snapshot()
if err != nil {
handleStreamResultError(err, nil, encoder)
return
}

alloc, err := snap.AllocByID(nil, args.AllocID)
if err != nil {
handleStreamResultError(err, nil, encoder)
return
}
if alloc == nil {
handleStreamResultError(structs.NewErrUnknownAllocation(args.AllocID), helper.Int64ToPtr(404), encoder)
return
}
nodeID := alloc.NodeID

// Make sure Node is valid and new enough to support RPC
node, err := snap.NodeByID(nil, nodeID)
if err != nil {
handleStreamResultError(err, helper.Int64ToPtr(500), encoder)
return
}

if node == nil {
err := fmt.Errorf("Unknown node %q", nodeID)
handleStreamResultError(err, helper.Int64ToPtr(400), encoder)
return
}

if err := nodeSupportsRpc(node); err != nil {
handleStreamResultError(err, helper.Int64ToPtr(400), encoder)
return
}

// Get the connection to the client either by forwarding to another server
// or creating a direct stream
var clientConn net.Conn
state, ok := a.srv.getNodeConn(nodeID)
if !ok {
// Determine the Server that has a connection to the node.
srv, err := a.srv.serverWithNodeConn(nodeID, a.srv.Region())
if err != nil {
var code *int64
if structs.IsErrNoNodeConn(err) {
code = helper.Int64ToPtr(404)
}
handleStreamResultError(err, code, encoder)
return
}

// Get a connection to the server
conn, err := a.srv.streamingRpc(srv, "Allocations.Exec")
if err != nil {
handleStreamResultError(err, nil, encoder)
return
}

clientConn = conn
} else {
stream, err := NodeStreamingRpc(state.Session, "Allocations.Exec")
if err != nil {
handleStreamResultError(err, nil, encoder)
return
}
clientConn = stream
}
defer clientConn.Close()

// Send the request.
outEncoder := codec.NewEncoder(clientConn, structs.MsgpackHandle)
if err := outEncoder.Encode(args); err != nil {
handleStreamResultError(err, nil, encoder)
return
}

structs.Bridge(conn, clientConn)
return
}
184 changes: 184 additions & 0 deletions nomad/client_alloc_endpoint_test.go
Original file line number Diff line number Diff line change
@@ -1,8 +1,13 @@
package nomad

import (
"encoding/json"
"fmt"
"io"
"net"
"strings"
"testing"
"time"

msgpackrpc "github.com/hashicorp/net-rpc-msgpackrpc"
"github.com/hashicorp/nomad/acl"
Expand All @@ -12,9 +17,12 @@ import (
"github.com/hashicorp/nomad/helper/uuid"
"github.com/hashicorp/nomad/nomad/mock"
"github.com/hashicorp/nomad/nomad/structs"
nstructs "github.com/hashicorp/nomad/nomad/structs"
"github.com/hashicorp/nomad/plugins/drivers"
"github.com/hashicorp/nomad/testutil"
"github.com/kr/pretty"
"github.com/stretchr/testify/require"
"github.com/ugorji/go/codec"
)

func TestClientAllocations_GarbageCollectAll_Local(t *testing.T) {
Expand Down Expand Up @@ -1040,3 +1048,179 @@ func TestClientAllocations_Restart_ACL(t *testing.T) {
})
}
}

// TestAlloc_ExecStreaming asserts that exec task requests are forwarded
// to appropriate server or remote regions
func TestAlloc_ExecStreaming(t *testing.T) {
t.Skip("try skipping")
t.Parallel()

////// Nomad clusters topology - not specific to test
localServer := TestServer(t, nil)
defer localServer.Shutdown()

remoteServer := TestServer(t, func(c *Config) {
c.DevDisableBootstrap = true
})
defer remoteServer.Shutdown()

remoteRegionServer := TestServer(t, func(c *Config) {
c.Region = "two"
})
defer remoteRegionServer.Shutdown()

TestJoin(t, localServer, remoteServer)
TestJoin(t, localServer, remoteRegionServer)
testutil.WaitForLeader(t, localServer.RPC)
testutil.WaitForLeader(t, remoteServer.RPC)
testutil.WaitForLeader(t, remoteRegionServer.RPC)

c, cleanup := client.TestClient(t, func(c *config.Config) {
c.Servers = []string{localServer.config.RPCAddr.String()}
})
defer cleanup()

// Wait for the client to connect
testutil.WaitForResult(func() (bool, error) {
nodes := remoteServer.connectedNodes()
return len(nodes) == 1, nil
}, func(err error) {
require.NoError(t, err, "failed to have a client")
})

// Force remove the connection locally in case it exists
remoteServer.nodeConnsLock.Lock()
delete(remoteServer.nodeConns, c.NodeID())
remoteServer.nodeConnsLock.Unlock()

///// Start task
a := mock.BatchAlloc()
a.NodeID = c.NodeID()
a.Job.Type = structs.JobTypeBatch
a.Job.TaskGroups[0].Count = 1
a.Job.TaskGroups[0].Tasks[0].Config = map[string]interface{}{
"run_for": "20s",
"exec_command": map[string]interface{}{
"run_for": "1ms",
"stdout_string": "expected output",
"exit_code": 3,
},
}

// Upsert the allocation
localState := localServer.State()
require.Nil(t, localState.UpsertJob(999, a.Job))
require.Nil(t, localState.UpsertAllocs(1003, []*structs.Allocation{a}))
remoteState := remoteServer.State()
require.Nil(t, remoteState.UpsertJob(999, a.Job))
require.Nil(t, remoteState.UpsertAllocs(1003, []*structs.Allocation{a}))

// Wait for the client to run the allocation
testutil.WaitForResult(func() (bool, error) {
alloc, err := localState.AllocByID(nil, a.ID)
if err != nil {
return false, err
}
if alloc == nil {
return false, fmt.Errorf("unknown alloc")
}
if alloc.ClientStatus != structs.AllocClientStatusRunning {
return false, fmt.Errorf("alloc client status: %v", alloc.ClientStatus)
}

return true, nil
}, func(err error) {
require.NoError(t, err, "task didn't start yet")
})

///////// Actually run query now
cases := []struct {
name string
rpc func(string) (structs.StreamingRpcHandler, error)
}{
{"client", c.StreamingRpcHandler},
{"local_server", localServer.StreamingRpcHandler},
{"remote_server", remoteServer.StreamingRpcHandler},
{"remote_region", remoteRegionServer.StreamingRpcHandler},
}

for _, tc := range cases {
t.Run(tc.name, func(t *testing.T) {

// Make the request
req := &cstructs.AllocExecRequest{
AllocID: a.ID,
Task: a.Job.TaskGroups[0].Tasks[0].Name,
Tty: true,
Cmd: []string{"placeholder command"},
QueryOptions: nstructs.QueryOptions{Region: "global"},
}

// Get the handler
handler, err := tc.rpc("Allocations.Exec")
require.Nil(t, err)

// Create a pipe
p1, p2 := net.Pipe()
defer p1.Close()
defer p2.Close()

errCh := make(chan error)
frames := make(chan *drivers.ExecTaskStreamingResponseMsg)

// Start the handler
go handler(p2)
go decodeFrames(t, p1, frames, errCh)

// Send the request
encoder := codec.NewEncoder(p1, nstructs.MsgpackHandle)
require.Nil(t, encoder.Encode(req))

timeout := time.After(3 * time.Second)

OUTER:
for {
select {
case <-timeout:
require.FailNow(t, "timed out before getting exit code")
case err := <-errCh:
require.NoError(t, err)
case f := <-frames:
if f.Exited && f.Result != nil {
code := int(f.Result.ExitCode)
require.Equal(t, 3, code)
break OUTER
}
}
}
})
}
}

func decodeFrames(t *testing.T, p1 net.Conn, frames chan<- *drivers.ExecTaskStreamingResponseMsg, errCh chan<- error) {
// Start the decoder
decoder := codec.NewDecoder(p1, nstructs.MsgpackHandle)

for {
var msg cstructs.StreamErrWrapper
if err := decoder.Decode(&msg); err != nil {
if err == io.EOF || strings.Contains(err.Error(), "closed") {
return
}
t.Logf("received error decoding: %#v", err)

errCh <- fmt.Errorf("error decoding: %v", err)
return
}

if msg.Error != nil {
errCh <- msg.Error
continue
}

var frame drivers.ExecTaskStreamingResponseMsg
json.Unmarshal(msg.Payload, &frame)
t.Logf("received message: %#v", msg)
frames <- &frame
}
}
Loading

0 comments on commit 781b255

Please sign in to comment.