Skip to content

Commit

Permalink
executor: refine HashAgg.Close when unparallelExec (pingcap#8810)
Browse files Browse the repository at this point in the history
  • Loading branch information
XuHuaiyu authored and AndrewDi committed Dec 28, 2018
1 parent b7f7090 commit ac84e81
Show file tree
Hide file tree
Showing 3 changed files with 61 additions and 4 deletions.
16 changes: 14 additions & 2 deletions executor/aggregate.go
Original file line number Diff line number Diff line change
Expand Up @@ -199,7 +199,7 @@ func (e *HashAggExec) Close() error {
e.childResult = nil
e.groupSet = nil
e.partialResultMap = nil
return nil
return e.baseExecutor.Close()
}
// `Close` may be called after `Open` without calling `Next` in test.
if !e.prepared {
Expand All @@ -216,7 +216,7 @@ func (e *HashAggExec) Close() error {
}
for range e.finalOutputCh {
}
return errors.Trace(e.baseExecutor.Close())
return e.baseExecutor.Close()
}

// Open implements the Executor Open interface.
Expand Down Expand Up @@ -611,6 +611,12 @@ func (e *HashAggExec) parallelExec(ctx context.Context, chk *chunk.Chunk) error
e.prepare4ParallelExec(ctx)
e.prepared = true
}

// gofail: var parallelHashAggError bool
// if parallelHashAggError {
// return errors.New("HashAggExec.parallelExec error")
// }

for {
result, ok := <-e.finalOutputCh
if !ok || result.err != nil || result.chk.NumRows() == 0 {
Expand Down Expand Up @@ -682,6 +688,12 @@ func (e *HashAggExec) execute(ctx context.Context) (err error) {
if err != nil {
return errors.Trace(err)
}

// gofail: var unparallelHashAggError bool
// if unparallelHashAggError {
// return errors.New("HashAggExec.unparallelExec error")
// }

// no more data.
if e.childResult.NumRows() == 0 {
return nil
Expand Down
4 changes: 2 additions & 2 deletions executor/builder.go
Original file line number Diff line number Diff line change
Expand Up @@ -1031,8 +1031,8 @@ func (b *executorBuilder) buildProjBelowAgg(aggFuncs []*aggregation.AggFuncDesc,
}

return &ProjectionExec{
baseExecutor: newBaseExecutor(b.ctx, expression.NewSchema(projSchemaCols...), projFromID, src),
//numWorkers: b.ctx.GetSessionVars().ProjectionConcurrency,
baseExecutor: newBaseExecutor(b.ctx, expression.NewSchema(projSchemaCols...), projFromID, src),
numWorkers: b.ctx.GetSessionVars().ProjectionConcurrency,
evaluatorSuit: expression.NewEvaluatorSuite(projExprs, false),
}
}
Expand Down
45 changes: 45 additions & 0 deletions executor/seqtest/seq_executor_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,8 @@
// See the License for the specific language governing permissions and
// limitations under the License.

// Note: All the tests in this file will be executed sequentially.

package executor_test

import (
Expand Down Expand Up @@ -647,6 +649,49 @@ func (s *seqTestSuite) TestIndexDoubleReadClose(c *C) {
atomic.StoreInt32(&executor.LookupTableTaskChannelSize, originSize)
}

func (s *seqTestSuite) TestParallelHashAggClose(c *C) {
tk := testkit.NewTestKitWithInit(c, s.store)
tk.MustExec(`use test;`)
tk.MustExec(`drop table if exists t;`)
tk.MustExec("create table t(a int, b int)")
tk.MustExec("insert into t values(1,1),(2,2)")
// desc select sum(a) from (select cast(t.a as signed) as a, b from t) t group by b
// HashAgg_8 | 2.40 | root | group by:t.b, funcs:sum(t.a)
// └─Projection_9 | 3.00 | root | cast(test.t.a), test.t.b
// └─TableReader_11 | 3.00 | root | data:TableScan_10
// └─TableScan_10 | 3.00 | cop | table:t, range:[-inf,+inf], keep order:fa$se, stats:pseudo |

// Goroutine should not leak when error happen.
gofail.Enable("github.com/pingcap/tidb/executor/parallelHashAggError", `return(true)`)
defer gofail.Disable("github.com/pingcap/tidb/executor/parallelHashAggError")
ctx := context.Background()
rss, err := tk.Se.Execute(ctx, "select sum(a) from (select cast(t.a as signed) as a, b from t) t group by b;")
c.Assert(err, IsNil)
rs := rss[0]
chk := rs.NewChunk()
err = rs.Next(ctx, chk)
c.Assert(err.Error(), Equals, "HashAggExec.parallelExec error")
}

func (s *seqTestSuite) TestUnparallelHashAggClose(c *C) {
tk := testkit.NewTestKitWithInit(c, s.store)
tk.MustExec(`use test;`)
tk.MustExec(`drop table if exists t;`)
tk.MustExec("create table t(a int, b int)")
tk.MustExec("insert into t values(1,1),(2,2)")

// Goroutine should not leak when error happen.
gofail.Enable("github.com/pingcap/tidb/executor/unparallelHashAggError", `return(true)`)
defer gofail.Disable("github.com/pingcap/tidb/executor/unparallelHashAggError")
ctx := context.Background()
rss, err := tk.Se.Execute(ctx, "select sum(distinct a) from (select cast(t.a as signed) as a, b from t) t group by b;")
c.Assert(err, IsNil)
rs := rss[0]
chk := rs.NewChunk()
err = rs.Next(ctx, chk)
c.Assert(err.Error(), Equals, "HashAggExec.unparallelExec error")
}

func checkGoroutineExists(keyword string) bool {
buf := new(bytes.Buffer)
profile := pprof.Lookup("goroutine")
Expand Down

0 comments on commit ac84e81

Please sign in to comment.