diff --git a/pkg/spooledtempfile/spooled.go b/pkg/spooledtempfile/spooled.go index efe3f97..79246e0 100644 --- a/pkg/spooledtempfile/spooled.go +++ b/pkg/spooledtempfile/spooled.go @@ -10,7 +10,6 @@ import ( "strconv" "strings" "sync" - "sync/atomic" "time" ) @@ -26,13 +25,14 @@ const DefaultMaxRAMUsageFraction = 0.50 const memoryCheckInterval = 500 * time.Millisecond // globalMemoryCache is a struct representing global cache of memory usage data. -type memoryUsageData struct { +type globalMemoryCache struct { + sync.Mutex lastChecked time.Time lastFraction float64 } // memoryUsageCache is an atomic pointer to memoryUsageData. -var memoryUsageCache atomic.Pointer[memoryUsageData] +var memoryUsageCache *globalMemoryCache var spooledPool = sync.Pool{ New: func() interface{} { @@ -236,7 +236,7 @@ func (s *spooledTempFile) Close() error { // If we're above the RAM threshold, we don't want to keep the buffer around. if s.buf != nil && s.buf.Cap() > s.maxInMemorySize { s.buf = nil - } else { + } else if s.buf != nil { // Release the buffer s.buf.Reset() spooledPool.Put(s.buf) @@ -267,7 +267,7 @@ func (s *spooledTempFile) FileName() string { // exceeds s.maxRAMUsageFraction of total system memory. // This implementation is Linux-specific via /proc/meminfo. func (s *spooledTempFile) isSystemMemoryUsageHigh() bool { - usedFraction, err := getSystemMemoryUsedFraction() + usedFraction, err := getCachedMemoryUsage() if err != nil { // If we fail to get memory usage info, we conservatively return false, // or you may choose to return true to avoid in-memory usage. @@ -276,98 +276,90 @@ func (s *spooledTempFile) isSystemMemoryUsageHigh() bool { return usedFraction >= s.maxRAMUsageFraction } +func getCachedMemoryUsage() (float64, error) { + if memoryUsageCache == nil { + memoryUsageCache = &globalMemoryCache{} + } + + // 1) If it's still fresh, just return the cached value. + if time.Since(memoryUsageCache.lastChecked) < memoryCheckInterval { + return memoryUsageCache.lastFraction, nil + } + + memoryUsageCache.Lock() + defer memoryUsageCache.Unlock() + + // 2) Otherwise, do a fresh read (expensive). + fraction, err := getSystemMemoryUsedFraction() + if err != nil { + return 0, err + } + + // 3) Update the cache + memoryUsageCache.lastChecked = time.Now() + memoryUsageCache.lastFraction = fraction + + return fraction, nil +} + // getSystemMemoryUsedFraction parses /proc/meminfo on Linux to figure out // how much memory is used vs total. Returns fraction = used / total // This is a Linux-specific implementation. // This function is defined as a variable so it can be overridden in tests. // Now includes lock-free CAS caching to avoid hammering /proc/meminfo on every call. var getSystemMemoryUsedFraction = func() (float64, error) { - for { - // Atomically load the current pointer. - oldPtr := memoryUsageCache.Load() - if oldPtr != nil { - data := *oldPtr - // If it's still fresh, just return it. - if time.Since(data.lastChecked) < memoryCheckInterval { - return data.lastFraction, nil - } - } - - // Data is nil or stale -> we attempt to refresh. - // But first, double-check if someone else already updated - // between the time we loaded oldPtr and now. - againPtr := memoryUsageCache.Load() - if againPtr != oldPtr { - // Another goroutine already updated it => just loop again - // so we can use the fresh data. (No need to read /proc/meminfo.) + // We're the winners and need to refresh the data. + f, err := os.Open("/proc/meminfo") + if err != nil { + // If we cannot open /proc/meminfo, return an error + // or fallback if you prefer + return 0, fmt.Errorf("failed to open /proc/meminfo: %v", err) + } + defer f.Close() + + var memTotal, memAvailable, memFree, buffers, cached uint64 + scanner := bufio.NewScanner(f) + for scanner.Scan() { + line := scanner.Text() + fields := strings.Fields(line) + if len(fields) < 2 { continue } - - // We're the winners and need to refresh the data. - f, err := os.Open("/proc/meminfo") - if err != nil { - // If we cannot open /proc/meminfo, return an error - // or fallback if you prefer - return 0, fmt.Errorf("failed to open /proc/meminfo: %v", err) - } - defer f.Close() - - var memTotal, memAvailable, memFree, buffers, cached uint64 - scanner := bufio.NewScanner(f) - for scanner.Scan() { - line := scanner.Text() - fields := strings.Fields(line) - if len(fields) < 2 { - continue - } - key := strings.TrimRight(fields[0], ":") - value, _ := strconv.ParseUint(fields[1], 10, 64) - // value is typically in kB - switch key { - case "MemTotal": - memTotal = value - case "MemAvailable": - memAvailable = value - case "MemFree": - memFree = value - case "Buffers": - buffers = value - case "Cached": - cached = value - } - } - if err := scanner.Err(); err != nil { - return 0, fmt.Errorf("scanner error reading /proc/meminfo: %v", err) - } - - if memTotal == 0 { - return 0, fmt.Errorf("could not find MemTotal in /proc/meminfo") + key := strings.TrimRight(fields[0], ":") + value, _ := strconv.ParseUint(fields[1], 10, 64) + // value is typically in kB + switch key { + case "MemTotal": + memTotal = value + case "MemAvailable": + memAvailable = value + case "MemFree": + memFree = value + case "Buffers": + buffers = value + case "Cached": + cached = value } + } + if err := scanner.Err(); err != nil { + return 0, fmt.Errorf("scanner error reading /proc/meminfo: %v", err) + } - var used uint64 - if memAvailable > 0 { - // Linux 3.14+ has MemAvailable for better measure - used = memTotal - memAvailable - } else { - // Approximate available as free + buffers + cached - approxAvailable := memFree + buffers + cached - used = memTotal - approxAvailable - } + if memTotal == 0 { + return 0, fmt.Errorf("could not find MemTotal in /proc/meminfo") + } - fraction := float64(used) / float64(memTotal) + var used uint64 + if memAvailable > 0 { + // Linux 3.14+ has MemAvailable for better measure + used = memTotal - memAvailable + } else { + // Approximate available as free + buffers + cached + approxAvailable := memFree + buffers + cached + used = memTotal - approxAvailable + } - newData := &memoryUsageData{ - lastChecked: time.Now(), - lastFraction: fraction, - } + fraction := float64(used) / float64(memTotal) - // CAS to store the new data (only if oldPtr is still valid). - swapped := memoryUsageCache.CompareAndSwap(oldPtr, newData) - if swapped { - // We successfully updated => return the fresh fraction. - return fraction, nil - } - // If swap fails, it means another goroutine beat us to it. - // So we just loop around, load their data, and return that. - } + return fraction, nil }