diff --git a/server/client.go b/server/client.go index 8196cb44e01..8fd775f5c59 100644 --- a/server/client.go +++ b/server/client.go @@ -1497,7 +1497,15 @@ func (c *client) processConnect(arg []byte) error { lang := c.opts.Lang account := c.opts.Account accountNew := c.opts.AccountNew + + // If JWT is not in opts, imply cookie JWT + if c.opts.JWT == "" { + if ws := c.ws; ws != nil { + c.opts.JWT = ws.cookieJwt + } + } ujwt := c.opts.JWT + // For headers both client and server need to support. c.headers = supportsHeaders && c.opts.Headers c.mu.Unlock() diff --git a/server/opts.go b/server/opts.go index 88a63a762aa..d3ad78359ce 100644 --- a/server/opts.go +++ b/server/opts.go @@ -265,6 +265,10 @@ type WebsocketOpts struct { // Users defined here or in the global options. NoAuthUser string + // Name of the cookie, which if present, will be treated as JWT. + // If not valid or not present, protocol authorization is possible. + JWTCookie string + // Authentication section. If anything is configured in this section, // it will override the authorization configuration for regular clients. Username string @@ -3152,6 +3156,8 @@ func parseWebsocket(v interface{}, o *Options, errors *[]error, warnings *[]erro if auth.nkeys != nil { o.Websocket.Nkeys = append(o.Websocket.Nkeys, auth.nkeys...) } + case "jwt_cookie": + o.Websocket.JWTCookie = mv.(string) case "no_auth_user": o.Websocket.NoAuthUser = mv.(string) default: diff --git a/server/websocket.go b/server/websocket.go index 4cf5cec1062..8c16783059e 100644 --- a/server/websocket.go +++ b/server/websocket.go @@ -90,6 +90,7 @@ type websocket struct { closeSent bool browser bool compressor *flate.Writer + cookieJwt string } type srvWebsocket struct { @@ -597,6 +598,11 @@ func (s *Server) wsUpgrade(w http.ResponseWriter, r *http.Request) (*wsUpgradeRe if ua := r.Header.Get("User-Agent"); ua != "" && strings.HasPrefix(ua, "Mozilla/") { ws.browser = true } + if opts.Websocket.JWTCookie != "" { + if c, err := r.Cookie(opts.Websocket.JWTCookie); err == nil && c != nil { + ws.cookieJwt = c.Value + } + } return &wsUpgradeResult{conn: conn, ws: ws}, nil } @@ -748,6 +754,12 @@ func validateWebsocketOptions(o *Options) error { } return fmt.Errorf("websocket no_auth_user %q not found in users configuration", wo.NoAuthUser) } + // Using JWT requires Trusted Keys + if wo.JWTCookie != "" { + if len(o.TrustedOperators) == 0 && len(o.TrustedKeys) == 0 { + return fmt.Errorf("trusted operators or trusted keys configuration is required for JWT authentication via cookie %q", wo.JWTCookie) + } + } return nil } diff --git a/server/websocket_test.go b/server/websocket_test.go index 9a625a38d7b..2ee4230322f 100644 --- a/server/websocket_test.go +++ b/server/websocket_test.go @@ -36,6 +36,7 @@ import ( "testing" "time" + "github.com/nats-io/jwt/v2" "github.com/nats-io/nkeys" ) @@ -1584,9 +1585,16 @@ func TestWSAbnormalFailureOfWebServer(t *testing.T) { } } -func testWSCreateClientGetInfo(t testing.TB, compress, web bool, host string, port int) (net.Conn, *bufio.Reader, []byte) { +type testWSClientOptions struct { + compress, web bool + host string + port int + extraHeaders map[string]string +} + +func testNewWSClient(t testing.TB, o testWSClientOptions) (net.Conn, *bufio.Reader, []byte) { t.Helper() - addr := fmt.Sprintf("%s:%d", host, port) + addr := fmt.Sprintf("%s:%d", o.host, o.port) wsc, err := net.Dial("tcp", addr) if err != nil { t.Fatalf("Error creating ws connection: %v", err) @@ -1596,12 +1604,17 @@ func testWSCreateClientGetInfo(t testing.TB, compress, web bool, host string, po t.Fatalf("Error during handshake: %v", err) } req := testWSCreateValidReq() - if compress { + if o.compress { req.Header.Set("Sec-Websocket-Extensions", "permessage-deflate") } - if web { + if o.web { req.Header.Set("User-Agent", "Mozilla/5.0") } + if o.extraHeaders != nil { + for hdr, val := range o.extraHeaders { + req.Header.Add(hdr, val) + } + } req.URL, _ = url.Parse("wss://" + addr) if err := req.Write(wsc); err != nil { t.Fatalf("Error sending request: %v", err) @@ -1623,6 +1636,102 @@ func testWSCreateClientGetInfo(t testing.TB, compress, web bool, host string, po return wsc, br, info } +type testClaimsOptions struct { + nac *jwt.AccountClaims + nuc *jwt.UserClaims + connectRequest interface{} + dontSign bool + expectAnswer string +} + +func testWSWithClaims(t *testing.T, s *Server, o testWSClientOptions, tclm testClaimsOptions) (kp nkeys.KeyPair, conn net.Conn, rdr *bufio.Reader, auth_was_required bool) { + t.Helper() + + okp, _ := nkeys.FromSeed(oSeed) + + akp, _ := nkeys.CreateAccount() + apub, _ := akp.PublicKey() + if tclm.nac == nil { + tclm.nac = jwt.NewAccountClaims(apub) + } else { + tclm.nac.Subject = apub + } + ajwt, err := tclm.nac.Encode(okp) + if err != nil { + t.Fatalf("Error generating account JWT: %v", err) + } + + nkp, _ := nkeys.CreateUser() + pub, _ := nkp.PublicKey() + if tclm.nuc == nil { + tclm.nuc = jwt.NewUserClaims(pub) + } else { + tclm.nuc.Subject = pub + } + jwt, err := tclm.nuc.Encode(akp) + if err != nil { + t.Fatalf("Error generating user JWT: %v", err) + } + + addAccountToMemResolver(s, apub, ajwt) + + c, cr, l := testNewWSClient(t, o) + + var info struct { + Nonce string `json:"nonce,omitempty"` + AuthRequired bool `json:"auth_required,omitempty"` + } + + if err := json.Unmarshal([]byte(l[5:]), &info); err != nil { + t.Fatal(err) + } + if info.AuthRequired { + cs := "" + if tclm.connectRequest != nil { + customReq, err := json.Marshal(tclm.connectRequest) + if err != nil { + t.Fatal(err) + } + // PING needed to flush the +OK/-ERR to us. + cs = fmt.Sprintf("CONNECT %v\r\nPING\r\n", string(customReq)) + } else if !tclm.dontSign { + // Sign Nonce + sigraw, _ := nkp.Sign([]byte(info.Nonce)) + sig := base64.RawURLEncoding.EncodeToString(sigraw) + cs = fmt.Sprintf("CONNECT {\"jwt\":%q,\"sig\":\"%s\",\"verbose\":true,\"pedantic\":true}\r\nPING\r\n", jwt, sig) + } else { + cs = fmt.Sprintf("CONNECT {\"jwt\":%q,\"verbose\":true,\"pedantic\":true}\r\nPING\r\n", jwt) + } + wsmsg := testWSCreateClientMsg(wsBinaryMessage, 1, true, false, []byte(cs)) + c.Write(wsmsg) + l = testWSReadFrame(t, cr) + if !strings.HasPrefix(string(l), tclm.expectAnswer) { + t.Fatalf("Expected %q, got %q", tclm.expectAnswer, l) + } + } + return akp, c, cr, info.AuthRequired +} + +func setupAddTrusted(o *Options) { + kp, _ := nkeys.FromSeed(oSeed) + pub, _ := kp.PublicKey() + o.TrustedKeys = []string{pub} +} + +func setupAddCookie(o *Options) { + o.Websocket.JWTCookie = "jwt" +} + +func testWSCreateClientGetInfo(t testing.TB, compress, web bool, host string, port int) (net.Conn, *bufio.Reader, []byte) { + t.Helper() + return testNewWSClient(t, testWSClientOptions{ + compress: compress, + web: web, + host: host, + port: port, + }) +} + func testWSCreateClient(t testing.TB, compress, web bool, host string, port int) (net.Conn, *bufio.Reader) { wsc, br, _ := testWSCreateClientGetInfo(t, compress, web, host, port) // Send CONNECT and PING @@ -3157,6 +3266,124 @@ func TestWSNkeyAuth(t *testing.T) { } } +func TestJWTWSCookieUser(t *testing.T) { + + nucSigFunc := func() *jwt.UserClaims { return newJWTTestUserClaims() } + nucBearerFunc := func() *jwt.UserClaims { + ret := newJWTTestUserClaims() + ret.BearerToken = true + return ret + } + + o := testWSOptions() + setupAddTrusted(o) + setupAddCookie(o) + s := RunServer(o) + buildMemAccResolver(s) + defer s.Shutdown() + + genJwt := func(t *testing.T, nuc *jwt.UserClaims) string { + okp, _ := nkeys.FromSeed(oSeed) + + akp, _ := nkeys.CreateAccount() + apub, _ := akp.PublicKey() + + nac := jwt.NewAccountClaims(apub) + ajwt, err := nac.Encode(okp) + if err != nil { + t.Fatalf("Error generating account JWT: %v", err) + } + + nkp, _ := nkeys.CreateUser() + pub, _ := nkp.PublicKey() + nuc.Subject = pub + jwt, err := nuc.Encode(akp) + if err != nil { + t.Fatalf("Error generating user JWT: %v", err) + } + addAccountToMemResolver(s, apub, ajwt) + return jwt + } + + cliOpts := testWSClientOptions{ + host: o.Websocket.Host, + port: o.Websocket.Port, + } + for _, test := range []struct { + name string + nuc *jwt.UserClaims + opts func(t *testing.T, claims *jwt.UserClaims) (testWSClientOptions, testClaimsOptions) + expectAnswer string + }{ + { + name: "protocol auth, non-bearer key, with signature", + nuc: nucSigFunc(), + opts: func(t *testing.T, claims *jwt.UserClaims) (testWSClientOptions, testClaimsOptions) { + return cliOpts, testClaimsOptions{nuc: claims} + }, + expectAnswer: "+OK", + }, + { + name: "protocol auth, non-bearer key, w/o required signature", + nuc: nucSigFunc(), + opts: func(t *testing.T, claims *jwt.UserClaims) (testWSClientOptions, testClaimsOptions) { + return cliOpts, testClaimsOptions{nuc: claims, dontSign: true} + }, + expectAnswer: "-ERR", + }, + { + name: "protocol auth, bearer key, w/o signature", + nuc: nucBearerFunc(), + opts: func(t *testing.T, claims *jwt.UserClaims) (testWSClientOptions, testClaimsOptions) { + return cliOpts, testClaimsOptions{nuc: claims, dontSign: true} + }, + expectAnswer: "+OK", + }, + { + name: "cookie auth, non-bearer key, protocol auth fail", + nuc: nucSigFunc(), + opts: func(t *testing.T, claims *jwt.UserClaims) (testWSClientOptions, testClaimsOptions) { + co := cliOpts + co.extraHeaders = map[string]string{} + co.extraHeaders["Cookie"] = o.Websocket.JWTCookie + "=" + genJwt(t, claims) + return co, testClaimsOptions{connectRequest: struct{}{}} + }, + expectAnswer: "-ERR", + }, + { + name: "cookie auth, bearer key, protocol auth success with implied cookie jwt", + nuc: nucBearerFunc(), + opts: func(t *testing.T, claims *jwt.UserClaims) (testWSClientOptions, testClaimsOptions) { + co := cliOpts + co.extraHeaders = map[string]string{} + co.extraHeaders["Cookie"] = o.Websocket.JWTCookie + "=" + genJwt(t, claims) + return co, testClaimsOptions{connectRequest: struct{}{}} + }, + expectAnswer: "+OK", + }, + { + name: "cookie auth, non-bearer key, protocol auth success via override jwt in CONNECT opts", + nuc: nucSigFunc(), + opts: func(t *testing.T, claims *jwt.UserClaims) (testWSClientOptions, testClaimsOptions) { + co := cliOpts + co.extraHeaders = map[string]string{} + co.extraHeaders["Cookie"] = o.Websocket.JWTCookie + "=" + genJwt(t, claims) + return co, testClaimsOptions{nuc: nucBearerFunc()} + }, + expectAnswer: "+OK", + }, + } { + t.Run(test.name, func(t *testing.T) { + cliOpt, claimOpt := test.opts(t, test.nuc) + claimOpt.expectAnswer = test.expectAnswer + _, c, _, _ := testWSWithClaims(t, s, cliOpt, claimOpt) + c.Close() + }) + } + s.Shutdown() + +} + // ================================================================== // = Benchmark tests // ==================================================================