Skip to content

Commit

Permalink
Change message look ahead (#1136)
Browse files Browse the repository at this point in the history
Change message look ahead
  • Loading branch information
Quinn-With-Two-Ns authored Jun 14, 2023
1 parent 5075346 commit 2990ebf
Show file tree
Hide file tree
Showing 5 changed files with 259 additions and 48 deletions.
46 changes: 31 additions & 15 deletions internal/internal_task_handlers.go
Original file line number Diff line number Diff line change
Expand Up @@ -49,6 +49,7 @@ import (
"go.temporal.io/api/workflowservice/v1"

"go.temporal.io/sdk/internal/common/retry"
"go.temporal.io/sdk/internal/protocol"

"go.temporal.io/sdk/converter"
"go.temporal.io/sdk/internal/common"
Expand Down Expand Up @@ -292,21 +293,21 @@ func isCommandEvent(eventType enumspb.EventType) bool {

// NextCommandEvents returns events that there processed as new by the next command.
// TODO(maxim): Refactor to return a struct instead of multiple parameters
func (eh *history) NextCommandEvents() (result []*historypb.HistoryEvent, markers []*historypb.HistoryEvent, binaryChecksum string, sdkFlags []sdkFlag, err error) {
func (eh *history) NextCommandEvents() (result []*historypb.HistoryEvent, markers []*historypb.HistoryEvent, binaryChecksum string, sdkFlags []sdkFlag, msgs []*protocolpb.Message, err error) {
if eh.next == nil {
eh.next, _, eh.nextFlags, err = eh.nextCommandEvents()
eh.next, _, eh.nextFlags, _, err = eh.nextCommandEvents()
if err != nil {
return result, markers, eh.binaryChecksum, sdkFlags, err
return result, markers, eh.binaryChecksum, sdkFlags, msgs, err
}
}

result = eh.next
checksum := eh.binaryChecksum
sdkFlags = eh.nextFlags
if len(result) > 0 {
eh.next, markers, eh.nextFlags, err = eh.nextCommandEvents()
eh.next, markers, eh.nextFlags, msgs, err = eh.nextCommandEvents()
}
return result, markers, checksum, sdkFlags, err
return result, markers, checksum, sdkFlags, msgs, err
}

func (eh *history) hasMoreEvents() bool {
Expand Down Expand Up @@ -334,12 +335,12 @@ func (eh *history) verifyAllEventsProcessed() error {
return nil
}

func (eh *history) nextCommandEvents() (nextEvents []*historypb.HistoryEvent, markers []*historypb.HistoryEvent, sdkFlags []sdkFlag, err error) {
func (eh *history) nextCommandEvents() (nextEvents []*historypb.HistoryEvent, markers []*historypb.HistoryEvent, sdkFlags []sdkFlag, msgs []*protocolpb.Message, err error) {
if eh.currentIndex == len(eh.loadedEvents) && !eh.hasMoreEvents() {
if err := eh.verifyAllEventsProcessed(); err != nil {
return nil, nil, nil, err
return nil, nil, nil, nil, err
}
return []*historypb.HistoryEvent{}, []*historypb.HistoryEvent{}, []sdkFlag{}, nil
return []*historypb.HistoryEvent{}, []*historypb.HistoryEvent{}, []sdkFlag{}, []*protocolpb.Message{}, nil
}

// Process events
Expand Down Expand Up @@ -391,6 +392,15 @@ OrderEvents:
default:
if isPreloadMarkerEvent(event) {
markers = append(markers, event)
} else if attrs := event.GetWorkflowExecutionUpdateAcceptedEventAttributes(); attrs != nil {
msgs = append(msgs, &protocolpb.Message{
Id: attrs.GetAcceptedRequestMessageId(),
ProtocolInstanceId: attrs.GetProtocolInstanceId(),
SequencingId: &protocolpb.Message_EventId{
EventId: attrs.GetAcceptedRequestSequencingEventId(),
},
Body: protocol.MustMarshalAny(attrs.GetAcceptedRequest()),
})
}
nextEvents = append(nextEvents, event)
}
Expand All @@ -408,7 +418,7 @@ OrderEvents:

eh.currentIndex = 0

return nextEvents, markers, sdkFlags, nil
return nextEvents, markers, sdkFlags, msgs, nil
}

func isPreloadMarkerEvent(event *historypb.HistoryEvent) bool {
Expand Down Expand Up @@ -566,7 +576,6 @@ func (wth *workflowTaskHandlerImpl) createWorkflowContext(task *workflowservice.
if taskQueue == nil || taskQueue.Name == "" {
return nil, errors.New("nil or empty TaskQueue in WorkflowExecutionStarted event")
}
task.Messages = append(inferMessages(task.GetHistory().GetEvents()), task.Messages...)

runID := task.WorkflowExecution.GetRunId()
workflowID := task.WorkflowExecution.GetWorkflowId()
Expand Down Expand Up @@ -713,7 +722,6 @@ func (w *workflowExecutionContextImpl) resetStateIfDestroyed(task *workflowservi
return err
}
}
task.Messages = append(inferMessages(task.GetHistory().GetEvents()), task.Messages...)
if w.workflowInfo != nil {
// Reset the search attributes and memos from the WorkflowExecutionStartedEvent.
// The search attributes and memo may have been modified by calls like UpsertMemo
Expand Down Expand Up @@ -878,8 +886,7 @@ func (w *workflowExecutionContextImpl) ProcessWorkflowTask(workflowTask *workflo
var replayCommands []*commandpb.Command
var respondEvents []*historypb.HistoryEvent

msgs := indexMessagesByEventID(workflowTask.task.GetMessages())

taskMessages := workflowTask.task.GetMessages()
skipReplayCheck := w.skipReplayCheck()
shouldForceReplayCheck := func() bool {
isInReplayer := IsReplayNamespace(w.wth.namespace)
Expand All @@ -899,7 +906,17 @@ func (w *workflowExecutionContextImpl) ProcessWorkflowTask(workflowTask *workflo

ProcessEvents:
for {
reorderedEvents, markers, binaryChecksum, flags, err := reorderedHistory.NextCommandEvents()
reorderedEvents, markers, binaryChecksum, flags, historyMessages, err := reorderedHistory.NextCommandEvents()
// Check if we are replaying so we know if we should use the messages in the WFT or the history
isReplay := len(reorderedEvents) > 0 && reorderedHistory.IsReplayEvent(reorderedEvents[len(reorderedEvents)-1])
var msgs *eventMsgIndex
if isReplay {
msgs = indexMessagesByEventID(historyMessages)
} else {
msgs = indexMessagesByEventID(taskMessages)
taskMessages = []*protocolpb.Message{}
}

if err != nil {
return nil, err
}
Expand Down Expand Up @@ -993,7 +1010,6 @@ ProcessEvents:
}
}
}
isReplay := len(reorderedEvents) > 0 && reorderedHistory.IsReplayEvent(reorderedEvents[len(reorderedEvents)-1])
if isReplay {
eventCommands := eventHandler.commandsHelper.getCommands(true)
if !skipReplayCheck {
Expand Down
71 changes: 68 additions & 3 deletions internal/internal_task_handlers_interfaces_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,7 @@ import (
historypb "go.temporal.io/api/history/v1"
"go.temporal.io/api/sdk/v1"
taskqueuepb "go.temporal.io/api/taskqueue/v1"
updatepb "go.temporal.io/api/update/v1"
"go.temporal.io/api/workflowservice/v1"
"go.temporal.io/api/workflowservicemock/v1"
)
Expand Down Expand Up @@ -179,7 +180,7 @@ func (s *PollLayerInterfacesTestSuite) TestGetNextCommands() {

eh := newHistory(workflowTask, nil)

events, _, _, _, err := eh.NextCommandEvents()
events, _, _, _, _, err := eh.NextCommandEvents()

s.NoError(err)
s.Equal(3, len(events))
Expand Down Expand Up @@ -222,7 +223,7 @@ func (s *PollLayerInterfacesTestSuite) TestGetNextCommandsSdkFlags() {

eh := newHistory(workflowTask, nil)

events, _, _, sdkFlags, err := eh.NextCommandEvents()
events, _, _, sdkFlags, _, err := eh.NextCommandEvents()

s.NoError(err)
s.Equal(2, len(events))
Expand All @@ -232,7 +233,7 @@ func (s *PollLayerInterfacesTestSuite) TestGetNextCommandsSdkFlags() {
s.Equal(1, len(sdkFlags))
s.EqualValues(SDKFlagLimitChangeVersionSASize, sdkFlags[0])

events, _, _, sdkFlags, err = eh.NextCommandEvents()
events, _, _, sdkFlags, _, err = eh.NextCommandEvents()

s.NoError(err)
s.Equal(4, len(events))
Expand All @@ -243,3 +244,67 @@ func (s *PollLayerInterfacesTestSuite) TestGetNextCommandsSdkFlags() {

s.Equal(0, len(sdkFlags))
}

func (s *PollLayerInterfacesTestSuite) TestMessageCommands() {
// Schedule an activity and see if we complete workflow.
taskQueue := "tq1"
testEvents := []*historypb.HistoryEvent{
createTestEventWorkflowExecutionStarted(1, &historypb.WorkflowExecutionStartedEventAttributes{TaskQueue: &taskqueuepb.TaskQueue{Name: taskQueue}}),
createTestEventWorkflowTaskScheduled(2, &historypb.WorkflowTaskScheduledEventAttributes{TaskQueue: &taskqueuepb.TaskQueue{Name: taskQueue}}),
createTestEventWorkflowTaskStarted(3),
{
EventId: 4,
EventType: enumspb.EVENT_TYPE_WORKFLOW_TASK_FAILED,
},
createTestEventWorkflowTaskScheduled(5, &historypb.WorkflowTaskScheduledEventAttributes{TaskQueue: &taskqueuepb.TaskQueue{Name: taskQueue}}),
createTestEventWorkflowTaskStarted(6),
createTestEventWorkflowTaskCompleted(7, &historypb.WorkflowTaskCompletedEventAttributes{
ScheduledEventId: 5,
StartedEventId: 6,
}),
{
EventId: 8,
EventType: enumspb.EVENT_TYPE_WORKFLOW_EXECUTION_UPDATE_ACCEPTED,
Attributes: &historypb.HistoryEvent_WorkflowExecutionUpdateAcceptedEventAttributes{
WorkflowExecutionUpdateAcceptedEventAttributes: &historypb.WorkflowExecutionUpdateAcceptedEventAttributes{
ProtocolInstanceId: "test",
AcceptedRequest: &updatepb.Request{},
},
},
},
createTestEventWorkflowTaskScheduled(9, &historypb.WorkflowTaskScheduledEventAttributes{TaskQueue: &taskqueuepb.TaskQueue{Name: taskQueue}}),
createTestEventWorkflowTaskStarted(10),
}
task := createWorkflowTaskWithQueries(testEvents[0:3], 0, "HelloWorld_Workflow", nil, false)

historyIterator := &historyIteratorImpl{
iteratorFunc: func(nextToken []byte) (*historypb.History, []byte, error) {
return &historypb.History{
Events: testEvents[3:],
}, nil, nil
},
nextPageToken: []byte("test"),
}

workflowTask := &workflowTask{task: task, historyIterator: historyIterator}

eh := newHistory(workflowTask, nil)

events, _, _, _, msgs, err := eh.NextCommandEvents()
s.NoError(err)
s.Equal(2, len(events))
s.Equal(enumspb.EVENT_TYPE_WORKFLOW_EXECUTION_STARTED, events[0].GetEventType())
s.Equal(enumspb.EVENT_TYPE_WORKFLOW_TASK_STARTED, events[1].GetEventType())

s.Equal(1, len(msgs))
s.Equal("test", msgs[0].GetProtocolInstanceId())

events, _, _, _, msgs, err = eh.NextCommandEvents()
s.NoError(err)
s.Equal(3, len(events))
s.Equal(enumspb.EVENT_TYPE_WORKFLOW_TASK_COMPLETED, events[0].GetEventType())
s.Equal(enumspb.EVENT_TYPE_WORKFLOW_EXECUTION_UPDATE_ACCEPTED, events[1].GetEventType())
s.Equal(enumspb.EVENT_TYPE_WORKFLOW_TASK_STARTED, events[2].GetEventType())

s.Equal(0, len(msgs))
}
Loading

0 comments on commit 2990ebf

Please sign in to comment.