From 93e356b4ab0b65c71659bd52d73f618edffc96f5 Mon Sep 17 00:00:00 2001 From: Skye Gill Date: Fri, 3 Feb 2023 22:55:59 +0000 Subject: [PATCH] fix: introduced RWMutex to flag state to prevent concurrent r/w of map (#370) ## This PR - Introduces RWMutex on flag state to prevent concurrent read/write of map. ### Related Issues Fixes #368 ### Notes ### Follow-up Tasks ### How to test --------- Signed-off-by: Skye Gill --- pkg/eval/fractional_evaluation_test.go | 6 ++ pkg/eval/json_evaluator.go | 18 +++- pkg/eval/json_evaluator_model.go | 34 ++++++- pkg/eval/json_evaluator_model_test.go | 74 ++++++++++---- pkg/eval/json_evaluator_test.go | 129 +++++++++++++++++++++++++ 5 files changed, 235 insertions(+), 26 deletions(-) diff --git a/pkg/eval/fractional_evaluation_test.go b/pkg/eval/fractional_evaluation_test.go index b9faad41e..6c78fc1a3 100644 --- a/pkg/eval/fractional_evaluation_test.go +++ b/pkg/eval/fractional_evaluation_test.go @@ -1,6 +1,7 @@ package eval import ( + "sync" "testing" "github.com/open-feature/flagd/pkg/logger" @@ -10,6 +11,7 @@ import ( func TestFractionalEvaluation(t *testing.T) { flags := Flags{ + mx: &sync.RWMutex{}, Flags: map[string]Flag{ "headerColor": { State: "ENABLED", @@ -113,6 +115,7 @@ func TestFractionalEvaluation(t *testing.T) { }, "non even split": { flags: Flags{ + mx: &sync.RWMutex{}, Flags: map[string]Flag{ "headerColor": { State: "ENABLED", @@ -164,6 +167,7 @@ func TestFractionalEvaluation(t *testing.T) { }, "fallback to default variant if no email provided": { flags: Flags{ + mx: &sync.RWMutex{}, Flags: map[string]Flag{ "headerColor": { State: "ENABLED", @@ -206,6 +210,7 @@ func TestFractionalEvaluation(t *testing.T) { }, "fallback to default variant if invalid variant as result of fractional evaluation": { flags: Flags{ + mx: &sync.RWMutex{}, Flags: map[string]Flag{ "headerColor": { State: "ENABLED", @@ -240,6 +245,7 @@ func TestFractionalEvaluation(t *testing.T) { }, "fallback to default variant if percentages don't sum to 100": { flags: Flags{ + mx: &sync.RWMutex{}, Flags: map[string]Flag{ "headerColor": { State: "ENABLED", diff --git a/pkg/eval/json_evaluator.go b/pkg/eval/json_evaluator.go index 565986d77..15196c05b 100644 --- a/pkg/eval/json_evaluator.go +++ b/pkg/eval/json_evaluator.go @@ -8,6 +8,7 @@ import ( "regexp" "strconv" "strings" + mxSync "sync" "github.com/open-feature/flagd/pkg/sync" @@ -47,6 +48,7 @@ func NewJSONEvaluator(logger *logger.Logger) *JSONEvaluator { ), state: Flags{ Flags: map[string]Flag{}, + mx: &mxSync.RWMutex{}, }, } jsonlogic.AddOperator("fractionalEvaluation", ev.fractionalEvaluation) @@ -110,6 +112,8 @@ func (je *JSONEvaluator) ResolveAllValues(reqID string, context *structpb.Struct var variant string var reason string var err error + je.state.mx.RLock() + defer je.state.mx.RUnlock() for flagKey, flag := range je.state.Flags { defaultValue := flag.Variants[flag.DefaultVariant] switch defaultValue.(type) { @@ -161,6 +165,8 @@ func (je *JSONEvaluator) ResolveBooleanValue(reqID string, flagKey string, conte reason string, err error, ) { + je.state.mx.RLock() + defer je.state.mx.RUnlock() je.Logger.DebugWithID(reqID, fmt.Sprintf("evaluating boolean flag: %s", flagKey)) return resolve[bool](reqID, flagKey, context, je.evaluateVariant, je.state.Flags[flagKey].Variants) } @@ -171,6 +177,8 @@ func (je *JSONEvaluator) ResolveStringValue(reqID string, flagKey string, contex reason string, err error, ) { + je.state.mx.RLock() + defer je.state.mx.RUnlock() je.Logger.DebugWithID(reqID, fmt.Sprintf("evaluating string flag: %s", flagKey)) return resolve[string](reqID, flagKey, context, je.evaluateVariant, je.state.Flags[flagKey].Variants) } @@ -181,6 +189,8 @@ func (je *JSONEvaluator) ResolveFloatValue(reqID string, flagKey string, context reason string, err error, ) { + je.state.mx.RLock() + defer je.state.mx.RUnlock() je.Logger.DebugWithID(reqID, fmt.Sprintf("evaluating float flag: %s", flagKey)) value, variant, reason, err = resolve[float64]( reqID, flagKey, context, je.evaluateVariant, je.state.Flags[flagKey].Variants) @@ -193,6 +203,8 @@ func (je *JSONEvaluator) ResolveIntValue(reqID string, flagKey string, context * reason string, err error, ) { + je.state.mx.RLock() + defer je.state.mx.RUnlock() je.Logger.DebugWithID(reqID, fmt.Sprintf("evaluating int flag: %s", flagKey)) var val float64 val, variant, reason, err = resolve[float64]( @@ -207,6 +219,8 @@ func (je *JSONEvaluator) ResolveObjectValue(reqID string, flagKey string, contex reason string, err error, ) { + je.state.mx.RLock() + defer je.state.mx.RUnlock() je.Logger.DebugWithID(reqID, fmt.Sprintf("evaluating object flag: %s", flagKey)) return resolve[map[string]any](reqID, flagKey, context, je.evaluateVariant, je.state.Flags[flagKey].Variants) } @@ -256,7 +270,7 @@ func (je *JSONEvaluator) evaluateVariant( variant = strings.ReplaceAll(strings.TrimSpace(result.String()), "\"", "") // if this is a valid variant, return it - if _, ok := je.state.Flags[flagKey].Variants[variant]; ok { + if _, ok := flag.Variants[variant]; ok { return variant, model.TargetingMatchReason, nil } @@ -266,7 +280,7 @@ func (je *JSONEvaluator) evaluateVariant( reason = model.StaticReason } - return je.state.Flags[flagKey].DefaultVariant, reason, nil + return flag.DefaultVariant, reason, nil } // configToFlags convert string configurations to flags and store them to pointer newFlags diff --git a/pkg/eval/json_evaluator_model.go b/pkg/eval/json_evaluator_model.go index 67751d3c8..8dd8e157f 100644 --- a/pkg/eval/json_evaluator_model.go +++ b/pkg/eval/json_evaluator_model.go @@ -4,6 +4,7 @@ import ( "encoding/json" "fmt" "reflect" + "sync" "github.com/open-feature/flagd/pkg/logger" ) @@ -21,6 +22,7 @@ type Evaluators struct { } type Flags struct { + mx *sync.RWMutex Flags map[string]Flag `json:"flags"` } @@ -29,7 +31,10 @@ func (f Flags) Add(logger *logger.Logger, source string, ff Flags) map[string]in notifications := map[string]interface{}{} for k, newFlag := range ff.Flags { - if storedFlag, ok := f.Flags[k]; ok && storedFlag.Source != source { + f.mx.RLock() + storedFlag, ok := f.Flags[k] + f.mx.RUnlock() + if ok && storedFlag.Source != source { logger.Warn(fmt.Sprintf( "flag with key %s from source %s already exist, overriding this with flag from source %s", k, @@ -45,7 +50,9 @@ func (f Flags) Add(logger *logger.Logger, source string, ff Flags) map[string]in // Store the new version of the flag newFlag.Source = source + f.mx.Lock() f.Flags[k] = newFlag + f.mx.Unlock() } return notifications @@ -56,14 +63,18 @@ func (f Flags) Update(logger *logger.Logger, source string, ff Flags) map[string notifications := map[string]interface{}{} for k, flag := range ff.Flags { - if storedFlag, ok := f.Flags[k]; !ok { + f.mx.RLock() + storedFlag, ok := f.Flags[k] + f.mx.RUnlock() + if !ok { logger.Warn( - fmt.Sprintf("failed to update the flag, flag with key %s from source %s does not exisit.", + fmt.Sprintf("failed to update the flag, flag with key %s from source %s does not exist.", k, source)) continue - } else if storedFlag.Source != source { + } + if storedFlag.Source != source { logger.Warn(fmt.Sprintf( "flag with key %s from source %s already exist, overriding this with flag from source %s", k, @@ -78,7 +89,9 @@ func (f Flags) Update(logger *logger.Logger, source string, ff Flags) map[string } flag.Source = source + f.mx.Lock() f.Flags[k] = flag + f.mx.Unlock() } return notifications @@ -89,13 +102,18 @@ func (f Flags) Delete(logger *logger.Logger, source string, ff Flags) map[string notifications := map[string]interface{}{} for k := range ff.Flags { - if _, ok := f.Flags[k]; ok { + f.mx.RLock() + _, ok := f.Flags[k] + f.mx.RUnlock() + if ok { notifications[k] = map[string]interface{}{ "type": string(NotificationDelete), "source": source, } + f.mx.Lock() delete(f.Flags, k) + f.mx.Unlock() } else { logger.Warn( fmt.Sprintf("failed to remove flag, flag with key %s from source %s does not exisit.", @@ -111,6 +129,7 @@ func (f Flags) Delete(logger *logger.Logger, source string, ff Flags) map[string func (f Flags) Merge(logger *logger.Logger, source string, ff Flags) map[string]interface{} { notifications := map[string]interface{}{} + f.mx.Lock() for k, v := range f.Flags { if v.Source == source { if _, ok := ff.Flags[k]; !ok { @@ -124,11 +143,14 @@ func (f Flags) Merge(logger *logger.Logger, source string, ff Flags) map[string] } } } + f.mx.Unlock() for k, newFlag := range ff.Flags { newFlag.Source = source + f.mx.RLock() storedFlag, ok := f.Flags[k] + f.mx.RUnlock() if !ok { notifications[k] = map[string]interface{}{ "type": string(NotificationCreate), @@ -151,8 +173,10 @@ func (f Flags) Merge(logger *logger.Logger, source string, ff Flags) map[string] } } + f.mx.Lock() // Store the new version of the flag f.Flags[k] = newFlag + f.mx.Unlock() } return notifications diff --git a/pkg/eval/json_evaluator_model_test.go b/pkg/eval/json_evaluator_model_test.go index b4f7dc77a..dd2582c0d 100644 --- a/pkg/eval/json_evaluator_model_test.go +++ b/pkg/eval/json_evaluator_model_test.go @@ -1,6 +1,7 @@ package eval import ( + "sync" "testing" "github.com/open-feature/flagd/pkg/logger" @@ -18,41 +19,56 @@ func TestMergeFlags(t *testing.T) { wantNotifs map[string]interface{} }{ { - name: "both nil", - current: Flags{Flags: nil}, + name: "both nil", + current: Flags{ + mx: &sync.RWMutex{}, + Flags: nil, + }, new: Flags{Flags: nil}, want: Flags{Flags: map[string]Flag{}}, wantNotifs: map[string]interface{}{}, }, { - name: "both empty flags", - current: Flags{Flags: map[string]Flag{}}, + name: "both empty flags", + current: Flags{ + mx: &sync.RWMutex{}, + Flags: map[string]Flag{}, + }, new: Flags{Flags: map[string]Flag{}}, want: Flags{Flags: map[string]Flag{}}, wantNotifs: map[string]interface{}{}, }, { - name: "empty current", - current: Flags{Flags: nil}, + name: "empty current", + current: Flags{ + mx: &sync.RWMutex{}, + Flags: nil, + }, new: Flags{Flags: map[string]Flag{}}, want: Flags{Flags: map[string]Flag{}}, wantNotifs: map[string]interface{}{}, }, { - name: "empty new", - current: Flags{Flags: map[string]Flag{}}, + name: "empty new", + current: Flags{ + mx: &sync.RWMutex{}, + Flags: map[string]Flag{}, + }, new: Flags{Flags: nil}, want: Flags{Flags: map[string]Flag{}}, wantNotifs: map[string]interface{}{}, }, { name: "extra fields on each", - current: Flags{Flags: map[string]Flag{ - "waka": { - DefaultVariant: "off", - Source: "1", + current: Flags{ + mx: &sync.RWMutex{}, + Flags: map[string]Flag{ + "waka": { + DefaultVariant: "off", + Source: "1", + }, }, - }}, + }, new: Flags{Flags: map[string]Flag{ "paka": { DefaultVariant: "on", @@ -75,9 +91,10 @@ func TestMergeFlags(t *testing.T) { }, { name: "override", - current: Flags{Flags: map[string]Flag{ - "waka": {DefaultVariant: "off"}, - }}, + current: Flags{ + mx: &sync.RWMutex{}, + Flags: map[string]Flag{"waka": {DefaultVariant: "off"}}, + }, new: Flags{Flags: map[string]Flag{ "waka": {DefaultVariant: "on"}, "paka": {DefaultVariant: "on"}, @@ -93,9 +110,10 @@ func TestMergeFlags(t *testing.T) { }, { name: "identical", - current: Flags{Flags: map[string]Flag{ - "hello": {DefaultVariant: "off"}, - }}, + current: Flags{ + mx: &sync.RWMutex{}, + Flags: map[string]Flag{"hello": {DefaultVariant: "off"}}, + }, new: Flags{Flags: map[string]Flag{ "hello": {DefaultVariant: "off"}, }}, @@ -137,6 +155,7 @@ func TestFlags_Add(t *testing.T) { { name: "Add success", storedState: Flags{ + mx: &sync.RWMutex{}, Flags: map[string]Flag{ "A": {Source: mockSource}, }, @@ -150,6 +169,7 @@ func TestFlags_Add(t *testing.T) { }, }, expectedState: Flags{ + mx: &sync.RWMutex{}, Flags: map[string]Flag{ "A": {Source: mockSource}, "B": {Source: mockSource}, @@ -160,6 +180,7 @@ func TestFlags_Add(t *testing.T) { { name: "Add multiple success", storedState: Flags{ + mx: &sync.RWMutex{}, Flags: map[string]Flag{ "A": {Source: mockSource}, }, @@ -174,6 +195,7 @@ func TestFlags_Add(t *testing.T) { }, }, expectedState: Flags{ + mx: &sync.RWMutex{}, Flags: map[string]Flag{ "A": {Source: mockSource}, "B": {Source: mockSource}, @@ -185,6 +207,7 @@ func TestFlags_Add(t *testing.T) { { name: "Add success - conflict and override", storedState: Flags{ + mx: &sync.RWMutex{}, Flags: map[string]Flag{ "A": {Source: mockSource}, }, @@ -198,6 +221,7 @@ func TestFlags_Add(t *testing.T) { }, }, expectedState: Flags{ + mx: &sync.RWMutex{}, Flags: map[string]Flag{ "A": {Source: mockOverrideSource}, }, @@ -239,6 +263,7 @@ func TestFlags_Update(t *testing.T) { { name: "Update success", storedState: Flags{ + mx: &sync.RWMutex{}, Flags: map[string]Flag{ "A": {Source: mockSource, DefaultVariant: "True"}, }, @@ -252,6 +277,7 @@ func TestFlags_Update(t *testing.T) { }, }, expectedState: Flags{ + mx: &sync.RWMutex{}, Flags: map[string]Flag{ "A": {Source: mockSource, DefaultVariant: "False"}, }, @@ -261,6 +287,7 @@ func TestFlags_Update(t *testing.T) { { name: "Update multiple success", storedState: Flags{ + mx: &sync.RWMutex{}, Flags: map[string]Flag{ "A": {Source: mockSource, DefaultVariant: "True"}, "B": {Source: mockSource, DefaultVariant: "True"}, @@ -276,6 +303,7 @@ func TestFlags_Update(t *testing.T) { }, }, expectedState: Flags{ + mx: &sync.RWMutex{}, Flags: map[string]Flag{ "A": {Source: mockSource, DefaultVariant: "False"}, "B": {Source: mockSource, DefaultVariant: "False"}, @@ -286,6 +314,7 @@ func TestFlags_Update(t *testing.T) { { name: "Update success - conflict and override", storedState: Flags{ + mx: &sync.RWMutex{}, Flags: map[string]Flag{ "A": {Source: mockSource, DefaultVariant: "True"}, }, @@ -299,6 +328,7 @@ func TestFlags_Update(t *testing.T) { }, }, expectedState: Flags{ + mx: &sync.RWMutex{}, Flags: map[string]Flag{ "A": {Source: mockOverrideSource, DefaultVariant: "True"}, }, @@ -308,6 +338,7 @@ func TestFlags_Update(t *testing.T) { { name: "Update fail", storedState: Flags{ + mx: &sync.RWMutex{}, Flags: map[string]Flag{ "A": {Source: mockSource}, }, @@ -321,6 +352,7 @@ func TestFlags_Update(t *testing.T) { }, }, expectedState: Flags{ + mx: &sync.RWMutex{}, Flags: map[string]Flag{ "A": {Source: mockSource}, }, @@ -357,6 +389,7 @@ func TestFlags_Delete(t *testing.T) { { name: "Remove success", storedState: Flags{ + mx: &sync.RWMutex{}, Flags: map[string]Flag{ "A": {Source: mockSource}, "B": {Source: mockSource}, @@ -368,6 +401,7 @@ func TestFlags_Delete(t *testing.T) { }, }, expectedState: Flags{ + mx: &sync.RWMutex{}, Flags: map[string]Flag{ "B": {Source: mockSource}, }, @@ -377,6 +411,7 @@ func TestFlags_Delete(t *testing.T) { { name: "Nothing to remove", storedState: Flags{ + mx: &sync.RWMutex{}, Flags: map[string]Flag{ "A": {Source: mockSource}, "B": {Source: mockSource}, @@ -388,6 +423,7 @@ func TestFlags_Delete(t *testing.T) { }, }, expectedState: Flags{ + mx: &sync.RWMutex{}, Flags: map[string]Flag{ "A": {Source: mockSource}, "B": {Source: mockSource}, diff --git a/pkg/eval/json_evaluator_test.go b/pkg/eval/json_evaluator_test.go index 1b659207c..87c52918f 100644 --- a/pkg/eval/json_evaluator_test.go +++ b/pkg/eval/json_evaluator_test.go @@ -6,6 +6,7 @@ import ( "reflect" "strings" "testing" + "time" "github.com/open-feature/flagd/pkg/eval" "github.com/open-feature/flagd/pkg/logger" @@ -1065,3 +1066,131 @@ func TestState_Evaluator(t *testing.T) { }) } } + +func TestFlagStateSafeForConcurrentReadWrites(t *testing.T) { + tests := map[string]struct { + dataSyncType sync.Type + flagResolution func(evaluator *eval.JSONEvaluator) error + }{ + "Add_ResolveAllValues": { + dataSyncType: sync.ADD, + flagResolution: func(evaluator *eval.JSONEvaluator) error { + evaluator.ResolveAllValues("", nil) + return nil + }, + }, + "Update_ResolveAllValues": { + dataSyncType: sync.UPDATE, + flagResolution: func(evaluator *eval.JSONEvaluator) error { + evaluator.ResolveAllValues("", nil) + return nil + }, + }, + "Delete_ResolveAllValues": { + dataSyncType: sync.DELETE, + flagResolution: func(evaluator *eval.JSONEvaluator) error { + evaluator.ResolveAllValues("", nil) + return nil + }, + }, + "Add_ResolveBooleanValue": { + dataSyncType: sync.ADD, + flagResolution: func(evaluator *eval.JSONEvaluator) error { + _, _, _, err := evaluator.ResolveBooleanValue("", StaticBoolFlag, nil) + return err + }, + }, + "Update_ResolveStringValue": { + dataSyncType: sync.UPDATE, + flagResolution: func(evaluator *eval.JSONEvaluator) error { + _, _, _, err := evaluator.ResolveBooleanValue("", StaticStringValue, nil) + return err + }, + }, + "Delete_ResolveIntValue": { + dataSyncType: sync.DELETE, + flagResolution: func(evaluator *eval.JSONEvaluator) error { + _, _, _, err := evaluator.ResolveIntValue("", StaticIntFlag, nil) + return err + }, + }, + "Add_ResolveFloatValue": { + dataSyncType: sync.ADD, + flagResolution: func(evaluator *eval.JSONEvaluator) error { + _, _, _, err := evaluator.ResolveFloatValue("", StaticFloatFlag, nil) + return err + }, + }, + "Update_ResolveObjectValue": { + dataSyncType: sync.UPDATE, + flagResolution: func(evaluator *eval.JSONEvaluator) error { + _, _, _, err := evaluator.ResolveObjectValue("", StaticObjectFlag, nil) + return err + }, + }, + } + + for name, tt := range tests { + t.Run(name, func(t *testing.T) { + jsonEvaluator := eval.NewJSONEvaluator(logger.NewLogger(nil, false)) + + _, err := jsonEvaluator.SetState(sync.DataSync{FlagData: Flags, Type: sync.ADD}) + if err != nil { + t.Fatal(err) + } + + errChan := make(chan error) + + timeoutDur := 25 * time.Millisecond + + go func() { + defer func() { + if r := recover(); r != nil { + errChan <- fmt.Errorf("%v", r) + } + }() + timeout := time.After(timeoutDur) + + for { + select { + case <-timeout: + errChan <- nil + return + default: + _, err := jsonEvaluator.SetState(sync.DataSync{FlagData: Flags, Type: tt.dataSyncType}) + if err != nil { + errChan <- err + return + } + } + } + }() + + go func() { + defer func() { + if r := recover(); r != nil { + errChan <- fmt.Errorf("%v", r) + } + }() + timeout := time.After(timeoutDur) + + for { + select { + case <-timeout: + errChan <- nil + return + default: + _ = tt.flagResolution(jsonEvaluator) + } + } + }() + + for i := 0; i < 2; i++ { + err := <-errChan + if err != nil { + t.Error(err) + } + } + }) + } +}