diff --git a/pkg/scale/decode.go b/pkg/scale/decode.go index f4dd530e53..6acc567fca 100644 --- a/pkg/scale/decode.go +++ b/pkg/scale/decode.go @@ -150,6 +150,8 @@ func (ds *decodeState) unmarshal(dstv reflect.Value) (err error) { err = ds.decodeArray(dstv) case reflect.Slice: err = ds.decodeSlice(dstv) + case reflect.Map: + err = ds.decodeMap(dstv) default: err = fmt.Errorf("%w: %T", ErrUnsupportedType, in) } @@ -417,6 +419,34 @@ func (ds *decodeState) decodeArray(dstv reflect.Value) (err error) { return } +func (ds *decodeState) decodeMap(dstv reflect.Value) (err error) { + numberOfTuples, err := ds.decodeLength() + if err != nil { + return fmt.Errorf("decoding length: %w", err) + } + in := dstv.Interface() + + for i := uint(0); i < numberOfTuples; i++ { + tempKeyType := reflect.TypeOf(in).Key() + tempKey := reflect.New(tempKeyType).Elem() + err = ds.unmarshal(tempKey) + if err != nil { + return fmt.Errorf("decoding key %d of %d: %w", i+1, numberOfTuples, err) + } + + tempElemType := reflect.TypeOf(in).Elem() + tempElem := reflect.New(tempElemType).Elem() + err = ds.unmarshal(tempElem) + if err != nil { + return fmt.Errorf("decoding value %d of %d: %w", i+1, numberOfTuples, err) + } + + dstv.SetMapIndex(tempKey, tempElem) + } + + return nil +} + // decodeStruct decodes a byte array representing a SCALE tuple. The order of data is // determined by the source tuple in rust, or the struct field order in a go struct func (ds *decodeState) decodeStruct(dstv reflect.Value) (err error) { diff --git a/pkg/scale/decode_test.go b/pkg/scale/decode_test.go index ac644871ce..d97dcaac73 100644 --- a/pkg/scale/decode_test.go +++ b/pkg/scale/decode_test.go @@ -132,6 +132,121 @@ func Test_decodeState_decodeSlice(t *testing.T) { } } +// // Rust code to encode a map of string to struct. +// let mut btree_map: BTreeMap = BTreeMap::new(); +// match btree_map.entry("string1".to_string()) { +// Entry::Vacant(entry) => { +// entry.insert(User{ +// active: true, +// username: "lorem".to_string(), +// email: "lorem@ipsum.org".to_string(), +// sign_in_count: 1, +// }); +// () +// }, +// Entry::Occupied(_) => (), +// } +// match btree_map.entry("string2".to_string()) { +// Entry::Vacant(entry) => { +// entry.insert(User{ +// active: false, +// username: "john".to_string(), +// email: "jack@gmail.com".to_string(), +// sign_in_count: 73, +// }); +// () +// }, +// Entry::Occupied(_) => (), +// } +// println!("{:?}", btree_map.encode()); + +type user struct { + Active bool + Username string + Email string + SignInCount uint64 +} + +func Test_decodeState_decodeMap(t *testing.T) { + mapTests1 := []struct { + name string + input []byte + wantErr bool + expectedOutput map[int8][]byte + }{ + { + name: "testing a map of int8 to a byte array 1", + input: []byte{4, 2, 44, 115, 111, 109, 101, 32, 115, 116, 114, 105, 110, 103}, + expectedOutput: map[int8][]byte{2: []byte("some string")}, + }, + { + name: "testing a map of int8 to a byte array 2", + input: []byte{ + 8, 2, 44, 115, 111, 109, 101, 32, 115, 116, 114, 105, 110, 103, 16, 44, 108, 111, 114, 101, 109, 32, + 105, 112, 115, 117, 109, + }, + expectedOutput: map[int8][]byte{ + 2: []byte("some string"), + 16: []byte("lorem ipsum"), + }, + }, + } + + for _, tt := range mapTests1 { + tt := tt + t.Run(tt.name, func(t *testing.T) { + actualOutput := make(map[int8][]byte) + if err := Unmarshal(tt.input, &actualOutput); (err != nil) != tt.wantErr { + t.Errorf("decodeState.unmarshal() error = %v, wantErr %v", err, tt.wantErr) + } + + if !reflect.DeepEqual(actualOutput, tt.expectedOutput) { + t.Errorf("decodeState.unmarshal() = %v, want %v", actualOutput, tt.expectedOutput) + } + }) + } + + mapTests2 := []struct { + name string + input []byte + wantErr bool + expectedOutput map[string]user + }{ + { + name: "testing a map of string to struct", + input: []byte{8, 28, 115, 116, 114, 105, 110, 103, 49, 1, 20, 108, 111, 114, 101, 109, 60, 108, 111, 114, 101, 109, 64, 105, 112, 115, 117, 109, 46, 111, 114, 103, 1, 0, 0, 0, 0, 0, 0, 0, 28, 115, 116, 114, 105, 110, 103, 50, 0, 16, 106, 111, 104, 110, 56, 106, 97, 99, 107, 64, 103, 109, 97, 105, 108, 46, 99, 111, 109, 73, 0, 0, 0, 0, 0, 0, 0}, //nolint:lll + expectedOutput: map[string]user{ + "string1": { + Active: true, + Username: "lorem", + Email: "lorem@ipsum.org", + SignInCount: 1, + }, + "string2": { + Active: false, + Username: "john", + Email: "jack@gmail.com", + SignInCount: 73, + }, + }, + }, + } + + for _, tt := range mapTests2 { + tt := tt + t.Run(tt.name, func(t *testing.T) { + actualOutput := make(map[string]user) + if err := Unmarshal(tt.input, &actualOutput); (err != nil) != tt.wantErr { + t.Errorf("decodeState.unmarshal() error = %v, wantErr %v", err, tt.wantErr) + } + + if !reflect.DeepEqual(actualOutput, tt.expectedOutput) { + t.Errorf("decodeState.unmarshal() = %v, want %v", actualOutput, tt.expectedOutput) + } + }) + } +} + func Test_unmarshal_optionality(t *testing.T) { var ptrTests tests for _, t := range append(tests{}, allTests...) { @@ -167,7 +282,14 @@ func Test_unmarshal_optionality(t *testing.T) { t.Errorf("decodeState.unmarshal() = %s", diff) } default: - dst := reflect.New(reflect.TypeOf(tt.in)).Interface() + var dst interface{} + + if reflect.TypeOf(tt.in).Kind().String() == "map" { + dst = &(map[int8][]byte{}) + } else { + dst = reflect.New(reflect.TypeOf(tt.in)).Interface() + } + if err := Unmarshal(tt.want, &dst); (err != nil) != tt.wantErr { t.Errorf("decodeState.unmarshal() error = %v, wantErr %v", err, tt.wantErr) return diff --git a/pkg/scale/encode.go b/pkg/scale/encode.go index 1ab0bbc9cc..e32696bb60 100644 --- a/pkg/scale/encode.go +++ b/pkg/scale/encode.go @@ -10,6 +10,7 @@ import ( "io" "math/big" "reflect" + "sort" ) // Encoder scale encodes to a given io.Writer. @@ -106,6 +107,8 @@ func (es *encodeState) marshal(in interface{}) (err error) { err = es.encodeArray(in) case reflect.Slice: err = es.encodeSlice(in) + case reflect.Map: + err = es.encodeMap(in) default: err = fmt.Errorf("%w: %T", ErrUnsupportedType, in) } @@ -223,6 +226,40 @@ func (es *encodeState) encodeArray(in interface{}) (err error) { return } +func (es *encodeState) encodeMap(in interface{}) (err error) { + v := reflect.ValueOf(in) + err = es.encodeLength(v.Len()) + if err != nil { + return fmt.Errorf("encoding length: %w", err) + } + + mapKeys := v.MapKeys() + + sort.Slice(mapKeys, func(i, j int) bool { + keyByteOfI, _ := Marshal(mapKeys[i].Interface()) + keyByteOfJ, _ := Marshal(mapKeys[j].Interface()) + return bytes.Compare(keyByteOfI, keyByteOfJ) < 0 + }) + + for _, key := range mapKeys { + err = es.marshal(key.Interface()) + if err != nil { + return fmt.Errorf("encoding map key: %w", err) + } + + mapValue := v.MapIndex(key) + if !mapValue.CanInterface() { + continue + } + + err = es.marshal(mapValue.Interface()) + if err != nil { + return fmt.Errorf("encoding map value: %w", err) + } + } + return nil +} + // encodeBigInt performs the same encoding as encodeInteger, except on a big.Int. // if 2^30 <= n < 2^536 write // [lower 2 bits of first byte = 11] [upper 6 bits of first byte = # of bytes following less 4] diff --git a/pkg/scale/encode_test.go b/pkg/scale/encode_test.go index fd43c17201..6864c256de 100644 --- a/pkg/scale/encode_test.go +++ b/pkg/scale/encode_test.go @@ -909,9 +909,28 @@ var ( }, } + mapTests = tests{ + { + name: "testMap1", + in: map[int8][]byte{2: []byte("some string")}, + want: []byte{4, 2, 44, 115, 111, 109, 101, 32, 115, 116, 114, 105, 110, 103}, + }, + { + name: "testMap2", + in: map[int8][]byte{ + 2: []byte("some string"), + 16: []byte("lorem ipsum"), + }, + want: []byte{ + 8, 2, 44, 115, 111, 109, 101, 32, 115, 116, 114, 105, 110, 103, 16, 44, 108, 111, 114, 101, 109, 32, + 105, 112, 115, 117, 109, + }, + }, + } + allTests = newTests( fixedWidthIntegerTests, variableWidthIntegerTests, stringTests, - boolTests, structTests, sliceTests, arrayTests, + boolTests, structTests, sliceTests, arrayTests, mapTests, varyingDataTypeTests, ) ) @@ -1096,6 +1115,25 @@ func Test_encodeState_encodeArray(t *testing.T) { } } +func Test_encodeState_encodeMap(t *testing.T) { + for _, tt := range mapTests { + tt := tt + t.Run(tt.name, func(t *testing.T) { + buffer := bytes.NewBuffer(nil) + es := &encodeState{ + Writer: buffer, + fieldScaleIndicesCache: cache, + } + if err := es.marshal(tt.in); (err != nil) != tt.wantErr { + t.Errorf("encodeState.encodeMap() error = %v, wantErr %v", err, tt.wantErr) + } + if !reflect.DeepEqual(buffer.Bytes(), tt.want) { + t.Errorf("encodeState.encodeMap() = %v, want %v", buffer.Bytes(), tt.want) + } + }) + } +} + func Test_marshal_optionality(t *testing.T) { var ptrTests tests for i := range allTests {