Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[v14] Reduce backend load created from session trackers #42695

Merged
merged 1 commit into from
Jun 11, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
25 changes: 18 additions & 7 deletions lib/srv/sess.go
Original file line number Diff line number Diff line change
Expand Up @@ -332,7 +332,7 @@ func (s *SessionRegistry) OpenSession(ctx context.Context, ch ssh.Channel, scx *

// This logic allows concurrent request to create a new session
// to fail, what is ok because we should never have this condition
sess, p, err := newSession(ctx, rsession.ID(sid), s, scx, ch)
sess, p, err := newSession(ctx, rsession.ID(sid), s, scx, ch, sessionTypeInteractive)
if err != nil {
return trace.Wrap(err)
}
Expand Down Expand Up @@ -367,7 +367,7 @@ func (s *SessionRegistry) OpenExecSession(ctx context.Context, channel ssh.Chann

// This logic allows concurrent request to create a new session
// to fail, what is ok because we should never have this condition.
sess, _, err := newSession(ctx, sessionID, s, scx, channel)
sess, _, err := newSession(ctx, sessionID, s, scx, channel, sessionTypeNonInteractive)
if err != nil {
return trace.Wrap(err)
}
Expand Down Expand Up @@ -721,8 +721,15 @@ type session struct {
started atomic.Bool
}

type sessionType bool

const (
sessionTypeInteractive sessionType = true
sessionTypeNonInteractive sessionType = false
)

// newSession creates a new session with a given ID within a given context.
func newSession(ctx context.Context, id rsession.ID, r *SessionRegistry, scx *ServerContext, ch ssh.Channel) (*session, *party, error) {
func newSession(ctx context.Context, id rsession.ID, r *SessionRegistry, scx *ServerContext, ch ssh.Channel, sessType sessionType) (*session, *party, error) {
serverSessions.Inc()
startTime := time.Now().UTC()
rsess := rsession.Session{
Expand Down Expand Up @@ -795,7 +802,7 @@ func newSession(ctx context.Context, id rsession.ID, r *SessionRegistry, scx *Se
sess.participants[p.id] = p

var err error
if err = sess.trackSession(ctx, scx.Identity.TeleportUser, policySets, p); err != nil {
if err = sess.trackSession(ctx, scx.Identity.TeleportUser, policySets, p, sessType); err != nil {
if trace.IsNotImplemented(err) {
return nil, nil, trace.NotImplemented("Attempted to use Moderated Sessions with an Auth Server below the minimum version of 9.0.0.")
}
Expand Down Expand Up @@ -2106,7 +2113,7 @@ func (p *party) closeUnderSessionLock() error {
// trackSession creates a new session tracker for the ssh session.
// While ctx is open, the session tracker's expiration will be extended
// on an interval until the session tracker is closed.
func (s *session) trackSession(ctx context.Context, teleportUser string, policySet []*types.SessionTrackerPolicySet, p *party) error {
func (s *session) trackSession(ctx context.Context, teleportUser string, policySet []*types.SessionTrackerPolicySet, p *party, sessType sessionType) error {
s.log.Debugf("Tracking participant: %s", p.id)
var initialCommand []string
if execRequest, err := s.scx.GetExecRequest(); err == nil {
Expand Down Expand Up @@ -2144,9 +2151,13 @@ func (s *session) trackSession(ctx context.Context, teleportUser string, policyS
}

svc := s.registry.SessionTrackerService
// only propagate the session tracker when the recording mode and component are in sync
// Only propagate the session tracker when the recording mode and component are in sync
// AND the sesssion is interactive
// AND the session was not initiated by a bot
if (s.registry.Srv.Component() == teleport.ComponentNode && services.IsRecordAtProxy(s.scx.SessionRecordingConfig.GetMode())) ||
(s.registry.Srv.Component() == teleport.ComponentProxy && !services.IsRecordAtProxy(s.scx.SessionRecordingConfig.GetMode())) {
(s.registry.Srv.Component() == teleport.ComponentProxy && !services.IsRecordAtProxy(s.scx.SessionRecordingConfig.GetMode())) ||
sessType == sessionTypeNonInteractive ||
s.scx.Identity.BotName != "" {
svc = nil
}

Expand Down
40 changes: 39 additions & 1 deletion lib/srv/sess_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -992,13 +992,16 @@ func TestTrackingSession(t *testing.T) {
recordingMode string
createError error
moderated bool
interactive bool
botUser bool
assertion require.ErrorAssertionFunc
createAssertion func(t *testing.T, count int)
}{
{
name: "node with proxy recording mode",
component: teleport.ComponentNode,
recordingMode: types.RecordAtProxy,
interactive: true,
assertion: require.NoError,
createAssertion: func(t *testing.T, count int) {
require.Equal(t, 0, count)
Expand All @@ -1008,6 +1011,7 @@ func TestTrackingSession(t *testing.T) {
name: "node with node recording mode",
component: teleport.ComponentNode,
recordingMode: types.RecordAtNode,
interactive: true,
assertion: require.NoError,
createAssertion: func(t *testing.T, count int) {
require.Equal(t, 1, count)
Expand All @@ -1017,6 +1021,7 @@ func TestTrackingSession(t *testing.T) {
name: "proxy with proxy recording mode",
component: teleport.ComponentProxy,
recordingMode: types.RecordAtProxy,
interactive: true,
assertion: require.NoError,
createAssertion: func(t *testing.T, count int) {
require.Equal(t, 1, count)
Expand All @@ -1026,6 +1031,7 @@ func TestTrackingSession(t *testing.T) {
name: "proxy with node recording mode",
component: teleport.ComponentProxy,
recordingMode: types.RecordAtNode,
interactive: true,
assertion: require.NoError,
createAssertion: func(t *testing.T, count int) {
require.Equal(t, 0, count)
Expand All @@ -1036,6 +1042,7 @@ func TestTrackingSession(t *testing.T) {
component: teleport.ComponentNode,
recordingMode: types.RecordAtNodeSync,
assertion: require.NoError,
interactive: true,
createError: trace.ConnectionProblem(context.DeadlineExceeded, ""),
createAssertion: func(t *testing.T, count int) {
require.Equal(t, 1, count)
Expand All @@ -1046,12 +1053,33 @@ func TestTrackingSession(t *testing.T) {
component: teleport.ComponentNode,
recordingMode: types.RecordAtNodeSync,
moderated: true,
interactive: true,
assertion: require.Error,
createError: trace.ConnectionProblem(context.DeadlineExceeded, ""),
createAssertion: func(t *testing.T, count int) {
require.Equal(t, 1, count)
},
},
{
name: "bot session",
component: teleport.ComponentNode,
recordingMode: types.RecordAtNode,
interactive: true,
botUser: true,
assertion: require.NoError,
createAssertion: func(t *testing.T, count int) {
require.Equal(t, 0, count)
},
},
{
name: "non-interactive session",
component: teleport.ComponentNode,
recordingMode: types.RecordAtNode,
assertion: require.NoError,
createAssertion: func(t *testing.T, count int) {
require.Equal(t, 0, count)
},
},
}

for _, tt := range cases {
Expand All @@ -1075,6 +1103,10 @@ func TestTrackingSession(t *testing.T) {
},
}

if tt.botUser {
scx.Identity.BotName = "test-bot"
}

sess := &session{
id: rsession.NewID(),
log: utils.NewLoggerForTests().WithField(trace.Component, "test-session"),
Expand All @@ -1101,7 +1133,13 @@ func TestTrackingSession(t *testing.T) {
id: rsession.NewID(),
mode: types.SessionPeerMode,
}
err = sess.trackSession(ctx, me.Name, nil, p)

sessType := sessionTypeNonInteractive
if tt.interactive {
sessType = sessionTypeInteractive
}

err = sess.trackSession(ctx, me.Name, nil, p, sessType)
tt.assertion(t, err)
tt.createAssertion(t, trackingService.CreatedCount())
})
Expand Down
Loading