Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

planner: unify the argument of stats functions to use SessionCtx instead of StatementContext #30668

Merged
merged 10 commits into from
Dec 14, 2021
17 changes: 8 additions & 9 deletions planner/core/find_best_task.go
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@ import (
"github.com/pingcap/tidb/parser/mysql"
"github.com/pingcap/tidb/planner/property"
"github.com/pingcap/tidb/planner/util"
"github.com/pingcap/tidb/sessionctx/stmtctx"
"github.com/pingcap/tidb/sessionctx"
"github.com/pingcap/tidb/statistics"
"github.com/pingcap/tidb/types"
tidbutil "github.com/pingcap/tidb/util"
Expand Down Expand Up @@ -1478,7 +1478,7 @@ func getMostCorrCol4Handle(exprs []expression.Expression, histColl *statistics.T
}

// getColumnRangeCounts estimates row count for each range respectively.
func getColumnRangeCounts(sc *stmtctx.StatementContext, colID int64, ranges []*ranger.Range, histColl *statistics.HistColl, idxID int64) ([]float64, bool) {
func getColumnRangeCounts(sctx sessionctx.Context, colID int64, ranges []*ranger.Range, histColl *statistics.HistColl, idxID int64) ([]float64, bool) {
var err error
var count float64
rangeCounts := make([]float64, len(ranges))
Expand All @@ -1488,13 +1488,13 @@ func getColumnRangeCounts(sc *stmtctx.StatementContext, colID int64, ranges []*r
if idxHist == nil || idxHist.IsInvalid(false) {
return nil, false
}
count, err = histColl.GetRowCountByIndexRanges(sc, idxID, []*ranger.Range{ran})
count, err = histColl.GetRowCountByIndexRanges(sctx, idxID, []*ranger.Range{ran})
} else {
colHist, ok := histColl.Columns[colID]
if !ok || colHist.IsInvalid(sc, false) {
if !ok || colHist.IsInvalid(sctx, false) {
return nil, false
}
count, err = histColl.GetRowCountByColumnRanges(sc, colID, []*ranger.Range{ran})
count, err = histColl.GetRowCountByColumnRanges(sctx, colID, []*ranger.Range{ran})
}
if err != nil {
return nil, false
Expand Down Expand Up @@ -1564,7 +1564,6 @@ func (ds *DataSource) crossEstimateRowCount(path *util.AccessPath, conds []expre
if len(accessConds) == 0 {
return 0, false, corr
}
sc := ds.ctx.GetSessionVars().StmtCtx
ranges, err := ranger.BuildColumnRange(accessConds, ds.ctx, col.RetType, types.UnspecifiedLength)
if len(ranges) == 0 || err != nil {
return 0, err == nil, corr
Expand All @@ -1573,7 +1572,7 @@ func (ds *DataSource) crossEstimateRowCount(path *util.AccessPath, conds []expre
if !idxExists {
idxID = -1
}
rangeCounts, ok := getColumnRangeCounts(sc, colID, ranges, ds.tableStats.HistColl, idxID)
rangeCounts, ok := getColumnRangeCounts(ds.ctx, colID, ranges, ds.tableStats.HistColl, idxID)
if !ok {
return 0, false, corr
}
Expand All @@ -1583,9 +1582,9 @@ func (ds *DataSource) crossEstimateRowCount(path *util.AccessPath, conds []expre
}
var rangeCount float64
if idxExists {
rangeCount, err = ds.tableStats.HistColl.GetRowCountByIndexRanges(sc, idxID, convertedRanges)
rangeCount, err = ds.tableStats.HistColl.GetRowCountByIndexRanges(ds.ctx, idxID, convertedRanges)
} else {
rangeCount, err = ds.tableStats.HistColl.GetRowCountByColumnRanges(sc, colID, convertedRanges)
rangeCount, err = ds.tableStats.HistColl.GetRowCountByColumnRanges(ds.ctx, colID, convertedRanges)
}
if err != nil {
return 0, false, corr
Expand Down
9 changes: 3 additions & 6 deletions planner/core/logical_plans.go
Original file line number Diff line number Diff line change
Expand Up @@ -726,7 +726,6 @@ func (ds *DataSource) deriveCommonHandleTablePathStats(path *util.AccessPath, co
if len(conds) == 0 {
return nil
}
sc := ds.ctx.GetSessionVars().StmtCtx
if len(path.IdxCols) != 0 {
res, err := ranger.DetachCondAndBuildRangeForIndex(ds.ctx, conds, path.IdxCols, path.IdxColLens)
if err != nil {
Expand All @@ -744,7 +743,7 @@ func (ds *DataSource) deriveCommonHandleTablePathStats(path *util.AccessPath, co
path.ConstCols[i] = res.ColumnValues[i] != nil
}
}
path.CountAfterAccess, err = ds.tableStats.HistColl.GetRowCountByIndexRanges(sc, path.Index.ID, path.Ranges)
path.CountAfterAccess, err = ds.tableStats.HistColl.GetRowCountByIndexRanges(ds.ctx, path.Index.ID, path.Ranges)
if err != nil {
return err
}
Expand Down Expand Up @@ -785,7 +784,6 @@ func (ds *DataSource) deriveTablePathStats(path *util.AccessPath, conds []expres
return ds.deriveCommonHandleTablePathStats(path, conds, isIm)
}
var err error
sc := ds.ctx.GetSessionVars().StmtCtx
path.CountAfterAccess = float64(ds.statisticTable.Count)
path.TableFilters = conds
var pkCol *expression.Column
Expand Down Expand Up @@ -848,7 +846,7 @@ func (ds *DataSource) deriveTablePathStats(path *util.AccessPath, conds []expres
if err != nil {
return err
}
path.CountAfterAccess, err = ds.statisticTable.GetRowCountByIntColumnRanges(sc, pkCol.ID, path.Ranges)
path.CountAfterAccess, err = ds.statisticTable.GetRowCountByIntColumnRanges(ds.ctx, pkCol.ID, path.Ranges)
// If the `CountAfterAccess` is less than `stats.RowCount`, there must be some inconsistent stats info.
// We prefer the `stats.RowCount` because it could use more stats info to calculate the selectivity.
if path.CountAfterAccess < ds.stats.RowCount && !isIm {
Expand All @@ -858,7 +856,6 @@ func (ds *DataSource) deriveTablePathStats(path *util.AccessPath, conds []expres
}

func (ds *DataSource) fillIndexPath(path *util.AccessPath, conds []expression.Expression) error {
sc := ds.ctx.GetSessionVars().StmtCtx
path.Ranges = ranger.FullRange()
path.CountAfterAccess = float64(ds.statisticTable.Count)
path.IdxCols, path.IdxColLens = expression.IndexInfo2PrefixCols(ds.Columns, ds.schema.Columns, path.Index)
Expand Down Expand Up @@ -900,7 +897,7 @@ func (ds *DataSource) fillIndexPath(path *util.AccessPath, conds []expression.Ex
path.ConstCols[i] = res.ColumnValues[i] != nil
}
}
path.CountAfterAccess, err = ds.tableStats.HistColl.GetRowCountByIndexRanges(sc, path.Index.ID, path.Ranges)
path.CountAfterAccess, err = ds.tableStats.HistColl.GetRowCountByIndexRanges(ds.ctx, path.Index.ID, path.Ranges)
if err != nil {
return err
}
Expand Down
6 changes: 3 additions & 3 deletions planner/core/rule_partition_processor.go
Original file line number Diff line number Diff line change
Expand Up @@ -140,7 +140,7 @@ func (s *partitionProcessor) findUsedPartitions(ctx sessionctx.Context, tbl tabl
ranges := detachedResult.Ranges
used := make([]int, 0, len(ranges))
for _, r := range ranges {
if r.IsPointNullable(ctx.GetSessionVars().StmtCtx) {
if r.IsPointNullable(ctx) {
if !r.HighVal[0].IsNull() {
if len(r.HighVal) != len(partIdx) {
used = []int{-1}
Expand Down Expand Up @@ -473,7 +473,7 @@ func (l *listPartitionPruner) locateColumnPartitionsByCondition(cond expression.
return nil, true, nil
}
var locations []tables.ListPartitionLocation
if r.IsPointNullable(l.ctx.GetSessionVars().StmtCtx) {
if r.IsPointNullable(l.ctx) {
location, err := colPrune.LocatePartition(sc, r.HighVal[0])
if types.ErrOverflow.Equal(err) {
return nil, true, nil // return full-scan if over-flow
Expand Down Expand Up @@ -555,7 +555,7 @@ func (l *listPartitionPruner) findUsedListPartitions(conds []expression.Expressi
}
used := make(map[int]struct{}, len(ranges))
for _, r := range ranges {
if r.IsPointNullable(l.ctx.GetSessionVars().StmtCtx) {
if r.IsPointNullable(l.ctx) {
if len(r.HighVal) != len(exprCols) {
return l.fullRange, nil
}
Expand Down
2 changes: 1 addition & 1 deletion planner/core/stats.go
Original file line number Diff line number Diff line change
Expand Up @@ -253,7 +253,7 @@ func (ds *DataSource) deriveStatsByFilter(conds expression.CNFExprs, filledPaths
}
stats := ds.tableStats.Scale(selectivity)
if ds.ctx.GetSessionVars().OptimizerSelectivityLevel >= 1 {
stats.HistColl = stats.HistColl.NewHistCollBySelectivity(ds.ctx.GetSessionVars().StmtCtx, nodes)
stats.HistColl = stats.HistColl.NewHistCollBySelectivity(ds.ctx, nodes)
}
return stats
}
Expand Down
2 changes: 1 addition & 1 deletion planner/util/path.go
Original file line number Diff line number Diff line change
Expand Up @@ -153,7 +153,7 @@ func isColEqCorColOrConstant(ctx sessionctx.Context, filter expression.Expressio
func (path *AccessPath) OnlyPointRange(sctx sessionctx.Context) bool {
if path.IsIntHandlePath {
for _, ran := range path.Ranges {
if !ran.IsPointNullable(sctx.GetSessionVars().StmtCtx) {
if !ran.IsPointNullable(sctx) {
return false
}
}
Expand Down
14 changes: 7 additions & 7 deletions statistics/handle/ddl_serial_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -18,10 +18,10 @@ import (
"testing"

"github.com/pingcap/tidb/parser/model"
"github.com/pingcap/tidb/sessionctx/stmtctx"
"github.com/pingcap/tidb/sessionctx/variable"
"github.com/pingcap/tidb/testkit"
"github.com/pingcap/tidb/types"
"github.com/pingcap/tidb/util/mock"
"github.com/stretchr/testify/require"
)

Expand Down Expand Up @@ -51,10 +51,10 @@ func TestDDLAfterLoad(t *testing.T) {
require.NoError(t, err)
tableInfo = tbl.Meta()

sc := new(stmtctx.StatementContext)
count := statsTbl.ColumnGreaterRowCount(sc, types.NewDatum(recordCount+1), tableInfo.Columns[0].ID)
sctx := mock.NewContext()
count := statsTbl.ColumnGreaterRowCount(sctx, types.NewDatum(recordCount+1), tableInfo.Columns[0].ID)
require.Equal(t, 0.0, count)
count = statsTbl.ColumnGreaterRowCount(sc, types.NewDatum(recordCount+1), tableInfo.Columns[2].ID)
count = statsTbl.ColumnGreaterRowCount(sctx, types.NewDatum(recordCount+1), tableInfo.Columns[2].ID)
require.Equal(t, 333, int(count))
}

Expand Down Expand Up @@ -131,11 +131,11 @@ func TestDDLHistogram(t *testing.T) {
tableInfo = tbl.Meta()
statsTbl = do.StatsHandle().GetTableStats(tableInfo)
require.False(t, statsTbl.Pseudo)
sc := new(stmtctx.StatementContext)
count, err := statsTbl.ColumnEqualRowCount(sc, types.NewIntDatum(0), tableInfo.Columns[3].ID)
sctx := mock.NewContext()
count, err := statsTbl.ColumnEqualRowCount(sctx, types.NewIntDatum(0), tableInfo.Columns[3].ID)
require.NoError(t, err)
require.Equal(t, float64(2), count)
count, err = statsTbl.ColumnEqualRowCount(sc, types.NewIntDatum(1), tableInfo.Columns[3].ID)
count, err = statsTbl.ColumnEqualRowCount(sctx, types.NewIntDatum(1), tableInfo.Columns[3].ID)
require.NoError(t, err)
require.Equal(t, float64(0), count)

Expand Down
13 changes: 6 additions & 7 deletions statistics/handle/handle_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -32,14 +32,14 @@ import (
"github.com/pingcap/tidb/parser/model"
"github.com/pingcap/tidb/parser/terror"
"github.com/pingcap/tidb/session"
"github.com/pingcap/tidb/sessionctx/stmtctx"
"github.com/pingcap/tidb/sessionctx/variable"
"github.com/pingcap/tidb/statistics"
"github.com/pingcap/tidb/statistics/handle"
"github.com/pingcap/tidb/store/mockstore"
"github.com/pingcap/tidb/types"
"github.com/pingcap/tidb/util/collate"
"github.com/pingcap/tidb/util/israce"
"github.com/pingcap/tidb/util/mock"
"github.com/pingcap/tidb/util/ranger"
"github.com/pingcap/tidb/util/testkit"
"github.com/tikv/client-go/v2/oracle"
Expand Down Expand Up @@ -267,8 +267,7 @@ func (s *testStatsSuite) TestEmptyTable(c *C) {
c.Assert(err, IsNil)
tableInfo := tbl.Meta()
statsTbl := do.StatsHandle().GetTableStats(tableInfo)
sc := new(stmtctx.StatementContext)
count := statsTbl.ColumnGreaterRowCount(sc, types.NewDatum(1), tableInfo.Columns[0].ID)
count := statsTbl.ColumnGreaterRowCount(mock.NewContext(), types.NewDatum(1), tableInfo.Columns[0].ID)
c.Assert(count, Equals, 0.0)
}

Expand All @@ -285,14 +284,14 @@ func (s *testStatsSuite) TestColumnIDs(c *C) {
c.Assert(err, IsNil)
tableInfo := tbl.Meta()
statsTbl := do.StatsHandle().GetTableStats(tableInfo)
sc := new(stmtctx.StatementContext)
sctx := mock.NewContext()
ran := &ranger.Range{
LowVal: []types.Datum{types.MinNotNullDatum()},
HighVal: []types.Datum{types.NewIntDatum(2)},
LowExclude: false,
HighExclude: true,
}
count, err := statsTbl.GetRowCountByColumnRanges(sc, tableInfo.Columns[0].ID, []*ranger.Range{ran})
count, err := statsTbl.GetRowCountByColumnRanges(sctx, tableInfo.Columns[0].ID, []*ranger.Range{ran})
c.Assert(err, IsNil)
c.Assert(count, Equals, float64(1))

Expand All @@ -307,7 +306,7 @@ func (s *testStatsSuite) TestColumnIDs(c *C) {
tableInfo = tbl.Meta()
statsTbl = do.StatsHandle().GetTableStats(tableInfo)
// At that time, we should get c2's stats instead of c1's.
count, err = statsTbl.GetRowCountByColumnRanges(sc, tableInfo.Columns[0].ID, []*ranger.Range{ran})
count, err = statsTbl.GetRowCountByColumnRanges(sctx, tableInfo.Columns[0].ID, []*ranger.Range{ran})
c.Assert(err, IsNil)
c.Assert(count, Equals, 0.0)
}
Expand Down Expand Up @@ -614,7 +613,7 @@ func (s *testStatsSuite) TestLoadStats(c *C) {
c.Assert(hg.Len(), Equals, 0)
cms = stat.Columns[tableInfo.Columns[2].ID].CMSketch
c.Assert(cms, IsNil)
_, err = stat.ColumnEqualRowCount(testKit.Se.GetSessionVars().StmtCtx, types.NewIntDatum(1), tableInfo.Columns[2].ID)
_, err = stat.ColumnEqualRowCount(testKit.Se, types.NewIntDatum(1), tableInfo.Columns[2].ID)
c.Assert(err, IsNil)
c.Assert(h.LoadNeededHistograms(), IsNil)
stat = h.GetTableStats(tableInfo)
Expand Down
38 changes: 31 additions & 7 deletions statistics/handle/update.go
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,7 @@ import (
"github.com/pingcap/tidb/parser/model"
"github.com/pingcap/tidb/parser/mysql"
"github.com/pingcap/tidb/parser/terror"
"github.com/pingcap/tidb/sessionctx"
"github.com/pingcap/tidb/sessionctx/stmtctx"
"github.com/pingcap/tidb/sessionctx/variable"
"github.com/pingcap/tidb/statistics"
Expand Down Expand Up @@ -1266,18 +1267,29 @@ func (h *Handle) RecalculateExpectCount(q *statistics.QueryFeedback) error {
return nil
}

sc := &stmtctx.StatementContext{TimeZone: time.UTC}
se, err := h.pool.Get()
if err != nil {
return err
}
sctx := se.(sessionctx.Context)
timeZone := sctx.GetSessionVars().StmtCtx.TimeZone
defer func() {
sctx.GetSessionVars().StmtCtx.TimeZone = timeZone
h.pool.Put(se)
}()
sctx.GetSessionVars().StmtCtx.TimeZone = time.UTC

ranges, err := q.DecodeToRanges(isIndex)
if err != nil {
return errors.Trace(err)
}
expected := 0.0
if isIndex {
idx := t.Indices[id]
expected, err = idx.GetRowCount(sc, nil, ranges, t.Count)
expected, err = idx.GetRowCount(sctx, nil, ranges, t.Count)
} else {
c := t.Columns[id]
expected, err = c.GetColumnRowCount(sc, ranges, t.Count, true)
expected, err = c.GetColumnRowCount(sctx, ranges, t.Count, true)
}
q.Expected = int64(expected)
return err
Expand Down Expand Up @@ -1354,7 +1366,20 @@ func (h *Handle) DumpFeedbackForIndex(q *statistics.QueryFeedback, t *statistics
if !ok {
return nil
}
sc := &stmtctx.StatementContext{TimeZone: time.UTC}

se, err := h.pool.Get()
if err != nil {
return err
}
sctx := se.(sessionctx.Context)
sc := sctx.GetSessionVars().StmtCtx
timeZone := sc.TimeZone
defer func() {
sctx.GetSessionVars().StmtCtx.TimeZone = timeZone
h.pool.Put(se)
}()
sc.TimeZone = time.UTC

if idx.CMSketch == nil || idx.StatsVer < statistics.Version1 {
return h.DumpFeedbackToKV(q)
}
Expand All @@ -1369,7 +1394,6 @@ func (h *Handle) DumpFeedbackForIndex(q *statistics.QueryFeedback, t *statistics
if rangePosition == 0 || rangePosition == len(ran.LowVal) {
continue
}

bytes, err := codec.EncodeKey(sc, nil, ran.LowVal[:rangePosition]...)
if err != nil {
logutil.BgLogger().Debug("encode keys fail", zap.Error(err))
Expand All @@ -1385,12 +1409,12 @@ func (h *Handle) DumpFeedbackForIndex(q *statistics.QueryFeedback, t *statistics
rangeFB := &statistics.QueryFeedback{PhysicalID: q.PhysicalID}
// prefer index stats over column stats
if idx := t.IndexStartWithColumn(colName); idx != nil && idx.Histogram.Len() != 0 {
rangeCount, err = t.GetRowCountByIndexRanges(sc, idx.ID, []*ranger.Range{rang})
rangeCount, err = t.GetRowCountByIndexRanges(sctx, idx.ID, []*ranger.Range{rang})
rangeFB.Tp, rangeFB.Hist = statistics.IndexType, &idx.Histogram
} else if col := t.ColumnByName(colName); col != nil && col.Histogram.Len() != 0 {
err = convertRangeType(rang, col.Tp, time.UTC)
if err == nil {
rangeCount, err = t.GetRowCountByColumnRanges(sc, col.ID, []*ranger.Range{rang})
rangeCount, err = t.GetRowCountByColumnRanges(sctx, col.ID, []*ranger.Range{rang})
rangeFB.Tp, rangeFB.Hist = statistics.ColType, &col.Histogram
}
} else {
Expand Down
2 changes: 1 addition & 1 deletion statistics/handle/update_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -149,7 +149,7 @@ func (s *testStatsSuite) TestSingleSessionInsert(c *C) {
c.Assert(stats1.Count, Equals, int64(rowCount1*2))

// Test IncreaseFactor.
count, err := stats1.ColumnEqualRowCount(testKit.Se.GetSessionVars().StmtCtx, types.NewIntDatum(1), tableInfo1.Columns[0].ID)
count, err := stats1.ColumnEqualRowCount(testKit.Se, types.NewIntDatum(1), tableInfo1.Columns[0].ID)
c.Assert(err, IsNil)
c.Assert(count, Equals, float64(rowCount1*2))

Expand Down
Loading