diff --git a/store/tikv/gcworker/gc_worker.go b/store/tikv/gcworker/gc_worker.go index 12cb12362c147..adae8148cd51a 100644 --- a/store/tikv/gcworker/gc_worker.go +++ b/store/tikv/gcworker/gc_worker.go @@ -1256,61 +1256,60 @@ func (w *GCWorker) removeLockObservers(ctx context.Context, safePoint uint64, st // physicalScanAndResolveLocks performs physical scan lock and resolves these locks. Returns successful stores func (w *GCWorker) physicalScanAndResolveLocks(ctx context.Context, safePoint uint64, stores map[uint64]*metapb.Store) (map[uint64]interface{}, error) { + ctx, cancel := context.WithCancel(ctx) + // Cancel all spawned goroutines for lock scanning and resolving. + defer cancel() + scanner := newMergeLockScanner(safePoint, w.store.GetTiKVClient(), stores) err := scanner.Start(ctx) if err != nil { return nil, errors.Trace(err) } - innerCtx, cancel := context.WithCancel(ctx) - wg := &sync.WaitGroup{} - defer func() { - cancel() - wg.Wait() - }() - taskCh := make(chan []*tikv.Lock, len(stores)) errCh := make(chan error, len(stores)) + wg := &sync.WaitGroup{} for range stores { wg.Add(1) go func() { defer wg.Done() for { select { - case <-innerCtx.Done(): - return case locks, ok := <-taskCh: if !ok { - // Closed + // All locks have been resolved. return } - err := w.resolveLocksAcrossRegions(innerCtx, locks) + err := w.resolveLocksAcrossRegions(ctx, locks) if err != nil { - logutil.Logger(innerCtx).Error("resolve locks failed", zap.Error(err)) + logutil.Logger(ctx).Error("resolve locks failed", zap.Error(err)) errCh <- err } + case <-ctx.Done(): + return } } }() } for { - select { - case err := <-errCh: - return nil, errors.Trace(err) - default: - } - locks := scanner.NextBatch(128) if len(locks) == 0 { break } - taskCh <- locks + select { + case taskCh <- locks: + case err := <-errCh: + return nil, errors.Trace(err) + case <-ctx.Done(): + return nil, ctx.Err() + } } close(taskCh) + // Wait for all locks resolved. wg.Wait() select { @@ -1323,6 +1322,12 @@ func (w *GCWorker) physicalScanAndResolveLocks(ctx context.Context, safePoint ui } func (w *GCWorker) resolveLocksAcrossRegions(ctx context.Context, locks []*tikv.Lock) error { + failpoint.Inject("resolveLocksAcrossRegionsErr", func(v failpoint.Value) { + ms := v.(int) + time.Sleep(time.Duration(ms) * time.Millisecond) + failpoint.Return(errors.New("injectedError")) + }) + bo := tikv.NewBackoffer(ctx, tikv.GcResolveLockMaxBackoff) for { @@ -1902,7 +1907,11 @@ func (s *mergeLockScanner) Start(ctx context.Context) error { zap.Uint64("safePoint", s.safePoint), zap.Any("store", store1), zap.Error(err)) - ch <- scanLockResult{Err: err} + + select { + case ch <- scanLockResult{Err: err}: + case <-ctx.Done(): + } } }() receivers = append(receivers, &receiver{Ch: ch, StoreID: storeID}) @@ -1939,7 +1948,7 @@ func (s *mergeLockScanner) NextBatch(batchSize int) []*tikv.Lock { for len(result) < batchSize { lock := s.Next() if lock == nil { - return result + break } result = append(result, lock) } @@ -1991,7 +2000,11 @@ func (s *mergeLockScanner) physicalScanLocksForStore(ctx context.Context, safePo nextKey = append(nextKey, 0) for _, lockInfo := range resp.Locks { - lockCh <- scanLockResult{Lock: tikv.NewLock(lockInfo)} + select { + case lockCh <- scanLockResult{Lock: tikv.NewLock(lockInfo)}: + case <-ctx.Done(): + return ctx.Err() + } } if len(resp.Locks) < int(s.scanLockLimit) { diff --git a/store/tikv/gcworker/gc_worker_test.go b/store/tikv/gcworker/gc_worker_test.go index c6d5fcb80af71..814d3f061f2b4 100644 --- a/store/tikv/gcworker/gc_worker_test.go +++ b/store/tikv/gcworker/gc_worker_test.go @@ -1165,3 +1165,47 @@ func (s *testGCWorkerSuite) TestMergeLockScanner(c *C) { c.Assert(scanner.GetSucceededStores(), DeepEquals, makeIDSet(storeIDs, 0, 1, 2)) } } + +func (s *testGCWorkerSuite) TestPhyscailScanLockDeadlock(c *C) { + ctx := context.Background() + stores := s.cluster.GetAllStores() + c.Assert(len(stores), Greater, 1) + + s.client.physicalScanLockHandler = func(addr string, req *tikvrpc.Request) (*tikvrpc.Response, error) { + c.Assert(addr, Equals, stores[0].Address) + scanReq := req.PhysicalScanLock() + scanLockLimit := int(scanReq.Limit) + locks := make([]*kvrpcpb.LockInfo, 0, scanReq.Limit) + for i := 0; i < scanLockLimit; i++ { + // The order of keys doesn't matter. + locks = append(locks, &kvrpcpb.LockInfo{Key: []byte{byte(i)}}) + } + return &tikvrpc.Response{ + Resp: &kvrpcpb.PhysicalScanLockResponse{ + Locks: locks, + Error: "", + }, + }, nil + } + + // Sleep 1000ms to let the main goroutine block on sending tasks. + // Inject error to the goroutine resolving locks so that the main goroutine will block forever if it doesn't handle channels properly. + c.Assert(failpoint.Enable("github.com/pingcap/tidb/store/tikv/gcworker/resolveLocksAcrossRegionsErr", "return(1000)"), IsNil) + defer func() { + c.Assert(failpoint.Disable("github.com/pingcap/tidb/store/tikv/gcworker/resolveLocksAcrossRegionsErr"), IsNil) + }() + + done := make(chan interface{}) + go func() { + defer close(done) + storesMap := map[uint64]*metapb.Store{stores[0].Id: stores[0]} + succeeded, err := s.gcWorker.physicalScanAndResolveLocks(ctx, 10000, storesMap) + c.Assert(succeeded, IsNil) + c.Assert(err, ErrorMatches, "injectedError") + }() + select { + case <-done: + case <-time.After(5 * time.Second): + c.Fatal("physicalScanAndResolveLocks blocks") + } +}