diff --git a/service/interpreter/stateExecutionCounter.go b/service/interpreter/stateExecutionCounter.go index b6291466..90b383c0 100644 --- a/service/interpreter/stateExecutionCounter.go +++ b/service/interpreter/stateExecutionCounter.go @@ -6,7 +6,7 @@ import ( "github.com/indeedeng/iwf/service" "github.com/indeedeng/iwf/service/common/compatibility" "github.com/indeedeng/iwf/service/common/ptr" - "slices" + "reflect" ) type StateExecutionCounter struct { @@ -81,9 +81,6 @@ func (e *StateExecutionCounter) MarkStateIdExecutingIfNotYet(stateReqs []StateRe e.provider.GetLogger(e.ctx).Error("error for GetSearchAttributes", err) } - // TODO: Address the !ok case - currentSAs := sas[service.SearchAttributeExecutingStateIds] - needsUpdateSA := false numOfNew := 0 for _, sr := range stateReqs { @@ -98,18 +95,14 @@ func (e *StateExecutionCounter) MarkStateIdExecutingIfNotYet(stateReqs []StateRe // do nothing case iwfidl.ENABLED_FOR_ALL: e.increaseStateIdCurrentlyExecutingCounts(s) - if !slices.Contains(currentSAs.StringArrayValue, s.StateId) { - needsUpdateSA = true - } + needsUpdateSA = true case iwfidl.ENABLED_FOR_STATES_WITH_WAIT_UNTIL: fallthrough default: options := s.GetStateOptions() if !compatibility.GetSkipWaitUntilApi(&options) { e.increaseStateIdCurrentlyExecutingCounts(s) - if !slices.Contains(currentSAs.StringArrayValue, s.StateId) { - needsUpdateSA = true - } + needsUpdateSA = true } } } else { @@ -124,8 +117,46 @@ func (e *StateExecutionCounter) MarkStateIdExecutingIfNotYet(stateReqs []StateRe } e.totalCurrentlyExecutingCount += numOfNew + var currentSAsValues []string + + currentSAs, ok := sas[service.SearchAttributeExecutingStateIds] + if ok { + currentSAsValues = currentSAs.StringArrayValue + } else { + e.provider.GetLogger(e.ctx).Error("search attribute IwfExecutingStateIds is not found", err) + } + + // Optimization: don't upsert SAs if currentSAsValues == stateReqs + if e.globalVersioner.IsAfterVersionOfExecutingStateIdMode() && needsUpdateSA { + switch mode := config.GetExecutingStateIdMode(); mode { + // Should never get here, but keeping to address all possible modes + case iwfidl.DISABLED: + // noop + case iwfidl.ENABLED_FOR_ALL: + var stateReqStates []string + for _, sr := range stateReqs { + stateReqStates = append(stateReqStates, sr.GetStateId()) + } + if reflect.DeepEqual(currentSAsValues, stateReqStates) { + needsUpdateSA = false + } + case iwfidl.ENABLED_FOR_STATES_WITH_WAIT_UNTIL: + fallthrough + default: + var stateReqStates []string + for _, sr := range stateReqs { + if !sr.GetStateStartRequest().StateOptions.GetSkipWaitUntil() { + stateReqStates = append(stateReqStates, sr.GetStateId()) + } + } + if reflect.DeepEqual(currentSAsValues, stateReqStates) { + needsUpdateSA = false + } + } + } + if needsUpdateSA { - return e.updateStateIdSearchAttribute() + return e.refreshIwfExecutingStateIdSearchAttribute() } return nil } @@ -151,7 +182,7 @@ func (e *StateExecutionCounter) MarkStateExecutionCompleted(currentState iwfidl. return nil case iwfidl.ENABLED_FOR_ALL: e.decreaseStateIdCurrentlyExecutingCounts(currentState) - shouldSkipUpsert := determineIfShouldSkipUpsert(currentState, nextStates) + shouldSkipUpsert := determineIfShouldSkipRefresh(currentState, nextStates) if shouldSkipUpsert { return nil } @@ -162,7 +193,7 @@ func (e *StateExecutionCounter) MarkStateExecutionCompleted(currentState iwfidl. return nil } else { e.decreaseStateIdCurrentlyExecutingCounts(currentState) - shouldSkipUpsert := determineIfShouldSkipUpsert(currentState, nextStates) + shouldSkipUpsert := determineIfShouldSkipRefresh(currentState, nextStates) if shouldSkipUpsert { return nil } @@ -176,16 +207,16 @@ func (e *StateExecutionCounter) MarkStateExecutionCompleted(currentState iwfidl. } } - return e.updateStateIdSearchAttribute() + return e.refreshIwfExecutingStateIdSearchAttribute() } -func determineIfShouldSkipUpsert(currentState iwfidl.StateMovement, nextStates []iwfidl.StateMovement) bool { +func determineIfShouldSkipRefresh(currentState iwfidl.StateMovement, nextStates []iwfidl.StateMovement) bool { // Case: State loops back to itself; Outcome: do not upsert SAs if len(nextStates) == 1 && currentState.StateId == nextStates[0].StateId { return true } - // Check if all nextStates skip waitUntil; omit currentState in case it loops back + // Check if any of nextStates skips waitUntil; omit currentState in case it loops back var nextStagesWithNoCurrent []iwfidl.StateMovement for _, s := range nextStates { if s.StateId != currentState.StateId { @@ -193,16 +224,13 @@ func determineIfShouldSkipUpsert(currentState iwfidl.StateMovement, nextStates [ } } - shouldSkipUpsertingSAs := true - for _, s := range nextStagesWithNoCurrent { - if !s.StateOptions.GetSkipWaitUntil() { - shouldSkipUpsertingSAs = false - break + if s.StateOptions.GetSkipWaitUntil() { + return true } } - return shouldSkipUpsertingSAs + return false } func (e *StateExecutionCounter) decreaseStateIdCurrentlyExecutingCounts(state iwfidl.StateMovement) { @@ -216,7 +244,7 @@ func (e *StateExecutionCounter) GetTotalCurrentlyExecutingCount() int { return e.totalCurrentlyExecutingCount } -func (e *StateExecutionCounter) updateStateIdSearchAttribute() error { +func (e *StateExecutionCounter) refreshIwfExecutingStateIdSearchAttribute() error { var executingStateIds []string for sid := range e.stateIdCurrentlyExecutingCounts { executingStateIds = append(executingStateIds, sid) diff --git a/service/interpreter/workflowImpl.go b/service/interpreter/workflowImpl.go index 6444e603..89cc7a25 100644 --- a/service/interpreter/workflowImpl.go +++ b/service/interpreter/workflowImpl.go @@ -109,7 +109,7 @@ func InterpreterImpl( var forceCompleteWf bool var shouldGracefulComplete bool - // this is for an optimization for StateId Search attribute, see updateStateIdSearchAttribute in stateExecutionCounter + // this is for an optimization for StateId Search attribute, see refreshIwfExecutingStateIdSearchAttribute in stateExecutionCounter // Because it will check totalCurrentlyExecutingCount == 0, so it will also work for continueAsNew case defer stateExecutionCounter.ClearExecutingStateIdsSearchAttributeFinally()