diff --git a/DEPS.bzl b/DEPS.bzl index ac0a348ad55fe..dbd1edbaf98d6 100644 --- a/DEPS.bzl +++ b/DEPS.bzl @@ -4467,8 +4467,8 @@ def go_deps(): name = "org_golang_x_time", build_file_proto_mode = "disable_global", importpath = "golang.org/x/time", - sum = "h1:52I/1L54xyEQAYdtcSuxtiT84KGYTBGXwayxmIpNJhE=", - version = "v0.2.0", + sum = "h1:rg5rLMjNzMS1RkNLzCG38eapWhnYLFYXDXj2gOlr8j4=", + version = "v0.3.0", ) go_repository( name = "org_golang_x_tools", diff --git a/br/pkg/lightning/backend/kv/sql2kv.go b/br/pkg/lightning/backend/kv/sql2kv.go index 6cebb1e29e329..9ad552ef5f340 100644 --- a/br/pkg/lightning/backend/kv/sql2kv.go +++ b/br/pkg/lightning/backend/kv/sql2kv.go @@ -169,7 +169,7 @@ func collectGeneratedColumns(se *session, meta *model.TableInfo, cols []*table.C var genCols []genCol for i, col := range cols { if col.GeneratedExpr != nil { - expr, err := expression.RewriteAstExpr(se, col.GeneratedExpr, schema, names) + expr, err := expression.RewriteAstExpr(se, col.GeneratedExpr, schema, names, false) if err != nil { return nil, err } diff --git a/ddl/backfilling.go b/ddl/backfilling.go index d1035bad084bd..0f0910e1caf28 100644 --- a/ddl/backfilling.go +++ b/ddl/backfilling.go @@ -807,12 +807,7 @@ func (b *backfillScheduler) initCopReqSenderPool() { logutil.BgLogger().Warn("[ddl-ingest] cannot init cop request sender", zap.Error(err)) return } - ver, err := sessCtx.GetStore().CurrentVersion(kv.GlobalTxnScope) - if err != nil { - logutil.BgLogger().Warn("[ddl-ingest] cannot init cop request sender", zap.Error(err)) - return - } - b.copReqSenderPool = newCopReqSenderPool(b.ctx, copCtx, ver.Ver) + b.copReqSenderPool = newCopReqSenderPool(b.ctx, copCtx, sessCtx.GetStore()) } func (b *backfillScheduler) canSkipError(err error) bool { diff --git a/ddl/ddl_api.go b/ddl/ddl_api.go index 7c301598f5238..a5e89e4996d0a 100644 --- a/ddl/ddl_api.go +++ b/ddl/ddl_api.go @@ -6172,7 +6172,7 @@ func (d *ddl) CreatePrimaryKey(ctx sessionctx.Context, ti ast.Ident, indexName m // After DDL job is put to the queue, and if the check fail, TiDB will run the DDL cancel logic. // The recover step causes DDL wait a few seconds, makes the unit test painfully slow. // For same reason, decide whether index is global here. - indexColumns, err := buildIndexColumns(ctx, tblInfo.Columns, indexPartSpecifications) + indexColumns, _, err := buildIndexColumns(ctx, tblInfo.Columns, indexPartSpecifications) if err != nil { return errors.Trace(err) } @@ -6282,7 +6282,7 @@ func BuildHiddenColumnInfo(ctx sessionctx.Context, indexPartSpecifications []*as if err != nil { return nil, errors.Trace(err) } - expr, err := expression.RewriteSimpleExprWithTableInfo(ctx, tblInfo, idxPart.Expr) + expr, err := expression.RewriteSimpleExprWithTableInfo(ctx, tblInfo, idxPart.Expr, true) if err != nil { // TODO: refine the error message. return nil, err @@ -6397,7 +6397,7 @@ func (d *ddl) createIndex(ctx sessionctx.Context, ti ast.Ident, keyType ast.Inde // After DDL job is put to the queue, and if the check fail, TiDB will run the DDL cancel logic. // The recover step causes DDL wait a few seconds, makes the unit test painfully slow. // For same reason, decide whether index is global here. - indexColumns, err := buildIndexColumns(ctx, finalColumns, indexPartSpecifications) + indexColumns, _, err := buildIndexColumns(ctx, finalColumns, indexPartSpecifications) if err != nil { return errors.Trace(err) } diff --git a/ddl/ddl_worker.go b/ddl/ddl_worker.go index 8621dcb08361c..89e515db8e1bc 100644 --- a/ddl/ddl_worker.go +++ b/ddl/ddl_worker.go @@ -536,16 +536,20 @@ func cleanMDLInfo(pool *sessionPool, jobID int64, ec *clientv3.Client) { } // checkMDLInfo checks if metadata lock info exists. It means the schema is locked by some TiDBs if exists. -func checkMDLInfo(jobID int64, pool *sessionPool) (bool, error) { - sql := fmt.Sprintf("select * from mysql.tidb_mdl_info where job_id = %d", jobID) +func checkMDLInfo(jobID int64, pool *sessionPool) (bool, int64, error) { + sql := fmt.Sprintf("select version from mysql.tidb_mdl_info where job_id = %d", jobID) sctx, _ := pool.get() defer pool.put(sctx) sess := newSession(sctx) rows, err := sess.execute(context.Background(), sql, "check-mdl-info") if err != nil { - return false, err + return false, 0, err } - return len(rows) > 0, nil + if len(rows) == 0 { + return false, 0, nil + } + ver := rows[0].GetInt64(0) + return true, ver, nil } func needUpdateRawArgs(job *model.Job, meetErr bool) bool { @@ -1377,6 +1381,32 @@ func waitSchemaChanged(ctx context.Context, d *ddlCtx, waitTime time.Duration, l zap.String("job", job.String())) } +// waitSchemaSyncedForMDL likes waitSchemaSynced, but it waits for getting the metadata lock of the latest version of this DDL. +func waitSchemaSyncedForMDL(d *ddlCtx, job *model.Job, latestSchemaVersion int64) error { + failpoint.Inject("checkDownBeforeUpdateGlobalVersion", func(val failpoint.Value) { + if val.(bool) { + if mockDDLErrOnce > 0 && mockDDLErrOnce != latestSchemaVersion { + panic("check down before update global version failed") + } else { + mockDDLErrOnce = -1 + } + } + }) + + timeStart := time.Now() + // OwnerCheckAllVersions returns only when all TiDB schemas are synced(exclude the isolated TiDB). + err := d.schemaSyncer.OwnerCheckAllVersions(context.Background(), job.ID, latestSchemaVersion) + if err != nil { + logutil.Logger(d.ctx).Info("[ddl] wait latest schema version encounter error", zap.Int64("ver", latestSchemaVersion), zap.Error(err)) + return err + } + logutil.Logger(d.ctx).Info("[ddl] wait latest schema version changed(get the metadata lock if tidb_enable_metadata_lock is true)", + zap.Int64("ver", latestSchemaVersion), + zap.Duration("take time", time.Since(timeStart)), + zap.String("job", job.String())) + return nil +} + // waitSchemaSynced handles the following situation: // If the job enters a new state, and the worker crashs when it's in the process of waiting for 2 * lease time, // Then the worker restarts quickly, we may run the job immediately again, diff --git a/ddl/export_test.go b/ddl/export_test.go index 486390f9a6810..3ea26fb04290c 100644 --- a/ddl/export_test.go +++ b/ddl/export_test.go @@ -28,7 +28,7 @@ func SetBatchInsertDeleteRangeSize(i int) { var NewCopContext4Test = newCopContext -func FetchRowsFromCop4Test(copCtx *copContext, startKey, endKey kv.Key, startTS uint64, +func FetchRowsFromCop4Test(copCtx *copContext, startKey, endKey kv.Key, store kv.Storage, batchSize int) ([]*indexRecord, bool, error) { variable.SetDDLReorgBatchSize(int32(batchSize)) task := &reorgBackfillTask{ @@ -36,7 +36,7 @@ func FetchRowsFromCop4Test(copCtx *copContext, startKey, endKey kv.Key, startTS startKey: startKey, endKey: endKey, } - pool := newCopReqSenderPool(context.Background(), copCtx, startTS) + pool := newCopReqSenderPool(context.Background(), copCtx, store) pool.adjustSize(1) pool.tasksCh <- task idxRec, _, _, done, err := pool.fetchRowColValsFromCop(*task) diff --git a/ddl/fktest/foreign_key_test.go b/ddl/fktest/foreign_key_test.go index 741c296c69459..f64de90ca4955 100644 --- a/ddl/fktest/foreign_key_test.go +++ b/ddl/fktest/foreign_key_test.go @@ -319,6 +319,24 @@ func TestCreateTableWithForeignKeyPrivilegeCheck(t *testing.T) { tk2.MustExec("create table t4 (a int, foreign key fk(a) references t1(id), foreign key (a) references t3(id));") } +func TestAlterTableWithForeignKeyPrivilegeCheck(t *testing.T) { + store, _ := testkit.CreateMockStoreAndDomain(t) + tk := testkit.NewTestKit(t, store) + tk.MustExec("use test") + tk.MustExec("create user 'u1'@'%' identified by '';") + tk.MustExec("grant create,alter on *.* to 'u1'@'%';") + tk.MustExec("create table t1 (id int key);") + tk2 := testkit.NewTestKit(t, store) + tk2.MustExec("use test") + tk2.Session().Auth(&auth.UserIdentity{Username: "u1", Hostname: "localhost", CurrentUser: true, AuthUsername: "u1", AuthHostname: "%"}, nil, []byte("012345678901234567890")) + tk2.MustExec("create table t2 (a int)") + err := tk2.ExecToErr("alter table t2 add foreign key (a) references t1 (id) on update cascade") + require.Error(t, err) + require.Equal(t, "[planner:1142]REFERENCES command denied to user 'u1'@'%' for table 't1'", err.Error()) + tk.MustExec("grant references on test.t1 to 'u1'@'%';") + tk2.MustExec("alter table t2 add foreign key (a) references t1 (id) on update cascade") +} + func TestRenameTableWithForeignKeyMetaInfo(t *testing.T) { store, dom := testkit.CreateMockStoreAndDomain(t) tk := testkit.NewTestKit(t, store) diff --git a/ddl/generated_column.go b/ddl/generated_column.go index 2f4ceee8b60a9..678d803edf521 100644 --- a/ddl/generated_column.go +++ b/ddl/generated_column.go @@ -268,12 +268,14 @@ func checkModifyGeneratedColumn(sctx sessionctx.Context, tbl table.Table, oldCol } type illegalFunctionChecker struct { - hasIllegalFunc bool - hasAggFunc bool - hasRowVal bool // hasRowVal checks whether the functional index refers to a row value - hasWindowFunc bool - hasNotGAFunc4ExprIdx bool - otherErr error + hasIllegalFunc bool + hasAggFunc bool + hasRowVal bool // hasRowVal checks whether the functional index refers to a row value + hasWindowFunc bool + hasNotGAFunc4ExprIdx bool + hasCastArrayFunc bool + disallowCastArrayFunc bool + otherErr error } func (c *illegalFunctionChecker) Enter(inNode ast.Node) (outNode ast.Node, skipChildren bool) { @@ -308,7 +310,14 @@ func (c *illegalFunctionChecker) Enter(inNode ast.Node) (outNode ast.Node, skipC case *ast.WindowFuncExpr: c.hasWindowFunc = true return inNode, true + case *ast.FuncCastExpr: + c.hasCastArrayFunc = c.hasCastArrayFunc || node.Tp.IsArray() + if c.disallowCastArrayFunc && node.Tp.IsArray() { + c.otherErr = expression.ErrNotSupportedYet.GenWithStackByArgs("Use of CAST( .. AS .. ARRAY) outside of functional index in CREATE(non-SELECT)/ALTER TABLE or in general expressions") + return inNode, true + } } + c.disallowCastArrayFunc = true return inNode, false } @@ -355,6 +364,9 @@ func checkIllegalFn4Generated(name string, genType int, expr ast.ExprNode) error if genType == typeIndex && c.hasNotGAFunc4ExprIdx && !config.GetGlobalConfig().Experimental.AllowsExpressionIndex { return dbterror.ErrUnsupportedExpressionIndex } + if genType == typeColumn && c.hasCastArrayFunc { + return expression.ErrNotSupportedYet.GenWithStackByArgs("Use of CAST( .. AS .. ARRAY) outside of functional index in CREATE(non-SELECT)/ALTER TABLE or in general expressions") + } return nil } diff --git a/ddl/index.go b/ddl/index.go index 0f70b73b61046..273b89e041233 100644 --- a/ddl/index.go +++ b/ddl/index.go @@ -64,26 +64,28 @@ var ( telemetryAddIndexIngestUsage = metrics.TelemetryAddIndexIngestCnt ) -func buildIndexColumns(ctx sessionctx.Context, columns []*model.ColumnInfo, indexPartSpecifications []*ast.IndexPartSpecification) ([]*model.IndexColumn, error) { +func buildIndexColumns(ctx sessionctx.Context, columns []*model.ColumnInfo, indexPartSpecifications []*ast.IndexPartSpecification) ([]*model.IndexColumn, bool, error) { // Build offsets. idxParts := make([]*model.IndexColumn, 0, len(indexPartSpecifications)) var col *model.ColumnInfo + var mvIndex bool maxIndexLength := config.GetGlobalConfig().MaxIndexLength // The sum of length of all index columns. sumLength := 0 for _, ip := range indexPartSpecifications { col = model.FindColumnInfo(columns, ip.Column.Name.L) if col == nil { - return nil, dbterror.ErrKeyColumnDoesNotExits.GenWithStack("column does not exist: %s", ip.Column.Name) + return nil, false, dbterror.ErrKeyColumnDoesNotExits.GenWithStack("column does not exist: %s", ip.Column.Name) } if err := checkIndexColumn(ctx, col, ip.Length); err != nil { - return nil, err + return nil, false, err } + mvIndex = mvIndex || col.FieldType.IsArray() indexColLen := ip.Length indexColumnLength, err := getIndexColumnLength(col, ip.Length) if err != nil { - return nil, err + return nil, false, err } sumLength += indexColumnLength @@ -92,12 +94,12 @@ func buildIndexColumns(ctx sessionctx.Context, columns []*model.ColumnInfo, inde // The multiple column index and the unique index in which the length sum exceeds the maximum size // will return an error instead produce a warning. if ctx == nil || ctx.GetSessionVars().StrictSQLMode || mysql.HasUniKeyFlag(col.GetFlag()) || len(indexPartSpecifications) > 1 { - return nil, dbterror.ErrTooLongKey.GenWithStackByArgs(maxIndexLength) + return nil, false, dbterror.ErrTooLongKey.GenWithStackByArgs(maxIndexLength) } // truncate index length and produce warning message in non-restrict sql mode. colLenPerUint, err := getIndexColumnLength(col, 1) if err != nil { - return nil, err + return nil, false, err } indexColLen = maxIndexLength / colLenPerUint // produce warning message @@ -111,7 +113,7 @@ func buildIndexColumns(ctx sessionctx.Context, columns []*model.ColumnInfo, inde }) } - return idxParts, nil + return idxParts, mvIndex, nil } // CheckPKOnGeneratedColumn checks the specification of PK is valid. @@ -154,7 +156,7 @@ func checkIndexColumn(ctx sessionctx.Context, col *model.ColumnInfo, indexColumn } // JSON column cannot index. - if col.FieldType.GetType() == mysql.TypeJSON { + if col.FieldType.GetType() == mysql.TypeJSON && !col.FieldType.IsArray() { if col.Hidden { return dbterror.ErrFunctionalIndexOnJSONOrGeometryFunction } @@ -263,7 +265,7 @@ func BuildIndexInfo( return nil, errors.Trace(err) } - idxColumns, err := buildIndexColumns(ctx, allTableColumns, indexPartSpecifications) + idxColumns, mvIndex, err := buildIndexColumns(ctx, allTableColumns, indexPartSpecifications) if err != nil { return nil, errors.Trace(err) } @@ -276,6 +278,7 @@ func BuildIndexInfo( Primary: isPrimary, Unique: isUnique, Global: isGlobal, + MVIndex: mvIndex, } if indexOption != nil { diff --git a/ddl/index_cop.go b/ddl/index_cop.go index 0a04ac63eb190..fab097727139b 100644 --- a/ddl/index_cop.go +++ b/ddl/index_cop.go @@ -103,9 +103,9 @@ type copReqSenderPool struct { resultsCh chan idxRecResult results generic.SyncMap[int, struct{}] - ctx context.Context - copCtx *copContext - startTS uint64 + ctx context.Context + copCtx *copContext + store kv.Storage senders []*copReqSender wg sync.WaitGroup @@ -139,7 +139,12 @@ func (c *copReqSender) run() { curTaskID = task.id logutil.BgLogger().Info("[ddl-ingest] start a cop-request task", zap.Int("id", task.id), zap.String("task", task.String())) - rs, err := p.copCtx.buildTableScan(p.ctx, p.startTS, task.startKey, task.excludedEndKey()) + ver, err := p.store.CurrentVersion(kv.GlobalTxnScope) + if err != nil { + p.resultsCh <- idxRecResult{id: task.id, err: err} + return + } + rs, err := p.copCtx.buildTableScan(p.ctx, ver.Ver, task.startKey, task.excludedEndKey()) if err != nil { p.resultsCh <- idxRecResult{id: task.id, err: err} return @@ -167,7 +172,7 @@ func (c *copReqSender) run() { } } -func newCopReqSenderPool(ctx context.Context, copCtx *copContext, startTS uint64) *copReqSenderPool { +func newCopReqSenderPool(ctx context.Context, copCtx *copContext, store kv.Storage) *copReqSenderPool { poolSize := copReadChunkPoolSize() idxBufPool := make(chan []*indexRecord, poolSize) srcChkPool := make(chan *chunk.Chunk, poolSize) @@ -181,7 +186,7 @@ func newCopReqSenderPool(ctx context.Context, copCtx *copContext, startTS uint64 results: generic.NewSyncMap[int, struct{}](10), ctx: ctx, copCtx: copCtx, - startTS: startTS, + store: store, senders: make([]*copReqSender, 0, variable.GetDDLReorgWorkerCounter()), wg: sync.WaitGroup{}, idxBufPool: idxBufPool, diff --git a/ddl/index_cop_test.go b/ddl/index_cop_test.go index 80e37f6a74121..38bced0b6678d 100644 --- a/ddl/index_cop_test.go +++ b/ddl/index_cop_test.go @@ -43,7 +43,7 @@ func TestAddIndexFetchRowsFromCoprocessor(t *testing.T) { endKey := startKey.PrefixNext() txn, err := store.Begin() require.NoError(t, err) - idxRec, done, err := ddl.FetchRowsFromCop4Test(copCtx, startKey, endKey, txn.StartTS(), 10) + idxRec, done, err := ddl.FetchRowsFromCop4Test(copCtx, startKey, endKey, store, 10) require.NoError(t, err) require.False(t, done) require.NoError(t, txn.Rollback()) diff --git a/ddl/job_table.go b/ddl/job_table.go index a6e19b7f7edf0..771a83b8f8264 100644 --- a/ddl/job_table.go +++ b/ddl/job_table.go @@ -237,7 +237,7 @@ func (d *ddl) delivery2worker(wk *worker, pool *workerPool, job *model.Job) { // check if this ddl job is synced to all servers. if !d.isSynced(job) || d.once.Load() { if variable.EnableMDL.Load() { - exist, err := checkMDLInfo(job.ID, d.sessPool) + exist, version, err := checkMDLInfo(job.ID, d.sessPool) if err != nil { logutil.BgLogger().Warn("[ddl] check MDL info failed", zap.Error(err), zap.String("job", job.String())) // Release the worker resource. @@ -246,10 +246,8 @@ func (d *ddl) delivery2worker(wk *worker, pool *workerPool, job *model.Job) { } else if exist { // Release the worker resource. pool.put(wk) - err = waitSchemaSynced(d.ddlCtx, job, 2*d.lease) + err = waitSchemaSyncedForMDL(d.ddlCtx, job, version) if err != nil { - logutil.BgLogger().Warn("[ddl] wait ddl job sync failed", zap.Error(err), zap.String("job", job.String())) - time.Sleep(time.Second) return } d.once.Store(false) diff --git a/ddl/partition.go b/ddl/partition.go index 0a1ea4e6fbe66..2c95f389707f9 100644 --- a/ddl/partition.go +++ b/ddl/partition.go @@ -1375,7 +1375,7 @@ func checkPartitionFuncType(ctx sessionctx.Context, expr ast.ExprNode, tblInfo * return nil } - e, err := expression.RewriteSimpleExprWithTableInfo(ctx, tblInfo, expr) + e, err := expression.RewriteSimpleExprWithTableInfo(ctx, tblInfo, expr, false) if err != nil { return errors.Trace(err) } diff --git a/docs/design/2020-08-04-global-index.md b/docs/design/2020-08-04-global-index.md index 80078688777b7..f5e2d89f932c4 100644 --- a/docs/design/2020-08-04-global-index.md +++ b/docs/design/2020-08-04-global-index.md @@ -183,7 +183,7 @@ In TiDB, operators in the partitioned table will be translated to UnionAll in th ## Compatibility -MySQL does not support global index, which means this feature may cause some compatibility issues. We add an option `enable_global_index` in `config.Config` to control it. The default value of this option is `false`, so TiDB will keep consistent with MySQL, unless the user open global index feature manually. +MySQL does not support global index, which means this feature may cause some compatibility issues. We add an option `enable-global-index` in `config.Config` to control it. The default value of this option is `false`, so TiDB will keep consistent with MySQL, unless the user open global index feature manually. ## Implementation diff --git a/executor/analyzetest/BUILD.bazel b/executor/analyzetest/BUILD.bazel index 53126213363a5..3112abe57c00f 100644 --- a/executor/analyzetest/BUILD.bazel +++ b/executor/analyzetest/BUILD.bazel @@ -8,7 +8,6 @@ go_test( "main_test.go", ], flaky = True, - race = "on", shard_count = 50, deps = [ "//domain", @@ -30,6 +29,7 @@ go_test( "//tablecodec", "//testkit", "//types", + "//util", "//util/codec", "@com_github_pingcap_errors//:errors", "@com_github_pingcap_failpoint//:failpoint", diff --git a/executor/analyzetest/analyze_test.go b/executor/analyzetest/analyze_test.go index e3bf9d51d9260..55f3ad9397be9 100644 --- a/executor/analyzetest/analyze_test.go +++ b/executor/analyzetest/analyze_test.go @@ -17,6 +17,7 @@ package analyzetest import ( "context" "fmt" + "runtime" "strconv" "strings" "testing" @@ -43,6 +44,7 @@ import ( "github.com/pingcap/tidb/tablecodec" "github.com/pingcap/tidb/testkit" "github.com/pingcap/tidb/types" + "github.com/pingcap/tidb/util" "github.com/pingcap/tidb/util/codec" "github.com/stretchr/testify/require" "github.com/tikv/client-go/v2/testutils" @@ -3060,3 +3062,115 @@ func TestAutoAnalyzeAwareGlobalVariableChange(t *testing.T) { require.NoError(t, failpoint.Disable("github.com/pingcap/tidb/executor/injectBaseCount")) require.NoError(t, failpoint.Disable("github.com/pingcap/tidb/executor/injectBaseModifyCount")) } + +func TestGlobalMemoryControlForAnalyze(t *testing.T) { + store, dom := testkit.CreateMockStoreAndDomain(t) + + tk0 := testkit.NewTestKit(t, store) + tk0.MustExec("set global tidb_mem_oom_action = 'cancel'") + tk0.MustExec("set global tidb_server_memory_limit = 512MB") + tk0.MustExec("set global tidb_server_memory_limit_sess_min_size = 128") + + sm := &testkit.MockSessionManager{ + PS: []*util.ProcessInfo{tk0.Session().ShowProcess()}, + } + dom.ServerMemoryLimitHandle().SetSessionManager(sm) + go dom.ServerMemoryLimitHandle().Run() + + tk0.MustExec("use test") + tk0.MustExec("create table t(a int)") + tk0.MustExec("insert into t select 1") + for i := 1; i <= 8; i++ { + tk0.MustExec("insert into t select * from t") // 256 Lines + } + sql := "analyze table t with 1.0 samplerate;" // Need about 100MB + require.NoError(t, failpoint.Enable("github.com/pingcap/tidb/util/memory/ReadMemStats", `return(536870912)`)) + require.NoError(t, failpoint.Enable("github.com/pingcap/tidb/executor/mockAnalyzeMergeWorkerSlowConsume", `return(100)`)) + defer func() { + require.NoError(t, failpoint.Disable("github.com/pingcap/tidb/util/memory/ReadMemStats")) + require.NoError(t, failpoint.Disable("github.com/pingcap/tidb/executor/mockAnalyzeMergeWorkerSlowConsume")) + }() + _, err := tk0.Exec(sql) + require.True(t, strings.Contains(err.Error(), "Out Of Memory Quota!")) + runtime.GC() +} + +func TestGlobalMemoryControlForAutoAnalyze(t *testing.T) { + store, dom := testkit.CreateMockStoreAndDomain(t) + tk := testkit.NewTestKit(t, store) + originalVal1 := tk.MustQuery("select @@global.tidb_mem_oom_action").Rows()[0][0].(string) + tk.MustExec("set global tidb_mem_oom_action = 'cancel'") + //originalVal2 := tk.MustQuery("select @@global.tidb_server_memory_limit").Rows()[0][0].(string) + tk.MustExec("set global tidb_server_memory_limit = 512MB") + originalVal3 := tk.MustQuery("select @@global.tidb_server_memory_limit_sess_min_size").Rows()[0][0].(string) + tk.MustExec("set global tidb_server_memory_limit_sess_min_size = 128") + defer func() { + tk.MustExec(fmt.Sprintf("set global tidb_mem_oom_action = %v", originalVal1)) + //tk.MustExec(fmt.Sprintf("set global tidb_server_memory_limit = %v", originalVal2)) + tk.MustExec(fmt.Sprintf("set global tidb_server_memory_limit_sess_min_size = %v", originalVal3)) + }() + + // clean child trackers + oldChildTrackers := executor.GlobalAnalyzeMemoryTracker.GetChildrenForTest() + for _, tracker := range oldChildTrackers { + tracker.Detach() + } + defer func() { + for _, tracker := range oldChildTrackers { + tracker.AttachTo(executor.GlobalAnalyzeMemoryTracker) + } + }() + childTrackers := executor.GlobalAnalyzeMemoryTracker.GetChildrenForTest() + require.Len(t, childTrackers, 0) + + tk.MustExec("use test") + tk.MustExec("create table t(a int)") + tk.MustExec("insert into t select 1") + for i := 1; i <= 8; i++ { + tk.MustExec("insert into t select * from t") // 256 Lines + } + _, err0 := tk.Exec("analyze table t with 1.0 samplerate;") + require.NoError(t, err0) + rs0 := tk.MustQuery("select fail_reason from mysql.analyze_jobs where table_name=? and state=? limit 1", "t", "failed") + require.Len(t, rs0.Rows(), 0) + + h := dom.StatsHandle() + originalVal4 := handle.AutoAnalyzeMinCnt + originalVal5 := tk.MustQuery("select @@global.tidb_auto_analyze_ratio").Rows()[0][0].(string) + handle.AutoAnalyzeMinCnt = 0 + tk.MustExec("set global tidb_auto_analyze_ratio = 0.001") + defer func() { + handle.AutoAnalyzeMinCnt = originalVal4 + tk.MustExec(fmt.Sprintf("set global tidb_auto_analyze_ratio = %v", originalVal5)) + }() + + sm := &testkit.MockSessionManager{ + Dom: dom, + PS: []*util.ProcessInfo{tk.Session().ShowProcess()}, + } + dom.ServerMemoryLimitHandle().SetSessionManager(sm) + go dom.ServerMemoryLimitHandle().Run() + + tk.MustExec("insert into t values(4),(5),(6)") + require.NoError(t, h.DumpStatsDeltaToKV(handle.DumpAll)) + err := h.Update(dom.InfoSchema()) + require.NoError(t, err) + + require.NoError(t, failpoint.Enable("github.com/pingcap/tidb/util/memory/ReadMemStats", `return(536870912)`)) + require.NoError(t, failpoint.Enable("github.com/pingcap/tidb/executor/mockAnalyzeMergeWorkerSlowConsume", `return(100)`)) + defer func() { + require.NoError(t, failpoint.Disable("github.com/pingcap/tidb/util/memory/ReadMemStats")) + require.NoError(t, failpoint.Disable("github.com/pingcap/tidb/executor/mockAnalyzeMergeWorkerSlowConsume")) + }() + tk.MustQuery("select 1") + childTrackers = executor.GlobalAnalyzeMemoryTracker.GetChildrenForTest() + require.Len(t, childTrackers, 0) + + h.HandleAutoAnalyze(dom.InfoSchema()) + rs := tk.MustQuery("select fail_reason from mysql.analyze_jobs where table_name=? and state=? limit 1", "t", "failed") + failReason := rs.Rows()[0][0].(string) + require.True(t, strings.Contains(failReason, "Out Of Memory Quota!")) + + childTrackers = executor.GlobalAnalyzeMemoryTracker.GetChildrenForTest() + require.Len(t, childTrackers, 0) +} diff --git a/executor/builder.go b/executor/builder.go index b33c57d3de234..d4270397eecd0 100644 --- a/executor/builder.go +++ b/executor/builder.go @@ -3511,17 +3511,39 @@ func buildIndexRangeForEachPartition(ctx sessionctx.Context, usedPartitions []ta return nextRange, nil } -func keyColumnsIncludeAllPartitionColumns(keyColumns []int, pe *tables.PartitionExpr) bool { - tmp := make(map[int]struct{}, len(keyColumns)) - for _, offset := range keyColumns { - tmp[offset] = struct{}{} +func getPartitionKeyColOffsets(keyColIDs []int64, pt table.PartitionedTable) []int { + keyColOffsets := make([]int, len(keyColIDs)) + for i, colID := range keyColIDs { + offset := -1 + for j, col := range pt.Cols() { + if colID == col.ID { + offset = j + break + } + } + if offset == -1 { + return nil + } + keyColOffsets[i] = offset + } + + pe, err := pt.(interface { + PartitionExpr() (*tables.PartitionExpr, error) + }).PartitionExpr() + if err != nil { + return nil + } + + offsetMap := make(map[int]struct{}) + for _, offset := range keyColOffsets { + offsetMap[offset] = struct{}{} } for _, offset := range pe.ColumnOffset { - if _, ok := tmp[offset]; !ok { - return false + if _, ok := offsetMap[offset]; !ok { + return nil } } - return true + return keyColOffsets } func (builder *dataReaderBuilder) prunePartitionForInnerExecutor(tbl table.Table, schema *expression.Schema, partitionInfo *plannercore.PartitionInfo, @@ -3536,15 +3558,6 @@ func (builder *dataReaderBuilder) prunePartitionForInnerExecutor(tbl table.Table return nil, false, nil, err } - // check whether can runtime prune. - type partitionExpr interface { - PartitionExpr() (*tables.PartitionExpr, error) - } - pe, err := tbl.(partitionExpr).PartitionExpr() - if err != nil { - return nil, false, nil, err - } - // recalculate key column offsets if len(lookUpContent) == 0 { return nil, false, nil, nil @@ -3552,29 +3565,9 @@ func (builder *dataReaderBuilder) prunePartitionForInnerExecutor(tbl table.Table if lookUpContent[0].keyColIDs == nil { return nil, false, nil, plannercore.ErrInternal.GenWithStack("cannot get column IDs when dynamic pruning") } - keyColOffsets := make([]int, len(lookUpContent[0].keyColIDs)) - for i, colID := range lookUpContent[0].keyColIDs { - offset := -1 - for j, col := range partitionTbl.Cols() { - if colID == col.ID { - offset = j - break - } - } - if offset == -1 { - return nil, false, nil, plannercore.ErrInternal.GenWithStack("invalid column offset when dynamic pruning") - } - keyColOffsets[i] = offset - } - - offsetMap := make(map[int]bool) - for _, offset := range keyColOffsets { - offsetMap[offset] = true - } - for _, offset := range pe.ColumnOffset { - if _, ok := offsetMap[offset]; !ok { - return condPruneResult, false, nil, nil - } + keyColOffsets := getPartitionKeyColOffsets(lookUpContent[0].keyColIDs, partitionTbl) + if len(keyColOffsets) == 0 { + return condPruneResult, false, nil, nil } locateKey := make([]types.Datum, len(partitionTbl.Cols())) @@ -4149,12 +4142,6 @@ func (builder *dataReaderBuilder) buildTableReaderForIndexJoin(ctx context.Conte } tbl, _ := builder.is.TableByID(tbInfo.ID) pt := tbl.(table.PartitionedTable) - pe, err := tbl.(interface { - PartitionExpr() (*tables.PartitionExpr, error) - }).PartitionExpr() - if err != nil { - return nil, err - } partitionInfo := &v.PartitionInfo usedPartitionList, err := builder.partitionPruning(pt, partitionInfo.PruningConds, partitionInfo.PartitionNames, partitionInfo.Columns, partitionInfo.ColumnNames) if err != nil { @@ -4165,8 +4152,12 @@ func (builder *dataReaderBuilder) buildTableReaderForIndexJoin(ctx context.Conte usedPartitions[p.GetPhysicalID()] = p } var kvRanges []kv.KeyRange + var keyColOffsets []int + if len(lookUpContents) > 0 { + keyColOffsets = getPartitionKeyColOffsets(lookUpContents[0].keyColIDs, pt) + } if v.IsCommonHandle { - if len(lookUpContents) > 0 && keyColumnsIncludeAllPartitionColumns(lookUpContents[0].keyCols, pe) { + if len(keyColOffsets) > 0 { locateKey := make([]types.Datum, e.Schema().Len()) kvRanges = make([]kv.KeyRange, 0, len(lookUpContents)) // lookUpContentsByPID groups lookUpContents by pid(partition) so that kv ranges for same partition can be merged. @@ -4212,7 +4203,7 @@ func (builder *dataReaderBuilder) buildTableReaderForIndexJoin(ctx context.Conte handles, lookUpContents := dedupHandles(lookUpContents) - if len(lookUpContents) > 0 && keyColumnsIncludeAllPartitionColumns(lookUpContents[0].keyCols, pe) { + if len(keyColOffsets) > 0 { locateKey := make([]types.Datum, e.Schema().Len()) kvRanges = make([]kv.KeyRange, 0, len(lookUpContents)) for _, content := range lookUpContents { diff --git a/executor/executor.go b/executor/executor.go index 90622ce52e527..2e5b5c4a0280f 100644 --- a/executor/executor.go +++ b/executor/executor.go @@ -1965,6 +1965,7 @@ func ResetContextOfStmt(ctx sessionctx.Context, s ast.StmtNode) (err error) { sc.SysdateIsNow = ctx.GetSessionVars().SysdateIsNow + vars.MemTracker.Detach() vars.MemTracker.UnbindActions() vars.MemTracker.SetBytesLimit(vars.MemQuotaQuery) vars.MemTracker.ResetMaxConsumed() diff --git a/executor/executor_test.go b/executor/executor_test.go index bd64c39e5a134..122fbdbe7dd2f 100644 --- a/executor/executor_test.go +++ b/executor/executor_test.go @@ -6199,38 +6199,6 @@ func TestGlobalMemoryControl2(t *testing.T) { runtime.GC() } -func TestGlobalMemoryControlForAnalyze(t *testing.T) { - store, dom := testkit.CreateMockStoreAndDomain(t) - - tk0 := testkit.NewTestKit(t, store) - tk0.MustExec("set global tidb_mem_oom_action = 'cancel'") - tk0.MustExec("set global tidb_server_memory_limit = 512MB") - tk0.MustExec("set global tidb_server_memory_limit_sess_min_size = 128") - - sm := &testkit.MockSessionManager{ - PS: []*util.ProcessInfo{tk0.Session().ShowProcess()}, - } - dom.ServerMemoryLimitHandle().SetSessionManager(sm) - go dom.ServerMemoryLimitHandle().Run() - - tk0.MustExec("use test") - tk0.MustExec("create table t(a int)") - tk0.MustExec("insert into t select 1") - for i := 1; i <= 8; i++ { - tk0.MustExec("insert into t select * from t") // 256 Lines - } - sql := "analyze table t with 1.0 samplerate;" // Need about 100MB - require.NoError(t, failpoint.Enable("github.com/pingcap/tidb/util/memory/ReadMemStats", `return(536870912)`)) - require.NoError(t, failpoint.Enable("github.com/pingcap/tidb/executor/mockAnalyzeMergeWorkerSlowConsume", `return(100)`)) - defer func() { - require.NoError(t, failpoint.Disable("github.com/pingcap/tidb/util/memory/ReadMemStats")) - require.NoError(t, failpoint.Disable("github.com/pingcap/tidb/executor/mockAnalyzeMergeWorkerSlowConsume")) - }() - _, err := tk0.Exec(sql) - require.True(t, strings.Contains(err.Error(), "Out Of Memory Quota!")) - runtime.GC() -} - func TestCompileOutOfMemoryQuota(t *testing.T) { store := testkit.CreateMockStore(t) tk := testkit.NewTestKit(t, store) diff --git a/executor/index_lookup_join_test.go b/executor/index_lookup_join_test.go index 9a021568b20ee..600f052b1225e 100644 --- a/executor/index_lookup_join_test.go +++ b/executor/index_lookup_join_test.go @@ -428,17 +428,19 @@ PARTITIONS 1`) // Why does the t2.prefiller need be at least 2^32 ? If smaller the bug will not appear!?! tk.MustExec("insert into t2 values ( pow(2,32), 1, 1), ( pow(2,32)+1, 2, 0)") + tk.MustExec(`analyze table t1`) + tk.MustExec(`analyze table t2`) // Why must it be = 1 and not 2? - tk.MustQuery("explain select /* +INL_JOIN(t1,t2) */ t1.id, t1.pc from t1 where id in ( select prefiller from t2 where t2.postfiller = 1 )").Check(testkit.Rows("" + - "IndexJoin_15 10.00 root inner join, inner:TableReader_14, outer key:test.t2.prefiller, inner key:test.t1.id, equal cond:eq(test.t2.prefiller, test.t1.id)]\n" + - "[├─HashAgg_25(Build) 8.00 root group by:test.t2.prefiller, funcs:firstrow(test.t2.prefiller)->test.t2.prefiller]\n" + - "[│ └─TableReader_26 8.00 root data:HashAgg_20]\n" + - "[│ └─HashAgg_20 8.00 cop[tikv] group by:test.t2.prefiller, ]\n" + - "[│ └─Selection_24 10.00 cop[tikv] eq(test.t2.postfiller, 1)]\n" + - "[│ └─TableFullScan_23 10000.00 cop[tikv] table:t2 keep order:false, stats:pseudo]\n" + - "[└─TableReader_14(Probe) 8.00 root partition:all data:TableRangeScan_13]\n" + - "[ └─TableRangeScan_13 8.00 cop[tikv] table:t1 range: decided by [eq(test.t1.id, test.t2.prefiller)], keep order:false, stats:pseudo")) + tk.MustQuery("explain format='brief' select /* +INL_JOIN(t1,t2) */ t1.id, t1.pc from t1 where id in ( select prefiller from t2 where t2.postfiller = 1 )").Check(testkit.Rows(""+ + `IndexJoin 1.25 root inner join, inner:TableReader, outer key:test.t2.prefiller, inner key:test.t1.id, equal cond:eq(test.t2.prefiller, test.t1.id)`, + `├─HashAgg(Build) 1.00 root group by:test.t2.prefiller, funcs:firstrow(test.t2.prefiller)->test.t2.prefiller`, + `│ └─TableReader 1.00 root data:HashAgg`, + `│ └─HashAgg 1.00 cop[tikv] group by:test.t2.prefiller, `, + `│ └─Selection 1.00 cop[tikv] eq(test.t2.postfiller, 1)`, + `│ └─TableFullScan 2.00 cop[tikv] table:t2 keep order:false`, + `└─TableReader(Probe) 1.00 root partition:all data:TableRangeScan`, + ` └─TableRangeScan 1.00 cop[tikv] table:t1 range: decided by [eq(test.t1.id, test.t2.prefiller)], keep order:false, stats:pseudo`)) tk.MustQuery("show warnings").Check(testkit.Rows()) // without fix it fails with: "runtime error: index out of range [0] with length 0" tk.MustQuery("select /* +INL_JOIN(t1,t2) */ t1.id, t1.pc from t1 where id in ( select prefiller from t2 where t2.postfiller = 1 )").Check(testkit.Rows()) diff --git a/executor/oomtest/oom_test.go b/executor/oomtest/oom_test.go index 5b348f5c238de..fc95bb47ceab8 100644 --- a/executor/oomtest/oom_test.go +++ b/executor/oomtest/oom_test.go @@ -223,7 +223,8 @@ func (h *oomCapture) Write(entry zapcore.Entry, fields []zapcore.Field) error { return nil } // They are just common background task and not related to the oom. - if entry.Message == "SetTiFlashGroupConfig" { + if entry.Message == "SetTiFlashGroupConfig" || + entry.Message == "record table item load status failed due to not finding item" { return nil } diff --git a/executor/partition_table_test.go b/executor/partition_table_test.go index 15d2c2872ca9c..5696b56f6f730 100644 --- a/executor/partition_table_test.go +++ b/executor/partition_table_test.go @@ -3831,3 +3831,72 @@ func TestIssue21732(t *testing.T) { }) } } + +func TestIssue39999(t *testing.T) { + store := testkit.CreateMockStore(t) + + tk := testkit.NewTestKit(t, store) + + tk.MustExec(`create schema test39999`) + tk.MustExec(`use test39999`) + tk.MustExec(`drop table if exists c, t`) + tk.MustExec("CREATE TABLE `c` (" + + "`serial_id` varchar(24)," + + "`occur_trade_date` date," + + "`txt_account_id` varchar(24)," + + "`capital_sub_class` varchar(10)," + + "`occur_amount` decimal(16,2)," + + "`broker` varchar(10)," + + "PRIMARY KEY (`txt_account_id`,`occur_trade_date`,`serial_id`) /*T![clustered_index] CLUSTERED */," + + "KEY `idx_serial_id` (`serial_id`)" + + ") ENGINE=InnoDB DEFAULT CHARSET=utf8 COLLATE=utf8_unicode_ci " + + "PARTITION BY RANGE COLUMNS(`serial_id`) (" + + "PARTITION `p202209` VALUES LESS THAN ('20221001')," + + "PARTITION `p202210` VALUES LESS THAN ('20221101')," + + "PARTITION `p202211` VALUES LESS THAN ('20221201')" + + ")") + + tk.MustExec("CREATE TABLE `t` ( " + + "`txn_account_id` varchar(24), " + + "`account_id` varchar(32), " + + "`broker` varchar(10), " + + "PRIMARY KEY (`txn_account_id`) /*T![clustered_index] CLUSTERED */ " + + ") ENGINE=InnoDB DEFAULT CHARSET=utf8 COLLATE=utf8_unicode_ci") + + tk.MustExec("INSERT INTO `c` (serial_id, txt_account_id, capital_sub_class, occur_trade_date, occur_amount, broker) VALUES ('2022111700196920','04482786','CUST','2022-11-17',-2.01,'0009')") + tk.MustExec("INSERT INTO `t` VALUES ('04482786','1142927','0009')") + + tk.MustExec(`set tidb_partition_prune_mode='dynamic'`) + tk.MustExec(`analyze table c`) + tk.MustExec(`analyze table t`) + query := `select + /*+ inl_join(c) */ + c.occur_amount +from + c + join t on c.txt_account_id = t.txn_account_id + and t.broker = '0009' + and c.occur_trade_date = '2022-11-17'` + tk.MustQuery("explain " + query).Check(testkit.Rows(""+ + "IndexJoin_22 1.00 root inner join, inner:TableReader_21, outer key:test39999.t.txn_account_id, inner key:test39999.c.txt_account_id, equal cond:eq(test39999.t.txn_account_id, test39999.c.txt_account_id)", + "├─TableReader_27(Build) 1.00 root data:Selection_26", + "│ └─Selection_26 1.00 cop[tikv] eq(test39999.t.broker, \"0009\")", + "│ └─TableFullScan_25 1.00 cop[tikv] table:t keep order:false", + "└─TableReader_21(Probe) 1.00 root partition:all data:Selection_20", + " └─Selection_20 1.00 cop[tikv] eq(test39999.c.occur_trade_date, 2022-11-17 00:00:00.000000)", + " └─TableRangeScan_19 1.00 cop[tikv] table:c range: decided by [eq(test39999.c.txt_account_id, test39999.t.txn_account_id) eq(test39999.c.occur_trade_date, 2022-11-17 00:00:00.000000)], keep order:false")) + tk.MustQuery(query).Check(testkit.Rows("-2.01")) + + // Add the missing partition key part. + tk.MustExec(`alter table t add column serial_id varchar(24) default '2022111700196920'`) + query += ` and c.serial_id = t.serial_id` + tk.MustQuery(query).Check(testkit.Rows("-2.01")) + tk.MustQuery("explain " + query).Check(testkit.Rows(""+ + `IndexJoin_20 0.80 root inner join, inner:TableReader_19, outer key:test39999.t.txn_account_id, test39999.t.serial_id, inner key:test39999.c.txt_account_id, test39999.c.serial_id, equal cond:eq(test39999.t.serial_id, test39999.c.serial_id), eq(test39999.t.txn_account_id, test39999.c.txt_account_id)`, + `├─TableReader_25(Build) 0.80 root data:Selection_24`, + `│ └─Selection_24 0.80 cop[tikv] eq(test39999.t.broker, "0009"), not(isnull(test39999.t.serial_id))`, + `│ └─TableFullScan_23 1.00 cop[tikv] table:t keep order:false`, + `└─TableReader_19(Probe) 0.80 root partition:all data:Selection_18`, + ` └─Selection_18 0.80 cop[tikv] eq(test39999.c.occur_trade_date, 2022-11-17 00:00:00.000000)`, + ` └─TableRangeScan_17 0.80 cop[tikv] table:c range: decided by [eq(test39999.c.txt_account_id, test39999.t.txn_account_id) eq(test39999.c.serial_id, test39999.t.serial_id) eq(test39999.c.occur_trade_date, 2022-11-17 00:00:00.000000)], keep order:false`)) +} diff --git a/executor/prepared.go b/executor/prepared.go index a9dd9452e3c99..6a5025e0d539b 100644 --- a/executor/prepared.go +++ b/executor/prepared.go @@ -115,7 +115,7 @@ func (e *PrepareExec) Next(ctx context.Context, req *chunk.Chunk) error { return err } } - stmt, p, paramCnt, err := plannercore.GeneratePlanCacheStmtWithAST(ctx, e.ctx, stmt0) + stmt, p, paramCnt, err := plannercore.GeneratePlanCacheStmtWithAST(ctx, e.ctx, stmt0.Text(), stmt0) if err != nil { return err } diff --git a/expression/BUILD.bazel b/expression/BUILD.bazel index c7304642c544a..5a201d906b5a3 100644 --- a/expression/BUILD.bazel +++ b/expression/BUILD.bazel @@ -177,6 +177,7 @@ go_test( "integration_serial_test.go", "integration_test.go", "main_test.go", + "multi_valued_index_test.go", "scalar_function_test.go", "schema_test.go", "typeinfer_test.go", diff --git a/expression/builtin_cast.go b/expression/builtin_cast.go index ee66669e638d6..e6257c4dd058c 100644 --- a/expression/builtin_cast.go +++ b/expression/builtin_cast.go @@ -23,6 +23,7 @@ package expression import ( + "fmt" "math" "strconv" "strings" @@ -407,6 +408,70 @@ func (c *castAsDurationFunctionClass) getFunction(ctx sessionctx.Context, args [ return sig, nil } +type castAsArrayFunctionClass struct { + baseFunctionClass + + tp *types.FieldType +} + +func (c *castAsArrayFunctionClass) verifyArgs(args []Expression) error { + if err := c.baseFunctionClass.verifyArgs(args); err != nil { + return err + } + + if args[0].GetType().EvalType() != types.ETJson { + return types.ErrInvalidJSONData.GenWithStackByArgs("1", "cast_as_array") + } + + return nil +} + +func (c *castAsArrayFunctionClass) getFunction(ctx sessionctx.Context, args []Expression) (sig builtinFunc, err error) { + if err := c.verifyArgs(args); err != nil { + return nil, err + } + arrayType := c.tp.ArrayType() + switch arrayType.GetType() { + case mysql.TypeYear, mysql.TypeJSON: + return nil, ErrNotSupportedYet.GenWithStackByArgs(fmt.Sprintf("CAST-ing data to array of %s", arrayType.String())) + } + if arrayType.EvalType() == types.ETString && arrayType.GetCharset() != charset.CharsetUTF8MB4 && arrayType.GetCharset() != charset.CharsetBin { + return nil, ErrNotSupportedYet.GenWithStackByArgs("specifying charset for multi-valued index", arrayType.String()) + } + + bf, err := newBaseBuiltinFunc(ctx, c.funcName, args, c.tp) + if err != nil { + return nil, err + } + sig = &castJSONAsArrayFunctionSig{bf} + return sig, nil +} + +type castJSONAsArrayFunctionSig struct { + baseBuiltinFunc +} + +func (b *castJSONAsArrayFunctionSig) Clone() builtinFunc { + newSig := &castJSONAsArrayFunctionSig{} + newSig.cloneFrom(&b.baseBuiltinFunc) + return newSig +} + +func (b *castJSONAsArrayFunctionSig) evalJSON(row chunk.Row) (res types.BinaryJSON, isNull bool, err error) { + val, isNull, err := b.args[0].EvalJSON(b.ctx, row) + if isNull || err != nil { + return res, isNull, err + } + + if val.TypeCode != types.JSONTypeCodeArray { + return types.BinaryJSON{}, false, ErrNotSupportedYet.GenWithStackByArgs("CAST-ing Non-JSON Array type to array") + } + + // TODO: impl the cast(... as ... array) function + + return types.BinaryJSON{}, false, nil +} + type castAsJSONFunctionClass struct { baseFunctionClass @@ -1914,6 +1979,13 @@ func BuildCastCollationFunction(ctx sessionctx.Context, expr Expression, ec *Exp // BuildCastFunction builds a CAST ScalarFunction from the Expression. func BuildCastFunction(ctx sessionctx.Context, expr Expression, tp *types.FieldType) (res Expression) { + res, err := BuildCastFunctionWithCheck(ctx, expr, tp) + terror.Log(err) + return +} + +// BuildCastFunctionWithCheck builds a CAST ScalarFunction from the Expression and return error if any. +func BuildCastFunctionWithCheck(ctx sessionctx.Context, expr Expression, tp *types.FieldType) (res Expression, err error) { argType := expr.GetType() // If source argument's nullable, then target type should be nullable if !mysql.HasNotNullFlag(argType.GetFlag()) { @@ -1933,7 +2005,11 @@ func BuildCastFunction(ctx sessionctx.Context, expr Expression, tp *types.FieldT case types.ETDuration: fc = &castAsDurationFunctionClass{baseFunctionClass{ast.Cast, 1, 1}, tp} case types.ETJson: - fc = &castAsJSONFunctionClass{baseFunctionClass{ast.Cast, 1, 1}, tp} + if tp.IsArray() { + fc = &castAsArrayFunctionClass{baseFunctionClass{ast.Cast, 1, 1}, tp} + } else { + fc = &castAsJSONFunctionClass{baseFunctionClass{ast.Cast, 1, 1}, tp} + } case types.ETString: fc = &castAsStringFunctionClass{baseFunctionClass{ast.Cast, 1, 1}, tp} if expr.GetType().GetType() == mysql.TypeBit { @@ -1941,7 +2017,6 @@ func BuildCastFunction(ctx sessionctx.Context, expr Expression, tp *types.FieldT } } f, err := fc.getFunction(ctx, []Expression{expr}) - terror.Log(err) res = &ScalarFunction{ FuncName: model.NewCIStr(ast.Cast), RetType: tp, @@ -1950,10 +2025,10 @@ func BuildCastFunction(ctx sessionctx.Context, expr Expression, tp *types.FieldT // We do not fold CAST if the eval type of this scalar function is ETJson // since we may reset the flag of the field type of CastAsJson later which // would affect the evaluation of it. - if tp.EvalType() != types.ETJson { + if tp.EvalType() != types.ETJson && err == nil { res = FoldConstant(res) } - return res + return res, err } // WrapWithCastAsInt wraps `expr` with `cast` if the return type of expr is not diff --git a/expression/errors.go b/expression/errors.go index 0db38645f78d4..c56737ec2fae3 100644 --- a/expression/errors.go +++ b/expression/errors.go @@ -37,6 +37,7 @@ var ( ErrInvalidTableSample = dbterror.ClassExpression.NewStd(mysql.ErrInvalidTableSample) ErrInternal = dbterror.ClassOptimizer.NewStd(mysql.ErrInternal) ErrNoDB = dbterror.ClassOptimizer.NewStd(mysql.ErrNoDB) + ErrNotSupportedYet = dbterror.ClassExpression.NewStd(mysql.ErrNotSupportedYet) // All the un-exported errors are defined here: errFunctionNotExists = dbterror.ClassExpression.NewStd(mysql.ErrSpDoesNotExist) diff --git a/expression/expression.go b/expression/expression.go index 024bac00ef960..6d7eb080b29fc 100644 --- a/expression/expression.go +++ b/expression/expression.go @@ -59,7 +59,7 @@ var EvalAstExpr func(sctx sessionctx.Context, expr ast.ExprNode) (types.Datum, e // RewriteAstExpr rewrites ast expression directly. // Note: initialized in planner/core // import expression and planner/core together to use EvalAstExpr -var RewriteAstExpr func(sctx sessionctx.Context, expr ast.ExprNode, schema *Schema, names types.NameSlice) (Expression, error) +var RewriteAstExpr func(sctx sessionctx.Context, expr ast.ExprNode, schema *Schema, names types.NameSlice, allowCastArray bool) (Expression, error) // VecExpr contains all vectorized evaluation methods. type VecExpr interface { @@ -998,7 +998,7 @@ func ColumnInfos2ColumnsAndNames(ctx sessionctx.Context, dbName, tblName model.C if err != nil { return nil, nil, errors.Trace(err) } - e, err := RewriteAstExpr(ctx, expr, mockSchema, names) + e, err := RewriteAstExpr(ctx, expr, mockSchema, names, false) if err != nil { return nil, nil, errors.Trace(err) } @@ -1358,12 +1358,15 @@ func canScalarFuncPushDown(scalarFunc *ScalarFunction, pc PbConverter, storeType panic(errors.Errorf("unspecified PbCode: %T", scalarFunc.Function)) }) } + storageName := storeType.Name() + if storeType == kv.UnSpecified { + storageName = "storage layer" + } + warnErr := errors.New("Scalar function '" + scalarFunc.FuncName.L + "'(signature: " + scalarFunc.Function.PbCode().String() + ", return type: " + scalarFunc.RetType.CompactStr() + ") is not supported to push down to " + storageName + " now.") if pc.sc.InExplainStmt { - storageName := storeType.Name() - if storeType == kv.UnSpecified { - storageName = "storage layer" - } - pc.sc.AppendWarning(errors.New("Scalar function '" + scalarFunc.FuncName.L + "'(signature: " + scalarFunc.Function.PbCode().String() + ", return type: " + scalarFunc.RetType.CompactStr() + ") is not supported to push down to " + storageName + " now.")) + pc.sc.AppendWarning(warnErr) + } else { + pc.sc.AppendExtraWarning(warnErr) } return false } @@ -1393,14 +1396,20 @@ func canExprPushDown(expr Expression, pc PbConverter, storeType kv.StoreType, ca if expr.GetType().GetType() == mysql.TypeEnum && canEnumPush { break } + warnErr := errors.New("Expression about '" + expr.String() + "' can not be pushed to TiFlash because it contains unsupported calculation of type '" + types.TypeStr(expr.GetType().GetType()) + "'.") if pc.sc.InExplainStmt { - pc.sc.AppendWarning(errors.New("Expression about '" + expr.String() + "' can not be pushed to TiFlash because it contains unsupported calculation of type '" + types.TypeStr(expr.GetType().GetType()) + "'.")) + pc.sc.AppendWarning(warnErr) + } else { + pc.sc.AppendExtraWarning(warnErr) } return false case mysql.TypeNewDecimal: if !expr.GetType().IsDecimalValid() { + warnErr := errors.New("Expression about '" + expr.String() + "' can not be pushed to TiFlash because it contains invalid decimal('" + strconv.Itoa(expr.GetType().GetFlen()) + "','" + strconv.Itoa(expr.GetType().GetDecimal()) + "').") if pc.sc.InExplainStmt { - pc.sc.AppendWarning(errors.New("Expression about '" + expr.String() + "' can not be pushed to TiFlash because it contains invalid decimal('" + strconv.Itoa(expr.GetType().GetFlen()) + "','" + strconv.Itoa(expr.GetType().GetDecimal()) + "').")) + pc.sc.AppendWarning(warnErr) + } else { + pc.sc.AppendExtraWarning(warnErr) } return false } diff --git a/expression/multi_valued_index_test.go b/expression/multi_valued_index_test.go new file mode 100644 index 0000000000000..058d955faa4fb --- /dev/null +++ b/expression/multi_valued_index_test.go @@ -0,0 +1,47 @@ +// Copyright 2022 PingCAP, Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package expression_test + +import ( + "testing" + + "github.com/pingcap/tidb/errno" + "github.com/pingcap/tidb/testkit" +) + +func TestMultiValuedIndexDDL(t *testing.T) { + store := testkit.CreateMockStore(t) + + tk := testkit.NewTestKit(t, store) + tk.MustExec("USE test;") + + tk.MustExec("create table t(a json);") + tk.MustGetErrCode("select cast(a as signed array) from t", errno.ErrNotSupportedYet) + tk.MustGetErrCode("select json_extract(cast(a as signed array), '$[0]') from t", errno.ErrNotSupportedYet) + tk.MustGetErrCode("select * from t where cast(a as signed array)", errno.ErrNotSupportedYet) + tk.MustGetErrCode("select cast('[1,2,3]' as unsigned array);", errno.ErrNotSupportedYet) + + tk.MustExec("drop table t") + tk.MustGetErrCode("CREATE TABLE t(x INT, KEY k ((1 AND CAST(JSON_ARRAY(x) AS UNSIGNED ARRAY))));", errno.ErrNotSupportedYet) + tk.MustGetErrCode("CREATE TABLE t1 (f1 json, key mvi((cast(cast(f1 as unsigned array) as unsigned array))));", errno.ErrNotSupportedYet) + tk.MustGetErrCode("CREATE TABLE t1 (f1 json, key mvi((cast(f1->>'$[*]' as unsigned array))));", errno.ErrInvalidJSONData) + tk.MustGetErrCode("CREATE TABLE t1 (f1 json, key mvi((cast(f1->'$[*]' as year array))));", errno.ErrNotSupportedYet) + tk.MustGetErrCode("CREATE TABLE t1 (f1 json, key mvi((cast(f1->'$[*]' as json array))));", errno.ErrNotSupportedYet) + tk.MustGetErrCode("CREATE TABLE t1 (f1 json, key mvi((cast(f1->'$[*]' as char(10) charset gbk array))));", errno.ErrNotSupportedYet) + tk.MustGetErrCode("create table t(j json, gc json as ((concat(cast(j->'$[*]' as unsigned array),\"x\"))));", errno.ErrNotSupportedYet) + tk.MustGetErrCode("create table t(j json, gc json as (cast(j->'$[*]' as unsigned array)));", errno.ErrNotSupportedYet) + tk.MustGetErrCode("create view v as select cast('[1,2,3]' as unsigned array);", errno.ErrNotSupportedYet) + tk.MustExec("create table t(a json, index idx((cast(a as signed array))));") +} diff --git a/expression/simple_rewriter.go b/expression/simple_rewriter.go index 808db9f69b4cf..3343a0cbaa169 100644 --- a/expression/simple_rewriter.go +++ b/expression/simple_rewriter.go @@ -48,7 +48,7 @@ func ParseSimpleExprWithTableInfo(ctx sessionctx.Context, exprStr string, tableI return nil, errors.Trace(err) } expr := stmts[0].(*ast.SelectStmt).Fields.Fields[0].Expr - return RewriteSimpleExprWithTableInfo(ctx, tableInfo, expr) + return RewriteSimpleExprWithTableInfo(ctx, tableInfo, expr, false) } // ParseSimpleExprCastWithTableInfo parses simple expression string to Expression. @@ -63,13 +63,13 @@ func ParseSimpleExprCastWithTableInfo(ctx sessionctx.Context, exprStr string, ta } // RewriteSimpleExprWithTableInfo rewrites simple ast.ExprNode to expression.Expression. -func RewriteSimpleExprWithTableInfo(ctx sessionctx.Context, tbl *model.TableInfo, expr ast.ExprNode) (Expression, error) { +func RewriteSimpleExprWithTableInfo(ctx sessionctx.Context, tbl *model.TableInfo, expr ast.ExprNode, allowCastArray bool) (Expression, error) { dbName := model.NewCIStr(ctx.GetSessionVars().CurrentDB) columns, names, err := ColumnInfos2ColumnsAndNames(ctx, dbName, tbl.Name, tbl.Cols(), tbl) if err != nil { return nil, err } - e, err := RewriteAstExpr(ctx, expr, NewSchema(columns...), names) + e, err := RewriteAstExpr(ctx, expr, NewSchema(columns...), names, allowCastArray) if err != nil { return nil, err } @@ -111,7 +111,7 @@ func ParseSimpleExprsWithNames(ctx sessionctx.Context, exprStr string, schema *S // RewriteSimpleExprWithNames rewrites simple ast.ExprNode to expression.Expression. func RewriteSimpleExprWithNames(ctx sessionctx.Context, expr ast.ExprNode, schema *Schema, names []*types.FieldName) (Expression, error) { - e, err := RewriteAstExpr(ctx, expr, schema, names) + e, err := RewriteAstExpr(ctx, expr, schema, names, false) if err != nil { return nil, err } diff --git a/go.mod b/go.mod index df18ccec64b18..f4d3fb93f9e98 100644 --- a/go.mod +++ b/go.mod @@ -116,7 +116,7 @@ require ( golang.org/x/sys v0.3.0 golang.org/x/term v0.3.0 golang.org/x/text v0.5.0 - golang.org/x/time v0.2.0 + golang.org/x/time v0.3.0 golang.org/x/tools v0.2.0 google.golang.org/api v0.74.0 google.golang.org/grpc v1.45.0 diff --git a/go.sum b/go.sum index b7de00223b8a3..8c21c0326306d 100644 --- a/go.sum +++ b/go.sum @@ -1338,8 +1338,8 @@ golang.org/x/time v0.0.0-20181108054448-85acf8d2951c/go.mod h1:tRJNPiyCQ0inRvYxb golang.org/x/time v0.0.0-20190308202827-9d24e82272b4/go.mod h1:tRJNPiyCQ0inRvYxbN9jk5I+vvW/OXSQhTDSoE431IQ= golang.org/x/time v0.0.0-20191024005414-555d28b269f0/go.mod h1:tRJNPiyCQ0inRvYxbN9jk5I+vvW/OXSQhTDSoE431IQ= golang.org/x/time v0.0.0-20210220033141-f8bda1e9f3ba/go.mod h1:tRJNPiyCQ0inRvYxbN9jk5I+vvW/OXSQhTDSoE431IQ= -golang.org/x/time v0.2.0 h1:52I/1L54xyEQAYdtcSuxtiT84KGYTBGXwayxmIpNJhE= -golang.org/x/time v0.2.0/go.mod h1:tRJNPiyCQ0inRvYxbN9jk5I+vvW/OXSQhTDSoE431IQ= +golang.org/x/time v0.3.0 h1:rg5rLMjNzMS1RkNLzCG38eapWhnYLFYXDXj2gOlr8j4= +golang.org/x/time v0.3.0/go.mod h1:tRJNPiyCQ0inRvYxbN9jk5I+vvW/OXSQhTDSoE431IQ= golang.org/x/tools v0.0.0-20180221164845-07fd8470d635/go.mod h1:n7NCudcB/nEzxVGmLbDWY5pfWTLqBcC2KZ6jyYvM4mQ= golang.org/x/tools v0.0.0-20180525024113-a5b4c53f6e8b/go.mod h1:n7NCudcB/nEzxVGmLbDWY5pfWTLqBcC2KZ6jyYvM4mQ= golang.org/x/tools v0.0.0-20180917221912-90fa682c2a6e/go.mod h1:n7NCudcB/nEzxVGmLbDWY5pfWTLqBcC2KZ6jyYvM4mQ= diff --git a/infoschema/tables_test.go b/infoschema/tables_test.go index ebf3e9b535893..d4d4c4fa588f7 100644 --- a/infoschema/tables_test.go +++ b/infoschema/tables_test.go @@ -531,18 +531,166 @@ func TestSlowQuery(t *testing.T) { slowLogFileName := "tidb_slow.log" prepareSlowLogfile(t, slowLogFileName) defer func() { require.NoError(t, os.Remove(slowLogFileName)) }() + expectedRes := [][]interface{}{ + {"2019-02-12 19:33:56.571953", + "406315658548871171", + "root", + "localhost", + "6", + "57", + "0.12", + "4.895492", + "0.4", + "0.2", + "0.000000003", + "2", + "0.000000002", + "0.00000001", + "0.000000003", + "0.19", + "0.21", + "0.01", + "0", + "0.18", + "[txnLock]", + "0.03", + "0", + "15", + "480", + "1", + "8", + "0.3824278", + "0.161", + "0.101", + "0.092", + "1.71", + "1", + "100001", + "100000", + "100", + "10", + "10", + "10", + "100", + "test", + "", + "0", + "42a1c8aae6f133e934d4bf0147491709a8812ea05ff8819ec522780fe657b772", + "t1:1,t2:2", + "0.1", + "0.2", + "0.03", + "127.0.0.1:20160", + "0.05", + "0.6", + "0.8", + "0.0.0.0:20160", + "70724", + "65536", + "0", + "0", + "0", + "0", + "10", + "", + "0", + "1", + "0", + "0", + "1", + "0", + "0", + "abcd", + "60e9378c746d9a2be1c791047e008967cf252eb6de9167ad3aa6098fa2d523f4", + "", + "update t set i = 2;", + "select * from t_slim;"}, + {"2021-09-08 14:39:54.506967", + "427578666238083075", + "root", + "172.16.0.0", + "40507", + "0", + "0", + "25.571605962", + "0.002923536", + "0.006800973", + "0.002100764", + "0", + "0", + "0", + "0.000015801", + "25.542014572", + "0", + "0.002294647", + "0.000605473", + "12.483", + "[tikvRPC regionMiss tikvRPC regionMiss regionMiss]", + "0", + "0", + "624", + "172064", + "60", + "0", + "0", + "0", + "0", + "0", + "0", + "0", + "0", + "0", + "0", + "0", + "0", + "0", + "0", + "rtdb", + "", + "0", + "124acb3a0bec903176baca5f9da00b4e7512a41c93b417923f26502edeb324cc", + "", + "0", + "0", + "0", + "", + "0", + "0", + "0", + "", + "856544", + "0", + "86.635049185", + "0.015486658", + "100.054", + "0", + "0", + "", + "0", + "1", + "0", + "0", + "0", + "0", + "0", + "", + "", + "", + "", + "INSERT INTO ...;", + }, + } tk.MustExec(fmt.Sprintf("set @@tidb_slow_query_file='%v'", slowLogFileName)) tk.MustExec("set time_zone = '+08:00';") re := tk.MustQuery("select * from information_schema.slow_query") - re.Check(testkit.RowsWithSep("|", "2019-02-12 19:33:56.571953|406315658548871171|root|localhost|6|57|0.12|4.895492|0.4|0.2|0.000000003|2|0.000000002|0.00000001|0.000000003|0.19|0.21|0.01|0|0.18|[txnLock]|0.03|0|15|480|1|8|0.3824278|0.161|0.101|0.092|1.71|1|100001|100000|100|10|10|10|100|test||0|42a1c8aae6f133e934d4bf0147491709a8812ea05ff8819ec522780fe657b772|t1:1,t2:2|0.1|0.2|0.03|127.0.0.1:20160|0.05|0.6|0.8|0.0.0.0:20160|70724|65536|0|0|0|0|10||0|1|0|0|1|0|0|abcd|60e9378c746d9a2be1c791047e008967cf252eb6de9167ad3aa6098fa2d523f4||update t set i = 2;|select * from t_slim;", - "2021-09-08|14:39:54.506967|427578666238083075|root|172.16.0.0|40507|0|0|25.571605962|0.002923536|0.006800973|0.002100764|0|0|0|0.000015801|25.542014572|0|0.002294647|0.000605473|12.483|[tikvRPC regionMiss tikvRPC regionMiss regionMiss]|0|0|624|172064|60|0|0|0|0|0|0|0|0|0|0|0|0|0|0|rtdb||0|124acb3a0bec903176baca5f9da00b4e7512a41c93b417923f26502edeb324cc||0|0|0||0|0|0||856544|0|86.635049185|0.015486658|100.054|0|0||0|1|0|0|0|0|0|||||INSERT INTO ...;", - )) + re.Check(expectedRes) + tk.MustExec("set time_zone = '+00:00';") re = tk.MustQuery("select * from information_schema.slow_query") - re.Check(testkit.RowsWithSep("|", "2019-02-12 11:33:56.571953|406315658548871171|root|localhost|6|57|0.12|4.895492|0.4|0.2|0.000000003|2|0.000000002|0.00000001|0.000000003|0.19|0.21|0.01|0|0.18|[txnLock]|0.03|0|15|480|1|8|0.3824278|0.161|0.101|0.092|1.71|1|100001|100000|100|10|10|10|100|test||0|42a1c8aae6f133e934d4bf0147491709a8812ea05ff8819ec522780fe657b772|t1:1,t2:2|0.1|0.2|0.03|127.0.0.1:20160|0.05|0.6|0.8|0.0.0.0:20160|70724|65536|0|0|0|0|10||0|1|0|0|1|0|0|abcd|60e9378c746d9a2be1c791047e008967cf252eb6de9167ad3aa6098fa2d523f4||update t set i = 2;|select * from t_slim;", - "2021-09-08|06:39:54.506967|427578666238083075|root|172.16.0.0|40507|0|0|25.571605962|0.002923536|0.006800973|0.002100764|0|0|0|0.000015801|25.542014572|0|0.002294647|0.000605473|12.483|[tikvRPC regionMiss tikvRPC regionMiss regionMiss]|0|0|624|172064|60|0|0|0|0|0|0|0|0|0|0|0|0|0|0|rtdb||0|124acb3a0bec903176baca5f9da00b4e7512a41c93b417923f26502edeb324cc||0|0|0||0|0|0||856544|0|86.635049185|0.015486658|100.054|0|0||0|1|0|0|0|0|0|||||INSERT INTO ...;", - )) + expectedRes[0][0] = "2019-02-12 11:33:56.571953" + expectedRes[1][0] = "2021-09-08 06:39:54.506967" + re.Check(expectedRes) // Test for long query. f, err := os.OpenFile(slowLogFileName, os.O_CREATE|os.O_WRONLY, 0644) diff --git a/parser/model/model.go b/parser/model/model.go index 4622ca2810359..411db3fcf1d20 100644 --- a/parser/model/model.go +++ b/parser/model/model.go @@ -1419,6 +1419,7 @@ type IndexInfo struct { Primary bool `json:"is_primary"` // Whether the index is primary key. Invisible bool `json:"is_invisible"` // Whether the index is invisible. Global bool `json:"is_global"` // Whether the index is global. + MVIndex bool `json:"mv_index"` // Whether the index is multivalued index. } // Clone clones IndexInfo. diff --git a/parser/parser_test.go b/parser/parser_test.go index c06d2076f085a..7b72117f69d16 100644 --- a/parser/parser_test.go +++ b/parser/parser_test.go @@ -1603,6 +1603,7 @@ func TestBuiltin(t *testing.T) { {"select cast(time '2000' as year);", true, "SELECT CAST(TIME '2000' AS YEAR)"}, {"select cast(b as signed array);", true, "SELECT CAST(`b` AS SIGNED ARRAY)"}, + {"select cast(b as char(10) array);", true, "SELECT CAST(`b` AS CHAR(10) ARRAY)"}, // for last_insert_id {"SELECT last_insert_id();", true, "SELECT LAST_INSERT_ID()"}, diff --git a/parser/types/field_type.go b/parser/types/field_type.go index 369ed59fa7a59..ff0ac9793cf17 100644 --- a/parser/types/field_type.go +++ b/parser/types/field_type.go @@ -72,7 +72,7 @@ func NewFieldType(tp byte) *FieldType { // IsDecimalValid checks whether the decimal is valid. func (ft *FieldType) IsDecimalValid() bool { - if ft.tp == mysql.TypeNewDecimal && (ft.decimal < 0 || ft.decimal > mysql.MaxDecimalScale || ft.flen <= 0 || ft.flen > mysql.MaxDecimalWidth || ft.flen < ft.decimal) { + if ft.GetType() == mysql.TypeNewDecimal && (ft.decimal < 0 || ft.decimal > mysql.MaxDecimalScale || ft.flen <= 0 || ft.flen > mysql.MaxDecimalWidth || ft.flen < ft.decimal) { return false } return true @@ -80,7 +80,7 @@ func (ft *FieldType) IsDecimalValid() bool { // IsVarLengthType Determine whether the column type is a variable-length type func (ft *FieldType) IsVarLengthType() bool { - switch ft.tp { + switch ft.GetType() { case mysql.TypeVarchar, mysql.TypeVarString, mysql.TypeJSON, mysql.TypeBlob, mysql.TypeTinyBlob, mysql.TypeMediumBlob, mysql.TypeLongBlob: return true default: @@ -90,6 +90,9 @@ func (ft *FieldType) IsVarLengthType() bool { // GetType returns the type of the FieldType. func (ft *FieldType) GetType() byte { + if ft.array { + return mysql.TypeJSON + } return ft.tp } @@ -126,6 +129,7 @@ func (ft *FieldType) GetElems() []string { // SetType sets the type of the FieldType. func (ft *FieldType) SetType(tp byte) { ft.tp = tp + ft.array = false } // SetFlag sets the flag of the FieldType. @@ -160,7 +164,7 @@ func (ft *FieldType) SetFlen(flen int) { // SetFlenUnderLimit sets the length of the field to the value of the argument func (ft *FieldType) SetFlenUnderLimit(flen int) { - if ft.tp == mysql.TypeNewDecimal { + if ft.GetType() == mysql.TypeNewDecimal { ft.flen = mathutil.Min(flen, mysql.MaxDecimalWidth) } else { ft.flen = flen @@ -174,7 +178,7 @@ func (ft *FieldType) SetDecimal(decimal int) { // SetDecimalUnderLimit sets the decimal of the field to the value of the argument func (ft *FieldType) SetDecimalUnderLimit(decimal int) { - if ft.tp == mysql.TypeNewDecimal { + if ft.GetType() == mysql.TypeNewDecimal { ft.decimal = mathutil.Min(decimal, mysql.MaxDecimalScale) } else { ft.decimal = decimal @@ -183,7 +187,7 @@ func (ft *FieldType) SetDecimalUnderLimit(decimal int) { // UpdateFlenAndDecimalUnderLimit updates the length and decimal to the value of the argument func (ft *FieldType) UpdateFlenAndDecimalUnderLimit(old *FieldType, deltaDecimal int, deltaFlen int) { - if ft.tp != mysql.TypeNewDecimal { + if ft.GetType() != mysql.TypeNewDecimal { return } if old.decimal < 0 { @@ -229,6 +233,13 @@ func (ft *FieldType) IsArray() bool { return ft.array } +// ArrayType return the type of the array. +func (ft *FieldType) ArrayType() *FieldType { + clone := ft.Clone() + clone.SetArray(false) + return clone +} + // SetElemWithIsBinaryLit sets the element of the FieldType. func (ft *FieldType) SetElemWithIsBinaryLit(idx int, element string, isBinaryLit bool) { ft.elems[idx] = element @@ -274,7 +285,7 @@ func (ft *FieldType) Equal(other *FieldType) bool { // When tp is float or double with decimal unspecified, do not check whether flen is equal, // because flen for them is useless. // The decimal field can be ignored if the type is int or string. - tpEqual := (ft.tp == other.tp) || (ft.tp == mysql.TypeVarchar && other.tp == mysql.TypeVarString) || (ft.tp == mysql.TypeVarString && other.tp == mysql.TypeVarchar) + tpEqual := (ft.GetType() == other.GetType()) || (ft.GetType() == mysql.TypeVarchar && other.GetType() == mysql.TypeVarString) || (ft.GetType() == mysql.TypeVarString && other.GetType() == mysql.TypeVarchar) flenEqual := ft.flen == other.flen || (ft.EvalType() == ETReal && ft.decimal == UnspecifiedLength) ignoreDecimal := ft.EvalType() == ETInt || ft.EvalType() == ETString partialEqual := tpEqual && @@ -316,7 +327,7 @@ func (ft *FieldType) PartialEqual(other *FieldType, unsafe bool) bool { // EvalType gets the type in evaluation. func (ft *FieldType) EvalType() EvalType { - switch ft.tp { + switch ft.GetType() { case mysql.TypeTiny, mysql.TypeShort, mysql.TypeInt24, mysql.TypeLong, mysql.TypeLonglong, mysql.TypeBit, mysql.TypeYear: return ETInt @@ -342,7 +353,7 @@ func (ft *FieldType) EvalType() EvalType { // Hybrid checks whether a type is a hybrid type, which can represent different types of value in specific context. func (ft *FieldType) Hybrid() bool { - return ft.tp == mysql.TypeEnum || ft.tp == mysql.TypeBit || ft.tp == mysql.TypeSet + return ft.GetType() == mysql.TypeEnum || ft.GetType() == mysql.TypeBit || ft.GetType() == mysql.TypeSet } // Init initializes the FieldType data. @@ -355,10 +366,10 @@ func (ft *FieldType) Init(tp byte) { // CompactStr only considers tp/CharsetBin/flen/Deimal. // This is used for showing column type in infoschema. func (ft *FieldType) CompactStr() string { - ts := TypeToStr(ft.tp, ft.charset) + ts := TypeToStr(ft.GetType(), ft.charset) suffix := "" - defaultFlen, defaultDecimal := mysql.GetDefaultFieldLengthAndDecimal(ft.tp) + defaultFlen, defaultDecimal := mysql.GetDefaultFieldLengthAndDecimal(ft.GetType()) isDecimalNotDefault := ft.decimal != defaultDecimal && ft.decimal != 0 && ft.decimal != UnspecifiedLength // displayFlen and displayDecimal are flen and decimal values with `-1` substituted with default value. @@ -370,7 +381,7 @@ func (ft *FieldType) CompactStr() string { displayDecimal = defaultDecimal } - switch ft.tp { + switch ft.GetType() { case mysql.TypeEnum, mysql.TypeSet: // Format is ENUM ('e1', 'e2') or SET ('e1', 'e2') es := make([]string, 0, len(ft.elems)) @@ -414,8 +425,8 @@ func (ft *FieldType) CompactStr() string { func (ft *FieldType) InfoSchemaStr() string { suffix := "" if mysql.HasUnsignedFlag(ft.flag) && - ft.tp != mysql.TypeBit && - ft.tp != mysql.TypeYear { + ft.GetType() != mysql.TypeBit && + ft.GetType() != mysql.TypeYear { suffix = " unsigned" } return ft.CompactStr() + suffix @@ -431,11 +442,11 @@ func (ft *FieldType) String() string { if mysql.HasZerofillFlag(ft.flag) { strs = append(strs, "ZEROFILL") } - if mysql.HasBinaryFlag(ft.flag) && ft.tp != mysql.TypeString { + if mysql.HasBinaryFlag(ft.flag) && ft.GetType() != mysql.TypeString { strs = append(strs, "BINARY") } - if IsTypeChar(ft.tp) || IsTypeBlob(ft.tp) { + if IsTypeChar(ft.GetType()) || IsTypeBlob(ft.GetType()) { if ft.charset != "" && ft.charset != charset.CharsetBin { strs = append(strs, fmt.Sprintf("CHARACTER SET %s", ft.charset)) } @@ -449,12 +460,12 @@ func (ft *FieldType) String() string { // Restore implements Node interface. func (ft *FieldType) Restore(ctx *format.RestoreCtx) error { - ctx.WriteKeyWord(TypeToStr(ft.tp, ft.charset)) + ctx.WriteKeyWord(TypeToStr(ft.GetType(), ft.charset)) precision := UnspecifiedLength scale := UnspecifiedLength - switch ft.tp { + switch ft.GetType() { case mysql.TypeEnum, mysql.TypeSet: ctx.WritePlain("(") for i, e := range ft.elems { @@ -491,7 +502,7 @@ func (ft *FieldType) Restore(ctx *format.RestoreCtx) error { ctx.WriteKeyWord(" BINARY") } - if IsTypeChar(ft.tp) || IsTypeBlob(ft.tp) { + if IsTypeChar(ft.GetType()) || IsTypeBlob(ft.GetType()) { if ft.charset != "" && ft.charset != charset.CharsetBin { ctx.WriteKeyWord(" CHARACTER SET " + ft.charset) } @@ -519,7 +530,7 @@ func (ft *FieldType) RestoreAsCastType(ctx *format.RestoreCtx, explicitCharset b ctx.WritePlainf("(%d)", ft.flen) } if !explicitCharset { - return + break } if !skipWriteBinary && ft.flag&mysql.BinaryFlag != 0 { ctx.WriteKeyWord(" BINARY") @@ -581,7 +592,7 @@ const VarStorageLen = -1 // StorageLength is the length of stored value for the type. func (ft *FieldType) StorageLength() int { - switch ft.tp { + switch ft.GetType() { case mysql.TypeTiny, mysql.TypeShort, mysql.TypeInt24, mysql.TypeLong, mysql.TypeLonglong, mysql.TypeDouble, mysql.TypeFloat, mysql.TypeYear, mysql.TypeDuration, mysql.TypeDate, mysql.TypeDatetime, mysql.TypeTimestamp, mysql.TypeEnum, mysql.TypeSet, @@ -599,7 +610,7 @@ func (ft *FieldType) StorageLength() int { // HasCharset indicates if a COLUMN has an associated charset. Returning false here prevents some information // statements(like `SHOW CREATE TABLE`) from attaching a CHARACTER SET clause to the column. func HasCharset(ft *FieldType) bool { - switch ft.tp { + switch ft.GetType() { case mysql.TypeVarchar, mysql.TypeString, mysql.TypeVarString, mysql.TypeBlob, mysql.TypeTinyBlob, mysql.TypeMediumBlob, mysql.TypeLongBlob: return !mysql.HasBinaryFlag(ft.flag) diff --git a/planner/core/common_plans.go b/planner/core/common_plans.go index f11f60b95cfe5..8feb357745853 100644 --- a/planner/core/common_plans.go +++ b/planner/core/common_plans.go @@ -914,11 +914,6 @@ func (e *Explain) explainOpRecursivelyInJSONFormat(flatOp *FlatOperator, flats F textTreeExplainID := texttree.PrettyIdentifier(explainID, flatOp.TextTreeIndent, flatOp.IsLastChild) cur := e.prepareOperatorInfoForJSONFormat(flatOp.Origin, taskTp, textTreeExplainID, explainID) - if e.ctx != nil && e.ctx.GetSessionVars() != nil && e.ctx.GetSessionVars().StmtCtx != nil { - if optimInfo, ok := e.ctx.GetSessionVars().StmtCtx.OptimInfo[flatOp.Origin.ID()]; ok { - e.ctx.GetSessionVars().StmtCtx.AppendNote(errors.New(optimInfo)) - } - } for _, idx := range flatOp.ChildrenIdx { cur.SubOperators = append(cur.SubOperators, @@ -938,11 +933,6 @@ func (e *Explain) explainFlatOpInRowFormat(flatOp *FlatOperator) { flatOp.TextTreeIndent, flatOp.IsLastChild) e.prepareOperatorInfo(flatOp.Origin, taskTp, textTreeExplainID) - if e.ctx != nil && e.ctx.GetSessionVars() != nil && e.ctx.GetSessionVars().StmtCtx != nil { - if optimInfo, ok := e.ctx.GetSessionVars().StmtCtx.OptimInfo[flatOp.Origin.ID()]; ok { - e.ctx.GetSessionVars().StmtCtx.AppendNote(errors.New(optimInfo)) - } - } } func getRuntimeInfoStr(ctx sessionctx.Context, p Plan, runtimeStatsColl *execdetails.RuntimeStatsColl) (actRows, analyzeInfo, memoryInfo, diskInfo string) { diff --git a/planner/core/exhaust_physical_plans.go b/planner/core/exhaust_physical_plans.go index 768a8c20fc0b5..2817f370ffcec 100644 --- a/planner/core/exhaust_physical_plans.go +++ b/planner/core/exhaust_physical_plans.go @@ -2730,8 +2730,11 @@ func (la *LogicalAggregation) checkCanPushDownToMPP() bool { } } if hasUnsupportedDistinct { + warnErr := errors.New("Aggregation can not be pushed to storage layer in mpp mode because it contains agg function with distinct") if la.ctx.GetSessionVars().StmtCtx.InExplainStmt { - la.ctx.GetSessionVars().StmtCtx.AppendWarning(errors.New("Aggregation can not be pushed to storage layer in mpp mode because it contains agg function with distinct")) + la.ctx.GetSessionVars().StmtCtx.AppendWarning(warnErr) + } else { + la.ctx.GetSessionVars().StmtCtx.AppendExtraWarning(warnErr) } return false } diff --git a/planner/core/expression_rewriter.go b/planner/core/expression_rewriter.go index d0ca6e6f8e4cf..ddb905dc5c06b 100644 --- a/planner/core/expression_rewriter.go +++ b/planner/core/expression_rewriter.go @@ -55,7 +55,7 @@ func evalAstExpr(sctx sessionctx.Context, expr ast.ExprNode) (types.Datum, error if val, ok := expr.(*driver.ValueExpr); ok { return val.Datum, nil } - newExpr, err := rewriteAstExpr(sctx, expr, nil, nil) + newExpr, err := rewriteAstExpr(sctx, expr, nil, nil, false) if err != nil { return types.Datum{}, err } @@ -63,13 +63,14 @@ func evalAstExpr(sctx sessionctx.Context, expr ast.ExprNode) (types.Datum, error } // rewriteAstExpr rewrites ast expression directly. -func rewriteAstExpr(sctx sessionctx.Context, expr ast.ExprNode, schema *expression.Schema, names types.NameSlice) (expression.Expression, error) { +func rewriteAstExpr(sctx sessionctx.Context, expr ast.ExprNode, schema *expression.Schema, names types.NameSlice, allowCastArray bool) (expression.Expression, error) { var is infoschema.InfoSchema // in tests, it may be null if s, ok := sctx.GetInfoSchema().(infoschema.InfoSchema); ok { is = s } b, savedBlockNames := NewPlanBuilder().Init(sctx, is, &hint.BlockHintProcessor{}) + b.allowBuildCastArray = allowCastArray fakePlan := LogicalTableDual{}.Init(sctx, 0) if schema != nil { fakePlan.schema = schema @@ -1183,6 +1184,10 @@ func (er *expressionRewriter) Leave(originInNode ast.Node) (retNode ast.Node, ok er.disableFoldCounter-- } case *ast.FuncCastExpr: + if v.Tp.IsArray() && !er.b.allowBuildCastArray { + er.err = expression.ErrNotSupportedYet.GenWithStackByArgs("Use of CAST( .. AS .. ARRAY) outside of functional index in CREATE(non-SELECT)/ALTER TABLE or in general expressions") + return retNode, false + } arg := er.ctxStack[len(er.ctxStack)-1] er.err = expression.CheckArgsNotMultiColumnRow(arg) if er.err != nil { @@ -1195,7 +1200,11 @@ func (er *expressionRewriter) Leave(originInNode ast.Node) (retNode ast.Node, ok return retNode, false } - castFunction := expression.BuildCastFunction(er.sctx, arg, v.Tp) + castFunction, err := expression.BuildCastFunctionWithCheck(er.sctx, arg, v.Tp) + if err != nil { + er.err = err + return retNode, false + } if v.Tp.EvalType() == types.ETString { castFunction.SetCoercibility(expression.CoercibilityImplicit) if v.Tp.GetCharset() == charset.CharsetASCII { diff --git a/planner/core/find_best_task.go b/planner/core/find_best_task.go index afc5223b9be94..639bc15dbdc98 100644 --- a/planner/core/find_best_task.go +++ b/planner/core/find_best_task.go @@ -742,7 +742,7 @@ func (ds *DataSource) skylinePruning(prop *property.PhysicalProperty) []*candida } func (ds *DataSource) getPruningInfo(candidates []*candidatePath, prop *property.PhysicalProperty) string { - if !ds.ctx.GetSessionVars().StmtCtx.InVerboseExplain || len(candidates) == len(ds.possibleAccessPaths) { + if len(candidates) == len(ds.possibleAccessPaths) { return "" } if len(candidates) == 1 && len(candidates[0].path.Ranges) == 0 { @@ -889,10 +889,12 @@ func (ds *DataSource) findBestTask(prop *property.PhysicalProperty, planCounter pruningInfo := ds.getPruningInfo(candidates, prop) defer func() { if err == nil && t != nil && !t.invalid() && pruningInfo != "" { - if ds.ctx.GetSessionVars().StmtCtx.OptimInfo == nil { - ds.ctx.GetSessionVars().StmtCtx.OptimInfo = make(map[int]string) + warnErr := errors.New(pruningInfo) + if ds.ctx.GetSessionVars().StmtCtx.InVerboseExplain { + ds.ctx.GetSessionVars().StmtCtx.AppendNote(warnErr) + } else { + ds.ctx.GetSessionVars().StmtCtx.AppendExtraNote(warnErr) } - ds.ctx.GetSessionVars().StmtCtx.OptimInfo[t.plan().ID()] = pruningInfo } }() diff --git a/planner/core/plan_cache.go b/planner/core/plan_cache.go index d9df658aaa46e..ab4eb4e4912ab 100644 --- a/planner/core/plan_cache.go +++ b/planner/core/plan_cache.go @@ -115,6 +115,10 @@ func planCachePreprocess(ctx context.Context, sctx sessionctx.Context, isNonPrep func GetPlanFromSessionPlanCache(ctx context.Context, sctx sessionctx.Context, isNonPrepared bool, is infoschema.InfoSchema, stmt *PlanCacheStmt, params []expression.Expression) (plan Plan, names []*types.FieldName, err error) { + if v := ctx.Value("____GetPlanFromSessionPlanCacheErr"); v != nil { // for testing + return nil, nil, errors.New("____GetPlanFromSessionPlanCacheErr") + } + if err := planCachePreprocess(ctx, sctx, isNonPrepared, is, stmt, params); err != nil { return nil, nil, err } diff --git a/planner/core/plan_cache_param.go b/planner/core/plan_cache_param.go index 7c79b2a6416a0..9094edec621c0 100644 --- a/planner/core/plan_cache_param.go +++ b/planner/core/plan_cache_param.go @@ -15,6 +15,7 @@ package core import ( + "context" "errors" "strings" "sync" @@ -70,7 +71,7 @@ func (pr *paramReplacer) Reset() { pr.params = nil } // ParameterizeAST parameterizes this StmtNode. // e.g. `select * from t where a<10 and b<23` --> `select * from t where a `select * from t where a<10 and b<23`. -func RestoreASTWithParams(_ sessionctx.Context, stmt ast.StmtNode, params []*driver.ValueExpr) error { +func RestoreASTWithParams(ctx context.Context, _ sessionctx.Context, stmt ast.StmtNode, params []*driver.ValueExpr) error { + if v := ctx.Value("____RestoreASTWithParamsErr"); v != nil { + return errors.New("____RestoreASTWithParamsErr") + } + pr := paramRestorerPool.Get().(*paramRestorer) defer func() { pr.Reset() diff --git a/planner/core/plan_cache_param_test.go b/planner/core/plan_cache_param_test.go index 5c65b89767a60..ee4a8e9ae65c5 100644 --- a/planner/core/plan_cache_param_test.go +++ b/planner/core/plan_cache_param_test.go @@ -15,6 +15,7 @@ package core import ( + "context" "strings" "testing" @@ -61,7 +62,7 @@ func TestParameterize(t *testing.T) { for _, c := range cases { stmt, err := parser.New().ParseOneStmt(c.sql, "", "") require.Nil(t, err) - paramSQL, params, err := ParameterizeAST(sctx, stmt) + paramSQL, params, err := ParameterizeAST(context.Background(), sctx, stmt) require.Nil(t, err) require.Equal(t, c.paramSQL, paramSQL) require.Equal(t, len(c.params), len(params)) @@ -69,7 +70,7 @@ func TestParameterize(t *testing.T) { require.Equal(t, c.params[i], params[i].Datum.GetValue()) } - err = RestoreASTWithParams(sctx, stmt, params) + err = RestoreASTWithParams(context.Background(), sctx, stmt, params) require.Nil(t, err) var buf strings.Builder rCtx := format.NewRestoreCtx(format.DefaultRestoreFlags, &buf) diff --git a/planner/core/plan_cache_test.go b/planner/core/plan_cache_test.go index f541c441fd4f5..7a4ac860d8593 100644 --- a/planner/core/plan_cache_test.go +++ b/planner/core/plan_cache_test.go @@ -15,6 +15,7 @@ package core_test import ( + "context" "errors" "fmt" "math/rand" @@ -81,6 +82,76 @@ func TestInitLRUWithSystemVar(t *testing.T) { require.NotNil(t, lru) } +func TestNonPreparedPlanCacheWithExplain(t *testing.T) { + store := testkit.CreateMockStore(t) + tk := testkit.NewTestKit(t, store) + tk.MustExec(`use test`) + tk.MustExec("create table t(a int)") + tk.MustExec("set tidb_enable_non_prepared_plan_cache=1") + tk.MustExec("select * from t where a=1") // cache this plan + + tk.MustQuery("explain select * from t where a=2").Check(testkit.Rows( + `Selection_8 10.00 root eq(test.t.a, 2)`, + `└─TableReader_7 10.00 root data:Selection_6`, + ` └─Selection_6 10.00 cop[tikv] eq(test.t.a, 2)`, + ` └─TableFullScan_5 10000.00 cop[tikv] table:t keep order:false, stats:pseudo`)) + tk.MustQuery("select @@last_plan_from_cache").Check(testkit.Rows("1")) + + tk.MustQuery("explain format=verbose select * from t where a=2").Check(testkit.Rows( + `Selection_8 10.00 169474.57 root eq(test.t.a, 2)`, + `└─TableReader_7 10.00 168975.57 root data:Selection_6`, + ` └─Selection_6 10.00 2534000.00 cop[tikv] eq(test.t.a, 2)`, + ` └─TableFullScan_5 10000.00 2035000.00 cop[tikv] table:t keep order:false, stats:pseudo`)) + tk.MustQuery("select @@last_plan_from_cache").Check(testkit.Rows("1")) + + tk.MustQuery("explain analyze select * from t where a=2").CheckAt([]int{0, 1, 2, 3}, [][]interface{}{ + {"Selection_8", "10.00", "0", "root"}, + {"└─TableReader_7", "10.00", "0", "root"}, + {" └─Selection_6", "10.00", "0", "cop[tikv]"}, + {" └─TableFullScan_5", "10000.00", "0", "cop[tikv]"}, + }) + tk.MustQuery("select @@last_plan_from_cache").Check(testkit.Rows("1")) +} + +func TestNonPreparedPlanCacheFallback(t *testing.T) { + store := testkit.CreateMockStore(t) + tk := testkit.NewTestKit(t, store) + tk.MustExec(`use test`) + tk.MustExec(`create table t (a int)`) + for i := 0; i < 5; i++ { + tk.MustExec(fmt.Sprintf("insert into t values (%v)", i)) + } + tk.MustExec("set tidb_enable_non_prepared_plan_cache=1") + + // inject a fault to GeneratePlanCacheStmtWithAST + ctx := context.WithValue(context.Background(), "____GeneratePlanCacheStmtWithASTErr", struct{}{}) + tk.MustQueryWithContext(ctx, "select * from t where a in (1, 2)").Sort().Check(testkit.Rows("1", "2")) + tk.MustQuery("select @@last_plan_from_cache").Check(testkit.Rows("0")) // cannot generate PlanCacheStmt + tk.MustQueryWithContext(ctx, "select * from t where a in (1, 3)").Sort().Check(testkit.Rows("1", "3")) + tk.MustQuery("select @@last_plan_from_cache").Check(testkit.Rows("0")) // cannot generate PlanCacheStmt + tk.MustQuery("select * from t where a in (1, 2)").Sort().Check(testkit.Rows("1", "2")) + tk.MustQuery("select * from t where a in (1, 3)").Sort().Check(testkit.Rows("1", "3")) + tk.MustQuery("select @@last_plan_from_cache").Check(testkit.Rows("1")) // no error + + // inject a fault to GetPlanFromSessionPlanCache + tk.MustQuery("select * from t where a=1").Check(testkit.Rows("1")) // cache this plan + tk.MustQuery("select * from t where a=2").Check(testkit.Rows("2")) // plan from cache + tk.MustQuery("select @@last_plan_from_cache").Check(testkit.Rows("1")) + ctx = context.WithValue(context.Background(), "____GetPlanFromSessionPlanCacheErr", struct{}{}) + tk.MustQueryWithContext(ctx, "select * from t where a=3").Check(testkit.Rows("3")) + tk.MustQuery("select @@last_plan_from_cache").Check(testkit.Rows("0")) // fallback to the normal opt-path + tk.MustQueryWithContext(ctx, "select * from t where a=4").Check(testkit.Rows("4")) + tk.MustQuery("select @@last_plan_from_cache").Check(testkit.Rows("0")) // fallback to the normal opt-path + tk.MustQueryWithContext(context.Background(), "select * from t where a=0").Check(testkit.Rows("0")) + tk.MustQuery("select @@last_plan_from_cache").Check(testkit.Rows("1")) // use the cached plan if no error + + // inject a fault to RestoreASTWithParams + ctx = context.WithValue(context.Background(), "____GetPlanFromSessionPlanCacheErr", struct{}{}) + ctx = context.WithValue(ctx, "____RestoreASTWithParamsErr", struct{}{}) + _, err := tk.ExecWithContext(ctx, "select * from t where a=1") + require.NotNil(t, err) +} + func TestNonPreparedPlanCacheBasically(t *testing.T) { store := testkit.CreateMockStore(t) tk := testkit.NewTestKit(t, store) diff --git a/planner/core/plan_cache_utils.go b/planner/core/plan_cache_utils.go index 2b1621857b9ca..8dc867316207d 100644 --- a/planner/core/plan_cache_utils.go +++ b/planner/core/plan_cache_utils.go @@ -64,17 +64,23 @@ func (e *paramMarkerExtractor) Leave(in ast.Node) (ast.Node, bool) { } // GeneratePlanCacheStmtWithAST generates the PlanCacheStmt structure for this AST. -func GeneratePlanCacheStmtWithAST(ctx context.Context, sctx sessionctx.Context, stmt ast.StmtNode) (*PlanCacheStmt, Plan, int, error) { +// paramSQL is the corresponding parameterized sql like 'select * from t where a?'. +// paramStmt is the Node of paramSQL. +func GeneratePlanCacheStmtWithAST(ctx context.Context, sctx sessionctx.Context, paramSQL string, paramStmt ast.StmtNode) (*PlanCacheStmt, Plan, int, error) { + if v := ctx.Value("____GeneratePlanCacheStmtWithASTErr"); v != nil { // for testing + return nil, nil, 0, errors.New("____GeneratePlanCacheStmtWithASTErr") + } + vars := sctx.GetSessionVars() var extractor paramMarkerExtractor - stmt.Accept(&extractor) + paramStmt.Accept(&extractor) // DDL Statements can not accept parameters - if _, ok := stmt.(ast.DDLNode); ok && len(extractor.markers) > 0 { + if _, ok := paramStmt.(ast.DDLNode); ok && len(extractor.markers) > 0 { return nil, nil, 0, ErrPrepareDDL } - switch stmt.(type) { + switch paramStmt.(type) { case *ast.LoadDataStmt, *ast.PrepareStmt, *ast.ExecuteStmt, *ast.DeallocateStmt, *ast.NonTransactionalDMLStmt: return nil, nil, 0, ErrUnsupportedPs } @@ -86,7 +92,7 @@ func GeneratePlanCacheStmtWithAST(ctx context.Context, sctx sessionctx.Context, } ret := &PreprocessorReturn{} - err := Preprocess(ctx, sctx, stmt, InPrepare, WithPreprocessorReturn(ret)) + err := Preprocess(ctx, sctx, paramStmt, InPrepare, WithPreprocessorReturn(ret)) if err != nil { return nil, nil, 0, err } @@ -103,8 +109,8 @@ func GeneratePlanCacheStmtWithAST(ctx context.Context, sctx sessionctx.Context, } prepared := &ast.Prepared{ - Stmt: stmt, - StmtType: ast.GetStmtLabel(stmt), + Stmt: paramStmt, + StmtType: ast.GetStmtLabel(paramStmt), Params: extractor.markers, SchemaVersion: ret.InfoSchema.SchemaMetaVersion(), } @@ -117,12 +123,12 @@ func GeneratePlanCacheStmtWithAST(ctx context.Context, sctx sessionctx.Context, if !vars.EnablePreparedPlanCache { prepared.UseCache = false } else { - cacheable, reason := CacheableWithCtx(sctx, stmt, ret.InfoSchema) + cacheable, reason := CacheableWithCtx(sctx, paramStmt, ret.InfoSchema) prepared.UseCache = cacheable if !cacheable { sctx.GetSessionVars().StmtCtx.AppendWarning(errors.Errorf("skip plan-cache: " + reason)) } - selectStmtNode, normalizedSQL4PC, digest4PC, err = ExtractSelectAndNormalizeDigest(stmt, vars.CurrentDB) + selectStmtNode, normalizedSQL4PC, digest4PC, err = ExtractSelectAndNormalizeDigest(paramStmt, vars.CurrentDB) if err != nil || selectStmtNode == nil { normalizedSQL4PC = "" digest4PC = "" @@ -138,7 +144,7 @@ func GeneratePlanCacheStmtWithAST(ctx context.Context, sctx sessionctx.Context, var p Plan destBuilder, _ := NewPlanBuilder().Init(sctx, ret.InfoSchema, &hint.BlockHintProcessor{}) - p, err = destBuilder.Build(ctx, stmt) + p, err = destBuilder.Build(ctx, paramStmt) if err != nil { return nil, nil, 0, err } @@ -146,7 +152,7 @@ func GeneratePlanCacheStmtWithAST(ctx context.Context, sctx sessionctx.Context, preparedObj := &PlanCacheStmt{ PreparedAst: prepared, StmtDB: vars.CurrentDB, - StmtText: stmt.Text(), + StmtText: paramSQL, VisitInfos: destBuilder.GetVisitInfo(), NormalizedSQL: normalizedSQL, SQLDigest: digest, diff --git a/planner/core/planbuilder.go b/planner/core/planbuilder.go index 9201f953bdcdc..26508814523d2 100644 --- a/planner/core/planbuilder.go +++ b/planner/core/planbuilder.go @@ -577,6 +577,9 @@ type PlanBuilder struct { // disableSubQueryPreprocessing indicates whether to pre-process uncorrelated sub-queries in rewriting stage. disableSubQueryPreprocessing bool + + // allowBuildCastArray indicates whether allow cast(... as ... array). + allowBuildCastArray bool } type handleColHelper struct { @@ -697,6 +700,14 @@ func (p PlanBuilderOptNoExecution) Apply(builder *PlanBuilder) { builder.disableSubQueryPreprocessing = true } +// PlanBuilderOptAllowCastArray means the plan builder should allow build cast(... as ... array). +type PlanBuilderOptAllowCastArray struct{} + +// Apply implements the interface PlanBuilderOpt. +func (p PlanBuilderOptAllowCastArray) Apply(builder *PlanBuilder) { + builder.allowBuildCastArray = true +} + // NewPlanBuilder creates a new PlanBuilder. func NewPlanBuilder(opts ...PlanBuilderOpt) *PlanBuilder { builder := &PlanBuilder{ @@ -4511,6 +4522,14 @@ func (b *PlanBuilder) buildDDL(ctx context.Context, node ast.DDLNode) (Plan, err } b.visitInfo = appendVisitInfo(b.visitInfo, mysql.UpdatePriv, mysql.SystemDB, "stats_extended", "", authErr) + } else if spec.Tp == ast.AlterTableAddConstraint { + if b.ctx.GetSessionVars().User != nil && spec.Constraint != nil && + spec.Constraint.Tp == ast.ConstraintForeignKey && spec.Constraint.Refer != nil { + authErr = ErrTableaccessDenied.GenWithStackByArgs("REFERENCES", b.ctx.GetSessionVars().User.AuthUsername, + b.ctx.GetSessionVars().User.AuthHostname, spec.Constraint.Refer.Table.Name.L) + b.visitInfo = appendVisitInfo(b.visitInfo, mysql.ReferencesPriv, spec.Constraint.Refer.Table.Schema.L, + spec.Constraint.Refer.Table.Name.L, "", authErr) + } } } case *ast.AlterSequenceStmt: diff --git a/planner/core/stats.go b/planner/core/stats.go index f377feac91030..71e1037c52c76 100644 --- a/planner/core/stats.go +++ b/planner/core/stats.go @@ -343,35 +343,37 @@ func (ds *DataSource) derivePathStatsAndTryHeuristics() error { if selected != nil { ds.possibleAccessPaths[0] = selected ds.possibleAccessPaths = ds.possibleAccessPaths[:1] - if ds.ctx.GetSessionVars().StmtCtx.InVerboseExplain { - var tableName string - if ds.TableAsName.O == "" { - tableName = ds.tableInfo.Name.O + var tableName string + if ds.TableAsName.O == "" { + tableName = ds.tableInfo.Name.O + } else { + tableName = ds.TableAsName.O + } + var sb strings.Builder + if selected.IsTablePath() { + // TODO: primary key / handle / real name? + sb.WriteString(fmt.Sprintf("handle of %s is selected since the path only has point ranges", tableName)) + } else { + if selected.Index.Unique { + sb.WriteString("unique ") + } + sb.WriteString(fmt.Sprintf("index %s of %s is selected since the path", selected.Index.Name.O, tableName)) + if isRefinedPath { + sb.WriteString(" only fetches limited number of rows") } else { - tableName = ds.TableAsName.O + sb.WriteString(" only has point ranges") } - if selected.IsTablePath() { - // TODO: primary key / handle / real name? - ds.ctx.GetSessionVars().StmtCtx.AppendNote(fmt.Errorf("handle of %s is selected since the path only has point ranges", tableName)) + if selected.IsSingleScan { + sb.WriteString(" with single scan") } else { - var sb strings.Builder - if selected.Index.Unique { - sb.WriteString("unique ") - } - sb.WriteString(fmt.Sprintf("index %s of %s is selected since the path", selected.Index.Name.O, tableName)) - if isRefinedPath { - sb.WriteString(" only fetches limited number of rows") - } else { - sb.WriteString(" only has point ranges") - } - if selected.IsSingleScan { - sb.WriteString(" with single scan") - } else { - sb.WriteString(" with double scan") - } - ds.ctx.GetSessionVars().StmtCtx.AppendNote(errors.New(sb.String())) + sb.WriteString(" with double scan") } } + if ds.ctx.GetSessionVars().StmtCtx.InVerboseExplain { + ds.ctx.GetSessionVars().StmtCtx.AppendNote(errors.New(sb.String())) + } else { + ds.ctx.GetSessionVars().StmtCtx.AppendExtraNote(errors.New(sb.String())) + } } return nil } @@ -435,8 +437,10 @@ func (ds *DataSource) DeriveStats(_ []*property.StatsInfo, _ *expression.Schema, if needConsiderIndexMerge { // PushDownExprs() will append extra warnings, which is annoying. So we reset warnings here. warnings := stmtCtx.GetWarnings() + extraWarnings := stmtCtx.GetExtraWarnings() _, remaining := expression.PushDownExprs(stmtCtx, indexMergeConds, ds.ctx.GetClient(), kv.UnSpecified) stmtCtx.SetWarnings(warnings) + stmtCtx.SetExtraWarnings(extraWarnings) if len(remaining) != 0 { needConsiderIndexMerge = false } diff --git a/planner/core/task.go b/planner/core/task.go index cc27029d83c8e..99952038688fe 100644 --- a/planner/core/task.go +++ b/planner/core/task.go @@ -1239,12 +1239,17 @@ func CheckAggCanPushCop(sctx sessionctx.Context, aggFuncs []*aggregation.AggFunc ret = false } - if !ret && sc.InExplainStmt { + if !ret { storageName := storeType.Name() if storeType == kv.UnSpecified { storageName = "storage layer" } - sc.AppendWarning(errors.New("Aggregation can not be pushed to " + storageName + " because " + reason)) + warnErr := errors.New("Aggregation can not be pushed to " + storageName + " because " + reason) + if sc.InExplainStmt { + sc.AppendWarning(warnErr) + } else { + sc.AppendExtraWarning(warnErr) + } } return ret } diff --git a/planner/core/testdata/integration_suite_out.json b/planner/core/testdata/integration_suite_out.json index 14c04c6cfb0ab..d61124de927bd 100644 --- a/planner/core/testdata/integration_suite_out.json +++ b/planner/core/testdata/integration_suite_out.json @@ -1969,7 +1969,8 @@ " └─TableRangeScan_8 3333.33 923531.15 cop[tikv] table:t range:(1,+inf], keep order:false, stats:pseudo" ], "Warnings": [ - "Note 1105 [t] remain after pruning paths for t given Prop{SortItems: [], TaskTp: rootTask}" + "Note 1105 [t] remain after pruning paths for t given Prop{SortItems: [], TaskTp: rootTask}", + "Note 1105 [t,f,f_g] remain after pruning paths for t given Prop{SortItems: [{test.t.f asc}], TaskTp: rootTask}" ] }, { @@ -2014,7 +2015,8 @@ " └─TableRowIDScan_12(Probe) 10.00 2770.59 cop[tikv] table:t keep order:false, stats:pseudo" ], "Warnings": [ - "Note 1105 [t,g] remain after pruning paths for t given Prop{SortItems: [], TaskTp: rootTask}" + "Note 1105 [t,g] remain after pruning paths for t given Prop{SortItems: [], TaskTp: rootTask}", + "Note 1105 [t,f_g,g] remain after pruning paths for t given Prop{SortItems: [{test.t.f asc}], TaskTp: rootTask}" ] }, { @@ -2026,6 +2028,7 @@ "└─TableRowIDScan_13(Probe) 10.00 2770.59 cop[tikv] table:t keep order:false, stats:pseudo" ], "Warnings": [ + "Note 1105 [t] remain after pruning paths for t given Prop{SortItems: [], TaskTp: rootTask}", "Note 1105 [t,c_d_e] remain after pruning paths for t given Prop{SortItems: [{test.t.c asc} {test.t.e asc}], TaskTp: rootTask}" ] } diff --git a/planner/optimize.go b/planner/optimize.go index 3a6804d5fa319..5e572d8485368 100644 --- a/planner/optimize.go +++ b/planner/optimize.go @@ -74,18 +74,25 @@ func matchSQLBinding(sctx sessionctx.Context, stmtNode ast.StmtNode) (bindRecord } // getPlanFromNonPreparedPlanCache tries to get an available cached plan from the NonPrepared Plan Cache for this stmt. -func getPlanFromNonPreparedPlanCache(ctx context.Context, sctx sessionctx.Context, stmt ast.StmtNode, is infoschema.InfoSchema) (core.Plan, types.NameSlice, bool, error) { +func getPlanFromNonPreparedPlanCache(ctx context.Context, sctx sessionctx.Context, stmt ast.StmtNode, is infoschema.InfoSchema) (p core.Plan, ns types.NameSlice, ok bool, err error) { if sctx.GetSessionVars().StmtCtx.InPreparedPlanBuilding || // already in cached plan rebuilding phase !core.NonPreparedPlanCacheableWithCtx(sctx, stmt, is) { return nil, nil, false, nil } - paramSQL, params, err := core.ParameterizeAST(sctx, stmt) + paramSQL, params, err := core.ParameterizeAST(ctx, sctx, stmt) if err != nil { return nil, nil, false, err } + defer func() { + if err != nil { + // keep the stmt unchanged if err so that it can fallback to the normal optimization path. + // TODO: add metrics + err = core.RestoreASTWithParams(ctx, sctx, stmt, params) + } + }() val := sctx.GetSessionVars().GetNonPreparedPlanCacheStmt(paramSQL) if val == nil { - cachedStmt, _, _, err := core.GeneratePlanCacheStmtWithAST(ctx, sctx, stmt) + cachedStmt, _, _, err := core.GeneratePlanCacheStmtWithAST(ctx, sctx, paramSQL, stmt) if err != nil { return nil, nil, false, err } @@ -234,6 +241,8 @@ func Optimize(ctx context.Context, sctx sessionctx.Context, node ast.Node, is in sessVars.FoundInBinding = true if sessVars.StmtCtx.InVerboseExplain { sessVars.StmtCtx.AppendNote(errors.Errorf("Using the bindSQL: %v", chosenBinding.BindSQL)) + } else { + sessVars.StmtCtx.AppendExtraNote(errors.Errorf("Using the bindSQL: %v", chosenBinding.BindSQL)) } } // Restore the hint to avoid changing the stmt node. diff --git a/server/server.go b/server/server.go index 09a20c8cb39c2..ba915c64f23cb 100644 --- a/server/server.go +++ b/server/server.go @@ -733,6 +733,11 @@ func (s *Server) GetProcessInfo(id uint64) (*util.ProcessInfo, bool) { conn, ok := s.clients[id] s.rwlock.RUnlock() if !ok { + if s.dom != nil { + if pinfo, ok2 := s.dom.SysProcTracker().GetSysProcessList()[id]; ok2 { + return pinfo, true + } + } return &util.ProcessInfo{}, false } return conn.ctx.ShowProcess(), ok diff --git a/session/session.go b/session/session.go index 2c6aa0567fa66..63bc1c970fe08 100644 --- a/session/session.go +++ b/session/session.go @@ -1988,6 +1988,7 @@ func (s *session) useCurrentSession(execOption sqlexec.ExecOption) (*session, fu s.sessionVars.StmtCtx.OriginalSQL = prevSQL s.sessionVars.StmtCtx.StmtType = prevStmtType s.sessionVars.StmtCtx.Tables = prevTables + s.sessionVars.MemTracker.Detach() }, nil } @@ -2049,6 +2050,7 @@ func (s *session) getInternalSession(execOption sqlexec.ExecOption) (*session, f se.sessionVars.PartitionPruneMode.Store(prePruneMode) se.sessionVars.OptimizerUseInvisibleIndexes = false se.sessionVars.InspectionTableCache = nil + se.sessionVars.MemTracker.Detach() s.sysSessionPool().Put(tmp) }, nil } diff --git a/sessionctx/stmtctx/stmtctx.go b/sessionctx/stmtctx/stmtctx.go index 6f9a276691149..47159dc8f8a60 100644 --- a/sessionctx/stmtctx/stmtctx.go +++ b/sessionctx/stmtctx/stmtctx.go @@ -209,8 +209,15 @@ type StatementContext struct { copied uint64 touched uint64 - message string - warnings []SQLWarn + message string + warnings []SQLWarn + // extraWarnings record the extra warnings and are only used by the slow log only now. + // If a warning is expected to be output only under some conditions (like in EXPLAIN or EXPLAIN VERBOSE) but it's + // not under such conditions now, it is considered as an extra warning. + // extraWarnings would not be printed through SHOW WARNINGS, but we want to always output them through the slow + // log to help diagnostics, so we store them here separately. + extraWarnings []SQLWarn + execDetails execdetails.ExecDetails allExecDetails []*execdetails.DetailsNeedP90 } @@ -299,8 +306,6 @@ type StatementContext struct { LogOnExceed [2]memory.LogOnExceed } - // OptimInfo maps Plan.ID() to optimization information when generating Plan. - OptimInfo map[int]string // InVerboseExplain indicates the statement is "explain format='verbose' ...". InVerboseExplain bool @@ -812,6 +817,47 @@ func (sc *StatementContext) AppendError(warn error) { } } +// GetExtraWarnings gets extra warnings. +func (sc *StatementContext) GetExtraWarnings() []SQLWarn { + sc.mu.Lock() + defer sc.mu.Unlock() + return sc.mu.extraWarnings +} + +// SetExtraWarnings sets extra warnings. +func (sc *StatementContext) SetExtraWarnings(warns []SQLWarn) { + sc.mu.Lock() + defer sc.mu.Unlock() + sc.mu.extraWarnings = warns +} + +// AppendExtraWarning appends an extra warning with level 'Warning'. +func (sc *StatementContext) AppendExtraWarning(warn error) { + sc.mu.Lock() + defer sc.mu.Unlock() + if len(sc.mu.extraWarnings) < math.MaxUint16 { + sc.mu.extraWarnings = append(sc.mu.extraWarnings, SQLWarn{WarnLevelWarning, warn}) + } +} + +// AppendExtraNote appends an extra warning with level 'Note'. +func (sc *StatementContext) AppendExtraNote(warn error) { + sc.mu.Lock() + defer sc.mu.Unlock() + if len(sc.mu.extraWarnings) < math.MaxUint16 { + sc.mu.extraWarnings = append(sc.mu.extraWarnings, SQLWarn{WarnLevelNote, warn}) + } +} + +// AppendExtraError appends an extra warning with level 'Error'. +func (sc *StatementContext) AppendExtraError(warn error) { + sc.mu.Lock() + defer sc.mu.Unlock() + if len(sc.mu.extraWarnings) < math.MaxUint16 { + sc.mu.extraWarnings = append(sc.mu.extraWarnings, SQLWarn{WarnLevelError, warn}) + } +} + // HandleTruncate ignores or returns the error based on the StatementContext state. func (sc *StatementContext) HandleTruncate(err error) error { // TODO: At present we have not checked whether the error can be ignored or treated as warning. diff --git a/sessionctx/variable/session.go b/sessionctx/variable/session.go index fe4972fb5ff3a..695d0ad48a5a5 100644 --- a/sessionctx/variable/session.go +++ b/sessionctx/variable/session.go @@ -1482,8 +1482,13 @@ func (s *SessionVars) IsMPPEnforced() bool { // TODO: Confirm whether this function will be inlined and // omit the overhead of string construction when calling with false condition. func (s *SessionVars) RaiseWarningWhenMPPEnforced(warning string) { - if s.IsMPPEnforced() && s.StmtCtx.InExplainStmt { + if !s.IsMPPEnforced() { + return + } + if s.StmtCtx.InExplainStmt { s.StmtCtx.AppendWarning(errors.New(warning)) + } else { + s.StmtCtx.AppendExtraWarning(errors.New(warning)) } } diff --git a/sessionctx/variable/tidb_vars.go b/sessionctx/variable/tidb_vars.go index 00137ca7608cf..c8da3ed5c10e6 100644 --- a/sessionctx/variable/tidb_vars.go +++ b/sessionctx/variable/tidb_vars.go @@ -1107,7 +1107,7 @@ const ( DefAdaptiveClosestReadThreshold = 4096 DefTiDBEnableAnalyzeSnapshot = false DefTiDBGenerateBinaryPlan = true - DefEnableTiDBGCAwareMemoryTrack = true + DefEnableTiDBGCAwareMemoryTrack = false DefTiDBDefaultStrMatchSelectivity = 0.8 DefTiDBEnableTmpStorageOnOOM = true DefTiDBEnableMDL = true diff --git a/testkit/mocksessionmanager.go b/testkit/mocksessionmanager.go index a9e4d085dc34d..550ff69132d91 100644 --- a/testkit/mocksessionmanager.go +++ b/testkit/mocksessionmanager.go @@ -18,6 +18,7 @@ import ( "crypto/tls" "sync" + "github.com/pingcap/tidb/domain" "github.com/pingcap/tidb/parser/ast" "github.com/pingcap/tidb/planner/core" "github.com/pingcap/tidb/session" @@ -31,6 +32,7 @@ type MockSessionManager struct { PSMu sync.RWMutex SerID uint64 TxnInfo []*txninfo.TxnInfo + Dom *domain.Domain conn map[uint64]session.Session mu sync.Mutex } @@ -68,6 +70,11 @@ func (msm *MockSessionManager) ShowProcessList() map[uint64]*util.ProcessInfo { ret[connID] = pi.ShowProcess() } msm.mu.Unlock() + if msm.Dom != nil { + for connID, pi := range msm.Dom.SysProcTracker().GetSysProcessList() { + ret[connID] = pi + } + } return ret } @@ -85,6 +92,11 @@ func (msm *MockSessionManager) GetProcessInfo(id uint64) (*util.ProcessInfo, boo if sess := msm.conn[id]; sess != nil { return sess.ShowProcess(), true } + if msm.Dom != nil { + if pinfo, ok := msm.Dom.SysProcTracker().GetSysProcessList()[id]; ok { + return pinfo, true + } + } return &util.ProcessInfo{}, false } diff --git a/types/field_type_builder.go b/types/field_type_builder.go index 7c9f3bdc3177d..81554c4585442 100644 --- a/types/field_type_builder.go +++ b/types/field_type_builder.go @@ -114,6 +114,12 @@ func (b *FieldTypeBuilder) SetElems(elems []string) *FieldTypeBuilder { return b } +// SetArray sets array of the ft +func (b *FieldTypeBuilder) SetArray(x bool) *FieldTypeBuilder { + b.ft.SetArray(x) + return b +} + // Build returns the ft func (b *FieldTypeBuilder) Build() FieldType { return b.ft diff --git a/util/cpu/BUILD.bazel b/util/cpu/BUILD.bazel index 58bc047a332c4..08893520caaa0 100644 --- a/util/cpu/BUILD.bazel +++ b/util/cpu/BUILD.bazel @@ -20,5 +20,6 @@ go_test( name = "cpu_test", srcs = ["cpu_test.go"], embed = [":cpu"], + flaky = True, deps = ["@com_github_stretchr_testify//require"], ) diff --git a/util/cpu/cpu.go b/util/cpu/cpu.go index 416b3c3eaeb99..2803b4e106c49 100644 --- a/util/cpu/cpu.go +++ b/util/cpu/cpu.go @@ -56,11 +56,13 @@ func NewCPUObserver() *Observer { // Start starts the cpu observer. func (c *Observer) Start() { - ticker := time.NewTicker(100 * time.Millisecond) - defer ticker.Stop() c.wg.Add(1) go func() { - defer c.wg.Done() + ticker := time.NewTicker(100 * time.Millisecond) + defer func() { + ticker.Stop() + c.wg.Done() + }() for { select { case <-ticker.C: diff --git a/util/cpu/cpu_test.go b/util/cpu/cpu_test.go index 6c7e863f9060a..cd330a11e5196 100644 --- a/util/cpu/cpu_test.go +++ b/util/cpu/cpu_test.go @@ -42,9 +42,10 @@ func TestCPUValue(t *testing.T) { } }() } - time.Sleep(30 * time.Second) - require.Greater(t, Observer.observe(), 0.0) - require.Less(t, Observer.observe(), 1.0) + Observer.Start() + time.Sleep(5 * time.Second) + require.GreaterOrEqual(t, GetCPUUsage(), 0.0) + require.Less(t, GetCPUUsage(), 1.0) Observer.Stop() close(exit) wg.Wait() diff --git a/util/memory/tracker.go b/util/memory/tracker.go index 9c2adf31ace14..39261a45355a1 100644 --- a/util/memory/tracker.go +++ b/util/memory/tracker.go @@ -762,6 +762,17 @@ func (t *Tracker) CountAllChildrenMemUse() map[string]int64 { return trackerMemUseMap } +// GetChildrenForTest returns children trackers +func (t *Tracker) GetChildrenForTest() []*Tracker { + t.mu.Lock() + defer t.mu.Unlock() + trackers := make([]*Tracker, 0) + for _, list := range t.mu.children { + trackers = append(trackers, list...) + } + return trackers +} + func countChildMem(t *Tracker, familyTreeName string, trackerMemUseMap map[string]int64) { if len(familyTreeName) > 0 { familyTreeName += " <- "