Skip to content

Commit

Permalink
Reduce backend load created from session trackers (#42324)
Browse files Browse the repository at this point in the history
Session trackers were originally added to facilitate joining
sessions and enforcing moderation policies. When a session is
created, a new tracker is written to the backend and a background
routine is spawned to periodically update the status of the tracker
until the session is terminated. This can cause a massive amount
of backend activity for a cluster that is spawning large
quantities of sessions per second. While in most cases where
humans are starting the sessions this isn't a problem, any machine
id heavy use cases could trigger backend throttling. Since
non-interactive sessions and sessions started by tbot are not
meant to be joined or moderated, the existence of a session
tracker for them doesn't provide much benefit, especially now that
session recordings are disabled for non-interactive sessions. To
prevent excess backend writes session trackers are no longer
created for non-interactive and tbot sessions.
  • Loading branch information
rosstimothy committed Jun 10, 2024
1 parent 291ba34 commit 1680d60
Show file tree
Hide file tree
Showing 2 changed files with 57 additions and 8 deletions.
25 changes: 18 additions & 7 deletions lib/srv/sess.go
Original file line number Diff line number Diff line change
Expand Up @@ -332,7 +332,7 @@ func (s *SessionRegistry) OpenSession(ctx context.Context, ch ssh.Channel, scx *

// This logic allows concurrent request to create a new session
// to fail, what is ok because we should never have this condition
sess, p, err := newSession(ctx, rsession.ID(sid), s, scx, ch)
sess, p, err := newSession(ctx, rsession.ID(sid), s, scx, ch, sessionTypeInteractive)
if err != nil {
return trace.Wrap(err)
}
Expand Down Expand Up @@ -367,7 +367,7 @@ func (s *SessionRegistry) OpenExecSession(ctx context.Context, channel ssh.Chann

// This logic allows concurrent request to create a new session
// to fail, what is ok because we should never have this condition.
sess, _, err := newSession(ctx, sessionID, s, scx, channel)
sess, _, err := newSession(ctx, sessionID, s, scx, channel, sessionTypeNonInteractive)
if err != nil {
return trace.Wrap(err)
}
Expand Down Expand Up @@ -721,8 +721,15 @@ type session struct {
started atomic.Bool
}

type sessionType bool

const (
sessionTypeInteractive sessionType = true
sessionTypeNonInteractive sessionType = false
)

// newSession creates a new session with a given ID within a given context.
func newSession(ctx context.Context, id rsession.ID, r *SessionRegistry, scx *ServerContext, ch ssh.Channel) (*session, *party, error) {
func newSession(ctx context.Context, id rsession.ID, r *SessionRegistry, scx *ServerContext, ch ssh.Channel, sessType sessionType) (*session, *party, error) {
serverSessions.Inc()
startTime := time.Now().UTC()
rsess := rsession.Session{
Expand Down Expand Up @@ -795,7 +802,7 @@ func newSession(ctx context.Context, id rsession.ID, r *SessionRegistry, scx *Se
sess.participants[p.id] = p

var err error
if err = sess.trackSession(ctx, scx.Identity.TeleportUser, policySets, p); err != nil {
if err = sess.trackSession(ctx, scx.Identity.TeleportUser, policySets, p, sessType); err != nil {
if trace.IsNotImplemented(err) {
return nil, nil, trace.NotImplemented("Attempted to use Moderated Sessions with an Auth Server below the minimum version of 9.0.0.")
}
Expand Down Expand Up @@ -2106,7 +2113,7 @@ func (p *party) closeUnderSessionLock() error {
// trackSession creates a new session tracker for the ssh session.
// While ctx is open, the session tracker's expiration will be extended
// on an interval until the session tracker is closed.
func (s *session) trackSession(ctx context.Context, teleportUser string, policySet []*types.SessionTrackerPolicySet, p *party) error {
func (s *session) trackSession(ctx context.Context, teleportUser string, policySet []*types.SessionTrackerPolicySet, p *party, sessType sessionType) error {
s.log.Debugf("Tracking participant: %s", p.id)
var initialCommand []string
if execRequest, err := s.scx.GetExecRequest(); err == nil {
Expand Down Expand Up @@ -2144,9 +2151,13 @@ func (s *session) trackSession(ctx context.Context, teleportUser string, policyS
}

svc := s.registry.SessionTrackerService
// only propagate the session tracker when the recording mode and component are in sync
// Only propagate the session tracker when the recording mode and component are in sync
// AND the sesssion is interactive
// AND the session was not initiated by a bot
if (s.registry.Srv.Component() == teleport.ComponentNode && services.IsRecordAtProxy(s.scx.SessionRecordingConfig.GetMode())) ||
(s.registry.Srv.Component() == teleport.ComponentProxy && !services.IsRecordAtProxy(s.scx.SessionRecordingConfig.GetMode())) {
(s.registry.Srv.Component() == teleport.ComponentProxy && !services.IsRecordAtProxy(s.scx.SessionRecordingConfig.GetMode())) ||
sessType == sessionTypeNonInteractive ||
s.scx.Identity.BotName != "" {
svc = nil
}

Expand Down
40 changes: 39 additions & 1 deletion lib/srv/sess_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -992,13 +992,16 @@ func TestTrackingSession(t *testing.T) {
recordingMode string
createError error
moderated bool
interactive bool
botUser bool
assertion require.ErrorAssertionFunc
createAssertion func(t *testing.T, count int)
}{
{
name: "node with proxy recording mode",
component: teleport.ComponentNode,
recordingMode: types.RecordAtProxy,
interactive: true,
assertion: require.NoError,
createAssertion: func(t *testing.T, count int) {
require.Equal(t, 0, count)
Expand All @@ -1008,6 +1011,7 @@ func TestTrackingSession(t *testing.T) {
name: "node with node recording mode",
component: teleport.ComponentNode,
recordingMode: types.RecordAtNode,
interactive: true,
assertion: require.NoError,
createAssertion: func(t *testing.T, count int) {
require.Equal(t, 1, count)
Expand All @@ -1017,6 +1021,7 @@ func TestTrackingSession(t *testing.T) {
name: "proxy with proxy recording mode",
component: teleport.ComponentProxy,
recordingMode: types.RecordAtProxy,
interactive: true,
assertion: require.NoError,
createAssertion: func(t *testing.T, count int) {
require.Equal(t, 1, count)
Expand All @@ -1026,6 +1031,7 @@ func TestTrackingSession(t *testing.T) {
name: "proxy with node recording mode",
component: teleport.ComponentProxy,
recordingMode: types.RecordAtNode,
interactive: true,
assertion: require.NoError,
createAssertion: func(t *testing.T, count int) {
require.Equal(t, 0, count)
Expand All @@ -1036,6 +1042,7 @@ func TestTrackingSession(t *testing.T) {
component: teleport.ComponentNode,
recordingMode: types.RecordAtNodeSync,
assertion: require.NoError,
interactive: true,
createError: trace.ConnectionProblem(context.DeadlineExceeded, ""),
createAssertion: func(t *testing.T, count int) {
require.Equal(t, 1, count)
Expand All @@ -1046,12 +1053,33 @@ func TestTrackingSession(t *testing.T) {
component: teleport.ComponentNode,
recordingMode: types.RecordAtNodeSync,
moderated: true,
interactive: true,
assertion: require.Error,
createError: trace.ConnectionProblem(context.DeadlineExceeded, ""),
createAssertion: func(t *testing.T, count int) {
require.Equal(t, 1, count)
},
},
{
name: "bot session",
component: teleport.ComponentNode,
recordingMode: types.RecordAtNode,
interactive: true,
botUser: true,
assertion: require.NoError,
createAssertion: func(t *testing.T, count int) {
require.Equal(t, 0, count)
},
},
{
name: "non-interactive session",
component: teleport.ComponentNode,
recordingMode: types.RecordAtNode,
assertion: require.NoError,
createAssertion: func(t *testing.T, count int) {
require.Equal(t, 0, count)
},
},
}

for _, tt := range cases {
Expand All @@ -1075,6 +1103,10 @@ func TestTrackingSession(t *testing.T) {
},
}

if tt.botUser {
scx.Identity.BotName = "test-bot"
}

sess := &session{
id: rsession.NewID(),
log: utils.NewLoggerForTests().WithField(trace.Component, "test-session"),
Expand All @@ -1101,7 +1133,13 @@ func TestTrackingSession(t *testing.T) {
id: rsession.NewID(),
mode: types.SessionPeerMode,
}
err = sess.trackSession(ctx, me.Name, nil, p)

sessType := sessionTypeNonInteractive
if tt.interactive {
sessType = sessionTypeInteractive
}

err = sess.trackSession(ctx, me.Name, nil, p, sessType)
tt.assertion(t, err)
tt.createAssertion(t, trackingService.CreatedCount())
})
Expand Down

0 comments on commit 1680d60

Please sign in to comment.