diff --git a/executor/analyzetest/analyze_test.go b/executor/analyzetest/analyze_test.go index a9f8c6f12f915..b33631a8a9b50 100644 --- a/executor/analyzetest/analyze_test.go +++ b/executor/analyzetest/analyze_test.go @@ -3196,6 +3196,48 @@ func TestGlobalMemoryControlForAnalyze(t *testing.T) { tk0.MustExec(sql) } +func TestGlobalMemoryControlForPrepareAnalyze(t *testing.T) { + store, dom := testkit.CreateMockStoreAndDomain(t) + + tk0 := testkit.NewTestKit(t, store) + tk0.MustExec("set global tidb_mem_oom_action = 'cancel'") + tk0.MustExec("set global tidb_mem_quota_query = 209715200 ") // 200MB + tk0.MustExec("set global tidb_server_memory_limit = 5GB") + tk0.MustExec("set global tidb_server_memory_limit_sess_min_size = 128") + + sm := &testkit.MockSessionManager{ + PS: []*util.ProcessInfo{tk0.Session().ShowProcess()}, + } + dom.ServerMemoryLimitHandle().SetSessionManager(sm) + go dom.ServerMemoryLimitHandle().Run() + + tk0.MustExec("use test") + tk0.MustExec("create table t(a int)") + tk0.MustExec("insert into t select 1") + for i := 1; i <= 8; i++ { + tk0.MustExec("insert into t select * from t") // 256 Lines + } + sqlPrepare := "prepare stmt from 'analyze table t with 1.0 samplerate';" + sqlExecute := "execute stmt;" // Need about 100MB + require.NoError(t, failpoint.Enable("github.com/pingcap/tidb/util/memory/ReadMemStats", `return(536870912)`)) // 512MB + require.NoError(t, failpoint.Enable("github.com/pingcap/tidb/executor/mockAnalyzeMergeWorkerSlowConsume", `return(100)`)) + // won't be killed by tidb_mem_quota_query + tk0.MustExec(sqlPrepare) + tk0.MustExec(sqlExecute) + runtime.GC() + // killed by tidb_server_memory_limit + tk0.MustExec("set global tidb_server_memory_limit = 512MB") + _, err0 := tk0.Exec(sqlPrepare) + require.NoError(t, err0) + _, err1 := tk0.Exec(sqlExecute) + require.True(t, strings.Contains(err1.Error(), "Out Of Memory Quota!")) + runtime.GC() + require.NoError(t, failpoint.Disable("github.com/pingcap/tidb/util/memory/ReadMemStats")) + require.NoError(t, failpoint.Disable("github.com/pingcap/tidb/executor/mockAnalyzeMergeWorkerSlowConsume")) + tk0.MustExec(sqlPrepare) + tk0.MustExec(sqlExecute) +} + func TestGlobalMemoryControlForAutoAnalyze(t *testing.T) { store, dom := testkit.CreateMockStoreAndDomain(t) tk := testkit.NewTestKit(t, store) diff --git a/executor/executor.go b/executor/executor.go index cd83564c26853..167a6e6328e39 100644 --- a/executor/executor.go +++ b/executor/executor.go @@ -1977,10 +1977,20 @@ func ResetContextOfStmt(ctx sessionctx.Context, s ast.StmtNode) (err error) { vars.MemTracker.SetBytesLimit(vars.MemQuotaQuery) vars.MemTracker.ResetMaxConsumed() vars.DiskTracker.ResetMaxConsumed() - vars.MemTracker.SessionID = vars.ConnectionID + vars.MemTracker.SessionID.Store(vars.ConnectionID) vars.StmtCtx.TableStats = make(map[int64]interface{}) - if _, ok := s.(*ast.AnalyzeTableStmt); ok { + isAnalyze := false + if execStmt, ok := s.(*ast.ExecuteStmt); ok { + prepareStmt, err := plannercore.GetPreparedStmt(execStmt, vars) + if err != nil { + return err + } + _, isAnalyze = prepareStmt.PreparedAst.Stmt.(*ast.AnalyzeTableStmt) + } else if _, ok := s.(*ast.AnalyzeTableStmt); ok { + isAnalyze = true + } + if isAnalyze { sc.InitMemTracker(memory.LabelForAnalyzeMemory, -1) vars.MemTracker.SetBytesLimit(-1) vars.MemTracker.AttachTo(GlobalAnalyzeMemoryTracker) @@ -2000,7 +2010,7 @@ func ResetContextOfStmt(ctx sessionctx.Context, s ast.StmtNode) (err error) { action.SetLogHook(logOnQueryExceedMemQuota) vars.MemTracker.SetActionOnExceed(action) } - sc.MemTracker.SessionID = vars.ConnectionID + sc.MemTracker.SessionID.Store(vars.ConnectionID) sc.MemTracker.AttachTo(vars.MemTracker) sc.InitDiskTracker(memory.LabelForSQLText, -1) globalConfig := config.GetGlobalConfig() diff --git a/executor/issuetest/executor_issue_test.go b/executor/issuetest/executor_issue_test.go index 7f39baba63851..7ebe79a10f9e9 100644 --- a/executor/issuetest/executor_issue_test.go +++ b/executor/issuetest/executor_issue_test.go @@ -1357,7 +1357,7 @@ func TestIssue42662(t *testing.T) { tk := testkit.NewTestKit(t, store) tk.Session().GetSessionVars().ConnectionID = 12345 tk.Session().GetSessionVars().MemTracker = memory.NewTracker(memory.LabelForSession, -1) - tk.Session().GetSessionVars().MemTracker.SessionID = 12345 + tk.Session().GetSessionVars().MemTracker.SessionID.Store(12345) tk.Session().GetSessionVars().MemTracker.IsRootTrackerOfSess = true sm := &testkit.MockSessionManager{ diff --git a/util/memory/action.go b/util/memory/action.go index 75b587e2157f9..d2d5b457436ef 100644 --- a/util/memory/action.go +++ b/util/memory/action.go @@ -139,7 +139,7 @@ func (a *PanicOnExceed) Action(t *Tracker) { if !a.acted { if a.logHook == nil { logutil.BgLogger().Warn("memory exceeds quota", - zap.Uint64("connID", t.SessionID), zap.Error(errMemExceedThreshold.GenWithStackByArgs(t.label, t.BytesConsumed(), t.GetBytesLimit(), t.String()))) + zap.Uint64("conn", t.SessionID.Load()), zap.Error(errMemExceedThreshold.GenWithStackByArgs(t.label, t.BytesConsumed(), t.GetBytesLimit(), t.String()))) } else { a.logHook(a.ConnID) } diff --git a/util/memory/memstats.go b/util/memory/memstats.go index 9cc4a3b14fb5a..4ea192620bee2 100644 --- a/util/memory/memstats.go +++ b/util/memory/memstats.go @@ -37,7 +37,7 @@ func ReadMemStats() (memStats *runtime.MemStats) { } failpoint.Inject("ReadMemStats", func(val failpoint.Value) { injectedSize := val.(int) - memStats.HeapInuse += uint64(injectedSize) + memStats = &runtime.MemStats{HeapInuse: memStats.HeapInuse + uint64(injectedSize)} }) return } diff --git a/util/memory/tracker.go b/util/memory/tracker.go index bd682260393fb..700146e94e0d3 100644 --- a/util/memory/tracker.go +++ b/util/memory/tracker.go @@ -88,11 +88,11 @@ type Tracker struct { } label int // Label of this "Tracker". // following fields are used with atomic operations, so make them 64-byte aligned. - bytesConsumed int64 // Consumed bytes. - bytesReleased int64 // Released bytes. - maxConsumed atomicutil.Int64 // max number of bytes consumed during execution. - SessionID uint64 // SessionID indicates the sessionID the tracker is bound. - NeedKill atomic.Bool // NeedKill indicates whether this session need kill because OOM + bytesConsumed int64 // Consumed bytes. + bytesReleased int64 // Released bytes. + maxConsumed atomicutil.Int64 // max number of bytes consumed during execution. + SessionID atomicutil.Uint64 // SessionID indicates the sessionID the tracker is bound. + NeedKill atomic.Bool // NeedKill indicates whether this session need kill because OOM NeedKillReceived sync.Once IsRootTrackerOfSess bool // IsRootTrackerOfSess indicates whether this tracker is bound for session isGlobal bool // isGlobal indicates whether this tracker is global tracker @@ -459,7 +459,7 @@ func (t *Tracker) Consume(bs int64) { sessionRootTracker.NeedKillReceived.Do( func() { logutil.BgLogger().Warn("global memory controller, NeedKill signal is received successfully", - zap.Uint64("connID", sessionRootTracker.SessionID)) + zap.Uint64("conn", sessionRootTracker.SessionID.Load())) }) tryActionLastOne(&sessionRootTracker.actionMuForHardLimit, sessionRootTracker) } diff --git a/util/servermemorylimit/servermemorylimit.go b/util/servermemorylimit/servermemorylimit.go index 38a679a4ad755..436a375b72842 100644 --- a/util/servermemorylimit/servermemorylimit.go +++ b/util/servermemorylimit/servermemorylimit.go @@ -138,12 +138,13 @@ func killSessIfNeeded(s *sessionToBeKilled, bt uint64, sm util.SessionManager) { if instanceStats.HeapInuse > bt { t := memory.MemUsageTop1Tracker.Load() if t != nil { + sessionID := t.SessionID.Load() memUsage := t.BytesConsumed() // If the memory usage of the top1 session is less than tidb_server_memory_limit_sess_min_size, we do not need to kill it. if uint64(memUsage) < limitSessMinSize { memory.MemUsageTop1Tracker.CompareAndSwap(t, nil) t = nil - } else if info, ok := sm.GetProcessInfo(t.SessionID); ok { + } else if info, ok := sm.GetProcessInfo(sessionID); ok { logutil.BgLogger().Warn("global memory controller tries to kill the top1 memory consumer", zap.Uint64("connID", info.ID), zap.String("sql digest", info.Digest), @@ -152,7 +153,7 @@ func killSessIfNeeded(s *sessionToBeKilled, bt uint64, sm util.SessionManager) { zap.Uint64("heap inuse", instanceStats.HeapInuse), zap.Int64("sql memory usage", info.MemTracker.BytesConsumed()), ) - s.sessionID = t.SessionID + s.sessionID = sessionID s.sqlStartTime = info.Time s.isKilling = true s.sessionTracker = t