diff --git a/cmd/launcher/internal/record_metadata.go b/cmd/launcher/internal/record_metadata.go index 987480463..755f076c4 100644 --- a/cmd/launcher/internal/record_metadata.go +++ b/cmd/launcher/internal/record_metadata.go @@ -108,7 +108,7 @@ func (mw *metadataWriter) recordMetadata(metadata *metadata) error { return nil } -func (mw *metadataWriter) getServerDataValue(store types.GetterSetterDeleterIteratorUpdater, key string) string { +func (mw *metadataWriter) getServerDataValue(store types.KVStore, key string) string { val, err := store.Get([]byte(key)) if err != nil { mw.slogger.Log(context.TODO(), slog.LevelDebug, diff --git a/ee/agent/storage/bbolt/keyvalue_store_bbolt.go b/ee/agent/storage/bbolt/keyvalue_store_bbolt.go index b9005fdb7..7c8924ceb 100644 --- a/ee/agent/storage/bbolt/keyvalue_store_bbolt.go +++ b/ee/agent/storage/bbolt/keyvalue_store_bbolt.go @@ -2,6 +2,7 @@ package agentbbolt import ( "context" + "encoding/binary" "fmt" "log/slog" @@ -137,12 +138,14 @@ func (s *bboltKeyValueStore) DeleteAll() error { }) } +// ForEach provides a read-only iterator for all key-value pairs stored within s.bucketName +// this allows bboltKeyValueStore to adhere to the types.Iterator interface func (s *bboltKeyValueStore) ForEach(fn func(k, v []byte) error) error { if s == nil || s.db == nil { return NoDbError{} } - return s.db.Update(func(tx *bbolt.Tx) error { + return s.db.View(func(tx *bbolt.Tx) error { b := tx.Bucket([]byte(s.bucketName)) if b == nil { return NewNoBucketError(s.bucketName) @@ -225,3 +228,65 @@ func (s *bboltKeyValueStore) Update(kvPairs map[string]string) ([]string, error) return deletedKeys, nil } + +func (s *bboltKeyValueStore) Count() (int, error) { + if s == nil || s.db == nil { + s.slogger.Log(context.TODO(), slog.LevelError, "unable to count uninitialized bbolt storage db") + return 0, NoDbError{} + } + + var len int + if err := s.db.View(func(tx *bbolt.Tx) error { + b := tx.Bucket([]byte(s.bucketName)) + if b == nil { + return NewNoBucketError(s.bucketName) + } + + len = b.Stats().KeyN + return nil + }); err != nil { + s.slogger.Log(context.TODO(), slog.LevelError, + "err counting from bucket", + "err", err, + ) + return 0, err + } + + return len, nil +} + +// AppendValues utlizes bbolts NextSequence functionality to add ordered values +// after generating the next autoincrementing key for each +func (s *bboltKeyValueStore) AppendValues(values ...[]byte) error { + if s == nil || s.db == nil { + return fmt.Errorf("unable to append values into uninitialized bbolt db store") + } + + return s.db.Update(func(tx *bbolt.Tx) error { + b := tx.Bucket([]byte(s.bucketName)) + if b == nil { + return NewNoBucketError(s.bucketName) + } + + for _, value := range values { + key, err := b.NextSequence() + if err != nil { + return fmt.Errorf("generating key: %w", err) + } + + if err = b.Put(byteKeyFromUint64(key), value); err != nil { + return fmt.Errorf("adding ordered value: %w", err) + } + } + + return nil + }) +} + +func byteKeyFromUint64(k uint64) []byte { + // Adapted from Bolt docs + // 8 bytes in a uint64 + b := make([]byte, 8) + binary.BigEndian.PutUint64(b, k) + return b +} diff --git a/ee/agent/storage/ci/keyvalue_store_test.go b/ee/agent/storage/ci/keyvalue_store_test.go index f777a6d77..2e614e010 100644 --- a/ee/agent/storage/ci/keyvalue_store_test.go +++ b/ee/agent/storage/ci/keyvalue_store_test.go @@ -142,14 +142,9 @@ func Test_Delete(t *testing.T) { require.NoError(t, err) } - // There should be no records, count and verify - var recordCount int - rowFn := func(k, v []byte) error { - recordCount = recordCount + 1 - return nil - } - s.ForEach(rowFn) - assert.Equal(t, tt.expectedRecordCount, recordCount) + totalCount, err := s.Count() + require.NoError(t, err) + assert.Equal(t, tt.expectedRecordCount, totalCount) } }) } @@ -189,13 +184,9 @@ func Test_DeleteAll(t *testing.T) { require.NoError(t, s.DeleteAll()) // There should be no records, count and verify - var recordCount int - rowFn := func(k, v []byte) error { - recordCount = recordCount + 1 - return nil - } - s.ForEach(rowFn) - assert.Equal(t, 0, recordCount) + totalCount, err := s.Count() + require.NoError(t, err) + assert.Equal(t, 0, totalCount) } }) } @@ -377,6 +368,94 @@ func Test_ForEach(t *testing.T) { } } +func Test_Count(t *testing.T) { + t.Parallel() + + tests := []struct { + name string + sets map[string]string + expectedCount int + }{ + { + name: "empty", + sets: map[string]string{}, + expectedCount: 0, + }, + { + name: "one value", + sets: map[string]string{"key1": "value1"}, + expectedCount: 1, + }, + { + name: "multiple values", + sets: map[string]string{"key1": "value1", "key2": "value2", "key3": "value3", "key4": "value4"}, + expectedCount: 4, + }, + } + for _, tt := range tests { + tt := tt + t.Run(tt.name, func(t *testing.T) { + t.Parallel() + + for _, s := range getStores(t) { + _, err := s.Update(tt.sets) + require.NoError(t, err) + totalCount, err := s.Count() + require.NoError(t, err) + assert.Equal(t, tt.expectedCount, totalCount) + } + }) + } +} + +func Test_AppendValues(t *testing.T) { + t.Parallel() + + tests := []struct { + name string + sets [][]byte + expectedCount int + }{ + { + name: "empty", + sets: [][]byte{}, + expectedCount: 0, + }, + { + name: "one value", + sets: [][]byte{[]byte("one")}, + expectedCount: 1, + }, + { + name: "multiple values", + sets: [][]byte{[]byte("one"), []byte("two"), []byte("three"), []byte("four"), []byte("five")}, + expectedCount: 5, + }, + } + for _, tt := range tests { + tt := tt + t.Run(tt.name, func(t *testing.T) { + t.Parallel() + + for _, s := range getStores(t) { + err := s.AppendValues(tt.sets...) + require.NoError(t, err) + // check the count to ensure the tests below will endure the expected number of iterations + totalCount, err := s.Count() + require.NoError(t, err) + require.Equal(t, tt.expectedCount, totalCount) + idx := 0 + // now we expect to be able to iterate over these in the same order that we appended them + s.ForEach(func(k, v []byte) error { + require.Equal(t, tt.sets[idx], v) + idx++ + return nil + }) + } + }) + } +} + func getKeyValueRows(store types.KVStore, bucketName string) ([]map[string]string, error) { results := make([]map[string]string, 0) diff --git a/ee/agent/storage/inmemory/keyvalue_store_in_memory.go b/ee/agent/storage/inmemory/keyvalue_store_in_memory.go index 19d6adb49..91c86c973 100644 --- a/ee/agent/storage/inmemory/keyvalue_store_in_memory.go +++ b/ee/agent/storage/inmemory/keyvalue_store_in_memory.go @@ -1,18 +1,23 @@ package inmemory import ( + "encoding/binary" "errors" + "fmt" "sync" ) type inMemoryKeyValueStore struct { - mu sync.RWMutex - items map[string][]byte + mu sync.RWMutex + items map[string][]byte + order []string + sequence uint64 } func NewStore() *inMemoryKeyValueStore { s := &inMemoryKeyValueStore{ items: make(map[string][]byte), + order: make([]string, 0), } return s @@ -42,7 +47,14 @@ func (s *inMemoryKeyValueStore) Set(key, value []byte) error { s.mu.Lock() defer s.mu.Unlock() - s.items[string(key)] = value + + if _, exists := s.items[string(key)]; !exists { + s.order = append(s.order, string(key)) + } + + s.items[string(key)] = make([]byte, len(value)) + copy(s.items[string(key)], value) + return nil } @@ -55,7 +67,14 @@ func (s *inMemoryKeyValueStore) Delete(keys ...[]byte) error { defer s.mu.Unlock() for _, key := range keys { delete(s.items, string(key)) + for i, k := range s.order { + if k == string(key) { + s.order = append(s.order[:i], s.order[i+1:]...) + break + } + } } + return nil } @@ -66,7 +85,9 @@ func (s *inMemoryKeyValueStore) DeleteAll() error { s.mu.Lock() defer s.mu.Unlock() - clear(s.items) + s.items = make(map[string][]byte) + s.order = make([]string, 0) + return nil } @@ -77,40 +98,38 @@ func (s *inMemoryKeyValueStore) ForEach(fn func(k, v []byte) error) error { s.mu.Lock() defer s.mu.Unlock() - for k, v := range s.items { - if err := fn([]byte(k), v); err != nil { + for _, key := range s.order { + if err := fn([]byte(key), s.items[key]); err != nil { return err } } return nil } +// Update adheres to the Updater interface for bulk replacing data in a key/value store. +// Note that this method internally defers all mutating operations to the existing Set/Delete +// functions, so the mutex is not locked here func (s *inMemoryKeyValueStore) Update(kvPairs map[string]string) ([]string, error) { if s == nil { return nil, errors.New("store is nil") } - s.mu.Lock() - defer s.mu.Unlock() - - s.items = make(map[string][]byte) - for key, value := range kvPairs { if key == "" { return nil, errors.New("key is blank") } - s.items[key] = []byte(value) + s.Set([]byte(key), []byte(value)) } - var deletedKeys []string + deletedKeys := make([]string, 0) for key := range s.items { if _, ok := kvPairs[key]; ok { continue } - delete(s.items, key) + s.Delete([]byte(key)) // Remember which keys we're deleting deletedKeys = append(deletedKeys, key) @@ -118,3 +137,32 @@ func (s *inMemoryKeyValueStore) Update(kvPairs map[string]string) ([]string, err return deletedKeys, nil } + +func (s *inMemoryKeyValueStore) Count() (int, error) { + s.mu.Lock() + defer s.mu.Unlock() + + return len(s.items), nil +} + +func (s *inMemoryKeyValueStore) AppendValues(values ...[]byte) error { + if s == nil { + return fmt.Errorf("unable to append values into uninitialized inmemory db store") + } + + for _, value := range values { + s.Set(s.nextSequenceKey(), value) + } + + return nil +} + +func (s *inMemoryKeyValueStore) nextSequenceKey() []byte { + s.mu.Lock() + defer s.mu.Unlock() + + s.sequence++ + b := make([]byte, 8) + binary.BigEndian.PutUint64(b, s.sequence) + return b +} diff --git a/ee/agent/types/keyvalue_store.go b/ee/agent/types/keyvalue_store.go index 52f985a33..d8c9c10f8 100644 --- a/ee/agent/types/keyvalue_store.go +++ b/ee/agent/types/keyvalue_store.go @@ -37,10 +37,24 @@ type Iterator interface { type Updater interface { // Update takes a map of key-value pairs, and inserts // these key-values into the store. Any preexisting keys in the store which - // do not exist in data will be deleted. + // do not exist in data will be deleted, and the deleted keys will be returned Update(kvPairs map[string]string) ([]string, error) } +// Counter is an interface for reporting the count of key-value +// pairs held by the underlying storage methodology +type Counter interface { + // Count should return the total number of current key-value pairs + Count() (int, error) +} + +// Appender is an interface for supporting the ordered addition of values to a store +// implementations should generate keys to ensure an ordered iteration is possible +type Appender interface { + // AppendValues takes 1 or more ordered values + AppendValues(values ...[]byte) error +} + // GetterSetter is an interface that groups the Get and Set methods. type GetterSetter interface { Getter @@ -75,13 +89,15 @@ type GetterSetterDeleterIterator interface { } // GetterSetterDeleterIteratorUpdater is an interface that groups the Get, Set, Delete, ForEach, and Update methods. -type GetterSetterDeleterIteratorUpdater interface { +type GetterSetterDeleterIteratorUpdaterCounterAppender interface { Getter Setter Deleter Iterator Updater + Counter + Appender } // Convenient alias for a key value store that supports all methods -type KVStore = GetterSetterDeleterIteratorUpdater +type KVStore = GetterSetterDeleterIteratorUpdaterCounterAppender diff --git a/ee/agent/types/mocks/knapsack.go b/ee/agent/types/mocks/knapsack.go index 9d1b279dd..8be4fe74f 100644 --- a/ee/agent/types/mocks/knapsack.go +++ b/ee/agent/types/mocks/knapsack.go @@ -1,4 +1,4 @@ -// Code generated by mockery v2.21.1. DO NOT EDIT. +// Code generated by mockery v2.34.2. DO NOT EDIT. package mocks @@ -37,15 +37,15 @@ func (_m *Knapsack) AddSlogHandler(handler ...slog.Handler) { } // AgentFlagsStore provides a mock function with given fields: -func (_m *Knapsack) AgentFlagsStore() types.GetterSetterDeleterIteratorUpdater { +func (_m *Knapsack) AgentFlagsStore() types.GetterSetterDeleterIteratorUpdaterCounterAppender { ret := _m.Called() - var r0 types.GetterSetterDeleterIteratorUpdater - if rf, ok := ret.Get(0).(func() types.GetterSetterDeleterIteratorUpdater); ok { + var r0 types.GetterSetterDeleterIteratorUpdaterCounterAppender + if rf, ok := ret.Get(0).(func() types.GetterSetterDeleterIteratorUpdaterCounterAppender); ok { r0 = rf() } else { if ret.Get(0) != nil { - r0 = ret.Get(0).(types.GetterSetterDeleterIteratorUpdater) + r0 = ret.Get(0).(types.GetterSetterDeleterIteratorUpdaterCounterAppender) } } @@ -67,15 +67,15 @@ func (_m *Knapsack) Autoupdate() bool { } // AutoupdateErrorsStore provides a mock function with given fields: -func (_m *Knapsack) AutoupdateErrorsStore() types.GetterSetterDeleterIteratorUpdater { +func (_m *Knapsack) AutoupdateErrorsStore() types.GetterSetterDeleterIteratorUpdaterCounterAppender { ret := _m.Called() - var r0 types.GetterSetterDeleterIteratorUpdater - if rf, ok := ret.Get(0).(func() types.GetterSetterDeleterIteratorUpdater); ok { + var r0 types.GetterSetterDeleterIteratorUpdaterCounterAppender + if rf, ok := ret.Get(0).(func() types.GetterSetterDeleterIteratorUpdaterCounterAppender); ok { r0 = rf() } else { if ret.Get(0) != nil { - r0 = ret.Get(0).(types.GetterSetterDeleterIteratorUpdater) + r0 = ret.Get(0).(types.GetterSetterDeleterIteratorUpdaterCounterAppender) } } @@ -143,15 +143,15 @@ func (_m *Knapsack) CertPins() [][]byte { } // ConfigStore provides a mock function with given fields: -func (_m *Knapsack) ConfigStore() types.GetterSetterDeleterIteratorUpdater { +func (_m *Knapsack) ConfigStore() types.GetterSetterDeleterIteratorUpdaterCounterAppender { ret := _m.Called() - var r0 types.GetterSetterDeleterIteratorUpdater - if rf, ok := ret.Get(0).(func() types.GetterSetterDeleterIteratorUpdater); ok { + var r0 types.GetterSetterDeleterIteratorUpdaterCounterAppender + if rf, ok := ret.Get(0).(func() types.GetterSetterDeleterIteratorUpdaterCounterAppender); ok { r0 = rf() } else { if ret.Get(0) != nil { - r0 = ret.Get(0).(types.GetterSetterDeleterIteratorUpdater) + r0 = ret.Get(0).(types.GetterSetterDeleterIteratorUpdaterCounterAppender) } } @@ -187,15 +187,15 @@ func (_m *Knapsack) ControlServerURL() string { } // ControlStore provides a mock function with given fields: -func (_m *Knapsack) ControlStore() types.GetterSetterDeleterIteratorUpdater { +func (_m *Knapsack) ControlStore() types.GetterSetterDeleterIteratorUpdaterCounterAppender { ret := _m.Called() - var r0 types.GetterSetterDeleterIteratorUpdater - if rf, ok := ret.Get(0).(func() types.GetterSetterDeleterIteratorUpdater); ok { + var r0 types.GetterSetterDeleterIteratorUpdaterCounterAppender + if rf, ok := ret.Get(0).(func() types.GetterSetterDeleterIteratorUpdaterCounterAppender); ok { r0 = rf() } else { if ret.Get(0) != nil { - r0 = ret.Get(0).(types.GetterSetterDeleterIteratorUpdater) + r0 = ret.Get(0).(types.GetterSetterDeleterIteratorUpdaterCounterAppender) } } @@ -437,15 +437,15 @@ func (_m *Knapsack) InModernStandby() bool { } // InitialResultsStore provides a mock function with given fields: -func (_m *Knapsack) InitialResultsStore() types.GetterSetterDeleterIteratorUpdater { +func (_m *Knapsack) InitialResultsStore() types.GetterSetterDeleterIteratorUpdaterCounterAppender { ret := _m.Called() - var r0 types.GetterSetterDeleterIteratorUpdater - if rf, ok := ret.Get(0).(func() types.GetterSetterDeleterIteratorUpdater); ok { + var r0 types.GetterSetterDeleterIteratorUpdaterCounterAppender + if rf, ok := ret.Get(0).(func() types.GetterSetterDeleterIteratorUpdaterCounterAppender); ok { r0 = rf() } else { if ret.Get(0) != nil { - r0 = ret.Get(0).(types.GetterSetterDeleterIteratorUpdater) + r0 = ret.Get(0).(types.GetterSetterDeleterIteratorUpdaterCounterAppender) } } @@ -651,15 +651,15 @@ func (_m *Knapsack) OsqueryHealthcheckStartupDelay() time.Duration { } // OsqueryHistoryInstanceStore provides a mock function with given fields: -func (_m *Knapsack) OsqueryHistoryInstanceStore() types.GetterSetterDeleterIteratorUpdater { +func (_m *Knapsack) OsqueryHistoryInstanceStore() types.GetterSetterDeleterIteratorUpdaterCounterAppender { ret := _m.Called() - var r0 types.GetterSetterDeleterIteratorUpdater - if rf, ok := ret.Get(0).(func() types.GetterSetterDeleterIteratorUpdater); ok { + var r0 types.GetterSetterDeleterIteratorUpdaterCounterAppender + if rf, ok := ret.Get(0).(func() types.GetterSetterDeleterIteratorUpdaterCounterAppender); ok { r0 = rf() } else { if ret.Get(0) != nil { - r0 = ret.Get(0).(types.GetterSetterDeleterIteratorUpdater) + r0 = ret.Get(0).(types.GetterSetterDeleterIteratorUpdaterCounterAppender) } } @@ -765,15 +765,15 @@ func (_m *Knapsack) OsquerydPath() string { } // PersistentHostDataStore provides a mock function with given fields: -func (_m *Knapsack) PersistentHostDataStore() types.GetterSetterDeleterIteratorUpdater { +func (_m *Knapsack) PersistentHostDataStore() types.GetterSetterDeleterIteratorUpdaterCounterAppender { ret := _m.Called() - var r0 types.GetterSetterDeleterIteratorUpdater - if rf, ok := ret.Get(0).(func() types.GetterSetterDeleterIteratorUpdater); ok { + var r0 types.GetterSetterDeleterIteratorUpdaterCounterAppender + if rf, ok := ret.Get(0).(func() types.GetterSetterDeleterIteratorUpdaterCounterAppender); ok { r0 = rf() } else { if ret.Get(0) != nil { - r0 = ret.Get(0).(types.GetterSetterDeleterIteratorUpdater) + r0 = ret.Get(0).(types.GetterSetterDeleterIteratorUpdaterCounterAppender) } } @@ -845,15 +845,15 @@ func (_m *Knapsack) RegisterChangeObserver(observer types.FlagsChangeObserver, f } // ResultLogsStore provides a mock function with given fields: -func (_m *Knapsack) ResultLogsStore() types.GetterSetterDeleterIteratorUpdater { +func (_m *Knapsack) ResultLogsStore() types.GetterSetterDeleterIteratorUpdaterCounterAppender { ret := _m.Called() - var r0 types.GetterSetterDeleterIteratorUpdater - if rf, ok := ret.Get(0).(func() types.GetterSetterDeleterIteratorUpdater); ok { + var r0 types.GetterSetterDeleterIteratorUpdaterCounterAppender + if rf, ok := ret.Get(0).(func() types.GetterSetterDeleterIteratorUpdaterCounterAppender); ok { r0 = rf() } else { if ret.Get(0) != nil { - r0 = ret.Get(0).(types.GetterSetterDeleterIteratorUpdater) + r0 = ret.Get(0).(types.GetterSetterDeleterIteratorUpdaterCounterAppender) } } @@ -889,15 +889,15 @@ func (_m *Knapsack) RootPEM() string { } // SentNotificationsStore provides a mock function with given fields: -func (_m *Knapsack) SentNotificationsStore() types.GetterSetterDeleterIteratorUpdater { +func (_m *Knapsack) SentNotificationsStore() types.GetterSetterDeleterIteratorUpdaterCounterAppender { ret := _m.Called() - var r0 types.GetterSetterDeleterIteratorUpdater - if rf, ok := ret.Get(0).(func() types.GetterSetterDeleterIteratorUpdater); ok { + var r0 types.GetterSetterDeleterIteratorUpdaterCounterAppender + if rf, ok := ret.Get(0).(func() types.GetterSetterDeleterIteratorUpdaterCounterAppender); ok { r0 = rf() } else { if ret.Get(0) != nil { - r0 = ret.Get(0).(types.GetterSetterDeleterIteratorUpdater) + r0 = ret.Get(0).(types.GetterSetterDeleterIteratorUpdaterCounterAppender) } } @@ -905,15 +905,15 @@ func (_m *Knapsack) SentNotificationsStore() types.GetterSetterDeleterIteratorUp } // ServerProvidedDataStore provides a mock function with given fields: -func (_m *Knapsack) ServerProvidedDataStore() types.GetterSetterDeleterIteratorUpdater { +func (_m *Knapsack) ServerProvidedDataStore() types.GetterSetterDeleterIteratorUpdaterCounterAppender { ret := _m.Called() - var r0 types.GetterSetterDeleterIteratorUpdater - if rf, ok := ret.Get(0).(func() types.GetterSetterDeleterIteratorUpdater); ok { + var r0 types.GetterSetterDeleterIteratorUpdaterCounterAppender + if rf, ok := ret.Get(0).(func() types.GetterSetterDeleterIteratorUpdaterCounterAppender); ok { r0 = rf() } else { if ret.Get(0) != nil { - r0 = ret.Get(0).(types.GetterSetterDeleterIteratorUpdater) + r0 = ret.Get(0).(types.GetterSetterDeleterIteratorUpdaterCounterAppender) } } @@ -1475,15 +1475,15 @@ func (_m *Knapsack) Slogger() *slog.Logger { } // StatusLogsStore provides a mock function with given fields: -func (_m *Knapsack) StatusLogsStore() types.GetterSetterDeleterIteratorUpdater { +func (_m *Knapsack) StatusLogsStore() types.GetterSetterDeleterIteratorUpdaterCounterAppender { ret := _m.Called() - var r0 types.GetterSetterDeleterIteratorUpdater - if rf, ok := ret.Get(0).(func() types.GetterSetterDeleterIteratorUpdater); ok { + var r0 types.GetterSetterDeleterIteratorUpdaterCounterAppender + if rf, ok := ret.Get(0).(func() types.GetterSetterDeleterIteratorUpdaterCounterAppender); ok { r0 = rf() } else { if ret.Get(0) != nil { - r0 = ret.Get(0).(types.GetterSetterDeleterIteratorUpdater) + r0 = ret.Get(0).(types.GetterSetterDeleterIteratorUpdaterCounterAppender) } } @@ -1491,15 +1491,15 @@ func (_m *Knapsack) StatusLogsStore() types.GetterSetterDeleterIteratorUpdater { } // Stores provides a mock function with given fields: -func (_m *Knapsack) Stores() map[storage.Store]types.GetterSetterDeleterIteratorUpdater { +func (_m *Knapsack) Stores() map[storage.Store]types.GetterSetterDeleterIteratorUpdaterCounterAppender { ret := _m.Called() - var r0 map[storage.Store]types.GetterSetterDeleterIteratorUpdater - if rf, ok := ret.Get(0).(func() map[storage.Store]types.GetterSetterDeleterIteratorUpdater); ok { + var r0 map[storage.Store]types.GetterSetterDeleterIteratorUpdaterCounterAppender + if rf, ok := ret.Get(0).(func() map[storage.Store]types.GetterSetterDeleterIteratorUpdaterCounterAppender); ok { r0 = rf() } else { if ret.Get(0) != nil { - r0 = ret.Get(0).(map[storage.Store]types.GetterSetterDeleterIteratorUpdater) + r0 = ret.Get(0).(map[storage.Store]types.GetterSetterDeleterIteratorUpdaterCounterAppender) } } @@ -1523,15 +1523,15 @@ func (_m *Knapsack) SystemSlogger() *slog.Logger { } // TokenStore provides a mock function with given fields: -func (_m *Knapsack) TokenStore() types.GetterSetterDeleterIteratorUpdater { +func (_m *Knapsack) TokenStore() types.GetterSetterDeleterIteratorUpdaterCounterAppender { ret := _m.Called() - var r0 types.GetterSetterDeleterIteratorUpdater - if rf, ok := ret.Get(0).(func() types.GetterSetterDeleterIteratorUpdater); ok { + var r0 types.GetterSetterDeleterIteratorUpdaterCounterAppender + if rf, ok := ret.Get(0).(func() types.GetterSetterDeleterIteratorUpdaterCounterAppender); ok { r0 = rf() } else { if ret.Get(0) != nil { - r0 = ret.Get(0).(types.GetterSetterDeleterIteratorUpdater) + r0 = ret.Get(0).(types.GetterSetterDeleterIteratorUpdaterCounterAppender) } } @@ -1692,13 +1692,12 @@ func (_m *Knapsack) WatchdogUtilizationLimitPercent() int { return r0 } -type mockConstructorTestingTNewKnapsack interface { +// NewKnapsack creates a new instance of Knapsack. It also registers a testing interface on the mock and a cleanup function to assert the mocks expectations. +// The first argument is typically a *testing.T value. +func NewKnapsack(t interface { mock.TestingT Cleanup(func()) -} - -// NewKnapsack creates a new instance of Knapsack. It also registers a testing interface on the mock and a cleanup function to assert the mocks expectations. -func NewKnapsack(t mockConstructorTestingTNewKnapsack) *Knapsack { +}) *Knapsack { mock := &Knapsack{} mock.Mock.Test(t) diff --git a/pkg/osquery/extension.go b/pkg/osquery/extension.go index 294340006..10e722c3f 100644 --- a/pkg/osquery/extension.go +++ b/pkg/osquery/extension.go @@ -15,7 +15,6 @@ import ( "github.com/google/uuid" "github.com/kolide/launcher/ee/agent" - "github.com/kolide/launcher/ee/agent/storage" "github.com/kolide/launcher/ee/agent/types" "github.com/kolide/launcher/pkg/backoff" "github.com/kolide/launcher/pkg/osquery/runtime/history" @@ -26,7 +25,6 @@ import ( "github.com/osquery/osquery-go/plugin/logger" "github.com/pkg/errors" - "go.etcd.io/bbolt" "google.golang.org/grpc/codes" "google.golang.org/grpc/status" ) @@ -97,6 +95,12 @@ type ExtensionOpts struct { skipHardwareKeysSetup bool } +type iterationTerminatedError struct{} + +func (e iterationTerminatedError) Error() string { + return "ceasing kv store iteration" +} + // NewExtension creates a new Extension from the provided service.KolideService // implementation. The background routines should be started by calling // Start(). @@ -631,16 +635,15 @@ func uint64FromByteKey(k []byte) uint64 { return binary.BigEndian.Uint64(k) } -// bucketNameFromLogType returns the Bolt bucket name that stores logs of the -// provided type. -func bucketNameFromLogType(typ logger.LogType) (string, error) { +// storeForLogType returns the store with the logs of the provided type. +func storeForLogType(s types.Stores, typ logger.LogType) (types.KVStore, error) { switch typ { case logger.LogTypeString, logger.LogTypeSnapshot: - return storage.ResultLogsStore.String(), nil + return s.ResultLogsStore(), nil case logger.LogTypeStatus: - return storage.StatusLogsStore.String(), nil + return s.StatusLogsStore(), nil default: - return "", fmt.Errorf("unknown log type: %v", typ) + return nil, fmt.Errorf("unknown log type: %v", typ) } } @@ -676,31 +679,11 @@ func (e *Extension) writeAndPurgeLogs() { } } -// numberOfBufferedLogs returns the number of logs buffered for a given type. -func (e *Extension) numberOfBufferedLogs(typ logger.LogType) (int, error) { - bucketName, err := bucketNameFromLogType(typ) - if err != nil { - return 0, err - } - - var count int - err = e.knapsack.BboltDB().View(func(tx *bbolt.Tx) error { - b := tx.Bucket([]byte(bucketName)) - count = b.Stats().KeyN - return nil - }) - if err != nil { - return 0, fmt.Errorf("counting buffered logs: %w", err) - } - - return count, nil -} - // writeBufferedLogs flushes the log buffers, writing up to // Opts.MaxBytesPerBatch bytes worth of logs in one run. If the logs write // successfully, they will be deleted from the buffer. func (e *Extension) writeBufferedLogsForType(typ logger.LogType) error { - bucketName, err := bucketNameFromLogType(typ) + store, err := storeForLogType(e.knapsack, typ) if err != nil { return err } @@ -709,54 +692,49 @@ func (e *Extension) writeBufferedLogsForType(typ logger.LogType) error { var logs []string var logIDs [][]byte bufferFilled := false - err = e.knapsack.BboltDB().View(func(tx *bbolt.Tx) error { - b := tx.Bucket([]byte(bucketName)) - - c := b.Cursor() - k, v := c.First() - for totalBytes := 0; k != nil; { - // A somewhat cumbersome if block... - // - // 1. If the log is too big, skip it and mark for deletion. - // 2. If the buffer would be too big with the log, break for - // 3. Else append it - // - // Note that (1) must come first, otherwise (2) will always trigger. - if e.logPublicationState.ExceedsCurrentBatchThreshold(len(v)) { - // Discard logs that are too big - logheadSize := minInt(len(v), 100) - e.slogger.Log(context.TODO(), slog.LevelInfo, - "dropped log", - "log_id", k, - "size", len(v), - "limit", e.Opts.MaxBytesPerBatch, - "loghead", string(v)[0:logheadSize], - ) - } else if e.logPublicationState.ExceedsCurrentBatchThreshold(totalBytes + len(v)) { - // Buffer is filled. Break the loop and come back later. - bufferFilled = true - break - } else { - logs = append(logs, string(v)) - totalBytes += len(v) - } - - // Note the logID for deletion. We do this by - // making a copy of k. It is retained in - // logIDs after the transaction is closed, - // when the goroutine ticks it zeroes out some - // of the IDs to delete below, causing logs to - // remain in the buffer and be sent again to - // the server. - logID := make([]byte, len(k)) - copy(logID, k) - logIDs = append(logIDs, logID) - - k, v = c.Next() + totalBytes := 0 + err = store.ForEach(func(k, v []byte) error { + // A somewhat cumbersome if block... + // + // 1. If the log is too big, skip it and mark for deletion. + // 2. If the buffer would be too big with the log, break for + // 3. Else append it + // + // Note that (1) must come first, otherwise (2) will always trigger. + if e.logPublicationState.ExceedsCurrentBatchThreshold(len(v)) { + // Discard logs that are too big + logheadSize := minInt(len(v), 100) + e.slogger.Log(context.TODO(), slog.LevelInfo, + "dropped log", + "log_id", k, + "size", len(v), + "limit", e.Opts.MaxBytesPerBatch, + "loghead", string(v)[0:logheadSize], + ) + } else if e.logPublicationState.ExceedsCurrentBatchThreshold(totalBytes + len(v)) { + // Buffer is filled. Break the loop and come back later. + return iterationTerminatedError{} + } else { + logs = append(logs, string(v)) + totalBytes += len(v) } + + // Note the logID for deletion. We do this by + // making a copy of k. It is retained in + // logIDs after the transaction is closed, + // when the goroutine ticks it zeroes out some + // of the IDs to delete below, causing logs to + // remain in the buffer and be sent again to + // the server. + logID := make([]byte, len(k)) + copy(logID, k) + logIDs = append(logIDs, logID) return nil }) - if err != nil { + + if err != nil && errors.Is(err, iterationTerminatedError{}) { + bufferFilled = true + } else if err != nil { return fmt.Errorf("reading buffered logs: %w", err) } @@ -778,13 +756,8 @@ func (e *Extension) writeBufferedLogsForType(typ logger.LogType) error { } // Delete logs that were successfully sent - err = e.knapsack.BboltDB().Update(func(tx *bbolt.Tx) error { - b := tx.Bucket([]byte(bucketName)) - for _, k := range logIDs { - b.Delete(k) - } - return nil - }) + err = store.Delete(logIDs...) + if err != nil { return fmt.Errorf("deleting sent logs: %w", err) } @@ -829,40 +802,38 @@ func (e *Extension) writeLogsWithReenroll(ctx context.Context, typ logger.LogTyp // purgeBufferedLogsForType flushes the log buffers for the provided type, // ensuring that at most Opts.MaxBufferedLogs logs remain. func (e *Extension) purgeBufferedLogsForType(typ logger.LogType) error { - bucketName, err := bucketNameFromLogType(typ) + store, err := storeForLogType(e.knapsack, typ) if err != nil { return err } - err = e.knapsack.BboltDB().Update(func(tx *bbolt.Tx) error { - b := tx.Bucket([]byte(bucketName)) - logCount := b.Stats().KeyN - deleteCount := logCount - e.Opts.MaxBufferedLogs - - if deleteCount <= 0 { - // Limit not exceeded - return nil - } + totalCount, err := store.Count() + if err != nil { + return err + } - e.slogger.Log(context.TODO(), slog.LevelInfo, - "buffered logs limit exceeded, purging excess", - "limit", e.Opts.MaxBufferedLogs, - "purge_count", deleteCount, - ) + deleteCount := totalCount - e.Opts.MaxBufferedLogs + if deleteCount <= 0 { // Limit not exceeded + return nil + } - c := b.Cursor() - k, _ := c.First() - for total := 0; k != nil && total < deleteCount; total++ { - c.Delete() // Note: This advances the cursor - k, _ = c.First() + logIdsCollectedCount := 0 + logIDsForDeletion := make([][]byte, deleteCount) + if err = store.ForEach(func(k, v []byte) error { + if logIdsCollectedCount >= deleteCount { + return iterationTerminatedError{} } + logID := make([]byte, len(k)) + copy(logID, k) + logIDsForDeletion = append(logIDsForDeletion, logID) + logIdsCollectedCount++ return nil - }) - if err != nil { - return fmt.Errorf("deleting overflowed logs: %w", err) + }); err != nil && !errors.Is(err, iterationTerminatedError{}) { + return fmt.Errorf("collecting overflowed log keys for deletion: %w", err) } - return nil + + return store.Delete(logIDsForDeletion...) } // LogString will buffer logs from osquery into the local BoltDB store. No @@ -875,7 +846,7 @@ func (e *Extension) LogString(ctx context.Context, typ logger.LogType, logText s return nil } - bucketName, err := bucketNameFromLogType(typ) + store, err := storeForLogType(e.knapsack, typ) if err != nil { e.slogger.Log(ctx, slog.LevelInfo, "received unknown log type", @@ -885,25 +856,9 @@ func (e *Extension) LogString(ctx context.Context, typ logger.LogType, logText s } // Buffer the log for sending later in a batch - err = e.knapsack.BboltDB().Update(func(tx *bbolt.Tx) error { - b := tx.Bucket([]byte(bucketName)) - - // Log keys are generated with the auto-incrementing sequence - // number provided by BoltDB. These must be converted to []byte - // (which we do with byteKeyFromUint64 function). - key, err := b.NextSequence() - if err != nil { - return fmt.Errorf("generating key: %w", err) - } - - return b.Put(byteKeyFromUint64(key), []byte(logText)) - }) - - if err != nil { - return fmt.Errorf("buffering log: %w", err) - } - - return nil + // note that AppendValues guarantees these logs are inserted with + // sequential keys for ordered retrieval later + return store.AppendValues([]byte(logText)) } // GetQueries will request the distributed queries to execute from the server. diff --git a/pkg/osquery/extension_test.go b/pkg/osquery/extension_test.go index 07de49f04..d6bed35c8 100644 --- a/pkg/osquery/extension_test.go +++ b/pkg/osquery/extension_test.go @@ -555,16 +555,15 @@ func TestExtensionWriteBufferedLogsEmpty(t *testing.T) { return "", "", false, nil }, } - db, cleanup := makeTempDB(t) - defer cleanup() // Create the status logs bucket ahead of time - agentbbolt.NewStore(multislogger.NewNopLogger(), db, storage.StatusLogsStore.String()) + statusLogsStore, err := storageci.NewStore(t, multislogger.NewNopLogger(), storage.StatusLogsStore.String()) + require.NoError(t, err) k := mocks.NewKnapsack(t) k.On("ConfigStore").Return(storageci.NewStore(t, multislogger.NewNopLogger(), storage.ConfigStore.String())) - k.On("BboltDB").Return(db) k.On("Slogger").Return(multislogger.NewNopLogger()).Maybe() + k.On("StatusLogsStore").Return(statusLogsStore) k.On("ReadEnrollSecret").Maybe().Return("enroll_secret", nil) e, err := NewExtension(context.TODO(), m, k, ExtensionOpts{ @@ -595,17 +594,18 @@ func TestExtensionWriteBufferedLogs(t *testing.T) { return "", "", false, nil }, } - db, cleanup := makeTempDB(t) - defer cleanup() // Create these buckets ahead of time - agentbbolt.NewStore(multislogger.NewNopLogger(), db, storage.ResultLogsStore.String()) - agentbbolt.NewStore(multislogger.NewNopLogger(), db, storage.StatusLogsStore.String()) + statusLogsStore, err := storageci.NewStore(t, multislogger.NewNopLogger(), storage.StatusLogsStore.String()) + require.NoError(t, err) + resultLogsStore, err := storageci.NewStore(t, multislogger.NewNopLogger(), storage.ResultLogsStore.String()) + require.NoError(t, err) k := mocks.NewKnapsack(t) k.On("ConfigStore").Return(storageci.NewStore(t, multislogger.NewNopLogger(), storage.ConfigStore.String())) - k.On("BboltDB").Return(db) k.On("Slogger").Return(multislogger.NewNopLogger()).Maybe() + k.On("StatusLogsStore").Return(statusLogsStore) + k.On("ResultLogsStore").Return(resultLogsStore) k.On("ReadEnrollSecret").Maybe().Return("enroll_secret", nil) e, err := NewExtension(context.TODO(), m, k, ExtensionOpts{ @@ -666,15 +666,14 @@ func TestExtensionWriteBufferedLogsEnrollmentInvalid(t *testing.T) { return expectedNodeKey, false, nil }, } - db, cleanup := makeTempDB(t) - defer cleanup() // Create the status logs bucket ahead of time - agentbbolt.NewStore(multislogger.NewNopLogger(), db, storage.StatusLogsStore.String()) + statusLogsStore, err := storageci.NewStore(t, multislogger.NewNopLogger(), storage.StatusLogsStore.String()) + require.NoError(t, err) k := mocks.NewKnapsack(t) k.On("ConfigStore").Return(storageci.NewStore(t, multislogger.NewNopLogger(), storage.ConfigStore.String())) - k.On("BboltDB").Return(db) + k.On("StatusLogsStore").Return(statusLogsStore) k.On("OsquerydPath").Maybe().Return("") k.On("LatestOsquerydPath", testifymock.Anything).Maybe().Return("") k.On("Slogger").Return(multislogger.NewNopLogger()) @@ -714,17 +713,18 @@ func TestExtensionWriteBufferedLogsLimit(t *testing.T) { return "", "", false, nil }, } - db, cleanup := makeTempDB(t) - defer cleanup() // Create the status logs bucket ahead of time - agentbbolt.NewStore(multislogger.NewNopLogger(), db, storage.StatusLogsStore.String()) - agentbbolt.NewStore(multislogger.NewNopLogger(), db, storage.ResultLogsStore.String()) + statusLogsStore, err := storageci.NewStore(t, multislogger.NewNopLogger(), storage.StatusLogsStore.String()) + require.NoError(t, err) + resultLogsStore, err := storageci.NewStore(t, multislogger.NewNopLogger(), storage.ResultLogsStore.String()) + require.NoError(t, err) k := mocks.NewKnapsack(t) k.On("ConfigStore").Return(storageci.NewStore(t, multislogger.NewNopLogger(), storage.ConfigStore.String())) - k.On("BboltDB").Return(db) k.On("Slogger").Return(multislogger.NewNopLogger()) + k.On("StatusLogsStore").Return(statusLogsStore) + k.On("ResultLogsStore").Return(resultLogsStore) e, err := NewExtension(context.TODO(), m, k, ExtensionOpts{ MaxBytesPerBatch: 100, @@ -788,17 +788,14 @@ func TestExtensionWriteBufferedLogsDropsBigLog(t *testing.T) { return "", "", false, nil }, } - db, cleanup := makeTempDB(t) - defer cleanup() - // Create the status logs bucket ahead of time - agentbbolt.NewStore(multislogger.NewNopLogger(), db, storage.StatusLogsStore.String()) - agentbbolt.NewStore(multislogger.NewNopLogger(), db, storage.ResultLogsStore.String()) + resultLogsStore, err := storageci.NewStore(t, multislogger.NewNopLogger(), storage.ResultLogsStore.String()) + require.NoError(t, err) k := mocks.NewKnapsack(t) k.On("ConfigStore").Return(storageci.NewStore(t, multislogger.NewNopLogger(), storage.ConfigStore.String())) - k.On("BboltDB").Return(db) k.On("Slogger").Return(multislogger.NewNopLogger()) + k.On("ResultLogsStore").Return(resultLogsStore) e, err := NewExtension(context.TODO(), m, k, ExtensionOpts{ MaxBytesPerBatch: 15, @@ -806,7 +803,7 @@ func TestExtensionWriteBufferedLogsDropsBigLog(t *testing.T) { }) require.Nil(t, err) - startLogCount, err := e.numberOfBufferedLogs(logger.LogTypeString) + startLogCount, err := e.knapsack.ResultLogsStore().Count() require.NoError(t, err) require.Equal(t, 0, startLogCount, "start with no buffered logs") @@ -820,7 +817,7 @@ func TestExtensionWriteBufferedLogsDropsBigLog(t *testing.T) { e.LogString(context.Background(), logger.LogTypeString, "res4") e.LogString(context.Background(), logger.LogTypeString, "this_result_is_tooooooo_big! darn") - queuedLogCount, err := e.numberOfBufferedLogs(logger.LogTypeString) + queuedLogCount, err := e.knapsack.ResultLogsStore().Count() require.NoError(t, err) require.Equal(t, 8, queuedLogCount, "correct number of enqueued logs") @@ -845,7 +842,7 @@ func TestExtensionWriteBufferedLogsDropsBigLog(t *testing.T) { assert.Nil(t, gotResultLogs) assert.Nil(t, gotStatusLogs) - finalLogCount, err := e.numberOfBufferedLogs(logger.LogTypeString) + finalLogCount, err := e.knapsack.ResultLogsStore().Count() require.NoError(t, err) require.Equal(t, 0, finalLogCount, "no more queued logs") } @@ -871,17 +868,18 @@ func TestExtensionWriteLogsLoop(t *testing.T) { return "", "", false, nil }, } - db, cleanup := makeTempDB(t) - defer cleanup() // Create the status logs bucket ahead of time - agentbbolt.NewStore(multislogger.NewNopLogger(), db, storage.StatusLogsStore.String()) - agentbbolt.NewStore(multislogger.NewNopLogger(), db, storage.ResultLogsStore.String()) + statusLogsStore, err := storageci.NewStore(t, multislogger.NewNopLogger(), storage.StatusLogsStore.String()) + require.NoError(t, err) + resultLogsStore, err := storageci.NewStore(t, multislogger.NewNopLogger(), storage.ResultLogsStore.String()) + require.NoError(t, err) k := mocks.NewKnapsack(t) k.On("ConfigStore").Return(storageci.NewStore(t, multislogger.NewNopLogger(), storage.ConfigStore.String())) - k.On("BboltDB").Return(db) k.On("Slogger").Return(multislogger.NewNopLogger()) + k.On("StatusLogsStore").Return(statusLogsStore) + k.On("ResultLogsStore").Return(resultLogsStore) mockClock := clock.NewMockClock() expectedLoggingInterval := 10 * time.Second @@ -1002,16 +1000,17 @@ func TestExtensionPurgeBufferedLogs(t *testing.T) { return "", "", false, errors.New("server rejected logs") }, } - db, cleanup := makeTempDB(t) - defer cleanup() // Create these buckets ahead of time - agentbbolt.NewStore(multislogger.NewNopLogger(), db, storage.ResultLogsStore.String()) - agentbbolt.NewStore(multislogger.NewNopLogger(), db, storage.StatusLogsStore.String()) + statusLogsStore, err := storageci.NewStore(t, multislogger.NewNopLogger(), storage.StatusLogsStore.String()) + require.NoError(t, err) + resultLogsStore, err := storageci.NewStore(t, multislogger.NewNopLogger(), storage.ResultLogsStore.String()) + require.NoError(t, err) k := mocks.NewKnapsack(t) k.On("ConfigStore").Return(storageci.NewStore(t, multislogger.NewNopLogger(), storage.ConfigStore.String())) - k.On("BboltDB").Return(db) + k.On("StatusLogsStore").Return(statusLogsStore) + k.On("ResultLogsStore").Return(resultLogsStore) k.On("Slogger").Return(multislogger.NewNopLogger()) max := 10