Skip to content

Commit

Permalink
refactor: improving coding style
Browse files Browse the repository at this point in the history
  • Loading branch information
jpts committed Jun 11, 2023
1 parent 907df92 commit 27ea48c
Showing 1 changed file with 103 additions and 98 deletions.
201 changes: 103 additions & 98 deletions cmd/websocket.go
Original file line number Diff line number Diff line change
Expand Up @@ -18,9 +18,10 @@ import (
)

type WebsocketRoundTripper struct {
Dialer *websocket.Dialer
TermState *TerminalState
opts Options
Dialer *websocket.Dialer
TermState *TerminalState
opts Options
SendBuffer bytes.Buffer
}

type ApiServerError struct {
Expand Down Expand Up @@ -56,131 +57,135 @@ func (d *WebsocketRoundTripper) RoundTrip(r *http.Request) (*http.Response, erro

func (d *WebsocketRoundTripper) WsCallback(ws *websocket.Conn) error {
errChan := make(chan error, 4)
var sendBuffer bytes.Buffer

wg := sync.WaitGroup{}
wg.Add(3)

stdIn, stdOut, stdErr := term.StdStreams()
go d.concurrentSend(&wg, ws, errChan)
go d.concurrentRecv(&wg, ws, errChan)
go d.concurrentResize(&wg, ws, errChan)

// send
go func() {
defer wg.Done()
buf := make([]byte, 1025)
for {
n, err := stdIn.Read(buf[1:])
if err != nil {
errChan <- err
return
wg.Wait()
close(errChan)
}()

for err := range errChan {
if e, ok := err.(*websocket.CloseError); ok {
klog.V(4).Infof("Closing websocket connection with error code %d, err: %s", e.Code, err)
}
if websocket.IsCloseError(err, websocket.CloseNormalClosure) {
return nil
} else if errors.Is(err, io.EOF) {
return nil
}
return err
}
return nil
}

func (d *WebsocketRoundTripper) concurrentSend(wg *sync.WaitGroup, ws *websocket.Conn, errChan chan error) {
defer wg.Done()

buf := make([]byte, 1025)
stdIn, _, _ := term.StdStreams()

for {
n, err := stdIn.Read(buf[1:])
if err != nil {
errChan <- err
return
}

d.SendBuffer.Write(buf[1:n])
d.SendBuffer.Write([]byte{13, 10})
err = ws.WriteMessage(websocket.BinaryMessage, buf[:n+1])
if err != nil {
errChan <- err
return
}
}
}

func (d *WebsocketRoundTripper) concurrentRecv(wg *sync.WaitGroup, ws *websocket.Conn, errChan chan error) {
defer wg.Done()

_, stdOut, stdErr := term.StdStreams()

for {
msgType, buf, err := ws.ReadMessage()
if err != nil {
errChan <- err
return
}
if msgType != websocket.BinaryMessage {
errChan <- errors.New("Received unexpected websocket message")
return
}
if len(buf) > 1 {
var w io.Writer
switch buf[0] {
case streamStdOut:
w = stdOut
case streamStdErr:
w = stdErr
case streamErr:
if err := parseStreamErr(buf[1:]); err != nil {
errChan <- err
return
}
default:
errChan <- fmt.Errorf("Unknown stream type: %d", buf[0])
continue
}

if w == nil {
continue
}

sendBuffer.Write(buf[1:n])
sendBuffer.Write([]byte{13, 10})
err = ws.WriteMessage(websocket.BinaryMessage, buf[:n+1])
out := buf[1:]
_, err = w.Write(out)
if err != nil {
errChan <- err
return
}
}
}()
d.SendBuffer.Reset()
}
}

// recv
go func() {
defer wg.Done()
func (d *WebsocketRoundTripper) concurrentResize(wg *sync.WaitGroup, ws *websocket.Conn, errChan chan error) {
defer wg.Done()
if d.opts.TTY {
resizeNotify := registerResizeSignal()

d.TermState.Initialised = false
for {
msgType, buf, err := ws.ReadMessage()
changed, err := updateSize(d.TermState)
if err != nil {
errChan <- err
errChan <- fmt.Errorf("Failed to update terminal size: %w", err)
return
}
if msgType != websocket.BinaryMessage {
errChan <- errors.New("Received unexpected websocket message")
return
}
if len(buf) > 1 {
var w io.Writer
switch buf[0] {
case streamStdOut:
w = stdOut
case streamStdErr:
w = stdErr
case streamErr:
if err := parseStreamErr(buf[1:]); err != nil {
errChan <- err
return
}
default:
errChan <- fmt.Errorf("Unknown stream type: %d", buf[0])
continue
}

if w == nil {
continue
}

out := buf[1:]
_, err = w.Write(out)
if changed || !d.TermState.Initialised {
res, err := json.Marshal(d.TermState.Size)
if err != nil {
errChan <- err
errChan <- fmt.Errorf("Failed to marshal JSON: %w", err)
return
}
}
sendBuffer.Reset()
}
}()

// resize
go func() {
defer wg.Done()
if d.opts.TTY {
resizeNotify := registerResizeSignal()
msg := []byte(fmt.Sprintf("%s%s", "\x04", res))

d.TermState.Initialised = false
for {
changed, err := updateSize(d.TermState)
err = ws.WriteMessage(websocket.BinaryMessage, msg)
if err != nil {
errChan <- fmt.Errorf("Failed to update terminal size: %w", err)
errChan <- fmt.Errorf("Failed to write msg to channel: %w", err)
return
}

if changed || !d.TermState.Initialised {
res, err := json.Marshal(d.TermState.Size)
if err != nil {
errChan <- fmt.Errorf("Failed to marshal JSON: %w", err)
return
}
msg := []byte(fmt.Sprintf("%s%s", "\x04", res))

err = ws.WriteMessage(websocket.BinaryMessage, msg)
if err != nil {
errChan <- fmt.Errorf("Failed to write msg to channel: %w", err)
return
}
d.TermState.Initialised = true
}

waitForResizeChange(resizeNotify)
d.TermState.Initialised = true
}
}
}()

go func() {
wg.Wait()
close(errChan)
}()

for err := range errChan {
if e, ok := err.(*websocket.CloseError); ok {
klog.V(4).Infof("Closing websocket connection with error code %d, err: %s", e.Code, err)
waitForResizeChange(resizeNotify)
}
if websocket.IsCloseError(err, websocket.CloseNormalClosure) {
return nil
} else if errors.Is(err, io.EOF) {
return nil
}
return err
}
return nil
}

type streamError struct {
Expand Down

0 comments on commit 27ea48c

Please sign in to comment.