Skip to content

Commit

Permalink
shift to using buffered shutdown channel, rework
Browse files Browse the repository at this point in the history
  • Loading branch information
zackattack01 committed Dec 19, 2024
1 parent aa0dc07 commit 01e4063
Show file tree
Hide file tree
Showing 3 changed files with 30 additions and 25 deletions.
1 change: 0 additions & 1 deletion ee/agent/types/registration.go
Original file line number Diff line number Diff line change
Expand Up @@ -11,4 +11,3 @@ type RegistrationTracker interface {
RegistrationIDs() []string
SetRegistrationIDs(registrationIDs []string) error
}

41 changes: 22 additions & 19 deletions pkg/osquery/runtime/runner.go
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@ import (
"log/slog"
"slices"
"sync"
"sync/atomic"
"time"

"github.com/kolide/launcher/ee/agent/flags/keys"
Expand All @@ -27,9 +28,9 @@ type Runner struct {
knapsack types.Knapsack
serviceClient service.KolideService // shared service client for communication between osquery instance and Kolide SaaS
opts []OsqueryInstanceOption // global options applying to all osquery instances
shutdown chan struct{}
rerunRequired bool
interrupted bool
shutdown chan struct{} // buffered shutdown channel for to enable shutting down to restart or exit
rerunRequired atomic.Bool
interrupted atomic.Bool
}

func New(k types.Knapsack, serviceClient service.KolideService, opts ...OsqueryInstanceOption) *Runner {
Expand All @@ -39,9 +40,9 @@ func New(k types.Knapsack, serviceClient service.KolideService, opts ...OsqueryI
slogger: k.Slogger().With("component", "osquery_runner"),
knapsack: k,
serviceClient: serviceClient,
shutdown: make(chan struct{}),
rerunRequired: false,
opts: opts,
// the buffer length is arbitrarily set at 100, this number just needs to be higher than the total possible instances
shutdown: make(chan struct{}, 100),
opts: opts,
}

k.RegisterChangeObserver(runner,
Expand All @@ -60,8 +61,8 @@ func (r *Runner) Run() error {

// if we're in a state that required re-running all registered instances,
// reset the field and do that
if r.rerunRequired {
r.rerunRequired = false
if r.rerunRequired.Load() {
r.rerunRequired.Store(false)
continue
}

Expand Down Expand Up @@ -214,6 +215,13 @@ func (r *Runner) Query(query string) ([]map[string]string, error) {
}

func (r *Runner) Interrupt(_ error) {
if r.interrupted.Load() {
// Already shut down, nothing else to do
return
}

r.interrupted.Store(true)

if err := r.Shutdown(); err != nil {
r.slogger.Log(context.TODO(), slog.LevelWarn,
"could not shut down runner on interrupt",
Expand All @@ -225,13 +233,12 @@ func (r *Runner) Interrupt(_ error) {
// Shutdown instructs the runner to permanently stop the running instance (no
// restart will be attempted).
func (r *Runner) Shutdown() error {
if r.interrupted {
// Already shut down, nothing else to do
return nil
// ensure one shutdown is sent for each instance to read
r.instanceLock.Lock()
for range r.instances {
r.shutdown <- struct{}{}
}

r.interrupted = true
close(r.shutdown)
r.instanceLock.Unlock()

if err := r.triggerShutdownForInstances(); err != nil {
return fmt.Errorf("triggering shutdown for instances during runner shutdown: %w", err)
Expand Down Expand Up @@ -385,7 +392,7 @@ func (r *Runner) UpdateRegistrationIDs(newRegistrationIDs []string) error {
r.registrationIds = newRegistrationIDs
// mark rerun as required so that we can safely shutdown all workers and have the changes
// picked back up from within the main Run function
r.rerunRequired = true
r.rerunRequired.Store(true)

if err := r.Shutdown(); err != nil {
r.slogger.Log(context.TODO(), slog.LevelWarn,
Expand All @@ -396,9 +403,5 @@ func (r *Runner) UpdateRegistrationIDs(newRegistrationIDs []string) error {
return err
}

// reset the shutdown channel and interrupted state
r.shutdown = make(chan struct{})
r.interrupted = false

return nil
}
13 changes: 8 additions & 5 deletions pkg/osquery/runtime/runtime_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -638,11 +638,12 @@ func TestMultipleInstancesWithUpdatedRegistrationIDs(t *testing.T) {

// Add in an extra instance
extraRegistrationId := ulid.New()
runner.UpdateRegistrationIDs([]string{types.DefaultRegistrationID, extraRegistrationId})
updateErr := runner.UpdateRegistrationIDs([]string{types.DefaultRegistrationID, extraRegistrationId})
require.NoError(t, updateErr)
waitHealthy(t, runner, logBytes)
updatedInstanceStatuses := runner.InstanceStatuses()
// verify that rerunRequired has been reset for any future changes
require.False(t, runner.rerunRequired)
require.False(t, runner.rerunRequired.Load())
// now verify both instances are reported
require.Equal(t, 2, len(runner.instances))
require.Contains(t, updatedInstanceStatuses, types.DefaultRegistrationID)
Expand All @@ -655,7 +656,8 @@ func TestMultipleInstancesWithUpdatedRegistrationIDs(t *testing.T) {

// update registration IDs one more time, this time removing the additional registration
originalDefaultInstanceStartTime := runner.instances[extraRegistrationId].stats.StartTime
runner.UpdateRegistrationIDs([]string{types.DefaultRegistrationID})
updateErr = runner.UpdateRegistrationIDs([]string{types.DefaultRegistrationID})
require.NoError(t, updateErr)
waitHealthy(t, runner, logBytes)

// now verify only the default instance remains
Expand All @@ -666,7 +668,7 @@ func TestMultipleInstancesWithUpdatedRegistrationIDs(t *testing.T) {
require.NotEmpty(t, runner.instances[types.DefaultRegistrationID].stats.StartTime, "start time should be added to default instance stats on start up")
require.NotEmpty(t, runner.instances[types.DefaultRegistrationID].stats.ConnectTime, "connect time should be added to default instance stats on start up")
// verify that rerunRequired has been reset for any future changes
require.False(t, runner.rerunRequired)
require.False(t, runner.rerunRequired.Load())
// verify the default instance was restarted
require.NotEqual(t, originalDefaultInstanceStartTime, runner.instances[types.DefaultRegistrationID].stats.StartTime)

Expand Down Expand Up @@ -726,7 +728,8 @@ func TestUpdatingRegistrationIDsOnlyRestartsForChanges(t *testing.T) {
extraInstanceStartTime := runner.instances[extraRegistrationId].stats.StartTime

// rerun with identical registrationIDs in swapped order and verify that the instances are not restarted
runner.UpdateRegistrationIDs([]string{extraRegistrationId, types.DefaultRegistrationID})
updateErr := runner.UpdateRegistrationIDs([]string{extraRegistrationId, types.DefaultRegistrationID})
require.NoError(t, updateErr)
waitHealthy(t, runner, logBytes)

require.Equal(t, 2, len(runner.instances))
Expand Down

0 comments on commit 01e4063

Please sign in to comment.