diff --git a/executor/aggregate.go b/executor/aggregate.go index 2b580849e133b..868691763eb41 100644 --- a/executor/aggregate.go +++ b/executor/aggregate.go @@ -32,6 +32,7 @@ import ( "github.com/pingcap/tidb/util/codec" "github.com/pingcap/tidb/util/logutil" "github.com/pingcap/tidb/util/mathutil" + "github.com/pingcap/tidb/util/memory" "github.com/pingcap/tidb/util/set" "github.com/spaolacci/murmur3" "go.uber.org/zap" @@ -70,7 +71,8 @@ type HashAggPartialWorker struct { groupKey [][]byte // chk stores the input data from child, // and is reused by childExec and partial worker. - chk *chunk.Chunk + chk *chunk.Chunk + memTracker *memory.Tracker } // HashAggFinalWorker indicates the final workers of parallel hash agg execution, @@ -166,6 +168,8 @@ type HashAggExec struct { isUnparallelExec bool prepared bool executed bool + + memTracker *memory.Tracker // track memory usage. } // HashAggInput indicates the input of hash agg exec. @@ -199,6 +203,7 @@ func (d *HashAggIntermData) getPartialResultBatch(sc *stmtctx.StatementContext, // Close implements the Executor Close interface. func (e *HashAggExec) Close() error { if e.isUnparallelExec { + e.memTracker.Consume(-e.childResult.MemoryUsage()) e.childResult = nil e.groupSet = nil e.partialResultMap = nil @@ -221,7 +226,8 @@ func (e *HashAggExec) Close() error { } } for _, ch := range e.partialInputChs { - for range ch { + for chk := range ch { + e.memTracker.Consume(-chk.MemoryUsage()) } } for range e.finalOutputCh { @@ -250,6 +256,9 @@ func (e *HashAggExec) Open(ctx context.Context) error { } e.prepared = false + e.memTracker = memory.NewTracker(e.id, -1) + e.memTracker.AttachTo(e.ctx.GetSessionVars().StmtCtx.MemTracker) + if e.isUnparallelExec { e.initForUnparallelExec() return nil @@ -263,6 +272,7 @@ func (e *HashAggExec) initForUnparallelExec() { e.partialResultMap = make(aggPartialResultMapper) e.groupKeyBuffer = make([][]byte, 0, 8) e.childResult = newFirstChunk(e.children[0]) + e.memTracker.Consume(e.childResult.MemoryUsage()) } func (e *HashAggExec) initForParallelExec(ctx sessionctx.Context) { @@ -298,13 +308,17 @@ func (e *HashAggExec) initForParallelExec(ctx sessionctx.Context) { groupByItems: e.GroupByItems, chk: newFirstChunk(e.children[0]), groupKey: make([][]byte, 0, 8), + memTracker: e.memTracker, } - + e.memTracker.Consume(w.chk.MemoryUsage()) e.partialWorkers[i] = w - e.inputCh <- &HashAggInput{ + + input := &HashAggInput{ chk: newFirstChunk(e.children[0]), giveBackCh: w.inputCh, } + e.memTracker.Consume(input.chk.MemoryUsage()) + e.inputCh <- input } // Init final workers. @@ -356,6 +370,7 @@ func (w *HashAggPartialWorker) run(ctx sessionctx.Context, waitGroup *sync.WaitG if needShuffle { w.shuffleIntermData(sc, finalConcurrency) } + w.memTracker.Consume(-w.chk.MemoryUsage()) waitGroup.Done() }() for { @@ -606,20 +621,28 @@ func (e *HashAggExec) fetchChildData(ctx context.Context) { } chk = input.chk } + mSize := chk.MemoryUsage() err = Next(ctx, e.children[0], chk) if err != nil { e.finalOutputCh <- &AfFinalResult{err: err} + e.memTracker.Consume(-mSize) return } if chk.NumRows() == 0 { + e.memTracker.Consume(-mSize) return } + e.memTracker.Consume(chk.MemoryUsage() - mSize) input.giveBackCh <- chk } } func (e *HashAggExec) waitPartialWorkerAndCloseOutputChs(waitGroup *sync.WaitGroup) { waitGroup.Wait() + close(e.inputCh) + for input := range e.inputCh { + e.memTracker.Consume(-input.chk.MemoryUsage()) + } for _, ch := range e.partialOutputChs { close(ch) } @@ -733,7 +756,9 @@ func (e *HashAggExec) unparallelExec(ctx context.Context, chk *chunk.Chunk) erro // execute fetches Chunks from src and update each aggregate function for each row in Chunk. func (e *HashAggExec) execute(ctx context.Context) (err error) { for { + mSize := e.childResult.MemoryUsage() err := Next(ctx, e.children[0], e.childResult) + e.memTracker.Consume(e.childResult.MemoryUsage() - mSize) if err != nil { return err } @@ -800,6 +825,8 @@ type StreamAggExec struct { partialResults []aggfuncs.PartialResult groupRows []chunk.Row childResult *chunk.Chunk + + memTracker *memory.Tracker // track memory usage. } // Open implements the Executor Open interface. @@ -818,11 +845,16 @@ func (e *StreamAggExec) Open(ctx context.Context) error { e.partialResults = append(e.partialResults, aggFunc.AllocPartialResult()) } + // bytesLimit <= 0 means no limit, for now we just track the memory footprint + e.memTracker = memory.NewTracker(e.id, -1) + e.memTracker.AttachTo(e.ctx.GetSessionVars().StmtCtx.MemTracker) + e.memTracker.Consume(e.childResult.MemoryUsage()) return nil } // Close implements the Executor Close interface. func (e *StreamAggExec) Close() error { + e.memTracker.Consume(-e.childResult.MemoryUsage()) e.childResult = nil e.groupChecker.reset() return e.baseExecutor.Close() @@ -910,7 +942,9 @@ func (e *StreamAggExec) consumeCurGroupRowsAndFetchChild(ctx context.Context, ch return err } + mSize := e.childResult.MemoryUsage() err = Next(ctx, e.children[0], e.childResult) + e.memTracker.Consume(e.childResult.MemoryUsage() - mSize) if err != nil { return err } diff --git a/executor/explain_test.go b/executor/explain_test.go index 3af5cc2321687..96c2880beff78 100644 --- a/executor/explain_test.go +++ b/executor/explain_test.go @@ -129,7 +129,7 @@ func (s *testSuite1) TestExplainAnalyzeMemory(c *C) { func (s *testSuite1) checkMemoryInfo(c *C, tk *testkit.TestKit, sql string) { memCol := 5 - ops := []string{"Join", "Reader", "Top", "Sort", "LookUp", "Projection", "Selection"} + ops := []string{"Join", "Reader", "Top", "Sort", "LookUp", "Projection", "Selection", "Agg"} rows := tk.MustQuery(sql).Rows() for _, row := range rows { strs := make([]string, len(row)) @@ -165,7 +165,10 @@ func (s *testSuite1) TestMemoryUsageAfterClose(c *C) { } SQLs := []string{"select v+abs(k) from t", "select v from t where abs(v) > 0", - "select v from t order by v"} + "select v from t order by v", + "select count(v) from t", // StreamAgg + "select count(v) from t group by v", // HashAgg + } for _, sql := range SQLs { tk.MustQuery(sql) c.Assert(tk.Se.GetSessionVars().StmtCtx.MemTracker.BytesConsumed(), Equals, int64(0))