Skip to content
This repository has been archived by the owner on Nov 24, 2023. It is now read-only.

*: fix context usage for SQL operation #377

Merged
merged 18 commits into from
Dec 9, 2019
Merged
Show file tree
Hide file tree
Changes from 13 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion checker/checker.go
Original file line number Diff line number Diff line change
Expand Up @@ -97,7 +97,7 @@ func NewChecker(cfgs []*config.SubTaskConfig, checkingItems map[string]string) *
}

// Init implements Unit interface
func (c *Checker) Init() (err error) {
func (c *Checker) Init(ctx context.Context) (err error) {
rollbackHolder := fr.NewRollbackHolder("checker")
defer func() {
if err != nil {
Expand Down
2 changes: 1 addition & 1 deletion checker/cmd.go
Original file line number Diff line number Diff line change
Expand Up @@ -58,7 +58,7 @@ func CheckSyncConfig(ctx context.Context, cfgs []*config.SubTaskConfig) error {

c := NewChecker(cfgs, checkingItems)

err := c.Init()
err := c.Init(ctx)
if err != nil {
return terror.Annotate(err, "fail to initial checker")
}
Expand Down
10 changes: 8 additions & 2 deletions dm/unit/unit.go
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@ package unit

import (
"context"
"time"

"github.com/pingcap/errors"

Expand All @@ -23,14 +24,19 @@ import (
"github.com/pingcap/dm/pkg/terror"
)

const (
// DefaultInitTimeout represents the default timeout value when initializing a process unit.
DefaultInitTimeout = time.Minute
)

// Unit defines interface for sub task process units, like syncer, loader, relay, etc.
type Unit interface {
// Init initializes the dm process unit
// every unit does base initialization in `Init`, and this must pass before start running the sub task
// other setups can be done in `Process`, but this should be treated carefully, let it's compatible with Pause / Resume
// if initialing successfully, the outer caller should call `Close` when the unit (or the task) finished, stopped or canceled (because other units Init fail).
// if initialing fail, Init itself should release resources it acquired before (rolling itself back).
Init() error
Init(ctx context.Context) error
// Process processes sub task
// When ctx.Done, stops the process and returns
// When not in processing, call Process to continue or resume the process
Expand All @@ -52,7 +58,7 @@ type Unit interface {
Type() pb.UnitType
// IsFreshTask return whether is a fresh task (not processed before)
// it will be used to decide where the task should become restoring
IsFreshTask() (bool, error)
IsFreshTask(ctx context.Context) (bool, error)
}

// NewProcessError creates a new ProcessError
Expand Down
5 changes: 4 additions & 1 deletion dm/worker/relay.go
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@ import (
"go.uber.org/zap"

"github.com/pingcap/dm/dm/pb"
"github.com/pingcap/dm/dm/unit"
"github.com/pingcap/dm/pkg/log"
"github.com/pingcap/dm/pkg/streamer"
"github.com/pingcap/dm/pkg/terror"
Expand Down Expand Up @@ -125,7 +126,9 @@ func (h *realRelayHolder) Init(interceptors []purger.PurgeInterceptor) (purger.P
streamer.GetReaderHub(),
}

if err := h.relay.Init(); err != nil {
ctx, cancel := context.WithTimeout(context.Background(), unit.DefaultInitTimeout)
defer cancel()
if err := h.relay.Init(ctx); err != nil {
return nil, terror.Annotate(err, "initial relay unit")
}

Expand Down
2 changes: 1 addition & 1 deletion dm/worker/relay_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -50,7 +50,7 @@ func NewDummyRelay(cfg *relay.Config) relay.Process {
}

// Init implements Process interface
func (d *DummyRelay) Init() error {
func (d *DummyRelay) Init(ctx context.Context) error {
return d.initErr
}

Expand Down
8 changes: 6 additions & 2 deletions dm/worker/subtask.go
Original file line number Diff line number Diff line change
Expand Up @@ -125,7 +125,9 @@ func (st *SubTask) Init() error {
// other setups can be done in `Process`, like Loader's prepare which depends on Mydumper's output
// but setups in `Process` should be treated carefully, let it's compatible with Pause / Resume
for i, u := range st.units {
err := u.Init()
ctx, cancel := context.WithTimeout(context.Background(), unit.DefaultInitTimeout)
err := u.Init(ctx)
cancel()
if err != nil {
initializeUnitSuccess = false
// when init fail, other units initialized before should be closed
Expand All @@ -140,7 +142,9 @@ func (st *SubTask) Init() error {
var skipIdx = 0
for i := len(st.units) - 1; i > 0; i-- {
u := st.units[i]
isFresh, err := u.IsFreshTask()
ctx, cancel := context.WithTimeout(context.Background(), unit.DefaultInitTimeout)
isFresh, err := u.IsFreshTask(ctx)
cancel()
if err != nil {
initializeUnitSuccess = false
return terror.Annotatef(err, "fail to get fresh status of subtask %s %s", st.cfg.Name, u.Type())
Expand Down
4 changes: 2 additions & 2 deletions dm/worker/subtask_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -83,7 +83,7 @@ func NewMockUnit(typ pb.UnitType) *MockUnit {
}
}

func (m *MockUnit) Init() error {
func (m *MockUnit) Init(ctx context.Context) error {
return m.errInit
}

Expand Down Expand Up @@ -121,7 +121,7 @@ func (m *MockUnit) Error() interface{} { return nil }

func (m *MockUnit) Type() pb.UnitType { return m.typ }

func (m *MockUnit) IsFreshTask() (bool, error) { return m.isFresh, m.errFresh }
func (m *MockUnit) IsFreshTask(ctx context.Context) (bool, error) { return m.isFresh, m.errFresh }

func (m *MockUnit) InjectProcessError(ctx context.Context, err error) error {
newCtx, cancel := context.WithTimeout(ctx, time.Second)
Expand Down
57 changes: 30 additions & 27 deletions loader/checkpoint.go
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,7 @@ type CheckPoint interface {
// Load loads all checkpoints recorded before.
// because of no checkpoints updated in memory when error occurred
// when resuming, Load will be called again to load checkpoints
Load() error
Load(tctx *tcontext.Context) error

// GetRestoringFileInfo get restoring data files for table
GetRestoringFileInfo(db, table string) map[string][]int64
Expand All @@ -48,19 +48,19 @@ type CheckPoint interface {
CalcProgress(allFiles map[string]Tables2DataFiles) error

// Init initialize checkpoint data in tidb
Init(filename string, endpos int64) error
Init(tctx *tcontext.Context, filename string, endpos int64) error

// ResetConn resets database connections owned by the Checkpoint
ResetConn() error
ResetConn(tctx *tcontext.Context) error

// Close closes the CheckPoint
Close()

// Clear clears all recorded checkpoints
Clear() error
Clear(tctx *tcontext.Context) error

// Count returns recorded checkpoints' count
Count() (int, error)
Count(tctx *tcontext.Context) (int, error)

// GenSQL generates sql to update checkpoint to DB
GenSQL(filename string, offset int64) string
Expand All @@ -85,8 +85,6 @@ func newRemoteCheckPoint(tctx *tcontext.Context, cfg *config.SubTaskConfig, id s
return nil, err
}

newtctx := tctx.WithLogger(tctx.L().WithFields(zap.String("component", "remote checkpoint")))

cp := &RemoteCheckPoint{
db: db,
conn: dbConns[0],
Expand All @@ -95,36 +93,36 @@ func newRemoteCheckPoint(tctx *tcontext.Context, cfg *config.SubTaskConfig, id s
finishedTables: make(map[string]struct{}),
schema: cfg.MetaSchema,
table: fmt.Sprintf("%s_loader_checkpoint", cfg.Name),
tctx: newtctx,
tctx: tctx.WithLogger(tctx.L().WithFields(zap.String("component", "remote checkpoint"))),
}

err = cp.prepare()
err = cp.prepare(tctx)
if err != nil {
return nil, err
}

return cp, nil
}

func (cp *RemoteCheckPoint) prepare() error {
func (cp *RemoteCheckPoint) prepare(tctx *tcontext.Context) error {
// create schema
if err := cp.createSchema(); err != nil {
if err := cp.createSchema(tctx); err != nil {
return err
}
// create table
if err := cp.createTable(); err != nil {
if err := cp.createTable(tctx); err != nil {
return err
}
return nil
}

func (cp *RemoteCheckPoint) createSchema() error {
func (cp *RemoteCheckPoint) createSchema(tctx *tcontext.Context) error {
sql2 := fmt.Sprintf("CREATE SCHEMA IF NOT EXISTS `%s`", cp.schema)
err := cp.conn.executeSQL(cp.tctx, []string{sql2})
err := cp.conn.executeSQL(tctx, []string{sql2})
return terror.WithScope(err, terror.ScopeDownstream)
}

func (cp *RemoteCheckPoint) createTable() error {
func (cp *RemoteCheckPoint) createTable(tctx *tcontext.Context) error {
tableName := fmt.Sprintf("`%s`.`%s`", cp.schema, cp.table)
createTable := `CREATE TABLE IF NOT EXISTS %s (
id char(32) NOT NULL,
Expand All @@ -139,19 +137,19 @@ func (cp *RemoteCheckPoint) createTable() error {
);
`
sql2 := fmt.Sprintf(createTable, tableName)
err := cp.conn.executeSQL(cp.tctx, []string{sql2})
err := cp.conn.executeSQL(tctx, []string{sql2})
return terror.WithScope(err, terror.ScopeDownstream)
}

// Load implements CheckPoint.Load
func (cp *RemoteCheckPoint) Load() error {
func (cp *RemoteCheckPoint) Load(tctx *tcontext.Context) error {
begin := time.Now()
defer func() {
cp.tctx.L().Info("load checkpoint", zap.Duration("cost time", time.Since(begin)))
}()

query := fmt.Sprintf("SELECT `filename`,`cp_schema`,`cp_table`,`offset`,`end_pos` from `%s`.`%s` where `id`=?", cp.schema, cp.table)
rows, err := cp.conn.querySQL(cp.tctx, query, cp.id)
rows, err := cp.conn.querySQL(tctx, query, cp.id)
if err != nil {
return terror.WithScope(err, terror.ScopeDownstream)
}
Expand Down Expand Up @@ -266,7 +264,7 @@ func (cp *RemoteCheckPoint) allFilesFinished(files map[string][]int64) bool {
}

// Init implements CheckPoint.Init
func (cp *RemoteCheckPoint) Init(filename string, endPos int64) error {
func (cp *RemoteCheckPoint) Init(tctx *tcontext.Context, filename string, endPos int64) error {
idx := strings.Index(filename, ".sql")
if idx < 0 {
return terror.ErrCheckpointInvalidTableFile.Generate(filename)
Expand All @@ -288,7 +286,7 @@ func (cp *RemoteCheckPoint) Init(filename string, endPos int64) error {
zap.Int64("offset", 0),
zap.Int64("end position", endPos))
args := []interface{}{cp.id, filename, fields[0], fields[1], 0, endPos}
err := cp.conn.executeSQL(cp.tctx, []string{sql2}, args)
err := cp.conn.executeSQL(tctx, []string{sql2}, args)
if err != nil {
if isErrDupEntry(err) {
cp.tctx.L().Info("checkpoint record already exists, skip it.", zap.String("id", cp.id), zap.String("filename", filename))
Expand All @@ -300,8 +298,8 @@ func (cp *RemoteCheckPoint) Init(filename string, endPos int64) error {
}

// ResetConn implements CheckPoint.ResetConn
func (cp *RemoteCheckPoint) ResetConn() error {
return cp.conn.resetConn(cp.tctx)
func (cp *RemoteCheckPoint) ResetConn(tctx *tcontext.Context) error {
return cp.conn.resetConn(tctx)
}

// Close implements CheckPoint.Close
Expand All @@ -320,16 +318,16 @@ func (cp *RemoteCheckPoint) GenSQL(filename string, offset int64) string {
}

// Clear implements CheckPoint.Clear
func (cp *RemoteCheckPoint) Clear() error {
func (cp *RemoteCheckPoint) Clear(tctx *tcontext.Context) error {
sql2 := fmt.Sprintf("DELETE FROM `%s`.`%s` WHERE `id` = '%s'", cp.schema, cp.table, cp.id)
err := cp.conn.executeSQL(cp.tctx, []string{sql2})
err := cp.conn.executeSQL(tctx, []string{sql2})
return terror.WithScope(err, terror.ScopeDownstream)
}

// Count implements CheckPoint.Count
func (cp *RemoteCheckPoint) Count() (int, error) {
func (cp *RemoteCheckPoint) Count(tctx *tcontext.Context) (int, error) {
query := fmt.Sprintf("SELECT COUNT(id) FROM `%s`.`%s` WHERE `id` = ?", cp.schema, cp.table)
rows, err := cp.conn.querySQL(cp.tctx, query, cp.id)
rows, err := cp.conn.querySQL(tctx, query, cp.id)
if err != nil {
return 0, terror.WithScope(err, terror.ScopeDownstream)
}
Expand All @@ -349,7 +347,12 @@ func (cp *RemoteCheckPoint) Count() (int, error) {
}

func (cp *RemoteCheckPoint) String() string {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think we can refine this function later, it used to log, but need visit the database, maybe we can just use information saved in RemoteCheckPoint

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

agree.

if err := cp.Load(); err != nil {
// `String` is often used to log something, it's not a big problem even fail,
// so 1min should be enough.
tctx2, cancel := cp.tctx.WithTimeout(time.Minute)
defer cancel()

if err := cp.Load(tctx2); err != nil {
return err.Error()
}

Expand Down
22 changes: 11 additions & 11 deletions loader/checkpoint_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -76,26 +76,26 @@ func (t *testCheckPointSuite) TestForDB(c *C) {
c.Assert(err, IsNil)
defer cp.Close()

cp.Clear()
cp.Clear(tctx)

// no checkpoint exist
err = cp.Load()
err = cp.Load(tctx)
c.Assert(err, IsNil)

infos := cp.GetAllRestoringFileInfo()
c.Assert(len(infos), Equals, 0)

count, err := cp.Count()
count, err := cp.Count(tctx)
c.Assert(err, IsNil)
c.Assert(count, Equals, 0)

// insert default checkpoints
for _, cs := range cases {
err = cp.Init(cs.filename, cs.endPos)
err = cp.Init(tctx, cs.filename, cs.endPos)
c.Assert(err, IsNil)
}

err = cp.Load()
err = cp.Load(tctx)
c.Assert(err, IsNil)

infos = cp.GetAllRestoringFileInfo()
Expand All @@ -108,7 +108,7 @@ func (t *testCheckPointSuite) TestForDB(c *C) {
c.Assert(info[1], Equals, cs.endPos)
}

count, err = cp.Count()
count, err = cp.Count(tctx)
c.Assert(err, IsNil)
c.Assert(count, Equals, len(cases))

Expand All @@ -126,7 +126,7 @@ func (t *testCheckPointSuite) TestForDB(c *C) {
c.Assert(err, IsNil)
}

err = cp.Load()
err = cp.Load(tctx)
c.Assert(err, IsNil)

infos = cp.GetAllRestoringFileInfo()
Expand All @@ -139,22 +139,22 @@ func (t *testCheckPointSuite) TestForDB(c *C) {
c.Assert(info[1], Equals, cs.endPos)
}

count, err = cp.Count()
count, err = cp.Count(tctx)
c.Assert(err, IsNil)
c.Assert(count, Equals, len(cases))

// clear all
cp.Clear()
cp.Clear(tctx)

// no checkpoint exist
err = cp.Load()
err = cp.Load(tctx)
c.Assert(err, IsNil)

infos = cp.GetAllRestoringFileInfo()
c.Assert(len(infos), Equals, 0)

// obtain count again
count, err = cp.Count()
count, err = cp.Count(tctx)
c.Assert(err, IsNil)
c.Assert(count, Equals, 0)
}
Loading