Skip to content

Commit

Permalink
add dynamic buffer size for log publication (#1630)
Browse files Browse the repository at this point in the history
  • Loading branch information
zackattack01 authored Mar 1, 2024
1 parent a3198f4 commit 7b1b2f1
Show file tree
Hide file tree
Showing 4 changed files with 302 additions and 36 deletions.
97 changes: 61 additions & 36 deletions pkg/osquery/extension.go
Original file line number Diff line number Diff line change
Expand Up @@ -36,14 +36,15 @@ import (
// and servers -- It provides a grpc and jsonrpc interface for
// osquery. It does not provide any tables.
type Extension struct {
NodeKey string
Opts ExtensionOpts
knapsack types.Knapsack
serviceClient service.KolideService
enrollMutex sync.Mutex
done chan struct{}
interrupted bool
slogger *slog.Logger
NodeKey string
Opts ExtensionOpts
knapsack types.Knapsack
serviceClient service.KolideService
enrollMutex sync.Mutex
done chan struct{}
interrupted bool
slogger *slog.Logger
logPublicationState *logPublicationState
}

const (
Expand Down Expand Up @@ -146,12 +147,13 @@ func NewExtension(ctx context.Context, client service.KolideService, k types.Kna
}

return &Extension{
slogger: slogger,
serviceClient: client,
knapsack: k,
NodeKey: nodekey,
Opts: opts,
done: make(chan struct{}),
slogger: slogger,
serviceClient: client,
knapsack: k,
NodeKey: nodekey,
Opts: opts,
done: make(chan struct{}),
logPublicationState: NewLogPublicationState(opts.MaxBytesPerBatch),
}, nil
}

Expand Down Expand Up @@ -647,12 +649,14 @@ func bucketNameFromLogType(typ logger.LogType) (string, error) {
// buffers.
func (e *Extension) writeAndPurgeLogs() {
for _, typ := range []logger.LogType{logger.LogTypeStatus, logger.LogTypeString} {
originalBatchState := e.logPublicationState.CurrentValues()
// Write logs
err := e.writeBufferedLogsForType(typ)
if err != nil {
e.slogger.Log(context.TODO(), slog.LevelInfo,
"sending logs",
"type", typ.String(),
"attempted_publication_state", originalBatchState,
"err", err,
)
}
Expand Down Expand Up @@ -701,6 +705,7 @@ func (e *Extension) writeBufferedLogsForType(typ logger.LogType) error {
// Collect up logs to be sent
var logs []string
var logIDs [][]byte
bufferFilled := false
err = e.knapsack.BboltDB().View(func(tx *bbolt.Tx) error {
b := tx.Bucket([]byte(bucketName))

Expand All @@ -714,7 +719,7 @@ func (e *Extension) writeBufferedLogsForType(typ logger.LogType) error {
// 3. Else append it
//
// Note that (1) must come first, otherwise (2) will always trigger.
if len(v) > e.Opts.MaxBytesPerBatch {
if e.logPublicationState.ExceedsCurrentBatchThreshold(len(v)) {
// Discard logs that are too big
logheadSize := minInt(len(v), 100)
e.slogger.Log(context.TODO(), slog.LevelInfo,
Expand All @@ -724,8 +729,9 @@ func (e *Extension) writeBufferedLogsForType(typ logger.LogType) error {
"limit", e.Opts.MaxBytesPerBatch,
"loghead", string(v)[0:logheadSize],
)
} else if totalBytes+len(v) > e.Opts.MaxBytesPerBatch {
} else if e.logPublicationState.ExceedsCurrentBatchThreshold(totalBytes + len(v)) {
// Buffer is filled. Break the loop and come back later.
bufferFilled = true
break
} else {
logs = append(logs, string(v))
Expand Down Expand Up @@ -756,7 +762,14 @@ func (e *Extension) writeBufferedLogsForType(typ logger.LogType) error {
return nil
}

err = e.writeLogsWithReenroll(context.Background(), typ, logs, true)
// inform the publication state tracking whether this batch should be used to
// determine the appropriate limit
e.logPublicationState.BeginBatch(time.Now(), bufferFilled)
publicationCtx := context.WithValue(context.Background(),
service.PublicationCtxKey,
e.logPublicationState.CurrentValues(),
)
err = e.writeLogsWithReenroll(publicationCtx, typ, logs, true)
if err != nil {
return fmt.Errorf("writing logs: %w", err)
}
Expand All @@ -779,31 +792,35 @@ func (e *Extension) writeBufferedLogsForType(typ logger.LogType) error {
// Helper to allow for a single attempt at re-enrollment
func (e *Extension) writeLogsWithReenroll(ctx context.Context, typ logger.LogType, logs []string, reenroll bool) error {
_, _, invalid, err := e.serviceClient.PublishLogs(ctx, e.NodeKey, typ, logs)
if isNodeInvalidErr(err) {
invalid = true
} else if err != nil {
return fmt.Errorf("transport error sending logs: %w", err)
invalid = invalid || isNodeInvalidErr(err)
if !invalid && err == nil {
// publication was successful- update logPublicationState and move on
e.logPublicationState.EndBatch(logs, true)
return nil
}

if invalid {
if !reenroll {
return errors.New("enrollment invalid, reenroll disabled")
}
if err != nil {
// logPublicationState will determine whether this failure should impact
// the batch size limit based on the elapsed time
e.logPublicationState.EndBatch(logs, false)
return fmt.Errorf("transport error sending logs: %w", err)
}

e.RequireReenroll(ctx)
_, invalid, err := e.Enroll(ctx)
if err != nil {
return fmt.Errorf("enrollment invalid, reenrollment errored: %w", err)
}
if invalid {
return errors.New("enrollment invalid, reenrollment invalid")
}
if !reenroll {
return errors.New("enrollment invalid, reenroll disabled")
}

// Don't attempt reenroll after first attempt
return e.writeLogsWithReenroll(ctx, typ, logs, false)
e.RequireReenroll(ctx)
_, invalid, err = e.Enroll(ctx)
if err != nil {
return fmt.Errorf("enrollment invalid, reenrollment errored: %w", err)
}
if invalid {
return errors.New("enrollment invalid, reenrollment invalid")
}

return nil
// Don't attempt reenroll after first attempt
return e.writeLogsWithReenroll(ctx, typ, logs, false)
}

// purgeBufferedLogsForType flushes the log buffers for the provided type,
Expand Down Expand Up @@ -981,3 +998,11 @@ func minInt(a, b int) int {

return b
}

func maxInt(a, b int) int {
if a > b {
return a
}

return b
}
105 changes: 105 additions & 0 deletions pkg/osquery/log_publication_state.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,105 @@
package osquery

import "time"

const (
// minBytesPerBatch sets the minimum batch size to 0.5mb as lower bound for correction
minBytesPerBatch int = 524288
// batchIncrementAmount (0.5mb) is the incremental increase amount for the target batch
// size when previous runs have been successful
batchIncrementAmount int = 524288
// maxPublicationDuration is the total time a batch can take without an error triggering a reduction
// in the max batch size
maxPublicationDuration time.Duration = 20 * time.Second
)

// logPublicationState holds stateful logic to influence the log batch publication size
// depending on prior successes or failures. The primary intent here is to prevent repeatedly
// consuming the entire network bandwidth available to devices that are unable to ship
// the standard maxBytesPerBatch inside of the connection timeout (currently 30 seconds, enforced cloud side).
// Note that we always expect these batches to be sent sequentially, BeginBatch -> EndBatch
// this would need rework (likely state locking and batch ID tracking) to support concurrent batch publications
type logPublicationState struct {
// maxBytesPerBatch is passed in from the extension opts and respected
// as a fixed upper limit for batch size, regardless of publication success/failure rates
maxBytesPerBatch int
// currentMaxBytesPerBatch represents the (stateful) upper limit being enforced
currentMaxBytesPerBatch int
// currentBatchBufferFilled is used to indicate when a batch's success can be used
// to increase the threshold (we only want to increase the threshold after sending full, not partial, batches successfully)
currentBatchBufferFilled bool
currentBatchStartTime time.Time
}

func NewLogPublicationState(maxBytesPerBatch int) *logPublicationState {
return &logPublicationState{
maxBytesPerBatch: maxBytesPerBatch,
currentMaxBytesPerBatch: maxBytesPerBatch,
}
}

// BeginBatch sets the opening state before attempting to publish a batch of logs. Specifically, it must
// - note the time (to determine if an error later was timeout related)
// - note whether this batch is full (to determine whether success should increase the limit on success later)
func (lps *logPublicationState) BeginBatch(startTime time.Time, bufferFilled bool) {
lps.currentBatchStartTime = startTime
lps.currentBatchBufferFilled = bufferFilled
}

func (lps *logPublicationState) CurrentValues() map[string]int {
return map[string]int{
"options_batch_limit_bytes": lps.maxBytesPerBatch,
"current_batch_limit_bytes": lps.currentMaxBytesPerBatch,
}
}

func (lps *logPublicationState) EndBatch(logs []string, successful bool) {
// ensure we reset all batch state at the end
defer func() {
lps.currentBatchBufferFilled = false
lps.currentBatchStartTime = time.Time{}
}()

// we can always safely decrease the threshold for a failed batch, but
// shouldn't increase the threshold for a successful batch unless we've at
// least filled the buffer
if successful && !lps.currentBatchBufferFilled {
return
}

// in practice there could be one of a few different transport or timeout errors that bubble up
// depending on network conditions. instead of trying to keep up with all potential errors,
// only reduce the threshold if the calls are failing after more than 20 seconds
if !successful && time.Since(lps.currentBatchStartTime) < maxPublicationDuration {
return
}

if successful {
lps.increaseBatchThreshold()
return
}

lps.reduceBatchThreshold()
}

func (lps *logPublicationState) ExceedsCurrentBatchThreshold(amountBytes int) bool {
return amountBytes > lps.currentMaxBytesPerBatch
}

func (lps *logPublicationState) reduceBatchThreshold() {
if lps.currentMaxBytesPerBatch <= minBytesPerBatch {
return
}

newTargetThreshold := lps.currentMaxBytesPerBatch - batchIncrementAmount
lps.currentMaxBytesPerBatch = maxInt(newTargetThreshold, minBytesPerBatch)
}

func (lps *logPublicationState) increaseBatchThreshold() {
if lps.currentMaxBytesPerBatch >= lps.maxBytesPerBatch {
return
}

newTargetThreshold := lps.currentMaxBytesPerBatch + batchIncrementAmount
lps.currentMaxBytesPerBatch = minInt(newTargetThreshold, lps.maxBytesPerBatch)
}
123 changes: 123 additions & 0 deletions pkg/osquery/log_publication_state_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,123 @@
//nolint:paralleltest
package osquery

import (
"context"
"errors"
"testing"
"time"

"github.com/kolide/launcher/pkg/service/mock"
"github.com/osquery/osquery-go/plugin/logger"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
)

func TestExtensionLogPublicationHappyPath(t *testing.T) {
startingBatchLimitBytes := minBytesPerBatch * 4
m := &mock.KolideService{
PublishLogsFunc: func(ctx context.Context, nodeKey string, logType logger.LogType, logs []string) (string, string, bool, error) {
return "", "", false, nil
},
}
db, cleanup := makeTempDB(t)
defer cleanup()
k := makeKnapsack(t, db)
e, err := NewExtension(context.TODO(), m, k, ExtensionOpts{MaxBytesPerBatch: startingBatchLimitBytes})
require.Nil(t, err)

// issue a few successful calls, expect that the batch limit is unchanged from the original opts
for i := 0; i < 3; i++ {
e.logPublicationState.BeginBatch(time.Now(), true)
err = e.writeLogsWithReenroll(context.Background(), logger.LogTypeSnapshot, []string{"foobar"}, true)
assert.Nil(t, err)
assert.Equal(t, e.logPublicationState.currentMaxBytesPerBatch, startingBatchLimitBytes)
// always expect that these values are reset between runs
assert.Equal(t, time.Time{}, e.logPublicationState.currentBatchStartTime)
assert.False(t, e.logPublicationState.currentBatchBufferFilled)
}
}

func TestExtensionLogPublicationRespondsToNetworkTimeouts(t *testing.T) {
numberOfPublicationRounds := 3
publicationCalledCount := -1
// startingBatchLimitBytes is set this way to ensure sufficient room for correction in both directions
startingBatchLimitBytes := minBytesPerBatch * (numberOfPublicationRounds + 1)
m := &mock.KolideService{
PublishLogsFunc: func(ctx context.Context, nodeKey string, logType logger.LogType, logs []string) (string, string, bool, error) {
publicationCalledCount++
switch {
case publicationCalledCount < numberOfPublicationRounds:
return "", "", false, errors.New("transport")
default:
return "", "", false, nil
}
},
}
db, cleanup := makeTempDB(t)
defer cleanup()
k := makeKnapsack(t, db)
e, err := NewExtension(context.TODO(), m, k, ExtensionOpts{MaxBytesPerBatch: startingBatchLimitBytes})
require.Nil(t, err)

// expect each subsequent failed call to reduce the batch size until the min threshold is reached
expectedMaxValue := e.Opts.MaxBytesPerBatch
for i := 0; i < numberOfPublicationRounds; i++ {
// set the batch state to have started earlier than the 20 seconds threshold ago
e.logPublicationState.BeginBatch(time.Now().Add(-21*time.Second), true)
err = e.writeLogsWithReenroll(context.Background(), logger.LogTypeSnapshot, []string{"foobar"}, true)
assert.NotNil(t, err)
assert.Less(t, e.logPublicationState.currentMaxBytesPerBatch, expectedMaxValue)
// always expect that these values are reset between runs
assert.Equal(t, time.Time{}, e.logPublicationState.currentBatchStartTime)
assert.False(t, e.logPublicationState.currentBatchBufferFilled)
expectedMaxValue = e.logPublicationState.currentMaxBytesPerBatch
}

// now run a successful publication loop without filling the buffer - we expect
// this should have no effect on the current batch size
err = e.writeLogsWithReenroll(context.Background(), logger.LogTypeSnapshot, []string{"foobar"}, true)
assert.Nil(t, err)
assert.Equal(t, expectedMaxValue, e.logPublicationState.currentMaxBytesPerBatch)

// this time mark the buffer as filled for subsequent successful calls and expect that we move back up towards the original batch limit
for i := 0; i < numberOfPublicationRounds; i++ {
e.logPublicationState.BeginBatch(time.Now(), true)
err = e.writeLogsWithReenroll(context.Background(), logger.LogTypeSnapshot, []string{"foobar"}, true)
assert.Nil(t, err)
assert.Greater(t, e.logPublicationState.currentMaxBytesPerBatch, expectedMaxValue)
// always expect that these values are reset between runs
assert.Equal(t, time.Time{}, e.logPublicationState.currentBatchStartTime)
assert.False(t, e.logPublicationState.currentBatchBufferFilled)
expectedMaxValue = e.logPublicationState.currentMaxBytesPerBatch
}

// lastly expect that we've returned to our baseline state
assert.Equal(t, e.logPublicationState.currentMaxBytesPerBatch, startingBatchLimitBytes)
}

func TestExtensionLogPublicationIgnoresNonTimeoutErrors(t *testing.T) {
startingBatchLimitBytes := minBytesPerBatch * 4
m := &mock.KolideService{
PublishLogsFunc: func(ctx context.Context, nodeKey string, logType logger.LogType, logs []string) (string, string, bool, error) {
return "", "", false, errors.New("transport")
},
}
db, cleanup := makeTempDB(t)
defer cleanup()
k := makeKnapsack(t, db)
e, err := NewExtension(context.TODO(), m, k, ExtensionOpts{MaxBytesPerBatch: startingBatchLimitBytes})
require.Nil(t, err)

// issue a few calls that error immediately, expect that the batch limit is unchanged from the original opts
for i := 0; i < 3; i++ {
e.logPublicationState.BeginBatch(time.Now(), true)
err = e.writeLogsWithReenroll(context.Background(), logger.LogTypeSnapshot, []string{"foobar"}, true)
// we still expect an error, but the batch limitation should not have changed
assert.NotNil(t, err)
assert.Equal(t, e.logPublicationState.currentMaxBytesPerBatch, startingBatchLimitBytes)
// always expect that these values are reset between runs
assert.Equal(t, time.Time{}, e.logPublicationState.currentBatchStartTime)
assert.False(t, e.logPublicationState.currentBatchBufferFilled)
}
}
Loading

0 comments on commit 7b1b2f1

Please sign in to comment.