Skip to content

Commit

Permalink
hotfix: make the memory usage cache use a mutex instead of relying on…
Browse files Browse the repository at this point in the history
… CAS operations
  • Loading branch information
equals215 committed Jan 20, 2025
1 parent 97a7c50 commit 2fb6a1a
Showing 1 changed file with 78 additions and 86 deletions.
164 changes: 78 additions & 86 deletions pkg/spooledtempfile/spooled.go
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,6 @@ import (
"strconv"
"strings"
"sync"
"sync/atomic"
"time"
)

Expand All @@ -26,13 +25,14 @@ const DefaultMaxRAMUsageFraction = 0.50
const memoryCheckInterval = 500 * time.Millisecond

// globalMemoryCache is a struct representing global cache of memory usage data.
type memoryUsageData struct {
type globalMemoryCache struct {
sync.Mutex
lastChecked time.Time
lastFraction float64
}

// memoryUsageCache is an atomic pointer to memoryUsageData.
var memoryUsageCache atomic.Pointer[memoryUsageData]
var memoryUsageCache *globalMemoryCache

var spooledPool = sync.Pool{
New: func() interface{} {
Expand Down Expand Up @@ -236,7 +236,7 @@ func (s *spooledTempFile) Close() error {
// If we're above the RAM threshold, we don't want to keep the buffer around.
if s.buf != nil && s.buf.Cap() > s.maxInMemorySize {
s.buf = nil
} else {
} else if s.buf != nil {
// Release the buffer
s.buf.Reset()
spooledPool.Put(s.buf)
Expand Down Expand Up @@ -267,7 +267,7 @@ func (s *spooledTempFile) FileName() string {
// exceeds s.maxRAMUsageFraction of total system memory.
// This implementation is Linux-specific via /proc/meminfo.
func (s *spooledTempFile) isSystemMemoryUsageHigh() bool {
usedFraction, err := getSystemMemoryUsedFraction()
usedFraction, err := getCachedMemoryUsage()
if err != nil {
// If we fail to get memory usage info, we conservatively return false,
// or you may choose to return true to avoid in-memory usage.
Expand All @@ -276,98 +276,90 @@ func (s *spooledTempFile) isSystemMemoryUsageHigh() bool {
return usedFraction >= s.maxRAMUsageFraction
}

func getCachedMemoryUsage() (float64, error) {
if memoryUsageCache == nil {
memoryUsageCache = &globalMemoryCache{}
}

// 1) If it's still fresh, just return the cached value.
if time.Since(memoryUsageCache.lastChecked) < memoryCheckInterval {
return memoryUsageCache.lastFraction, nil
}

memoryUsageCache.Lock()
defer memoryUsageCache.Unlock()

// 2) Otherwise, do a fresh read (expensive).
fraction, err := getSystemMemoryUsedFraction()
if err != nil {
return 0, err
}

// 3) Update the cache
memoryUsageCache.lastChecked = time.Now()
memoryUsageCache.lastFraction = fraction

return fraction, nil
}

// getSystemMemoryUsedFraction parses /proc/meminfo on Linux to figure out
// how much memory is used vs total. Returns fraction = used / total
// This is a Linux-specific implementation.
// This function is defined as a variable so it can be overridden in tests.
// Now includes lock-free CAS caching to avoid hammering /proc/meminfo on every call.
var getSystemMemoryUsedFraction = func() (float64, error) {
for {
// Atomically load the current pointer.
oldPtr := memoryUsageCache.Load()
if oldPtr != nil {
data := *oldPtr
// If it's still fresh, just return it.
if time.Since(data.lastChecked) < memoryCheckInterval {
return data.lastFraction, nil
}
}

// Data is nil or stale -> we attempt to refresh.
// But first, double-check if someone else already updated
// between the time we loaded oldPtr and now.
againPtr := memoryUsageCache.Load()
if againPtr != oldPtr {
// Another goroutine already updated it => just loop again
// so we can use the fresh data. (No need to read /proc/meminfo.)
// We're the winners and need to refresh the data.
f, err := os.Open("/proc/meminfo")
if err != nil {
// If we cannot open /proc/meminfo, return an error
// or fallback if you prefer
return 0, fmt.Errorf("failed to open /proc/meminfo: %v", err)
}
defer f.Close()

var memTotal, memAvailable, memFree, buffers, cached uint64
scanner := bufio.NewScanner(f)
for scanner.Scan() {
line := scanner.Text()
fields := strings.Fields(line)
if len(fields) < 2 {
continue
}

// We're the winners and need to refresh the data.
f, err := os.Open("/proc/meminfo")
if err != nil {
// If we cannot open /proc/meminfo, return an error
// or fallback if you prefer
return 0, fmt.Errorf("failed to open /proc/meminfo: %v", err)
}
defer f.Close()

var memTotal, memAvailable, memFree, buffers, cached uint64
scanner := bufio.NewScanner(f)
for scanner.Scan() {
line := scanner.Text()
fields := strings.Fields(line)
if len(fields) < 2 {
continue
}
key := strings.TrimRight(fields[0], ":")
value, _ := strconv.ParseUint(fields[1], 10, 64)
// value is typically in kB
switch key {
case "MemTotal":
memTotal = value
case "MemAvailable":
memAvailable = value
case "MemFree":
memFree = value
case "Buffers":
buffers = value
case "Cached":
cached = value
}
}
if err := scanner.Err(); err != nil {
return 0, fmt.Errorf("scanner error reading /proc/meminfo: %v", err)
}

if memTotal == 0 {
return 0, fmt.Errorf("could not find MemTotal in /proc/meminfo")
key := strings.TrimRight(fields[0], ":")
value, _ := strconv.ParseUint(fields[1], 10, 64)
// value is typically in kB
switch key {
case "MemTotal":
memTotal = value
case "MemAvailable":
memAvailable = value
case "MemFree":
memFree = value
case "Buffers":
buffers = value
case "Cached":
cached = value
}
}
if err := scanner.Err(); err != nil {
return 0, fmt.Errorf("scanner error reading /proc/meminfo: %v", err)
}

var used uint64
if memAvailable > 0 {
// Linux 3.14+ has MemAvailable for better measure
used = memTotal - memAvailable
} else {
// Approximate available as free + buffers + cached
approxAvailable := memFree + buffers + cached
used = memTotal - approxAvailable
}
if memTotal == 0 {
return 0, fmt.Errorf("could not find MemTotal in /proc/meminfo")
}

fraction := float64(used) / float64(memTotal)
var used uint64
if memAvailable > 0 {
// Linux 3.14+ has MemAvailable for better measure
used = memTotal - memAvailable
} else {
// Approximate available as free + buffers + cached
approxAvailable := memFree + buffers + cached
used = memTotal - approxAvailable
}

newData := &memoryUsageData{
lastChecked: time.Now(),
lastFraction: fraction,
}
fraction := float64(used) / float64(memTotal)

// CAS to store the new data (only if oldPtr is still valid).
swapped := memoryUsageCache.CompareAndSwap(oldPtr, newData)
if swapped {
// We successfully updated => return the fresh fraction.
return fraction, nil
}
// If swap fails, it means another goroutine beat us to it.
// So we just loop around, load their data, and return that.
}
return fraction, nil
}

0 comments on commit 2fb6a1a

Please sign in to comment.