Skip to content

Commit

Permalink
Use sync.Map to prevent race on map access in osquery-perf (#24501)
Browse files Browse the repository at this point in the history
  • Loading branch information
dantecatalfamo authored Dec 20, 2024
1 parent 7aa0433 commit ccb44a3
Showing 1 changed file with 40 additions and 23 deletions.
63 changes: 40 additions & 23 deletions cmd/osquery-perf/agent.go
Original file line number Diff line number Diff line change
Expand Up @@ -514,8 +514,7 @@ type agent struct {
MDMCheckInInterval time.Duration
DiskEncryptionEnabled bool

scheduledQueriesMu sync.Mutex // protects the below members
scheduledQueryData map[string]scheduledQuery
scheduledQueryData *sync.Map
// bufferedResults contains result logs that are buffered when
// /api/v1/osquery/log requests to the Fleet server fail.
//
Expand Down Expand Up @@ -668,6 +667,7 @@ func newAgent(
disableFleetDesktop: disableFleetDesktop,
loggerTLSMaxLines: loggerTLSMaxLines,
bufferedResults: make(map[resultLog]int),
scheduledQueryData: new(sync.Map),
}
}

Expand Down Expand Up @@ -777,28 +777,38 @@ func (a *agent) runLoop(i int, onlyAlreadyEnrolled bool) {
// check if we have any scheduled queries that should be returning results
var results []resultLog
now := time.Now().Unix()
a.scheduledQueriesMu.Lock()
prevCount := a.countBuffered()
for queryName, query := range a.scheduledQueryData {

// NOTE The goroutine that pulls in new configurations
// MAY replace this map if it happens to run at the
// exact same time. The result would be. The result
// would be that the query lastRun does not get
// updated and cause the query to run more times than
// expected.
queryData := a.scheduledQueryData
queryData.Range(func(key, value any) bool {
queryName := key.(string)
query := value.(scheduledQuery)

if query.lastRun == 0 || now >= (query.lastRun+int64(query.ScheduleInterval)) {
results = append(results, resultLog{
packName: query.packName,
queryName: query.Name,
numRows: int(query.numRows),
})
// Update lastRun
v := a.scheduledQueryData[queryName]
v.lastRun = now
a.scheduledQueryData[queryName] = v
query.lastRun = now
queryData.Store(queryName, query)
}
}

return true
})
if prevCount+len(results) < 1_000_000 { // osquery buffered_log_max is 1M
a.addToBuffer(results)
}
a.sendLogsBatch()
newBufferedCount := a.countBuffered() - prevCount
a.stats.UpdateBufferedLogs(newBufferedCount)
a.scheduledQueriesMu.Unlock()
}
}

Expand Down Expand Up @@ -1518,7 +1528,16 @@ func (a *agent) config() error {
return fmt.Errorf("json parse at config: %w", err)
}

scheduledQueryData := make(map[string]scheduledQuery)
existingLastRunData := make(map[string]int64)

a.scheduledQueryData.Range(func(key, value any) bool {
existingLastRunData[key.(string)] = value.(scheduledQuery).lastRun

return true
})

newScheduledQueryData := new(sync.Map)

for packName, pack := range parsedResp.Packs {
for queryName, query := range pack.Queries {
m, ok := query.(map[string]interface{})
Expand Down Expand Up @@ -1546,17 +1565,14 @@ func (a *agent) config() error {
q.Query = m["query"].(string)

scheduledQueryName := packName + "_" + queryName
if existingEntry, ok := a.scheduledQueryData[scheduledQueryName]; ok {
// Keep lastRun if the query is already scheduled.
q.lastRun = existingEntry.lastRun
if lastRun, ok := existingLastRunData[scheduledQueryName]; ok {
q.lastRun = lastRun
}
scheduledQueryData[scheduledQueryName] = q
newScheduledQueryData.Store(scheduledQueryName, q)
}
}

a.scheduledQueriesMu.Lock()
a.scheduledQueryData = scheduledQueryData
a.scheduledQueriesMu.Unlock()
a.scheduledQueryData = newScheduledQueryData

return nil
}
Expand Down Expand Up @@ -1852,13 +1868,12 @@ func (a *agent) runPolicy(query string) []map[string]string {
}

func (a *agent) randomQueryStats() []map[string]string {
a.scheduledQueriesMu.Lock()
defer a.scheduledQueriesMu.Unlock()

var stats []map[string]string
for scheduledQuery := range a.scheduledQueryData {
a.scheduledQueryData.Range(func(key, value any) bool {
queryName := key.(string)

stats = append(stats, map[string]string{
"name": scheduledQuery,
"name": queryName,
"delimiter": "_",
"average_memory": fmt.Sprint(rand.Intn(200) + 10),
"denylisted": "false",
Expand All @@ -1871,7 +1886,9 @@ func (a *agent) randomQueryStats() []map[string]string {
"wall_time": fmt.Sprint(rand.Intn(4) + 1),
"wall_time_ms": fmt.Sprint(rand.Intn(4000) + 10),
})
}

return true
})
return stats
}

Expand Down

0 comments on commit ccb44a3

Please sign in to comment.