Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Support HTTP Websocket Connections in wsclient #899

Merged
merged 2 commits into from
Jul 21, 2017
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 6 additions & 3 deletions agent/tcs/handler/handler_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -82,7 +82,8 @@ func TestFormatURL(t *testing.T) {
func TestStartSession(t *testing.T) {
// Start test server.
closeWS := make(chan []byte)
server, serverChan, requestChan, serverErr, err := wsmock.StartMockServer(t, closeWS)
server, serverChan, requestChan, serverErr, err := wsmock.GetMockServer(t, closeWS)
server.StartTLS()
defer server.Close()
if err != nil {
t.Fatal(err)
Expand Down Expand Up @@ -142,7 +143,8 @@ func TestStartSession(t *testing.T) {
func TestSessionConnectionClosedByRemote(t *testing.T) {
// Start test server.
closeWS := make(chan []byte)
server, serverChan, _, serverErr, err := wsmock.StartMockServer(t, closeWS)
server, serverChan, _, serverErr, err := wsmock.GetMockServer(t, closeWS)
server.StartTLS()
defer server.Close()
if err != nil {
t.Fatal(err)
Expand Down Expand Up @@ -181,7 +183,8 @@ func TestSessionConnectionClosedByRemote(t *testing.T) {
func TestConnectionInactiveTimeout(t *testing.T) {
// Start test server.
closeWS := make(chan []byte)
server, _, requestChan, serverErr, err := wsmock.StartMockServer(t, closeWS)
server, _, requestChan, serverErr, err := wsmock.GetMockServer(t, closeWS)
server.StartTLS()
defer server.Close()
if err != nil {
t.Fatal(err)
Expand Down
20 changes: 17 additions & 3 deletions agent/wsclient/client.go
Original file line number Diff line number Diff line change
Expand Up @@ -54,9 +54,6 @@ const (
// writeBufSize is the size of the write buffer for the ws connection.
writeBufSize = 32768

// gorilla/websocket expects the websocket scheme (ws[s]://)
wsScheme = "wss"

// Default NO_PROXY env var IP addresses
defaultNoProxyIP = "169.254.169.254,169.254.170.2"
)
Expand Down Expand Up @@ -142,6 +139,10 @@ func (cs *ClientServerImpl) Connect() error {
return err
}

wsScheme, err := websocketScheme(parsedURL.Scheme)
if err != nil {
return err
}
parsedURL.Scheme = wsScheme

// NewRequest never returns an error if the url parses and we just verified
Expand Down Expand Up @@ -349,6 +350,19 @@ func (cs *ClientServerImpl) handleMessage(data []byte) {
}
}

func websocketScheme(httpScheme string) (wsScheme string, err error) {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

In general, we've avoided using named returns. Please change this.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

// gorilla/websocket expects the websocket scheme (ws[s]://)
switch httpScheme {
case "http":
wsScheme = "ws"
case "https":
wsScheme = "wss"
default:
err = fmt.Errorf("Unknown httpScheme %s", httpScheme)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Use format $context: msg and package errors.New()

errors.New("wsclient: Unknown httpScheme %s", httpScheme)

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done

}
return
}

// See https://github.com/gorilla/websocket/blob/87f6f6a22ebfbc3f89b9ccdc7fddd1b914c095f9/conn.go#L650
func permissibleCloseCode(err error) bool {
return websocket.IsCloseError(err,
Expand Down
76 changes: 71 additions & 5 deletions agent/wsclient/client_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@ package wsclient

import (
"io"
"net/url"
"os"
"testing"

Expand All @@ -39,7 +40,8 @@ func TestConcurrentWritesDontPanic(t *testing.T) {
closeWS := make(chan []byte)
defer close(closeWS)

mockServer, _, requests, _, _ := utils.StartMockServer(t, closeWS)
mockServer, _, requests, _, _ := utils.GetMockServer(t, closeWS)
mockServer.StartTLS()
defer mockServer.Close()

req := ecsacs.AckRequest{Cluster: aws.String("test"), ContainerInstance: aws.String("test"), MessageId: aws.String("test")}
Expand Down Expand Up @@ -70,7 +72,8 @@ func TestProxyVariableCustomValue(t *testing.T) {
closeWS := make(chan []byte)
defer close(closeWS)

mockServer, _, _, _, _ := utils.StartMockServer(t, closeWS)
mockServer, _, _, _, _ := utils.GetMockServer(t, closeWS)
mockServer.StartTLS()
defer mockServer.Close()

testString := "Custom no proxy string"
Expand All @@ -86,7 +89,8 @@ func TestProxyVariableDefaultValue(t *testing.T) {
closeWS := make(chan []byte)
defer close(closeWS)

mockServer, _, _, _, _ := utils.StartMockServer(t, closeWS)
mockServer, _, _, _, _ := utils.GetMockServer(t, closeWS)
mockServer.StartTLS()
defer mockServer.Close()

os.Unsetenv("NO_PROXY")
Expand All @@ -104,7 +108,8 @@ func TestHandleMessagePermissibleCloseCode(t *testing.T) {
defer close(closeWS)

messageError := make(chan error)
mockServer, _, _, _, _ := utils.StartMockServer(t, closeWS)
mockServer, _, _, _, _ := utils.GetMockServer(t, closeWS)
mockServer.StartTLS()
cs := getClientServer(mockServer.URL)
cs.Connect()

Expand All @@ -123,7 +128,8 @@ func TestHandleMessageUnexpectedCloseCode(t *testing.T) {
defer close(closeWS)

messageError := make(chan error)
mockServer, _, _, _, _ := utils.StartMockServer(t, closeWS)
mockServer, _, _, _, _ := utils.GetMockServer(t, closeWS)
mockServer.StartTLS()
cs := getClientServer(mockServer.URL)
cs.Connect()

Expand All @@ -135,6 +141,66 @@ func TestHandleMessageUnexpectedCloseCode(t *testing.T) {
assert.True(t, websocket.IsCloseError(<-messageError, websocket.CloseTryAgainLater), "Expected error from websocket library")
}

// TestHandlNonHTTPSEndpoint verifies that the wsclient can handle communication over
// an HTTP (so WS) connection
func TestHandleNonHTTPSEndpoint(t *testing.T) {
closeWS := make(chan []byte)
defer close(closeWS)

mockServer, _, requests, _, _ := utils.GetMockServer(t, closeWS)
mockServer.Start()
defer mockServer.Close()

cs := getClientServer(mockServer.URL)
cs.Connect()

req := ecsacs.AckRequest{Cluster: aws.String("test"), ContainerInstance: aws.String("test"), MessageId: aws.String("test")}
cs.MakeRequest(&req)

t.Log("Waiting for single request to be visible server-side")
<-requests
}

// TestHandleIncorrectHttpScheme checks that an incorrect URL scheme results in
// an error
func TestHandleIncorrectURLScheme(t *testing.T) {
closeWS := make(chan []byte)
defer close(closeWS)

mockServer, _, _, _, _ := utils.GetMockServer(t, closeWS)
mockServer.StartTLS()
defer mockServer.Close()

mockServerURL, _ := url.Parse(mockServer.URL)
mockServerURL.Scheme = "notaparticularlyrealscheme"

cs := getClientServer(mockServerURL.String())
err := cs.Connect()

assert.Error(t, err, "Expected error for incorrect URL scheme")
}

// TestWebsocketScheme checks that websocketScheme handles valid and invalid mappings
// correctly
func TestWebsocketScheme(t *testing.T) {
// test valid schemes
validMappings := map[string]string{
"http": "ws",
"https": "wss",
}

for input, expectedOutput := range validMappings {
actualOutput, err := websocketScheme(input)

assert.NoError(t, err, "Unexpected error for valid http scheme")
assert.Equal(t, actualOutput, expectedOutput, "Valid http schemes should map to a websocket scheme")
}

// test an invalid mapping
_, err := websocketScheme("highly-likely-to-be-junk")
assert.Error(t, err, "Expected error for invalid http scheme")
}

func getClientServer(url string) *ClientServerImpl {
types := []interface{}{ecsacs.AckRequest{}}

Expand Down
6 changes: 3 additions & 3 deletions agent/wsclient/mock/utils/utils.go
Original file line number Diff line number Diff line change
Expand Up @@ -22,9 +22,9 @@ import (
"github.com/gorilla/websocket"
)

// StartMockServer starts a mock websocket server.
// GetMockServer retuns a mock websocket server that can be started up as TLS or not.
// TODO replace with gomock
func StartMockServer(t *testing.T, closeWS <-chan []byte) (*httptest.Server, chan<- string, <-chan string, <-chan error, error) {
func GetMockServer(t *testing.T, closeWS <-chan []byte) (*httptest.Server, chan<- string, <-chan string, <-chan error, error) {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nice!!

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

"Should we gate this code path with an 'allowUnsecureACS' option?"

So my argument against is that we do allow using an http endpoint to talk to ECS via the SDK. If there were a flag, it should apply to both.

serverChan := make(chan string)
requestsChan := make(chan string)
errChan := make(chan error)
Expand Down Expand Up @@ -63,6 +63,6 @@ func StartMockServer(t *testing.T, closeWS <-chan []byte) (*httptest.Server, cha
}
})

server := httptest.NewTLSServer(handler)
server := httptest.NewUnstartedServer(handler)
return server, serverChan, requestsChan, errChan, nil
}