From 1680d60c720476c14b59295ec00e430d88c0ddce Mon Sep 17 00:00:00 2001 From: rosstimothy <39066650+rosstimothy@users.noreply.github.com> Date: Tue, 4 Jun 2024 13:08:08 -0400 Subject: [PATCH] Reduce backend load created from session trackers (#42324) Session trackers were originally added to facilitate joining sessions and enforcing moderation policies. When a session is created, a new tracker is written to the backend and a background routine is spawned to periodically update the status of the tracker until the session is terminated. This can cause a massive amount of backend activity for a cluster that is spawning large quantities of sessions per second. While in most cases where humans are starting the sessions this isn't a problem, any machine id heavy use cases could trigger backend throttling. Since non-interactive sessions and sessions started by tbot are not meant to be joined or moderated, the existence of a session tracker for them doesn't provide much benefit, especially now that session recordings are disabled for non-interactive sessions. To prevent excess backend writes session trackers are no longer created for non-interactive and tbot sessions. --- lib/srv/sess.go | 25 ++++++++++++++++++------- lib/srv/sess_test.go | 40 +++++++++++++++++++++++++++++++++++++++- 2 files changed, 57 insertions(+), 8 deletions(-) diff --git a/lib/srv/sess.go b/lib/srv/sess.go index 8fb411f8a6f5c..dbc22562b3473 100644 --- a/lib/srv/sess.go +++ b/lib/srv/sess.go @@ -332,7 +332,7 @@ func (s *SessionRegistry) OpenSession(ctx context.Context, ch ssh.Channel, scx * // This logic allows concurrent request to create a new session // to fail, what is ok because we should never have this condition - sess, p, err := newSession(ctx, rsession.ID(sid), s, scx, ch) + sess, p, err := newSession(ctx, rsession.ID(sid), s, scx, ch, sessionTypeInteractive) if err != nil { return trace.Wrap(err) } @@ -367,7 +367,7 @@ func (s *SessionRegistry) OpenExecSession(ctx context.Context, channel ssh.Chann // This logic allows concurrent request to create a new session // to fail, what is ok because we should never have this condition. - sess, _, err := newSession(ctx, sessionID, s, scx, channel) + sess, _, err := newSession(ctx, sessionID, s, scx, channel, sessionTypeNonInteractive) if err != nil { return trace.Wrap(err) } @@ -721,8 +721,15 @@ type session struct { started atomic.Bool } +type sessionType bool + +const ( + sessionTypeInteractive sessionType = true + sessionTypeNonInteractive sessionType = false +) + // newSession creates a new session with a given ID within a given context. -func newSession(ctx context.Context, id rsession.ID, r *SessionRegistry, scx *ServerContext, ch ssh.Channel) (*session, *party, error) { +func newSession(ctx context.Context, id rsession.ID, r *SessionRegistry, scx *ServerContext, ch ssh.Channel, sessType sessionType) (*session, *party, error) { serverSessions.Inc() startTime := time.Now().UTC() rsess := rsession.Session{ @@ -795,7 +802,7 @@ func newSession(ctx context.Context, id rsession.ID, r *SessionRegistry, scx *Se sess.participants[p.id] = p var err error - if err = sess.trackSession(ctx, scx.Identity.TeleportUser, policySets, p); err != nil { + if err = sess.trackSession(ctx, scx.Identity.TeleportUser, policySets, p, sessType); err != nil { if trace.IsNotImplemented(err) { return nil, nil, trace.NotImplemented("Attempted to use Moderated Sessions with an Auth Server below the minimum version of 9.0.0.") } @@ -2106,7 +2113,7 @@ func (p *party) closeUnderSessionLock() error { // trackSession creates a new session tracker for the ssh session. // While ctx is open, the session tracker's expiration will be extended // on an interval until the session tracker is closed. -func (s *session) trackSession(ctx context.Context, teleportUser string, policySet []*types.SessionTrackerPolicySet, p *party) error { +func (s *session) trackSession(ctx context.Context, teleportUser string, policySet []*types.SessionTrackerPolicySet, p *party, sessType sessionType) error { s.log.Debugf("Tracking participant: %s", p.id) var initialCommand []string if execRequest, err := s.scx.GetExecRequest(); err == nil { @@ -2144,9 +2151,13 @@ func (s *session) trackSession(ctx context.Context, teleportUser string, policyS } svc := s.registry.SessionTrackerService - // only propagate the session tracker when the recording mode and component are in sync + // Only propagate the session tracker when the recording mode and component are in sync + // AND the sesssion is interactive + // AND the session was not initiated by a bot if (s.registry.Srv.Component() == teleport.ComponentNode && services.IsRecordAtProxy(s.scx.SessionRecordingConfig.GetMode())) || - (s.registry.Srv.Component() == teleport.ComponentProxy && !services.IsRecordAtProxy(s.scx.SessionRecordingConfig.GetMode())) { + (s.registry.Srv.Component() == teleport.ComponentProxy && !services.IsRecordAtProxy(s.scx.SessionRecordingConfig.GetMode())) || + sessType == sessionTypeNonInteractive || + s.scx.Identity.BotName != "" { svc = nil } diff --git a/lib/srv/sess_test.go b/lib/srv/sess_test.go index db4bd50c67659..98b650532c9e0 100644 --- a/lib/srv/sess_test.go +++ b/lib/srv/sess_test.go @@ -992,6 +992,8 @@ func TestTrackingSession(t *testing.T) { recordingMode string createError error moderated bool + interactive bool + botUser bool assertion require.ErrorAssertionFunc createAssertion func(t *testing.T, count int) }{ @@ -999,6 +1001,7 @@ func TestTrackingSession(t *testing.T) { name: "node with proxy recording mode", component: teleport.ComponentNode, recordingMode: types.RecordAtProxy, + interactive: true, assertion: require.NoError, createAssertion: func(t *testing.T, count int) { require.Equal(t, 0, count) @@ -1008,6 +1011,7 @@ func TestTrackingSession(t *testing.T) { name: "node with node recording mode", component: teleport.ComponentNode, recordingMode: types.RecordAtNode, + interactive: true, assertion: require.NoError, createAssertion: func(t *testing.T, count int) { require.Equal(t, 1, count) @@ -1017,6 +1021,7 @@ func TestTrackingSession(t *testing.T) { name: "proxy with proxy recording mode", component: teleport.ComponentProxy, recordingMode: types.RecordAtProxy, + interactive: true, assertion: require.NoError, createAssertion: func(t *testing.T, count int) { require.Equal(t, 1, count) @@ -1026,6 +1031,7 @@ func TestTrackingSession(t *testing.T) { name: "proxy with node recording mode", component: teleport.ComponentProxy, recordingMode: types.RecordAtNode, + interactive: true, assertion: require.NoError, createAssertion: func(t *testing.T, count int) { require.Equal(t, 0, count) @@ -1036,6 +1042,7 @@ func TestTrackingSession(t *testing.T) { component: teleport.ComponentNode, recordingMode: types.RecordAtNodeSync, assertion: require.NoError, + interactive: true, createError: trace.ConnectionProblem(context.DeadlineExceeded, ""), createAssertion: func(t *testing.T, count int) { require.Equal(t, 1, count) @@ -1046,12 +1053,33 @@ func TestTrackingSession(t *testing.T) { component: teleport.ComponentNode, recordingMode: types.RecordAtNodeSync, moderated: true, + interactive: true, assertion: require.Error, createError: trace.ConnectionProblem(context.DeadlineExceeded, ""), createAssertion: func(t *testing.T, count int) { require.Equal(t, 1, count) }, }, + { + name: "bot session", + component: teleport.ComponentNode, + recordingMode: types.RecordAtNode, + interactive: true, + botUser: true, + assertion: require.NoError, + createAssertion: func(t *testing.T, count int) { + require.Equal(t, 0, count) + }, + }, + { + name: "non-interactive session", + component: teleport.ComponentNode, + recordingMode: types.RecordAtNode, + assertion: require.NoError, + createAssertion: func(t *testing.T, count int) { + require.Equal(t, 0, count) + }, + }, } for _, tt := range cases { @@ -1075,6 +1103,10 @@ func TestTrackingSession(t *testing.T) { }, } + if tt.botUser { + scx.Identity.BotName = "test-bot" + } + sess := &session{ id: rsession.NewID(), log: utils.NewLoggerForTests().WithField(trace.Component, "test-session"), @@ -1101,7 +1133,13 @@ func TestTrackingSession(t *testing.T) { id: rsession.NewID(), mode: types.SessionPeerMode, } - err = sess.trackSession(ctx, me.Name, nil, p) + + sessType := sessionTypeNonInteractive + if tt.interactive { + sessType = sessionTypeInteractive + } + + err = sess.trackSession(ctx, me.Name, nil, p, sessType) tt.assertion(t, err) tt.createAssertion(t, trackingService.CreatedCount()) })