Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

executor: support spill intermediate data for unparalleled hash agg #25714

Merged
merged 18 commits into from
Jul 15, 2021
Merged
181 changes: 155 additions & 26 deletions executor/aggregate.go
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,7 @@ import (
"github.com/pingcap/tidb/types/json"
"github.com/pingcap/tidb/util/chunk"
"github.com/pingcap/tidb/util/codec"
"github.com/pingcap/tidb/util/disk"
"github.com/pingcap/tidb/util/execdetails"
"github.com/pingcap/tidb/util/hack"
"github.com/pingcap/tidb/util/logutil"
Expand Down Expand Up @@ -191,9 +192,18 @@ type HashAggExec struct {
prepared bool
executed bool

memTracker *memory.Tracker // track memory usage.
memTracker *memory.Tracker // track memory usage.
wshwsh12 marked this conversation as resolved.
Show resolved Hide resolved
diskTracker *disk.Tracker

stats *HashAggRuntimeStats

listInDisk *chunk.ListInDisk
lastChunkNum int
processIdx int
spillMode uint32
spillChunk *chunk.Chunk
spillAction *AggSpillDiskAction
childDrained bool
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We need comments for these variables

}

// HashAggInput indicates the input of hash agg exec.
Expand Down Expand Up @@ -233,6 +243,15 @@ func (e *HashAggExec) Close() error {
if e.memTracker != nil {
e.memTracker.ReplaceBytesUsed(0)
}
if e.listInDisk != nil {
if err := e.listInDisk.Close(); err != nil {
return err
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Don't we close the chilrenExec? This may cause leaks.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Addressed.

}
}
if e.spillAction != nil {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What's this for?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Useless.. I remove the code now.

e.spillAction.spillTimes = maxSpillTimes
}
e.spillAction, e.spillChunk = nil, nil
return e.baseExecutor.Close()
}
if e.parallelExecInitialized {
Expand Down Expand Up @@ -301,6 +320,17 @@ func (e *HashAggExec) initForUnparallelExec() {
e.groupKeyBuffer = make([][]byte, 0, 8)
e.childResult = newFirstChunk(e.children[0])
e.memTracker.Consume(e.childResult.MemoryUsage())

e.processIdx, e.lastChunkNum = 0, 0
e.executed, e.childDrained = false, false
e.listInDisk = chunk.NewListInDisk(retTypes(e.children[0]))
e.spillChunk = newFirstChunk(e.children[0])
if e.ctx.GetSessionVars().TrackAggregateMemoryUsage {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why do we need to check this?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If tidb doesn't track aggregate executor memory usgae, should we also try to spill hashAgg when exceeded?
In addition, oom-use-tmp-storage also should be check... I add the check now. PTAL again.

e.diskTracker = disk.NewTracker(e.id, -1)
e.diskTracker.AttachTo(e.ctx.GetSessionVars().StmtCtx.DiskTracker)
e.listInDisk.GetDiskTracker().AttachTo(e.diskTracker)
e.ctx.GetSessionVars().StmtCtx.MemTracker.FallbackOldAndSetNewActionForSoftLimit(e.ActionSpill())
}
}

func (e *HashAggExec) initForParallelExec(ctx sessionctx.Context) {
Expand Down Expand Up @@ -853,10 +883,32 @@ func (e *HashAggExec) parallelExec(ctx context.Context, chk *chunk.Chunk) error

// unparallelExec executes hash aggregation algorithm in single thread.
func (e *HashAggExec) unparallelExec(ctx context.Context, chk *chunk.Chunk) error {
// In this stage we consider all data from src as a single group.
if !e.prepared {
err := e.execute(ctx)
if err != nil {
chk.Reset()
for {
if e.prepared {
// Since we return e.maxChunkSize rows every time, so we should not traverse
// `groupSet` because of its randomness.
for ; e.cursor4GroupKey < len(e.groupKeys); e.cursor4GroupKey++ {
partialResults := e.getPartialResults(e.groupKeys[e.cursor4GroupKey])
if len(e.PartialAggFuncs) == 0 {
chk.SetNumVirtualRows(chk.NumRows() + 1)
}
for i, af := range e.PartialAggFuncs {
if err := af.AppendFinalResult2Chunk(e.ctx, partialResults[i], chk); err != nil {
return err
}
}
if chk.IsFull() {
e.cursor4GroupKey++
return nil
}
}
e.resetSpillMode()
}
if e.executed {
return nil
}
if err := e.execute(ctx); err != nil {
return err
}
if (len(e.groupSet.StringSet) == 0) && len(e.GroupByItems) == 0 {
Expand All @@ -869,33 +921,34 @@ func (e *HashAggExec) unparallelExec(ctx context.Context, chk *chunk.Chunk) erro
}
e.prepared = true
}
chk.Reset()
}

// Since we return e.maxChunkSize rows every time, so we should not traverse
// `groupSet` because of its randomness.
for ; e.cursor4GroupKey < len(e.groupKeys); e.cursor4GroupKey++ {
partialResults := e.getPartialResults(e.groupKeys[e.cursor4GroupKey])
if len(e.PartialAggFuncs) == 0 {
chk.SetNumVirtualRows(chk.NumRows() + 1)
}
for i, af := range e.PartialAggFuncs {
if err := af.AppendFinalResult2Chunk(e.ctx, partialResults[i], chk); err != nil {
return err
}
}
if chk.IsFull() {
e.cursor4GroupKey++
return nil
}
}
return nil
func (e *HashAggExec) resetSpillMode() {
e.cursor4GroupKey, e.groupKeys = 0, e.groupKeys[:0]
var setSize int64
e.groupSet, setSize = set.NewStringSetWithMemoryUsage()
e.partialResultMap = make(aggPartialResultMapper)
e.bInMap = 0
e.prepared = false
e.executed = e.lastChunkNum == e.listInDisk.NumChunks()
e.lastChunkNum = e.listInDisk.NumChunks()
e.memTracker.ReplaceBytesUsed(setSize)
atomic.StoreUint32(&e.spillMode, 0)
}

// execute fetches Chunks from src and update each aggregate function for each row in Chunk.
func (e *HashAggExec) execute(ctx context.Context) (err error) {
defer func() {
if e.spillChunk.NumRows() > 0 && err == nil {
err = e.listInDisk.Add(e.spillChunk)
e.spillChunk.Reset()
}
}()
for {
mSize := e.childResult.MemoryUsage()
err := Next(ctx, e.children[0], e.childResult)
if err := e.getNextChunk(ctx); err != nil {
return err
}
failpoint.Inject("ConsumeRandomPanic", nil)
e.memTracker.Consume(e.childResult.MemoryUsage() - mSize)
if err != nil {
Expand All @@ -912,7 +965,6 @@ func (e *HashAggExec) execute(ctx context.Context) (err error) {
if e.childResult.NumRows() == 0 {
return nil
}

e.groupKeyBuffer, err = getGroupKey(e.ctx, e.childResult, e.groupKeyBuffer, e.GroupByItems)
if err != nil {
return err
Expand All @@ -922,6 +974,17 @@ func (e *HashAggExec) execute(ctx context.Context) (err error) {
for j := 0; j < e.childResult.NumRows(); j++ {
groupKey := string(e.groupKeyBuffer[j]) // do memory copy here, because e.groupKeyBuffer may be reused.
if !e.groupSet.Exist(groupKey) {
if atomic.LoadUint32(&e.spillMode) == 1 && e.groupSet.Count() > 0 {
e.spillChunk.Append(e.childResult, j, j+1)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can we use Chunk.sel to optimize this if-block?

  1. We can check e.groupSet.Exist(groupKey) and build the sel firstly, and then invoke e.spillChunk.Append based on the sel.
  2. Further, if len(sel) == len(e.childResult), we can invoke e.listInDisk.Add(e.childResult) directly.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Addressed.

if e.spillChunk.IsFull() {
err = e.listInDisk.Add(e.spillChunk)
if err != nil {
return err
}
e.spillChunk.Reset()
}
continue
}
allMemDelta += e.groupSet.Insert(groupKey)
e.groupKeys = append(e.groupKeys, groupKey)
}
Expand All @@ -939,6 +1002,29 @@ func (e *HashAggExec) execute(ctx context.Context) (err error) {
}
}

func (e *HashAggExec) getNextChunk(ctx context.Context) (err error) {
e.childResult.Reset()
if !e.childDrained {
if err := Next(ctx, e.children[0], e.childResult); err != nil {
return err
}
if e.childResult.NumRows() == 0 {
e.childDrained = true
} else {
return nil
}
}
if e.processIdx < e.lastChunkNum {
e.childResult, err = e.listInDisk.GetChunk(e.processIdx)
if err != nil {
return err
}
e.processIdx++
return nil
wshwsh12 marked this conversation as resolved.
Show resolved Hide resolved
}
return nil
}

func (e *HashAggExec) getPartialResults(groupKey string) []aggfuncs.PartialResult {
partialResults, ok := e.partialResultMap[groupKey]
allMemDelta := int64(0)
Expand Down Expand Up @@ -1744,3 +1830,46 @@ func (e *vecGroupChecker) reset() {
e.lastRowDatums = e.lastRowDatums[:0]
}
}

// ActionSpill returns a AggSpillDiskAction for spilling intermediate data for hashAgg.
func (e *HashAggExec) ActionSpill() *AggSpillDiskAction {
if e.spillAction == nil {
e.spillAction = &AggSpillDiskAction{
e: e,
}
}
return e.spillAction
}

const maxSpillTimes = 10
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Add a comment for this


// AggSpillDiskAction implements memory.ActionOnExceed for unparalleled HashAgg.
// If the memory quota of a query is exceeded, AggSpillDiskAction.Action is
// triggered.
type AggSpillDiskAction struct {
memory.BaseOOMAction
e *HashAggExec
spillTimes uint32
}

// Action set HashAggExec spill mode.
func (a *AggSpillDiskAction) Action(t *memory.Tracker) {
if atomic.LoadUint32(&a.e.spillMode) == 0 && a.spillTimes < maxSpillTimes {
a.spillTimes++
logutil.BgLogger().Info("memory exceeds quota, set aggregate mode to spill-mode",
zap.Uint32("spillTimes", a.spillTimes))
atomic.StoreUint32(&a.e.spillMode, 1)
return
}
if fallback := a.GetFallback(); fallback != nil {
fallback.Action(t)
}
}

// GetPriority get the priority of the Action
func (a *AggSpillDiskAction) GetPriority() int64 {
return memory.DefSpillPriority
}

// SetLogHook sets the hook, it does nothing just to form the memory.ActionOnExceed interface.
func (a *AggSpillDiskAction) SetLogHook(hook func(uint642 uint64)) {}
wshwsh12 marked this conversation as resolved.
Show resolved Hide resolved
49 changes: 37 additions & 12 deletions util/memory/tracker.go
Original file line number Diff line number Diff line change
Expand Up @@ -46,11 +46,9 @@ type Tracker struct {
// we wouldn't maintain its children in order to avoiding mutex contention.
children map[int][]*Tracker
}
actionMu struct {
sync.Mutex
actionOnExceed ActionOnExceed
}
parMu struct {
actionMu actionMu
actionMuForSoftLimit actionMu
parMu struct {
sync.Mutex
parent *Tracker // The parent memory tracker.
}
Expand All @@ -62,6 +60,11 @@ type Tracker struct {
isGlobal bool // isGlobal indicates whether this tracker is global tracker
}

type actionMu struct {
sync.Mutex
actionOnExceed ActionOnExceed
}

// NewTracker creates a memory tracker.
// 1. "label" is the label used in the usage string.
// 2. "bytesLimit <= 0" means no limit.
Expand Down Expand Up @@ -125,6 +128,14 @@ func (t *Tracker) FallbackOldAndSetNewAction(a ActionOnExceed) {
t.actionMu.actionOnExceed = reArrangeFallback(t.actionMu.actionOnExceed, a)
}

// FallbackOldAndSetNewActionForSoftLimit sets the action when memory usage exceeds soft bytesLimit
// and set the original action as its fallback.
func (t *Tracker) FallbackOldAndSetNewActionForSoftLimit(a ActionOnExceed) {
t.actionMuForSoftLimit.Lock()
defer t.actionMuForSoftLimit.Unlock()
t.actionMuForSoftLimit.actionOnExceed = reArrangeFallback(t.actionMuForSoftLimit.actionOnExceed, a)
}

// GetFallbackForTest get the oom action used by test.
func (t *Tracker) GetFallbackForTest() ActionOnExceed {
t.actionMu.Lock()
Expand Down Expand Up @@ -255,18 +266,24 @@ func (t *Tracker) ReplaceChild(oldChild, newChild *Tracker) {
t.Consume(newConsumed)
}

const softScale = 0.8

// Consume is used to consume a memory usage. "bytes" can be a negative value,
// which means this is a memory release operation. When memory usage of a tracker
// exceeds its bytesLimit, the tracker calls its action, so does each of its ancestors.
func (t *Tracker) Consume(bytes int64) {
if bytes == 0 {
return
}
var rootExceed *Tracker
var rootExceed, rootExceedForSoftLimit *Tracker
for tracker := t; tracker != nil; tracker = tracker.getParent() {
if atomic.AddInt64(&tracker.bytesConsumed, bytes) >= tracker.bytesLimit && tracker.bytesLimit > 0 {
bytesConsumed := atomic.AddInt64(&tracker.bytesConsumed, bytes)
if bytesConsumed >= tracker.bytesLimit && tracker.bytesLimit > 0 {
rootExceed = tracker
}
if bytesConsumed >= int64(float64(tracker.bytesLimit)*softScale) && tracker.bytesLimit > 0 {
rootExceedForSoftLimit = tracker
}

for {
maxNow := atomic.LoadInt64(&tracker.maxConsumed)
Expand All @@ -277,13 +294,21 @@ func (t *Tracker) Consume(bytes int64) {
break
}
}
if bytes > 0 && rootExceed != nil {
rootExceed.actionMu.Lock()
defer rootExceed.actionMu.Unlock()
if rootExceed.actionMu.actionOnExceed != nil {
rootExceed.actionMu.actionOnExceed.Action(rootExceed)

tryAction := func(mu *actionMu, tracker *Tracker) {
mu.Lock()
defer mu.Unlock()
if mu.actionOnExceed != nil {
mu.actionOnExceed.Action(tracker)
}
}

if bytes > 0 && rootExceedForSoftLimit != nil {
tryAction(&rootExceedForSoftLimit.actionMuForSoftLimit, rootExceedForSoftLimit)
}
if bytes > 0 && rootExceed != nil {
tryAction(&rootExceed.actionMu, rootExceed)
}
}

// BytesConsumed returns the consumed memory usage value in bytes.
Expand Down