From 27ea48c2234972b972ad264c38bc502a72841680 Mon Sep 17 00:00:00 2001 From: jpts Date: Sun, 11 Jun 2023 13:03:52 +0100 Subject: [PATCH] refactor: improving coding style --- cmd/websocket.go | 201 ++++++++++++++++++++++++----------------------- 1 file changed, 103 insertions(+), 98 deletions(-) diff --git a/cmd/websocket.go b/cmd/websocket.go index 0232da5..27f504e 100644 --- a/cmd/websocket.go +++ b/cmd/websocket.go @@ -18,9 +18,10 @@ import ( ) type WebsocketRoundTripper struct { - Dialer *websocket.Dialer - TermState *TerminalState - opts Options + Dialer *websocket.Dialer + TermState *TerminalState + opts Options + SendBuffer bytes.Buffer } type ApiServerError struct { @@ -56,131 +57,135 @@ func (d *WebsocketRoundTripper) RoundTrip(r *http.Request) (*http.Response, erro func (d *WebsocketRoundTripper) WsCallback(ws *websocket.Conn) error { errChan := make(chan error, 4) - var sendBuffer bytes.Buffer wg := sync.WaitGroup{} wg.Add(3) - stdIn, stdOut, stdErr := term.StdStreams() + go d.concurrentSend(&wg, ws, errChan) + go d.concurrentRecv(&wg, ws, errChan) + go d.concurrentResize(&wg, ws, errChan) - // send go func() { - defer wg.Done() - buf := make([]byte, 1025) - for { - n, err := stdIn.Read(buf[1:]) - if err != nil { - errChan <- err - return + wg.Wait() + close(errChan) + }() + + for err := range errChan { + if e, ok := err.(*websocket.CloseError); ok { + klog.V(4).Infof("Closing websocket connection with error code %d, err: %s", e.Code, err) + } + if websocket.IsCloseError(err, websocket.CloseNormalClosure) { + return nil + } else if errors.Is(err, io.EOF) { + return nil + } + return err + } + return nil +} + +func (d *WebsocketRoundTripper) concurrentSend(wg *sync.WaitGroup, ws *websocket.Conn, errChan chan error) { + defer wg.Done() + + buf := make([]byte, 1025) + stdIn, _, _ := term.StdStreams() + + for { + n, err := stdIn.Read(buf[1:]) + if err != nil { + errChan <- err + return + } + + d.SendBuffer.Write(buf[1:n]) + d.SendBuffer.Write([]byte{13, 10}) + err = ws.WriteMessage(websocket.BinaryMessage, buf[:n+1]) + if err != nil { + errChan <- err + return + } + } +} + +func (d *WebsocketRoundTripper) concurrentRecv(wg *sync.WaitGroup, ws *websocket.Conn, errChan chan error) { + defer wg.Done() + + _, stdOut, stdErr := term.StdStreams() + + for { + msgType, buf, err := ws.ReadMessage() + if err != nil { + errChan <- err + return + } + if msgType != websocket.BinaryMessage { + errChan <- errors.New("Received unexpected websocket message") + return + } + if len(buf) > 1 { + var w io.Writer + switch buf[0] { + case streamStdOut: + w = stdOut + case streamStdErr: + w = stdErr + case streamErr: + if err := parseStreamErr(buf[1:]); err != nil { + errChan <- err + return + } + default: + errChan <- fmt.Errorf("Unknown stream type: %d", buf[0]) + continue + } + + if w == nil { + continue } - sendBuffer.Write(buf[1:n]) - sendBuffer.Write([]byte{13, 10}) - err = ws.WriteMessage(websocket.BinaryMessage, buf[:n+1]) + out := buf[1:] + _, err = w.Write(out) if err != nil { errChan <- err return } } - }() + d.SendBuffer.Reset() + } +} - // recv - go func() { - defer wg.Done() +func (d *WebsocketRoundTripper) concurrentResize(wg *sync.WaitGroup, ws *websocket.Conn, errChan chan error) { + defer wg.Done() + if d.opts.TTY { + resizeNotify := registerResizeSignal() + + d.TermState.Initialised = false for { - msgType, buf, err := ws.ReadMessage() + changed, err := updateSize(d.TermState) if err != nil { - errChan <- err + errChan <- fmt.Errorf("Failed to update terminal size: %w", err) return } - if msgType != websocket.BinaryMessage { - errChan <- errors.New("Received unexpected websocket message") - return - } - if len(buf) > 1 { - var w io.Writer - switch buf[0] { - case streamStdOut: - w = stdOut - case streamStdErr: - w = stdErr - case streamErr: - if err := parseStreamErr(buf[1:]); err != nil { - errChan <- err - return - } - default: - errChan <- fmt.Errorf("Unknown stream type: %d", buf[0]) - continue - } - if w == nil { - continue - } - - out := buf[1:] - _, err = w.Write(out) + if changed || !d.TermState.Initialised { + res, err := json.Marshal(d.TermState.Size) if err != nil { - errChan <- err + errChan <- fmt.Errorf("Failed to marshal JSON: %w", err) return } - } - sendBuffer.Reset() - } - }() - - // resize - go func() { - defer wg.Done() - if d.opts.TTY { - resizeNotify := registerResizeSignal() + msg := []byte(fmt.Sprintf("%s%s", "\x04", res)) - d.TermState.Initialised = false - for { - changed, err := updateSize(d.TermState) + err = ws.WriteMessage(websocket.BinaryMessage, msg) if err != nil { - errChan <- fmt.Errorf("Failed to update terminal size: %w", err) + errChan <- fmt.Errorf("Failed to write msg to channel: %w", err) return } - - if changed || !d.TermState.Initialised { - res, err := json.Marshal(d.TermState.Size) - if err != nil { - errChan <- fmt.Errorf("Failed to marshal JSON: %w", err) - return - } - msg := []byte(fmt.Sprintf("%s%s", "\x04", res)) - - err = ws.WriteMessage(websocket.BinaryMessage, msg) - if err != nil { - errChan <- fmt.Errorf("Failed to write msg to channel: %w", err) - return - } - d.TermState.Initialised = true - } - - waitForResizeChange(resizeNotify) + d.TermState.Initialised = true } - } - }() - - go func() { - wg.Wait() - close(errChan) - }() - for err := range errChan { - if e, ok := err.(*websocket.CloseError); ok { - klog.V(4).Infof("Closing websocket connection with error code %d, err: %s", e.Code, err) + waitForResizeChange(resizeNotify) } - if websocket.IsCloseError(err, websocket.CloseNormalClosure) { - return nil - } else if errors.Is(err, io.EOF) { - return nil - } - return err } - return nil } type streamError struct {