diff --git a/api/allocations.go b/api/allocations.go index 8dc837b390e2..128541ca370b 100644 --- a/api/allocations.go +++ b/api/allocations.go @@ -2,16 +2,10 @@ package api import ( "context" - "encoding/json" - "errors" "fmt" "io" "sort" - "strconv" - "sync" "time" - - "github.com/gorilla/websocket" ) var ( @@ -87,195 +81,22 @@ func (a *Allocations) Exec(ctx context.Context, stdin io.Reader, stdout, stderr io.Writer, terminalSizeCh <-chan TerminalSize, q *QueryOptions) (exitCode int, err error) { - ctx, cancelFn := context.WithCancel(ctx) - defer cancelFn() - - errCh := make(chan error, 4) - - sender, output := a.execFrames(ctx, alloc, task, tty, command, errCh, q) - - select { - case err := <-errCh: - return -2, err - default: - } - - // Errors resulting from sending input (in goroutines) are silently dropped. - // To mitigate this, extra care is needed to distinguish between actual send errors - // and from send errors due to command terminating and our race to detect failures. - // If we have an actual network failure or send a bad input, we'd get an - // error in the reading side of websocket. - - go func() { - - bytes := make([]byte, 2048) - for { - if ctx.Err() != nil { - return - } - - input := ExecStreamingInput{Stdin: &ExecStreamingIOOperation{}} - - n, err := stdin.Read(bytes) - - // always send data if we read some - if n != 0 { - input.Stdin.Data = bytes[:n] - sender(&input) - } - - // then handle error - if err == io.EOF { - // if n != 0, send data and we'll get n = 0 on next read - if n == 0 { - input.Stdin.Close = true - sender(&input) - return - } - } else if err != nil { - errCh <- err - return - } - } - }() - - // forwarding terminal size - go func() { - for { - resizeInput := ExecStreamingInput{} - - select { - case <-ctx.Done(): - return - case size, ok := <-terminalSizeCh: - if !ok { - return - } - resizeInput.TTYSize = &size - sender(&resizeInput) - } - - } - }() - - // send a heartbeat every 10 seconds - go func() { - for { - select { - case <-ctx.Done(): - return - // heartbeat message - case <-time.After(10 * time.Second): - sender(&execStreamingInputHeartbeat) - } - - } - }() - - for { - select { - case err := <-errCh: - // drop websocket code, not relevant to user - if wsErr, ok := err.(*websocket.CloseError); ok && wsErr.Text != "" { - return -2, errors.New(wsErr.Text) - } - return -2, err - case <-ctx.Done(): - return -2, ctx.Err() - case frame, ok := <-output: - if !ok { - return -2, errors.New("disconnected without receiving the exit code") - } - - switch { - case frame.Stdout != nil: - if len(frame.Stdout.Data) != 0 { - stdout.Write(frame.Stdout.Data) - } - // don't really do anything if stdout is closing - case frame.Stderr != nil: - if len(frame.Stderr.Data) != 0 { - stderr.Write(frame.Stderr.Data) - } - // don't really do anything if stderr is closing - case frame.Exited && frame.Result != nil: - return frame.Result.ExitCode, nil - default: - // noop - heartbeat - } - } - } -} + s := &execSession{ + client: a.client, + alloc: alloc, + task: task, + tty: tty, + command: command, -func (a *Allocations) execFrames(ctx context.Context, alloc *Allocation, task string, tty bool, command []string, - errCh chan<- error, q *QueryOptions) (sendFn func(*ExecStreamingInput) error, output <-chan *ExecStreamingOutput) { - nodeClient, _ := a.client.GetNodeClientWithTimeout(alloc.NodeID, ClientConnTimeout, q) + stdin: stdin, + stdout: stdout, + stderr: stderr, - if q == nil { - q = &QueryOptions{} - } - if q.Params == nil { - q.Params = make(map[string]string) + terminalSizeCh: terminalSizeCh, + q: q, } - commandBytes, err := json.Marshal(command) - if err != nil { - errCh <- fmt.Errorf("failed to marshal command: %s", err) - return nil, nil - } - - q.Params["tty"] = strconv.FormatBool(tty) - q.Params["task"] = task - q.Params["command"] = string(commandBytes) - - reqPath := fmt.Sprintf("/v1/client/allocation/%s/exec", alloc.ID) - - var conn *websocket.Conn - - if nodeClient != nil { - conn, _, _ = nodeClient.websocket(reqPath, q) - } - - if conn == nil { - conn, _, err = a.client.websocket(reqPath, q) - if err != nil { - errCh <- err - return nil, nil - } - } - - // Create the output channel - frames := make(chan *ExecStreamingOutput, 10) - - go func() { - defer conn.Close() - for ctx.Err() == nil { - - // Decode the next frame - var frame ExecStreamingOutput - err := conn.ReadJSON(&frame) - if websocket.IsCloseError(err, websocket.CloseNormalClosure) { - close(frames) - return - } else if err != nil { - errCh <- err - return - } - - frames <- &frame - } - }() - - var sendLock sync.Mutex - send := func(v *ExecStreamingInput) error { - sendLock.Lock() - defer sendLock.Unlock() - - return conn.WriteJSON(v) - } - - return send, frames - + return s.run(ctx) } func (a *Allocations) Stats(alloc *Allocation, q *QueryOptions) (*AllocResourceUsage, error) { diff --git a/api/allocations_exec.go b/api/allocations_exec.go new file mode 100644 index 000000000000..9f5e0e299784 --- /dev/null +++ b/api/allocations_exec.go @@ -0,0 +1,236 @@ +package api + +import ( + "context" + "encoding/json" + "errors" + "fmt" + "io" + "strconv" + "sync" + "time" + + "github.com/gorilla/websocket" +) + +type execSession struct { + client *Client + alloc *Allocation + task string + tty bool + command []string + + stdin io.Reader + stdout io.Writer + stderr io.Writer + + terminalSizeCh <-chan TerminalSize + + q *QueryOptions +} + +func (s *execSession) run(ctx context.Context) (exitCode int, err error) { + ctx, cancelFn := context.WithCancel(ctx) + defer cancelFn() + + conn, err := s.startConnection() + if err != nil { + return -2, err + } + defer conn.Close() + + sendErrCh := s.startTransmit(ctx, conn) + exitCh, recvErrCh := s.startReceiving(ctx, conn) + + for { + select { + case <-ctx.Done(): + return -2, ctx.Err() + case exitCode := <-exitCh: + return exitCode, nil + case recvErr := <-recvErrCh: + // drop websocket code, not relevant to user + if wsErr, ok := recvErr.(*websocket.CloseError); ok && wsErr.Text != "" { + return -2, errors.New(wsErr.Text) + } + + return -2, recvErr + case sendErr := <-sendErrCh: + return -2, fmt.Errorf("failed to send input: %w", sendErr) + } + } +} + +func (s *execSession) startConnection() (*websocket.Conn, error) { + // First, attempt to connect to the node directly, but may fail due to network isolation + // and network errors. Fallback to using server-side forwarding instead. + nodeClient, err := s.client.GetNodeClientWithTimeout(s.alloc.NodeID, ClientConnTimeout, s.q) + if err == NodeDownErr { + return nil, NodeDownErr + } + + q := s.q + if q == nil { + q = &QueryOptions{} + } + if q.Params == nil { + q.Params = make(map[string]string) + } + + commandBytes, err := json.Marshal(s.command) + if err != nil { + return nil, fmt.Errorf("failed to marshal command: %W", err) + } + + q.Params["tty"] = strconv.FormatBool(s.tty) + q.Params["task"] = s.task + q.Params["command"] = string(commandBytes) + + reqPath := fmt.Sprintf("/v1/client/allocation/%s/exec", s.alloc.ID) + + var conn *websocket.Conn + + if nodeClient != nil { + conn, _, _ = nodeClient.websocket(reqPath, q) + } + + if conn == nil { + conn, _, err = s.client.websocket(reqPath, q) + if err != nil { + return nil, err + } + } + + return conn, nil +} + +func (s *execSession) startTransmit(ctx context.Context, conn *websocket.Conn) <-chan error { + + // FIXME: Handle websocket send errors. + // Currently, websocket write failures are dropped. As sending and + // receiving are running concurrently, it's expected that some send + // requests may fail with connection errors when connection closes. + // Connection errors should surface in the receive paths already, + // but I'm unsure about one-sided communication errors. + var sendLock sync.Mutex + send := func(v *ExecStreamingInput) { + sendLock.Lock() + defer sendLock.Unlock() + + conn.WriteJSON(v) + } + + errCh := make(chan error, 4) + + // propagate stdin + go func() { + + bytes := make([]byte, 2048) + for { + if ctx.Err() != nil { + return + } + + input := ExecStreamingInput{Stdin: &ExecStreamingIOOperation{}} + + n, err := s.stdin.Read(bytes) + + // always send data if we read some + if n != 0 { + input.Stdin.Data = bytes[:n] + send(&input) + } + + // then handle error + if err == io.EOF { + // if n != 0, send data and we'll get n = 0 on next read + if n == 0 { + input.Stdin.Close = true + send(&input) + return + } + } else if err != nil { + errCh <- err + return + } + } + }() + + // propagate terminal sizing updates + go func() { + for { + resizeInput := ExecStreamingInput{} + + select { + case <-ctx.Done(): + return + case size, ok := <-s.terminalSizeCh: + if !ok { + return + } + resizeInput.TTYSize = &size + send(&resizeInput) + } + + } + }() + + // send a heartbeat every 10 seconds + go func() { + for { + select { + case <-ctx.Done(): + return + // heartbeat message + case <-time.After(10 * time.Second): + send(&execStreamingInputHeartbeat) + } + + } + }() + + return errCh +} + +func (s *execSession) startReceiving(ctx context.Context, conn *websocket.Conn) (<-chan int, <-chan error) { + exitCodeCh := make(chan int, 1) + errCh := make(chan error, 1) + + go func() { + for ctx.Err() == nil { + + // Decode the next frame + var frame ExecStreamingOutput + err := conn.ReadJSON(&frame) + if websocket.IsCloseError(err, websocket.CloseNormalClosure) { + errCh <- fmt.Errorf("websocket closed before receiving exit code: %w", err) + return + } else if err != nil { + errCh <- err + return + } + + switch { + case frame.Stdout != nil: + if len(frame.Stdout.Data) != 0 { + s.stdout.Write(frame.Stdout.Data) + } + // don't really do anything if stdout is closing + case frame.Stderr != nil: + if len(frame.Stderr.Data) != 0 { + s.stderr.Write(frame.Stderr.Data) + } + // don't really do anything if stderr is closing + case frame.Exited && frame.Result != nil: + exitCodeCh <- frame.Result.ExitCode + return + default: + // noop - heartbeat + } + + } + + }() + + return exitCodeCh, errCh +} diff --git a/command/agent/alloc_endpoint.go b/command/agent/alloc_endpoint.go index e8eda6025171..d1a7e210c5d2 100644 --- a/command/agent/alloc_endpoint.go +++ b/command/agent/alloc_endpoint.go @@ -515,13 +515,6 @@ func (s *HTTPServer) execStreamImpl(ws *websocket.Conn, args *cstructs.AllocExec go forwardExecInput(encoder, ws, errCh) for { - select { - case <-ctx.Done(): - errCh <- nil - return - default: - } - var res cstructs.StreamErrWrapper err := decoder.Decode(&res) if isClosedError(err) { diff --git a/e2e/nomadexec/exec.go b/e2e/nomadexec/exec.go index 73a113b88dec..e54995a8fd6a 100644 --- a/e2e/nomadexec/exec.go +++ b/e2e/nomadexec/exec.go @@ -7,7 +7,6 @@ import ( "io" "reflect" "regexp" - "strings" "testing" "time" @@ -90,13 +89,7 @@ func (tc *NomadExecE2ETest) TestExecBasicResponses(f *framework.F) { stdin, &stdout, &stderr, resizeCh, nil) - // TODO: Occasionally, we get "Unexpected EOF" error, but with the correct output. - // investigate why - if err != nil && strings.Contains(err.Error(), io.ErrUnexpectedEOF.Error()) { - f.T().Logf("got unexpected EOF error, ignoring: %v", err) - } else { - assert.NoError(t, err) - } + assert.NoError(t, err) assert.Equal(t, c.ExitCode, exitCode) diff --git a/vendor/github.com/hashicorp/nomad/api/allocations.go b/vendor/github.com/hashicorp/nomad/api/allocations.go index 8dc837b390e2..128541ca370b 100644 --- a/vendor/github.com/hashicorp/nomad/api/allocations.go +++ b/vendor/github.com/hashicorp/nomad/api/allocations.go @@ -2,16 +2,10 @@ package api import ( "context" - "encoding/json" - "errors" "fmt" "io" "sort" - "strconv" - "sync" "time" - - "github.com/gorilla/websocket" ) var ( @@ -87,195 +81,22 @@ func (a *Allocations) Exec(ctx context.Context, stdin io.Reader, stdout, stderr io.Writer, terminalSizeCh <-chan TerminalSize, q *QueryOptions) (exitCode int, err error) { - ctx, cancelFn := context.WithCancel(ctx) - defer cancelFn() - - errCh := make(chan error, 4) - - sender, output := a.execFrames(ctx, alloc, task, tty, command, errCh, q) - - select { - case err := <-errCh: - return -2, err - default: - } - - // Errors resulting from sending input (in goroutines) are silently dropped. - // To mitigate this, extra care is needed to distinguish between actual send errors - // and from send errors due to command terminating and our race to detect failures. - // If we have an actual network failure or send a bad input, we'd get an - // error in the reading side of websocket. - - go func() { - - bytes := make([]byte, 2048) - for { - if ctx.Err() != nil { - return - } - - input := ExecStreamingInput{Stdin: &ExecStreamingIOOperation{}} - - n, err := stdin.Read(bytes) - - // always send data if we read some - if n != 0 { - input.Stdin.Data = bytes[:n] - sender(&input) - } - - // then handle error - if err == io.EOF { - // if n != 0, send data and we'll get n = 0 on next read - if n == 0 { - input.Stdin.Close = true - sender(&input) - return - } - } else if err != nil { - errCh <- err - return - } - } - }() - - // forwarding terminal size - go func() { - for { - resizeInput := ExecStreamingInput{} - - select { - case <-ctx.Done(): - return - case size, ok := <-terminalSizeCh: - if !ok { - return - } - resizeInput.TTYSize = &size - sender(&resizeInput) - } - - } - }() - - // send a heartbeat every 10 seconds - go func() { - for { - select { - case <-ctx.Done(): - return - // heartbeat message - case <-time.After(10 * time.Second): - sender(&execStreamingInputHeartbeat) - } - - } - }() - - for { - select { - case err := <-errCh: - // drop websocket code, not relevant to user - if wsErr, ok := err.(*websocket.CloseError); ok && wsErr.Text != "" { - return -2, errors.New(wsErr.Text) - } - return -2, err - case <-ctx.Done(): - return -2, ctx.Err() - case frame, ok := <-output: - if !ok { - return -2, errors.New("disconnected without receiving the exit code") - } - - switch { - case frame.Stdout != nil: - if len(frame.Stdout.Data) != 0 { - stdout.Write(frame.Stdout.Data) - } - // don't really do anything if stdout is closing - case frame.Stderr != nil: - if len(frame.Stderr.Data) != 0 { - stderr.Write(frame.Stderr.Data) - } - // don't really do anything if stderr is closing - case frame.Exited && frame.Result != nil: - return frame.Result.ExitCode, nil - default: - // noop - heartbeat - } - } - } -} + s := &execSession{ + client: a.client, + alloc: alloc, + task: task, + tty: tty, + command: command, -func (a *Allocations) execFrames(ctx context.Context, alloc *Allocation, task string, tty bool, command []string, - errCh chan<- error, q *QueryOptions) (sendFn func(*ExecStreamingInput) error, output <-chan *ExecStreamingOutput) { - nodeClient, _ := a.client.GetNodeClientWithTimeout(alloc.NodeID, ClientConnTimeout, q) + stdin: stdin, + stdout: stdout, + stderr: stderr, - if q == nil { - q = &QueryOptions{} - } - if q.Params == nil { - q.Params = make(map[string]string) + terminalSizeCh: terminalSizeCh, + q: q, } - commandBytes, err := json.Marshal(command) - if err != nil { - errCh <- fmt.Errorf("failed to marshal command: %s", err) - return nil, nil - } - - q.Params["tty"] = strconv.FormatBool(tty) - q.Params["task"] = task - q.Params["command"] = string(commandBytes) - - reqPath := fmt.Sprintf("/v1/client/allocation/%s/exec", alloc.ID) - - var conn *websocket.Conn - - if nodeClient != nil { - conn, _, _ = nodeClient.websocket(reqPath, q) - } - - if conn == nil { - conn, _, err = a.client.websocket(reqPath, q) - if err != nil { - errCh <- err - return nil, nil - } - } - - // Create the output channel - frames := make(chan *ExecStreamingOutput, 10) - - go func() { - defer conn.Close() - for ctx.Err() == nil { - - // Decode the next frame - var frame ExecStreamingOutput - err := conn.ReadJSON(&frame) - if websocket.IsCloseError(err, websocket.CloseNormalClosure) { - close(frames) - return - } else if err != nil { - errCh <- err - return - } - - frames <- &frame - } - }() - - var sendLock sync.Mutex - send := func(v *ExecStreamingInput) error { - sendLock.Lock() - defer sendLock.Unlock() - - return conn.WriteJSON(v) - } - - return send, frames - + return s.run(ctx) } func (a *Allocations) Stats(alloc *Allocation, q *QueryOptions) (*AllocResourceUsage, error) { diff --git a/vendor/github.com/hashicorp/nomad/api/allocations_exec.go b/vendor/github.com/hashicorp/nomad/api/allocations_exec.go new file mode 100644 index 000000000000..9f5e0e299784 --- /dev/null +++ b/vendor/github.com/hashicorp/nomad/api/allocations_exec.go @@ -0,0 +1,236 @@ +package api + +import ( + "context" + "encoding/json" + "errors" + "fmt" + "io" + "strconv" + "sync" + "time" + + "github.com/gorilla/websocket" +) + +type execSession struct { + client *Client + alloc *Allocation + task string + tty bool + command []string + + stdin io.Reader + stdout io.Writer + stderr io.Writer + + terminalSizeCh <-chan TerminalSize + + q *QueryOptions +} + +func (s *execSession) run(ctx context.Context) (exitCode int, err error) { + ctx, cancelFn := context.WithCancel(ctx) + defer cancelFn() + + conn, err := s.startConnection() + if err != nil { + return -2, err + } + defer conn.Close() + + sendErrCh := s.startTransmit(ctx, conn) + exitCh, recvErrCh := s.startReceiving(ctx, conn) + + for { + select { + case <-ctx.Done(): + return -2, ctx.Err() + case exitCode := <-exitCh: + return exitCode, nil + case recvErr := <-recvErrCh: + // drop websocket code, not relevant to user + if wsErr, ok := recvErr.(*websocket.CloseError); ok && wsErr.Text != "" { + return -2, errors.New(wsErr.Text) + } + + return -2, recvErr + case sendErr := <-sendErrCh: + return -2, fmt.Errorf("failed to send input: %w", sendErr) + } + } +} + +func (s *execSession) startConnection() (*websocket.Conn, error) { + // First, attempt to connect to the node directly, but may fail due to network isolation + // and network errors. Fallback to using server-side forwarding instead. + nodeClient, err := s.client.GetNodeClientWithTimeout(s.alloc.NodeID, ClientConnTimeout, s.q) + if err == NodeDownErr { + return nil, NodeDownErr + } + + q := s.q + if q == nil { + q = &QueryOptions{} + } + if q.Params == nil { + q.Params = make(map[string]string) + } + + commandBytes, err := json.Marshal(s.command) + if err != nil { + return nil, fmt.Errorf("failed to marshal command: %W", err) + } + + q.Params["tty"] = strconv.FormatBool(s.tty) + q.Params["task"] = s.task + q.Params["command"] = string(commandBytes) + + reqPath := fmt.Sprintf("/v1/client/allocation/%s/exec", s.alloc.ID) + + var conn *websocket.Conn + + if nodeClient != nil { + conn, _, _ = nodeClient.websocket(reqPath, q) + } + + if conn == nil { + conn, _, err = s.client.websocket(reqPath, q) + if err != nil { + return nil, err + } + } + + return conn, nil +} + +func (s *execSession) startTransmit(ctx context.Context, conn *websocket.Conn) <-chan error { + + // FIXME: Handle websocket send errors. + // Currently, websocket write failures are dropped. As sending and + // receiving are running concurrently, it's expected that some send + // requests may fail with connection errors when connection closes. + // Connection errors should surface in the receive paths already, + // but I'm unsure about one-sided communication errors. + var sendLock sync.Mutex + send := func(v *ExecStreamingInput) { + sendLock.Lock() + defer sendLock.Unlock() + + conn.WriteJSON(v) + } + + errCh := make(chan error, 4) + + // propagate stdin + go func() { + + bytes := make([]byte, 2048) + for { + if ctx.Err() != nil { + return + } + + input := ExecStreamingInput{Stdin: &ExecStreamingIOOperation{}} + + n, err := s.stdin.Read(bytes) + + // always send data if we read some + if n != 0 { + input.Stdin.Data = bytes[:n] + send(&input) + } + + // then handle error + if err == io.EOF { + // if n != 0, send data and we'll get n = 0 on next read + if n == 0 { + input.Stdin.Close = true + send(&input) + return + } + } else if err != nil { + errCh <- err + return + } + } + }() + + // propagate terminal sizing updates + go func() { + for { + resizeInput := ExecStreamingInput{} + + select { + case <-ctx.Done(): + return + case size, ok := <-s.terminalSizeCh: + if !ok { + return + } + resizeInput.TTYSize = &size + send(&resizeInput) + } + + } + }() + + // send a heartbeat every 10 seconds + go func() { + for { + select { + case <-ctx.Done(): + return + // heartbeat message + case <-time.After(10 * time.Second): + send(&execStreamingInputHeartbeat) + } + + } + }() + + return errCh +} + +func (s *execSession) startReceiving(ctx context.Context, conn *websocket.Conn) (<-chan int, <-chan error) { + exitCodeCh := make(chan int, 1) + errCh := make(chan error, 1) + + go func() { + for ctx.Err() == nil { + + // Decode the next frame + var frame ExecStreamingOutput + err := conn.ReadJSON(&frame) + if websocket.IsCloseError(err, websocket.CloseNormalClosure) { + errCh <- fmt.Errorf("websocket closed before receiving exit code: %w", err) + return + } else if err != nil { + errCh <- err + return + } + + switch { + case frame.Stdout != nil: + if len(frame.Stdout.Data) != 0 { + s.stdout.Write(frame.Stdout.Data) + } + // don't really do anything if stdout is closing + case frame.Stderr != nil: + if len(frame.Stderr.Data) != 0 { + s.stderr.Write(frame.Stderr.Data) + } + // don't really do anything if stderr is closing + case frame.Exited && frame.Result != nil: + exitCodeCh <- frame.Result.ExitCode + return + default: + // noop - heartbeat + } + + } + + }() + + return exitCodeCh, errCh +}