diff --git a/server/application/terminal.go b/server/application/terminal.go index bea1f6ea6a110..3147d15ca0030 100644 --- a/server/application/terminal.go +++ b/server/application/terminal.go @@ -38,12 +38,12 @@ type terminalHandler struct { allowedShells []string namespace string enabledNamespaces []string - sessionManager util_session.SessionManager + sessionManager *util_session.SessionManager } // NewHandler returns a new terminal handler. func NewHandler(appLister applisters.ApplicationLister, namespace string, enabledNamespaces []string, db db.ArgoDB, enf *rbac.Enforcer, cache *servercache.Cache, - appResourceTree AppResourceTreeFn, allowedShells []string, sessionManager util_session.SessionManager) *terminalHandler { + appResourceTree AppResourceTreeFn, allowedShells []string, sessionManager *util_session.SessionManager) *terminalHandler { return &terminalHandler{ appLister: appLister, db: db, diff --git a/server/application/websocket.go b/server/application/websocket.go index faee91c4f47e4..b04330c45c3d7 100644 --- a/server/application/websocket.go +++ b/server/application/websocket.go @@ -37,7 +37,7 @@ type terminalSession struct { tty bool readLock sync.Mutex writeLock sync.Mutex - sessionManager util_session.SessionManager + sessionManager *util_session.SessionManager token *string } @@ -48,7 +48,7 @@ func getToken(r *http.Request) (string, error) { } // newTerminalSession create terminalSession -func newTerminalSession(w http.ResponseWriter, r *http.Request, responseHeader http.Header, sessionManager util_session.SessionManager) (*terminalSession, error) { +func newTerminalSession(w http.ResponseWriter, r *http.Request, responseHeader http.Header, sessionManager *util_session.SessionManager) (*terminalSession, error) { token, err := getToken(r) if err != nil { return nil, err diff --git a/server/server.go b/server/server.go index 26182af123185..a70391ea01b1f 100644 --- a/server/server.go +++ b/server/server.go @@ -982,7 +982,7 @@ func (a *ArgoCDServer) newHTTPServer(ctx context.Context, port int, grpcWebHandl } mux.Handle("/api/", handler) - terminal := application.NewHandler(a.appLister, a.Namespace, a.ApplicationNamespaces, a.db, a.enf, a.Cache, appResourceTreeFn, a.settings.ExecShells, *a.sessionMgr). + terminal := application.NewHandler(a.appLister, a.Namespace, a.ApplicationNamespaces, a.db, a.enf, a.Cache, appResourceTreeFn, a.settings.ExecShells, a.sessionMgr). WithFeatureFlagMiddleware(a.settingsMgr.GetSettings) th := util_session.WithAuthMiddleware(a.DisableAuth, a.sessionMgr, terminal) mux.Handle("/terminal", th) diff --git a/util/session/sessionmanager.go b/util/session/sessionmanager.go index d11c96c6cf5aa..af22ca0f2502e 100644 --- a/util/session/sessionmanager.go +++ b/util/session/sessionmanager.go @@ -10,6 +10,7 @@ import ( "net/http" "os" "strings" + "sync" "time" "github.com/coreos/go-oidc/v3/oidc" @@ -41,6 +42,7 @@ type SessionManager struct { storage UserStateStorage sleep func(d time.Duration) verificationDelayNoiseEnabled bool + failedLock sync.RWMutex } // LoginAttempts is a timestamped counter for failed login attempts @@ -284,7 +286,7 @@ func (mgr *SessionManager) Parse(tokenString string) (jwt.Claims, string, error) return token.Claims, newToken, nil } -// GetLoginFailures retrieves the login failure information from the cache +// GetLoginFailures retrieves the login failure information from the cache. Any modifications to the LoginAttemps map must be done in a thread-safe manner. func (mgr *SessionManager) GetLoginFailures() map[string]LoginAttempts { // Get failures from the cache var failures map[string]LoginAttempts @@ -299,12 +301,12 @@ func (mgr *SessionManager) GetLoginFailures() map[string]LoginAttempts { return failures } -func expireOldFailedAttempts(maxAge time.Duration, failures *map[string]LoginAttempts) int { +func expireOldFailedAttempts(maxAge time.Duration, failures map[string]LoginAttempts) int { expiredCount := 0 - for key, attempt := range *failures { + for key, attempt := range failures { if time.Since(attempt.LastFailed) > maxAge*time.Second { expiredCount += 1 - delete(*failures, key) + delete(failures, key) } } return expiredCount @@ -328,12 +330,14 @@ func pickRandomNonAdminLoginFailure(failures map[string]LoginAttempts, username // Updates the failure count for a given username. If failed is true, increases the counter. Otherwise, sets counter back to 0. func (mgr *SessionManager) updateFailureCount(username string, failed bool) { + mgr.failedLock.Lock() + defer mgr.failedLock.Unlock() failures := mgr.GetLoginFailures() // Expire old entries in the cache if we have a failure window defined. if window := getLoginFailureWindow(); window > 0 { - count := expireOldFailedAttempts(window, &failures) + count := expireOldFailedAttempts(window, failures) if count > 0 { log.Infof("Expired %d entries from session cache due to max age reached", count) } @@ -380,6 +384,8 @@ func (mgr *SessionManager) updateFailureCount(username string, failed bool) { // Get the current login failure attempts for given username func (mgr *SessionManager) getFailureCount(username string) LoginAttempts { + mgr.failedLock.RLock() + defer mgr.failedLock.RUnlock() failures := mgr.GetLoginFailures() attempt, ok := failures[username] if !ok {