Skip to content

Commit

Permalink
Track mutable side effect calls
Browse files Browse the repository at this point in the history
  • Loading branch information
Quinn-With-Two-Ns committed Feb 27, 2023
1 parent 16d73c5 commit d8c78af
Show file tree
Hide file tree
Showing 9 changed files with 1,135 additions and 45 deletions.
25 changes: 16 additions & 9 deletions internal/internal_command_state_machine.go
Original file line number Diff line number Diff line change
Expand Up @@ -221,12 +221,13 @@ const (
localActivityMarkerName = "LocalActivity"
mutableSideEffectMarkerName = "MutableSideEffect"

sideEffectMarkerIDName = "side-effect-id"
sideEffectMarkerDataName = "data"
versionMarkerChangeIDName = "change-id"
versionMarkerDataName = "version"
localActivityMarkerDataName = "data"
localActivityResultName = "result"
sideEffectMarkerIDName = "side-effect-id"
sideEffectMarkerDataName = "data"
versionMarkerChangeIDName = "change-id"
versionMarkerDataName = "version"
localActivityMarkerDataName = "data"
localActivityResultName = "result"
mutableSideEffectCallCounterName = "mutable-side-effect-call-counter"
)

func (d commandState) String() string {
Expand Down Expand Up @@ -1146,7 +1147,7 @@ func (h *commandsHelper) recordLocalActivityMarker(activityID string, details ma
return command
}

func (h *commandsHelper) recordMutableSideEffectMarker(mutableSideEffectID string, data *commonpb.Payloads, dc converter.DataConverter) commandStateMachine {
func (h *commandsHelper) recordMutableSideEffectMarker(mutableSideEffectID string, callCountHint int, data *commonpb.Payloads, dc converter.DataConverter) commandStateMachine {
// In order to avoid duplicate marker IDs, we must append the counter to the
// user-provided ID
mutableSideEffectID = fmt.Sprintf("%v_%v", mutableSideEffectID, h.nextCommandEventID)
Expand All @@ -1157,11 +1158,17 @@ func (h *commandsHelper) recordMutableSideEffectMarker(mutableSideEffectID strin
panic(err)
}

mutableSideEffectCounterPayload, err := dc.ToPayloads(callCountHint)
if err != nil {
panic(err)
}

attributes := &commandpb.RecordMarkerCommandAttributes{
MarkerName: mutableSideEffectMarkerName,
Details: map[string]*commonpb.Payloads{
sideEffectMarkerIDName: mutableSideEffectIDPayload,
sideEffectMarkerDataName: data,
sideEffectMarkerIDName: mutableSideEffectIDPayload,
sideEffectMarkerDataName: data,
mutableSideEffectCallCounterName: mutableSideEffectCounterPayload,
},
}
command := h.newMarkerCommandStateMachine(markerID, attributes)
Expand Down
110 changes: 84 additions & 26 deletions internal/internal_event_handlers.go
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@ import (
"errors"
"fmt"
"reflect"
"sort"
"sync"
"time"

Expand Down Expand Up @@ -109,14 +110,18 @@ type (
changeVersions map[string]Version
pendingLaTasks map[string]*localActivityTask
completedLaAttemptsThisWFT uint32
mutableSideEffect map[string]*commonpb.Payloads
mutableSideEffect map[string]map[int]*commonpb.Payloads
unstartedLaTasks map[string]struct{}
openSessions map[string]*SessionInfo

// Set of mutable side effect IDs that are recorded on the next task for use
// during replay to determine whether a command should be created. The keys
// are the user-provided IDs + "_" + the command counter.
mutableSideEffectsRecorded map[string]bool
// Records the number of times a mutable side effect was called per ID over the
// life of the workflow. Used to help distinguish multiple calls to MutableSideEffect in the same
// WorkflowTask.
mutableSideEffectCallCounter map[string]int

// LocalActivities have a separate, individual counter instead of relying on actual commandEventIDs.
// This is because command IDs are only incremented on activity completion, which breaks
Expand Down Expand Up @@ -196,22 +201,23 @@ func newWorkflowExecutionEventHandler(
deadlockDetectionTimeout time.Duration,
) workflowExecutionEventHandler {
context := &workflowEnvironmentImpl{
workflowInfo: workflowInfo,
commandsHelper: newCommandsHelper(),
sideEffectResult: make(map[int64]*commonpb.Payloads),
mutableSideEffect: make(map[string]*commonpb.Payloads),
changeVersions: make(map[string]Version),
pendingLaTasks: make(map[string]*localActivityTask),
unstartedLaTasks: make(map[string]struct{}),
openSessions: make(map[string]*SessionInfo),
completeHandler: completeHandler,
enableLoggingInReplay: enableLoggingInReplay,
registry: registry,
dataConverter: dataConverter,
failureConverter: failureConverter,
contextPropagators: contextPropagators,
deadlockDetectionTimeout: deadlockDetectionTimeout,
protocols: protocol.NewRegistry(),
workflowInfo: workflowInfo,
commandsHelper: newCommandsHelper(),
sideEffectResult: make(map[int64]*commonpb.Payloads),
mutableSideEffect: make(map[string]map[int]*commonpb.Payloads),
changeVersions: make(map[string]Version),
pendingLaTasks: make(map[string]*localActivityTask),
unstartedLaTasks: make(map[string]struct{}),
openSessions: make(map[string]*SessionInfo),
completeHandler: completeHandler,
enableLoggingInReplay: enableLoggingInReplay,
registry: registry,
dataConverter: dataConverter,
failureConverter: failureConverter,
contextPropagators: contextPropagators,
deadlockDetectionTimeout: deadlockDetectionTimeout,
protocols: protocol.NewRegistry(),
mutableSideEffectCallCounter: make(map[string]int),
}
context.logger = ilog.NewReplayLogger(
log.With(logger,
Expand Down Expand Up @@ -782,15 +788,51 @@ func (wc *workflowEnvironmentImpl) SideEffect(f func() (*commonpb.Payloads, erro
wc.logger.Debug("SideEffect Marker added", tagSideEffectID, sideEffectID)
}

func (wc *workflowEnvironmentImpl) lookupMutableSideEffect(id string) (*commonpb.Payloads, bool) {
if payloadAtCallCount, ok := wc.mutableSideEffect[id]; ok {
currentCallCount := wc.mutableSideEffectCallCounter[id]
// Sort the calls
calls := make([]int, 0)
for k := range payloadAtCallCount {
calls = append(calls, k)
}
sort.Ints(calls)
// Find the most recent call at/before the current call count
var payload *commonpb.Payloads
var foundIndex int
for i := len(calls) - 1; i >= 0; i-- {
if calls[i] <= currentCallCount {
payload = payloadAtCallCount[calls[i]]
foundIndex = i
break
}
}
if payload == nil {
return nil, false
}
// Garbage collect old entries
// TODO(quinn) unclear if this should be done aggressively at WFT boundry
for i := 0; i < foundIndex; i++ {
delete(payloadAtCallCount, calls[i])
}

return payload, true
}
return nil, false
}

func (wc *workflowEnvironmentImpl) MutableSideEffect(id string, f func() interface{}, equals func(a, b interface{}) bool) converter.EncodedValue {
if result, ok := wc.mutableSideEffect[id]; ok {
wc.mutableSideEffectCallCounter[id]++
callCount := wc.mutableSideEffectCallCounter[id]

if result, ok := wc.lookupMutableSideEffect(id); ok {
encodedResult := newEncodedValue(result, wc.GetDataConverter())
if wc.isReplay {
// During replay, we only generate a command if there was a known marker
// recorded on the next task. We have to append the current command
// counter to the user-provided ID to avoid duplicates.
if wc.mutableSideEffectsRecorded[fmt.Sprintf("%v_%v", id, wc.commandsHelper.nextCommandEventID)] {
return wc.recordMutableSideEffect(id, result)
if _, ok := wc.mutableSideEffect[id][callCount]; ok && wc.mutableSideEffectsRecorded[fmt.Sprintf("%v_%v", id, wc.commandsHelper.nextCommandEventID)] {
return wc.recordMutableSideEffect(id, callCount, result)
}
return encodedResult
}
Expand All @@ -800,15 +842,15 @@ func (wc *workflowEnvironmentImpl) MutableSideEffect(id string, f func() interfa
return encodedResult
}

return wc.recordMutableSideEffect(id, wc.encodeValue(newValue))
return wc.recordMutableSideEffect(id, callCount, wc.encodeValue(newValue))
}

if wc.isReplay {
// This should not happen
panic(fmt.Sprintf("Non deterministic workflow code change detected. MutableSideEffect API call doesn't have a correspondent event in the workflow history. MutableSideEffect ID: %s", id))
}

return wc.recordMutableSideEffect(id, wc.encodeValue(f()))
return wc.recordMutableSideEffect(id, callCount, wc.encodeValue(f()))
}

func (wc *workflowEnvironmentImpl) isEqualValue(newValue interface{}, encodedOldValue *commonpb.Payloads, equals func(a, b interface{}) bool) bool {
Expand Down Expand Up @@ -844,13 +886,16 @@ func (wc *workflowEnvironmentImpl) encodeArg(arg interface{}) (*commonpb.Payload
return wc.GetDataConverter().ToPayloads(arg)
}

func (wc *workflowEnvironmentImpl) recordMutableSideEffect(id string, data *commonpb.Payloads) converter.EncodedValue {
func (wc *workflowEnvironmentImpl) recordMutableSideEffect(id string, callCountHint int, data *commonpb.Payloads) converter.EncodedValue {
details, err := encodeArgs(wc.GetDataConverter(), []interface{}{id, data})
if err != nil {
panic(err)
}
wc.commandsHelper.recordMutableSideEffectMarker(id, details, wc.dataConverter)
wc.mutableSideEffect[id] = data
wc.commandsHelper.recordMutableSideEffectMarker(id, callCountHint, details, wc.dataConverter)
if wc.mutableSideEffect[id] == nil {
wc.mutableSideEffect[id] = make(map[int]*commonpb.Payloads)
}
wc.mutableSideEffect[id][callCountHint] = data
return newEncodedValue(data, wc.GetDataConverter())
}

Expand Down Expand Up @@ -1311,7 +1356,20 @@ func (weh *workflowExecutionEventHandlerImpl) handleMarkerRecorded(
err = weh.dataConverter.FromPayloads(sideEffectDataPayload, &sideEffectDataID, &sideEffectDataContents)
}
if err == nil {
weh.mutableSideEffect[sideEffectDataID] = &sideEffectDataContents
counterHintPayload, ok := attributes.GetDetails()[mutableSideEffectCallCounterName]
var counterHint int
if ok {
err = weh.dataConverter.FromPayloads(counterHintPayload, &counterHint)
} else {
// An old version of the SDK did not write the counter hint so we have to assume.
// If multiple mutable side effects on the same ID are in a WFT only the last value is used.
counterHint = weh.mutableSideEffectCallCounter[sideEffectDataID]
}

if weh.mutableSideEffect[sideEffectDataID] == nil {
weh.mutableSideEffect[sideEffectDataID] = make(map[int]*commonpb.Payloads)
}
weh.mutableSideEffect[sideEffectDataID][counterHint] = &sideEffectDataContents
// We must mark that it is recorded so we can know whether a command
// needs to be generated during replay
if weh.mutableSideEffectsRecorded == nil {
Expand Down
44 changes: 34 additions & 10 deletions internal/internal_worker.go
Original file line number Diff line number Diff line change
Expand Up @@ -1082,11 +1082,13 @@ func (aw *AggregatedWorker) Stop() {

// WorkflowReplayer is used to replay workflow code from an event history
type WorkflowReplayer struct {
registry *registry
dataConverter converter.DataConverter
failureConverter converter.FailureConverter
contextPropagators []ContextPropagator
enableLoggingInReplay bool
registry *registry
dataConverter converter.DataConverter
failureConverter converter.FailureConverter
contextPropagators []ContextPropagator
enableLoggingInReplay bool
mu sync.Mutex
workflowExecutionResults map[string]*commonpb.Payloads
}

// WorkflowReplayerOptions are options for creating a workflow replayer.
Expand Down Expand Up @@ -1130,11 +1132,12 @@ func NewWorkflowReplayer(options WorkflowReplayerOptions) (*WorkflowReplayer, er
registry := newRegistryWithOptions(registryOptions{disableAliasing: options.DisableRegistrationAliasing})
registry.interceptors = options.Interceptors
return &WorkflowReplayer{
registry: registry,
dataConverter: options.DataConverter,
failureConverter: options.FailureConverter,
contextPropagators: options.ContextPropagators,
enableLoggingInReplay: options.EnableLoggingInReplay,
registry: registry,
dataConverter: options.DataConverter,
failureConverter: options.FailureConverter,
contextPropagators: options.ContextPropagators,
enableLoggingInReplay: options.EnableLoggingInReplay,
workflowExecutionResults: make(map[string]*commonpb.Payloads),
}, nil
}

Expand Down Expand Up @@ -1235,6 +1238,24 @@ func (aw *WorkflowReplayer) ReplayWorkflowExecution(ctx context.Context, service
return aw.replayWorkflowHistory(logger, service, namespace, execution, &history)
}

// GetWorkflowResult get the result of a succesfully replayed workflow.
func (aw *WorkflowReplayer) GetWorkflowResult(workflowID string, valuePtr interface{}) error {
aw.mu.Lock()
defer aw.mu.Unlock()
if workflowID == "" {
workflowID = "ReplayId"
}
payloads, ok := aw.workflowExecutionResults[workflowID]
if !ok {
return errors.New("workflow result not found")
}
dc := aw.dataConverter
if dc == nil {
dc = converter.GetDefaultDataConverter()
}
return dc.FromPayloads(payloads, valuePtr)
}

// inferMessages extracts the set of *interactionpb.Invocation objects that
// should be attached to a workflow task (i.e. the
// PollWorkflowTaskQueueResponse.Messages) if that task were to carry the
Expand Down Expand Up @@ -1348,6 +1369,9 @@ func (aw *WorkflowReplayer) replayWorkflowHistory(logger log.Logger, service wor
}
if d.GetCommandType() == enumspb.COMMAND_TYPE_COMPLETE_WORKFLOW_EXECUTION {
if last.GetEventType() == enumspb.EVENT_TYPE_WORKFLOW_EXECUTION_COMPLETED {
aw.mu.Lock()
defer aw.mu.Unlock()
aw.workflowExecutionResults[execution.WorkflowId] = d.GetCompleteWorkflowExecutionCommandAttributes().Result
return nil
}
}
Expand Down
Loading

0 comments on commit d8c78af

Please sign in to comment.