Skip to content

Commit

Permalink
Make matching service executeWithRetry function stateless (#2305)
Browse files Browse the repository at this point in the history
* Make matching service executeWithRetry function stateless
* Polish new db task logic for later integration
  • Loading branch information
wxing1292 authored Dec 16, 2021
1 parent 8e9a8d4 commit a619fd4
Show file tree
Hide file tree
Showing 5 changed files with 38 additions and 35 deletions.
10 changes: 7 additions & 3 deletions service/matching/db_task_manager.go
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@
package matching

import (
"context"
"sync/atomic"
"time"

Expand All @@ -41,6 +42,9 @@ import (
)

const (
dbTaskInitialRangeID = 1
dbTaskStickyTaskQueueTTL = 24 * time.Hour

dbTaskFlushInterval = 24 * time.Millisecond

dbTaskDeletionInterval = 10 * time.Second
Expand All @@ -58,7 +62,7 @@ type (
taskReader *dbTaskWriter
taskWriter *dbTaskReader

dispatchTaskFn func(*internalTask) error
dispatchTaskFn func(context.Context, *internalTask) error
finishTaskFn func(*persistencespb.AllocatedTaskInfo, error)
logger log.Logger

Expand All @@ -74,7 +78,7 @@ func newDBTaskManager(
taskIDRangeSize int64,
store persistence.TaskManager,
logger log.Logger,
dispatchTaskFn func(*internalTask) error,
dispatchTaskFn func(context.Context, *internalTask) error,
finishTaskFn func(*persistencespb.AllocatedTaskInfo, error),
) (*dbTaskManager, error) {
taskOwnership := newDBTaskQueueOwnership(
Expand Down Expand Up @@ -241,7 +245,7 @@ func (d *dbTaskManager) mustDispatch(
return
}

err := d.dispatchTaskFn(newInternalTask(
err := d.dispatchTaskFn(context.Background(), newInternalTask(
task,
d.finishTaskFn,
enumsspb.TASK_SOURCE_DB_BACKLOG,
Expand Down
6 changes: 3 additions & 3 deletions service/matching/db_task_queue_ownership.go
Original file line number Diff line number Diff line change
Expand Up @@ -154,7 +154,7 @@ func (m *dbTaskQueueOwnershipImpl) takeTaskQueueOwnership() error {

case *serviceerror.NotFound:
if _, err := m.store.CreateTaskQueue(&persistence.CreateTaskQueueRequest{
RangeID: initialRangeID,
RangeID: dbTaskInitialRangeID,
TaskQueueInfo: &persistencespb.TaskQueueInfo{
NamespaceId: m.taskQueueKey.NamespaceID,
Name: m.taskQueueKey.TaskQueueName,
Expand All @@ -169,7 +169,7 @@ func (m *dbTaskQueueOwnershipImpl) takeTaskQueueOwnership() error {
return err
}
m.stateLastUpdateTime = timestamp.TimePtr(m.timeSource.Now())
m.updateStateLocked(initialRangeID, 0)
m.updateStateLocked(dbTaskInitialRangeID, 0)
m.status = dbTaskQueueOwnershipStatusOwned
return nil

Expand Down Expand Up @@ -275,7 +275,7 @@ func (m *dbTaskQueueOwnershipImpl) expiryTime() *time.Time {
case enumspb.TASK_QUEUE_KIND_NORMAL:
return nil
case enumspb.TASK_QUEUE_KIND_STICKY:
return timestamp.TimePtr(m.timeSource.Now().Add(stickyTaskQueueTTL))
return timestamp.TimePtr(m.timeSource.Now().Add(dbTaskStickyTaskQueueTTL))
default:
panic(fmt.Sprintf("taskQueueDB encountered unknown task kind: %v", m.taskQueueKind))
}
Expand Down
8 changes: 4 additions & 4 deletions service/matching/db_task_queue_ownership_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -118,7 +118,7 @@ func (s *dbTaskOwnershipSuite) TestTaskOwnership_Create_Success() {
TaskType: s.taskQueueType,
}).Return(nil, serviceerror.NewNotFound("random error message"))
s.taskStore.EXPECT().CreateTaskQueue(&persistence.CreateTaskQueueRequest{
RangeID: initialRangeID,
RangeID: dbTaskInitialRangeID,
TaskQueueInfo: &persistencespb.TaskQueueInfo{
NamespaceId: s.namespaceID,
Name: s.taskQueueName,
Expand All @@ -130,12 +130,12 @@ func (s *dbTaskOwnershipSuite) TestTaskOwnership_Create_Success() {
},
}).Return(&persistence.CreateTaskQueueResponse{}, nil)

minTaskID, maxTaskID := rangeIDToTaskIDRange(initialRangeID, s.taskIDRangeSize)
minTaskID, maxTaskID := rangeIDToTaskIDRange(dbTaskInitialRangeID, s.taskIDRangeSize)
err := s.taskOwnership.takeTaskQueueOwnership()
s.NoError(err)
s.Equal(s.now, *s.taskOwnership.stateLastUpdateTime)
s.Equal(dbTaskQueueOwnershipState{
rangeID: initialRangeID,
rangeID: dbTaskInitialRangeID,
ackedTaskID: 0,
lastAllocatedTaskID: 0,
minTaskIDExclusive: minTaskID,
Expand All @@ -156,7 +156,7 @@ func (s *dbTaskOwnershipSuite) TestTaskOwnership_Create_Failed() {
TaskType: s.taskQueueType,
}).Return(nil, serviceerror.NewNotFound("random error message"))
s.taskStore.EXPECT().CreateTaskQueue(&persistence.CreateTaskQueueRequest{
RangeID: initialRangeID,
RangeID: dbTaskInitialRangeID,
TaskQueueInfo: &persistencespb.TaskQueueInfo{
NamespaceId: s.namespaceID,
Name: s.taskQueueName,
Expand Down
40 changes: 18 additions & 22 deletions service/matching/taskQueueManager.go
Original file line number Diff line number Diff line change
Expand Up @@ -278,34 +278,34 @@ func (c *taskQueueManagerImpl) AddTask(
}

var syncMatch bool
_, err := c.executeWithRetry(func() (interface{}, error) {
td := params.taskInfo
err := executeWithRetry(func() error {
taskInfo := params.taskInfo

namespaceEntry, err := c.namespaceRegistry.GetNamespaceByID(namespace.ID(td.GetNamespaceId()))
namespaceEntry, err := c.namespaceRegistry.GetNamespaceByID(namespace.ID(taskInfo.GetNamespaceId()))
if err != nil {
return nil, err
return err
}

if !namespaceEntry.ActiveInCluster(c.clusterMeta.GetCurrentClusterName()) {
r, err := c.taskWriter.appendTask(params.execution, td)
_, err := c.taskWriter.appendTask(params.execution, taskInfo)
syncMatch = false
return r, err
return err
}

syncMatch, err = c.trySyncMatch(ctx, params)
if syncMatch {
return &persistence.CreateTasksResponse{}, err
return err
}

if params.forwardedFrom != "" {
// forwarded from child partition - only do sync match
// child partition will persist the task when sync match fails
return &persistence.CreateTasksResponse{}, errRemoteSyncMatchFailed
return errRemoteSyncMatchFailed
}

resp, err := c.taskWriter.appendTask(params.execution, params.taskInfo)
_, err = c.taskWriter.appendTask(params.execution, taskInfo)
c.signalIfFatal(err)
return resp, err
return err
})
if !syncMatch && err == nil {
c.taskReader.Signal()
Expand Down Expand Up @@ -465,9 +465,10 @@ func (c *taskQueueManagerImpl) completeTask(task *persistencespb.AllocatedTaskIn
// again the underlying reason for failing to start will be resolved.
// Note that RecordTaskStarted only fails after retrying for a long time, so a single task will not be
// re-written to persistence frequently.
_, err = c.executeWithRetry(func() (interface{}, error) {
err = executeWithRetry(func() error {
wf := &commonpb.WorkflowExecution{WorkflowId: task.Data.GetWorkflowId(), RunId: task.Data.GetRunId()}
return c.taskWriter.appendTask(wf, task.Data)
_, err := c.taskWriter.appendTask(wf, task.Data)
return err
})

if err != nil {
Expand Down Expand Up @@ -497,15 +498,10 @@ func rangeIDToTaskIDBlock(rangeID int64, rangeSize int64) taskIDBlock {
}

// Retry operation on transient error.
func (c *taskQueueManagerImpl) executeWithRetry(
operation func() (interface{}, error)) (result interface{}, err error) {

op := func() error {
result, err = operation()
return err
}

err = backoff.Retry(op, persistenceOperationRetryPolicy, func(err error) bool {
func executeWithRetry(
operation func() error,
) error {
err := backoff.Retry(operation, persistenceOperationRetryPolicy, func(err error) bool {
if common.IsContextDeadlineExceededErr(err) || common.IsContextCanceledErr(err) {
return false
}
Expand All @@ -514,7 +510,7 @@ func (c *taskQueueManagerImpl) executeWithRetry(
}
return common.IsPersistenceTransientError(err)
})
return
return err
}

func (c *taskQueueManagerImpl) trySyncMatch(ctx context.Context, params addTaskParams) (bool, error) {
Expand Down
9 changes: 6 additions & 3 deletions service/matching/taskReader.go
Original file line number Diff line number Diff line change
Expand Up @@ -196,13 +196,16 @@ Loop:
}

func (tr *taskReader) getTaskBatchWithRange(readLevel int64, maxReadLevel int64) ([]*persistencespb.AllocatedTaskInfo, error) {
response, err := tr.tlMgr.executeWithRetry(func() (interface{}, error) {
return tr.tlMgr.db.GetTasks(readLevel, maxReadLevel, tr.tlMgr.config.GetTasksBatchSize())
var response *persistence.GetTasksResponse
var err error
err = executeWithRetry(func() error {
response, err = tr.tlMgr.db.GetTasks(readLevel, maxReadLevel, tr.tlMgr.config.GetTasksBatchSize())
return err
})
if err != nil {
return nil, err
}
return response.(*persistence.GetTasksResponse).Tasks, err
return response.Tasks, err
}

// Returns a batch of tasks from persistence starting form current read level.
Expand Down

0 comments on commit a619fd4

Please sign in to comment.