diff --git a/pkg/osquery/extension.go b/pkg/osquery/extension.go index abe2e5c70..583b03064 100644 --- a/pkg/osquery/extension.go +++ b/pkg/osquery/extension.go @@ -36,14 +36,15 @@ import ( // and servers -- It provides a grpc and jsonrpc interface for // osquery. It does not provide any tables. type Extension struct { - NodeKey string - Opts ExtensionOpts - knapsack types.Knapsack - serviceClient service.KolideService - enrollMutex sync.Mutex - done chan struct{} - interrupted bool - slogger *slog.Logger + NodeKey string + Opts ExtensionOpts + knapsack types.Knapsack + serviceClient service.KolideService + enrollMutex sync.Mutex + done chan struct{} + interrupted bool + slogger *slog.Logger + logPublicationState *logPublicationState } const ( @@ -146,12 +147,13 @@ func NewExtension(ctx context.Context, client service.KolideService, k types.Kna } return &Extension{ - slogger: slogger, - serviceClient: client, - knapsack: k, - NodeKey: nodekey, - Opts: opts, - done: make(chan struct{}), + slogger: slogger, + serviceClient: client, + knapsack: k, + NodeKey: nodekey, + Opts: opts, + done: make(chan struct{}), + logPublicationState: NewLogPublicationState(opts.MaxBytesPerBatch), }, nil } @@ -647,12 +649,14 @@ func bucketNameFromLogType(typ logger.LogType) (string, error) { // buffers. func (e *Extension) writeAndPurgeLogs() { for _, typ := range []logger.LogType{logger.LogTypeStatus, logger.LogTypeString} { + originalBatchState := e.logPublicationState.CurrentValues() // Write logs err := e.writeBufferedLogsForType(typ) if err != nil { e.slogger.Log(context.TODO(), slog.LevelInfo, "sending logs", "type", typ.String(), + "attempted_publication_state", originalBatchState, "err", err, ) } @@ -701,6 +705,7 @@ func (e *Extension) writeBufferedLogsForType(typ logger.LogType) error { // Collect up logs to be sent var logs []string var logIDs [][]byte + bufferFilled := false err = e.knapsack.BboltDB().View(func(tx *bbolt.Tx) error { b := tx.Bucket([]byte(bucketName)) @@ -714,7 +719,7 @@ func (e *Extension) writeBufferedLogsForType(typ logger.LogType) error { // 3. Else append it // // Note that (1) must come first, otherwise (2) will always trigger. - if len(v) > e.Opts.MaxBytesPerBatch { + if e.logPublicationState.ExceedsCurrentBatchThreshold(len(v)) { // Discard logs that are too big logheadSize := minInt(len(v), 100) e.slogger.Log(context.TODO(), slog.LevelInfo, @@ -724,8 +729,9 @@ func (e *Extension) writeBufferedLogsForType(typ logger.LogType) error { "limit", e.Opts.MaxBytesPerBatch, "loghead", string(v)[0:logheadSize], ) - } else if totalBytes+len(v) > e.Opts.MaxBytesPerBatch { + } else if e.logPublicationState.ExceedsCurrentBatchThreshold(totalBytes + len(v)) { // Buffer is filled. Break the loop and come back later. + bufferFilled = true break } else { logs = append(logs, string(v)) @@ -756,7 +762,14 @@ func (e *Extension) writeBufferedLogsForType(typ logger.LogType) error { return nil } - err = e.writeLogsWithReenroll(context.Background(), typ, logs, true) + // inform the publication state tracking whether this batch should be used to + // determine the appropriate limit + e.logPublicationState.BeginBatch(time.Now(), bufferFilled) + publicationCtx := context.WithValue(context.Background(), + service.PublicationCtxKey, + e.logPublicationState.CurrentValues(), + ) + err = e.writeLogsWithReenroll(publicationCtx, typ, logs, true) if err != nil { return fmt.Errorf("writing logs: %w", err) } @@ -779,31 +792,35 @@ func (e *Extension) writeBufferedLogsForType(typ logger.LogType) error { // Helper to allow for a single attempt at re-enrollment func (e *Extension) writeLogsWithReenroll(ctx context.Context, typ logger.LogType, logs []string, reenroll bool) error { _, _, invalid, err := e.serviceClient.PublishLogs(ctx, e.NodeKey, typ, logs) - if isNodeInvalidErr(err) { - invalid = true - } else if err != nil { - return fmt.Errorf("transport error sending logs: %w", err) + invalid = invalid || isNodeInvalidErr(err) + if !invalid && err == nil { + // publication was successful- update logPublicationState and move on + e.logPublicationState.EndBatch(logs, true) + return nil } - if invalid { - if !reenroll { - return errors.New("enrollment invalid, reenroll disabled") - } + if err != nil { + // logPublicationState will determine whether this failure should impact + // the batch size limit based on the elapsed time + e.logPublicationState.EndBatch(logs, false) + return fmt.Errorf("transport error sending logs: %w", err) + } - e.RequireReenroll(ctx) - _, invalid, err := e.Enroll(ctx) - if err != nil { - return fmt.Errorf("enrollment invalid, reenrollment errored: %w", err) - } - if invalid { - return errors.New("enrollment invalid, reenrollment invalid") - } + if !reenroll { + return errors.New("enrollment invalid, reenroll disabled") + } - // Don't attempt reenroll after first attempt - return e.writeLogsWithReenroll(ctx, typ, logs, false) + e.RequireReenroll(ctx) + _, invalid, err = e.Enroll(ctx) + if err != nil { + return fmt.Errorf("enrollment invalid, reenrollment errored: %w", err) + } + if invalid { + return errors.New("enrollment invalid, reenrollment invalid") } - return nil + // Don't attempt reenroll after first attempt + return e.writeLogsWithReenroll(ctx, typ, logs, false) } // purgeBufferedLogsForType flushes the log buffers for the provided type, @@ -981,3 +998,11 @@ func minInt(a, b int) int { return b } + +func maxInt(a, b int) int { + if a > b { + return a + } + + return b +} diff --git a/pkg/osquery/log_publication_state.go b/pkg/osquery/log_publication_state.go new file mode 100644 index 000000000..cec144d66 --- /dev/null +++ b/pkg/osquery/log_publication_state.go @@ -0,0 +1,105 @@ +package osquery + +import "time" + +const ( + // minBytesPerBatch sets the minimum batch size to 0.5mb as lower bound for correction + minBytesPerBatch int = 524288 + // batchIncrementAmount (0.5mb) is the incremental increase amount for the target batch + // size when previous runs have been successful + batchIncrementAmount int = 524288 + // maxPublicationDuration is the total time a batch can take without an error triggering a reduction + // in the max batch size + maxPublicationDuration time.Duration = 20 * time.Second +) + +// logPublicationState holds stateful logic to influence the log batch publication size +// depending on prior successes or failures. The primary intent here is to prevent repeatedly +// consuming the entire network bandwidth available to devices that are unable to ship +// the standard maxBytesPerBatch inside of the connection timeout (currently 30 seconds, enforced cloud side). +// Note that we always expect these batches to be sent sequentially, BeginBatch -> EndBatch +// this would need rework (likely state locking and batch ID tracking) to support concurrent batch publications +type logPublicationState struct { + // maxBytesPerBatch is passed in from the extension opts and respected + // as a fixed upper limit for batch size, regardless of publication success/failure rates + maxBytesPerBatch int + // currentMaxBytesPerBatch represents the (stateful) upper limit being enforced + currentMaxBytesPerBatch int + // currentBatchBufferFilled is used to indicate when a batch's success can be used + // to increase the threshold (we only want to increase the threshold after sending full, not partial, batches successfully) + currentBatchBufferFilled bool + currentBatchStartTime time.Time +} + +func NewLogPublicationState(maxBytesPerBatch int) *logPublicationState { + return &logPublicationState{ + maxBytesPerBatch: maxBytesPerBatch, + currentMaxBytesPerBatch: maxBytesPerBatch, + } +} + +// BeginBatch sets the opening state before attempting to publish a batch of logs. Specifically, it must +// - note the time (to determine if an error later was timeout related) +// - note whether this batch is full (to determine whether success should increase the limit on success later) +func (lps *logPublicationState) BeginBatch(startTime time.Time, bufferFilled bool) { + lps.currentBatchStartTime = startTime + lps.currentBatchBufferFilled = bufferFilled +} + +func (lps *logPublicationState) CurrentValues() map[string]int { + return map[string]int{ + "options_batch_limit_bytes": lps.maxBytesPerBatch, + "current_batch_limit_bytes": lps.currentMaxBytesPerBatch, + } +} + +func (lps *logPublicationState) EndBatch(logs []string, successful bool) { + // ensure we reset all batch state at the end + defer func() { + lps.currentBatchBufferFilled = false + lps.currentBatchStartTime = time.Time{} + }() + + // we can always safely decrease the threshold for a failed batch, but + // shouldn't increase the threshold for a successful batch unless we've at + // least filled the buffer + if successful && !lps.currentBatchBufferFilled { + return + } + + // in practice there could be one of a few different transport or timeout errors that bubble up + // depending on network conditions. instead of trying to keep up with all potential errors, + // only reduce the threshold if the calls are failing after more than 20 seconds + if !successful && time.Since(lps.currentBatchStartTime) < maxPublicationDuration { + return + } + + if successful { + lps.increaseBatchThreshold() + return + } + + lps.reduceBatchThreshold() +} + +func (lps *logPublicationState) ExceedsCurrentBatchThreshold(amountBytes int) bool { + return amountBytes > lps.currentMaxBytesPerBatch +} + +func (lps *logPublicationState) reduceBatchThreshold() { + if lps.currentMaxBytesPerBatch <= minBytesPerBatch { + return + } + + newTargetThreshold := lps.currentMaxBytesPerBatch - batchIncrementAmount + lps.currentMaxBytesPerBatch = maxInt(newTargetThreshold, minBytesPerBatch) +} + +func (lps *logPublicationState) increaseBatchThreshold() { + if lps.currentMaxBytesPerBatch >= lps.maxBytesPerBatch { + return + } + + newTargetThreshold := lps.currentMaxBytesPerBatch + batchIncrementAmount + lps.currentMaxBytesPerBatch = minInt(newTargetThreshold, lps.maxBytesPerBatch) +} diff --git a/pkg/osquery/log_publication_state_test.go b/pkg/osquery/log_publication_state_test.go new file mode 100644 index 000000000..a4f129a77 --- /dev/null +++ b/pkg/osquery/log_publication_state_test.go @@ -0,0 +1,123 @@ +//nolint:paralleltest +package osquery + +import ( + "context" + "errors" + "testing" + "time" + + "github.com/kolide/launcher/pkg/service/mock" + "github.com/osquery/osquery-go/plugin/logger" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +func TestExtensionLogPublicationHappyPath(t *testing.T) { + startingBatchLimitBytes := minBytesPerBatch * 4 + m := &mock.KolideService{ + PublishLogsFunc: func(ctx context.Context, nodeKey string, logType logger.LogType, logs []string) (string, string, bool, error) { + return "", "", false, nil + }, + } + db, cleanup := makeTempDB(t) + defer cleanup() + k := makeKnapsack(t, db) + e, err := NewExtension(context.TODO(), m, k, ExtensionOpts{MaxBytesPerBatch: startingBatchLimitBytes}) + require.Nil(t, err) + + // issue a few successful calls, expect that the batch limit is unchanged from the original opts + for i := 0; i < 3; i++ { + e.logPublicationState.BeginBatch(time.Now(), true) + err = e.writeLogsWithReenroll(context.Background(), logger.LogTypeSnapshot, []string{"foobar"}, true) + assert.Nil(t, err) + assert.Equal(t, e.logPublicationState.currentMaxBytesPerBatch, startingBatchLimitBytes) + // always expect that these values are reset between runs + assert.Equal(t, time.Time{}, e.logPublicationState.currentBatchStartTime) + assert.False(t, e.logPublicationState.currentBatchBufferFilled) + } +} + +func TestExtensionLogPublicationRespondsToNetworkTimeouts(t *testing.T) { + numberOfPublicationRounds := 3 + publicationCalledCount := -1 + // startingBatchLimitBytes is set this way to ensure sufficient room for correction in both directions + startingBatchLimitBytes := minBytesPerBatch * (numberOfPublicationRounds + 1) + m := &mock.KolideService{ + PublishLogsFunc: func(ctx context.Context, nodeKey string, logType logger.LogType, logs []string) (string, string, bool, error) { + publicationCalledCount++ + switch { + case publicationCalledCount < numberOfPublicationRounds: + return "", "", false, errors.New("transport") + default: + return "", "", false, nil + } + }, + } + db, cleanup := makeTempDB(t) + defer cleanup() + k := makeKnapsack(t, db) + e, err := NewExtension(context.TODO(), m, k, ExtensionOpts{MaxBytesPerBatch: startingBatchLimitBytes}) + require.Nil(t, err) + + // expect each subsequent failed call to reduce the batch size until the min threshold is reached + expectedMaxValue := e.Opts.MaxBytesPerBatch + for i := 0; i < numberOfPublicationRounds; i++ { + // set the batch state to have started earlier than the 20 seconds threshold ago + e.logPublicationState.BeginBatch(time.Now().Add(-21*time.Second), true) + err = e.writeLogsWithReenroll(context.Background(), logger.LogTypeSnapshot, []string{"foobar"}, true) + assert.NotNil(t, err) + assert.Less(t, e.logPublicationState.currentMaxBytesPerBatch, expectedMaxValue) + // always expect that these values are reset between runs + assert.Equal(t, time.Time{}, e.logPublicationState.currentBatchStartTime) + assert.False(t, e.logPublicationState.currentBatchBufferFilled) + expectedMaxValue = e.logPublicationState.currentMaxBytesPerBatch + } + + // now run a successful publication loop without filling the buffer - we expect + // this should have no effect on the current batch size + err = e.writeLogsWithReenroll(context.Background(), logger.LogTypeSnapshot, []string{"foobar"}, true) + assert.Nil(t, err) + assert.Equal(t, expectedMaxValue, e.logPublicationState.currentMaxBytesPerBatch) + + // this time mark the buffer as filled for subsequent successful calls and expect that we move back up towards the original batch limit + for i := 0; i < numberOfPublicationRounds; i++ { + e.logPublicationState.BeginBatch(time.Now(), true) + err = e.writeLogsWithReenroll(context.Background(), logger.LogTypeSnapshot, []string{"foobar"}, true) + assert.Nil(t, err) + assert.Greater(t, e.logPublicationState.currentMaxBytesPerBatch, expectedMaxValue) + // always expect that these values are reset between runs + assert.Equal(t, time.Time{}, e.logPublicationState.currentBatchStartTime) + assert.False(t, e.logPublicationState.currentBatchBufferFilled) + expectedMaxValue = e.logPublicationState.currentMaxBytesPerBatch + } + + // lastly expect that we've returned to our baseline state + assert.Equal(t, e.logPublicationState.currentMaxBytesPerBatch, startingBatchLimitBytes) +} + +func TestExtensionLogPublicationIgnoresNonTimeoutErrors(t *testing.T) { + startingBatchLimitBytes := minBytesPerBatch * 4 + m := &mock.KolideService{ + PublishLogsFunc: func(ctx context.Context, nodeKey string, logType logger.LogType, logs []string) (string, string, bool, error) { + return "", "", false, errors.New("transport") + }, + } + db, cleanup := makeTempDB(t) + defer cleanup() + k := makeKnapsack(t, db) + e, err := NewExtension(context.TODO(), m, k, ExtensionOpts{MaxBytesPerBatch: startingBatchLimitBytes}) + require.Nil(t, err) + + // issue a few calls that error immediately, expect that the batch limit is unchanged from the original opts + for i := 0; i < 3; i++ { + e.logPublicationState.BeginBatch(time.Now(), true) + err = e.writeLogsWithReenroll(context.Background(), logger.LogTypeSnapshot, []string{"foobar"}, true) + // we still expect an error, but the batch limitation should not have changed + assert.NotNil(t, err) + assert.Equal(t, e.logPublicationState.currentMaxBytesPerBatch, startingBatchLimitBytes) + // always expect that these values are reset between runs + assert.Equal(t, time.Time{}, e.logPublicationState.currentBatchStartTime) + assert.False(t, e.logPublicationState.currentBatchBufferFilled) + } +} diff --git a/pkg/service/publish_logs.go b/pkg/service/publish_logs.go index bce891255..236795d8b 100644 --- a/pkg/service/publish_logs.go +++ b/pkg/service/publish_logs.go @@ -14,6 +14,13 @@ import ( pb "github.com/kolide/launcher/pkg/pb/launcher" ) +type contextKey string + +const ( + // PublicationCtxKey is used to set the relevant thresholds in context for reporting when logs are published + PublicationCtxKey contextKey = "log_publication_state" +) + type logCollection struct { NodeKey string `json:"node_key"` LogType logger.LogType @@ -179,6 +186,11 @@ func (mw logmw) PublishLogs(ctx context.Context, nodeKey string, logType logger. } } + pubStateVals, ok := ctx.Value(PublicationCtxKey).(map[string]int) + if !ok { + pubStateVals = make(map[string]int) + } + mw.knapsack.Slogger().Log(ctx, levelForError(err), message, // nolint:sloglint // it's fine to not have a constant or literal here "method", "PublishLogs", "uuid", uuid, @@ -188,6 +200,7 @@ func (mw logmw) PublishLogs(ctx context.Context, nodeKey string, logType logger. "reauth", reauth, "err", err, "took", time.Since(begin), + "publication_state", pubStateVals, ) }(time.Now())