Skip to content

Commit

Permalink
Support Cookie JWT auth via WebSocket
Browse files Browse the repository at this point in the history
  • Loading branch information
pas2k committed Jun 18, 2020
1 parent fa744fd commit 158b971
Show file tree
Hide file tree
Showing 4 changed files with 257 additions and 4 deletions.
8 changes: 8 additions & 0 deletions server/client.go
Original file line number Diff line number Diff line change
Expand Up @@ -1497,7 +1497,15 @@ func (c *client) processConnect(arg []byte) error {
lang := c.opts.Lang
account := c.opts.Account
accountNew := c.opts.AccountNew

// If JWT is not in opts, imply cookie JWT
if c.opts.JWT == "" {
if ws := c.ws; ws != nil {
c.opts.JWT = ws.cookieJwt
}
}
ujwt := c.opts.JWT

// For headers both client and server need to support.
c.headers = supportsHeaders && c.opts.Headers
c.mu.Unlock()
Expand Down
6 changes: 6 additions & 0 deletions server/opts.go
Original file line number Diff line number Diff line change
Expand Up @@ -265,6 +265,10 @@ type WebsocketOpts struct {
// Users defined here or in the global options.
NoAuthUser string

// Name of the cookie, which if present, will be treated as JWT.
// If not valid or not present, protocol authorization is possible.
JWTCookie string

// Authentication section. If anything is configured in this section,
// it will override the authorization configuration for regular clients.
Username string
Expand Down Expand Up @@ -3152,6 +3156,8 @@ func parseWebsocket(v interface{}, o *Options, errors *[]error, warnings *[]erro
if auth.nkeys != nil {
o.Websocket.Nkeys = append(o.Websocket.Nkeys, auth.nkeys...)
}
case "jwt_cookie":
o.Websocket.JWTCookie = mv.(string)
case "no_auth_user":
o.Websocket.NoAuthUser = mv.(string)
default:
Expand Down
12 changes: 12 additions & 0 deletions server/websocket.go
Original file line number Diff line number Diff line change
Expand Up @@ -90,6 +90,7 @@ type websocket struct {
closeSent bool
browser bool
compressor *flate.Writer
cookieJwt string
}

type srvWebsocket struct {
Expand Down Expand Up @@ -597,6 +598,11 @@ func (s *Server) wsUpgrade(w http.ResponseWriter, r *http.Request) (*wsUpgradeRe
if ua := r.Header.Get("User-Agent"); ua != "" && strings.HasPrefix(ua, "Mozilla/") {
ws.browser = true
}
if opts.Websocket.JWTCookie != "" {
if c, err := r.Cookie(opts.Websocket.JWTCookie); err == nil && c != nil {
ws.cookieJwt = c.Value
}
}
return &wsUpgradeResult{conn: conn, ws: ws}, nil
}

Expand Down Expand Up @@ -748,6 +754,12 @@ func validateWebsocketOptions(o *Options) error {
}
return fmt.Errorf("websocket no_auth_user %q not found in users configuration", wo.NoAuthUser)
}
// Using JWT requires Trusted Keys
if wo.JWTCookie != "" {
if len(o.TrustedOperators) == 0 && len(o.TrustedKeys) == 0 {
return fmt.Errorf("trusted operators or trusted keys configuration is required for JWT authentication via cookie %q", wo.JWTCookie)
}
}
return nil
}

Expand Down
235 changes: 231 additions & 4 deletions server/websocket_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,7 @@ import (
"testing"
"time"

"github.com/nats-io/jwt/v2"
"github.com/nats-io/nkeys"
)

Expand Down Expand Up @@ -1584,9 +1585,16 @@ func TestWSAbnormalFailureOfWebServer(t *testing.T) {
}
}

func testWSCreateClientGetInfo(t testing.TB, compress, web bool, host string, port int) (net.Conn, *bufio.Reader, []byte) {
type testWSClientOptions struct {
compress, web bool
host string
port int
extraHeaders map[string]string
}

func testNewWSClient(t testing.TB, o testWSClientOptions) (net.Conn, *bufio.Reader, []byte) {
t.Helper()
addr := fmt.Sprintf("%s:%d", host, port)
addr := fmt.Sprintf("%s:%d", o.host, o.port)
wsc, err := net.Dial("tcp", addr)
if err != nil {
t.Fatalf("Error creating ws connection: %v", err)
Expand All @@ -1596,12 +1604,17 @@ func testWSCreateClientGetInfo(t testing.TB, compress, web bool, host string, po
t.Fatalf("Error during handshake: %v", err)
}
req := testWSCreateValidReq()
if compress {
if o.compress {
req.Header.Set("Sec-Websocket-Extensions", "permessage-deflate")
}
if web {
if o.web {
req.Header.Set("User-Agent", "Mozilla/5.0")
}
if o.extraHeaders != nil {
for hdr, val := range o.extraHeaders {
req.Header.Add(hdr, val)
}
}
req.URL, _ = url.Parse("wss://" + addr)
if err := req.Write(wsc); err != nil {
t.Fatalf("Error sending request: %v", err)
Expand All @@ -1623,6 +1636,102 @@ func testWSCreateClientGetInfo(t testing.TB, compress, web bool, host string, po
return wsc, br, info
}

type testClaimsOptions struct {
nac *jwt.AccountClaims
nuc *jwt.UserClaims
connectRequest interface{}
dontSign bool
expectAnswer string
}

func testWSWithClaims(t *testing.T, s *Server, o testWSClientOptions, tclm testClaimsOptions) (kp nkeys.KeyPair, conn net.Conn, rdr *bufio.Reader, auth_was_required bool) {
t.Helper()

okp, _ := nkeys.FromSeed(oSeed)

akp, _ := nkeys.CreateAccount()
apub, _ := akp.PublicKey()
if tclm.nac == nil {
tclm.nac = jwt.NewAccountClaims(apub)
} else {
tclm.nac.Subject = apub
}
ajwt, err := tclm.nac.Encode(okp)
if err != nil {
t.Fatalf("Error generating account JWT: %v", err)
}

nkp, _ := nkeys.CreateUser()
pub, _ := nkp.PublicKey()
if tclm.nuc == nil {
tclm.nuc = jwt.NewUserClaims(pub)
} else {
tclm.nuc.Subject = pub
}
jwt, err := tclm.nuc.Encode(akp)
if err != nil {
t.Fatalf("Error generating user JWT: %v", err)
}

addAccountToMemResolver(s, apub, ajwt)

c, cr, l := testNewWSClient(t, o)

var info struct {
Nonce string `json:"nonce,omitempty"`
AuthRequired bool `json:"auth_required,omitempty"`
}

if err := json.Unmarshal([]byte(l[5:]), &info); err != nil {
t.Fatal(err)
}
if info.AuthRequired {
cs := ""
if tclm.connectRequest != nil {
customReq, err := json.Marshal(tclm.connectRequest)
if err != nil {
t.Fatal(err)
}
// PING needed to flush the +OK/-ERR to us.
cs = fmt.Sprintf("CONNECT %v\r\nPING\r\n", string(customReq))
} else if !tclm.dontSign {
// Sign Nonce
sigraw, _ := nkp.Sign([]byte(info.Nonce))
sig := base64.RawURLEncoding.EncodeToString(sigraw)
cs = fmt.Sprintf("CONNECT {\"jwt\":%q,\"sig\":\"%s\",\"verbose\":true,\"pedantic\":true}\r\nPING\r\n", jwt, sig)
} else {
cs = fmt.Sprintf("CONNECT {\"jwt\":%q,\"verbose\":true,\"pedantic\":true}\r\nPING\r\n", jwt)
}
wsmsg := testWSCreateClientMsg(wsBinaryMessage, 1, true, false, []byte(cs))
c.Write(wsmsg)
l = testWSReadFrame(t, cr)
if !strings.HasPrefix(string(l), tclm.expectAnswer) {
t.Fatalf("Expected %q, got %q", tclm.expectAnswer, l)
}
}
return akp, c, cr, info.AuthRequired
}

func setupAddTrusted(o *Options) {
kp, _ := nkeys.FromSeed(oSeed)
pub, _ := kp.PublicKey()
o.TrustedKeys = []string{pub}
}

func setupAddCookie(o *Options) {
o.Websocket.JWTCookie = "jwt"
}

func testWSCreateClientGetInfo(t testing.TB, compress, web bool, host string, port int) (net.Conn, *bufio.Reader, []byte) {
t.Helper()
return testNewWSClient(t, testWSClientOptions{
compress: compress,
web: web,
host: host,
port: port,
})
}

func testWSCreateClient(t testing.TB, compress, web bool, host string, port int) (net.Conn, *bufio.Reader) {
wsc, br, _ := testWSCreateClientGetInfo(t, compress, web, host, port)
// Send CONNECT and PING
Expand Down Expand Up @@ -3157,6 +3266,124 @@ func TestWSNkeyAuth(t *testing.T) {
}
}

func TestJWTWSCookieUser(t *testing.T) {

nucSigFunc := func() *jwt.UserClaims { return newJWTTestUserClaims() }
nucBearerFunc := func() *jwt.UserClaims {
ret := newJWTTestUserClaims()
ret.BearerToken = true
return ret
}

o := testWSOptions()
setupAddTrusted(o)
setupAddCookie(o)
s := RunServer(o)
buildMemAccResolver(s)
defer s.Shutdown()

genJwt := func(t *testing.T, nuc *jwt.UserClaims) string {
okp, _ := nkeys.FromSeed(oSeed)

akp, _ := nkeys.CreateAccount()
apub, _ := akp.PublicKey()

nac := jwt.NewAccountClaims(apub)
ajwt, err := nac.Encode(okp)
if err != nil {
t.Fatalf("Error generating account JWT: %v", err)
}

nkp, _ := nkeys.CreateUser()
pub, _ := nkp.PublicKey()
nuc.Subject = pub
jwt, err := nuc.Encode(akp)
if err != nil {
t.Fatalf("Error generating user JWT: %v", err)
}
addAccountToMemResolver(s, apub, ajwt)
return jwt
}

cliOpts := testWSClientOptions{
host: o.Websocket.Host,
port: o.Websocket.Port,
}
for _, test := range []struct {
name string
nuc *jwt.UserClaims
opts func(t *testing.T, claims *jwt.UserClaims) (testWSClientOptions, testClaimsOptions)
expectAnswer string
}{
{
name: "protocol auth, non-bearer key, with signature",
nuc: nucSigFunc(),
opts: func(t *testing.T, claims *jwt.UserClaims) (testWSClientOptions, testClaimsOptions) {
return cliOpts, testClaimsOptions{nuc: claims}
},
expectAnswer: "+OK",
},
{
name: "protocol auth, non-bearer key, w/o required signature",
nuc: nucSigFunc(),
opts: func(t *testing.T, claims *jwt.UserClaims) (testWSClientOptions, testClaimsOptions) {
return cliOpts, testClaimsOptions{nuc: claims, dontSign: true}
},
expectAnswer: "-ERR",
},
{
name: "protocol auth, bearer key, w/o signature",
nuc: nucBearerFunc(),
opts: func(t *testing.T, claims *jwt.UserClaims) (testWSClientOptions, testClaimsOptions) {
return cliOpts, testClaimsOptions{nuc: claims, dontSign: true}
},
expectAnswer: "+OK",
},
{
name: "cookie auth, non-bearer key, protocol auth fail",
nuc: nucSigFunc(),
opts: func(t *testing.T, claims *jwt.UserClaims) (testWSClientOptions, testClaimsOptions) {
co := cliOpts
co.extraHeaders = map[string]string{}
co.extraHeaders["Cookie"] = o.Websocket.JWTCookie + "=" + genJwt(t, claims)
return co, testClaimsOptions{connectRequest: struct{}{}}
},
expectAnswer: "-ERR",
},
{
name: "cookie auth, bearer key, protocol auth success with implied cookie jwt",
nuc: nucBearerFunc(),
opts: func(t *testing.T, claims *jwt.UserClaims) (testWSClientOptions, testClaimsOptions) {
co := cliOpts
co.extraHeaders = map[string]string{}
co.extraHeaders["Cookie"] = o.Websocket.JWTCookie + "=" + genJwt(t, claims)
return co, testClaimsOptions{connectRequest: struct{}{}}
},
expectAnswer: "+OK",
},
{
name: "cookie auth, non-bearer key, protocol auth success via override jwt in CONNECT opts",
nuc: nucSigFunc(),
opts: func(t *testing.T, claims *jwt.UserClaims) (testWSClientOptions, testClaimsOptions) {
co := cliOpts
co.extraHeaders = map[string]string{}
co.extraHeaders["Cookie"] = o.Websocket.JWTCookie + "=" + genJwt(t, claims)
return co, testClaimsOptions{nuc: nucBearerFunc()}
},
expectAnswer: "+OK",
},
} {
t.Run(test.name, func(t *testing.T) {
cliOpt, claimOpt := test.opts(t, test.nuc)
claimOpt.expectAnswer = test.expectAnswer
_, c, _, _ := testWSWithClaims(t, s, cliOpt, claimOpt)
c.Close()
})
}
s.Shutdown()

}

// ==================================================================
// = Benchmark tests
// ==================================================================
Expand Down

0 comments on commit 158b971

Please sign in to comment.