diff --git a/distsql/distsql_test.go b/distsql/distsql_test.go index e196015b243e6..22411ad7d9134 100644 --- a/distsql/distsql_test.go +++ b/distsql/distsql_test.go @@ -19,6 +19,7 @@ import ( "testing" "time" + "github.com/cznic/mathutil" . "github.com/pingcap/check" "github.com/pingcap/errors" "github.com/pingcap/parser/charset" @@ -34,7 +35,7 @@ import ( "github.com/pingcap/tipb/go-tipb" ) -func (s *testSuite) TestSelectNormal(c *C) { +func (s *testSuite) createSelectNormal(batch, totalRows int, c *C) (*selectResult, []*types.FieldType) { request, err := (&RequestBuilder{}).SetKeyRanges(nil). SetDAGRequest(&tipb.DAGRequest{}). SetDesc(false). @@ -67,13 +68,23 @@ func (s *testSuite) TestSelectNormal(c *C) { c.Assert(result.sqlType, Equals, "general") c.Assert(result.rowLen, Equals, len(colTypes)) + resp, ok := result.resp.(*mockResponse) + c.Assert(ok, IsTrue) + resp.total = totalRows + resp.batch = batch + + return result, colTypes +} + +func (s *testSuite) TestSelectNormal(c *C) { + response, colTypes := s.createSelectNormal(1, 2, c) response.Fetch(context.TODO()) // Test Next. chk := chunk.New(colTypes, 32, 32) numAllRows := 0 for { - err = response.Next(context.TODO(), chk) + err := response.Next(context.TODO(), chk) c.Assert(err, IsNil) numAllRows += chk.NumRows() if chk.NumRows() == 0 { @@ -81,11 +92,17 @@ func (s *testSuite) TestSelectNormal(c *C) { } } c.Assert(numAllRows, Equals, 2) - err = response.Close() + err := response.Close() c.Assert(err, IsNil) } -func (s *testSuite) TestSelectStreaming(c *C) { +func (s *testSuite) TestSelectNormalBatchSize(c *C) { + response, colTypes := s.createSelectNormal(100, 1000000, c) + response.Fetch(context.TODO()) + s.testBatchSize(response, colTypes, c) +} + +func (s *testSuite) createSelectStreaming(batch, totalRows int, c *C) (*streamResult, []*types.FieldType) { request, err := (&RequestBuilder{}).SetKeyRanges(nil). SetDAGRequest(&tipb.DAGRequest{}). SetDesc(false). @@ -112,20 +129,29 @@ func (s *testSuite) TestSelectStreaming(c *C) { s.sctx.GetSessionVars().EnableStreaming = true - // Test Next. response, err := Select(context.TODO(), s.sctx, request, colTypes, statistics.NewQueryFeedback(0, nil, 0, false)) c.Assert(err, IsNil) result, ok := response.(*streamResult) c.Assert(ok, IsTrue) c.Assert(result.rowLen, Equals, len(colTypes)) + resp, ok := result.resp.(*mockResponse) + c.Assert(ok, IsTrue) + resp.total = totalRows + resp.batch = batch + + return result, colTypes +} + +func (s *testSuite) TestSelectStreaming(c *C) { + response, colTypes := s.createSelectStreaming(1, 2, c) response.Fetch(context.TODO()) // Test Next. chk := chunk.New(colTypes, 32, 32) numAllRows := 0 for { - err = response.Next(context.TODO(), chk) + err := response.Next(context.TODO(), chk) c.Assert(err, IsNil) numAllRows += chk.NumRows() if chk.NumRows() == 0 { @@ -133,10 +159,64 @@ func (s *testSuite) TestSelectStreaming(c *C) { } } c.Assert(numAllRows, Equals, 2) - err = response.Close() + err := response.Close() c.Assert(err, IsNil) } +func (s *testSuite) TestSelectStreamingBatchSize(c *C) { + response, colTypes := s.createSelectStreaming(100, 1000000, c) + response.Fetch(context.TODO()) + s.testBatchSize(response, colTypes, c) +} + +func (s *testSuite) testBatchSize(response SelectResult, colTypes []*types.FieldType, c *C) { + chk := chunk.New(colTypes, 32, 32) + batch := chunk.NewRecordBatch(chk) + + err := response.Next(context.TODO(), chk) + c.Assert(err, IsNil) + c.Assert(chk.NumRows(), Equals, 32) + + err = response.NextBatch(context.TODO(), batch) + c.Assert(err, IsNil) + c.Assert(batch.NumRows(), Equals, 32) + + batch.SetRequiredRows(1) + err = response.NextBatch(context.TODO(), batch) + c.Assert(err, IsNil) + c.Assert(batch.NumRows(), Equals, 1) + + batch.SetRequiredRows(2) + err = response.NextBatch(context.TODO(), batch) + c.Assert(err, IsNil) + c.Assert(batch.NumRows(), Equals, 2) + + batch.SetRequiredRows(17) + err = response.NextBatch(context.TODO(), batch) + c.Assert(err, IsNil) + c.Assert(batch.NumRows(), Equals, 17) + + batch.SetRequiredRows(170) + err = response.NextBatch(context.TODO(), batch) + c.Assert(err, IsNil) + c.Assert(batch.NumRows(), Equals, 32) + + batch.SetRequiredRows(32) + err = response.NextBatch(context.TODO(), batch) + c.Assert(err, IsNil) + c.Assert(batch.NumRows(), Equals, 32) + + batch.SetRequiredRows(0) + err = response.NextBatch(context.TODO(), batch) + c.Assert(err, IsNil) + c.Assert(batch.NumRows(), Equals, 32) + + batch.SetRequiredRows(-1) + err = response.NextBatch(context.TODO(), batch) + c.Assert(err, IsNil) + c.Assert(batch.NumRows(), Equals, 32) +} + func (s *testSuite) TestAnalyze(c *C) { request, err := (&RequestBuilder{}).SetKeyRanges(nil). SetAnalyzeRequest(&tipb.AnalyzeReq{}). @@ -166,6 +246,8 @@ func (s *testSuite) TestAnalyze(c *C) { // Used only for test. type mockResponse struct { count int + total int + batch int sync.Mutex } @@ -183,17 +265,24 @@ func (resp *mockResponse) Next(ctx context.Context) (kv.ResultSubset, error) { resp.Lock() defer resp.Unlock() - if resp.count == 2 { + if resp.count >= resp.total { return nil, nil } - defer func() { resp.count++ }() + numRows := mathutil.Min(resp.batch, resp.total-resp.count) + resp.count += numRows datum := types.NewIntDatum(1) bytes := make([]byte, 0, 100) bytes, _ = codec.EncodeValue(nil, bytes, datum, datum, datum, datum) + chunks := make([]tipb.Chunk, numRows) + for i := range chunks { + chkData := make([]byte, len(bytes)) + copy(chkData, bytes) + chunks[i] = tipb.Chunk{RowsData: chkData} + } respPB := &tipb.SelectResponse{ - Chunks: []tipb.Chunk{{RowsData: bytes}}, + Chunks: chunks, OutputCounts: []int64{1}, } respBytes, err := respPB.Marshal() diff --git a/distsql/request_builder_test.go b/distsql/request_builder_test.go index 10d319b9c3e65..640127f163594 100644 --- a/distsql/request_builder_test.go +++ b/distsql/request_builder_test.go @@ -53,7 +53,10 @@ func (s *testSuite) SetUpSuite(c *C) { ctx := mock.NewContext() ctx.Store = &mock.Store{ Client: &mock.Client{ - MockResponse: &mockResponse{}, + MockResponse: &mockResponse{ + batch: 1, + total: 2, + }, }, } s.sctx = ctx @@ -67,7 +70,10 @@ func (s *testSuite) SetUpTest(c *C) { ctx := s.sctx.(*mock.Context) store := ctx.Store.(*mock.Store) store.Client = &mock.Client{ - MockResponse: &mockResponse{}, + MockResponse: &mockResponse{ + batch: 1, + total: 2, + }, } } diff --git a/distsql/select_result.go b/distsql/select_result.go index 5badfc624ec1c..5080cba753b1a 100644 --- a/distsql/select_result.go +++ b/distsql/select_result.go @@ -41,7 +41,10 @@ type SelectResult interface { // NextRaw gets the next raw result. NextRaw(context.Context) ([]byte, error) // Next reads the data into chunk. - Next(context.Context, *chunk.Chunk) error + // TODO: replace all calls of Next to NextBatch and remove this Next method + Next(ctx context.Context, chk *chunk.Chunk) error + // NextBatch reads the data into batch. + NextBatch(ctx context.Context, batch *chunk.RecordBatch) error // Close closes the iterator. Close() error } @@ -115,15 +118,20 @@ func (r *selectResult) NextRaw(ctx context.Context) ([]byte, error) { // Next reads data to the chunk. func (r *selectResult) Next(ctx context.Context, chk *chunk.Chunk) error { - chk.Reset() - for chk.NumRows() < r.ctx.GetSessionVars().MaxChunkSize { + return r.NextBatch(ctx, chunk.NewRecordBatch(chk)) +} + +// NextBatch reads the data into batch. +func (r *selectResult) NextBatch(ctx context.Context, batch *chunk.RecordBatch) error { + batch.Reset() + for !batch.IsFull() && batch.NumRows() < r.ctx.GetSessionVars().MaxChunkSize { if r.selectResp == nil || r.respChkIdx == len(r.selectResp.Chunks) { err := r.getSelectResp() if err != nil || r.selectResp == nil { return errors.Trace(err) } } - err := r.readRowsData(chk) + err := r.readRowsData(batch) if err != nil { return errors.Trace(err) } @@ -167,11 +175,10 @@ func (r *selectResult) getSelectResp() error { } } -func (r *selectResult) readRowsData(chk *chunk.Chunk) (err error) { +func (r *selectResult) readRowsData(batch *chunk.RecordBatch) (err error) { rowsData := r.selectResp.Chunks[r.respChkIdx].RowsData - maxChunkSize := r.ctx.GetSessionVars().MaxChunkSize - decoder := codec.NewDecoder(chk, r.ctx.GetSessionVars().Location()) - for chk.NumRows() < maxChunkSize && len(rowsData) > 0 { + decoder := codec.NewDecoder(batch.Chunk, r.ctx.GetSessionVars().Location()) + for !batch.IsFull() && batch.NumRows() < r.ctx.GetSessionVars().MaxChunkSize && len(rowsData) > 0 { for i := 0; i < r.rowLen; i++ { rowsData, err = decoder.DecodeOne(rowsData, i, r.fieldTypes[i]) if err != nil { diff --git a/distsql/stream.go b/distsql/stream.go index dada7053f7a09..a8a87a7738229 100644 --- a/distsql/stream.go +++ b/distsql/stream.go @@ -43,10 +43,15 @@ type streamResult struct { func (r *streamResult) Fetch(context.Context) {} +// Next reads data to the chunk. func (r *streamResult) Next(ctx context.Context, chk *chunk.Chunk) error { - chk.Reset() - maxChunkSize := r.ctx.GetSessionVars().MaxChunkSize - for chk.NumRows() < maxChunkSize { + return r.NextBatch(ctx, chunk.NewRecordBatch(chk)) +} + +// NextBatch reads the data into batch. +func (r *streamResult) NextBatch(ctx context.Context, batch *chunk.RecordBatch) error { + batch.Reset() + for !batch.IsFull() && batch.NumRows() < r.ctx.GetSessionVars().MaxChunkSize { err := r.readDataIfNecessary(ctx) if err != nil { return errors.Trace(err) @@ -55,7 +60,7 @@ func (r *streamResult) Next(ctx context.Context, chk *chunk.Chunk) error { return nil } - err = r.flushToChunk(chk) + err = r.flushToBatch(batch) if err != nil { return errors.Trace(err) } @@ -113,11 +118,10 @@ func (r *streamResult) readDataIfNecessary(ctx context.Context) error { return nil } -func (r *streamResult) flushToChunk(chk *chunk.Chunk) (err error) { +func (r *streamResult) flushToBatch(batch *chunk.RecordBatch) (err error) { remainRowsData := r.curr.RowsData - maxChunkSize := r.ctx.GetSessionVars().MaxChunkSize - decoder := codec.NewDecoder(chk, r.ctx.GetSessionVars().Location()) - for chk.NumRows() < maxChunkSize && len(remainRowsData) > 0 { + decoder := codec.NewDecoder(batch.Chunk, r.ctx.GetSessionVars().Location()) + for !batch.IsFull() && batch.NumRows() < r.ctx.GetSessionVars().MaxChunkSize && len(remainRowsData) > 0 { for i := 0; i < r.rowLen; i++ { remainRowsData, err = decoder.DecodeOne(remainRowsData, i, r.fieldTypes[i]) if err != nil { diff --git a/util/chunk/recordbatch.go b/util/chunk/recordbatch.go index 7eb79f54f4333..4c1a3520af642 100644 --- a/util/chunk/recordbatch.go +++ b/util/chunk/recordbatch.go @@ -13,12 +13,46 @@ package chunk +// UnspecifiedNumRows represents requiredRows is not specified. +const UnspecifiedNumRows = 0 + // RecordBatch is input parameter of Executor.Next` method. type RecordBatch struct { *Chunk + + // requiredRows indicates how many rows is required by the parent executor. + // Child executor should stop populating rows immediately if there are at + // least required rows in the Chunk. + requiredRows int } // NewRecordBatch is used to construct a RecordBatch. func NewRecordBatch(chk *Chunk) *RecordBatch { - return &RecordBatch{chk} + return &RecordBatch{chk, UnspecifiedNumRows} +} + +// SetRequiredRows sets the number of rows the parent executor want. +func (rb *RecordBatch) SetRequiredRows(numRows int) *RecordBatch { + if numRows <= 0 { + numRows = UnspecifiedNumRows + } + rb.requiredRows = numRows + return rb +} + +// RequiredRows returns how many rows the parent executor want. +func (rb *RecordBatch) RequiredRows() int { + return rb.requiredRows +} + +// IsFull returns if this batch can be considered full. +// IsFull only takes requiredRows into account, the caller of this method should +// also consider maxChunkSize, then it should behave like: +// if !batch.IsFull() && batch.NumRows() < maxChunkSize { ... } +func (rb *RecordBatch) IsFull() bool { + if rb.requiredRows == UnspecifiedNumRows { + return false + } + + return rb.NumRows() >= rb.requiredRows } diff --git a/util/chunk/recordbatch_test.go b/util/chunk/recordbatch_test.go new file mode 100644 index 0000000000000..b2274ef54190a --- /dev/null +++ b/util/chunk/recordbatch_test.go @@ -0,0 +1,58 @@ +// Copyright 2019 PingCAP, Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// See the License for the specific language governing permissions and +// limitations under the License. + +package chunk + +import ( + "github.com/pingcap/check" + "github.com/pingcap/parser/mysql" + "github.com/pingcap/tidb/types" +) + +func (s *testChunkSuite) TestRecordBatch(c *check.C) { + maxChunkSize := 10 + chk := New([]*types.FieldType{types.NewFieldType(mysql.TypeLong)}, maxChunkSize, maxChunkSize) + batch := NewRecordBatch(chk) + c.Assert(batch.RequiredRows(), check.Equals, UnspecifiedNumRows) + for i := 1; i < 10; i++ { + batch.SetRequiredRows(i) + c.Assert(batch.RequiredRows(), check.Equals, i) + } + batch.SetRequiredRows(0) + c.Assert(batch.RequiredRows(), check.Equals, UnspecifiedNumRows) + batch.SetRequiredRows(-1) + c.Assert(batch.RequiredRows(), check.Equals, UnspecifiedNumRows) + batch.SetRequiredRows(1).SetRequiredRows(2).SetRequiredRows(3) + c.Assert(batch.RequiredRows(), check.Equals, 3) + + batch.SetRequiredRows(5) + batch.AppendInt64(0, 1) + batch.AppendInt64(0, 1) + batch.AppendInt64(0, 1) + batch.AppendInt64(0, 1) + c.Assert(batch.NumRows(), check.Equals, 4) + c.Assert(batch.IsFull(), check.IsFalse) + + batch.AppendInt64(0, 1) + c.Assert(batch.NumRows(), check.Equals, 5) + c.Assert(batch.IsFull(), check.IsTrue) + + batch.AppendInt64(0, 1) + batch.AppendInt64(0, 1) + batch.AppendInt64(0, 1) + c.Assert(batch.NumRows(), check.Equals, 8) + c.Assert(batch.IsFull(), check.IsTrue) + + batch.SetRequiredRows(UnspecifiedNumRows) + c.Assert(batch.IsFull(), check.IsFalse) +}