From 9d6c29dc852c853e6d473eed18d6c8b3a996f4e5 Mon Sep 17 00:00:00 2001 From: Denis Date: Tue, 29 Oct 2024 11:27:06 +0100 Subject: [PATCH] Add a function to perform case-insensitive search in mapstr.M (#244) `FindFold` (similar to `strings.EqualFold`) traverses the map and tries to perform a case-insensitive match of each key segment on each map level. --- mapstr/mapstr.go | 72 +++++++++++++++++++++++++++-- mapstr/mapstr_test.go | 103 ++++++++++++++++++++++++++++++++++++++++++ 2 files changed, 171 insertions(+), 4 deletions(-) diff --git a/mapstr/mapstr.go b/mapstr/mapstr.go index 39b7fd3f..e507ce5a 100644 --- a/mapstr/mapstr.go +++ b/mapstr/mapstr.go @@ -41,6 +41,10 @@ const ( var ( // ErrKeyNotFound indicates that the specified key was not found. ErrKeyNotFound = errors.New("key not found") + // ErrKeyCollision indicates that during the case-insensitive search multiple keys matched + ErrKeyCollision = errors.New("key collision") + // ErrNotMapType indicates that the given value is not map type + ErrNotMapType = errors.New("value is not a map") ) // EventMetadata contains fields and tags that can be added to an event via @@ -172,6 +176,62 @@ func (m M) HasKey(key string) (bool, error) { return hasKey, err } +// FindFold accepts a key and traverses the map trying to match every key segment +// using `strings.FindFold` (case-insensitive match) and returns the actual +// key of the map that matched the given key and the value stored under this key. +// Returns `ErrKeyCollision` if multiple keys match the same request. +// Returns `ErrNotMapType` when one of the values on the path is not a map and cannot be traversed. +func (m M) FindFold(key string) (matchedKey string, value interface{}, err error) { + path := strings.Split(key, ".") + // the initial value must be `true` for the first iteration to work + found := true + // start with the root + current := m + // allocate only once + var mapType bool + + for i, segment := range path { + if !found { + return "", nil, ErrKeyNotFound + } + found = false + + // we have to go through the list of all key on each level to detect case-insensitive collisions + for k := range current { + if !strings.EqualFold(segment, k) { + continue + } + // if already found on this level, it's a collision + if found { + return "", nil, fmt.Errorf("key collision on the same path %q, previous match - %q, another subkey - %q: %w", key, matchedKey, k, ErrKeyCollision) + } + + // mark for collision detection + found = true + + // build the result with the currently matched segment + matchedKey += k + value = current[k] + + // if it's the last segment, we don't need to go deeper + if i == len(path)-1 { + continue + } + + // if it's not the last segment we put the separator dot + matchedKey += "." + + // go one level deeper + current, mapType = tryToMapStr(current[k]) + if !mapType { + return "", nil, fmt.Errorf("cannot continue path %q (full: %q), next element is not a map: %w", matchedKey, key, ErrNotMapType) + } + } + } + + return matchedKey, value, nil +} + // GetValue gets a value from the map. If the key does not exist then an error // is returned. func (m M) GetValue(key string) (interface{}, error) { @@ -266,10 +326,12 @@ func (m M) Format(f fmt.State, c rune) { // Flatten flattens the given M and returns a flat M. // // Example: -// "hello": M{"world": "test" } +// +// "hello": M{"world": "test" } // // This is converted to: -// "hello.world": "test" +// +// "hello.world": "test" // // This can be useful for testing or logging. func (m M) Flatten() M { @@ -299,10 +361,12 @@ func flatten(prefix string, in, out M) M { // FlattenKeys flattens given MapStr keys and returns a containing array pointer // // Example: -// "hello": MapStr{"world": "test" } +// +// "hello": MapStr{"world": "test" } // // This is converted to: -// ["hello.world"] +// +// ["hello.world"] func (m M) FlattenKeys() *[]string { out := make([]string, 0) flattenKeys("", m, &out) diff --git a/mapstr/mapstr_test.go b/mapstr/mapstr_test.go index 2497d0f3..85be4f29 100644 --- a/mapstr/mapstr_test.go +++ b/mapstr/mapstr_test.go @@ -1107,3 +1107,106 @@ func TestFormat(t *testing.T) { }) } } + +func TestFindFold(t *testing.T) { + field1level2 := M{ + "level3_Field1": "value2", + } + field1level1 := M{ + "non_map": "value1", + "level2_Field1": field1level2, + } + + input := M{ + // baseline + "level1_Field1": field1level1, + // fold equal testing + "Level1_fielD2": M{ + "lEvel2_fiEld2": M{ + "levEl3_fIeld2": "value3", + }, + }, + // collision testing + "level1_field2": M{ + "level2_field2": M{ + "level3_field2": "value4", + }, + }, + } + + cases := []struct { + name string + key string + expKey string + expVal interface{} + expErr string + }{ + { + name: "returns normal key, full match", + key: "level1_Field1.level2_Field1.level3_Field1", + expKey: "level1_Field1.level2_Field1.level3_Field1", + expVal: "value2", + }, + { + name: "returns normal key, partial match", + key: "level1_Field1.level2_Field1", + expKey: "level1_Field1.level2_Field1", + expVal: field1level2, + }, + { + name: "returns normal key, one level", + key: "level1_Field1", + expKey: "level1_Field1", + expVal: field1level1, + }, + { + name: "returns case-insensitive full match", + key: "level1_field1.level2_field1.level3_field1", + expKey: "level1_Field1.level2_Field1.level3_Field1", + expVal: "value2", + }, + { + name: "returns case-insensitive partial match", + key: "level1_field1.level2_field1", + expKey: "level1_Field1.level2_Field1", + expVal: field1level2, + }, + { + name: "returns case-insensitive one-level match", + key: "level1_field1", + expKey: "level1_Field1", + expVal: field1level1, + }, + { + name: "returns collision error", + key: "level1_field2.level2_field2.level3_field2", + expErr: "collision", + }, + { + name: "returns non-map error", + key: "level1_field1.non_map.some_key", + expErr: "next element is not a map", + }, + { + name: "returns non-found error", + key: "level1_field1.not_exists.some_key", + expErr: "key not found", + }, + } + + for _, tc := range cases { + t.Run(tc.name, func(t *testing.T) { + key, val, err := input.FindFold(tc.key) + if tc.expErr != "" { + require.Error(t, err) + assert.Contains(t, err.Error(), tc.expErr) + assert.Nil(t, val) + assert.Empty(t, key) + return + } + require.NoError(t, err) + assert.Equal(t, tc.expKey, key) + assert.Equal(t, tc.expVal, val) + }) + } +}