Skip to content

Commit

Permalink
planner: unify the argument of stats functions to use SessionCtx inst…
Browse files Browse the repository at this point in the history
…ead of StatementContext (#30668)
  • Loading branch information
qw4990 authored Dec 14, 2021
1 parent e9b1fb8 commit 2f42f7c
Show file tree
Hide file tree
Showing 16 changed files with 204 additions and 179 deletions.
17 changes: 8 additions & 9 deletions planner/core/find_best_task.go
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@ import (
"github.com/pingcap/tidb/parser/mysql"
"github.com/pingcap/tidb/planner/property"
"github.com/pingcap/tidb/planner/util"
"github.com/pingcap/tidb/sessionctx/stmtctx"
"github.com/pingcap/tidb/sessionctx"
"github.com/pingcap/tidb/statistics"
"github.com/pingcap/tidb/types"
tidbutil "github.com/pingcap/tidb/util"
Expand Down Expand Up @@ -1478,7 +1478,7 @@ func getMostCorrCol4Handle(exprs []expression.Expression, histColl *statistics.T
}

// getColumnRangeCounts estimates row count for each range respectively.
func getColumnRangeCounts(sc *stmtctx.StatementContext, colID int64, ranges []*ranger.Range, histColl *statistics.HistColl, idxID int64) ([]float64, bool) {
func getColumnRangeCounts(sctx sessionctx.Context, colID int64, ranges []*ranger.Range, histColl *statistics.HistColl, idxID int64) ([]float64, bool) {
var err error
var count float64
rangeCounts := make([]float64, len(ranges))
Expand All @@ -1488,13 +1488,13 @@ func getColumnRangeCounts(sc *stmtctx.StatementContext, colID int64, ranges []*r
if idxHist == nil || idxHist.IsInvalid(false) {
return nil, false
}
count, err = histColl.GetRowCountByIndexRanges(sc, idxID, []*ranger.Range{ran})
count, err = histColl.GetRowCountByIndexRanges(sctx, idxID, []*ranger.Range{ran})
} else {
colHist, ok := histColl.Columns[colID]
if !ok || colHist.IsInvalid(sc, false) {
if !ok || colHist.IsInvalid(sctx, false) {
return nil, false
}
count, err = histColl.GetRowCountByColumnRanges(sc, colID, []*ranger.Range{ran})
count, err = histColl.GetRowCountByColumnRanges(sctx, colID, []*ranger.Range{ran})
}
if err != nil {
return nil, false
Expand Down Expand Up @@ -1564,7 +1564,6 @@ func (ds *DataSource) crossEstimateRowCount(path *util.AccessPath, conds []expre
if len(accessConds) == 0 {
return 0, false, corr
}
sc := ds.ctx.GetSessionVars().StmtCtx
ranges, err := ranger.BuildColumnRange(accessConds, ds.ctx, col.RetType, types.UnspecifiedLength)
if len(ranges) == 0 || err != nil {
return 0, err == nil, corr
Expand All @@ -1573,7 +1572,7 @@ func (ds *DataSource) crossEstimateRowCount(path *util.AccessPath, conds []expre
if !idxExists {
idxID = -1
}
rangeCounts, ok := getColumnRangeCounts(sc, colID, ranges, ds.tableStats.HistColl, idxID)
rangeCounts, ok := getColumnRangeCounts(ds.ctx, colID, ranges, ds.tableStats.HistColl, idxID)
if !ok {
return 0, false, corr
}
Expand All @@ -1583,9 +1582,9 @@ func (ds *DataSource) crossEstimateRowCount(path *util.AccessPath, conds []expre
}
var rangeCount float64
if idxExists {
rangeCount, err = ds.tableStats.HistColl.GetRowCountByIndexRanges(sc, idxID, convertedRanges)
rangeCount, err = ds.tableStats.HistColl.GetRowCountByIndexRanges(ds.ctx, idxID, convertedRanges)
} else {
rangeCount, err = ds.tableStats.HistColl.GetRowCountByColumnRanges(sc, colID, convertedRanges)
rangeCount, err = ds.tableStats.HistColl.GetRowCountByColumnRanges(ds.ctx, colID, convertedRanges)
}
if err != nil {
return 0, false, corr
Expand Down
9 changes: 3 additions & 6 deletions planner/core/logical_plans.go
Original file line number Diff line number Diff line change
Expand Up @@ -726,7 +726,6 @@ func (ds *DataSource) deriveCommonHandleTablePathStats(path *util.AccessPath, co
if len(conds) == 0 {
return nil
}
sc := ds.ctx.GetSessionVars().StmtCtx
if len(path.IdxCols) != 0 {
res, err := ranger.DetachCondAndBuildRangeForIndex(ds.ctx, conds, path.IdxCols, path.IdxColLens)
if err != nil {
Expand All @@ -744,7 +743,7 @@ func (ds *DataSource) deriveCommonHandleTablePathStats(path *util.AccessPath, co
path.ConstCols[i] = res.ColumnValues[i] != nil
}
}
path.CountAfterAccess, err = ds.tableStats.HistColl.GetRowCountByIndexRanges(sc, path.Index.ID, path.Ranges)
path.CountAfterAccess, err = ds.tableStats.HistColl.GetRowCountByIndexRanges(ds.ctx, path.Index.ID, path.Ranges)
if err != nil {
return err
}
Expand Down Expand Up @@ -785,7 +784,6 @@ func (ds *DataSource) deriveTablePathStats(path *util.AccessPath, conds []expres
return ds.deriveCommonHandleTablePathStats(path, conds, isIm)
}
var err error
sc := ds.ctx.GetSessionVars().StmtCtx
path.CountAfterAccess = float64(ds.statisticTable.Count)
path.TableFilters = conds
var pkCol *expression.Column
Expand Down Expand Up @@ -848,7 +846,7 @@ func (ds *DataSource) deriveTablePathStats(path *util.AccessPath, conds []expres
if err != nil {
return err
}
path.CountAfterAccess, err = ds.statisticTable.GetRowCountByIntColumnRanges(sc, pkCol.ID, path.Ranges)
path.CountAfterAccess, err = ds.statisticTable.GetRowCountByIntColumnRanges(ds.ctx, pkCol.ID, path.Ranges)
// If the `CountAfterAccess` is less than `stats.RowCount`, there must be some inconsistent stats info.
// We prefer the `stats.RowCount` because it could use more stats info to calculate the selectivity.
if path.CountAfterAccess < ds.stats.RowCount && !isIm {
Expand All @@ -858,7 +856,6 @@ func (ds *DataSource) deriveTablePathStats(path *util.AccessPath, conds []expres
}

func (ds *DataSource) fillIndexPath(path *util.AccessPath, conds []expression.Expression) error {
sc := ds.ctx.GetSessionVars().StmtCtx
path.Ranges = ranger.FullRange()
path.CountAfterAccess = float64(ds.statisticTable.Count)
path.IdxCols, path.IdxColLens = expression.IndexInfo2PrefixCols(ds.Columns, ds.schema.Columns, path.Index)
Expand Down Expand Up @@ -900,7 +897,7 @@ func (ds *DataSource) fillIndexPath(path *util.AccessPath, conds []expression.Ex
path.ConstCols[i] = res.ColumnValues[i] != nil
}
}
path.CountAfterAccess, err = ds.tableStats.HistColl.GetRowCountByIndexRanges(sc, path.Index.ID, path.Ranges)
path.CountAfterAccess, err = ds.tableStats.HistColl.GetRowCountByIndexRanges(ds.ctx, path.Index.ID, path.Ranges)
if err != nil {
return err
}
Expand Down
6 changes: 3 additions & 3 deletions planner/core/rule_partition_processor.go
Original file line number Diff line number Diff line change
Expand Up @@ -140,7 +140,7 @@ func (s *partitionProcessor) findUsedPartitions(ctx sessionctx.Context, tbl tabl
ranges := detachedResult.Ranges
used := make([]int, 0, len(ranges))
for _, r := range ranges {
if r.IsPointNullable(ctx.GetSessionVars().StmtCtx) {
if r.IsPointNullable(ctx) {
if !r.HighVal[0].IsNull() {
if len(r.HighVal) != len(partIdx) {
used = []int{-1}
Expand Down Expand Up @@ -473,7 +473,7 @@ func (l *listPartitionPruner) locateColumnPartitionsByCondition(cond expression.
return nil, true, nil
}
var locations []tables.ListPartitionLocation
if r.IsPointNullable(l.ctx.GetSessionVars().StmtCtx) {
if r.IsPointNullable(l.ctx) {
location, err := colPrune.LocatePartition(sc, r.HighVal[0])
if types.ErrOverflow.Equal(err) {
return nil, true, nil // return full-scan if over-flow
Expand Down Expand Up @@ -555,7 +555,7 @@ func (l *listPartitionPruner) findUsedListPartitions(conds []expression.Expressi
}
used := make(map[int]struct{}, len(ranges))
for _, r := range ranges {
if r.IsPointNullable(l.ctx.GetSessionVars().StmtCtx) {
if r.IsPointNullable(l.ctx) {
if len(r.HighVal) != len(exprCols) {
return l.fullRange, nil
}
Expand Down
2 changes: 1 addition & 1 deletion planner/core/stats.go
Original file line number Diff line number Diff line change
Expand Up @@ -253,7 +253,7 @@ func (ds *DataSource) deriveStatsByFilter(conds expression.CNFExprs, filledPaths
}
stats := ds.tableStats.Scale(selectivity)
if ds.ctx.GetSessionVars().OptimizerSelectivityLevel >= 1 {
stats.HistColl = stats.HistColl.NewHistCollBySelectivity(ds.ctx.GetSessionVars().StmtCtx, nodes)
stats.HistColl = stats.HistColl.NewHistCollBySelectivity(ds.ctx, nodes)
}
return stats
}
Expand Down
2 changes: 1 addition & 1 deletion planner/util/path.go
Original file line number Diff line number Diff line change
Expand Up @@ -153,7 +153,7 @@ func isColEqCorColOrConstant(ctx sessionctx.Context, filter expression.Expressio
func (path *AccessPath) OnlyPointRange(sctx sessionctx.Context) bool {
if path.IsIntHandlePath {
for _, ran := range path.Ranges {
if !ran.IsPointNullable(sctx.GetSessionVars().StmtCtx) {
if !ran.IsPointNullable(sctx) {
return false
}
}
Expand Down
14 changes: 7 additions & 7 deletions statistics/handle/ddl_serial_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -18,10 +18,10 @@ import (
"testing"

"github.com/pingcap/tidb/parser/model"
"github.com/pingcap/tidb/sessionctx/stmtctx"
"github.com/pingcap/tidb/sessionctx/variable"
"github.com/pingcap/tidb/testkit"
"github.com/pingcap/tidb/types"
"github.com/pingcap/tidb/util/mock"
"github.com/stretchr/testify/require"
)

Expand Down Expand Up @@ -51,10 +51,10 @@ func TestDDLAfterLoad(t *testing.T) {
require.NoError(t, err)
tableInfo = tbl.Meta()

sc := new(stmtctx.StatementContext)
count := statsTbl.ColumnGreaterRowCount(sc, types.NewDatum(recordCount+1), tableInfo.Columns[0].ID)
sctx := mock.NewContext()
count := statsTbl.ColumnGreaterRowCount(sctx, types.NewDatum(recordCount+1), tableInfo.Columns[0].ID)
require.Equal(t, 0.0, count)
count = statsTbl.ColumnGreaterRowCount(sc, types.NewDatum(recordCount+1), tableInfo.Columns[2].ID)
count = statsTbl.ColumnGreaterRowCount(sctx, types.NewDatum(recordCount+1), tableInfo.Columns[2].ID)
require.Equal(t, 333, int(count))
}

Expand Down Expand Up @@ -131,11 +131,11 @@ func TestDDLHistogram(t *testing.T) {
tableInfo = tbl.Meta()
statsTbl = do.StatsHandle().GetTableStats(tableInfo)
require.False(t, statsTbl.Pseudo)
sc := new(stmtctx.StatementContext)
count, err := statsTbl.ColumnEqualRowCount(sc, types.NewIntDatum(0), tableInfo.Columns[3].ID)
sctx := mock.NewContext()
count, err := statsTbl.ColumnEqualRowCount(sctx, types.NewIntDatum(0), tableInfo.Columns[3].ID)
require.NoError(t, err)
require.Equal(t, float64(2), count)
count, err = statsTbl.ColumnEqualRowCount(sc, types.NewIntDatum(1), tableInfo.Columns[3].ID)
count, err = statsTbl.ColumnEqualRowCount(sctx, types.NewIntDatum(1), tableInfo.Columns[3].ID)
require.NoError(t, err)
require.Equal(t, float64(0), count)

Expand Down
13 changes: 6 additions & 7 deletions statistics/handle/handle_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -32,14 +32,14 @@ import (
"github.com/pingcap/tidb/parser/model"
"github.com/pingcap/tidb/parser/terror"
"github.com/pingcap/tidb/session"
"github.com/pingcap/tidb/sessionctx/stmtctx"
"github.com/pingcap/tidb/sessionctx/variable"
"github.com/pingcap/tidb/statistics"
"github.com/pingcap/tidb/statistics/handle"
"github.com/pingcap/tidb/store/mockstore"
"github.com/pingcap/tidb/types"
"github.com/pingcap/tidb/util/collate"
"github.com/pingcap/tidb/util/israce"
"github.com/pingcap/tidb/util/mock"
"github.com/pingcap/tidb/util/ranger"
"github.com/pingcap/tidb/util/testkit"
"github.com/tikv/client-go/v2/oracle"
Expand Down Expand Up @@ -267,8 +267,7 @@ func (s *testStatsSuite) TestEmptyTable(c *C) {
c.Assert(err, IsNil)
tableInfo := tbl.Meta()
statsTbl := do.StatsHandle().GetTableStats(tableInfo)
sc := new(stmtctx.StatementContext)
count := statsTbl.ColumnGreaterRowCount(sc, types.NewDatum(1), tableInfo.Columns[0].ID)
count := statsTbl.ColumnGreaterRowCount(mock.NewContext(), types.NewDatum(1), tableInfo.Columns[0].ID)
c.Assert(count, Equals, 0.0)
}

Expand All @@ -285,14 +284,14 @@ func (s *testStatsSuite) TestColumnIDs(c *C) {
c.Assert(err, IsNil)
tableInfo := tbl.Meta()
statsTbl := do.StatsHandle().GetTableStats(tableInfo)
sc := new(stmtctx.StatementContext)
sctx := mock.NewContext()
ran := &ranger.Range{
LowVal: []types.Datum{types.MinNotNullDatum()},
HighVal: []types.Datum{types.NewIntDatum(2)},
LowExclude: false,
HighExclude: true,
}
count, err := statsTbl.GetRowCountByColumnRanges(sc, tableInfo.Columns[0].ID, []*ranger.Range{ran})
count, err := statsTbl.GetRowCountByColumnRanges(sctx, tableInfo.Columns[0].ID, []*ranger.Range{ran})
c.Assert(err, IsNil)
c.Assert(count, Equals, float64(1))

Expand All @@ -307,7 +306,7 @@ func (s *testStatsSuite) TestColumnIDs(c *C) {
tableInfo = tbl.Meta()
statsTbl = do.StatsHandle().GetTableStats(tableInfo)
// At that time, we should get c2's stats instead of c1's.
count, err = statsTbl.GetRowCountByColumnRanges(sc, tableInfo.Columns[0].ID, []*ranger.Range{ran})
count, err = statsTbl.GetRowCountByColumnRanges(sctx, tableInfo.Columns[0].ID, []*ranger.Range{ran})
c.Assert(err, IsNil)
c.Assert(count, Equals, 0.0)
}
Expand Down Expand Up @@ -614,7 +613,7 @@ func (s *testStatsSuite) TestLoadStats(c *C) {
c.Assert(hg.Len(), Equals, 0)
cms = stat.Columns[tableInfo.Columns[2].ID].CMSketch
c.Assert(cms, IsNil)
_, err = stat.ColumnEqualRowCount(testKit.Se.GetSessionVars().StmtCtx, types.NewIntDatum(1), tableInfo.Columns[2].ID)
_, err = stat.ColumnEqualRowCount(testKit.Se, types.NewIntDatum(1), tableInfo.Columns[2].ID)
c.Assert(err, IsNil)
c.Assert(h.LoadNeededHistograms(), IsNil)
stat = h.GetTableStats(tableInfo)
Expand Down
38 changes: 31 additions & 7 deletions statistics/handle/update.go
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,7 @@ import (
"github.com/pingcap/tidb/parser/model"
"github.com/pingcap/tidb/parser/mysql"
"github.com/pingcap/tidb/parser/terror"
"github.com/pingcap/tidb/sessionctx"
"github.com/pingcap/tidb/sessionctx/stmtctx"
"github.com/pingcap/tidb/sessionctx/variable"
"github.com/pingcap/tidb/statistics"
Expand Down Expand Up @@ -1266,18 +1267,29 @@ func (h *Handle) RecalculateExpectCount(q *statistics.QueryFeedback) error {
return nil
}

sc := &stmtctx.StatementContext{TimeZone: time.UTC}
se, err := h.pool.Get()
if err != nil {
return err
}
sctx := se.(sessionctx.Context)
timeZone := sctx.GetSessionVars().StmtCtx.TimeZone
defer func() {
sctx.GetSessionVars().StmtCtx.TimeZone = timeZone
h.pool.Put(se)
}()
sctx.GetSessionVars().StmtCtx.TimeZone = time.UTC

ranges, err := q.DecodeToRanges(isIndex)
if err != nil {
return errors.Trace(err)
}
expected := 0.0
if isIndex {
idx := t.Indices[id]
expected, err = idx.GetRowCount(sc, nil, ranges, t.Count)
expected, err = idx.GetRowCount(sctx, nil, ranges, t.Count)
} else {
c := t.Columns[id]
expected, err = c.GetColumnRowCount(sc, ranges, t.Count, true)
expected, err = c.GetColumnRowCount(sctx, ranges, t.Count, true)
}
q.Expected = int64(expected)
return err
Expand Down Expand Up @@ -1354,7 +1366,20 @@ func (h *Handle) DumpFeedbackForIndex(q *statistics.QueryFeedback, t *statistics
if !ok {
return nil
}
sc := &stmtctx.StatementContext{TimeZone: time.UTC}

se, err := h.pool.Get()
if err != nil {
return err
}
sctx := se.(sessionctx.Context)
sc := sctx.GetSessionVars().StmtCtx
timeZone := sc.TimeZone
defer func() {
sctx.GetSessionVars().StmtCtx.TimeZone = timeZone
h.pool.Put(se)
}()
sc.TimeZone = time.UTC

if idx.CMSketch == nil || idx.StatsVer < statistics.Version1 {
return h.DumpFeedbackToKV(q)
}
Expand All @@ -1369,7 +1394,6 @@ func (h *Handle) DumpFeedbackForIndex(q *statistics.QueryFeedback, t *statistics
if rangePosition == 0 || rangePosition == len(ran.LowVal) {
continue
}

bytes, err := codec.EncodeKey(sc, nil, ran.LowVal[:rangePosition]...)
if err != nil {
logutil.BgLogger().Debug("encode keys fail", zap.Error(err))
Expand All @@ -1385,12 +1409,12 @@ func (h *Handle) DumpFeedbackForIndex(q *statistics.QueryFeedback, t *statistics
rangeFB := &statistics.QueryFeedback{PhysicalID: q.PhysicalID}
// prefer index stats over column stats
if idx := t.IndexStartWithColumn(colName); idx != nil && idx.Histogram.Len() != 0 {
rangeCount, err = t.GetRowCountByIndexRanges(sc, idx.ID, []*ranger.Range{rang})
rangeCount, err = t.GetRowCountByIndexRanges(sctx, idx.ID, []*ranger.Range{rang})
rangeFB.Tp, rangeFB.Hist = statistics.IndexType, &idx.Histogram
} else if col := t.ColumnByName(colName); col != nil && col.Histogram.Len() != 0 {
err = convertRangeType(rang, col.Tp, time.UTC)
if err == nil {
rangeCount, err = t.GetRowCountByColumnRanges(sc, col.ID, []*ranger.Range{rang})
rangeCount, err = t.GetRowCountByColumnRanges(sctx, col.ID, []*ranger.Range{rang})
rangeFB.Tp, rangeFB.Hist = statistics.ColType, &col.Histogram
}
} else {
Expand Down
2 changes: 1 addition & 1 deletion statistics/handle/update_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -149,7 +149,7 @@ func (s *testStatsSuite) TestSingleSessionInsert(c *C) {
c.Assert(stats1.Count, Equals, int64(rowCount1*2))

// Test IncreaseFactor.
count, err := stats1.ColumnEqualRowCount(testKit.Se.GetSessionVars().StmtCtx, types.NewIntDatum(1), tableInfo1.Columns[0].ID)
count, err := stats1.ColumnEqualRowCount(testKit.Se, types.NewIntDatum(1), tableInfo1.Columns[0].ID)
c.Assert(err, IsNil)
c.Assert(count, Equals, float64(rowCount1*2))

Expand Down
Loading

0 comments on commit 2f42f7c

Please sign in to comment.