diff --git a/CHANGELOG.next.asciidoc b/CHANGELOG.next.asciidoc index 2361baf73f5..acdf4efc087 100644 --- a/CHANGELOG.next.asciidoc +++ b/CHANGELOG.next.asciidoc @@ -197,3 +197,4 @@ - Add liveness endpoint, allow fleet-gateway component to report degraded state, add update time and messages to status output. {issue}390[390] {pull}569[569] - Redact sensitive information on diagnostics collect command. {issue}[241] {pull}[566] - Fix incorrectly creating a filebeat redis input when a policy contains a packetbeat redis input. {issue}[427] {pull}[700] +- Allow upgrade actions to be retried on failure with action queue scheduling. {issue}778[778] {pull}1219[1219] diff --git a/internal/pkg/agent/application/coordinator/coordinator.go b/internal/pkg/agent/application/coordinator/coordinator.go index e49198da65f..407ebbca625 100644 --- a/internal/pkg/agent/application/coordinator/coordinator.go +++ b/internal/pkg/agent/application/coordinator/coordinator.go @@ -12,6 +12,7 @@ import ( "gopkg.in/yaml.v2" "github.com/elastic/elastic-agent/internal/pkg/diagnostics" + "github.com/elastic/elastic-agent/internal/pkg/fleetapi/acker" "github.com/elastic/elastic-agent/internal/pkg/fleetapi" @@ -51,6 +52,9 @@ type UpgradeManager interface { // Upgrade upgrades running agent. Upgrade(ctx context.Context, version string, sourceURI string, action *fleetapi.ActionUpgrade) (_ reexec.ShutdownCallbackFn, err error) + + // Ack is used on startup to check if the agent has upgraded and needs to send an ack for the action + Ack(ctx context.Context, acker acker.Acker) error } // Runner provides interface to run a manager and receive running errors. @@ -251,10 +255,16 @@ func (c *Coordinator) Upgrade(ctx context.Context, version string, sourceURI str c.state.overrideState = nil return err } - c.ReExec(cb) + if cb != nil { + c.ReExec(cb) + } return nil } +func (c *Coordinator) AckUpgrade(ctx context.Context, acker acker.Acker) error { + return c.upgradeMgr.Ack(ctx, acker) +} + // PerformAction executes an action on a unit. func (c *Coordinator) PerformAction(ctx context.Context, unit component.Unit, name string, params map[string]interface{}) (map[string]interface{}, error) { return c.runtimeMgr.PerformAction(ctx, unit, name, params) diff --git a/internal/pkg/agent/application/dispatcher/dispatcher.go b/internal/pkg/agent/application/dispatcher/dispatcher.go index 700c7d35349..e37fbdc770b 100644 --- a/internal/pkg/agent/application/dispatcher/dispatcher.go +++ b/internal/pkg/agent/application/dispatcher/dispatcher.go @@ -23,14 +23,16 @@ import ( type actionHandlers map[string]actions.Handler type priorityQueue interface { - Add(fleetapi.Action, int64) - DequeueActions() []fleetapi.Action + Add(fleetapi.ScheduledAction, int64) + DequeueActions() []fleetapi.ScheduledAction + CancelType(string) int Save() error } // Dispatcher processes actions coming from fleet api. type Dispatcher interface { - Dispatch(context.Context, acker.Acker, ...fleetapi.Action) error + Dispatch(context.Context, acker.Acker, ...fleetapi.Action) + Errors() <-chan error } // ActionDispatcher processes actions coming from fleet using registered set of handlers. @@ -39,6 +41,8 @@ type ActionDispatcher struct { handlers actionHandlers def actions.Handler queue priorityQueue + rt *retryConfig + errCh chan error } // New creates a new action dispatcher. @@ -60,9 +64,15 @@ func New(log *logger.Logger, def actions.Handler, queue priorityQueue) (*ActionD handlers: make(actionHandlers), def: def, queue: queue, + rt: defaultRetryConfig(), + errCh: make(chan error), }, nil } +func (ad *ActionDispatcher) Errors() <-chan error { + return ad.errCh +} + // Register registers a new handler for action. func (ad *ActionDispatcher) Register(a fleetapi.Action, handler actions.Handler) error { k := ad.key(a) @@ -88,13 +98,18 @@ func (ad *ActionDispatcher) key(a fleetapi.Action) string { } // Dispatch dispatches an action using pre-registered set of handlers. -func (ad *ActionDispatcher) Dispatch(ctx context.Context, acker acker.Acker, actions ...fleetapi.Action) (err error) { +// Dispatch will handle action queue operations, and retries. +// Any action that implements the ScheduledAction interface may be added/removed from the queue based on StartTime. +// Any action that implements the RetryableAction interface will be rescheduled if the handler returns an error. +func (ad *ActionDispatcher) Dispatch(ctx context.Context, acker acker.Acker, actions ...fleetapi.Action) { + var err error span, ctx := apm.StartSpan(ctx, "dispatch", "app.internal") defer func() { apm.CaptureError(ctx, err).Send() span.End() }() + ad.removeQueuedUpgrades(actions) actions = ad.queueScheduledActions(actions) actions = ad.dispatchCancelActions(ctx, actions, acker) queued, expired := ad.gatherQueuedActions(time.Now().UTC()) @@ -108,7 +123,7 @@ func (ad *ActionDispatcher) Dispatch(ctx context.Context, acker acker.Acker, act if len(actions) == 0 { ad.log.Debug("No action to dispatch") - return nil + return } ad.log.Debugf( @@ -118,18 +133,28 @@ func (ad *ActionDispatcher) Dispatch(ctx context.Context, acker acker.Acker, act ) for _, action := range actions { - if err := ctx.Err(); err != nil { - return err + if err = ctx.Err(); err != nil { + ad.errCh <- err + return } if err := ad.dispatchAction(ctx, action, acker); err != nil { + rAction, ok := action.(fleetapi.RetryableAction) + if ok { + rAction.SetError(err) // set the retryable action error to what the dispatcher returned + ad.scheduleRetry(ctx, rAction, acker) + continue + } ad.log.Debugf("Failed to dispatch action '%+v', error: %+v", action, err) - return err + ad.errCh <- err + continue } ad.log.Debugf("Successfully dispatched action: '%+v'", action) } - return acker.Commit(ctx) + if err = acker.Commit(ctx); err != nil { + ad.errCh <- err + } } func (ad *ActionDispatcher) dispatchAction(ctx context.Context, a fleetapi.Action, acker acker.Acker) error { @@ -154,15 +179,18 @@ func detectTypes(actions []fleetapi.Action) []string { func (ad *ActionDispatcher) queueScheduledActions(input []fleetapi.Action) []fleetapi.Action { actions := make([]fleetapi.Action, 0, len(input)) for _, action := range input { - start, err := action.StartTime() - if err == nil { - ad.log.Debugf("Adding action id: %s to queue.", action.ID()) - ad.queue.Add(action, start.Unix()) + sAction, ok := action.(fleetapi.ScheduledAction) + if ok { + start, err := sAction.StartTime() + if err != nil { + ad.log.Warnf("Skipping addition to action-queue, issue gathering start time from action id %s: %v", sAction.ID(), err) + actions = append(actions, action) + continue + } + ad.log.Debugf("Adding action id: %s to queue.", sAction.ID()) + ad.queue.Add(sAction, start.Unix()) continue } - if !errors.Is(err, fleetapi.ErrNoStartTime) { - ad.log.Warnf("Issue gathering start time from action id %s: %v", action.ID(), err) - } actions = append(actions, action) } return actions @@ -197,3 +225,50 @@ func (ad *ActionDispatcher) gatherQueuedActions(ts time.Time) (queued, expired [ } return queued, expired } + +// removeQueuedUpgrades will scan the passed actions and if there is an upgrade action it will remove all upgrade actions in the queue but not alter the passed list. +// this is done to try to only have the most recent upgrade action executed. However it does not eliminate duplicates in retrieved directly from the gateway +func (ad *ActionDispatcher) removeQueuedUpgrades(actions []fleetapi.Action) { + for _, action := range actions { + if action.Type() == fleetapi.ActionTypeUpgrade { + if n := ad.queue.CancelType(fleetapi.ActionTypeUpgrade); n > 0 { + ad.log.Debugw("New upgrade action retrieved from gateway, removing queued upgrade actions", "actions_found", n) + } + return + } + } +} + +func (ad *ActionDispatcher) scheduleRetry(ctx context.Context, action fleetapi.RetryableAction, acker acker.Acker) { + attempt := action.RetryAttempt() + d, err := ad.rt.GetWait(attempt) + if err != nil { + ad.log.Errorf("No more reties for action id %s: %v", action.ID(), err) + action.SetRetryAttempt(-1) + if err := acker.Ack(ctx, action); err != nil { + ad.log.Errorf("Unable to ack action failure (id %s) to fleet-server: %v", action.ID(), err) + return + } + if err := acker.Commit(ctx); err != nil { + ad.log.Errorf("Unable to commit action failure (id %s) to fleet-server: %v", action.ID(), err) + } + return + } + attempt = attempt + 1 + startTime := time.Now().UTC().Add(d) + action.SetRetryAttempt(attempt) + action.SetStartTime(startTime) + ad.log.Debugf("Adding action id: %s to queue.", action.ID()) + ad.queue.Add(action, startTime.Unix()) + err = ad.queue.Save() + if err != nil { + ad.log.Errorf("retry action id %s attempt %d failed to persist action_queue: %v", action.ID(), attempt, err) + } + if err := acker.Ack(ctx, action); err != nil { + ad.log.Errorf("Unable to ack action retry (id %s) to fleet-server: %v", action.ID(), err) + return + } + if err := acker.Commit(ctx); err != nil { + ad.log.Errorf("Unable to commit action retry (id %s) to fleet-server: %v", action.ID(), err) + } +} diff --git a/internal/pkg/agent/application/dispatcher/dispatcher_test.go b/internal/pkg/agent/application/dispatcher/dispatcher_test.go index d140033655c..c9c1397443c 100644 --- a/internal/pkg/agent/application/dispatcher/dispatcher_test.go +++ b/internal/pkg/agent/application/dispatcher/dispatcher_test.go @@ -12,6 +12,7 @@ import ( "github.com/stretchr/testify/mock" "github.com/stretchr/testify/require" + "github.com/elastic/elastic-agent/internal/pkg/agent/errors" "github.com/elastic/elastic-agent/internal/pkg/fleetapi" "github.com/elastic/elastic-agent/internal/pkg/fleetapi/acker" "github.com/elastic/elastic-agent/internal/pkg/fleetapi/acker/noop" @@ -33,9 +34,12 @@ type mockAction struct { type mockOtherAction struct { mockAction } -type mockUnknownAction struct { +type mockScheduledAction struct { mockAction } +type mockRetryableAction struct { + mockScheduledAction +} func (m *mockAction) ID() string { args := m.Called() @@ -49,26 +53,48 @@ func (m *mockAction) String() string { args := m.Called() return args.String(0) } -func (m *mockAction) StartTime() (time.Time, error) { +func (m *mockScheduledAction) StartTime() (time.Time, error) { args := m.Called() return args.Get(0).(time.Time), args.Error(1) } -func (m *mockAction) Expiration() (time.Time, error) { +func (m *mockScheduledAction) Expiration() (time.Time, error) { args := m.Called() return args.Get(0).(time.Time), args.Error(1) } +func (m *mockRetryableAction) RetryAttempt() int { + args := m.Called() + return args.Int(0) +} +func (m *mockRetryableAction) SetRetryAttempt(n int) { + m.Called(n) +} +func (m *mockRetryableAction) SetStartTime(ts time.Time) { + m.Called(ts) +} +func (m *mockRetryableAction) GetError() error { + args := m.Called() + return args.Error(0) +} +func (m *mockRetryableAction) SetError(err error) { + m.Called(err) +} type mockQueue struct { mock.Mock } -func (m *mockQueue) Add(action fleetapi.Action, n int64) { +func (m *mockQueue) Add(action fleetapi.ScheduledAction, n int64) { m.Called(action, n) } -func (m *mockQueue) DequeueActions() []fleetapi.Action { +func (m *mockQueue) DequeueActions() []fleetapi.ScheduledAction { args := m.Called() - return args.Get(0).([]fleetapi.Action) + return args.Get(0).([]fleetapi.ScheduledAction) +} + +func (m *mockQueue) CancelType(t string) int { + args := m.Called(t) + return args.Int(0) } func (m *mockQueue) Save() error { @@ -84,7 +110,7 @@ func TestActionDispatcher(t *testing.T) { def := &mockHandler{} queue := &mockQueue{} queue.On("Save").Return(nil).Once() - queue.On("DequeueActions").Return([]fleetapi.Action{}).Once() + queue.On("DequeueActions").Return([]fleetapi.ScheduledAction{}).Once() d, err := New(nil, def, queue) require.NoError(t, err) @@ -97,11 +123,9 @@ func TestActionDispatcher(t *testing.T) { require.NoError(t, err) action1 := &mockAction{} - action1.On("StartTime").Return(time.Time{}, fleetapi.ErrNoStartTime) action1.On("Type").Return("action") action1.On("ID").Return("id") action2 := &mockOtherAction{} - action2.On("StartTime").Return(time.Time{}, fleetapi.ErrNoStartTime) action2.On("Type").Return("action") action2.On("ID").Return("id") @@ -109,8 +133,12 @@ func TestActionDispatcher(t *testing.T) { success1.On("Handle", mock.Anything, mock.Anything, mock.Anything).Return(nil).Once() success2.On("Handle", mock.Anything, mock.Anything, mock.Anything).Return(nil).Once() - err = d.Dispatch(ctx, ack, action1, action2) - require.NoError(t, err) + d.Dispatch(ctx, ack, action1, action2) + select { + case err := <-d.Errors(): + t.Fatalf("Unexpected error: %v", err) + default: + } success1.AssertExpectations(t) success2.AssertExpectations(t) @@ -124,17 +152,20 @@ func TestActionDispatcher(t *testing.T) { ctx := context.Background() queue := &mockQueue{} queue.On("Save").Return(nil).Once() - queue.On("DequeueActions").Return([]fleetapi.Action{}).Once() + queue.On("DequeueActions").Return([]fleetapi.ScheduledAction{}).Once() d, err := New(nil, def, queue) require.NoError(t, err) - action := &mockUnknownAction{} - action.On("StartTime").Return(time.Time{}, fleetapi.ErrNoStartTime) + action := &mockOtherAction{} action.On("Type").Return("action") action.On("ID").Return("id") - err = d.Dispatch(ctx, ack, action) + d.Dispatch(ctx, ack, action) + select { + case err := <-d.Errors(): + t.Fatalf("Unexpected error: %v", err) + default: + } - require.NoError(t, err) def.AssertExpectations(t) queue.AssertExpectations(t) }) @@ -162,7 +193,7 @@ func TestActionDispatcher(t *testing.T) { queue := &mockQueue{} queue.On("Save").Return(nil).Once() - queue.On("DequeueActions").Return([]fleetapi.Action{}).Once() + queue.On("DequeueActions").Return([]fleetapi.ScheduledAction{}).Once() queue.On("Add", mock.Anything, mock.Anything).Once() d, err := New(nil, def, queue) @@ -171,16 +202,19 @@ func TestActionDispatcher(t *testing.T) { require.NoError(t, err) action1 := &mockAction{} - action1.On("StartTime").Return(time.Time{}, fleetapi.ErrNoStartTime) action1.On("Type").Return("action") action1.On("ID").Return("id") - action2 := &mockAction{} + action2 := &mockScheduledAction{} action2.On("StartTime").Return(time.Now().Add(time.Hour), nil) action2.On("Type").Return("action") action2.On("ID").Return("id") - err = d.Dispatch(context.Background(), ack, action1, action2) - require.NoError(t, err) + d.Dispatch(context.Background(), ack, action1, action2) + select { + case err := <-d.Errors(): + t.Fatalf("Unexpected error: %v", err) + default: + } def.AssertExpectations(t) queue.AssertExpectations(t) }) @@ -191,7 +225,7 @@ func TestActionDispatcher(t *testing.T) { queue := &mockQueue{} queue.On("Save").Return(nil).Once() - queue.On("DequeueActions").Return([]fleetapi.Action{}).Once() + queue.On("DequeueActions").Return([]fleetapi.ScheduledAction{}).Once() d, err := New(nil, def, queue) require.NoError(t, err) @@ -199,12 +233,15 @@ func TestActionDispatcher(t *testing.T) { require.NoError(t, err) action := &mockAction{} - action.On("StartTime").Return(time.Time{}, fleetapi.ErrNoStartTime) action.On("Type").Return(fleetapi.ActionTypeCancel) action.On("ID").Return("id") - err = d.Dispatch(context.Background(), ack, action) - require.NoError(t, err) + d.Dispatch(context.Background(), ack, action) + select { + case err := <-d.Errors(): + t.Fatalf("Unexpected error: %v", err) + default: + } def.AssertExpectations(t) queue.AssertExpectations(t) }) @@ -213,7 +250,7 @@ func TestActionDispatcher(t *testing.T) { def := &mockHandler{} def.On("Handle", mock.Anything, mock.Anything, mock.Anything).Return(nil).Twice() - action1 := &mockAction{} + action1 := &mockScheduledAction{} action1.On("StartTime").Return(time.Time{}, fleetapi.ErrNoStartTime) action1.On("Expiration").Return(time.Now().Add(time.Hour), fleetapi.ErrNoStartTime) action1.On("Type").Return(fleetapi.ActionTypeCancel) @@ -221,7 +258,7 @@ func TestActionDispatcher(t *testing.T) { queue := &mockQueue{} queue.On("Save").Return(nil).Once() - queue.On("DequeueActions").Return([]fleetapi.Action{action1}).Once() + queue.On("DequeueActions").Return([]fleetapi.ScheduledAction{action1}).Once() d, err := New(nil, def, queue) require.NoError(t, err) @@ -229,12 +266,15 @@ func TestActionDispatcher(t *testing.T) { require.NoError(t, err) action2 := &mockAction{} - action2.On("StartTime").Return(time.Time{}, fleetapi.ErrNoStartTime) action2.On("Type").Return(fleetapi.ActionTypeCancel) action2.On("ID").Return("id") - err = d.Dispatch(context.Background(), ack, action2) - require.NoError(t, err) + d.Dispatch(context.Background(), ack, action2) + select { + case err := <-d.Errors(): + t.Fatalf("Unexpected error: %v", err) + default: + } def.AssertExpectations(t) queue.AssertExpectations(t) }) @@ -245,15 +285,132 @@ func TestActionDispatcher(t *testing.T) { queue := &mockQueue{} queue.On("Save").Return(nil).Once() - queue.On("DequeueActions").Return([]fleetapi.Action{}).Once() + queue.On("DequeueActions").Return([]fleetapi.ScheduledAction{}).Once() d, err := New(nil, def, queue) require.NoError(t, err) err = d.Register(&mockAction{}, def) require.NoError(t, err) - err = d.Dispatch(context.Background(), ack) - require.NoError(t, err) + d.Dispatch(context.Background(), ack) + select { + case err := <-d.Errors(): + t.Fatalf("Unexpected error: %v", err) + default: + } def.AssertNotCalled(t, "Handle", mock.Anything, mock.Anything, mock.Anything) }) + + t.Run("Dispatch of a retryable action returns an error", func(t *testing.T) { + def := &mockHandler{} + def.On("Handle", mock.Anything, mock.Anything, mock.Anything).Return(errors.New("test error")).Once() + + queue := &mockQueue{} + queue.On("Save").Return(nil).Twice() + queue.On("DequeueActions").Return([]fleetapi.ScheduledAction{}).Once() + queue.On("Add", mock.Anything, mock.Anything).Once() + + d, err := New(nil, def, queue) + require.NoError(t, err) + err = d.Register(&mockRetryableAction{}, def) + require.NoError(t, err) + + action := &mockRetryableAction{} + action.On("Type").Return("action") + action.On("ID").Return("id") + action.On("StartTime").Return(time.Time{}, fleetapi.ErrNoStartTime).Once() + action.On("SetError", mock.Anything).Once() + action.On("RetryAttempt").Return(0).Once() + action.On("SetRetryAttempt", 1).Once() + action.On("SetStartTime", mock.Anything).Once() + + d.Dispatch(context.Background(), ack, action) + select { + case err := <-d.Errors(): + t.Fatalf("Unexpected error: %v", err) + default: + } + def.AssertExpectations(t) + queue.AssertExpectations(t) + action.AssertExpectations(t) + }) + + t.Run("Dispatch multiples events returns one error", func(t *testing.T) { + def := &mockHandler{} + def.On("Handle", mock.Anything, mock.Anything, mock.Anything).Return(errors.New("test error")).Once() + def.On("Handle", mock.Anything, mock.Anything, mock.Anything).Return(nil).Once() + + queue := &mockQueue{} + queue.On("Save").Return(nil).Once() + queue.On("DequeueActions").Return([]fleetapi.ScheduledAction{}).Once() + + d, err := New(nil, def, queue) + require.NoError(t, err) + err = d.Register(&mockAction{}, def) + require.NoError(t, err) + + action1 := &mockAction{} + action1.On("Type").Return("action") + action1.On("ID").Return("id") + action2 := &mockAction{} + action2.On("Type").Return("action") + action2.On("ID").Return("id") + + // Kind of a dirty work around to test an error return. + // launch in another routing and sleep to check if an error is generated + go d.Dispatch(context.Background(), ack, action1, action2) + time.Sleep(time.Millisecond * 200) + select { + case <-d.Errors(): + default: + t.Fatal("Expected error") + } + time.Sleep(time.Millisecond * 200) + select { + case <-d.Errors(): + t.Fatal(err) + default: + } + + def.AssertExpectations(t) + queue.AssertExpectations(t) + }) +} + +func Test_ActionDispatcher_scheduleRetry(t *testing.T) { + ack := noop.New() + def := &mockHandler{} + + t.Run("no more attmpts", func(t *testing.T) { + queue := &mockQueue{} + d, err := New(nil, def, queue) + require.NoError(t, err) + + action := &mockRetryableAction{} + action.On("ID").Return("id") + action.On("RetryAttempt").Return(len(d.rt.steps)).Once() + action.On("SetRetryAttempt", mock.Anything).Once() + + d.scheduleRetry(context.Background(), action, ack) + queue.AssertExpectations(t) + action.AssertExpectations(t) + }) + + t.Run("schedule an attempt", func(t *testing.T) { + queue := &mockQueue{} + queue.On("Save").Return(nil).Once() + queue.On("Add", mock.Anything, mock.Anything).Once() + d, err := New(nil, def, queue) + require.NoError(t, err) + + action := &mockRetryableAction{} + action.On("ID").Return("id") + action.On("RetryAttempt").Return(0).Once() + action.On("SetRetryAttempt", 1).Once() + action.On("SetStartTime", mock.Anything).Once() + + d.scheduleRetry(context.Background(), action, ack) + queue.AssertExpectations(t) + action.AssertExpectations(t) + }) } diff --git a/internal/pkg/agent/application/dispatcher/retryconfig.go b/internal/pkg/agent/application/dispatcher/retryconfig.go new file mode 100644 index 00000000000..8ed5a6e31af --- /dev/null +++ b/internal/pkg/agent/application/dispatcher/retryconfig.go @@ -0,0 +1,29 @@ +// Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one +// or more contributor license agreements. Licensed under the Elastic License; +// you may not use this file except in compliance with the Elastic License. + +package dispatcher + +import ( + "fmt" + "time" +) + +var ErrNoRetry = fmt.Errorf("no retry attempts remaining") + +type retryConfig struct { + steps []time.Duration +} + +func defaultRetryConfig() *retryConfig { + return &retryConfig{ + steps: []time.Duration{time.Minute, 5 * time.Minute, 10 * time.Minute, 15 * time.Minute, 30 * time.Minute, time.Hour}, + } +} + +func (r *retryConfig) GetWait(step int) (time.Duration, error) { + if step < 0 || step >= len(r.steps) { + return time.Duration(0), ErrNoRetry + } + return r.steps[step], nil +} diff --git a/internal/pkg/agent/application/dispatcher/retryconfig_test.go b/internal/pkg/agent/application/dispatcher/retryconfig_test.go new file mode 100644 index 00000000000..d0db8a7650c --- /dev/null +++ b/internal/pkg/agent/application/dispatcher/retryconfig_test.go @@ -0,0 +1,34 @@ +// Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one +// or more contributor license agreements. Licensed under the Elastic License; +// you may not use this file except in compliance with the Elastic License. + +package dispatcher + +import ( + "testing" + "time" + + "github.com/stretchr/testify/assert" +) + +func Test_retryConfig_GetWait(t *testing.T) { + rt := defaultRetryConfig() + + t.Run("step is negative", func(t *testing.T) { + d, err := rt.GetWait(-1) + assert.Equal(t, time.Duration(0), d) + assert.ErrorIs(t, err, ErrNoRetry) + }) + + t.Run("returns duration", func(t *testing.T) { + d, err := rt.GetWait(0) + assert.Equal(t, time.Minute, d) + assert.NoError(t, err) + }) + + t.Run("step too large", func(t *testing.T) { + d, err := rt.GetWait(len(rt.steps)) + assert.Equal(t, time.Duration(0), d) + assert.ErrorIs(t, err, ErrNoRetry) + }) +} diff --git a/internal/pkg/agent/application/gateway/fleet/fleet_gateway.go b/internal/pkg/agent/application/gateway/fleet/fleet_gateway.go index 38fad92057c..0288152f726 100644 --- a/internal/pkg/agent/application/gateway/fleet/fleet_gateway.go +++ b/internal/pkg/agent/application/gateway/fleet/fleet_gateway.go @@ -6,12 +6,10 @@ package fleet import ( "context" - "fmt" "time" eaclient "github.com/elastic/elastic-agent-client/v7/pkg/client" "github.com/elastic/elastic-agent/internal/pkg/agent/application/coordinator" - "github.com/elastic/elastic-agent/internal/pkg/agent/application/dispatcher" "github.com/elastic/elastic-agent/internal/pkg/agent/application/gateway" "github.com/elastic/elastic-agent/internal/pkg/agent/application/info" agentclient "github.com/elastic/elastic-agent/internal/pkg/agent/control/client" @@ -66,7 +64,6 @@ type stateStore interface { type fleetGateway struct { log *logger.Logger - dispatcher dispatcher.Dispatcher client client.Sender scheduler scheduler.Scheduler settings *fleetGatewaySettings @@ -76,6 +73,7 @@ type fleetGateway struct { stateFetcher coordinator.StateFetcher stateStore stateStore errCh chan error + actionCh chan []fleetapi.Action } // New creates a new fleet gateway @@ -83,7 +81,6 @@ func New( log *logger.Logger, agentInfo agentInfo, client client.Sender, - d dispatcher.Dispatcher, acker acker.Acker, stateFetcher coordinator.StateFetcher, stateStore stateStore, @@ -95,7 +92,6 @@ func New( defaultGatewaySettings, agentInfo, client, - d, scheduler, acker, stateFetcher, @@ -108,7 +104,6 @@ func newFleetGatewayWithScheduler( settings *fleetGatewaySettings, agentInfo agentInfo, client client.Sender, - d dispatcher.Dispatcher, scheduler scheduler.Scheduler, acker acker.Acker, stateFetcher coordinator.StateFetcher, @@ -116,7 +111,6 @@ func newFleetGatewayWithScheduler( ) (gateway.FleetGateway, error) { return &fleetGateway{ log: log, - dispatcher: d, client: client, settings: settings, agentInfo: agentInfo, @@ -125,9 +119,14 @@ func newFleetGatewayWithScheduler( stateFetcher: stateFetcher, stateStore: stateStore, errCh: make(chan error), + actionCh: make(chan []fleetapi.Action, 1), }, nil } +func (f *fleetGateway) Actions() <-chan []fleetapi.Action { + return f.actionCh +} + func (f *fleetGateway) Run(ctx context.Context) error { // Backoff implementation doesn't support the use of a context [cancellation] as the shutdown mechanism. // So we keep a done channel that will be closed when the current context is shutdown. @@ -162,19 +161,8 @@ func (f *fleetGateway) Run(ctx context.Context) error { actions := make([]fleetapi.Action, len(resp.Actions)) copy(actions, resp.Actions) - - // Persist state - hadErr := false - if err := f.dispatcher.Dispatch(context.Background(), f.acker, actions...); err != nil { - err = fmt.Errorf("failed to dispatch actions, error: %w", err) - f.log.Error(err) - f.errCh <- err - hadErr = true - } - - f.log.Debugf("FleetGateway is sleeping, next update in %s", f.settings.Duration) - if !hadErr { - f.errCh <- nil + if len(actions) > 0 { + f.actionCh <- actions } } } diff --git a/internal/pkg/agent/application/gateway/fleet/fleet_gateway_test.go b/internal/pkg/agent/application/gateway/fleet/fleet_gateway_test.go index 49c05112e18..076453f1374 100644 --- a/internal/pkg/agent/application/gateway/fleet/fleet_gateway_test.go +++ b/internal/pkg/agent/application/gateway/fleet/fleet_gateway_test.go @@ -25,8 +25,6 @@ import ( "github.com/elastic/elastic-agent/internal/pkg/agent/application/gateway" "github.com/elastic/elastic-agent/internal/pkg/agent/storage" "github.com/elastic/elastic-agent/internal/pkg/agent/storage/store" - "github.com/elastic/elastic-agent/internal/pkg/fleetapi" - "github.com/elastic/elastic-agent/internal/pkg/fleetapi/acker" "github.com/elastic/elastic-agent/internal/pkg/fleetapi/acker/noop" "github.com/elastic/elastic-agent/internal/pkg/scheduler" "github.com/elastic/elastic-agent/pkg/core/logger" @@ -69,53 +67,12 @@ func newTestingClient() *testingClient { return &testingClient{received: make(chan struct{}, 1)} } -type testingDispatcherFunc func(...fleetapi.Action) error - -type testingDispatcher struct { - sync.Mutex - callback testingDispatcherFunc - received chan struct{} -} - -func (t *testingDispatcher) Dispatch(_ context.Context, acker acker.Acker, actions ...fleetapi.Action) error { - t.Lock() - defer t.Unlock() - defer func() { t.received <- struct{}{} }() - // Get a dummy context. - ctx := context.Background() - - // In context of testing we need to abort on error. - if err := t.callback(actions...); err != nil { - return err - } - - // Ack everything and commit at the end. - for _, action := range actions { - _ = acker.Ack(ctx, action) - } - _ = acker.Commit(ctx) - - return nil -} - -func (t *testingDispatcher) Answer(fn testingDispatcherFunc) <-chan struct{} { - t.Lock() - defer t.Unlock() - t.callback = fn - return t.received -} - -func newTestingDispatcher() *testingDispatcher { - return &testingDispatcher{received: make(chan struct{}, 1)} -} - -type withGatewayFunc func(*testing.T, gateway.FleetGateway, *testingClient, *testingDispatcher, *scheduler.Stepper) +type withGatewayFunc func(*testing.T, gateway.FleetGateway, *testingClient, *scheduler.Stepper) func withGateway(agentInfo agentInfo, settings *fleetGatewaySettings, fn withGatewayFunc) func(t *testing.T) { return func(t *testing.T) { scheduler := scheduler.NewStepper() client := newTestingClient() - dispatcher := newTestingDispatcher() log, _ := logger.New("fleet_gateway", false) @@ -126,7 +83,6 @@ func withGateway(agentInfo agentInfo, settings *fleetGatewaySettings, fn withGat settings, agentInfo, client, - dispatcher, scheduler, noop.New(), &emptyStateFetcher{}, @@ -135,7 +91,7 @@ func withGateway(agentInfo agentInfo, settings *fleetGatewaySettings, fn withGat require.NoError(t, err) - fn(t, gateway, client, dispatcher, scheduler) + fn(t, gateway, client, scheduler) } } @@ -171,7 +127,6 @@ func TestFleetGateway(t *testing.T) { t *testing.T, gateway gateway.FleetGateway, client *testingClient, - dispatcher *testingDispatcher, scheduler *scheduler.Stepper, ) { ctx, cancel := context.WithCancel(context.Background()) @@ -182,10 +137,6 @@ func TestFleetGateway(t *testing.T) { resp := wrapStrToResp(http.StatusOK, `{ "actions": [] }`) return resp, nil }), - dispatcher.Answer(func(actions ...fleetapi.Action) error { - require.Equal(t, 0, len(actions)) - return nil - }), ) errCh := runFleetGateway(ctx, gateway) @@ -197,13 +148,17 @@ func TestFleetGateway(t *testing.T) { cancel() err := <-errCh require.NoError(t, err) + select { + case actions := <-gateway.Actions(): + t.Errorf("Expected no actions, got %v", actions) + default: + } })) t.Run("Successfully connects and receives a series of actions", withGateway(agentInfo, settings, func( t *testing.T, gateway gateway.FleetGateway, client *testingClient, - dispatcher *testingDispatcher, scheduler *scheduler.Stepper, ) { ctx, cancel := context.WithCancel(context.Background()) @@ -233,10 +188,6 @@ func TestFleetGateway(t *testing.T) { `) return resp, nil }), - dispatcher.Answer(func(actions ...fleetapi.Action) error { - require.Len(t, actions, 2) - return nil - }), ) errCh := runFleetGateway(ctx, gateway) @@ -247,13 +198,18 @@ func TestFleetGateway(t *testing.T) { cancel() err := <-errCh require.NoError(t, err) + select { + case actions := <-gateway.Actions(): + require.Len(t, actions, 2) + default: + t.Errorf("Expected to receive actions") + } })) // Test the normal time based execution. t.Run("Periodically communicates with Fleet", func(t *testing.T) { scheduler := scheduler.NewPeriodic(150 * time.Millisecond) client := newTestingClient() - dispatcher := newTestingDispatcher() ctx, cancel := context.WithCancel(context.Background()) defer cancel() @@ -266,7 +222,6 @@ func TestFleetGateway(t *testing.T) { settings, agentInfo, client, - dispatcher, scheduler, noop.New(), &emptyStateFetcher{}, @@ -279,10 +234,6 @@ func TestFleetGateway(t *testing.T) { resp := wrapStrToResp(http.StatusOK, `{ "actions": [] }`) return resp, nil }), - dispatcher.Answer(func(actions ...fleetapi.Action) error { - require.Equal(t, 0, len(actions)) - return nil - }), ) errCh := runFleetGateway(ctx, gateway) @@ -309,7 +260,6 @@ func TestFleetGateway(t *testing.T) { d := 20 * time.Minute scheduler := scheduler.NewPeriodic(d) client := newTestingClient() - dispatcher := newTestingDispatcher() ctx, cancel := context.WithCancel(context.Background()) @@ -324,7 +274,6 @@ func TestFleetGateway(t *testing.T) { }, agentInfo, client, - dispatcher, scheduler, noop.New(), &emptyStateFetcher{}, @@ -332,7 +281,6 @@ func TestFleetGateway(t *testing.T) { ) require.NoError(t, err) - ch1 := dispatcher.Answer(func(actions ...fleetapi.Action) error { return nil }) ch2 := client.Answer(func(headers http.Header, body io.Reader) (*http.Response, error) { resp := wrapStrToResp(http.StatusOK, `{ "actions": [] }`) return resp, nil @@ -340,14 +288,7 @@ func TestFleetGateway(t *testing.T) { errCh := runFleetGateway(ctx, gateway) - // Silently dispatch action. - go func() { - for range ch1 { - } - }() - // Make sure that all API calls to the checkin API are successful, the following will happen: - // block on the first call. <-ch2 @@ -379,7 +320,6 @@ func TestRetriesOnFailures(t *testing.T) { t *testing.T, gateway gateway.FleetGateway, client *testingClient, - dispatcher *testingDispatcher, scheduler *scheduler.Stepper, ) { ctx, cancel := context.WithCancel(context.Background()) @@ -406,11 +346,6 @@ func TestRetriesOnFailures(t *testing.T) { resp := wrapStrToResp(http.StatusOK, `{ "actions": [] }`) return resp, nil }), - - dispatcher.Answer(func(actions ...fleetapi.Action) error { - require.Equal(t, 0, len(actions)) - return nil - }), ) waitFn() @@ -418,6 +353,11 @@ func TestRetriesOnFailures(t *testing.T) { cancel() err := <-errCh require.NoError(t, err) + select { + case actions := <-gateway.Actions(): + t.Errorf("Expected no actions, got %v", actions) + default: + } })) t.Run("The retry loop is interruptible", @@ -428,7 +368,6 @@ func TestRetriesOnFailures(t *testing.T) { t *testing.T, gateway gateway.FleetGateway, client *testingClient, - dispatcher *testingDispatcher, scheduler *scheduler.Stepper, ) { ctx, cancel := context.WithCancel(context.Background()) diff --git a/internal/pkg/agent/application/gateway/gateway.go b/internal/pkg/agent/application/gateway/gateway.go index d43dd32a0c2..6946c8671a4 100644 --- a/internal/pkg/agent/application/gateway/gateway.go +++ b/internal/pkg/agent/application/gateway/gateway.go @@ -7,6 +7,7 @@ package gateway import ( "context" + "github.com/elastic/elastic-agent/internal/pkg/fleetapi" "github.com/elastic/elastic-agent/internal/pkg/fleetapi/client" ) @@ -21,6 +22,9 @@ type FleetGateway interface { // Errors returns the channel to watch for reported errors. Errors() <-chan error + // Actions returns the channel to watch for new actions from the fleet-server. + Actions() <-chan []fleetapi.Action + // SetClient sets the client for the gateway. SetClient(client.Sender) } diff --git a/internal/pkg/agent/application/managed_mode.go b/internal/pkg/agent/application/managed_mode.go index cb72af2a700..ca50495dcb6 100644 --- a/internal/pkg/agent/application/managed_mode.go +++ b/internal/pkg/agent/application/managed_mode.go @@ -123,6 +123,10 @@ func (m *managedConfigManager) Run(ctx context.Context) error { batchedAcker := lazy.NewAcker(ack, m.log, lazy.WithRetrier(retrier)) actionAcker := store.NewStateStoreActionAcker(batchedAcker, m.stateStore) + if err := m.coord.AckUpgrade(ctx, actionAcker); err != nil { + m.log.Warnf("Failed to ack upgrade: %v", err) + } + // Run the retrier. retrierRun := make(chan bool) retrierCtx, retrierCancel := context.WithCancel(ctx) @@ -135,15 +139,26 @@ func (m *managedConfigManager) Run(ctx context.Context) error { close(retrierRun) }() + // Gather errors from the dispatcher and pass to the error channel. + go func() { + for { + select { + case <-ctx.Done(): + return + case err := <-actionDispatcher.Errors(): + m.errCh <- err // err is one or more failures from dispatching an action + } + } + }() + actions := m.stateStore.Actions() stateRestored := false if len(actions) > 0 && !m.wasUnenrolled() { // TODO(ph) We will need an improvement on fleet, if there is an error while dispatching a // persisted action on disk we should be able to ask Fleet to get the latest configuration. // But at the moment this is not possible because the policy change was acked. - if err := store.ReplayActions(ctx, m.log, actionDispatcher, actionAcker, actions...); err != nil { - m.log.Errorf("could not recover state, error %+v, skipping...", err) - } + m.log.Info("restoring current policy from disk") + actionDispatcher.Dispatch(ctx, actionAcker, actions...) stateRestored = true } @@ -167,7 +182,6 @@ func (m *managedConfigManager) Run(ctx context.Context) error { m.log, m.agentInfo, m.client, - actionDispatcher, actionAcker, m.coord, m.stateStore, @@ -200,6 +214,18 @@ func (m *managedConfigManager) Run(ctx context.Context) error { return gateway.Run(ctx) }) + // pass actions collected from gateway to dispatcher + go func() { + for { + select { + case <-ctx.Done(): + return + case actions := <-gateway.Actions(): + actionDispatcher.Dispatch(ctx, actionAcker, actions...) + } + } + }() + <-ctx.Done() return gatewayRunner.Err() } diff --git a/internal/pkg/agent/application/upgrade/upgrade.go b/internal/pkg/agent/application/upgrade/upgrade.go index edc70c3f5c0..31f48d8d0d0 100644 --- a/internal/pkg/agent/application/upgrade/upgrade.go +++ b/internal/pkg/agent/application/upgrade/upgrade.go @@ -142,7 +142,8 @@ func (u *Upgrader) Upgrade(ctx context.Context, version string, sourceURI string } if strings.HasPrefix(release.Commit(), newHash) { - return nil, ErrSameVersion + u.log.Warn("Upgrade action skipped: upgrade did not occur because its the same version") + return nil, nil } if err := copyActionStore(newHash); err != nil { @@ -161,7 +162,7 @@ func (u *Upgrader) Upgrade(ctx context.Context, version string, sourceURI string if err := InvokeWatcher(u.log); err != nil { rollbackInstall(ctx, newHash) - return nil, errors.New("failed to invoke rollback watcher", err) + return nil, err } cb := shutdownCallback(u.log, paths.Home(), release.Version(), version, release.TrimCommit(newHash)) diff --git a/internal/pkg/agent/storage/store/state_store.go b/internal/pkg/agent/storage/store/state_store.go index 8a6d3fc5e8d..522e46fdade 100644 --- a/internal/pkg/agent/storage/store/state_store.go +++ b/internal/pkg/agent/storage/store/state_store.go @@ -20,10 +20,6 @@ import ( "github.com/elastic/elastic-agent/pkg/core/logger" ) -type dispatcher interface { - Dispatch(context.Context, acker.Acker, ...action) error -} - type store interface { Save(io.Reader) error } @@ -98,7 +94,7 @@ func NewStateStore(log *logger.Logger, store storeLoad) (*StateStore, error) { // persisted and we return an empty store. reader, err := store.Load() if err != nil { - return &StateStore{log: log, store: store}, nil //nolint:nilerr // expected results + return &StateStore{log: log, store: store}, nil } defer reader.Close() @@ -340,23 +336,6 @@ func (a *StateStoreActionAcker) Commit(ctx context.Context) error { return a.acker.Commit(ctx) } -// ReplayActions replays list of actions. -func ReplayActions( - ctx context.Context, - log *logger.Logger, - dispatcher dispatcher, - acker acker.Acker, - actions ...action, -) error { - log.Info("restoring current policy from disk") - - if err := dispatcher.Dispatch(ctx, acker, actions...); err != nil { - return err - } - - return nil -} - func yamlToReader(in interface{}) (io.Reader, error) { data, err := yaml.Marshal(in) if err != nil { diff --git a/internal/pkg/agent/storage/store/state_store_test.go b/internal/pkg/agent/storage/store/state_store_test.go index e73b8721fbe..446433ca1ae 100644 --- a/internal/pkg/agent/storage/store/state_store_test.go +++ b/internal/pkg/agent/storage/store/state_store_test.go @@ -31,7 +31,7 @@ func TestStateStore(t *testing.T) { func runTestStateStore(t *testing.T, ackToken string) { log, _ := logger.New("state_store", false) - withFile := func(fn func(t *testing.T, file string)) func(*testing.T) { //nolint:unparam // false positive + withFile := func(fn func(t *testing.T, file string)) func(*testing.T) { return func(t *testing.T) { dir := t.TempDir() file := filepath.Join(dir, "state.yml") @@ -132,7 +132,9 @@ func runTestStateStore(t *testing.T, ackToken string) { require.Empty(t, store1.Actions()) require.Len(t, store1.Queue(), 1) require.Equal(t, "test", store1.Queue()[0].ID()) - start, err := store1.Queue()[0].StartTime() + scheduledAction, ok := store1.Queue()[0].(fleetapi.ScheduledAction) + require.True(t, ok, "expected to be able to cast Action as ScheduledAction") + start, err := scheduledAction.StartTime() require.NoError(t, err) require.Equal(t, ts, start) })) @@ -146,6 +148,7 @@ func runTestStateStore(t *testing.T, ackToken string) { ActionStartTime: ts.Format(time.RFC3339), Version: "1.2.3", SourceURI: "https://example.com", + Retry: 1, }, &fleetapi.ActionPolicyChange{ ActionID: "abc123", ActionType: "POLICY_CHANGE", @@ -172,13 +175,18 @@ func runTestStateStore(t *testing.T, ackToken string) { require.Len(t, store1.Queue(), 2) require.Equal(t, "test", store1.Queue()[0].ID()) - start, err := store1.Queue()[0].StartTime() + scheduledAction, ok := store1.Queue()[0].(fleetapi.ScheduledAction) + require.True(t, ok, "expected to be able to cast Action as ScheduledAction") + start, err := scheduledAction.StartTime() require.NoError(t, err) require.Equal(t, ts, start) + retryableAction, ok := store1.Queue()[0].(fleetapi.RetryableAction) + require.True(t, ok, "expected to be able to cast Action as RetryableAction") + require.Equal(t, 1, retryableAction.RetryAttempt()) require.Equal(t, "abc123", store1.Queue()[1].ID()) - _, err = store1.Queue()[1].StartTime() - require.ErrorIs(t, err, fleetapi.ErrNoStartTime) + _, ok = store1.Queue()[1].(fleetapi.ScheduledAction) + require.False(t, ok, "expected cast to ScheduledAction to fail") })) t.Run("can save to disk unenroll action type", diff --git a/internal/pkg/fleetapi/ack_cmd.go b/internal/pkg/fleetapi/ack_cmd.go index 09ba6f6b4ac..e8d8ac7e9e3 100644 --- a/internal/pkg/fleetapi/ack_cmd.go +++ b/internal/pkg/fleetapi/ack_cmd.go @@ -21,13 +21,13 @@ const ackPath = "/api/fleet/agents/%s/acks" // AckEvent is an event sent in an ACK request. type AckEvent struct { - EventType string `json:"type"` // 'STATE' | 'ERROR' | 'ACTION_RESULT' | 'ACTION' - SubType string `json:"subtype"` // 'RUNNING','STARTING','IN_PROGRESS','CONFIG','FAILED','STOPPING','STOPPED','DATA_DUMP','ACKNOWLEDGED','UNKNOWN'; - Timestamp string `json:"timestamp"` // : '2019-01-05T14:32:03.36764-05:00', - ActionID string `json:"action_id"` // : '48cebde1-c906-4893-b89f-595d943b72a2', - AgentID string `json:"agent_id"` // : 'agent1', - Message string `json:"message,omitempty"` // : 'hello2', - Payload string `json:"payload,omitempty"` // : 'payload2', + EventType string `json:"type"` // 'STATE' | 'ERROR' | 'ACTION_RESULT' | 'ACTION' + SubType string `json:"subtype"` // 'RUNNING','STARTING','IN_PROGRESS','CONFIG','FAILED','STOPPING','STOPPED','DATA_DUMP','ACKNOWLEDGED','UNKNOWN'; + Timestamp string `json:"timestamp"` // : '2019-01-05T14:32:03.36764-05:00', + ActionID string `json:"action_id"` // : '48cebde1-c906-4893-b89f-595d943b72a2', + AgentID string `json:"agent_id"` // : 'agent1', + Message string `json:"message,omitempty"` // : 'hello2', + Payload json.RawMessage `json:"payload,omitempty"` // : 'payload2', ActionInputType string `json:"action_input_type,omitempty"` // copy of original action input_type ActionData json.RawMessage `json:"action_data,omitempty"` // copy of original action data diff --git a/internal/pkg/fleetapi/acker/fleet/fleet_acker.go b/internal/pkg/fleetapi/acker/fleet/fleet_acker.go index c34fd8c3309..b78a55069d8 100644 --- a/internal/pkg/fleetapi/acker/fleet/fleet_acker.go +++ b/internal/pkg/fleetapi/acker/fleet/fleet_acker.go @@ -6,6 +6,7 @@ package fleet import ( "context" + "encoding/json" "fmt" "strings" "time" @@ -127,6 +128,23 @@ func constructEvent(action fleetapi.Action, agentID string) fleetapi.AckEvent { Message: fmt.Sprintf("Action '%s' of type '%s' acknowledged.", action.ID(), action.Type()), } + if a, ok := action.(fleetapi.RetryableAction); ok { + if err := a.GetError(); err != nil { + ackev.Error = err.Error() + var payload struct { + Retry bool `json:"retry"` + Attempt int `json:"retry_attempt,omitempty"` + } + payload.Retry = true + payload.Attempt = a.RetryAttempt() + if a.RetryAttempt() < 1 { + payload.Retry = false + } + p, _ := json.Marshal(payload) + ackev.Payload = p + } + } + if a, ok := action.(*fleetapi.ActionApp); ok { ackev.ActionInputType = a.InputType ackev.ActionData = a.Data diff --git a/internal/pkg/fleetapi/acker/fleet/fleet_acker_test.go b/internal/pkg/fleetapi/acker/fleet/fleet_acker_test.go index 251495b2173..ebea939a910 100644 --- a/internal/pkg/fleetapi/acker/fleet/fleet_acker_test.go +++ b/internal/pkg/fleetapi/acker/fleet/fleet_acker_test.go @@ -16,7 +16,9 @@ import ( "testing" "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + "github.com/elastic/elastic-agent/internal/pkg/agent/errors" "github.com/elastic/elastic-agent/internal/pkg/fleetapi" "github.com/elastic/elastic-agent/pkg/core/logger" ) @@ -114,6 +116,27 @@ func TestAcker_Ack(t *testing.T) { }, }, }, + { + name: "ackupgrade", + actions: []fleetapi.Action{ + &fleetapi.ActionUpgrade{ + ActionID: "upgrade-ok", + ActionType: fleetapi.ActionTypeUpgrade, + }, + &fleetapi.ActionUpgrade{ + ActionID: "upgrade-retry", + ActionType: fleetapi.ActionTypeUpgrade, + Retry: 1, + Err: errors.New("upgrade failed"), + }, + &fleetapi.ActionUpgrade{ + ActionID: "upgrade-failed", + ActionType: fleetapi.ActionTypeUpgrade, + Retry: -1, + Err: errors.New("upgrade failed"), + }, + }, + }, } log, _ := logger.New("fleet_acker", false) @@ -131,6 +154,29 @@ func TestAcker_Ack(t *testing.T) { assert.EqualValues(t, ac.ID(), req.Events[i].ActionID) assert.EqualValues(t, agentInfo.AgentID(), req.Events[i].AgentID) assert.EqualValues(t, fmt.Sprintf("Action '%s' of type '%s' acknowledged.", ac.ID(), ac.Type()), req.Events[i].Message) + // Check if the fleet acker handles RetryableActions correctly using the UpgradeAction + if a, ok := ac.(*fleetapi.ActionUpgrade); ok { + if a.Err != nil { + assert.EqualValues(t, a.Err.Error(), req.Events[i].Error) + // Check payload + require.NotEmpty(t, req.Events[i].Payload) + var pl struct { + Retry bool `json:"retry"` + Attempt int `json:"retry_attempt,omitempty"` + } + err := json.Unmarshal(req.Events[i].Payload, &pl) + require.NoError(t, err) + assert.Equal(t, a.Retry, pl.Attempt, "action ID %s failed", a.ActionID) + // Check retry flag + if pl.Attempt > 0 { + assert.True(t, pl.Retry) + } else { + assert.False(t, pl.Retry) + } + } else { + assert.Empty(t, req.Events[i].Error) + } + } if a, ok := ac.(*fleetapi.ActionApp); ok { assert.EqualValues(t, a.InputType, req.Events[i].ActionInputType) assert.EqualValues(t, a.Data, req.Events[i].ActionData) @@ -147,27 +193,18 @@ func TestAcker_Ack(t *testing.T) { t.Run(tc.name, func(t *testing.T) { sender := &testSender{} acker, err := NewAcker(log, agentInfo, sender) - if err != nil { - t.Fatal(err) - } - - if acker == nil { - t.Fatal("acker not initialized") - } + require.NoError(t, err) + require.NotNil(t, acker, "acker not initialized") if len(tc.actions) == 1 { err = acker.Ack(context.Background(), tc.actions[0]) } else { _, err = acker.AckBatch(context.Background(), tc.actions) } + require.NoError(t, err) - if err != nil { - t.Fatal(err) - } - - if err := acker.Commit(context.Background()); err != nil { - t.Fatal(err) - } + err = acker.Commit(context.Background()) + require.NoError(t, err) checkRequest(t, tc.actions, sender.req) }) diff --git a/internal/pkg/fleetapi/action.go b/internal/pkg/fleetapi/action.go index 4e6b08cd372..f23ec6e89e4 100644 --- a/internal/pkg/fleetapi/action.go +++ b/internal/pkg/fleetapi/action.go @@ -46,14 +46,40 @@ type Action interface { fmt.Stringer Type() string ID() string - // StartTime returns the earliest time an action should start (for schduled actions) - // Only ActionUpgrade implements this at the moment +} + +// ScheduledAction is an Action that may be executed at a later date +// Only ActionUpgrade implements this at the moment +type ScheduledAction interface { + Action + // StartTime returns the earliest time an action should start. StartTime() (time.Time, error) - // Expiration returns the time where an action is expired and should not be ran (for scheduled actions) - // Only ActionUpgrade implements this at the moment + // Expiration returns the time where an action is expired and should not be ran. Expiration() (time.Time, error) } +// RetryableAction is an Action that may be scheduled for a retry. +type RetryableAction interface { + ScheduledAction + // RetryAttempt returns the retry-attempt number of the action + // the retry_attempt number is meant to be an interal counter for the elastic-agent and not communicated to fleet-server or ES. + // If RetryAttempt returns > 1, and GetError is not nil the acker should signal that the action is being retried. + // If RetryAttempt returns < 1, and GetError is not nil the acker should signal that the action has failed. + RetryAttempt() int + // SetRetryAttempt sets the retry-attempt number of the action + // the retry_attempt number is meant to be an interal counter for the elastic-agent and not communicated to fleet-server or ES. + SetRetryAttempt(int) + // SetStartTime sets the start_time of the action to the specified value. + // this is used by the action-retry mechanism. + SetStartTime(t time.Time) + // GetError returns the error that is associated with the retry. + // If it is a retryable action fleet-server should mark it as such. + // Otherwise fleet-server should mark the action as failed. + GetError() error + // SetError sets the retryable action error + SetError(error) +} + // FleetAction represents an action from fleet-server. // should copy the action definition in fleet-server/model/schema.json type FleetAction struct { @@ -64,6 +90,7 @@ type FleetAction struct { ActionStartTime string `yaml:"start_time,omitempty" json:"start_time,omitempty"` Timeout int64 `yaml:"timeout,omitempty" json:"timeout,omitempty"` Data json.RawMessage `yaml:"data,omitempty" json:"data,omitempty"` + Retry int `json:"retry_attempt,omitempty" yaml:"retry_attempt,omitempty"` // used internally for serialization by elastic-agent. //Agents []string // disabled, fleet-server uses this to generate each agent's actions //Timestamp string // disabled, agent does not care when the document was created //UserID string // disabled, agent does not care @@ -91,16 +118,6 @@ func (a *ActionUnknown) ID() string { return a.ActionID } -// StartTime returns ErrNoStartTime -func (a *ActionUnknown) StartTime() (time.Time, error) { - return time.Time{}, ErrNoStartTime -} - -// Expiration returns ErrNoExpiration -func (a *ActionUnknown) Expiration() (time.Time, error) { - return time.Time{}, ErrNoExpiration -} - func (a *ActionUnknown) String() string { var s strings.Builder s.WriteString("action_id: ") @@ -143,16 +160,6 @@ func (a *ActionPolicyReassign) ID() string { return a.ActionID } -// StartTime returns ErrNoStartTime -func (a *ActionPolicyReassign) StartTime() (time.Time, error) { - return time.Time{}, ErrNoStartTime -} - -// Expiration returns ErrNoExpiration -func (a *ActionPolicyReassign) Expiration() (time.Time, error) { - return time.Time{}, ErrNoExpiration -} - // ActionPolicyChange is a request to apply a new type ActionPolicyChange struct { ActionID string `yaml:"action_id"` @@ -179,16 +186,6 @@ func (a *ActionPolicyChange) ID() string { return a.ActionID } -// StartTime returns ErrNoStartTime -func (a *ActionPolicyChange) StartTime() (time.Time, error) { - return time.Time{}, ErrNoStartTime -} - -// Expiration returns ErrNoExpiration -func (a *ActionPolicyChange) Expiration() (time.Time, error) { - return time.Time{}, ErrNoExpiration -} - // ActionUpgrade is a request for agent to upgrade. type ActionUpgrade struct { ActionID string `yaml:"action_id"` @@ -197,6 +194,8 @@ type ActionUpgrade struct { ActionExpiration string `json:"expiration" yaml:"expiration,omitempty"` Version string `json:"version" yaml:"version,omitempty"` SourceURI string `json:"source_uri,omitempty" yaml:"source_uri,omitempty"` + Retry int `json:"retry_attempt,omitempty" yaml:"retry_attempt,omitempty"` + Err error } func (a *ActionUpgrade) String() string { @@ -242,6 +241,31 @@ func (a *ActionUpgrade) Expiration() (time.Time, error) { return ts.UTC(), nil } +// RetryAttempt will return the retry_attempt of the action +func (a *ActionUpgrade) RetryAttempt() int { + return a.Retry +} + +// SetRetryAttempt sets the retry_attempt of the action +func (a *ActionUpgrade) SetRetryAttempt(n int) { + a.Retry = n +} + +// GetError returns the error associated with the attempt to run the action. +func (a *ActionUpgrade) GetError() error { + return a.Err +} + +// SetError sets the error associated with the attempt to run the action. +func (a *ActionUpgrade) SetError(err error) { + a.Err = err +} + +// SetStartTime sets the start time of the action. +func (a *ActionUpgrade) SetStartTime(t time.Time) { + a.ActionStartTime = t.Format(time.RFC3339) +} + // ActionUnenroll is a request for agent to unhook from fleet. type ActionUnenroll struct { ActionID string `yaml:"action_id"` @@ -268,16 +292,6 @@ func (a *ActionUnenroll) ID() string { return a.ActionID } -// StartTime returns ErrNoStartTime -func (a *ActionUnenroll) StartTime() (time.Time, error) { - return time.Time{}, ErrNoStartTime -} - -// Expiration returns ErrNoExpiration -func (a *ActionUnenroll) Expiration() (time.Time, error) { - return time.Time{}, ErrNoExpiration -} - // ActionSettings is a request to change agent settings. type ActionSettings struct { ActionID string `yaml:"action_id"` @@ -295,16 +309,6 @@ func (a *ActionSettings) Type() string { return a.ActionType } -// StartTime returns ErrNoStartTime -func (a *ActionSettings) StartTime() (time.Time, error) { - return time.Time{}, ErrNoStartTime -} - -// Expiration returns ErrNoExpiration -func (a *ActionSettings) Expiration() (time.Time, error) { - return time.Time{}, ErrNoExpiration -} - func (a *ActionSettings) String() string { var s strings.Builder s.WriteString("action_id: ") @@ -333,16 +337,6 @@ func (a *ActionCancel) Type() string { return a.ActionType } -// StartTime returns ErrNoStartTime -func (a *ActionCancel) StartTime() (time.Time, error) { - return time.Time{}, ErrNoStartTime -} - -// Expiration returns ErrNoExpiration -func (a *ActionCancel) Expiration() (time.Time, error) { - return time.Time{}, ErrNoExpiration -} - func (a *ActionCancel) String() string { var s strings.Builder s.WriteString("action_id: ") @@ -388,16 +382,6 @@ func (a *ActionApp) Type() string { return a.ActionType } -// StartTime returns ErrNoStartTime -func (a *ActionApp) StartTime() (time.Time, error) { - return time.Time{}, ErrNoStartTime -} - -// Expiration returns ErrExpiration -func (a *ActionApp) Expiration() (time.Time, error) { - return time.Time{}, ErrNoExpiration -} - // MarshalMap marshals ActionApp into a corresponding map func (a *ActionApp) MarshalMap() (map[string]interface{}, error) { var res map[string]interface{} @@ -544,6 +528,7 @@ func (a *Actions) UnmarshalYAML(unmarshal func(interface{}) error) error { ActionType: n.ActionType, ActionStartTime: n.ActionStartTime, ActionExpiration: n.ActionExpiration, + Retry: n.Retry, } if err := yaml.Unmarshal(n.Data, &action); err != nil { return errors.New(err, diff --git a/internal/pkg/fleetapi/action_test.go b/internal/pkg/fleetapi/action_test.go index b21e591c297..6a8dae3b31a 100644 --- a/internal/pkg/fleetapi/action_test.go +++ b/internal/pkg/fleetapi/action_test.go @@ -2,6 +2,7 @@ // or more contributor license agreements. Licensed under the Elastic License; // you may not use this file except in compliance with the Elastic License. +//nolint:dupl // duplicate code is in test cases package fleetapi import ( @@ -97,6 +98,7 @@ func TestActionsUnmarshalJSON(t *testing.T) { assert.Empty(t, action.ActionExpiration) assert.Equal(t, "1.2.3", action.Version) assert.Equal(t, "http://example.com", action.SourceURI) + assert.Equal(t, 0, action.Retry) }) t.Run("ActionUpgrade with start time", func(t *testing.T) { p := []byte(`[{"id":"testid","type":"UPGRADE","start_time":"2022-01-02T12:00:00Z","expiration":"2022-01-02T13:00:00Z","data":{"version":"1.2.3","source_uri":"http://example.com"}}]`) @@ -111,6 +113,7 @@ func TestActionsUnmarshalJSON(t *testing.T) { assert.Equal(t, "2022-01-02T13:00:00Z", action.ActionExpiration) assert.Equal(t, "1.2.3", action.Version) assert.Equal(t, "http://example.com", action.SourceURI) + assert.Equal(t, 0, action.Retry) }) t.Run("ActionPolicyChange no start time", func(t *testing.T) { p := []byte(`[{"id":"testid","type":"POLICY_CHANGE","data":{"policy":{"key":"value"}}}]`) @@ -134,4 +137,19 @@ func TestActionsUnmarshalJSON(t *testing.T) { assert.Equal(t, ActionTypePolicyChange, action.ActionType) assert.NotNil(t, action.Policy) }) + t.Run("ActionUpgrade with retry_attempt", func(t *testing.T) { + p := []byte(`[{"id":"testid","type":"UPGRADE","data":{"version":"1.2.3","source_uri":"http://example.com","retry_attempt":1}}]`) + a := &Actions{} + err := a.UnmarshalJSON(p) + require.Nil(t, err) + action, ok := (*a)[0].(*ActionUpgrade) + require.True(t, ok, "unable to cast action to specific type") + assert.Equal(t, "testid", action.ActionID) + assert.Equal(t, ActionTypeUpgrade, action.ActionType) + assert.Empty(t, action.ActionStartTime) + assert.Empty(t, action.ActionExpiration) + assert.Equal(t, "1.2.3", action.Version) + assert.Equal(t, "http://example.com", action.SourceURI) + assert.Equal(t, 1, action.Retry) + }) } diff --git a/internal/pkg/queue/actionqueue.go b/internal/pkg/queue/actionqueue.go index 0f3a2c20ffc..b0cdc127dff 100644 --- a/internal/pkg/queue/actionqueue.go +++ b/internal/pkg/queue/actionqueue.go @@ -19,7 +19,7 @@ type saver interface { // item tracks an action in the action queue type item struct { - action fleetapi.Action + action fleetapi.ScheduledAction priority int64 index int } @@ -76,7 +76,11 @@ func (q *queue) Pop() interface{} { // Will return an error if StartTime fails for any action. func newQueue(actions []fleetapi.Action) (*queue, error) { q := make(queue, len(actions)) - for i, action := range actions { + for i, a := range actions { + action, ok := a.(fleetapi.ScheduledAction) + if !ok { + continue + } ts, err := action.StartTime() if err != nil { return nil, err @@ -106,7 +110,7 @@ func NewActionQueue(actions []fleetapi.Action, s saver) (*ActionQueue, error) { // Add will add an action to the queue with the associated priority. // The priority is meant to be the start-time of the action as a unix epoch time. // Complexity: O(log n) -func (q *ActionQueue) Add(action fleetapi.Action, priority int64) { +func (q *ActionQueue) Add(action fleetapi.ScheduledAction, priority int64) { e := &item{ action: action, priority: priority, @@ -116,9 +120,9 @@ func (q *ActionQueue) Add(action fleetapi.Action, priority int64) { // DequeueActions will dequeue all actions that have a priority less then time.Now(). // Complexity: O(n*log n) -func (q *ActionQueue) DequeueActions() []fleetapi.Action { +func (q *ActionQueue) DequeueActions() []fleetapi.ScheduledAction { ts := time.Now().Unix() - actions := make([]fleetapi.Action, 0) + actions := make([]fleetapi.ScheduledAction, 0) for q.q.Len() != 0 { if (*q.q)[0].priority > ts { break @@ -153,6 +157,20 @@ func (q *ActionQueue) Actions() []fleetapi.Action { return actions } +// CancelType cancels all actions in the queue with a matching action type and returns the number of entries cancelled. +func (q *ActionQueue) CancelType(actionType string) int { + items := make([]*item, 0) + for _, item := range *q.q { + if item.action.Type() == actionType { + items = append(items, item) + } + } + for _, item := range items { + heap.Remove(q.q, item.index) + } + return len(items) +} + // Save persists the queue to disk. func (q *ActionQueue) Save() error { q.s.SetQueue(q.Actions()) diff --git a/internal/pkg/queue/actionqueue_test.go b/internal/pkg/queue/actionqueue_test.go index d951f855737..29643a80326 100644 --- a/internal/pkg/queue/actionqueue_test.go +++ b/internal/pkg/queue/actionqueue_test.go @@ -47,15 +47,15 @@ func (m *mockAction) Expiration() (time.Time, error) { return args.Get(0).(time.Time), args.Error(1) } -type mockPersistor struct { +type mockSaver struct { mock.Mock } -func (m *mockPersistor) SetQueue(a []fleetapi.Action) { +func (m *mockSaver) SetQueue(a []fleetapi.Action) { m.Called(a) } -func (m *mockPersistor) Save() error { +func (m *mockSaver) Save() error { args := m.Called() return args.Error(0) } @@ -238,7 +238,7 @@ func Test_ActionQueue_DequeueActions(t *testing.T) { index: 2, }} heap.Init(q) - aq := &ActionQueue{q, &mockPersistor{}} + aq := &ActionQueue{q, &mockSaver{}} actions := aq.DequeueActions() @@ -272,7 +272,7 @@ func Test_ActionQueue_DequeueActions(t *testing.T) { index: 2, }} heap.Init(q) - aq := &ActionQueue{q, &mockPersistor{}} + aq := &ActionQueue{q, &mockSaver{}} actions := aq.DequeueActions() @@ -304,7 +304,7 @@ func Test_ActionQueue_DequeueActions(t *testing.T) { index: 2, }} heap.Init(q) - aq := &ActionQueue{q, &mockPersistor{}} + aq := &ActionQueue{q, &mockSaver{}} actions := aq.DequeueActions() @@ -332,7 +332,7 @@ func Test_ActionQueue_DequeueActions(t *testing.T) { index: 2, }} heap.Init(q) - aq := &ActionQueue{q, &mockPersistor{}} + aq := &ActionQueue{q, &mockSaver{}} actions := aq.DequeueActions() assert.Empty(t, actions) @@ -361,7 +361,7 @@ func Test_ActionQueue_Cancel(t *testing.T) { t.Run("empty queue", func(t *testing.T) { q := &queue{} - aq := &ActionQueue{q, &mockPersistor{}} + aq := &ActionQueue{q, &mockSaver{}} n := aq.Cancel("test-1") assert.Zero(t, n) @@ -383,7 +383,7 @@ func Test_ActionQueue_Cancel(t *testing.T) { index: 2, }} heap.Init(q) - aq := &ActionQueue{q, &mockPersistor{}} + aq := &ActionQueue{q, &mockSaver{}} n := aq.Cancel("test-1") assert.Equal(t, 1, n) @@ -413,7 +413,7 @@ func Test_ActionQueue_Cancel(t *testing.T) { index: 2, }} heap.Init(q) - aq := &ActionQueue{q, &mockPersistor{}} + aq := &ActionQueue{q, &mockSaver{}} n := aq.Cancel("test-1") assert.Equal(t, 2, n) @@ -440,7 +440,7 @@ func Test_ActionQueue_Cancel(t *testing.T) { index: 2, }} heap.Init(q) - aq := &ActionQueue{q, &mockPersistor{}} + aq := &ActionQueue{q, &mockSaver{}} n := aq.Cancel("test-1") assert.Equal(t, 3, n) @@ -462,7 +462,7 @@ func Test_ActionQueue_Cancel(t *testing.T) { index: 2, }} heap.Init(q) - aq := &ActionQueue{q, &mockPersistor{}} + aq := &ActionQueue{q, &mockSaver{}} n := aq.Cancel("test-0") assert.Zero(t, n) @@ -484,7 +484,7 @@ func Test_ActionQueue_Cancel(t *testing.T) { func Test_ActionQueue_Actions(t *testing.T) { t.Run("empty queue", func(t *testing.T) { q := &queue{} - aq := &ActionQueue{q, &mockPersistor{}} + aq := &ActionQueue{q, &mockSaver{}} actions := aq.Actions() assert.Len(t, actions, 0) }) @@ -510,10 +510,72 @@ func Test_ActionQueue_Actions(t *testing.T) { index: 2, }} heap.Init(q) - aq := &ActionQueue{q, &mockPersistor{}} + aq := &ActionQueue{q, &mockSaver{}} actions := aq.Actions() assert.Len(t, actions, 3) assert.Equal(t, "test-1", actions[0].ID()) }) } + +func Test_ActionQueue_CancelType(t *testing.T) { + a1 := &mockAction{} + a1.On("ID").Return("test-1") + a1.On("Type").Return("upgrade") + a2 := &mockAction{} + a2.On("ID").Return("test-2") + a2.On("Type").Return("upgrade") + a3 := &mockAction{} + a3.On("ID").Return("test-3") + a3.On("Type").Return("unknown") + + t.Run("empty queue", func(t *testing.T) { + aq := &ActionQueue{&queue{}, &mockSaver{}} + + n := aq.CancelType("upgrade") + assert.Equal(t, 0, n) + }) + + t.Run("single item in queue", func(t *testing.T) { + q := &queue{&item{ + action: a1, + priority: 1, + index: 0, + }} + heap.Init(q) + aq := &ActionQueue{q, &mockSaver{}} + + n := aq.CancelType("upgrade") + assert.Equal(t, 1, n) + }) + + t.Run("no matches in queue", func(t *testing.T) { + q := &queue{&item{ + action: a3, + priority: 1, + index: 0, + }} + heap.Init(q) + aq := &ActionQueue{q, &mockSaver{}} + + n := aq.CancelType("upgrade") + assert.Equal(t, 0, n) + }) + + t.Run("all items cancelled", func(t *testing.T) { + q := &queue{&item{ + action: a1, + priority: 1, + index: 0, + }, &item{ + action: a2, + priority: 2, + index: 1, + }} + heap.Init(q) + aq := &ActionQueue{q, &mockSaver{}} + + n := aq.CancelType("upgrade") + assert.Equal(t, 2, n) + }) +}