From ffe77e860cccb4275dea00424ba39697c5b89cf9 Mon Sep 17 00:00:00 2001 From: Michel Laterman <82832767+michel-laterman@users.noreply.github.com> Date: Mon, 23 May 2022 22:56:17 -0700 Subject: [PATCH] Support scheduled actions and cancellation (#419) * Support scheduled actions and cancellation Support scheduled actions by adding a new queue that actions will be added to/removed from before they are sent to the dispatcher. The queue is a priority queue (ordered by start_time). fleet_gateway is responsible for syncing the queue to storage. Cancellation of an action will be handled by a new action dispatcher that will remove actions from the queue (if any) and update the targetID action status. TODO - cancel handler - action expiration - fleet_gateway tests * Add queue tests in fleet_gateway_tests, fix check and linting issues * Force start_time/expiration to be utc * Remove logic todos, fix logging statement * Apply suggestions from code review Co-authored-by: Anderson Queiroz Co-authored-by: Anderson Queiroz --- .golangci.yml | 2 +- CHANGELOG.next.asciidoc | 1 + NOTICE.txt | 32 ++ go.mod | 1 + go.sum | 1 + .../gateway/fleet/fleet_gateway.go | 91 +++- .../gateway/fleet/fleet_gateway_test.go | 323 +++++++++++- .../pkg/agent/application/managed_mode.go | 16 + .../actions/handlers/handler_action_cancel.go | 47 ++ .../pipeline/dispatcher/dispatcher_test.go | 111 ++-- internal/pkg/agent/errors/error.go | 5 +- .../pkg/agent/storage/store/state_store.go | 33 +- .../agent/storage/store/state_store_test.go | 124 ++++- internal/pkg/fleetapi/action.go | 302 ++++++++++- internal/pkg/fleetapi/action_test.go | 55 ++ internal/pkg/queue/actionqueue.go | 131 +++++ internal/pkg/queue/actionqueue_test.go | 485 ++++++++++++++++++ 17 files changed, 1644 insertions(+), 116 deletions(-) create mode 100644 internal/pkg/agent/application/pipeline/actions/handlers/handler_action_cancel.go create mode 100644 internal/pkg/queue/actionqueue.go create mode 100644 internal/pkg/queue/actionqueue_test.go diff --git a/.golangci.yml b/.golangci.yml index 0d4e9e5b454..956b4b4b573 100644 --- a/.golangci.yml +++ b/.golangci.yml @@ -63,7 +63,7 @@ linters: - noctx # noctx finds sending http request without context.Context - unconvert # Remove unnecessary type conversions - wastedassign # wastedassign finds wasted assignment statements. - - godox # tool for detection of FIXME, TODO and other comment keywords + # - godox # tool for detection of FIXME, TODO and other comment keywords # all available settings of specific linters linters-settings: diff --git a/CHANGELOG.next.asciidoc b/CHANGELOG.next.asciidoc index a4bd5feffd5..2c3e563cf21 100644 --- a/CHANGELOG.next.asciidoc +++ b/CHANGELOG.next.asciidoc @@ -182,3 +182,4 @@ - Increase the download artifact timeout to 10mins and add log download statistics. {pull}308[308] - Save the agent configuration and the state encrypted on the disk. {issue}535[535] {pull}398[398] - Bump node.js version for heartbeat/synthetics to 16.15.0 +- Support scheduled actions and cancellation of pending actions. {issue}393[393] {pull}419[419] diff --git a/NOTICE.txt b/NOTICE.txt index 5cd23e0750d..0aa819523a0 100644 --- a/NOTICE.txt +++ b/NOTICE.txt @@ -14073,6 +14073,38 @@ THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. +-------------------------------------------------------------------------------- +Dependency : github.com/stretchr/objx +Version: v0.2.0 +Licence type (autodetected): MIT +-------------------------------------------------------------------------------- + +Contents of probable licence file $GOMODCACHE/github.com/stretchr/objx@v0.2.0/LICENSE: + +The MIT License + +Copyright (c) 2014 Stretchr, Inc. +Copyright (c) 2017-2018 objx contributors + +Permission is hereby granted, free of charge, to any person obtaining a copy +of this software and associated documentation files (the "Software"), to deal +in the Software without restriction, including without limitation the rights +to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +copies of the Software, and to permit persons to whom the Software is +furnished to do so, subject to the following conditions: + +The above copyright notice and this permission notice shall be included in all +copies or substantial portions of the Software. + +THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE +SOFTWARE. + + -------------------------------------------------------------------------------- Dependency : github.com/tklauser/go-sysconf Version: v0.3.9 diff --git a/go.mod b/go.mod index e081cb3a865..f073a7105c3 100644 --- a/go.mod +++ b/go.mod @@ -109,6 +109,7 @@ require ( github.com/sergi/go-diff v1.1.0 // indirect github.com/sirupsen/logrus v1.8.1 // indirect github.com/spf13/pflag v1.0.5 // indirect + github.com/stretchr/objx v0.2.0 // indirect github.com/tklauser/go-sysconf v0.3.9 // indirect github.com/tklauser/numcpus v0.3.0 // indirect github.com/yusufpapurcu/wmi v1.2.2 // indirect diff --git a/go.sum b/go.sum index 4cd3152234b..91a2ec0da3a 100644 --- a/go.sum +++ b/go.sum @@ -1120,6 +1120,7 @@ github.com/stoewer/go-strcase v1.2.0/go.mod h1:IBiWB2sKIp3wVVQ3Y035++gc+knqhUQag github.com/stretchr/objx v0.0.0-20180129172003-8a3f7159479f/go.mod h1:HFkY916IF+rwdDfMAkV7OtwuqBVzrE8GR6GFx+wExME= github.com/stretchr/objx v0.1.0/go.mod h1:HFkY916IF+rwdDfMAkV7OtwuqBVzrE8GR6GFx+wExME= github.com/stretchr/objx v0.1.1/go.mod h1:HFkY916IF+rwdDfMAkV7OtwuqBVzrE8GR6GFx+wExME= +github.com/stretchr/objx v0.2.0 h1:Hbg2NidpLE8veEBkEZTL3CvlkUIVzuU9jDplZO54c48= github.com/stretchr/objx v0.2.0/go.mod h1:qt09Ya8vawLte6SNmTgCsAVtYtaKzEcn8ATUoHMkEqE= github.com/stretchr/testify v0.0.0-20180303142811-b89eecf5ca5d/go.mod h1:a8OnRcib4nhh0OaRAV+Yts87kKdq0PP7pXfy6kDkUVs= github.com/stretchr/testify v1.2.1/go.mod h1:a8OnRcib4nhh0OaRAV+Yts87kKdq0PP7pXfy6kDkUVs= diff --git a/internal/pkg/agent/application/gateway/fleet/fleet_gateway.go b/internal/pkg/agent/application/gateway/fleet/fleet_gateway.go index 7ed160bbb2d..4ff4c34ad42 100644 --- a/internal/pkg/agent/application/gateway/fleet/fleet_gateway.go +++ b/internal/pkg/agent/application/gateway/fleet/fleet_gateway.go @@ -6,6 +6,7 @@ package fleet import ( "context" + stderr "errors" "fmt" "sync" "time" @@ -62,6 +63,14 @@ type stateStore interface { AckToken() string SetAckToken(ackToken string) Save() error + SetQueue([]fleetapi.Action) + Actions() []fleetapi.Action +} + +type actionQueue interface { + Add(fleetapi.Action, int64) + DequeueActions() []fleetapi.Action + Cancel(string) int Actions() []fleetapi.Action } @@ -82,6 +91,7 @@ type fleetGateway struct { statusController status.Controller statusReporter status.Reporter stateStore stateStore + queue actionQueue } // New creates a new fleet gateway @@ -95,6 +105,7 @@ func New( acker store.FleetAcker, statusController status.Controller, stateStore stateStore, + queue actionQueue, ) (gateway.FleetGateway, error) { scheduler := scheduler.NewPeriodicJitter(defaultGatewaySettings.Duration, defaultGatewaySettings.Jitter) @@ -110,6 +121,7 @@ func New( acker, statusController, stateStore, + queue, ) } @@ -125,6 +137,7 @@ func newFleetGatewayWithScheduler( acker store.FleetAcker, statusController status.Controller, stateStore stateStore, + queue actionQueue, ) (gateway.FleetGateway, error) { // Backoff implementation doesn't support the use of a context [cancellation] @@ -151,13 +164,14 @@ func newFleetGatewayWithScheduler( statusReporter: statusController.RegisterComponent("gateway"), statusController: statusController, stateStore: stateStore, + queue: queue, }, nil } func (f *fleetGateway) worker() { for { select { - case <-f.scheduler.WaitTick(): + case ts := <-f.scheduler.WaitTick(): f.log.Debug("FleetGateway calling Checkin API") // Execute the checkin call and for any errors returned by the fleet-server API @@ -168,12 +182,27 @@ func (f *fleetGateway) worker() { continue } - actions := make([]fleetapi.Action, len(resp.Actions)) - for idx, a := range resp.Actions { - actions[idx] = a + actions := f.queueScheduledActions(resp.Actions) + actions, err = f.dispatchCancelActions(actions) + if err != nil { + f.log.Error(err.Error()) } + queued, expired := f.gatherQueuedActions(ts.UTC()) + f.log.Debugf("Gathered %d actions from queue, %d actions expired", len(queued), len(expired)) + f.log.Debugf("Expired actions: %v", expired) + + actions = append(actions, queued...) + var errMsg string + // Persist state + f.stateStore.SetQueue(f.queue.Actions()) + if err := f.stateStore.Save(); err != nil { + errMsg = fmt.Sprintf("failed to persist action_queue, error: %s", err) + f.log.Error(errMsg) + f.statusReporter.Update(state.Failed, errMsg, nil) + } + if err := f.dispatcher.Dispatch(context.Background(), f.acker, actions...); err != nil { errMsg = fmt.Sprintf("failed to dispatch actions, error: %s", err) f.log.Error(errMsg) @@ -194,6 +223,60 @@ func (f *fleetGateway) worker() { } } +// queueScheduledActions will add any action in actions with a valid start time to the queue and return the rest. +// start time to current time comparisons are purposefully not made in case of cancel actions. +func (f *fleetGateway) queueScheduledActions(input fleetapi.Actions) []fleetapi.Action { + actions := make([]fleetapi.Action, 0, len(input)) + for _, action := range input { + start, err := action.StartTime() + if err == nil { + f.log.Debugf("Adding action id: %s to queue.", action.ID()) + f.queue.Add(action, start.Unix()) + continue + } + if !stderr.Is(err, fleetapi.ErrNoStartTime) { + f.log.Warnf("Issue gathering start time from action id %s: %v", action.ID(), err) + } + actions = append(actions, action) + } + return actions +} + +// dispatchCancelActions will separate and dispatch any cancel actions from the actions list and return the rest of the list. +// cancel actions are dispatched seperatly as they may remove items from the queue. +func (f *fleetGateway) dispatchCancelActions(actions []fleetapi.Action) ([]fleetapi.Action, error) { + // separate cancel actions from the actions list + cancelActions := make([]fleetapi.Action, 0, len(actions)) + for i := len(actions) - 1; i >= 0; i-- { + action := actions[i] + if action.Type() == fleetapi.ActionTypeCancel { + cancelActions = append(cancelActions, action) + actions = append(actions[:i], actions[i+1:]...) + } + } + // Dispatch cancel actions + if len(cancelActions) > 0 { + if err := f.dispatcher.Dispatch(context.Background(), f.acker, cancelActions...); err != nil { + return actions, fmt.Errorf("failed to dispatch cancel actions: %w", err) + } + } + return actions, nil +} + +// gatherQueuedActions will dequeue actions from the action queue and separate those that have already expired. +func (f *fleetGateway) gatherQueuedActions(ts time.Time) (queued, expired []fleetapi.Action) { + actions := f.queue.DequeueActions() + for _, action := range actions { + exp, _ := action.Expiration() + if ts.After(exp) { + expired = append(expired, action) + continue + } + queued = append(queued, action) + } + return queued, expired +} + func (f *fleetGateway) doExecute() (*fleetapi.CheckinResponse, error) { f.backoff.Reset() diff --git a/internal/pkg/agent/application/gateway/fleet/fleet_gateway_test.go b/internal/pkg/agent/application/gateway/fleet/fleet_gateway_test.go index 8c645d57398..a9b9380519f 100644 --- a/internal/pkg/agent/application/gateway/fleet/fleet_gateway_test.go +++ b/internal/pkg/agent/application/gateway/fleet/fleet_gateway_test.go @@ -2,6 +2,7 @@ // or more contributor license agreements. Licensed under the Elastic License; // you may not use this file except in compliance with the Elastic License. +//nolint:dupl // duplicate code is in test cases package fleet import ( @@ -18,6 +19,7 @@ import ( "time" "github.com/pkg/errors" + "github.com/stretchr/testify/mock" "github.com/stretchr/testify/require" "github.com/elastic/elastic-agent/internal/pkg/agent/application/gateway" @@ -92,9 +94,9 @@ func (t *testingDispatcher) Dispatch(_ context.Context, acker store.FleetAcker, // Ack everything and commit at the end. for _, action := range actions { - acker.Ack(ctx, action) + _ = acker.Ack(ctx, action) } - acker.Commit(ctx) + _ = acker.Commit(ctx) return nil } @@ -110,6 +112,29 @@ func newTestingDispatcher() *testingDispatcher { return &testingDispatcher{received: make(chan struct{}, 1)} } +type mockQueue struct { + mock.Mock +} + +func (m *mockQueue) Add(action fleetapi.Action, n int64) { + m.Called(action, n) +} + +func (m *mockQueue) DequeueActions() []fleetapi.Action { + args := m.Called() + return args.Get(0).([]fleetapi.Action) +} + +func (m *mockQueue) Cancel(id string) int { + args := m.Called(id) + return args.Int(0) +} + +func (m *mockQueue) Actions() []fleetapi.Action { + args := m.Called() + return args.Get(0).([]fleetapi.Action) +} + type withGatewayFunc func(*testing.T, gateway.FleetGateway, *testingClient, *testingDispatcher, *scheduler.Stepper, repo.Backend) func withGateway(agentInfo agentInfo, settings *fleetGatewaySettings, fn withGatewayFunc) func(t *testing.T) { @@ -128,6 +153,10 @@ func withGateway(agentInfo agentInfo, settings *fleetGatewaySettings, fn withGat stateStore, err := store.NewStateStore(log, diskStore) require.NoError(t, err) + queue := &mockQueue{} + queue.On("DequeueActions").Return([]fleetapi.Action{}) + queue.On("Actions").Return([]fleetapi.Action{}) + gateway, err := newFleetGatewayWithScheduler( ctx, log, @@ -140,6 +169,7 @@ func withGateway(agentInfo agentInfo, settings *fleetGatewaySettings, fn withGat noopacker.NewAcker(), &noopController{}, stateStore, + queue, ) require.NoError(t, err) @@ -194,7 +224,8 @@ func TestFleetGateway(t *testing.T) { return nil }), ) - gateway.Start() + err := gateway.Start() + require.NoError(t, err) // Synchronize scheduler and acking of calls from the worker go routine. scheduler.Next() @@ -234,11 +265,12 @@ func TestFleetGateway(t *testing.T) { return resp, nil }), dispatcher.Answer(func(actions ...fleetapi.Action) error { - require.Equal(t, 2, len(actions)) + require.Len(t, actions, 2) return nil }), ) - gateway.Start() + err := gateway.Start() + require.NoError(t, err) scheduler.Next() waitFn() @@ -259,6 +291,10 @@ func TestFleetGateway(t *testing.T) { stateStore, err := store.NewStateStore(log, diskStore) require.NoError(t, err) + queue := &mockQueue{} + queue.On("DequeueActions").Return([]fleetapi.Action{}) + queue.On("Actions").Return([]fleetapi.Action{}) + gateway, err := newFleetGatewayWithScheduler( ctx, log, @@ -271,6 +307,7 @@ func TestFleetGateway(t *testing.T) { noopacker.NewAcker(), &noopController{}, stateStore, + queue, ) require.NoError(t, err) @@ -286,7 +323,8 @@ func TestFleetGateway(t *testing.T) { }), ) - gateway.Start() + err = gateway.Start() + require.NoError(t, err) var count int for { @@ -298,6 +336,256 @@ func TestFleetGateway(t *testing.T) { } }) + t.Run("queue action from checkin", func(t *testing.T) { + scheduler := scheduler.NewStepper() + client := newTestingClient() + dispatcher := newTestingDispatcher() + + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + + log, _ := logger.New("tst", false) + + diskStore := storage.NewDiskStore(paths.AgentStateStoreFile()) + stateStore, err := store.NewStateStore(log, diskStore) + require.NoError(t, err) + + ts := time.Now().UTC().Round(time.Second) + queue := &mockQueue{} + queue.On("Add", mock.Anything, ts.Add(time.Hour).Unix()).Return().Once() + queue.On("DequeueActions").Return([]fleetapi.Action{}) + queue.On("Actions").Return([]fleetapi.Action{}) + + gateway, err := newFleetGatewayWithScheduler( + ctx, + log, + settings, + agentInfo, + client, + dispatcher, + scheduler, + getReporter(agentInfo, log, t), + noopacker.NewAcker(), + &noopController{}, + stateStore, + queue, + ) + + require.NoError(t, err) + + waitFn := ackSeq( + client.Answer(func(headers http.Header, body io.Reader) (*http.Response, error) { + resp := wrapStrToResp(http.StatusOK, fmt.Sprintf(`{"actions": [{ + "type": "UPGRADE", + "id": "id1", + "start_time": "%s", + "expiration": "%s", + "data": { + "version": "1.2.3" + } + }]}`, + ts.Add(time.Hour).Format(time.RFC3339), + ts.Add(2*time.Hour).Format(time.RFC3339), + )) + return resp, nil + }), + dispatcher.Answer(func(actions ...fleetapi.Action) error { + require.Equal(t, 0, len(actions)) + return nil + }), + ) + + err = gateway.Start() + require.NoError(t, err) + + scheduler.Next() + waitFn() + queue.AssertExpectations(t) + }) + + t.Run("run action from queue", func(t *testing.T) { + scheduler := scheduler.NewStepper() + client := newTestingClient() + dispatcher := newTestingDispatcher() + + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + + log, _ := logger.New("tst", false) + + diskStore := storage.NewDiskStore(paths.AgentStateStoreFile()) + stateStore, err := store.NewStateStore(log, diskStore) + require.NoError(t, err) + + ts := time.Now().UTC().Round(time.Second) + queue := &mockQueue{} + queue.On("DequeueActions").Return([]fleetapi.Action{&fleetapi.ActionUpgrade{ActionID: "id1", ActionType: "UPGRADE", ActionStartTime: ts.Add(-1 * time.Hour).Format(time.RFC3339), ActionExpiration: ts.Add(time.Hour).Format(time.RFC3339)}}).Once() + queue.On("Actions").Return([]fleetapi.Action{}) + + gateway, err := newFleetGatewayWithScheduler( + ctx, + log, + settings, + agentInfo, + client, + dispatcher, + scheduler, + getReporter(agentInfo, log, t), + noopacker.NewAcker(), + &noopController{}, + stateStore, + queue, + ) + + require.NoError(t, err) + + waitFn := ackSeq( + client.Answer(func(headers http.Header, body io.Reader) (*http.Response, error) { + resp := wrapStrToResp(http.StatusOK, `{"actions": []}`) + return resp, nil + }), + dispatcher.Answer(func(actions ...fleetapi.Action) error { + require.Equal(t, 1, len(actions)) + return nil + }), + ) + + err = gateway.Start() + require.NoError(t, err) + + scheduler.Next() + waitFn() + queue.AssertExpectations(t) + }) + + t.Run("discard expired action from queue", func(t *testing.T) { + scheduler := scheduler.NewStepper() + client := newTestingClient() + dispatcher := newTestingDispatcher() + + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + + log, _ := logger.New("tst", false) + + diskStore := storage.NewDiskStore(paths.AgentStateStoreFile()) + stateStore, err := store.NewStateStore(log, diskStore) + require.NoError(t, err) + + ts := time.Now().UTC().Round(time.Second) + queue := &mockQueue{} + queue.On("DequeueActions").Return([]fleetapi.Action{&fleetapi.ActionUpgrade{ActionID: "id1", ActionType: "UPGRADE", ActionStartTime: ts.Add(-2 * time.Hour).Format(time.RFC3339), ActionExpiration: ts.Add(-1 * time.Hour).Format(time.RFC3339)}}).Once() + queue.On("Actions").Return([]fleetapi.Action{}) + + gateway, err := newFleetGatewayWithScheduler( + ctx, + log, + settings, + agentInfo, + client, + dispatcher, + scheduler, + getReporter(agentInfo, log, t), + noopacker.NewAcker(), + &noopController{}, + stateStore, + queue, + ) + + require.NoError(t, err) + + waitFn := ackSeq( + client.Answer(func(headers http.Header, body io.Reader) (*http.Response, error) { + resp := wrapStrToResp(http.StatusOK, `{"actions": []}`) + return resp, nil + }), + dispatcher.Answer(func(actions ...fleetapi.Action) error { + require.Equal(t, 0, len(actions)) + return nil + }), + ) + + err = gateway.Start() + require.NoError(t, err) + + scheduler.Next() + waitFn() + queue.AssertExpectations(t) + }) + + t.Run("cancel action from checkin", func(t *testing.T) { + scheduler := scheduler.NewStepper() + client := newTestingClient() + dispatcher := newTestingDispatcher() + + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + + log, _ := logger.New("tst", false) + + diskStore := storage.NewDiskStore(paths.AgentStateStoreFile()) + stateStore, err := store.NewStateStore(log, diskStore) + require.NoError(t, err) + + ts := time.Now().UTC().Round(time.Second) + queue := &mockQueue{} + queue.On("Add", mock.Anything, ts.Add(-1*time.Hour).Unix()).Return().Once() + queue.On("DequeueActions").Return([]fleetapi.Action{}) + queue.On("Actions").Return([]fleetapi.Action{}).Maybe() // this test seems flakey if we check for this call + // queue.Cancel does not need to be mocked here as it is ran in the cancel action dispatcher. + + gateway, err := newFleetGatewayWithScheduler( + ctx, + log, + settings, + agentInfo, + client, + dispatcher, + scheduler, + getReporter(agentInfo, log, t), + noopacker.NewAcker(), + &noopController{}, + stateStore, + queue, + ) + + require.NoError(t, err) + + waitFn := ackSeq( + client.Answer(func(headers http.Header, body io.Reader) (*http.Response, error) { + resp := wrapStrToResp(http.StatusOK, fmt.Sprintf(`{"actions": [{ + "type": "UPGRADE", + "id": "id1", + "start_time": "%s", + "expiration": "%s", + "data": { + "version": "1.2.3" + } + }, { + "type": "CANCEL", + "id": "id2", + "data": { + "target_id": "id1" + } + }]}`, + ts.Add(-1*time.Hour).Format(time.RFC3339), + ts.Add(2*time.Hour).Format(time.RFC3339), + )) + return resp, nil + }), + dispatcher.Answer(func(actions ...fleetapi.Action) error { + return nil + }), + ) + + err = gateway.Start() + require.NoError(t, err) + + scheduler.Next() + waitFn() + queue.AssertExpectations(t) + }) + t.Run("send event and receive no action", withGateway(agentInfo, settings, func( t *testing.T, gateway gateway.FleetGateway, @@ -306,7 +594,7 @@ func TestFleetGateway(t *testing.T) { scheduler *scheduler.Stepper, rep repo.Backend, ) { - rep.Report(context.Background(), &testStateEvent{}) + _ = rep.Report(context.Background(), &testStateEvent{}) waitFn := ackSeq( client.Answer(func(headers http.Header, body io.Reader) (*http.Response, error) { cr := &request{} @@ -329,7 +617,8 @@ func TestFleetGateway(t *testing.T) { return nil }), ) - gateway.Start() + err := gateway.Start() + require.NoError(t, err) // Synchronize scheduler and acking of calls from the worker go routine. scheduler.Next() @@ -351,6 +640,10 @@ func TestFleetGateway(t *testing.T) { stateStore, err := store.NewStateStore(log, diskStore) require.NoError(t, err) + queue := &mockQueue{} + queue.On("DequeueActions").Return([]fleetapi.Action{}) + queue.On("Actions").Return([]fleetapi.Action{}) + gateway, err := newFleetGatewayWithScheduler( ctx, log, @@ -366,6 +659,7 @@ func TestFleetGateway(t *testing.T) { noopacker.NewAcker(), &noopController{}, stateStore, + queue, ) require.NoError(t, err) @@ -376,7 +670,8 @@ func TestFleetGateway(t *testing.T) { return resp, nil }) - gateway.Start() + err = gateway.Start() + require.NoError(t, err) // Silently dispatch action. go func() { @@ -423,9 +718,10 @@ func TestRetriesOnFailures(t *testing.T) { return wrapStrToResp(http.StatusInternalServerError, "something is bad"), nil } clientWaitFn := client.Answer(fail) - gateway.Start() + err := gateway.Start() + require.NoError(t, err) - rep.Report(context.Background(), &testStateEvent{}) + _ = rep.Report(context.Background(), &testStateEvent{}) // Initial tick is done out of bound so we can block on channels. scheduler.Next() @@ -479,9 +775,10 @@ func TestRetriesOnFailures(t *testing.T) { return wrapStrToResp(http.StatusInternalServerError, "something is bad"), nil } waitChan := client.Answer(fail) - gateway.Start() + err := gateway.Start() + require.NoError(t, err) - rep.Report(context.Background(), &testStateEvent{}) + _ = rep.Report(context.Background(), &testStateEvent{}) // Initial tick is done out of bound so we can block on channels. scheduler.Next() diff --git a/internal/pkg/agent/application/managed_mode.go b/internal/pkg/agent/application/managed_mode.go index 061daa62cf4..d334ae0198c 100644 --- a/internal/pkg/agent/application/managed_mode.go +++ b/internal/pkg/agent/application/managed_mode.go @@ -41,6 +41,7 @@ import ( "github.com/elastic/elastic-agent/internal/pkg/fleetapi/acker/lazy" "github.com/elastic/elastic-agent/internal/pkg/fleetapi/acker/retrier" "github.com/elastic/elastic-agent/internal/pkg/fleetapi/client" + "github.com/elastic/elastic-agent/internal/pkg/queue" reporting "github.com/elastic/elastic-agent/internal/pkg/reporter" fleetreporter "github.com/elastic/elastic-agent/internal/pkg/reporter/fleet" logreporter "github.com/elastic/elastic-agent/internal/pkg/reporter/log" @@ -55,6 +56,7 @@ type stateStore interface { SetAckToken(ackToken string) Save() error Actions() []fleetapi.Action + Queue() []fleetapi.Action } // Managed application, when the application is run in managed mode, most of the configuration are @@ -179,6 +181,11 @@ func newManaged( managedApplication.stateStore = stateStore actionAcker := store.NewStateStoreActionAcker(batchedAcker, stateStore) + actionQueue, err := queue.NewActionQueue(stateStore.Queue()) + if err != nil { + return nil, fmt.Errorf("unable to initialize action queue: %w", err) + } + actionDispatcher, err := dispatcher.New(managedApplication.bgContext, log, handlers.NewDefault(log)) if err != nil { return nil, err @@ -237,6 +244,14 @@ func newManaged( ), ) + actionDispatcher.MustRegister( + &fleetapi.ActionCancel{}, + handlers.NewCancel( + log, + actionQueue, + ), + ) + actionDispatcher.MustRegister( &fleetapi.ActionApp{}, handlers.NewAppAction(log, managedApplication.srv), @@ -269,6 +284,7 @@ func newManaged( actionAcker, statusCtrl, stateStore, + actionQueue, ) if err != nil { return nil, err diff --git a/internal/pkg/agent/application/pipeline/actions/handlers/handler_action_cancel.go b/internal/pkg/agent/application/pipeline/actions/handlers/handler_action_cancel.go new file mode 100644 index 00000000000..a2208c7294d --- /dev/null +++ b/internal/pkg/agent/application/pipeline/actions/handlers/handler_action_cancel.go @@ -0,0 +1,47 @@ +// Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one +// or more contributor license agreements. Licensed under the Elastic License; +// you may not use this file except in compliance with the Elastic License. + +package handlers + +import ( + "context" + "fmt" + + "github.com/elastic/elastic-agent/internal/pkg/agent/storage/store" + "github.com/elastic/elastic-agent/internal/pkg/fleetapi" + "github.com/elastic/elastic-agent/pkg/core/logger" +) + +type queueCanceler interface { + Cancel(id string) int +} + +// Cancel is a handler for CANCEL actions. +type Cancel struct { + log *logger.Logger + c queueCanceler +} + +// NewCancel creates a new Cancel handler that uses the passed queue canceller. +func NewCancel(log *logger.Logger, cancel queueCanceler) *Cancel { + return &Cancel{ + log: log, + c: cancel, + } +} + +// Handle will cancel any actions in the queue that match target_id. +func (h *Cancel) Handle(ctx context.Context, a fleetapi.Action, acker store.FleetAcker) error { + action, ok := a.(*fleetapi.ActionCancel) + if !ok { + return fmt.Errorf("invalid type, expected ActionCancel and received %T", a) + } + n := h.c.Cancel(action.TargetID) + if n == 0 { + h.log.Debugf("Cancel action id: %s target id: %s found no actions in queue.", action.ActionID, action.TargetID) + return nil + } + h.log.Infof("Cancel action id: %s target id: %s removed %d action(s) from queue.", action.ActionID, action.TargetID, n) + return nil +} diff --git a/internal/pkg/agent/application/pipeline/dispatcher/dispatcher_test.go b/internal/pkg/agent/application/pipeline/dispatcher/dispatcher_test.go index 37654602c77..3c65dd4a2e7 100644 --- a/internal/pkg/agent/application/pipeline/dispatcher/dispatcher_test.go +++ b/internal/pkg/agent/application/pipeline/dispatcher/dispatcher_test.go @@ -7,10 +7,12 @@ package dispatcher import ( "context" "testing" + "time" "go.elastic.co/apm" "go.elastic.co/apm/apmtest" + "github.com/stretchr/testify/mock" "github.com/stretchr/testify/require" "github.com/elastic/elastic-agent/internal/pkg/agent/storage/store" @@ -19,45 +21,58 @@ import ( ) type mockHandler struct { - received fleetapi.Action - called bool - err error + mock.Mock } -func (h *mockHandler) Handle(_ context.Context, a fleetapi.Action, acker store.FleetAcker) error { - h.called = true - h.received = a - return h.err +func (h *mockHandler) Handle(ctx context.Context, a fleetapi.Action, acker store.FleetAcker) error { + args := h.Called(ctx, a, acker) + return args.Error(0) } -type mockAction struct{} - -func (m *mockAction) ID() string { return "mockAction" } -func (m *mockAction) Type() string { return "mockAction" } -func (m *mockAction) String() string { return "mockAction" } - -type mockActionUnknown struct{} - -func (m *mockActionUnknown) ID() string { return "mockActionUnknown" } -func (m *mockActionUnknown) Type() string { return "mockActionUnknown" } -func (m *mockActionUnknown) String() string { return "mockActionUnknown" } - -type mockActionOther struct{} +// need various action structs as the dispather uses type reflection for routing, not action.Type() +type mockAction struct { + mock.Mock +} +type mockOtherAction struct { + mockAction +} +type mockUnknownAction struct { + mockAction +} -func (m *mockActionOther) ID() string { return "mockActionOther" } -func (m *mockActionOther) Type() string { return "mockActionOther" } -func (m *mockActionOther) String() string { return "mockActionOther" } +func (m *mockAction) ID() string { + args := m.Called() + return args.String(0) +} +func (m *mockAction) Type() string { + args := m.Called() + return args.String(0) +} +func (m *mockAction) String() string { + args := m.Called() + return args.String(0) +} +func (m *mockAction) StartTime() (time.Time, error) { + args := m.Called() + return args.Get(0).(time.Time), args.Error(1) +} +func (m *mockAction) Expiration() (time.Time, error) { + args := m.Called() + return args.Get(0).(time.Time), args.Error(1) +} type mockAcker struct { - CommitFn func(ctx context.Context) error + mock.Mock } -func (m mockAcker) Ack(ctx context.Context, action fleetapi.Action) error { - panic("implement me") +func (m *mockAcker) Ack(ctx context.Context, action fleetapi.Action) error { + args := m.Called(ctx, action) + return args.Error(0) } -func (m mockAcker) Commit(ctx context.Context) error { - return m.CommitFn(ctx) +func (m *mockAcker) Commit(ctx context.Context) error { + args := m.Called(ctx) + return args.Error(0) } func TestActionDispatcher(t *testing.T) { @@ -66,24 +81,26 @@ func TestActionDispatcher(t *testing.T) { t.Run("Merges ActionDispatcher ctx cancel and Dispatch ctx value", func(t *testing.T) { action1 := &mockAction{} def := &mockHandler{} + def.On("Handle", mock.Anything, mock.Anything, mock.Anything).Return(nil).Once() span := apmtest.NewRecordingTracer(). StartTransaction("ignore", "ignore"). StartSpan("ignore", "ignore", nil) ctx1, cancel := context.WithCancel(context.Background()) - ack := mockAcker{CommitFn: func(ctx context.Context) error { - // ctx1 not cancelled yet + ack := &mockAcker{} + ack.On("Commit", mock.Anything).Run(func(args mock.Arguments) { + ctx, _ := args.Get(0).(context.Context) require.NoError(t, ctx.Err()) got := apm.SpanFromContext(ctx) require.Equal(t, span.TraceContext().Span, got.ParentID()) cancel() // cancel function from ctx1 require.Equal(t, ctx.Err(), context.Canceled) - return nil - }} + }).Return(nil) d, err := New(ctx1, nil, def) require.NoError(t, err) ctx2 := apm.ContextWithSpan(context.Background(), span) err = d.Dispatch(ctx2, ack, action1) require.NoError(t, err) + ack.AssertExpectations(t) }) t.Run("Success to dispatch multiples events", func(t *testing.T) { @@ -95,38 +112,38 @@ func TestActionDispatcher(t *testing.T) { success1 := &mockHandler{} success2 := &mockHandler{} - d.Register(&mockAction{}, success1) - d.Register(&mockActionOther{}, success2) + err = d.Register(&mockAction{}, success1) + require.NoError(t, err) + err = d.Register(&mockOtherAction{}, success2) + require.NoError(t, err) action1 := &mockAction{} - action2 := &mockActionOther{} + action2 := &mockOtherAction{} - err = d.Dispatch(ctx, ack, action1, action2) + // TODO better matching for actions + success1.On("Handle", mock.Anything, mock.Anything, mock.Anything).Return(nil).Once() + success2.On("Handle", mock.Anything, mock.Anything, mock.Anything).Return(nil).Once() + err = d.Dispatch(ctx, ack, action1, action2) require.NoError(t, err) - require.True(t, success1.called) - require.Equal(t, action1, success1.received) - - require.True(t, success2.called) - require.Equal(t, action2, success2.received) - - require.False(t, def.called) - require.Nil(t, def.received) + success1.AssertExpectations(t) + success2.AssertExpectations(t) + def.AssertNotCalled(t, "Handle", mock.Anything, mock.Anything, mock.Anything) }) t.Run("Unknown action are caught by the unknown handler", func(t *testing.T) { def := &mockHandler{} + def.On("Handle", mock.Anything, mock.Anything, mock.Anything).Return(nil).Once() ctx := context.Background() d, err := New(ctx, nil, def) require.NoError(t, err) - action := &mockActionUnknown{} + action := &mockUnknownAction{} err = d.Dispatch(ctx, ack, action) require.NoError(t, err) - require.True(t, def.called) - require.Equal(t, action, def.received) + def.AssertExpectations(t) }) t.Run("Could not register two handlers on the same action", func(t *testing.T) { diff --git a/internal/pkg/agent/errors/error.go b/internal/pkg/agent/errors/error.go index 5a8f55d9cd0..03a4942c2f4 100644 --- a/internal/pkg/agent/errors/error.go +++ b/internal/pkg/agent/errors/error.go @@ -1,9 +1,8 @@ // Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one // or more contributor license agreements. Licensed under the Elastic License; // you may not use this file except in compliance with the Elastic License. -// -// nolint:errorlint // Postpone the change here until we refactor error handling. -// + +//nolint:errorlint,errcheck // Postpone the change here until we refactor error handling. // Package errors provides a small api to manage hierarchy of errors. package errors diff --git a/internal/pkg/agent/storage/store/state_store.go b/internal/pkg/agent/storage/store/state_store.go index fb15bfc9400..3316b34960b 100644 --- a/internal/pkg/agent/storage/store/state_store.go +++ b/internal/pkg/agent/storage/store/state_store.go @@ -59,9 +59,12 @@ type StateStore struct { type stateT struct { action action ackToken string + queue []action } -// Combined yml serializer for the ActionPolicyChange and ActionUnenroll +// actionSerializer is a combined yml serializer for the ActionPolicyChange and ActionUnenroll +// it is used to read the yaml file and assign the action to stateT.action as we must provide the +// underlying struct that provides the action interface. type actionSerializer struct { ID string `yaml:"action_id"` Type string `yaml:"action_type"` @@ -69,9 +72,14 @@ type actionSerializer struct { IsDetected *bool `yaml:"is_detected,omitempty"` } +// stateSerializer is used to serialize the state to yaml. +// action serialization is handled through the actionSerializer struct +// queue serialization is handled through yaml struct tags or the actions unmarshaller defined in fleetapi +// TODO clean up action serialization (have it be part of the fleetapi?) type stateSerializer struct { Action *actionSerializer `yaml:"action,omitempty"` AckToken string `yaml:"ack_token,omitempty"` + Queue fleetapi.Actions `yaml:"action_queue,omitempty"` } // NewStateStoreWithMigration creates a new state store and migrates the old one. @@ -95,8 +103,7 @@ func NewStateStore(log *logger.Logger, store storeLoad) (*StateStore, error) { // persisted and we return an empty store. reader, err := store.Load() if err != nil { - //nolint:nilerr // wad - return &StateStore{log: log, store: store}, nil + return &StateStore{log: log, store: store}, nil //nolint:nilerr // expected results } defer reader.Close() @@ -117,6 +124,7 @@ func NewStateStore(log *logger.Logger, store storeLoad) (*StateStore, error) { state := stateT{ ackToken: sr.AckToken, + queue: sr.Queue, } if sr.Action != nil { @@ -238,6 +246,15 @@ func (s *StateStore) SetAckToken(ackToken string) { s.state.ackToken = ackToken } +// SetQueue sets the action_queue to agent state +func (s *StateStore) SetQueue(q []action) { + s.mx.Lock() + defer s.mx.Unlock() + s.state.queue = q + s.dirty = true + +} + // Save saves the actions into a state store. func (s *StateStore) Save() error { s.mx.Lock() @@ -251,6 +268,7 @@ func (s *StateStore) Save() error { var reader io.Reader serialize := stateSerializer{ AckToken: s.state.ackToken, + Queue: s.state.queue, } if s.state.action != nil { @@ -275,6 +293,15 @@ func (s *StateStore) Save() error { return nil } +// Queue returns a copy of the queue +func (s *StateStore) Queue() []action { + s.mx.RLock() + defer s.mx.RUnlock() + q := make([]action, len(s.state.queue)) + copy(q, s.state.queue) + return q +} + // Actions returns a slice of action to execute in order, currently only a action policy change is // persisted. func (s *StateStore) Actions() []action { diff --git a/internal/pkg/agent/storage/store/state_store_test.go b/internal/pkg/agent/storage/store/state_store_test.go index 587d02c2d68..e73b8721fbe 100644 --- a/internal/pkg/agent/storage/store/state_store_test.go +++ b/internal/pkg/agent/storage/store/state_store_test.go @@ -6,11 +6,10 @@ package store import ( "context" - "io/ioutil" - "os" "path/filepath" "sync" "testing" + "time" "github.com/google/go-cmp/cmp" "github.com/stretchr/testify/require" @@ -32,11 +31,9 @@ func TestStateStore(t *testing.T) { func runTestStateStore(t *testing.T, ackToken string) { log, _ := logger.New("state_store", false) - withFile := func(fn func(t *testing.T, file string)) func(*testing.T) { + withFile := func(fn func(t *testing.T, file string)) func(*testing.T) { //nolint:unparam // false positive return func(t *testing.T) { - dir, err := ioutil.TempDir("", "state-store") - require.NoError(t, err) - defer os.RemoveAll(dir) + dir := t.TempDir() file := filepath.Join(dir, "state.yml") fn(t, file) } @@ -47,7 +44,8 @@ func runTestStateStore(t *testing.T, ackToken string) { s := storage.NewDiskStore(file) store, err := NewStateStore(log, s) require.NoError(t, err) - require.Equal(t, 0, len(store.Actions())) + require.Empty(t, store.Actions()) + require.Empty(t, store.Queue()) })) t.Run("will discard silently unknown action", @@ -65,7 +63,8 @@ func runTestStateStore(t *testing.T, ackToken string) { store.SetAckToken(ackToken) err = store.Save() require.NoError(t, err) - require.Equal(t, 0, len(store.Actions())) + require.Empty(t, store.Actions()) + require.Empty(t, store.Queue()) require.Equal(t, ackToken, store.AckToken()) })) @@ -83,12 +82,14 @@ func runTestStateStore(t *testing.T, ackToken string) { store, err := NewStateStore(log, s) require.NoError(t, err) - require.Equal(t, 0, len(store.Actions())) + require.Empty(t, store.Actions()) + require.Empty(t, store.Queue()) store.Add(ActionPolicyChange) store.SetAckToken(ackToken) err = store.Save() require.NoError(t, err) - require.Equal(t, 1, len(store.Actions())) + require.Len(t, store.Actions(), 1) + require.Empty(t, store.Queue()) require.Equal(t, ackToken, store.AckToken()) s = storage.NewDiskStore(file) @@ -96,12 +97,90 @@ func runTestStateStore(t *testing.T, ackToken string) { require.NoError(t, err) actions := store1.Actions() - require.Equal(t, 1, len(actions)) + require.Len(t, actions, 1) + require.Empty(t, store1.Queue()) require.Equal(t, ActionPolicyChange, actions[0]) require.Equal(t, ackToken, store.AckToken()) })) + t.Run("can save a queue with one upgrade action", + withFile(func(t *testing.T, file string) { + ts := time.Now().UTC().Round(time.Second) + queue := []action{&fleetapi.ActionUpgrade{ + ActionID: "test", + ActionType: fleetapi.ActionTypeUpgrade, + ActionStartTime: ts.Format(time.RFC3339), + Version: "1.2.3", + SourceURI: "https://example.com", + }} + + s := storage.NewDiskStore(file) + store, err := NewStateStore(log, s) + require.NoError(t, err) + + require.Empty(t, store.Actions()) + store.SetQueue(queue) + err = store.Save() + require.NoError(t, err) + require.Empty(t, store.Actions()) + require.Len(t, store.Queue(), 1) + + s = storage.NewDiskStore(file) + store1, err := NewStateStore(log, s) + require.NoError(t, err) + require.Empty(t, store1.Actions()) + require.Len(t, store1.Queue(), 1) + require.Equal(t, "test", store1.Queue()[0].ID()) + start, err := store1.Queue()[0].StartTime() + require.NoError(t, err) + require.Equal(t, ts, start) + })) + + t.Run("can save a queue with two actions", + withFile(func(t *testing.T, file string) { + ts := time.Now().UTC().Round(time.Second) + queue := []action{&fleetapi.ActionUpgrade{ + ActionID: "test", + ActionType: fleetapi.ActionTypeUpgrade, + ActionStartTime: ts.Format(time.RFC3339), + Version: "1.2.3", + SourceURI: "https://example.com", + }, &fleetapi.ActionPolicyChange{ + ActionID: "abc123", + ActionType: "POLICY_CHANGE", + Policy: map[string]interface{}{ + "hello": "world", + }, + }} + + s := storage.NewDiskStore(file) + store, err := NewStateStore(log, s) + require.NoError(t, err) + + require.Empty(t, store.Actions()) + store.SetQueue(queue) + err = store.Save() + require.NoError(t, err) + require.Empty(t, store.Actions()) + require.Len(t, store.Queue(), 2) + + s = storage.NewDiskStore(file) + store1, err := NewStateStore(log, s) + require.NoError(t, err) + require.Empty(t, store1.Actions()) + require.Len(t, store1.Queue(), 2) + + require.Equal(t, "test", store1.Queue()[0].ID()) + start, err := store1.Queue()[0].StartTime() + require.NoError(t, err) + require.Equal(t, ts, start) + + require.Equal(t, "abc123", store1.Queue()[1].ID()) + _, err = store1.Queue()[1].StartTime() + require.ErrorIs(t, err, fleetapi.ErrNoStartTime) + })) + t.Run("can save to disk unenroll action type", withFile(func(t *testing.T, file string) { action := &fleetapi.ActionUnenroll{ @@ -113,12 +192,14 @@ func runTestStateStore(t *testing.T, ackToken string) { store, err := NewStateStore(log, s) require.NoError(t, err) - require.Equal(t, 0, len(store.Actions())) + require.Empty(t, store.Actions()) + require.Empty(t, store.Queue()) store.Add(action) store.SetAckToken(ackToken) err = store.Save() require.NoError(t, err) - require.Equal(t, 1, len(store.Actions())) + require.Len(t, store.Actions(), 1) + require.Empty(t, store.Queue()) require.Equal(t, ackToken, store.AckToken()) s = storage.NewDiskStore(file) @@ -126,8 +207,8 @@ func runTestStateStore(t *testing.T, ackToken string) { require.NoError(t, err) actions := store1.Actions() - require.Equal(t, 1, len(actions)) - + require.Len(t, actions, 1) + require.Empty(t, store1.Queue()) require.Equal(t, action, actions[0]) require.Equal(t, ackToken, store.AckToken()) })) @@ -144,10 +225,11 @@ func runTestStateStore(t *testing.T, ackToken string) { store.SetAckToken(ackToken) acker := NewStateStoreActionAcker(&testAcker{}, store) - require.Equal(t, 0, len(store.Actions())) + require.Empty(t, store.Actions()) require.NoError(t, acker.Ack(context.Background(), ActionPolicyChange)) - require.Equal(t, 1, len(store.Actions())) + require.Len(t, store.Actions(), 1) + require.Empty(t, store.Queue()) require.Equal(t, ackToken, store.AckToken()) })) @@ -159,8 +241,9 @@ func runTestStateStore(t *testing.T, ackToken string) { stateStore, err := NewStateStore(log, storage.NewDiskStore(stateStorePath)) require.NoError(t, err) stateStore.SetAckToken(ackToken) - require.Equal(t, 0, len(stateStore.Actions())) + require.Empty(t, stateStore.Actions()) require.Equal(t, ackToken, stateStore.AckToken()) + require.Empty(t, stateStore.Queue()) }) })) @@ -177,11 +260,11 @@ func runTestStateStore(t *testing.T, ackToken string) { actionStore, err := NewActionStore(log, storage.NewDiskStore(actionStorePath)) require.NoError(t, err) - require.Equal(t, 0, len(actionStore.Actions())) + require.Empty(t, actionStore.Actions()) actionStore.Add(ActionPolicyChange) err = actionStore.Save() require.NoError(t, err) - require.Equal(t, 1, len(actionStore.Actions())) + require.Len(t, actionStore.Actions(), 1) withFile(func(t *testing.T, stateStorePath string) { err = migrateStateStore(log, actionStorePath, stateStorePath) @@ -195,6 +278,7 @@ func runTestStateStore(t *testing.T, ackToken string) { t.Error(diff) } require.Equal(t, ackToken, stateStore.AckToken()) + require.Empty(t, stateStore.Queue()) }) })) diff --git a/internal/pkg/fleetapi/action.go b/internal/pkg/fleetapi/action.go index 81d23239ce1..4e6b08cd372 100644 --- a/internal/pkg/fleetapi/action.go +++ b/internal/pkg/fleetapi/action.go @@ -8,13 +8,17 @@ import ( "encoding/json" "fmt" "strings" + "time" "github.com/mitchellh/mapstructure" + "gopkg.in/yaml.v2" "github.com/elastic/elastic-agent/internal/pkg/agent/errors" ) const ( + // ActionTypeUnknown is used to indicate that the elastic-agent does not know how to handle the action + ActionTypeUnknown = "UNKNOWN" // ActionTypeUpgrade specifies upgrade action. ActionTypeUpgrade = "UPGRADE" // ActionTypeUnenroll specifies unenroll action. @@ -27,6 +31,14 @@ const ( ActionTypeSettings = "SETTINGS" // ActionTypeInputAction specifies agent action. ActionTypeInputAction = "INPUT_ACTION" + // ActionTypeCancel specifies a cancel action. + ActionTypeCancel = "CANCEL" +) + +// Error values that the Action interface can return +var ( + ErrNoStartTime = fmt.Errorf("action has no start time") + ErrNoExpiration = fmt.Errorf("action has no expiration") ) // Action base interface for all the implemented action from the fleet API. @@ -34,6 +46,28 @@ type Action interface { fmt.Stringer Type() string ID() string + // StartTime returns the earliest time an action should start (for schduled actions) + // Only ActionUpgrade implements this at the moment + StartTime() (time.Time, error) + // Expiration returns the time where an action is expired and should not be ran (for scheduled actions) + // Only ActionUpgrade implements this at the moment + Expiration() (time.Time, error) +} + +// FleetAction represents an action from fleet-server. +// should copy the action definition in fleet-server/model/schema.json +type FleetAction struct { + ActionID string `yaml:"action_id" json:"id"` // NOTE schema defines this as action_id, but fleet-server remaps it to id in the json response to agent check-in. + ActionType string `yaml:"type,omitempty" json:"type,omitempty"` + InputType string `yaml:"input_type,omitempty" json:"input_type,omitempty"` + ActionExpiration string `yaml:"expiration,omitempty" json:"expiration,omitempty"` + ActionStartTime string `yaml:"start_time,omitempty" json:"start_time,omitempty"` + Timeout int64 `yaml:"timeout,omitempty" json:"timeout,omitempty"` + Data json.RawMessage `yaml:"data,omitempty" json:"data,omitempty"` + //Agents []string // disabled, fleet-server uses this to generate each agent's actions + //Timestamp string // disabled, agent does not care when the document was created + //UserID string // disabled, agent does not care + //MinimumExecutionDuration int64 // disabled, used by fleet-server for scheduling } // ActionUnknown is an action that is not know by the current version of the Agent and we don't want @@ -49,7 +83,7 @@ type ActionUnknown struct { // Type returns the type of the Action. func (a *ActionUnknown) Type() string { - return "UNKNOWN" + return ActionTypeUnknown } // ID returns the ID of the Action. @@ -57,6 +91,16 @@ func (a *ActionUnknown) ID() string { return a.ActionID } +// StartTime returns ErrNoStartTime +func (a *ActionUnknown) StartTime() (time.Time, error) { + return time.Time{}, ErrNoStartTime +} + +// Expiration returns ErrNoExpiration +func (a *ActionUnknown) Expiration() (time.Time, error) { + return time.Time{}, ErrNoExpiration +} + func (a *ActionUnknown) String() string { var s strings.Builder s.WriteString("action_id: ") @@ -76,8 +120,8 @@ func (a *ActionUnknown) OriginalType() string { // ActionPolicyReassign is a request to apply a new type ActionPolicyReassign struct { - ActionID string - ActionType string + ActionID string `yaml:"action_id"` + ActionType string `yaml:"type"` } func (a *ActionPolicyReassign) String() string { @@ -99,11 +143,21 @@ func (a *ActionPolicyReassign) ID() string { return a.ActionID } +// StartTime returns ErrNoStartTime +func (a *ActionPolicyReassign) StartTime() (time.Time, error) { + return time.Time{}, ErrNoStartTime +} + +// Expiration returns ErrNoExpiration +func (a *ActionPolicyReassign) Expiration() (time.Time, error) { + return time.Time{}, ErrNoExpiration +} + // ActionPolicyChange is a request to apply a new type ActionPolicyChange struct { - ActionID string - ActionType string - Policy map[string]interface{} `json:"policy"` + ActionID string `yaml:"action_id"` + ActionType string `yaml:"type"` + Policy map[string]interface{} `json:"policy" yaml:"policy,omitempty"` } func (a *ActionPolicyChange) String() string { @@ -125,12 +179,24 @@ func (a *ActionPolicyChange) ID() string { return a.ActionID } +// StartTime returns ErrNoStartTime +func (a *ActionPolicyChange) StartTime() (time.Time, error) { + return time.Time{}, ErrNoStartTime +} + +// Expiration returns ErrNoExpiration +func (a *ActionPolicyChange) Expiration() (time.Time, error) { + return time.Time{}, ErrNoExpiration +} + // ActionUpgrade is a request for agent to upgrade. type ActionUpgrade struct { - ActionID string `json:"id" yaml:"id"` - ActionType string `json:"type" yaml:"type"` - Version string `json:"version" yaml:"version"` - SourceURI string `json:"source_uri,omitempty" yaml:"source_uri,omitempty"` + ActionID string `yaml:"action_id"` + ActionType string `yaml:"type"` + ActionStartTime string `json:"start_time" yaml:"start_time,omitempty"` // TODO change to time.Time in unmarshal + ActionExpiration string `json:"expiration" yaml:"expiration,omitempty"` + Version string `json:"version" yaml:"version,omitempty"` + SourceURI string `json:"source_uri,omitempty" yaml:"source_uri,omitempty"` } func (a *ActionUpgrade) String() string { @@ -152,11 +218,35 @@ func (a *ActionUpgrade) ID() string { return a.ActionID } +// StartTime returns the start_time as a UTC time.Time or ErrNoStartTime if there is no start time +func (a *ActionUpgrade) StartTime() (time.Time, error) { + if a.ActionStartTime == "" { + return time.Time{}, ErrNoStartTime + } + ts, err := time.Parse(time.RFC3339, a.ActionStartTime) + if err != nil { + return time.Time{}, err + } + return ts.UTC(), nil +} + +// Expiration returns the expiration as a UTC time.Time or ErrExpiration if there is no expiration +func (a *ActionUpgrade) Expiration() (time.Time, error) { + if a.ActionExpiration == "" { + return time.Time{}, ErrNoExpiration + } + ts, err := time.Parse(time.RFC3339, a.ActionExpiration) + if err != nil { + return time.Time{}, err + } + return ts.UTC(), nil +} + // ActionUnenroll is a request for agent to unhook from fleet. type ActionUnenroll struct { - ActionID string - ActionType string - IsDetected bool + ActionID string `yaml:"action_id"` + ActionType string `yaml:"type"` + IsDetected bool `json:"is_detected,omitempty" yaml:"is_detected,omitempty"` } func (a *ActionUnenroll) String() string { @@ -178,11 +268,21 @@ func (a *ActionUnenroll) ID() string { return a.ActionID } +// StartTime returns ErrNoStartTime +func (a *ActionUnenroll) StartTime() (time.Time, error) { + return time.Time{}, ErrNoStartTime +} + +// Expiration returns ErrNoExpiration +func (a *ActionUnenroll) Expiration() (time.Time, error) { + return time.Time{}, ErrNoExpiration +} + // ActionSettings is a request to change agent settings. type ActionSettings struct { - ActionID string - ActionType string - LogLevel string `json:"log_level"` + ActionID string `yaml:"action_id"` + ActionType string `yaml:"type"` + LogLevel string `json:"log_level" yaml:"log_level,omitempty"` } // ID returns the ID of the Action. @@ -195,6 +295,16 @@ func (a *ActionSettings) Type() string { return a.ActionType } +// StartTime returns ErrNoStartTime +func (a *ActionSettings) StartTime() (time.Time, error) { + return time.Time{}, ErrNoStartTime +} + +// Expiration returns ErrNoExpiration +func (a *ActionSettings) Expiration() (time.Time, error) { + return time.Time{}, ErrNoExpiration +} + func (a *ActionSettings) String() string { var s strings.Builder s.WriteString("action_id: ") @@ -206,6 +316,44 @@ func (a *ActionSettings) String() string { return s.String() } +// ActionCancel is a request to cancel an action. +type ActionCancel struct { + ActionID string `yaml:"action_id"` + ActionType string `yaml:"type"` + TargetID string `json:"target_id" yaml:"target_id,omitempty"` +} + +// ID returns the ID of the Action. +func (a *ActionCancel) ID() string { + return a.ActionID +} + +// Type returns the type of the Action. +func (a *ActionCancel) Type() string { + return a.ActionType +} + +// StartTime returns ErrNoStartTime +func (a *ActionCancel) StartTime() (time.Time, error) { + return time.Time{}, ErrNoStartTime +} + +// Expiration returns ErrNoExpiration +func (a *ActionCancel) Expiration() (time.Time, error) { + return time.Time{}, ErrNoExpiration +} + +func (a *ActionCancel) String() string { + var s strings.Builder + s.WriteString("action_id: ") + s.WriteString(a.ActionID) + s.WriteString(", type: ") + s.WriteString(a.ActionType) + s.WriteString(", target_id: ") + s.WriteString(a.TargetID) + return s.String() +} + // ActionApp is the application action request. type ActionApp struct { ActionID string `json:"id" mapstructure:"id"` @@ -240,6 +388,16 @@ func (a *ActionApp) Type() string { return a.ActionType } +// StartTime returns ErrNoStartTime +func (a *ActionApp) StartTime() (time.Time, error) { + return time.Time{}, ErrNoStartTime +} + +// Expiration returns ErrExpiration +func (a *ActionApp) Expiration() (time.Time, error) { + return time.Time{}, ErrNoExpiration +} + // MarshalMap marshals ActionApp into a corresponding map func (a *ActionApp) MarshalMap() (map[string]interface{}, error) { var res map[string]interface{} @@ -252,9 +410,7 @@ type Actions []Action // UnmarshalJSON takes every raw representation of an action and try to decode them. func (a *Actions) UnmarshalJSON(data []byte) error { - - var responses []ActionApp - + var responses []FleetAction if err := json.Unmarshal(data, &responses); err != nil { return errors.New(err, "fail to decode actions", @@ -262,9 +418,8 @@ func (a *Actions) UnmarshalJSON(data []byte) error { } actions := make([]Action, 0, len(responses)) - var action Action - for _, response := range responses { + var action Action switch response.ActionType { case ActionTypePolicyChange: action = &ActionPolicyChange{ @@ -288,7 +443,6 @@ func (a *Actions) UnmarshalJSON(data []byte) error { InputType: response.InputType, Timeout: response.Timeout, Data: response.Data, - Response: response.Response, } case ActionTypeUnenroll: action = &ActionUnenroll{ @@ -297,8 +451,10 @@ func (a *Actions) UnmarshalJSON(data []byte) error { } case ActionTypeUpgrade: action = &ActionUpgrade{ - ActionID: response.ActionID, - ActionType: response.ActionType, + ActionID: response.ActionID, + ActionType: response.ActionType, + ActionStartTime: response.ActionStartTime, + ActionExpiration: response.ActionExpiration, } if err := json.Unmarshal(response.Data, action); err != nil { @@ -317,10 +473,20 @@ func (a *Actions) UnmarshalJSON(data []byte) error { "fail to decode SETTINGS_ACTION action", errors.TypeConfig) } + case ActionTypeCancel: + action = &ActionCancel{ + ActionID: response.ActionID, + ActionType: response.ActionType, + } + if err := json.Unmarshal(response.Data, action); err != nil { + return errors.New(err, + "fail to decode CANCEL_ACTION action", + errors.TypeConfig) + } default: action = &ActionUnknown{ ActionID: response.ActionID, - ActionType: "UNKNOWN", + ActionType: ActionTypeUnknown, originalType: response.ActionType, } } @@ -330,3 +496,89 @@ func (a *Actions) UnmarshalJSON(data []byte) error { *a = actions return nil } + +// UnmarshalYAML attempts to decode yaml actions. +func (a *Actions) UnmarshalYAML(unmarshal func(interface{}) error) error { + var nodes []FleetAction + if err := unmarshal(&nodes); err != nil { + return errors.New(err, + "fail to decode action", + errors.TypeConfig) + } + actions := make([]Action, 0, len(nodes)) + for i := range nodes { + var action Action + n := nodes[i] + switch n.ActionType { + case ActionTypePolicyChange: + action = &ActionPolicyChange{ + ActionID: n.ActionID, + ActionType: n.ActionType, + } + if err := yaml.Unmarshal(n.Data, action); err != nil { + return errors.New(err, + "fail to decode POLICY_CHANGE action", + errors.TypeConfig) + } + case ActionTypePolicyReassign: + action = &ActionPolicyReassign{ + ActionID: n.ActionID, + ActionType: n.ActionType, + } + case ActionTypeInputAction: + action = &ActionApp{ + ActionID: n.ActionID, + ActionType: n.ActionType, + InputType: n.InputType, + Timeout: n.Timeout, + Data: n.Data, + } + case ActionTypeUnenroll: + action = &ActionUnenroll{ + ActionID: n.ActionID, + ActionType: n.ActionType, + } + case ActionTypeUpgrade: + action = &ActionUpgrade{ + ActionID: n.ActionID, + ActionType: n.ActionType, + ActionStartTime: n.ActionStartTime, + ActionExpiration: n.ActionExpiration, + } + if err := yaml.Unmarshal(n.Data, &action); err != nil { + return errors.New(err, + "fail to decode UPGRADE_ACTION action", + errors.TypeConfig) + } + case ActionTypeSettings: + action = &ActionSettings{ + ActionID: n.ActionID, + ActionType: n.ActionType, + } + if err := yaml.Unmarshal(n.Data, action); err != nil { + return errors.New(err, + "fail to decode SETTINGS_ACTION action", + errors.TypeConfig) + } + case ActionTypeCancel: + action = &ActionCancel{ + ActionID: n.ActionID, + ActionType: n.ActionType, + } + if err := yaml.Unmarshal(n.Data, action); err != nil { + return errors.New(err, + "fail to decode CANCEL_ACTION action", + errors.TypeConfig) + } + default: + action = &ActionUnknown{ + ActionID: n.ActionID, + ActionType: ActionTypeUnknown, + originalType: n.ActionType, + } + } + actions = append(actions, action) + } + *a = actions + return nil +} diff --git a/internal/pkg/fleetapi/action_test.go b/internal/pkg/fleetapi/action_test.go index 28e439699a7..b21e591c297 100644 --- a/internal/pkg/fleetapi/action_test.go +++ b/internal/pkg/fleetapi/action_test.go @@ -9,6 +9,8 @@ import ( "testing" "github.com/google/go-cmp/cmp" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" ) func TestActionSerialization(t *testing.T) { @@ -80,3 +82,56 @@ func mapRawMessageVal(m map[string]interface{}, key string) json.RawMessage { } return nil } + +func TestActionsUnmarshalJSON(t *testing.T) { + t.Run("ActionUpgrade no start time", func(t *testing.T) { + p := []byte(`[{"id":"testid","type":"UPGRADE","data":{"version":"1.2.3","source_uri":"http://example.com"}}]`) + a := &Actions{} + err := a.UnmarshalJSON(p) + require.Nil(t, err) + action, ok := (*a)[0].(*ActionUpgrade) + require.True(t, ok, "unable to cast action to specific type") + assert.Equal(t, "testid", action.ActionID) + assert.Equal(t, ActionTypeUpgrade, action.ActionType) + assert.Empty(t, action.ActionStartTime) + assert.Empty(t, action.ActionExpiration) + assert.Equal(t, "1.2.3", action.Version) + assert.Equal(t, "http://example.com", action.SourceURI) + }) + t.Run("ActionUpgrade with start time", func(t *testing.T) { + p := []byte(`[{"id":"testid","type":"UPGRADE","start_time":"2022-01-02T12:00:00Z","expiration":"2022-01-02T13:00:00Z","data":{"version":"1.2.3","source_uri":"http://example.com"}}]`) + a := &Actions{} + err := a.UnmarshalJSON(p) + require.Nil(t, err) + action, ok := (*a)[0].(*ActionUpgrade) + require.True(t, ok, "unable to cast action to specific type") + assert.Equal(t, "testid", action.ActionID) + assert.Equal(t, ActionTypeUpgrade, action.ActionType) + assert.Equal(t, "2022-01-02T12:00:00Z", action.ActionStartTime) + assert.Equal(t, "2022-01-02T13:00:00Z", action.ActionExpiration) + assert.Equal(t, "1.2.3", action.Version) + assert.Equal(t, "http://example.com", action.SourceURI) + }) + t.Run("ActionPolicyChange no start time", func(t *testing.T) { + p := []byte(`[{"id":"testid","type":"POLICY_CHANGE","data":{"policy":{"key":"value"}}}]`) + a := &Actions{} + err := a.UnmarshalJSON(p) + require.Nil(t, err) + action, ok := (*a)[0].(*ActionPolicyChange) + require.True(t, ok, "unable to cast action to specific type") + assert.Equal(t, "testid", action.ActionID) + assert.Equal(t, ActionTypePolicyChange, action.ActionType) + assert.NotNil(t, action.Policy) + }) + t.Run("ActionPolicyChange with start time", func(t *testing.T) { + p := []byte(`[{"id":"testid","type":"POLICY_CHANGE","start_time":"2022-01-02T12:00:00Z","expiration":"2022-01-02T13:00:00Z","data":{"policy":{"key":"value"}}}]`) + a := &Actions{} + err := a.UnmarshalJSON(p) + require.Nil(t, err) + action, ok := (*a)[0].(*ActionPolicyChange) + require.True(t, ok, "unable to cast action to specific type") + assert.Equal(t, "testid", action.ActionID) + assert.Equal(t, ActionTypePolicyChange, action.ActionType) + assert.NotNil(t, action.Policy) + }) +} diff --git a/internal/pkg/queue/actionqueue.go b/internal/pkg/queue/actionqueue.go new file mode 100644 index 00000000000..671291639a2 --- /dev/null +++ b/internal/pkg/queue/actionqueue.go @@ -0,0 +1,131 @@ +// Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one +// or more contributor license agreements. Licensed under the Elastic License; +// you may not use this file except in compliance with the Elastic License. + +package queue + +import ( + "container/heap" + "time" + + "github.com/elastic/elastic-agent/internal/pkg/fleetapi" +) + +// item tracks an action in the action queue +type item struct { + action fleetapi.Action + priority int64 + index int +} + +// ActionQueue uses the standard library's container/heap to implement a priority queue +// This queue should not be indexed directly, instead use the provided Add, DequeueActions, or Cancel methods to add or remove items +// Actions() is indended to get the list of actions in the queue for serialization. +type ActionQueue []*item + +// Len returns the length of the queue +func (q ActionQueue) Len() int { + return len(q) +} + +// Less will determine if item i's priority is less then item j's +func (q ActionQueue) Less(i, j int) bool { + return q[i].priority < q[j].priority +} + +// Swap will swap the items at index i and j +func (q ActionQueue) Swap(i, j int) { + q[i], q[j] = q[j], q[i] + q[i].index = i + q[j].index = j +} + +// Push will add x as an item to the queue +// When using the queue, the Add method should be used instead. +func (q *ActionQueue) Push(x interface{}) { + n := len(*q) + e := x.(*item) //nolint:errcheck // should be an *item + e.index = n + *q = append(*q, e) +} + +// Pop will return the last item from the queue +// When using the queue, DequeueActions should be used instead +func (q *ActionQueue) Pop() interface{} { + old := *q + n := len(old) + e := old[n-1] + old[n-1] = nil // avoid memory leak + e.index = -1 // for safety + *q = old[0 : n-1] + return e +} + +// NewActionQueue creates a new ActionQueue initialized with the passed actions. +// Will return an error if StartTime fails for any action. +func NewActionQueue(actions []fleetapi.Action) (*ActionQueue, error) { + q := make(ActionQueue, len(actions)) + for i, action := range actions { + ts, err := action.StartTime() + if err != nil { + return nil, err + } + q[i] = &item{ + action: action, + priority: ts.Unix(), + index: i, + } + } + heap.Init(&q) + return &q, nil +} + +// Add will add an action to the queue with the associated priority. +// The priority is meant to be the start-time of the action as a unix epoch time. +// Complexity: O(log n) +func (q *ActionQueue) Add(action fleetapi.Action, priority int64) { + e := &item{ + action: action, + priority: priority, + } + heap.Push(q, e) +} + +// DequeueActions will dequeue all actions that have a priority less then time.Now(). +// Complexity: O(n*log n) +func (q *ActionQueue) DequeueActions() []fleetapi.Action { + ts := time.Now().Unix() + actions := make([]fleetapi.Action, 0) + for q.Len() != 0 { + if (*q)[0].priority > ts { + break + } + item := heap.Pop(q).(*item) //nolint:errcheck // should be an *item + actions = append(actions, item.action) + } + return actions +} + +// Cancel will remove any actions in the queue with a matching actionID and return the number of entries cancelled. +// Complexity: O(n*log n) +func (q *ActionQueue) Cancel(actionID string) int { + items := make([]*item, 0) + for _, item := range *q { + if item.action.ID() == actionID { + items = append(items, item) + } + } + for _, item := range items { + heap.Remove(q, item.index) + } + return len(items) +} + +// Actions returns all actions in the queue, item 0 is garunteed to be the min, the rest may not be in sorted order. +func (q *ActionQueue) Actions() []fleetapi.Action { + actions := make([]fleetapi.Action, q.Len()) + for i, item := range *q { + actions[i] = item.action + } + return actions +} diff --git a/internal/pkg/queue/actionqueue_test.go b/internal/pkg/queue/actionqueue_test.go new file mode 100644 index 00000000000..1c1e1959a9f --- /dev/null +++ b/internal/pkg/queue/actionqueue_test.go @@ -0,0 +1,485 @@ +// Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one +// or more contributor license agreements. Licensed under the Elastic License; +// you may not use this file except in compliance with the Elastic License. + +//nolint:errcheck,dupl // lots of casting in test cases +package queue + +import ( + "container/heap" + "errors" + "testing" + "time" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/mock" + "github.com/stretchr/testify/require" + + "github.com/elastic/elastic-agent/internal/pkg/fleetapi" +) + +type mockAction struct { + mock.Mock +} + +func (m *mockAction) String() string { + args := m.Called() + return args.String(0) +} + +func (m *mockAction) Type() string { + args := m.Called() + return args.String(0) +} + +func (m *mockAction) ID() string { + args := m.Called() + return args.String(0) +} + +func (m *mockAction) StartTime() (time.Time, error) { + args := m.Called() + return args.Get(0).(time.Time), args.Error(1) +} + +func (m *mockAction) Expiration() (time.Time, error) { + args := m.Called() + return args.Get(0).(time.Time), args.Error(1) +} + +func TestNewActionQueue(t *testing.T) { + ts := time.Now() + a1 := &mockAction{} + a1.On("ID").Return("test-1") + a1.On("StartTime").Return(ts, nil) + a2 := &mockAction{} + a2.On("ID").Return("test-2") + a2.On("StartTime").Return(ts.Add(time.Second), nil) + a3 := &mockAction{} + a3.On("ID").Return("test-3") + a3.On("StartTime").Return(ts.Add(time.Minute), nil) + + t.Run("nil actions slice", func(t *testing.T) { + q, err := NewActionQueue(nil) + require.NoError(t, err) + assert.NotNil(t, q) + assert.Empty(t, q) + }) + + t.Run("empty actions slice", func(t *testing.T) { + q, err := NewActionQueue([]fleetapi.Action{}) + require.NoError(t, err) + assert.NotNil(t, q) + assert.Empty(t, q) + }) + + t.Run("ordered actions list", func(t *testing.T) { + q, err := NewActionQueue([]fleetapi.Action{a1, a2, a3}) + assert.NotNil(t, q) + require.NoError(t, err) + assert.Len(t, *q, 3) + + i := heap.Pop(q).(*item) + assert.Equal(t, "test-1", i.action.ID()) + i = heap.Pop(q).(*item) + assert.Equal(t, "test-2", i.action.ID()) + i = heap.Pop(q).(*item) + assert.Equal(t, "test-3", i.action.ID()) + assert.Empty(t, *q) + }) + + t.Run("unordered actions list", func(t *testing.T) { + q, err := NewActionQueue([]fleetapi.Action{a3, a2, a1}) + require.NoError(t, err) + assert.NotNil(t, q) + assert.Len(t, *q, 3) + + i := heap.Pop(q).(*item) + assert.Equal(t, "test-1", i.action.ID()) + i = heap.Pop(q).(*item) + assert.Equal(t, "test-2", i.action.ID()) + i = heap.Pop(q).(*item) + assert.Equal(t, "test-3", i.action.ID()) + assert.Empty(t, *q) + }) + + t.Run("start time error", func(t *testing.T) { + a := &mockAction{} + a.On("StartTime").Return(time.Time{}, errors.New("oh no")) + q, err := NewActionQueue([]fleetapi.Action{a}) + assert.EqualError(t, err, "oh no") + assert.Nil(t, q) + }) +} + +func assertOrdered(t *testing.T, q *ActionQueue) { + t.Helper() + require.Len(t, *q, 3) + i := heap.Pop(q).(*item) + assert.Equal(t, int64(1), i.priority) + assert.Equal(t, "test-1", i.action.ID()) + i = heap.Pop(q).(*item) + assert.Equal(t, int64(2), i.priority) + assert.Equal(t, "test-2", i.action.ID()) + i = heap.Pop(q).(*item) + assert.Equal(t, int64(3), i.priority) + assert.Equal(t, "test-3", i.action.ID()) + + assert.Empty(t, *q) +} + +func Test_ActionQueue_Add(t *testing.T) { + a1 := &mockAction{} + a1.On("ID").Return("test-1") + a2 := &mockAction{} + a2.On("ID").Return("test-2") + a3 := &mockAction{} + a3.On("ID").Return("test-3") + + t.Run("ascending order", func(t *testing.T) { + q := &ActionQueue{} + q.Add(a1, 1) + q.Add(a2, 2) + q.Add(a3, 3) + + assertOrdered(t, q) + }) + + t.Run("Add descending order", func(t *testing.T) { + q := &ActionQueue{} + q.Add(a3, 3) + q.Add(a2, 2) + q.Add(a1, 1) + + assertOrdered(t, q) + }) + + t.Run("mixed order", func(t *testing.T) { + q := &ActionQueue{} + q.Add(a1, 1) + q.Add(a3, 3) + q.Add(a2, 2) + + assertOrdered(t, q) + }) + + t.Run("two items have same priority", func(t *testing.T) { + q := &ActionQueue{} + q.Add(a1, 1) + q.Add(a2, 2) + q.Add(a3, 2) + + require.Len(t, *q, 3) + i := heap.Pop(q).(*item) + assert.Equal(t, int64(1), i.priority) + assert.Equal(t, "test-1", i.action.ID()) + // next two items have same priority, however the ids may not match insertion order + i = heap.Pop(q).(*item) + assert.Equal(t, int64(2), i.priority) + i = heap.Pop(q).(*item) + assert.Equal(t, int64(2), i.priority) + assert.Empty(t, *q) + }) +} + +func Test_ActionQueue_DequeueActions(t *testing.T) { + a1 := &mockAction{} + a1.On("ID").Return("test-1") + a2 := &mockAction{} + a2.On("ID").Return("test-2") + a3 := &mockAction{} + a3.On("ID").Return("test-3") + + t.Run("empty queue", func(t *testing.T) { + q := &ActionQueue{} + + actions := q.DequeueActions() + + assert.Empty(t, actions) + assert.Empty(t, *q) + }) + + t.Run("one action from queue", func(t *testing.T) { + ts := time.Now() + q := &ActionQueue{&item{ + action: a1, + priority: ts.Add(-1 * time.Minute).Unix(), + index: 0, + }, &item{ + action: a2, + priority: ts.Add(2 * time.Minute).Unix(), + index: 1, + }, &item{ + action: a3, + priority: ts.Add(3 * time.Minute).Unix(), + index: 2, + }} + heap.Init(q) + + actions := q.DequeueActions() + + require.Len(t, actions, 1) + assert.Equal(t, "test-1", actions[0].ID()) + + require.Len(t, *q, 2) + i := heap.Pop(q).(*item) + assert.Equal(t, "test-2", i.action.ID()) + assert.Equal(t, ts.Add(2*time.Minute).Unix(), i.priority) + i = heap.Pop(q).(*item) + assert.Equal(t, "test-3", i.action.ID()) + assert.Equal(t, ts.Add(3*time.Minute).Unix(), i.priority) + + assert.Empty(t, *q) + }) + + t.Run("two actions from queue", func(t *testing.T) { + ts := time.Now() + q := &ActionQueue{&item{ + action: a1, + priority: ts.Add(-1 * time.Minute).Unix(), + index: 0, + }, &item{ + action: a2, + priority: ts.Add(-2 * time.Minute).Unix(), + index: 1, + }, &item{ + action: a3, + priority: ts.Add(3 * time.Minute).Unix(), + index: 2, + }} + heap.Init(q) + + actions := q.DequeueActions() + + require.Len(t, actions, 2) + assert.Equal(t, "test-2", actions[0].ID()) + assert.Equal(t, "test-1", actions[1].ID()) + + require.Len(t, *q, 1) + i := heap.Pop(q).(*item) + assert.Equal(t, "test-3", i.action.ID()) + assert.Equal(t, ts.Add(3*time.Minute).Unix(), i.priority) + + assert.Empty(t, *q) + }) + + t.Run("all actions from queue", func(t *testing.T) { + ts := time.Now() + q := &ActionQueue{&item{ + action: a1, + priority: ts.Add(-1 * time.Minute).Unix(), + index: 0, + }, &item{ + action: a2, + priority: ts.Add(-2 * time.Minute).Unix(), + index: 1, + }, &item{ + action: a3, + priority: ts.Add(-3 * time.Minute).Unix(), + index: 2, + }} + heap.Init(q) + + actions := q.DequeueActions() + + require.Len(t, actions, 3) + assert.Equal(t, "test-3", actions[0].ID()) + assert.Equal(t, "test-2", actions[1].ID()) + assert.Equal(t, "test-1", actions[2].ID()) + + require.Empty(t, *q) + }) + + t.Run("no actions from queue", func(t *testing.T) { + ts := time.Now() + q := &ActionQueue{&item{ + action: a1, + priority: ts.Add(1 * time.Minute).Unix(), + index: 0, + }, &item{ + action: a2, + priority: ts.Add(2 * time.Minute).Unix(), + index: 1, + }, &item{ + action: a3, + priority: ts.Add(3 * time.Minute).Unix(), + index: 2, + }} + heap.Init(q) + + actions := q.DequeueActions() + assert.Empty(t, actions) + + require.Len(t, *q, 3) + i := heap.Pop(q).(*item) + assert.Equal(t, "test-1", i.action.ID()) + assert.Equal(t, ts.Add(1*time.Minute).Unix(), i.priority) + i = heap.Pop(q).(*item) + assert.Equal(t, "test-2", i.action.ID()) + assert.Equal(t, ts.Add(2*time.Minute).Unix(), i.priority) + i = heap.Pop(q).(*item) + assert.Equal(t, "test-3", i.action.ID()) + assert.Equal(t, ts.Add(3*time.Minute).Unix(), i.priority) + + }) +} + +func Test_ActionQueue_Cancel(t *testing.T) { + a1 := &mockAction{} + a1.On("ID").Return("test-1") + a2 := &mockAction{} + a2.On("ID").Return("test-2") + a3 := &mockAction{} + a3.On("ID").Return("test-3") + + t.Run("empty queue", func(t *testing.T) { + q := &ActionQueue{} + + n := q.Cancel("test-1") + assert.Zero(t, n) + assert.Empty(t, *q) + }) + + t.Run("one item cancelled", func(t *testing.T) { + q := &ActionQueue{&item{ + action: a1, + priority: 1, + index: 0, + }, &item{ + action: a2, + priority: 2, + index: 1, + }, &item{ + action: a3, + priority: 3, + index: 2, + }} + heap.Init(q) + + n := q.Cancel("test-1") + assert.Equal(t, 1, n) + + assert.Len(t, *q, 2) + i := heap.Pop(q).(*item) + assert.Equal(t, "test-2", i.action.ID()) + assert.Equal(t, int64(2), i.priority) + i = heap.Pop(q).(*item) + assert.Equal(t, "test-3", i.action.ID()) + assert.Equal(t, int64(3), i.priority) + assert.Empty(t, *q) + }) + + t.Run("two items cancelled", func(t *testing.T) { + q := &ActionQueue{&item{ + action: a1, + priority: 1, + index: 0, + }, &item{ + action: a1, + priority: 2, + index: 1, + }, &item{ + action: a3, + priority: 3, + index: 2, + }} + heap.Init(q) + + n := q.Cancel("test-1") + assert.Equal(t, 2, n) + + assert.Len(t, *q, 1) + i := heap.Pop(q).(*item) + assert.Equal(t, "test-3", i.action.ID()) + assert.Equal(t, int64(3), i.priority) + assert.Empty(t, *q) + }) + + t.Run("all items cancelled", func(t *testing.T) { + q := &ActionQueue{&item{ + action: a1, + priority: 1, + index: 0, + }, &item{ + action: a1, + priority: 2, + index: 1, + }, &item{ + action: a1, + priority: 3, + index: 2, + }} + heap.Init(q) + + n := q.Cancel("test-1") + assert.Equal(t, 3, n) + assert.Empty(t, *q) + }) + + t.Run("no items cancelled", func(t *testing.T) { + q := &ActionQueue{&item{ + action: a1, + priority: 1, + index: 0, + }, &item{ + action: a2, + priority: 2, + index: 1, + }, &item{ + action: a3, + priority: 3, + index: 2, + }} + heap.Init(q) + + n := q.Cancel("test-0") + assert.Zero(t, n) + + assert.Len(t, *q, 3) + i := heap.Pop(q).(*item) + assert.Equal(t, "test-1", i.action.ID()) + assert.Equal(t, int64(1), i.priority) + i = heap.Pop(q).(*item) + assert.Equal(t, "test-2", i.action.ID()) + assert.Equal(t, int64(2), i.priority) + i = heap.Pop(q).(*item) + assert.Equal(t, "test-3", i.action.ID()) + assert.Equal(t, int64(3), i.priority) + assert.Empty(t, *q) + }) +} + +func Test_ActionQueue_Actions(t *testing.T) { + t.Run("empty queue", func(t *testing.T) { + q := &ActionQueue{} + actions := q.Actions() + assert.Len(t, actions, 0) + }) + + t.Run("non-empty queue", func(t *testing.T) { + a1 := &mockAction{} + a1.On("ID").Return("test-1") + a2 := &mockAction{} + a2.On("ID").Return("test-2") + a3 := &mockAction{} + a3.On("ID").Return("test-3") + q := &ActionQueue{&item{ + action: a1, + priority: 1, + index: 0, + }, &item{ + action: a2, + priority: 2, + index: 1, + }, &item{ + action: a3, + priority: 3, + index: 2, + }} + heap.Init(q) + + actions := q.Actions() + assert.Len(t, actions, 3) + assert.Equal(t, "test-1", actions[0].ID()) + }) +}