Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

executor: support spill intermediate data for unparalleled hash agg #25714

Merged
merged 18 commits into from
Jul 15, 2021
Merged
207 changes: 180 additions & 27 deletions executor/aggregate.go
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@ import (
"github.com/pingcap/errors"
"github.com/pingcap/failpoint"
"github.com/pingcap/parser/mysql"
"github.com/pingcap/tidb/config"
"github.com/pingcap/tidb/executor/aggfuncs"
"github.com/pingcap/tidb/expression"
"github.com/pingcap/tidb/sessionctx"
Expand All @@ -34,6 +35,7 @@ import (
"github.com/pingcap/tidb/types/json"
"github.com/pingcap/tidb/util/chunk"
"github.com/pingcap/tidb/util/codec"
"github.com/pingcap/tidb/util/disk"
"github.com/pingcap/tidb/util/execdetails"
"github.com/pingcap/tidb/util/hack"
"github.com/pingcap/tidb/util/logutil"
Expand Down Expand Up @@ -191,9 +193,18 @@ type HashAggExec struct {
prepared bool
executed bool

memTracker *memory.Tracker // track memory usage.
memTracker *memory.Tracker // track memory usage.
wshwsh12 marked this conversation as resolved.
Show resolved Hide resolved
diskTracker *disk.Tracker

stats *HashAggRuntimeStats

listInDisk *chunk.ListInDisk // listInDisk is the chunks to store row values for spilling data.
lastChunkNum int // lastChunkNum indicates the num of spilling chunk.
wshwsh12 marked this conversation as resolved.
Show resolved Hide resolved
processIdx int // processIdx indicates the num of processed chunk in disk.
wshwsh12 marked this conversation as resolved.
Show resolved Hide resolved
spillMode uint32 // spillMode means that no new groups are added to hash table.
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

  1. isSpillModeSet?
  2. Add an explanation for what does 0 and 1 mean

spillChunk *chunk.Chunk // spillChunk is the temp chunk for spilling.
wshwsh12 marked this conversation as resolved.
Show resolved Hide resolved
spillAction *AggSpillDiskAction // spillAction save the Action for spilling.
childDrained bool // childDrained indicates whether the all data from child has been taken out.
wshwsh12 marked this conversation as resolved.
Show resolved Hide resolved
}

// HashAggInput indicates the input of hash agg exec.
Expand Down Expand Up @@ -227,13 +238,21 @@ func (d *HashAggIntermData) getPartialResultBatch(sc *stmtctx.StatementContext,
// Close implements the Executor Close interface.
func (e *HashAggExec) Close() error {
if e.isUnparallelExec {
var firstErr error
e.childResult = nil
e.groupSet, _ = set.NewStringSetWithMemoryUsage()
e.partialResultMap = nil
if e.memTracker != nil {
e.memTracker.ReplaceBytesUsed(0)
}
return e.baseExecutor.Close()
if e.listInDisk != nil {
firstErr = e.listInDisk.Close()
}
e.spillAction, e.spillChunk = nil, nil
if err := e.baseExecutor.Close(); firstErr == nil {
firstErr = err
}
return firstErr
}
if e.parallelExecInitialized {
// `Close` may be called after `Open` without calling `Next` in test.
Expand Down Expand Up @@ -301,6 +320,17 @@ func (e *HashAggExec) initForUnparallelExec() {
e.groupKeyBuffer = make([][]byte, 0, 8)
e.childResult = newFirstChunk(e.children[0])
e.memTracker.Consume(e.childResult.MemoryUsage())

e.processIdx, e.lastChunkNum = 0, 0
e.executed, e.childDrained = false, false
e.listInDisk = chunk.NewListInDisk(retTypes(e.children[0]))
e.spillChunk = newFirstChunk(e.children[0])
if e.ctx.GetSessionVars().TrackAggregateMemoryUsage && config.GetGlobalConfig().OOMUseTmpStorage {
e.diskTracker = disk.NewTracker(e.id, -1)
e.diskTracker.AttachTo(e.ctx.GetSessionVars().StmtCtx.DiskTracker)
e.listInDisk.GetDiskTracker().AttachTo(e.diskTracker)
e.ctx.GetSessionVars().StmtCtx.MemTracker.FallbackOldAndSetNewActionForSoftLimit(e.ActionSpill())
}
}

func (e *HashAggExec) initForParallelExec(ctx sessionctx.Context) {
Expand Down Expand Up @@ -853,10 +883,32 @@ func (e *HashAggExec) parallelExec(ctx context.Context, chk *chunk.Chunk) error

// unparallelExec executes hash aggregation algorithm in single thread.
func (e *HashAggExec) unparallelExec(ctx context.Context, chk *chunk.Chunk) error {
// In this stage we consider all data from src as a single group.
if !e.prepared {
err := e.execute(ctx)
if err != nil {
chk.Reset()
for {
if e.prepared {
// Since we return e.maxChunkSize rows every time, so we should not traverse
// `groupSet` because of its randomness.
for ; e.cursor4GroupKey < len(e.groupKeys); e.cursor4GroupKey++ {
partialResults := e.getPartialResults(e.groupKeys[e.cursor4GroupKey])
if len(e.PartialAggFuncs) == 0 {
chk.SetNumVirtualRows(chk.NumRows() + 1)
}
for i, af := range e.PartialAggFuncs {
if err := af.AppendFinalResult2Chunk(e.ctx, partialResults[i], chk); err != nil {
return err
}
}
if chk.IsFull() {
e.cursor4GroupKey++
return nil
}
}
e.resetSpillMode()
}
if e.executed {
return nil
}
if err := e.execute(ctx); err != nil {
return err
}
if (len(e.groupSet.StringSet) == 0) && len(e.GroupByItems) == 0 {
Expand All @@ -869,33 +921,34 @@ func (e *HashAggExec) unparallelExec(ctx context.Context, chk *chunk.Chunk) erro
}
e.prepared = true
}
chk.Reset()
}

// Since we return e.maxChunkSize rows every time, so we should not traverse
// `groupSet` because of its randomness.
for ; e.cursor4GroupKey < len(e.groupKeys); e.cursor4GroupKey++ {
partialResults := e.getPartialResults(e.groupKeys[e.cursor4GroupKey])
if len(e.PartialAggFuncs) == 0 {
chk.SetNumVirtualRows(chk.NumRows() + 1)
}
for i, af := range e.PartialAggFuncs {
if err := af.AppendFinalResult2Chunk(e.ctx, partialResults[i], chk); err != nil {
return err
}
}
if chk.IsFull() {
e.cursor4GroupKey++
return nil
}
}
return nil
func (e *HashAggExec) resetSpillMode() {
e.cursor4GroupKey, e.groupKeys = 0, e.groupKeys[:0]
var setSize int64
e.groupSet, setSize = set.NewStringSetWithMemoryUsage()
e.partialResultMap = make(aggPartialResultMapper)
e.bInMap = 0
e.prepared = false
e.executed = e.lastChunkNum == e.listInDisk.NumChunks() // No data is spilling again, all data have been processed.
e.lastChunkNum = e.listInDisk.NumChunks()
e.memTracker.ReplaceBytesUsed(setSize)
atomic.StoreUint32(&e.spillMode, 0)
}

// execute fetches Chunks from src and update each aggregate function for each row in Chunk.
func (e *HashAggExec) execute(ctx context.Context) (err error) {
defer func() {
if e.spillChunk.NumRows() > 0 && err == nil {
err = e.listInDisk.Add(e.spillChunk)
e.spillChunk.Reset()
}
}()
for {
mSize := e.childResult.MemoryUsage()
err := Next(ctx, e.children[0], e.childResult)
if err := e.getNextChunk(ctx); err != nil {
return err
}
failpoint.Inject("ConsumeRandomPanic", nil)
e.memTracker.Consume(e.childResult.MemoryUsage() - mSize)
if err != nil {
Expand All @@ -912,16 +965,20 @@ func (e *HashAggExec) execute(ctx context.Context) (err error) {
if e.childResult.NumRows() == 0 {
return nil
}

e.groupKeyBuffer, err = getGroupKey(e.ctx, e.childResult, e.groupKeyBuffer, e.GroupByItems)
if err != nil {
return err
}

allMemDelta := int64(0)
sel := make([]int, 0, e.childResult.NumRows())
for j := 0; j < e.childResult.NumRows(); j++ {
groupKey := string(e.groupKeyBuffer[j]) // do memory copy here, because e.groupKeyBuffer may be reused.
if !e.groupSet.Exist(groupKey) {
if atomic.LoadUint32(&e.spillMode) == 1 && e.groupSet.Count() > 0 {
sel = append(sel, j)
continue
}
allMemDelta += e.groupSet.Insert(groupKey)
e.groupKeys = append(e.groupKeys, groupKey)
}
Expand All @@ -934,11 +991,63 @@ func (e *HashAggExec) execute(ctx context.Context) (err error) {
allMemDelta += memDelta
}
}

// spill unprocessed data when exceeded.
if len(sel) > 0 {
err = e.spillUnprocessedData(sel)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The input argument sel is useless?

e.childResult.SetSel(sel)

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

e.childResult.SetSel(sel) will let len(sel) == len(e.childResult) always true, and e.listInDisk.Add(e.childResult) directly. If there are only a few elements in sel, it maybe have performance issue.
I remove the logic e.listInDisk.Add(e.childResult) and always append to tmpChkForSpill, PTAL

if err != nil {
return err
}
}

failpoint.Inject("ConsumeRandomPanic", nil)
e.memTracker.Consume(allMemDelta)
}
}

func (e *HashAggExec) spillUnprocessedData(sel []int) (err error) {
if len(sel) == e.childResult.NumRows() {
err = e.listInDisk.Add(e.childResult)
if err != nil {
return err
}
} else {
for _, j := range sel {
wshwsh12 marked this conversation as resolved.
Show resolved Hide resolved
e.spillChunk.Append(e.childResult, j, j+1)
if e.spillChunk.IsFull() {
err = e.listInDisk.Add(e.spillChunk)
if err != nil {
return err
}
e.spillChunk.Reset()
}
}
}
return nil
}

func (e *HashAggExec) getNextChunk(ctx context.Context) (err error) {
e.childResult.Reset()
if !e.childDrained {
if err := Next(ctx, e.children[0], e.childResult); err != nil {
return err
}
if e.childResult.NumRows() == 0 {
e.childDrained = true
} else {
return nil
}
}
if e.processIdx < e.lastChunkNum {
e.childResult, err = e.listInDisk.GetChunk(e.processIdx)
if err != nil {
return err
}
e.processIdx++
}
return nil
}

func (e *HashAggExec) getPartialResults(groupKey string) []aggfuncs.PartialResult {
partialResults, ok := e.partialResultMap[groupKey]
allMemDelta := int64(0)
Expand Down Expand Up @@ -1744,3 +1853,47 @@ func (e *vecGroupChecker) reset() {
e.lastRowDatums = e.lastRowDatums[:0]
}
}

// ActionSpill returns a AggSpillDiskAction for spilling intermediate data for hashAgg.
func (e *HashAggExec) ActionSpill() *AggSpillDiskAction {
if e.spillAction == nil {
e.spillAction = &AggSpillDiskAction{
e: e,
}
}
return e.spillAction
}

// maxSpillTimes indicates how many times the data can spill at most.
const maxSpillTimes = 10
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Add a comment for this


// AggSpillDiskAction implements memory.ActionOnExceed for unparalleled HashAgg.
// If the memory quota of a query is exceeded, AggSpillDiskAction.Action is
// triggered.
type AggSpillDiskAction struct {
memory.BaseOOMAction
e *HashAggExec
spillTimes uint32
}

// Action set HashAggExec spill mode.
func (a *AggSpillDiskAction) Action(t *memory.Tracker) {
if atomic.LoadUint32(&a.e.spillMode) == 0 && a.spillTimes < maxSpillTimes {
a.spillTimes++
logutil.BgLogger().Info("memory exceeds quota, set aggregate mode to spill-mode",
zap.Uint32("spillTimes", a.spillTimes))
atomic.StoreUint32(&a.e.spillMode, 1)
return
}
if fallback := a.GetFallback(); fallback != nil {
fallback.Action(t)
}
}

// GetPriority get the priority of the Action
func (a *AggSpillDiskAction) GetPriority() int64 {
return memory.DefSpillPriority
}

// SetLogHook sets the hook, it does nothing just to form the memory.ActionOnExceed interface.
func (a *AggSpillDiskAction) SetLogHook(hook func(uint642 uint64)) {}
wshwsh12 marked this conversation as resolved.
Show resolved Hide resolved
28 changes: 28 additions & 0 deletions executor/aggregate_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -1461,3 +1461,31 @@ func (s *testSuiteAgg) TestIssue23314(c *C) {
res := tk.MustQuery("select col1 from t1 group by col1")
res.Check(testkit.Rows("16:40:20.01"))
}

func (s *testSerialSuite) TestAggInDisk(c *C) {
tk := testkit.NewTestKit(c, s.store)
tk.MustExec("use test")
tk.MustExec("set tidb_hashagg_final_concurrency = 1;")
tk.MustExec("set tidb_hashagg_partial_concurrency = 1;")
tk.MustExec("set tidb_mem_quota_query = 4194304")
tk.MustExec("drop table if exists t1")
tk.MustExec("create table t(a int)")
sql := "insert into t values (0)"
for i := 1; i <= 300; i++ {
sql += fmt.Sprintf(",(%v)", i)
}
sql += ";"
tk.MustExec(sql)
rows := tk.MustQuery("desc analyze select /*+ HASH_AGG() */ avg(t1.a) from t t1 join t t2 group by t1.a, t2.a;").Rows()
for _, row := range rows {
length := len(row)
line := fmt.Sprintf("%v", row)
disk := fmt.Sprintf("%v", row[length-1])
if strings.Contains(line, "HashAgg") {
c.Assert(strings.Contains(disk, "0 Bytes"), IsFalse)
c.Assert(strings.Contains(disk, "MB") ||
strings.Contains(disk, "KB") ||
strings.Contains(disk, "Bytes"), IsTrue)
}
}
}
2 changes: 1 addition & 1 deletion util/memory/tracker.go
Original file line number Diff line number Diff line change
Expand Up @@ -43,7 +43,7 @@ import (
// We have two limits for the memory quota: soft limit and hard limit.
// If the soft limit is exceeded, we will trigger the action that alleviates the
// speed of memory growth. The soft limit is hard-coded as `0.8*hard limit`.
// The actions that could be triggered are: None.
// The actions that could be triggered are: AggSpillDiskAction.
//
// If the hard limit is exceeded, we will trigger the action that immediately
// reduces memory usage. The hard limit is set by the config item `mem-quota-query`
Expand Down