Skip to content

Commit

Permalink
Fix conflict error handling (#2469)
Browse files Browse the repository at this point in the history
  • Loading branch information
alexshtin authored Feb 9, 2022
1 parent 670d99c commit 01d071a
Show file tree
Hide file tree
Showing 4 changed files with 105 additions and 88 deletions.
76 changes: 43 additions & 33 deletions common/persistence/cassandra/errors.go
Original file line number Diff line number Diff line change
Expand Up @@ -57,52 +57,62 @@ type (
)

func convertErrors(
record map[string]interface{},
iter gocql.Iter,
conflictRecord map[string]interface{},
conflictIter gocql.Iter,
requestShardID int32,
requestRangeID int64,
requestCurrentRunID string,
requestExecutionCASConditions []executionCASCondition,
) error {

records := []map[string]interface{}{record}
conflictRecords := []map[string]interface{}{conflictRecord}
errors := extractErrors(
record,
conflictRecord,
requestShardID,
requestRangeID,
requestCurrentRunID,
requestExecutionCASConditions,
)

record = make(map[string]interface{})
for iter.MapScan(record) {
records = append(records, record)
conflictRecord = make(map[string]interface{})
for conflictIter.MapScan(conflictRecord) {
if conflictRecord["[applied]"].(bool) {
// Should never happen. All records in batch should have [applied]=false.
continue
}

conflictRecords = append(conflictRecords, conflictRecord)
errors = append(errors, extractErrors(
record,
conflictRecord,
requestShardID,
requestRangeID,
requestCurrentRunID,
requestExecutionCASConditions,
)...)

record = make(map[string]interface{})
conflictRecord = make(map[string]interface{})
}

errors = sortErrors(errors)
if len(errors) == 0 {
// This means that extractErrors wasn't able to extract error from the conflicting records.
// Most likely record to update is not found in the DB by WHERE clause and is NOT in conflictRecords slice.
// Unfortunately, there is no way to get the missing record w/o extra call to DB.
// Most likely it is current workflow execution record.
return &p.ConditionFailedError{
Msg: fmt.Sprintf("Encounter unknown error: shard ID: %v, range ID: %v, error: %v",
Msg: fmt.Sprintf("Encounter unknown condition update error: shard ID: %v, range ID: %v, possibly conflicting records:%v",
requestShardID,
requestRangeID,
printRecords(records),
printRecords(conflictRecords),
),
}
}

errors = sortErrors(errors)
return errors[0]
}

func extractErrors(
record map[string]interface{},
conflictRecord map[string]interface{},
requestShardID int32,
requestRangeID int64,
requestCurrentRunID string,
Expand All @@ -112,23 +122,23 @@ func extractErrors(
var errors []error

if err := extractShardOwnershipLostError(
record,
conflictRecord,
requestShardID,
requestRangeID,
); err != nil {
errors = append(errors, err)
}

if err := extractCurrentWorkflowConflictError(
record,
conflictRecord,
requestCurrentRunID,
); err != nil {
errors = append(errors, err)
}

for _, condition := range requestExecutionCASConditions {
if err := extractWorkflowConflictError(
record,
conflictRecord,
condition.runID,
condition.dbVersion,
condition.nextEventID,
Expand Down Expand Up @@ -158,11 +168,11 @@ func sortErrors(
}

func extractShardOwnershipLostError(
record map[string]interface{},
conflictRecord map[string]interface{},
requestShardID int32,
requestRangeID int64,
) error {
rowType, ok := record["type"].(int)
rowType, ok := conflictRecord["type"].(int)
if !ok {
// this case should not happen, maybe panic?
return nil
Expand All @@ -171,7 +181,7 @@ func extractShardOwnershipLostError(
return nil
}

actualRangeID := record["range_id"].(int64)
actualRangeID := conflictRecord["range_id"].(int64)
if actualRangeID != requestRangeID {
return &p.ShardOwnershipLostError{
ShardID: requestShardID,
Expand All @@ -185,25 +195,25 @@ func extractShardOwnershipLostError(
}

func extractCurrentWorkflowConflictError(
record map[string]interface{},
conflictRecord map[string]interface{},
requestCurrentRunID string,
) error {
rowType, ok := record["type"].(int)
rowType, ok := conflictRecord["type"].(int)
if !ok {
// this case should not happen, maybe panic?
return nil
}
if rowType != rowTypeExecution {
return nil
}
if runID := gocql.UUIDToString(record["run_id"]); runID != permanentRunID {
if runID := gocql.UUIDToString(conflictRecord["run_id"]); runID != permanentRunID {
return nil
}

actualCurrentRunID := gocql.UUIDToString(record["current_run_id"])
actualCurrentRunID := gocql.UUIDToString(conflictRecord["current_run_id"])
if actualCurrentRunID != requestCurrentRunID {
binary, _ := record["execution_state"].([]byte)
encoding, _ := record["execution_state_encoding"].(string)
binary, _ := conflictRecord["execution_state"].([]byte)
encoding, _ := conflictRecord["execution_state_encoding"].(string)
executionState := &persistencespb.WorkflowExecutionState{}
if state, err := serialization.WorkflowExecutionStateFromBlob(
binary,
Expand All @@ -213,12 +223,12 @@ func extractCurrentWorkflowConflictError(
}
// if err != nil, this means execution state cannot be parsed, just use default values

lastWriteVersion, _ := record["workflow_last_write_version"].(int64)
lastWriteVersion, _ := conflictRecord["workflow_last_write_version"].(int64)

// TODO maybe assert actualCurrentRunID == executionState.RunId ?

return &p.CurrentWorkflowConditionFailedError{
Msg: fmt.Sprintf("Encounter concurrent workflow error, request run ID: %v, actual run ID: %v",
Msg: fmt.Sprintf("Encounter current workflow error, request run ID: %v, actual run ID: %v",
requestCurrentRunID,
actualCurrentRunID,
),
Expand All @@ -233,25 +243,25 @@ func extractCurrentWorkflowConflictError(
}

func extractWorkflowConflictError(
record map[string]interface{},
conflictRecord map[string]interface{},
requestRunID string,
requestDBVersion int64,
requestNextEventID int64, // TODO deprecate this variable once DB version comparison is the default
) error {
rowType, ok := record["type"].(int)
rowType, ok := conflictRecord["type"].(int)
if !ok {
// this case should not happen, maybe panic?
return nil
}
if rowType != rowTypeExecution {
return nil
}
if runID := gocql.UUIDToString(record["run_id"]); runID != requestRunID {
if runID := gocql.UUIDToString(conflictRecord["run_id"]); runID != requestRunID {
return nil
}

actualNextEventID, _ := record["next_event_id"].(int64)
actualDBVersion, _ := record["db_version"].(int64)
actualNextEventID, _ := conflictRecord["next_event_id"].(int64)
actualDBVersion, _ := conflictRecord["db_record_version"].(int64)

// TODO remove this block once DB version comparison is the default
if requestDBVersion == 0 {
Expand All @@ -270,7 +280,7 @@ func extractWorkflowConflictError(

if actualDBVersion != requestDBVersion {
return &p.WorkflowConditionFailedError{
Msg: fmt.Sprintf("Encounter workflow db version mismatch, request db version ID: %v, actual db version ID: %v",
Msg: fmt.Sprintf("Encounter workflow db version mismatch, request db version: %v, actual db version: %v",
requestDBVersion,
actualDBVersion,
),
Expand Down
24 changes: 12 additions & 12 deletions common/persistence/cassandra/errors_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -244,23 +244,23 @@ func (s *cassandraErrorsSuite) TestExtractWorkflowConflictError_Failed() {
s.NoError(err)

err = extractWorkflowConflictError(map[string]interface{}{
"type": rowTypeShard,
"run_id": gocql.UUID(runID),
"db_version": dbVersion,
"type": rowTypeShard,
"run_id": gocql.UUID(runID),
"db_record_version": dbVersion,
}, runID.String(), dbVersion+1, rand.Int63())
s.NoError(err)

err = extractWorkflowConflictError(map[string]interface{}{
"type": rowTypeExecution,
"run_id": gocql.UUID([16]byte{}),
"db_version": dbVersion,
"type": rowTypeExecution,
"run_id": gocql.UUID([16]byte{}),
"db_record_version": dbVersion,
}, runID.String(), dbVersion+1, rand.Int63())
s.NoError(err)

err = extractWorkflowConflictError(map[string]interface{}{
"type": rowTypeExecution,
"run_id": gocql.UUID(runID),
"db_version": dbVersion,
"type": rowTypeExecution,
"run_id": gocql.UUID(runID),
"db_record_version": dbVersion,
}, runID.String(), dbVersion, rand.Int63())
s.NoError(err)
}
Expand All @@ -269,9 +269,9 @@ func (s *cassandraErrorsSuite) TestExtractWorkflowConflictError_Success() {
runID := uuid.New()
dbVersion := rand.Int63() + 1
record := map[string]interface{}{
"type": rowTypeExecution,
"run_id": gocql.UUID(runID),
"db_version": dbVersion,
"type": rowTypeExecution,
"run_id": gocql.UUID(runID),
"db_record_version": dbVersion,
}

err := extractWorkflowConflictError(record, runID.String(), dbVersion+1, rand.Int63())
Expand Down
30 changes: 15 additions & 15 deletions common/persistence/cassandra/mutable_state_store.go
Original file line number Diff line number Diff line change
Expand Up @@ -451,19 +451,19 @@ func (d *MutableStateStore) CreateWorkflowExecution(
request.RangeID,
)

record := make(map[string]interface{})
applied, iter, err := d.Session.MapExecuteBatchCAS(batch, record)
conflictRecord := make(map[string]interface{})
applied, conflictIter, err := d.Session.MapExecuteBatchCAS(batch, conflictRecord)
if err != nil {
return nil, gocql.ConvertError("CreateWorkflowExecution", err)
}
defer func() {
_ = iter.Close()
_ = conflictIter.Close()
}()

if !applied {
return nil, convertErrors(
record,
iter,
conflictRecord,
conflictIter,
shardID,
request.RangeID,
requestCurrentRunID,
Expand Down Expand Up @@ -670,19 +670,19 @@ func (d *MutableStateStore) UpdateWorkflowExecution(
request.RangeID,
)

record := make(map[string]interface{})
applied, iter, err := d.Session.MapExecuteBatchCAS(batch, record)
conflictRecord := make(map[string]interface{})
applied, conflictIter, err := d.Session.MapExecuteBatchCAS(batch, conflictRecord)
if err != nil {
return gocql.ConvertError("UpdateWorkflowExecution", err)
}
defer func() {
_ = iter.Close()
_ = conflictIter.Close()
}()

if !applied {
return convertErrors(
record,
iter,
conflictRecord,
conflictIter,
request.ShardID,
request.RangeID,
updateWorkflow.ExecutionState.RunId,
Expand Down Expand Up @@ -819,13 +819,13 @@ func (d *MutableStateStore) ConflictResolveWorkflowExecution(
request.RangeID,
)

record := make(map[string]interface{})
applied, iter, err := d.Session.MapExecuteBatchCAS(batch, record)
conflictRecord := make(map[string]interface{})
applied, conflictIter, err := d.Session.MapExecuteBatchCAS(batch, conflictRecord)
if err != nil {
return gocql.ConvertError("ConflictResolveWorkflowExecution", err)
}
defer func() {
_ = iter.Close()
_ = conflictIter.Close()
}()

if !applied {
Expand All @@ -846,8 +846,8 @@ func (d *MutableStateStore) ConflictResolveWorkflowExecution(
})
}
return convertErrors(
record,
iter,
conflictRecord,
conflictIter,
request.ShardID,
request.RangeID,
currentRunID,
Expand Down
Loading

0 comments on commit 01d071a

Please sign in to comment.