Skip to content

Commit

Permalink
add api support for nomad exec
Browse files Browse the repository at this point in the history
Adds nomad exec support in our API, by hitting the websocket endpoint.

We introduce API structs that correspond to the drivers streaming exec structs.

For creating the websocket connection, we reuse the transport setting from api
http client.
  • Loading branch information
Mahmood Ali committed Apr 30, 2019
1 parent a8e460a commit e78f7ce
Show file tree
Hide file tree
Showing 3 changed files with 389 additions and 0 deletions.
231 changes: 231 additions & 0 deletions api/allocations.go
Original file line number Diff line number Diff line change
@@ -1,9 +1,18 @@
package api

import (
"context"
"encoding/json"
"errors"
"fmt"
"io"
"net"
"sort"
"strconv"
"sync"
"time"

"github.com/gorilla/websocket"
)

var (
Expand Down Expand Up @@ -61,6 +70,189 @@ func (a *Allocations) Info(allocID string, q *QueryOptions) (*Allocation, *Query
return &resp, qm, nil
}

func (a *Allocations) Exec(ctx context.Context,
alloc *Allocation, task string, tty bool, command []string,
stdin io.Reader, stdout, stderr io.Writer,
terminalSizeCh <-chan TerminalSize, q *QueryOptions) (exitCode int, err error) {

ctx, cancelFn := context.WithCancel(ctx)
defer cancelFn()

errCh := make(chan error, 1)

sender, output := a.execFrames(ctx, alloc, task, tty, command, errCh, q)

select {
case err := <-errCh:
return -1, err
default:
}

// forwarding stdin
go func() {

bytes := make([]byte, 2048)
for {
if ctx.Err() != nil {
return
}

input := ExecStreamingInput{Stdin: &ExecStreamingIOOperation{}}

n, err := stdin.Read(bytes)
if err == io.EOF {
input.Stdin.Close = true
sender(&input)
return
} else if err != nil {
errCh <- err
return
}

input.Stdin.Data = bytes[:n]
sender(&input)
}
}()

// forwarding terminal size and heartbeats
go func() {
for {
resizeInput := ExecStreamingInput{}

select {
case <-ctx.Done():
return
case size, ok := <-terminalSizeCh:
if !ok {
continue
}
resizeInput.TTYSize = &size
sender(&resizeInput)

// heartbeat message
case <-time.After(10 * time.Second):
sender(&execStreamingInputHeartbeat)
}

}
}()

for {
select {
case err := <-errCh:
// drop websocket code, not relevant to user
if wsErr, ok := err.(*websocket.CloseError); ok && wsErr.Text != "" {
return -1, errors.New(wsErr.Text)
}
return -1, err
case <-ctx.Done():
return -1, ctx.Err()
case frame, ok := <-output:
if !ok {
return -1, nil
}

switch {
case frame.Stdout != nil && len(frame.Stdout.Data) != 0:
stdout.Write(frame.Stdout.Data)
case frame.Stderr != nil && len(frame.Stderr.Data) != 0:
stderr.Write(frame.Stderr.Data)
case frame.Stdout != nil && frame.Stdout.Close:
// don't really do anything
case frame.Stderr != nil && frame.Stderr.Close:
// don't really do anything
case frame.Exited && frame.Result != nil:
return frame.Result.ExitCode, nil
default:
// unexpected event, TODO: log it?!
}
}
}
}

func (a *Allocations) execFrames(ctx context.Context, alloc *Allocation, task string, tty bool, command []string,
errCh chan<- error, q *QueryOptions) (sendFn func(interface{}) error, output <-chan *ExecStreamingOutput) {

nodeClient, err := a.client.GetNodeClientWithTimeout(alloc.NodeID, ClientConnTimeout, q)
if err != nil {
errCh <- err
return nil, nil
}

if q == nil {
q = &QueryOptions{}
}
if q.Params == nil {
q.Params = make(map[string]string)
}

commandBytes, err := json.Marshal(command)
if err != nil {
errCh <- fmt.Errorf("failed to marshal command: %s", err)
return nil, nil
}

q.Params["tty"] = strconv.FormatBool(tty)
q.Params["task"] = task
q.Params["command"] = string(commandBytes)

reqPath := fmt.Sprintf("/v1/client/allocation/%s/exec", alloc.ID)

conn, _, err := nodeClient.websocket(reqPath, q)
if err != nil {
// There was a networking error when talking directly to the client.
if _, ok := err.(net.Error); !ok {
errCh <- err
return nil, nil
}

conn, _, err = a.client.websocket(reqPath, q)
if err != nil {
errCh <- err
return nil, nil
}
}

// Create the output channel
frames := make(chan *ExecStreamingOutput, 10)

go func() {
defer conn.Close()
for {
// Check if we have been cancelled
select {
case <-ctx.Done():
return
default:
}

// Decode the next frame
var frame ExecStreamingOutput
err := conn.ReadJSON(&frame)
if websocket.IsCloseError(err, websocket.CloseNormalClosure) {
close(frames)
return
} else if err != nil {
errCh <- err
return
}

frames <- &frame
}
}()

var sendLock sync.Mutex
send := func(v interface{}) error {
sendLock.Lock()
defer sendLock.Unlock()

return conn.WriteJSON(v)
}

return send, frames

}

func (a *Allocations) Stats(alloc *Allocation, q *QueryOptions) (*AllocResourceUsage, error) {
var resp AllocResourceUsage
path := fmt.Sprintf("/v1/client/allocation/%s/stats", alloc.ID)
Expand Down Expand Up @@ -339,3 +531,42 @@ type DesiredTransition struct {
func (d DesiredTransition) ShouldMigrate() bool {
return d.Migrate != nil && *d.Migrate
}

// ExecStreamingIOOperation represents a stream write operation: either appending data or close.
type ExecStreamingIOOperation struct {
Data []byte `json:"data,omitempty"`
Close bool `json:"close,omitempty"`
}

// TerminalSize represents the size of the terminal
type TerminalSize struct {
Height int32 `json:"height,omitempty"`
Width int32 `json:"width,omitempty"`
}

var execStreamingInputHeartbeat = ExecStreamingInput{}

// ExecStreamingInput represents user input to be sent to nomad exec handler.
//
// At most one field should be set.
type ExecStreamingInput struct {
Stdin *ExecStreamingIOOperation `json:"stdin,omitempty"`
TTYSize *TerminalSize `json:"tty_size,omitempty"`
}

// ExecStreamingExitResults captures the exit code of just completed nomad exec command
type ExecStreamingExitResult struct {
ExitCode int `json:"exit_code"`
}

// ExecStreamingInput represents an output streaming entity, e.g. stdout/stderr update or termination
//
// At most one of these fields should be set: `Stdout`, `Stderr`, or `Result`.
// If `Exited` is true, then `Result` is non-nil, and other fields are nil.
type ExecStreamingOutput struct {
Stdout *ExecStreamingIOOperation `json:"stdout,omitempty"`
Stderr *ExecStreamingIOOperation `json:"stderr,omitempty"`

Exited bool `json:"exited,omitempty"`
Result *ExecStreamingExitResult `json:"result,omitempty"`
}
58 changes: 58 additions & 0 deletions api/api.go
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@ import (
"strings"
"time"

"github.com/gorilla/websocket"
cleanhttp "github.com/hashicorp/go-cleanhttp"
rootcerts "github.com/hashicorp/go-rootcerts"
)
Expand Down Expand Up @@ -655,6 +656,63 @@ func (c *Client) rawQuery(endpoint string, q *QueryOptions) (io.ReadCloser, erro
return resp.Body, nil
}

// websocket makes a websocket request to the specific endpoint
func (c *Client) websocket(endpoint string, q *QueryOptions) (*websocket.Conn, *http.Response, error) {

transport, ok := c.config.httpClient.Transport.(*http.Transport)
if !ok {
return nil, nil, fmt.Errorf("unsupported transport")
}
dialer := websocket.Dialer{
ReadBufferSize: 4096,
WriteBufferSize: 4096,
HandshakeTimeout: c.config.httpClient.Timeout,

// values to inherit from http client configuration
NetDial: transport.Dial,
NetDialContext: transport.DialContext,
Proxy: transport.Proxy,
TLSClientConfig: transport.TLSClientConfig,
}

// build request object for header and parameters
r, err := c.newRequest("GET", endpoint)
if err != nil {
return nil, nil, err
}
r.setQueryOptions(q)

rhttp, err := r.toHTTP()
if err != nil {
return nil, nil, err
}

// convert scheme
wsScheme := ""
switch rhttp.URL.Scheme {
case "http":
wsScheme = "ws"
case "https":
wsScheme = "wss"
default:
return nil, nil, fmt.Errorf("unsupported scheme: %v", rhttp.URL.Scheme)
}
rhttp.URL.Scheme = wsScheme

conn, resp, err := dialer.Dial(rhttp.URL.String(), rhttp.Header)

// check resp status code, as it's more informative than handshake error we get from ws library
if resp != nil && resp.StatusCode != 101 {
var buf bytes.Buffer
io.Copy(&buf, resp.Body)
resp.Body.Close()

return nil, nil, fmt.Errorf("Unexpected response code: %d (%s)", resp.StatusCode, buf.Bytes())
}

return conn, resp, err
}

// query is used to do a GET request against an endpoint
// and deserialize the response into an interface using
// standard Nomad conventions.
Expand Down
Loading

0 comments on commit e78f7ce

Please sign in to comment.