Skip to content

Commit

Permalink
*: add methods to session.Context and refactor some code about ddl (#…
Browse files Browse the repository at this point in the history
  • Loading branch information
xiongjiwei authored May 30, 2022
1 parent 7f023bd commit eb46685
Show file tree
Hide file tree
Showing 7 changed files with 62 additions and 25 deletions.
24 changes: 12 additions & 12 deletions ddl/column_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -191,7 +191,7 @@ func TestColumnBasic(t *testing.T) {

require.Nil(t, table.FindCol(tbl.Cols(), "c4"))

jobID := testCreateColumn(tk, t, testNewContext(store), tableID, "c4", "after c3", 100, dom)
jobID := testCreateColumn(tk, t, testkit.NewTestKit(t, store).Session(), tableID, "c4", "after c3", 100, dom)
testCheckJobDone(t, store, jobID, true)

tbl = testGetTable(t, dom, tableID)
Expand Down Expand Up @@ -221,7 +221,7 @@ func TestColumnBasic(t *testing.T) {
require.Len(t, values, 4)
require.Equal(t, values[3].GetInt64(), int64(14))

jobID = testDropColumnInternal(tk, t, testNewContext(store), tableID, "c4", false, dom)
jobID = testDropColumnInternal(tk, t, testkit.NewTestKit(t, store).Session(), tableID, "c4", false, dom)
testCheckJobDone(t, store, jobID, false)

tbl = testGetTable(t, dom, tableID)
Expand All @@ -231,7 +231,7 @@ func TestColumnBasic(t *testing.T) {
require.Len(t, values, 3)
require.Equal(t, values[2].GetInt64(), int64(13))

jobID = testCreateColumn(tk, t, testNewContext(store), tableID, "c4", "", 111, dom)
jobID = testCreateColumn(tk, t, testkit.NewTestKit(t, store).Session(), tableID, "c4", "", 111, dom)
testCheckJobDone(t, store, jobID, true)

tbl = testGetTable(t, dom, tableID)
Expand All @@ -241,7 +241,7 @@ func TestColumnBasic(t *testing.T) {
require.Len(t, values, 4)
require.Equal(t, values[3].GetInt64(), int64(111))

jobID = testCreateColumn(tk, t, testNewContext(store), tableID, "c5", "", 101, dom)
jobID = testCreateColumn(tk, t, testkit.NewTestKit(t, store).Session(), tableID, "c5", "", 101, dom)
testCheckJobDone(t, store, jobID, true)

tbl = testGetTable(t, dom, tableID)
Expand All @@ -251,7 +251,7 @@ func TestColumnBasic(t *testing.T) {
require.Len(t, values, 5)
require.Equal(t, values[4].GetInt64(), int64(101))

jobID = testCreateColumn(tk, t, testNewContext(store), tableID, "c6", "first", 202, dom)
jobID = testCreateColumn(tk, t, testkit.NewTestKit(t, store).Session(), tableID, "c6", "first", 202, dom)
testCheckJobDone(t, store, jobID, true)

tbl = testGetTable(t, dom, tableID)
Expand All @@ -277,7 +277,7 @@ func TestColumnBasic(t *testing.T) {
require.Equal(t, values[0].GetInt64(), int64(202))
require.Equal(t, values[5].GetInt64(), int64(101))

jobID = testDropColumnInternal(tk, t, testNewContext(store), tableID, "c2", false, dom)
jobID = testDropColumnInternal(tk, t, testkit.NewTestKit(t, store).Session(), tableID, "c2", false, dom)
testCheckJobDone(t, store, jobID, false)

tbl = testGetTable(t, dom, tableID)
Expand All @@ -288,22 +288,22 @@ func TestColumnBasic(t *testing.T) {
require.Equal(t, values[0].GetInt64(), int64(202))
require.Equal(t, values[4].GetInt64(), int64(101))

jobID = testDropColumnInternal(tk, t, testNewContext(store), tableID, "c1", false, dom)
jobID = testDropColumnInternal(tk, t, testkit.NewTestKit(t, store).Session(), tableID, "c1", false, dom)
testCheckJobDone(t, store, jobID, false)

jobID = testDropColumnInternal(tk, t, testNewContext(store), tableID, "c3", false, dom)
jobID = testDropColumnInternal(tk, t, testkit.NewTestKit(t, store).Session(), tableID, "c3", false, dom)
testCheckJobDone(t, store, jobID, false)

jobID = testDropColumnInternal(tk, t, testNewContext(store), tableID, "c4", false, dom)
jobID = testDropColumnInternal(tk, t, testkit.NewTestKit(t, store).Session(), tableID, "c4", false, dom)
testCheckJobDone(t, store, jobID, false)

jobID = testCreateIndex(tk, t, testNewContext(store), tableID, false, "c5_idx", "c5", dom)
jobID = testCreateIndex(tk, t, testkit.NewTestKit(t, store).Session(), tableID, false, "c5_idx", "c5", dom)
testCheckJobDone(t, store, jobID, true)

jobID = testDropColumnInternal(tk, t, testNewContext(store), tableID, "c5", false, dom)
jobID = testDropColumnInternal(tk, t, testkit.NewTestKit(t, store).Session(), tableID, "c5", false, dom)
testCheckJobDone(t, store, jobID, false)

jobID = testDropColumnInternal(tk, t, testNewContext(store), tableID, "c6", true, dom)
jobID = testDropColumnInternal(tk, t, testkit.NewTestKit(t, store).Session(), tableID, "c6", true, dom)
testCheckJobDone(t, store, jobID, false)

testDropTable(tk, t, "test", "t1", dom)
Expand Down
18 changes: 13 additions & 5 deletions ddl/ddl.go
Original file line number Diff line number Diff line change
Expand Up @@ -499,11 +499,7 @@ func (d *ddl) Start(ctxPool *pools.ResourcePool) error {
asyncNotify(worker.ddlJobCh)
}

err = kv.RunInNewTxn(d.ctx, d.store, true, func(ctx context.Context, txn kv.Transaction) error {
t := meta.NewMeta(txn)
d.ddlSeqNumMu.seqNum, err = t.GetHistoryDDLCount()
return err
})
d.ddlSeqNumMu.seqNum, err = d.GetNextDDLSeqNum()
if err != nil {
return err
}
Expand All @@ -526,6 +522,18 @@ func (d *ddl) Start(ctxPool *pools.ResourcePool) error {
return nil
}

// GetNextDDLSeqNum return the next ddl seq num.
func (d *ddl) GetNextDDLSeqNum() (uint64, error) {
var count uint64
err := kv.RunInNewTxn(d.ctx, d.store, true, func(ctx context.Context, txn kv.Transaction) error {
t := meta.NewMeta(txn)
var err error
count, err = t.GetHistoryDDLCount()
return err
})
return count, err
}

func (d *ddl) close() {
if isChanClosed(d.ctx.Done()) {
return
Expand Down
2 changes: 1 addition & 1 deletion ddl/fail_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -60,7 +60,7 @@ func TestFailBeforeDecodeArgs(t *testing.T) {
}
d.SetHook(tc)
defaultValue := int64(3)
jobID := testCreateColumn(tk, t, testNewContext(store), tableID, "c3", "", defaultValue, dom)
jobID := testCreateColumn(tk, t, testkit.NewTestKit(t, store).Session(), tableID, "c3", "", defaultValue, dom)
// Make sure the schema state only appears once.
require.Equal(t, 1, stateCnt)
testCheckJobDone(t, store, jobID, true)
Expand Down
2 changes: 1 addition & 1 deletion executor/executor.go
Original file line number Diff line number Diff line change
Expand Up @@ -484,7 +484,7 @@ type ShowDDLJobsExec struct {
// nolint:structcheck
type DDLJobRetriever struct {
runningJobs []*model.Job
historyJobIter *meta.LastJobIterator
historyJobIter meta.LastJobIterator
cursor int
is infoschema.InfoSchema
activeRoles []*auth.RoleIdentity
Expand Down
15 changes: 10 additions & 5 deletions meta/meta.go
Original file line number Diff line number Diff line change
Expand Up @@ -1045,23 +1045,28 @@ func (m *Meta) GetLastNHistoryDDLJobs(num int) ([]*model.Job, error) {
}

// LastJobIterator is the iterator for gets latest history.
type LastJobIterator struct {
iter *structure.ReverseHashIterator
type LastJobIterator interface {
GetLastJobs(num int, jobs []*model.Job) ([]*model.Job, error)
}

// GetLastHistoryDDLJobsIterator gets latest N history ddl jobs iterator.
func (m *Meta) GetLastHistoryDDLJobsIterator() (*LastJobIterator, error) {
func (m *Meta) GetLastHistoryDDLJobsIterator() (LastJobIterator, error) {
iter, err := structure.NewHashReverseIter(m.txn, mDDLJobHistoryKey)
if err != nil {
return nil, err
}
return &LastJobIterator{
return &HLastJobIterator{
iter: iter,
}, nil
}

// HLastJobIterator is the iterator for gets the latest history.
type HLastJobIterator struct {
iter *structure.ReverseHashIterator
}

// GetLastJobs gets last several jobs.
func (i *LastJobIterator) GetLastJobs(num int, jobs []*model.Job) ([]*model.Job, error) {
func (i *HLastJobIterator) GetLastJobs(num int, jobs []*model.Job) ([]*model.Job, error) {
if len(jobs) < num {
jobs = make([]*model.Job, 0, num)
}
Expand Down
8 changes: 7 additions & 1 deletion sessionctx/context.go
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@ import (
"time"

"github.com/pingcap/errors"
"github.com/pingcap/kvproto/pkg/kvrpcpb"
"github.com/pingcap/tidb/kv"
"github.com/pingcap/tidb/metrics"
"github.com/pingcap/tidb/parser/model"
Expand Down Expand Up @@ -48,7 +49,12 @@ type Context interface {
NewTxn(context.Context) error
// NewStaleTxnWithStartTS initializes a staleness transaction with the given StartTS.
NewStaleTxnWithStartTS(ctx context.Context, startTS uint64) error

// SetDiskFullOpt set the disk full opt when tikv disk full happened.
SetDiskFullOpt(level kvrpcpb.DiskFullOpt)
// RollbackTxn rolls back the current transaction.
RollbackTxn(ctx context.Context)
// CommitTxn commits the current transaction.
CommitTxn(ctx context.Context) error
// Txn returns the current transaction which is created before executing a statement.
// The returned kv.Transaction is not nil, but it maybe pending or invalid.
// If the active parameter is true, call this function will wait for the pending txn
Expand Down
18 changes: 18 additions & 0 deletions util/mock/context.go
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@ import (
"github.com/pingcap/tidb/kv"
"github.com/pingcap/tidb/parser/ast"
"github.com/pingcap/tidb/parser/model"
"github.com/pingcap/tidb/parser/terror"
"github.com/pingcap/tidb/sessionctx"
"github.com/pingcap/tidb/sessionctx/variable"
"github.com/pingcap/tidb/util"
Expand Down Expand Up @@ -247,6 +248,23 @@ func (c *Context) RefreshVars(ctx context.Context) error {
return nil
}

// RollbackTxn indicates an expected call of RollbackTxn.
func (c *Context) RollbackTxn(ctx context.Context) {
defer c.sessionVars.SetInTxn(false)
if c.txn.Valid() {
terror.Log(c.txn.Rollback())
}
}

// CommitTxn indicates an expected call of CommitTxn.
func (c *Context) CommitTxn(ctx context.Context) error {
defer c.sessionVars.SetInTxn(false)
if c.txn.Valid() {
return c.txn.Commit(ctx)
}
return nil
}

// InitTxnWithStartTS implements the sessionctx.Context interface with startTS.
func (c *Context) InitTxnWithStartTS(startTS uint64) error {
if c.txn.Valid() {
Expand Down

0 comments on commit eb46685

Please sign in to comment.