Skip to content

Commit

Permalink
Enforce use of state parameter in browser flow
Browse files Browse the repository at this point in the history
  • Loading branch information
kian99 committed May 8, 2024
1 parent 9bbfaf9 commit 63ec24c
Show file tree
Hide file tree
Showing 7 changed files with 131 additions and 46 deletions.
26 changes: 16 additions & 10 deletions internal/auth/oauth2.go
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@ package auth

import (
"context"
"crypto/rand"
"encoding/base64"
stderrors "errors"
"fmt"
Expand Down Expand Up @@ -39,6 +40,9 @@ const (
// SessionIdentityKey is the key for the identity value stored within the
// session.
SessionIdentityKey = "identity-id"

// StateKey is the key for the OAuth callback state stored within a user's cookie.
StateKey = "jimm-oauth-state"
)

type sessionIdentityContextKey struct{}
Expand Down Expand Up @@ -149,15 +153,17 @@ func NewAuthenticationService(ctx context.Context, params AuthenticationServiceP
}

// AuthCodeURL returns a URL that will be used to redirect a browser to the identity provider.
func (as *AuthenticationService) AuthCodeURL() string {
// As we're not the browser creating the auth code url and then communicating back
// to the server, it is OK not to set a state as there's no communication
// between say many "tabs" and a JIMM deployment, but rather
// just JIMM creating the auth code URL itself, and then handling the exchanging
// itself. Of course, middleman attacks between the IdP and JIMM are possible,
// but we'd have much larger problems than an auth code interception at that
// point. As such, we're opting out of using auth code URL state.
return as.oauthConfig.AuthCodeURL("")
// It also generates a random state string that was used as part of the auth code URL. The state string
// is returned alongside the auth code URL and any errors that occured during state generation.
func (as *AuthenticationService) AuthCodeURL() (string, string, error) {
// Hydra requires at least 8 characters in the state parameter.
b := make([]byte, 8)
_, err := rand.Read(b)
if err != nil {
return "", "", err
}
state := base64.URLEncoding.EncodeToString(b)
return as.oauthConfig.AuthCodeURL(state), state, nil
}

// Exchange exchanges an authorisation code for an access token.
Expand Down Expand Up @@ -394,7 +400,7 @@ func (as *AuthenticationService) AuthenticateBrowserSession(ctx context.Context,

identityId, ok := session.Values[SessionIdentityKey]
if !ok {
return ctx, errors.E(op, "session is missing identity key")
return ctx, errors.E(op, errors.CodeForbidden, "session is missing identity key")
}

err = as.validateAndUpdateAccessToken(ctx, identityId)
Expand Down
8 changes: 5 additions & 3 deletions internal/auth/oauth2_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -63,12 +63,14 @@ func TestAuthCodeURL(t *testing.T) {

authSvc, _, _ := setupTestAuthSvc(ctx, c, time.Hour)

url := authSvc.AuthCodeURL()
url, state, err := authSvc.AuthCodeURL()
c.Assert(err, qt.IsNil)
c.Assert(
url,
qt.Equals,
`http://localhost:8082/realms/jimm/protocol/openid-connect/auth?client_id=jimm-device&redirect_uri=http%3A%2F%2Flocalhost%3A8080%2Fauth%2Fcallback&response_type=code&scope=openid+profile+email`,
qt.Matches,
regexp.MustCompile(`http:\/\/localhost:8082\/realms\/jimm\/protocol\/openid-connect\/auth\?client_id=jimm-device&redirect_uri=http%3A%2F%2Flocalhost%3A8080%2Fauth%2Fcallback&response_type=code&scope=openid\+profile\+email&state=.*`),
)
c.Assert(len(state), qt.Not(qt.Equals), 0)
}

// TestDevice is a unique test in that it runs through the entire device oauth2.0
Expand Down
59 changes: 49 additions & 10 deletions internal/jimmhttp/auth_handler.go
Original file line number Diff line number Diff line change
Expand Up @@ -15,10 +15,15 @@ import (
"github.com/canonical/jimm/internal/errors"
)

// CallbackEndpoint holds the endpoint path for OAuth2.0 authorisation
// flow callbacks.
// These consts holds the endpoint paths for OAuth2.0 related auth.
// AuthResourceBasePath forms the base path and the remainder are
// appended onto the base in practice.
const (
CallbackEndpoint = "/callback"
AuthResourceBasePath = "/auth"
CallbackEndpoint = "/callback"
WhoAmIEndpoint = "/whoami"
LogOutEndpoint = "/logout"
LoginEndpoint = "/login"
)

// OAuthHandler handles the oauth2.0 browser flow for JIMM.
Expand Down Expand Up @@ -47,7 +52,7 @@ type OAuthHandlerParams struct {
// BrowserOAuthAuthenticator handles authorisation code authentication within JIMM
// via OIDC.
type BrowserOAuthAuthenticator interface {
AuthCodeURL() string
AuthCodeURL() (string, string, error)
Exchange(ctx context.Context, code string) (*oauth2.Token, error)
UserInfo(ctx context.Context, oauth2Token *oauth2.Token) (string, error)
UpdateIdentity(ctx context.Context, email string, token *oauth2.Token) error
Expand Down Expand Up @@ -82,10 +87,10 @@ func NewOAuthHandler(p OAuthHandlerParams) (*OAuthHandler, error) {
// Routes returns the grouped routers routes with group specific middlewares.
func (oah *OAuthHandler) Routes() chi.Router {
oah.SetupMiddleware()
oah.Router.Get("/login", oah.Login)
oah.Router.Get(LoginEndpoint, oah.Login)
oah.Router.Get(CallbackEndpoint, oah.Callback)
oah.Router.Get("/logout", oah.Logout)
oah.Router.Get("/whoami", oah.Whoami)
oah.Router.Get(LogOutEndpoint, oah.Logout)
oah.Router.Get(WhoAmIEndpoint, oah.Whoami)
return oah.Router
}

Expand All @@ -95,17 +100,43 @@ func (oah *OAuthHandler) SetupMiddleware() {

// Login handles /auth/login.
func (oah *OAuthHandler) Login(w http.ResponseWriter, r *http.Request) {
redirectURL := oah.authenticator.AuthCodeURL()
ctx := r.Context()
redirectURL, state, err := oah.authenticator.AuthCodeURL()
if err != nil {
writeError(ctx, w, http.StatusInternalServerError, err, "failed to generate random state")
}
http.SetCookie(w, &http.Cookie{
Name: auth.StateKey,
Value: state,
MaxAge: 900, // 15 min.
Path: AuthResourceBasePath + CallbackEndpoint, // Only send the cookie back on /auth paths.
HttpOnly: true, // Restrict access from JS.
SameSite: http.SameSiteStrictMode, // Cannot be sent cross-origin.
})
http.Redirect(w, r, redirectURL, http.StatusTemporaryRedirect)
}

// Callback handles /auth/callback.
func (oah *OAuthHandler) Callback(w http.ResponseWriter, r *http.Request) {
ctx := r.Context()

stateByCookie, err := r.Cookie(auth.StateKey)
if err != nil {
usrErr := errors.E("no state cookie present")
writeError(ctx, w, http.StatusBadRequest, usrErr, "no state cookie present")
return
}
stateByURL := r.URL.Query().Get("state")
if stateByCookie.Value != stateByURL {
err := errors.E("state does not match")
writeError(ctx, w, http.StatusBadRequest, err, "state does not match")
return
}

code := r.URL.Query().Get("code")
if code == "" {
writeError(ctx, w, http.StatusBadRequest, nil, "no authorisation code present")
err := errors.E("missing auth code")
writeError(ctx, w, http.StatusBadRequest, err, "no authorisation code present")
return
}

Expand Down Expand Up @@ -172,6 +203,10 @@ func (oah *OAuthHandler) Whoami(w http.ResponseWriter, r *http.Request) {

ctx, err := authSvc.AuthenticateBrowserSession(ctx, w, r)
if err != nil {
if errors.ErrorCode(err) == errors.CodeForbidden {
w.WriteHeader(http.StatusForbidden)
return
}
writeError(ctx, w, http.StatusInternalServerError, err, "failed to authenticate users session")
return
}
Expand Down Expand Up @@ -201,5 +236,9 @@ func (oah *OAuthHandler) Whoami(w http.ResponseWriter, r *http.Request) {
func writeError(ctx context.Context, w http.ResponseWriter, status int, err error, logMessage string) {
zapctx.Error(ctx, logMessage, zap.Error(err))
w.WriteHeader(status)
w.Write([]byte(http.StatusText(status)))
errMsg := ""
if err != nil {
errMsg = " - " + err.Error()
}
w.Write([]byte(http.StatusText(status) + errMsg))
}
65 changes: 49 additions & 16 deletions internal/jimmhttp/auth_handler_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,9 @@ import (
"context"
"io"
"net/http"
"net/http/cookiejar"
"net/http/httptest"
"net/url"
"testing"
"time"

Expand All @@ -12,7 +15,9 @@ import (
"github.com/gorilla/sessions"

"github.com/canonical/jimm/api/params"
"github.com/canonical/jimm/internal/auth"
"github.com/canonical/jimm/internal/db"
"github.com/canonical/jimm/internal/jimmhttp"
"github.com/canonical/jimm/internal/jimmtest"
)

Expand All @@ -32,6 +37,16 @@ func setupDbAndSessionStore(c *qt.C) (*db.Database, sessions.Store) {
return db, store
}

func createClientWithStateCookie(c *qt.C, s *httptest.Server) *http.Client {
jar, err := cookiejar.New(nil)
c.Assert(err, qt.IsNil)
jimmURL, err := url.Parse(s.URL)
c.Assert(err, qt.IsNil)
stateCookie := http.Cookie{Name: auth.StateKey, Value: "123"}
jar.SetCookies(jimmURL, []*http.Cookie{&stateCookie})
return &http.Client{Jar: jar}
}

// TestBrowserLoginAndLogout goes through the flow of a browser logging in, simulating
// the cookie state and handling the callbacks are as expected. Additionally handling
// the final callback to the dashboard emulating an endpoint. See RunBrowserLogin
Expand Down Expand Up @@ -59,7 +74,7 @@ func TestBrowserLoginAndLogout(t *testing.T) {
c.Assert(cookie, qt.Not(qt.Equals), "")

// Run a whoami logged in
req, err := http.NewRequest("GET", jimmHTTPServer.URL+"/whoami", nil)
req, err := http.NewRequest("GET", jimmHTTPServer.URL+jimmhttp.AuthResourceBasePath+jimmhttp.WhoAmIEndpoint, nil)
c.Assert(err, qt.IsNil)
parsedCookies := jimmtest.ParseCookies(cookie)
c.Assert(parsedCookies, qt.HasLen, 1)
Expand All @@ -77,7 +92,7 @@ func TestBrowserLoginAndLogout(t *testing.T) {
})

// Logout
req, err = http.NewRequest("GET", jimmHTTPServer.URL+"/logout", nil)
req, err = http.NewRequest("GET", jimmHTTPServer.URL+jimmhttp.AuthResourceBasePath+jimmhttp.LogOutEndpoint, nil)
c.Assert(err, qt.IsNil)
req.AddCookie(parsedCookies[0])

Expand All @@ -87,7 +102,7 @@ func TestBrowserLoginAndLogout(t *testing.T) {
c.Assert(res.StatusCode, qt.Equals, http.StatusOK)

// Run a whoami logged out
req, err = http.NewRequest("GET", jimmHTTPServer.URL+"/whoami", nil)
req, err = http.NewRequest("GET", jimmHTTPServer.URL+jimmhttp.AuthResourceBasePath+jimmhttp.WhoAmIEndpoint, nil)
c.Assert(err, qt.IsNil)
parsedCookies = jimmtest.ParseCookies(cookie)
c.Assert(parsedCookies, qt.HasLen, 1)
Expand All @@ -96,22 +111,36 @@ func TestBrowserLoginAndLogout(t *testing.T) {
res, err = http.DefaultClient.Do(req)
c.Assert(err, qt.IsNil)
defer res.Body.Close()
c.Assert(res.StatusCode, qt.Equals, http.StatusInternalServerError)
b, err = io.ReadAll(res.Body)
c.Assert(err, qt.IsNil)
// TODO(ale8k): Really it isn't an internal server error here, the session is just
// missing in our store, we should probably bring this error up and return a forbidden.
c.Assert(string(b), qt.Equals, "Internal Server Error")
c.Assert(res.StatusCode, qt.Equals, http.StatusForbidden)

// Run a logout with no identity
req, err = http.NewRequest("GET", jimmHTTPServer.URL+"/logout", nil)
req, err = http.NewRequest("GET", jimmHTTPServer.URL+jimmhttp.AuthResourceBasePath+jimmhttp.LogOutEndpoint, nil)
c.Assert(err, qt.IsNil)
res, err = http.DefaultClient.Do(req)
c.Assert(err, qt.IsNil)
defer res.Body.Close()
c.Assert(res.StatusCode, qt.Equals, http.StatusForbidden)
}

func TestCallbackFailsNoState(t *testing.T) {
c := qt.New(t)

db, sessionStore := setupDbAndSessionStore(c)
s, err := jimmtest.SetupTestDashboardCallbackHandler("<no dashboard needed for this test>", db, sessionStore)
c.Assert(err, qt.IsNil)
defer s.Close()

callbackURL := s.URL + jimmhttp.AuthResourceBasePath + jimmhttp.CallbackEndpoint
res, err := http.Get(callbackURL)
c.Assert(err, qt.IsNil)

defer res.Body.Close()

b, err := io.ReadAll(res.Body)
c.Assert(err, qt.IsNil)
c.Assert(string(b), qt.Equals, http.StatusText(http.StatusBadRequest)+" - no state cookie present")
}

func TestCallbackFailsNoCodePresent(t *testing.T) {
c := qt.New(t)

Expand All @@ -120,15 +149,17 @@ func TestCallbackFailsNoCodePresent(t *testing.T) {
c.Assert(err, qt.IsNil)
defer s.Close()

// Test with no code present at all
res, err := http.Get(s.URL + "/callback")
client := createClientWithStateCookie(c, s)

callbackURL := s.URL + jimmhttp.AuthResourceBasePath + jimmhttp.CallbackEndpoint
res, err := client.Get(callbackURL + "?state=123")
c.Assert(err, qt.IsNil)

defer res.Body.Close()

b, err := io.ReadAll(res.Body)
c.Assert(err, qt.IsNil)
c.Assert(string(b), qt.Equals, http.StatusText(http.StatusBadRequest))
c.Assert(string(b), qt.Equals, http.StatusText(http.StatusBadRequest)+" - missing auth code")
}

func TestCallbackFailsExchange(t *testing.T) {
Expand All @@ -139,13 +170,15 @@ func TestCallbackFailsExchange(t *testing.T) {
c.Assert(err, qt.IsNil)
defer s.Close()

// Test with no code present at all
res, err := http.Get(s.URL + "/callback?code=idonotexist")
client := createClientWithStateCookie(c, s)
callbackURL := s.URL + jimmhttp.AuthResourceBasePath + jimmhttp.CallbackEndpoint
c.Assert(err, qt.IsNil)
res, err := client.Get(callbackURL + "?code=idonotexist&state=123")
c.Assert(err, qt.IsNil)

defer res.Body.Close()

b, err := io.ReadAll(res.Body)
c.Assert(err, qt.IsNil)
c.Assert(string(b), qt.Equals, http.StatusText(http.StatusBadRequest))
c.Assert(string(b), qt.Equals, http.StatusText(http.StatusBadRequest)+" - authorisation code exchange failed")
}
13 changes: 9 additions & 4 deletions internal/jimmtest/auth.go
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@ import (
"time"

"github.com/coreos/go-oidc/v3/oidc"
"github.com/go-chi/chi/v5"
"github.com/gorilla/sessions"
"github.com/juju/juju/api"
jujuparams "github.com/juju/juju/rpc/params"
Expand Down Expand Up @@ -118,7 +119,7 @@ func SetupTestDashboardCallbackHandler(browserURL string, db *db.Database, sessi
s.Listener = listener

// Remember redirect url to check it matches after test server starts
redirectURL := "http://127.0.0.1:" + port + "/callback"
redirectURL := "http://127.0.0.1:" + port + jimmhttp.AuthResourceBasePath + jimmhttp.CallbackEndpoint
authSvc, err := auth.NewAuthenticationService(context.Background(), auth.AuthenticationServiceParams{
IssuerURL: "http://localhost:8082/realms/jimm",
ClientID: "jimm-device",
Expand All @@ -144,12 +145,15 @@ func SetupTestDashboardCallbackHandler(browserURL string, db *db.Database, sessi
return nil, err
}

s.Config.Handler = h.Routes()
mux := chi.NewMux()
mux.Mount(jimmhttp.AuthResourceBasePath, h.Routes())
s.Config.Handler = mux

s.Start()

// Ensure redirectURL is matching port on listener
if s.URL+"/callback" != redirectURL {
callbackURL := s.URL + jimmhttp.AuthResourceBasePath + jimmhttp.CallbackEndpoint
if callbackURL != redirectURL {
return s, errors.New("server callback does not match redirectURL")
}

Expand Down Expand Up @@ -208,7 +212,8 @@ func runBrowserLogin(db *db.Database, sessionStore sessions.Store, username, pas
},
}

res, err := client.Get(s.URL + "/login")
loginURL := s.URL + jimmhttp.AuthResourceBasePath + jimmhttp.LoginEndpoint
res, err := client.Get(loginURL)
if err != nil {
return cookieString, s, err
}
Expand Down
1 change: 1 addition & 0 deletions local/traefik/certs/certs.sh
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@ if [ "$1" != "--force" ]; then
if [ -f "server.crt" ] && [ -f "server.key" ]; then
echo "Server certs already exist. Skipping cert generation."
echo "Run with --force to regenerate."
exit 0
fi
fi

Expand Down
Loading

0 comments on commit 63ec24c

Please sign in to comment.