Skip to content

Commit

Permalink
sessions: Introduce sessions package
Browse files Browse the repository at this point in the history
Isolated the coreos/go-oidc dependencies within our new sessions
package. Under the sessions package we have all the session related
functions and variables.

GitHub-PR: #104

Signed-off-by: Athanasios Markou <athamark@arrikto.com>
  • Loading branch information
Athanasios Markou committed Jan 5, 2023
1 parent 6dbd1a4 commit a09d006
Show file tree
Hide file tree
Showing 8 changed files with 62 additions and 53 deletions.
1 change: 1 addition & 0 deletions Dockerfile
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@ RUN go mod download
COPY *.go ./
COPY common common
COPY oidc oidc
COPY sessions sessions
RUN CGO_ENABLED=0 GOOS=linux go build -a -ldflags '-extldflags "-static"' -o /go/bin/oidc-authservice


Expand Down
12 changes: 6 additions & 6 deletions authenticator_session.go
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@ import (

"github.com/arrikto/oidc-authservice/common"
"github.com/arrikto/oidc-authservice/oidc"
"github.com/gorilla/sessions"
"github.com/arrikto/oidc-authservice/sessions"
"github.com/pkg/errors"
"golang.org/x/oauth2"
"k8s.io/apiserver/pkg/authentication/authenticator"
Expand Down Expand Up @@ -39,7 +39,7 @@ func (sa *sessionAuthenticator) AuthenticateRequest(r *http.Request) (*authentic
logger := common.LoggerForRequest(r, "session authenticator")

// Get session from header or cookie
session, authMethod, err := sessionFromRequest(r, sa.store, sa.cookie, sa.header)
session, authMethod, err := sessions.SessionFromRequest(r, sa.store, sa.cookie, sa.header)

// Check if user session is valid
if err != nil {
Expand All @@ -53,7 +53,7 @@ func (sa *sessionAuthenticator) AuthenticateRequest(r *http.Request) (*authentic
// User is logged in
if sa.strictSessionValidation {
ctx := common.SetTLSContext(r.Context(), sa.caBundle)
token := session.Values[userSessionOAuth2Tokens].(oauth2.Token)
token := session.Values[sessions.UserSessionOAuth2Tokens].(oauth2.Token)
// TokenSource takes care of automatically renewing the access token.
_, err := oidc.GetUserInfo(ctx, sa.provider, sa.oauth2Config.TokenSource(ctx, &token))
if err != nil {
Expand All @@ -70,7 +70,7 @@ func (sa *sessionAuthenticator) AuthenticateRequest(r *http.Request) (*authentic
// access to the ResponseWriter and thus can't set a cookie. This
// means that the cookie will remain at the user's browser but it
// will be replaced after the user logs in again.
err = revokeOIDCSession(ctx, httptest.NewRecorder(), session,
err = sessions.RevokeOIDCSession(ctx, httptest.NewRecorder(), session,
sa.provider, sa.oauth2Config, sa.caBundle)
if err != nil {
logger.Errorf("Failed to revoke tokens: %v", err)
Expand All @@ -82,7 +82,7 @@ func (sa *sessionAuthenticator) AuthenticateRequest(r *http.Request) (*authentic
// Data written at a previous version might not have groups stored, so
// default to an empty list of strings.
// TODO: Consolidate all session serialization/deserialization in one place.
groups, ok := session.Values[userSessionGroups].([]string)
groups, ok := session.Values[sessions.UserSessionGroups].([]string)
if !ok {
groups = []string{}
}
Expand All @@ -91,7 +91,7 @@ func (sa *sessionAuthenticator) AuthenticateRequest(r *http.Request) (*authentic

resp := &authenticator.Response{
User: &user.DefaultInfo{
Name: session.Values[userSessionUserID].(string),
Name: session.Values[sessions.UserSessionUserID].(string),
Groups: groups,
Extra: extra,
},
Expand Down
7 changes: 3 additions & 4 deletions main.go
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@ import (

"github.com/arrikto/oidc-authservice/common"
"github.com/arrikto/oidc-authservice/oidc"
"github.com/arrikto/oidc-authservice/sessions"
"github.com/gorilla/handlers"
"github.com/gorilla/mux"
"github.com/patrickmn/go-cache"
Expand All @@ -22,8 +23,6 @@ import (
clientconfig "sigs.k8s.io/controller-runtime/pkg/client/config"
)

// Issue: https://github.com/gorilla/sessions/issues/200
const secureCookieKeyPair = "notNeededBecauseCookieValueIsRandom"
const CacheCleanupInterval = 10

func main() {
Expand Down Expand Up @@ -99,7 +98,7 @@ func main() {

// Setup session store and state store using the configured session store
// type (BoltDB, or redis)
store, oidcStateStore := initiateSessionStores(c)
store, oidcStateStore := sessions.InitiateSessionStores(c)

defer store.Close()
defer oidcStateStore.Close()
Expand Down Expand Up @@ -136,7 +135,7 @@ func main() {
// Setup authenticators.
sessionAuthenticator := &sessionAuthenticator{
store: store,
cookie: userSessionCookie,
cookie: sessions.UserSessionCookie,
header: c.AuthHeader,
strictSessionValidation: c.StrictSessionValidation,
caBundle: caBundle,
Expand Down
32 changes: 16 additions & 16 deletions server.go
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@ import (

"github.com/arrikto/oidc-authservice/common"
"github.com/arrikto/oidc-authservice/oidc"
"github.com/gorilla/sessions"
"github.com/arrikto/oidc-authservice/sessions"
cache "github.com/patrickmn/go-cache"
"github.com/pkg/errors"
"github.com/tevino/abool"
Expand All @@ -38,8 +38,8 @@ var (
type server struct {
provider oidc.IdProvider
oauth2Config *oauth2.Config
store ClosableStore
oidcStateStore ClosableStore
store sessions.ClosableStore
oidcStateStore sessions.ClosableStore
bearerUserInfoCache *cache.Cache
authenticators []authenticator.Request
authorizers []Authorizer
Expand Down Expand Up @@ -247,12 +247,12 @@ func (s *server) authorized(w http.ResponseWriter, r *http.Request, userInfo use
// the session authenticator.
if !allowed {
logger.Infof("Authorizer '%T' denied the request with reason: '%s'", authz, reason)
session, _, err := sessionFromRequest(r, s.store, userSessionCookie, s.authHeader)
session, _, err := sessions.SessionFromRequest(r, s.store, sessions.UserSessionCookie, s.authHeader)
if err != nil {
logger.Errorf("Error getting session for request: %v", err)
}
if !session.IsNew {
err = revokeOIDCSession(r.Context(), w, session, s.provider, s.oauth2Config, s.caBundle)
err = sessions.RevokeOIDCSession(r.Context(), w, session, s.provider, s.oauth2Config, s.caBundle)
if err != nil {
logger.Errorf("Failed to revoke session after authorization fail: %v", err)
}
Expand Down Expand Up @@ -309,7 +309,7 @@ func (s *server) authCodeFlowAuthenticationRequest(w http.ResponseWriter, r *htt
logger := common.LoggerForRequest(r, logModuleInfo)

// Initiate OIDC Flow with Authorization Request.
state, err := createState(r, w, s.oidcStateStore)
state, err := sessions.CreateState(r, w, s.oidcStateStore)
if err != nil {
logger.Errorf("Failed to save state in store: %v", err)
common.ReturnMessage(w, http.StatusInternalServerError, "Failed to save state in store.")
Expand Down Expand Up @@ -343,7 +343,7 @@ func (s *server) callback(w http.ResponseWriter, r *http.Request) {
}

// If state is loaded, then it's correct, as it is saved by its id.
state, err := verifyState(r, w, s.oidcStateStore)
state, err := sessions.VerifyState(r, w, s.oidcStateStore)
if err != nil {
logger.Errorf("Failed to verify state parameter: %v", err)
common.ReturnMessage(w, http.StatusBadRequest, "CSRF check failed."+
Expand Down Expand Up @@ -394,7 +394,7 @@ func (s *server) callback(w http.ResponseWriter, r *http.Request) {
}

// User is authenticated, create new session.
session := sessions.NewSession(s.store, userSessionCookie)
session := sessions.NewSession(s.store, sessions.UserSessionCookie)
session.Options.MaxAge = s.sessionMaxAgeSeconds
session.Options.Path = "/"
// Extra layer of CSRF protection
Expand All @@ -414,11 +414,11 @@ func (s *server) callback(w http.ResponseWriter, r *http.Request) {
groups = common.InterfaceSliceToStringSlice(groupsClaim.([]interface{}))
}

session.Values[userSessionUserID] = userID
session.Values[userSessionGroups] = groups
session.Values[userSessionClaims] = claims
session.Values[userSessionIDToken] = rawIDToken
session.Values[userSessionOAuth2Tokens] = oauth2Tokens
session.Values[sessions.UserSessionUserID] = userID
session.Values[sessions.UserSessionGroups] = groups
session.Values[sessions.UserSessionClaims] = claims
session.Values[sessions.UserSessionIDToken] = rawIDToken
session.Values[sessions.UserSessionOAuth2Tokens] = oauth2Tokens
if err := session.Save(r, w); err != nil {
logger.Errorf("Couldn't create user session: %v", err)
common.ReturnMessage(w, http.StatusInternalServerError, "Error creating user session")
Expand Down Expand Up @@ -477,7 +477,7 @@ func (s *server) logout(w http.ResponseWriter, r *http.Request) {
}

// Revoke user session.
session, err := sessionFromID(sessionID, s.store)
session, err := sessions.SessionFromID(sessionID, s.store)
if err != nil {
logger.Errorf("Couldn't get user session: %v", err)
w.WriteHeader(http.StatusInternalServerError)
Expand All @@ -488,9 +488,9 @@ func (s *server) logout(w http.ResponseWriter, r *http.Request) {
w.WriteHeader(http.StatusUnauthorized)
return
}
logger = logger.WithField("userid", session.Values[userSessionUserID].(string))
logger = logger.WithField("userid", session.Values[sessions.UserSessionUserID].(string))

err = revokeOIDCSession(r.Context(), w, session, s.provider, s.oauth2Config, s.caBundle)
err = sessions.RevokeOIDCSession(r.Context(), w, session, s.provider, s.oauth2Config, s.caBundle)
if err != nil {
logger.Errorf("Error revoking tokens: %v", err)
statusCode := http.StatusInternalServerError
Expand Down
2 changes: 1 addition & 1 deletion session_store_boltdb.go → sessions/boltdb.go
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
package main
package sessions

import (
"os"
Expand Down
2 changes: 1 addition & 1 deletion session_store_redis.go → sessions/redis.go
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
package main
package sessions

import (
"context"
Expand Down
49 changes: 29 additions & 20 deletions session.go → sessions/session.go
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
package main
package sessions

import (
"context"
Expand All @@ -14,48 +14,57 @@ import (
)

const (
userSessionCookie = "authservice_session"
userSessionUserID = "userid"
userSessionGroups = "groups"
userSessionClaims = "claims"
userSessionIDToken = "idtoken"
userSessionOAuth2Tokens = "oauth2tokens"
// Issue: https://github.com/gorilla/sessions/issues/200
secureCookieKeyPair = "notNeededBecauseCookieValueIsRandom"

UserSessionCookie = "authservice_session"
UserSessionUserID = "userid"
UserSessionGroups = "groups"
UserSessionClaims = "claims"
UserSessionIDToken = "idtoken"
UserSessionOAuth2Tokens = "oauth2tokens"
)

type Store sessions.Store

type ClosableStore interface {
sessions.Store
Close() error
}

// sessionFromRequestHeader returns a session which has its key in a header.
func NewSession(store Store, name string) *sessions.Session{
return sessions.NewSession(store, name)
}

// SessionFromID returns a session which has its key in a header.
// XXX: Because the session library we use doesn't support getting a session
// by key, we need to fake a cookie
func sessionFromID(id string, store sessions.Store) (*sessions.Session, error) {
func SessionFromID(id string, store sessions.Store) (*sessions.Session, error) {
r := &http.Request{Header: make(http.Header)}
r.AddCookie(&http.Cookie{
// XXX: This is needed because the sessions lib we use also encodes
// cookies with securecookie, which requires passing the correct cookie
// name during decryption.
Name: userSessionCookie,
Name: UserSessionCookie,
Value: id,
Path: "/",
MaxAge: 1,
})
return store.Get(r, userSessionCookie)
return store.Get(r, UserSessionCookie)
}

// sessionFromRequest looks for a session id in a header and a cookie, in that
// SessionFromRequest looks for a session id in a header and a cookie, in that
// order. If it doesn't find a valid session in the header, it will then check
// the cookie.
func sessionFromRequest(r *http.Request, store sessions.Store, cookie,
func SessionFromRequest(r *http.Request, store sessions.Store, cookie,
header string) (*sessions.Session, string, error) {

var authMethod string
logger := common.LoggerForRequest(r, "session authenticator")
// Try to get session from header
sessionID := common.GetBearerToken(r.Header.Get(header))
if sessionID != "" {
s, err := sessionFromID(sessionID, store)
s, err := SessionFromID(sessionID, store)
if err == nil && !s.IsNew {
logger.Infof("Loading session from header %s", header)
// Authentication using header successfully completed
Expand Down Expand Up @@ -90,11 +99,11 @@ func revokeSession(ctx context.Context, w http.ResponseWriter,
return nil
}

// revokeOIDCSession revokes the given session, which is assumed to be an OIDC
// RevokeOIDCSession revokes the given session, which is assumed to be an OIDC
// session, for which it also performs the necessary cleanup.
// TODO: In the future, we may want to make this function take a function as
// input, instead of polluting it with extra arguments.
func revokeOIDCSession(ctx context.Context, w http.ResponseWriter,
func RevokeOIDCSession(ctx context.Context, w http.ResponseWriter,
session *sessions.Session, provider oidc.IdProvider,
oauth2Config *oauth2.Config, caBundle []byte) error {

Expand All @@ -105,25 +114,25 @@ func revokeOIDCSession(ctx context.Context, w http.ResponseWriter,
if err != nil {
logger.Warnf("Error getting provider's revocation_endpoint: %v", err)
} else {
token := session.Values[userSessionOAuth2Tokens].(oauth2.Token)
token := session.Values[UserSessionOAuth2Tokens].(oauth2.Token)
err := oidc.RevokeTokens(common.SetTLSContext(ctx, caBundle),
_revocationEndpoint, &token, oauth2Config.ClientID, oauth2Config.ClientSecret)
if err != nil {
return errors.Wrap(err, "Error revoking tokens")
}
logger.WithField("userid", session.Values[userSessionUserID].(string)).Info("Access/Refresh tokens revoked")
logger.WithField("userid", session.Values[UserSessionUserID].(string)).Info("Access/Refresh tokens revoked")
}

return revokeSession(ctx, w, session)
}

// initiateSessionStores initiates both the required stores for the:
// InitiateSessionStores initiates both the required stores for the:
// * users sessions
// * OIDC states
// Based on the configured session store (boltdb, or redis) this function will
// return these two session stores, or will terminate the execution with a fatal
// log message.
func initiateSessionStores(c *common.Config) (ClosableStore, ClosableStore) {
func InitiateSessionStores(c *common.Config) (ClosableStore, ClosableStore) {
logger := logrus.StandardLogger()

var store, oidcStateStore ClosableStore
Expand Down
10 changes: 5 additions & 5 deletions state.go → sessions/state.go
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
// Copyright © 2019 Arrikto Inc. All Rights Reserved.

package main
package sessions

import (
"encoding/gob"
Expand Down Expand Up @@ -40,11 +40,11 @@ func newState(firstVisitedURL string) *State {
}
}

// createState creates the state parameter from the incoming request, stores
// CreateState creates the state parameter from the incoming request, stores
// it in the session store and sets a cookie with the session key.
// It returns the session key, which can be used as the state value to start
// an OIDC authentication request.
func createState(r *http.Request, w http.ResponseWriter,
func CreateState(r *http.Request, w http.ResponseWriter,
store sessions.Store) (string, error) {

firstVisitedURL, err := url.Parse("")
Expand Down Expand Up @@ -75,7 +75,7 @@ func createState(r *http.Request, w http.ResponseWriter,
return c.Value, nil
}

// verifyState gets the state from the cookie 'initState' saved. It also gets
// VerifyState gets the state from the cookie 'initState' saved. It also gets
// the state from an http param and:
// 1. Confirms the two values match (CSRF check).
// 2. Confirms the value is still valid by retrieving the session it points to.
Expand All @@ -84,7 +84,7 @@ func createState(r *http.Request, w http.ResponseWriter,
//
// Finally, it returns a State struct, which contains information associated
// with the particular OIDC flow.
func verifyState(r *http.Request, w http.ResponseWriter,
func VerifyState(r *http.Request, w http.ResponseWriter,
store sessions.Store) (*State, error) {

// Get the state from the HTTP param.
Expand Down

0 comments on commit a09d006

Please sign in to comment.