diff --git a/ee/agent/types/registration.go b/ee/agent/types/registration.go index e6563e195..1c39ef1f7 100644 --- a/ee/agent/types/registration.go +++ b/ee/agent/types/registration.go @@ -11,4 +11,3 @@ type RegistrationTracker interface { RegistrationIDs() []string SetRegistrationIDs(registrationIDs []string) error } - \ No newline at end of file diff --git a/pkg/osquery/runtime/runner.go b/pkg/osquery/runtime/runner.go index 424bdd524..94b3c996c 100644 --- a/pkg/osquery/runtime/runner.go +++ b/pkg/osquery/runtime/runner.go @@ -7,6 +7,7 @@ import ( "log/slog" "slices" "sync" + "sync/atomic" "time" "github.com/kolide/launcher/ee/agent/flags/keys" @@ -27,9 +28,9 @@ type Runner struct { knapsack types.Knapsack serviceClient service.KolideService // shared service client for communication between osquery instance and Kolide SaaS opts []OsqueryInstanceOption // global options applying to all osquery instances - shutdown chan struct{} - rerunRequired bool - interrupted bool + shutdown chan struct{} // buffered shutdown channel for to enable shutting down to restart or exit + rerunRequired atomic.Bool + interrupted atomic.Bool } func New(k types.Knapsack, serviceClient service.KolideService, opts ...OsqueryInstanceOption) *Runner { @@ -39,9 +40,9 @@ func New(k types.Knapsack, serviceClient service.KolideService, opts ...OsqueryI slogger: k.Slogger().With("component", "osquery_runner"), knapsack: k, serviceClient: serviceClient, - shutdown: make(chan struct{}), - rerunRequired: false, - opts: opts, + // the buffer length is arbitrarily set at 100, this number just needs to be higher than the total possible instances + shutdown: make(chan struct{}, 100), + opts: opts, } k.RegisterChangeObserver(runner, @@ -60,8 +61,8 @@ func (r *Runner) Run() error { // if we're in a state that required re-running all registered instances, // reset the field and do that - if r.rerunRequired { - r.rerunRequired = false + if r.rerunRequired.Load() { + r.rerunRequired.Store(false) continue } @@ -214,6 +215,13 @@ func (r *Runner) Query(query string) ([]map[string]string, error) { } func (r *Runner) Interrupt(_ error) { + if r.interrupted.Load() { + // Already shut down, nothing else to do + return + } + + r.interrupted.Store(true) + if err := r.Shutdown(); err != nil { r.slogger.Log(context.TODO(), slog.LevelWarn, "could not shut down runner on interrupt", @@ -225,13 +233,12 @@ func (r *Runner) Interrupt(_ error) { // Shutdown instructs the runner to permanently stop the running instance (no // restart will be attempted). func (r *Runner) Shutdown() error { - if r.interrupted { - // Already shut down, nothing else to do - return nil + // ensure one shutdown is sent for each instance to read + r.instanceLock.Lock() + for range r.instances { + r.shutdown <- struct{}{} } - - r.interrupted = true - close(r.shutdown) + r.instanceLock.Unlock() if err := r.triggerShutdownForInstances(); err != nil { return fmt.Errorf("triggering shutdown for instances during runner shutdown: %w", err) @@ -385,7 +392,7 @@ func (r *Runner) UpdateRegistrationIDs(newRegistrationIDs []string) error { r.registrationIds = newRegistrationIDs // mark rerun as required so that we can safely shutdown all workers and have the changes // picked back up from within the main Run function - r.rerunRequired = true + r.rerunRequired.Store(true) if err := r.Shutdown(); err != nil { r.slogger.Log(context.TODO(), slog.LevelWarn, @@ -396,9 +403,5 @@ func (r *Runner) UpdateRegistrationIDs(newRegistrationIDs []string) error { return err } - // reset the shutdown channel and interrupted state - r.shutdown = make(chan struct{}) - r.interrupted = false - return nil } diff --git a/pkg/osquery/runtime/runtime_test.go b/pkg/osquery/runtime/runtime_test.go index 8b61d1e46..eb8a8fa17 100644 --- a/pkg/osquery/runtime/runtime_test.go +++ b/pkg/osquery/runtime/runtime_test.go @@ -638,11 +638,12 @@ func TestMultipleInstancesWithUpdatedRegistrationIDs(t *testing.T) { // Add in an extra instance extraRegistrationId := ulid.New() - runner.UpdateRegistrationIDs([]string{types.DefaultRegistrationID, extraRegistrationId}) + updateErr := runner.UpdateRegistrationIDs([]string{types.DefaultRegistrationID, extraRegistrationId}) + require.NoError(t, updateErr) waitHealthy(t, runner, logBytes) updatedInstanceStatuses := runner.InstanceStatuses() // verify that rerunRequired has been reset for any future changes - require.False(t, runner.rerunRequired) + require.False(t, runner.rerunRequired.Load()) // now verify both instances are reported require.Equal(t, 2, len(runner.instances)) require.Contains(t, updatedInstanceStatuses, types.DefaultRegistrationID) @@ -655,7 +656,8 @@ func TestMultipleInstancesWithUpdatedRegistrationIDs(t *testing.T) { // update registration IDs one more time, this time removing the additional registration originalDefaultInstanceStartTime := runner.instances[extraRegistrationId].stats.StartTime - runner.UpdateRegistrationIDs([]string{types.DefaultRegistrationID}) + updateErr = runner.UpdateRegistrationIDs([]string{types.DefaultRegistrationID}) + require.NoError(t, updateErr) waitHealthy(t, runner, logBytes) // now verify only the default instance remains @@ -666,7 +668,7 @@ func TestMultipleInstancesWithUpdatedRegistrationIDs(t *testing.T) { require.NotEmpty(t, runner.instances[types.DefaultRegistrationID].stats.StartTime, "start time should be added to default instance stats on start up") require.NotEmpty(t, runner.instances[types.DefaultRegistrationID].stats.ConnectTime, "connect time should be added to default instance stats on start up") // verify that rerunRequired has been reset for any future changes - require.False(t, runner.rerunRequired) + require.False(t, runner.rerunRequired.Load()) // verify the default instance was restarted require.NotEqual(t, originalDefaultInstanceStartTime, runner.instances[types.DefaultRegistrationID].stats.StartTime) @@ -726,7 +728,8 @@ func TestUpdatingRegistrationIDsOnlyRestartsForChanges(t *testing.T) { extraInstanceStartTime := runner.instances[extraRegistrationId].stats.StartTime // rerun with identical registrationIDs in swapped order and verify that the instances are not restarted - runner.UpdateRegistrationIDs([]string{extraRegistrationId, types.DefaultRegistrationID}) + updateErr := runner.UpdateRegistrationIDs([]string{extraRegistrationId, types.DefaultRegistrationID}) + require.NoError(t, updateErr) waitHealthy(t, runner, logBytes) require.Equal(t, 2, len(runner.instances))