diff --git a/executor/join.go b/executor/join.go index 09990e6c5d4b1..26a7e5f706fc2 100644 --- a/executor/join.go +++ b/executor/join.go @@ -20,7 +20,6 @@ import ( "fmt" "runtime/trace" "strconv" - "sync" "sync/atomic" "time" @@ -67,6 +66,8 @@ type HashJoinExec struct { // closeCh add a lock for closing executor. closeCh chan struct{} + worker util.WaitGroupWrapper + waiter util.WaitGroupWrapper joinType plannercore.JoinType requiredRows int64 @@ -89,9 +90,7 @@ type HashJoinExec struct { prepared bool isOuterJoin bool - // joinWorkerWaitGroup is for sync multiple join workers. - joinWorkerWaitGroup sync.WaitGroup - finished atomic.Value + finished atomic.Value stats *hashJoinRuntimeStats } @@ -146,6 +145,7 @@ func (e *HashJoinExec) Close() error { e.probeChkResourceCh = nil e.joinChkResourceCh = nil terror.Call(e.rowContainer.Close) + e.waiter.Wait() } e.outerMatchedStatus = e.outerMatchedStatus[:0] @@ -168,9 +168,10 @@ func (e *HashJoinExec) Open(ctx context.Context) error { e.diskTracker = disk.NewTracker(e.id, -1) e.diskTracker.AttachTo(e.ctx.GetSessionVars().StmtCtx.DiskTracker) + e.worker = util.WaitGroupWrapper{} + e.waiter = util.WaitGroupWrapper{} e.closeCh = make(chan struct{}) e.finished.Store(false) - e.joinWorkerWaitGroup = sync.WaitGroup{} if e.probeTypes == nil { e.probeTypes = retTypes(e.probeSideExec) @@ -264,13 +265,13 @@ func (e *HashJoinExec) wait4BuildSide() (emptyBuild bool, err error) { // fetchBuildSideRows fetches all rows from build side executor, and append them // to e.buildSideResult. -func (e *HashJoinExec) fetchBuildSideRows(ctx context.Context, chkCh chan<- *chunk.Chunk, doneCh <-chan struct{}) { +func (e *HashJoinExec) fetchBuildSideRows(ctx context.Context, chkCh chan<- *chunk.Chunk, errCh chan<- error, doneCh <-chan struct{}) { defer close(chkCh) var err error failpoint.Inject("issue30289", func(val failpoint.Value) { if val.(bool) { err = errors.Errorf("issue30289 build return error") - e.buildFinished <- errors.Trace(err) + errCh <- errors.Trace(err) return } }) @@ -281,7 +282,7 @@ func (e *HashJoinExec) fetchBuildSideRows(ctx context.Context, chkCh chan<- *chu chk := chunk.NewChunkWithCapacity(e.buildSideExec.base().retFieldTypes, e.ctx.GetSessionVars().MaxChunkSize) err = Next(ctx, e.buildSideExec, chk) if err != nil { - e.buildFinished <- errors.Trace(err) + errCh <- errors.Trace(err) return } failpoint.Inject("errorFetchBuildSideRowsMockOOMPanic", nil) @@ -332,8 +333,7 @@ func (e *HashJoinExec) initializeForProbe() { func (e *HashJoinExec) fetchAndProbeHashTable(ctx context.Context) { e.initializeForProbe() - e.joinWorkerWaitGroup.Add(1) - go util.WithRecovery(func() { + e.worker.RunWithRecover(func() { defer trace.StartRegion(ctx, "HashJoinProbeSideFetcher").End() e.fetchProbeSideChunks(ctx) }, e.handleProbeSideFetcherPanic) @@ -344,14 +344,13 @@ func (e *HashJoinExec) fetchAndProbeHashTable(ctx context.Context) { } for i := uint(0); i < e.concurrency; i++ { - e.joinWorkerWaitGroup.Add(1) workID := i - go util.WithRecovery(func() { + e.worker.RunWithRecover(func() { defer trace.StartRegion(ctx, "HashJoinWorker").End() e.runJoinWorker(workID, probeKeyColIdx) }, e.handleJoinWorkerPanic) } - go util.WithRecovery(e.waitJoinWorkersAndCloseResultChan, nil) + e.waiter.RunWithRecover(e.waitJoinWorkersAndCloseResultChan, nil) } func (e *HashJoinExec) handleProbeSideFetcherPanic(r interface{}) { @@ -361,14 +360,12 @@ func (e *HashJoinExec) handleProbeSideFetcherPanic(r interface{}) { if r != nil { e.joinResultCh <- &hashjoinWorkerResult{err: errors.Errorf("%v", r)} } - e.joinWorkerWaitGroup.Done() } func (e *HashJoinExec) handleJoinWorkerPanic(r interface{}) { if r != nil { e.joinResultCh <- &hashjoinWorkerResult{err: errors.Errorf("%v", r)} } - e.joinWorkerWaitGroup.Done() } // Concurrently handling unmatched rows from the hash table @@ -408,15 +405,14 @@ func (e *HashJoinExec) handleUnmatchedRowsFromHashTable(workerID uint) { } func (e *HashJoinExec) waitJoinWorkersAndCloseResultChan() { - e.joinWorkerWaitGroup.Wait() + e.worker.Wait() if e.useOuterToBuild { // Concurrently handling unmatched rows from the hash table at the tail for i := uint(0); i < e.concurrency; i++ { var workerID = i - e.joinWorkerWaitGroup.Add(1) - go util.WithRecovery(func() { e.handleUnmatchedRowsFromHashTable(workerID) }, e.handleJoinWorkerPanic) + e.worker.RunWithRecover(func() { e.handleUnmatchedRowsFromHashTable(workerID) }, e.handleJoinWorkerPanic) } - e.joinWorkerWaitGroup.Wait() + e.worker.Wait() } close(e.joinResultCh) } @@ -682,7 +678,7 @@ func (e *HashJoinExec) Next(ctx context.Context, req *chunk.Chunk) (err error) { e.rowContainerForProbe[i] = e.rowContainer.ShallowCopy() } } - go util.WithRecovery(func() { + e.worker.RunWithRecover(func() { defer trace.StartRegion(ctx, "HashJoinHashTableBuilder").End() e.fetchAndBuildHashTable(ctx) }, e.handleFetchAndBuildHashTablePanic) @@ -725,10 +721,10 @@ func (e *HashJoinExec) fetchAndBuildHashTable(ctx context.Context) { buildSideResultCh := make(chan *chunk.Chunk, 1) doneCh := make(chan struct{}) fetchBuildSideRowsOk := make(chan error, 1) - go util.WithRecovery( + e.worker.RunWithRecover( func() { defer trace.StartRegion(ctx, "HashJoinBuildSideFetcher").End() - e.fetchBuildSideRows(ctx, buildSideResultCh, doneCh) + e.fetchBuildSideRows(ctx, buildSideResultCh, fetchBuildSideRowsOk, doneCh) }, func(r interface{}) { if r != nil { diff --git a/util/wait_group_wrapper.go b/util/wait_group_wrapper.go new file mode 100644 index 0000000000000..3fb72049f1365 --- /dev/null +++ b/util/wait_group_wrapper.go @@ -0,0 +1,53 @@ +// Copyright 2021 PingCAP, Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package util + +import ( + "sync" +) + +// WaitGroupWrapper is a wrapper for sync.WaitGroup +type WaitGroupWrapper struct { + sync.WaitGroup +} + +// Run runs a function in a goroutine, adds 1 to WaitGroup +// and calls done when function returns. Please DO NOT use panic +// in the cb function. +func (w *WaitGroupWrapper) Run(exec func()) { + w.Add(1) + go func() { + defer w.Done() + exec() + }() +} + +// RunWithRecover wraps goroutine startup call with force recovery, add 1 to WaitGroup +// and call done when function return. it will dump current goroutine stack into log if catch any recover result. +// exec is that execute logic function. recoverFn is that handler will be called after recover and before dump stack, +// passing `nil` means noop. +func (w *WaitGroupWrapper) RunWithRecover(exec func(), recoverFn func(r interface{})) { + w.Add(1) + go func() { + defer func() { + r := recover() + if recoverFn != nil { + recoverFn(r) + } + w.Done() + }() + exec() + }() +}