diff --git a/auth/auth.go b/auth/auth.go index 613d6f0..d03b7a1 100644 --- a/auth/auth.go +++ b/auth/auth.go @@ -90,6 +90,6 @@ func NewAuthenticator(cj cookies.ICookieJar, OIDCconfig *oidc.Config, OauthConfi return &Authenticator{Cookiejar: cj, OIDCconfig: OIDCconfig, OauthConfig: OauthConfig, verifierProvider: verifierProvider} } -func (auth *Authenticator) GetVerifier() *oidc.Provider { - return auth.verifierProvider +func (auth *Authenticator) GetTokenVerifier() *oidc.IDTokenVerifier { + return auth.verifierProvider.Verifier(auth.OIDCconfig) } diff --git a/auth/cookies/cookie.go b/auth/cookies/cookie.go index 60ad5d7..8af4c3f 100644 --- a/auth/cookies/cookie.go +++ b/auth/cookies/cookie.go @@ -8,7 +8,8 @@ import ( ) type ICookieJar interface { - SetCallBackCookie(*gin.Context, string, string) + SetCallBackState(context *gin.Context, name string, stateValue string) + StateSession(contex *gin.Context, name string) (value string, isNew bool) } type CookieJar struct { store sessions.Store @@ -18,15 +19,30 @@ func NewCookieJar(secret []byte) *CookieJar { return &CookieJar{store: sessions.NewCookieStore(secret)} } -func (cj *CookieJar) SetCallBackCookie(g *gin.Context, name string, stateValue string) { +func (cj *CookieJar) SetCallBackState(context *gin.Context, name string, stateValue string) { session := sessions.NewSession(cj.store, name) session.Values["state"] = stateValue session.Options.MaxAge = 60 * 5 session.Options.Path = "/" - session.Options.Secure = g.Request.TLS != nil + session.Options.Secure = context.Request.TLS != nil - if err := cj.store.Save(g.Request, g.Writer, session); err != nil { + if err := cj.store.Save(context.Request, context.Writer, session); err != nil { fmt.Println("[error] failed to store session") return } } + +func (cj *CookieJar) StateSession(context *gin.Context, stateValue string) (value string, isNew bool) { + session, _ := cj.store.Get(context.Request, stateValue) + + if session.IsNew { + return "", true + } + + state, ok := session.Values["state"].(string) + if ok { + return state, false + } + + return "", true +} diff --git a/auth/gin_oidc.go b/auth/gin_oidc.go index 8a15f2b..35a4cc0 100644 --- a/auth/gin_oidc.go +++ b/auth/gin_oidc.go @@ -14,7 +14,7 @@ const ( CALLBACK_NONCE = "soarca_gui_nonce" ) -func (auth *Authenticator) RedirectToOIDCLogin(context *gin.Context) { +func (auth *Authenticator) OIDCRedirectToLogin(context *gin.Context) { state, err := randString(32) if err != nil { api.JSONErrorStatus(context, http.StatusInsufficientStorage, errors.New("failed to generate state")) @@ -25,8 +25,56 @@ func (auth *Authenticator) RedirectToOIDCLogin(context *gin.Context) { api.JSONErrorStatus(context, http.StatusInsufficientStorage, errors.New("failed to generate nonce")) return } - auth.Cookiejar.SetCallBackCookie(context, CALLBACK_STATE, state) - auth.Cookiejar.SetCallBackCookie(context, CALLBACK_NONCE, nonce) + auth.Cookiejar.SetCallBackState(context, CALLBACK_STATE, state) + auth.Cookiejar.SetCallBackState(context, CALLBACK_NONCE, nonce) context.Redirect(http.StatusFound, auth.OauthConfig.AuthCodeURL(state, oidc.Nonce(nonce))) } + +func (auth *Authenticator) OIDCCallBack(context *gin.Context) { + state, isNew := auth.Cookiejar.StateSession(context, CALLBACK_STATE) + if isNew || state == "" { + api.JSONErrorStatus(context, http.StatusInternalServerError, errors.New("state missing")) + return + } + + cookie, err := context.Request.Cookie(CALLBACK_STATE) + if err != nil { + api.JSONErrorStatus(context, http.StatusBadRequest, errors.New("state missing from client")) + return + } + + if cookie.Value != state { + api.JSONErrorStatus(context, http.StatusUnauthorized, errors.New("state mismatch")) + return + } + + oauth2Token, err := auth.OauthConfig.Exchange(context, context.Query("code")) + if err != nil { + api.JSONErrorStatus(context, http.StatusBadRequest, errors.New("could not get code from URL")) + return + } + rawIDtoken, ok := oauth2Token.Extra("id_token").(string) + + if !ok { + api.JSONErrorStatus(context, http.StatusBadRequest, errors.New("could not obtain code from URL")) + return + } + + verifier := auth.GetTokenVerifier() + verifiedIDToken, err := verifier.Verify(context, rawIDtoken) + if err != nil { + api.JSONErrorStatus(context, http.StatusInternalServerError, errors.New("failed to verify ID token")) + return + } + + nonce, err := context.Request.Cookie(CALLBACK_NONCE) + if err != nil { + api.JSONErrorStatus(context, http.StatusInternalServerError, errors.New("missing id token")) + return + } + if verifiedIDToken.Nonce != nonce.Value { + api.JSONErrorStatus(context, http.StatusBadRequest, errors.New("nonce for verified id token did not match")) + return + } +} diff --git a/handlers/oidc_handler.go b/handlers/oidc_handler.go index b1d8ad1..3b7555f 100644 --- a/handlers/oidc_handler.go +++ b/handlers/oidc_handler.go @@ -14,17 +14,20 @@ type OIDCAuthHandler struct { authenticator *auth.Authenticator } -func NewOIDCAuthHanlder(authenticator *auth.Authenticator) *OIDCAuthHandler { +func NewOIDCAuthHandler(authenticator *auth.Authenticator) *OIDCAuthHandler { return &OIDCAuthHandler{authenticator: authenticator} } -func (a *OIDCAuthHandler) OIDCAuthPageHandler(context *gin.Context) { +func (auth *OIDCAuthHandler) OIDCAuthPageHandler(context *gin.Context) { // context.Header("HX-Redirect", "/dashboard") // context.String(http.StatusFound, "") render := utils.NewTempl(context, http.StatusOK, authviews.OIDCLoginIndex()) context.Render(http.StatusOK, render) } -func (a *OIDCAuthHandler) OIDCLoginHandler(context *gin.Context) { - a.authenticator.RedirectToOIDCLogin(context) +func (auth *OIDCAuthHandler) OIDCLoginHandler(context *gin.Context) { + auth.authenticator.OIDCRedirectToLogin(context) +} + +func (auth *OIDCAuthHandler) OIDCCallBackHandler(context *gin.Context) { } diff --git a/routes/routes.go b/routes/routes.go index e3f1d76..75829f4 100644 --- a/routes/routes.go +++ b/routes/routes.go @@ -40,7 +40,7 @@ func Setup(app *gin.Engine) { func PublicOIDCRoutes(app *gin.RouterGroup) { auth := auth.SetupOIDCAuthHandler() - authHandler := handlers.NewOIDCAuthHanlder(auth) + authHandler := handlers.NewOIDCAuthHandler(auth) publicRoute := app.Group("/") { publicRoute.GET("/", authHandler.OIDCAuthPageHandler)