Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

executor: support spill intermediate data for unparalleled hash agg #25714

Merged
merged 18 commits into from
Jul 15, 2021
Merged
213 changes: 186 additions & 27 deletions executor/aggregate.go
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@ import (
"github.com/pingcap/errors"
"github.com/pingcap/failpoint"
"github.com/pingcap/parser/mysql"
"github.com/pingcap/tidb/config"
"github.com/pingcap/tidb/executor/aggfuncs"
"github.com/pingcap/tidb/expression"
"github.com/pingcap/tidb/sessionctx"
Expand All @@ -34,6 +35,7 @@ import (
"github.com/pingcap/tidb/types/json"
"github.com/pingcap/tidb/util/chunk"
"github.com/pingcap/tidb/util/codec"
"github.com/pingcap/tidb/util/disk"
"github.com/pingcap/tidb/util/execdetails"
"github.com/pingcap/tidb/util/hack"
"github.com/pingcap/tidb/util/logutil"
Expand Down Expand Up @@ -191,9 +193,30 @@ type HashAggExec struct {
prepared bool
executed bool

memTracker *memory.Tracker // track memory usage.
memTracker *memory.Tracker // track memory usage.
wshwsh12 marked this conversation as resolved.
Show resolved Hide resolved
diskTracker *disk.Tracker

stats *HashAggRuntimeStats

// listInDisk is the chunks to store row values for spilling data.
wshwsh12 marked this conversation as resolved.
Show resolved Hide resolved
// The HashAggExec may enter `spill mode` multiple times, and all spill data will append to ListInDisk.
wshwsh12 marked this conversation as resolved.
Show resolved Hide resolved
listInDisk *chunk.ListInDisk
// numOfSpilledChks indicates the num of spilling chunk after the last round of processing is over.
wshwsh12 marked this conversation as resolved.
Show resolved Hide resolved
// After one round of processing is over, no data spilling again means that all data has been processed.
wshwsh12 marked this conversation as resolved.
Show resolved Hide resolved
numOfSpilledChks int
// offsetOfSpillChks indicates the num of processed chunk in disk.
wshwsh12 marked this conversation as resolved.
Show resolved Hide resolved
// In the one round of processing, we need process all data spilled in the last round.
wshwsh12 marked this conversation as resolved.
Show resolved Hide resolved
offsetOfSpillChks int
wshwsh12 marked this conversation as resolved.
Show resolved Hide resolved

wshwsh12 marked this conversation as resolved.
Show resolved Hide resolved
// inSpillMode indicates whether HashAgg is in `spill mode`.
// When HashAgg is in `spill mode`, keep the tuple in partialResultMap no longer growing.
wshwsh12 marked this conversation as resolved.
Show resolved Hide resolved
inSpillMode uint32
// tmpChkForSpill is the temp chunk for spilling.
tmpChkForSpill *chunk.Chunk
// spillAction save the Action for spilling.
spillAction *AggSpillDiskAction
// isChildDrained indicates whether the all data from child has been taken out.
isChildDrained bool
}

// HashAggInput indicates the input of hash agg exec.
Expand Down Expand Up @@ -227,13 +250,21 @@ func (d *HashAggIntermData) getPartialResultBatch(sc *stmtctx.StatementContext,
// Close implements the Executor Close interface.
func (e *HashAggExec) Close() error {
if e.isUnparallelExec {
var firstErr error
e.childResult = nil
e.groupSet, _ = set.NewStringSetWithMemoryUsage()
e.partialResultMap = nil
if e.memTracker != nil {
e.memTracker.ReplaceBytesUsed(0)
}
return e.baseExecutor.Close()
if e.listInDisk != nil {
firstErr = e.listInDisk.Close()
}
e.spillAction, e.tmpChkForSpill = nil, nil
if err := e.baseExecutor.Close(); firstErr == nil {
firstErr = err
}
return firstErr
}
if e.parallelExecInitialized {
// `Close` may be called after `Open` without calling `Next` in test.
Expand Down Expand Up @@ -301,6 +332,17 @@ func (e *HashAggExec) initForUnparallelExec() {
e.groupKeyBuffer = make([][]byte, 0, 8)
e.childResult = newFirstChunk(e.children[0])
e.memTracker.Consume(e.childResult.MemoryUsage())

e.offsetOfSpillChks, e.numOfSpilledChks = 0, 0
e.executed, e.isChildDrained = false, false
e.listInDisk = chunk.NewListInDisk(retTypes(e.children[0]))
e.tmpChkForSpill = newFirstChunk(e.children[0])
if e.ctx.GetSessionVars().TrackAggregateMemoryUsage && config.GetGlobalConfig().OOMUseTmpStorage {
e.diskTracker = disk.NewTracker(e.id, -1)
e.diskTracker.AttachTo(e.ctx.GetSessionVars().StmtCtx.DiskTracker)
e.listInDisk.GetDiskTracker().AttachTo(e.diskTracker)
e.ctx.GetSessionVars().StmtCtx.MemTracker.FallbackOldAndSetNewActionForSoftLimit(e.ActionSpill())
}
}

func (e *HashAggExec) initForParallelExec(ctx sessionctx.Context) {
Expand Down Expand Up @@ -853,10 +895,32 @@ func (e *HashAggExec) parallelExec(ctx context.Context, chk *chunk.Chunk) error

// unparallelExec executes hash aggregation algorithm in single thread.
func (e *HashAggExec) unparallelExec(ctx context.Context, chk *chunk.Chunk) error {
// In this stage we consider all data from src as a single group.
if !e.prepared {
err := e.execute(ctx)
if err != nil {
chk.Reset()
for {
if e.prepared {
// Since we return e.maxChunkSize rows every time, so we should not traverse
// `groupSet` because of its randomness.
for ; e.cursor4GroupKey < len(e.groupKeys); e.cursor4GroupKey++ {
partialResults := e.getPartialResults(e.groupKeys[e.cursor4GroupKey])
if len(e.PartialAggFuncs) == 0 {
chk.SetNumVirtualRows(chk.NumRows() + 1)
}
for i, af := range e.PartialAggFuncs {
if err := af.AppendFinalResult2Chunk(e.ctx, partialResults[i], chk); err != nil {
return err
}
}
if chk.IsFull() {
e.cursor4GroupKey++
return nil
}
}
e.resetSpillMode()
}
if e.executed {
return nil
}
if err := e.execute(ctx); err != nil {
return err
}
if (len(e.groupSet.StringSet) == 0) && len(e.GroupByItems) == 0 {
Expand All @@ -869,33 +933,34 @@ func (e *HashAggExec) unparallelExec(ctx context.Context, chk *chunk.Chunk) erro
}
e.prepared = true
}
chk.Reset()
}

// Since we return e.maxChunkSize rows every time, so we should not traverse
// `groupSet` because of its randomness.
for ; e.cursor4GroupKey < len(e.groupKeys); e.cursor4GroupKey++ {
partialResults := e.getPartialResults(e.groupKeys[e.cursor4GroupKey])
if len(e.PartialAggFuncs) == 0 {
chk.SetNumVirtualRows(chk.NumRows() + 1)
}
for i, af := range e.PartialAggFuncs {
if err := af.AppendFinalResult2Chunk(e.ctx, partialResults[i], chk); err != nil {
return err
}
}
if chk.IsFull() {
e.cursor4GroupKey++
return nil
}
}
return nil
func (e *HashAggExec) resetSpillMode() {
e.cursor4GroupKey, e.groupKeys = 0, e.groupKeys[:0]
var setSize int64
e.groupSet, setSize = set.NewStringSetWithMemoryUsage()
e.partialResultMap = make(aggPartialResultMapper)
e.bInMap = 0
e.prepared = false
e.executed = e.numOfSpilledChks == e.listInDisk.NumChunks() // No data is spilling again, all data have been processed.
e.numOfSpilledChks = e.listInDisk.NumChunks()
e.memTracker.ReplaceBytesUsed(setSize)
atomic.StoreUint32(&e.inSpillMode, 0)
}

// execute fetches Chunks from src and update each aggregate function for each row in Chunk.
func (e *HashAggExec) execute(ctx context.Context) (err error) {
defer func() {
if e.tmpChkForSpill.NumRows() > 0 && err == nil {
err = e.listInDisk.Add(e.tmpChkForSpill)
e.tmpChkForSpill.Reset()
}
}()
for {
mSize := e.childResult.MemoryUsage()
err := Next(ctx, e.children[0], e.childResult)
if err := e.getNextChunk(ctx); err != nil {
return err
}
failpoint.Inject("ConsumeRandomPanic", nil)
e.memTracker.Consume(e.childResult.MemoryUsage() - mSize)
if err != nil {
Expand All @@ -912,16 +977,20 @@ func (e *HashAggExec) execute(ctx context.Context) (err error) {
if e.childResult.NumRows() == 0 {
return nil
}

e.groupKeyBuffer, err = getGroupKey(e.ctx, e.childResult, e.groupKeyBuffer, e.GroupByItems)
if err != nil {
return err
}

allMemDelta := int64(0)
sel := make([]int, 0, e.childResult.NumRows())
for j := 0; j < e.childResult.NumRows(); j++ {
groupKey := string(e.groupKeyBuffer[j]) // do memory copy here, because e.groupKeyBuffer may be reused.
if !e.groupSet.Exist(groupKey) {
if atomic.LoadUint32(&e.inSpillMode) == 1 && e.groupSet.Count() > 0 {
sel = append(sel, j)
continue
}
allMemDelta += e.groupSet.Insert(groupKey)
e.groupKeys = append(e.groupKeys, groupKey)
}
Expand All @@ -934,11 +1003,57 @@ func (e *HashAggExec) execute(ctx context.Context) (err error) {
allMemDelta += memDelta
}
}

// spill unprocessed data when exceeded.
if len(sel) > 0 {
e.childResult.SetSel(sel)
err = e.spillUnprocessedData()
wshwsh12 marked this conversation as resolved.
Show resolved Hide resolved
if err != nil {
return err
}
}

failpoint.Inject("ConsumeRandomPanic", nil)
e.memTracker.Consume(allMemDelta)
}
}

func (e *HashAggExec) spillUnprocessedData() (err error) {
wshwsh12 marked this conversation as resolved.
Show resolved Hide resolved
for i := 0; i < e.childResult.NumRows(); i++ {
e.tmpChkForSpill.AppendRow(e.childResult.GetRow(i))
if e.tmpChkForSpill.IsFull() {
err = e.listInDisk.Add(e.tmpChkForSpill)
if err != nil {
return err
}
e.tmpChkForSpill.Reset()
}
}
return nil
}

func (e *HashAggExec) getNextChunk(ctx context.Context) (err error) {
e.childResult.Reset()
if !e.isChildDrained {
if err := Next(ctx, e.children[0], e.childResult); err != nil {
return err
}
if e.childResult.NumRows() == 0 {
e.isChildDrained = true
} else {
return nil
}
}
if e.offsetOfSpillChks < e.numOfSpilledChks {
e.childResult, err = e.listInDisk.GetChunk(e.offsetOfSpillChks)
if err != nil {
return err
}
e.offsetOfSpillChks++
}
return nil
}

func (e *HashAggExec) getPartialResults(groupKey string) []aggfuncs.PartialResult {
partialResults, ok := e.partialResultMap[groupKey]
allMemDelta := int64(0)
Expand Down Expand Up @@ -1744,3 +1859,47 @@ func (e *vecGroupChecker) reset() {
e.lastRowDatums = e.lastRowDatums[:0]
}
}

// ActionSpill returns a AggSpillDiskAction for spilling intermediate data for hashAgg.
func (e *HashAggExec) ActionSpill() *AggSpillDiskAction {
if e.spillAction == nil {
e.spillAction = &AggSpillDiskAction{
e: e,
}
}
return e.spillAction
}

// maxSpillTimes indicates how many times the data can spill at most.
const maxSpillTimes = 10
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Add a comment for this


// AggSpillDiskAction implements memory.ActionOnExceed for unparalleled HashAgg.
// If the memory quota of a query is exceeded, AggSpillDiskAction.Action is
// triggered.
type AggSpillDiskAction struct {
memory.BaseOOMAction
e *HashAggExec
spillTimes uint32
}

// Action set HashAggExec spill mode.
func (a *AggSpillDiskAction) Action(t *memory.Tracker) {
if atomic.LoadUint32(&a.e.inSpillMode) == 0 && a.spillTimes < maxSpillTimes {
a.spillTimes++
logutil.BgLogger().Info("memory exceeds quota, set aggregate mode to spill-mode",
zap.Uint32("spillTimes", a.spillTimes))
atomic.StoreUint32(&a.e.inSpillMode, 1)
return
}
if fallback := a.GetFallback(); fallback != nil {
fallback.Action(t)
}
}

// GetPriority get the priority of the Action
func (a *AggSpillDiskAction) GetPriority() int64 {
return memory.DefSpillPriority
}

// SetLogHook sets the hook, it does nothing just to form the memory.ActionOnExceed interface.
func (a *AggSpillDiskAction) SetLogHook(hook func(uint64)) {}
28 changes: 28 additions & 0 deletions executor/aggregate_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -1461,3 +1461,31 @@ func (s *testSuiteAgg) TestIssue23314(c *C) {
res := tk.MustQuery("select col1 from t1 group by col1")
res.Check(testkit.Rows("16:40:20.01"))
}

func (s *testSerialSuite) TestAggInDisk(c *C) {
tk := testkit.NewTestKit(c, s.store)
tk.MustExec("use test")
tk.MustExec("set tidb_hashagg_final_concurrency = 1;")
tk.MustExec("set tidb_hashagg_partial_concurrency = 1;")
tk.MustExec("set tidb_mem_quota_query = 4194304")
tk.MustExec("drop table if exists t1")
tk.MustExec("create table t(a int)")
sql := "insert into t values (0)"
for i := 1; i <= 300; i++ {
sql += fmt.Sprintf(",(%v)", i)
}
sql += ";"
tk.MustExec(sql)
rows := tk.MustQuery("desc analyze select /*+ HASH_AGG() */ avg(t1.a) from t t1 join t t2 group by t1.a, t2.a;").Rows()
for _, row := range rows {
length := len(row)
line := fmt.Sprintf("%v", row)
disk := fmt.Sprintf("%v", row[length-1])
if strings.Contains(line, "HashAgg") {
c.Assert(strings.Contains(disk, "0 Bytes"), IsFalse)
c.Assert(strings.Contains(disk, "MB") ||
strings.Contains(disk, "KB") ||
strings.Contains(disk, "Bytes"), IsTrue)
}
}
}
2 changes: 1 addition & 1 deletion util/memory/tracker.go
Original file line number Diff line number Diff line change
Expand Up @@ -43,7 +43,7 @@ import (
// We have two limits for the memory quota: soft limit and hard limit.
// If the soft limit is exceeded, we will trigger the action that alleviates the
// speed of memory growth. The soft limit is hard-coded as `0.8*hard limit`.
// The actions that could be triggered are: None.
// The actions that could be triggered are: AggSpillDiskAction.
//
// If the hard limit is exceeded, we will trigger the action that immediately
// reduces memory usage. The hard limit is set by the config item `mem-quota-query`
Expand Down