Skip to content

Commit

Permalink
Merge branch 'development' into eclesio/fix/import-block-announce
Browse files Browse the repository at this point in the history
  • Loading branch information
EclesioMeloJunior committed Oct 21, 2022
2 parents 5b004be + 405db51 commit b04ed3b
Show file tree
Hide file tree
Showing 4 changed files with 229 additions and 2 deletions.
30 changes: 30 additions & 0 deletions pkg/scale/decode.go
Original file line number Diff line number Diff line change
Expand Up @@ -150,6 +150,8 @@ func (ds *decodeState) unmarshal(dstv reflect.Value) (err error) {
err = ds.decodeArray(dstv)
case reflect.Slice:
err = ds.decodeSlice(dstv)
case reflect.Map:
err = ds.decodeMap(dstv)
default:
err = fmt.Errorf("%w: %T", ErrUnsupportedType, in)
}
Expand Down Expand Up @@ -417,6 +419,34 @@ func (ds *decodeState) decodeArray(dstv reflect.Value) (err error) {
return
}

func (ds *decodeState) decodeMap(dstv reflect.Value) (err error) {
numberOfTuples, err := ds.decodeLength()
if err != nil {
return fmt.Errorf("decoding length: %w", err)
}
in := dstv.Interface()

for i := uint(0); i < numberOfTuples; i++ {
tempKeyType := reflect.TypeOf(in).Key()
tempKey := reflect.New(tempKeyType).Elem()
err = ds.unmarshal(tempKey)
if err != nil {
return fmt.Errorf("decoding key %d of %d: %w", i+1, numberOfTuples, err)
}

tempElemType := reflect.TypeOf(in).Elem()
tempElem := reflect.New(tempElemType).Elem()
err = ds.unmarshal(tempElem)
if err != nil {
return fmt.Errorf("decoding value %d of %d: %w", i+1, numberOfTuples, err)
}

dstv.SetMapIndex(tempKey, tempElem)
}

return nil
}

// decodeStruct decodes a byte array representing a SCALE tuple. The order of data is
// determined by the source tuple in rust, or the struct field order in a go struct
func (ds *decodeState) decodeStruct(dstv reflect.Value) (err error) {
Expand Down
124 changes: 123 additions & 1 deletion pkg/scale/decode_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -132,6 +132,121 @@ func Test_decodeState_decodeSlice(t *testing.T) {
}
}

// // Rust code to encode a map of string to struct.
// let mut btree_map: BTreeMap<String, User> = BTreeMap::new();
// match btree_map.entry("string1".to_string()) {
// Entry::Vacant(entry) => {
// entry.insert(User{
// active: true,
// username: "lorem".to_string(),
// email: "lorem@ipsum.org".to_string(),
// sign_in_count: 1,
// });
// ()
// },
// Entry::Occupied(_) => (),
// }
// match btree_map.entry("string2".to_string()) {
// Entry::Vacant(entry) => {
// entry.insert(User{
// active: false,
// username: "john".to_string(),
// email: "jack@gmail.com".to_string(),
// sign_in_count: 73,
// });
// ()
// },
// Entry::Occupied(_) => (),
// }
// println!("{:?}", btree_map.encode());

type user struct {
Active bool
Username string
Email string
SignInCount uint64
}

func Test_decodeState_decodeMap(t *testing.T) {
mapTests1 := []struct {
name string
input []byte
wantErr bool
expectedOutput map[int8][]byte
}{
{
name: "testing a map of int8 to a byte array 1",
input: []byte{4, 2, 44, 115, 111, 109, 101, 32, 115, 116, 114, 105, 110, 103},
expectedOutput: map[int8][]byte{2: []byte("some string")},
},
{
name: "testing a map of int8 to a byte array 2",
input: []byte{
8, 2, 44, 115, 111, 109, 101, 32, 115, 116, 114, 105, 110, 103, 16, 44, 108, 111, 114, 101, 109, 32,
105, 112, 115, 117, 109,
},
expectedOutput: map[int8][]byte{
2: []byte("some string"),
16: []byte("lorem ipsum"),
},
},
}

for _, tt := range mapTests1 {
tt := tt
t.Run(tt.name, func(t *testing.T) {
actualOutput := make(map[int8][]byte)
if err := Unmarshal(tt.input, &actualOutput); (err != nil) != tt.wantErr {
t.Errorf("decodeState.unmarshal() error = %v, wantErr %v", err, tt.wantErr)
}

if !reflect.DeepEqual(actualOutput, tt.expectedOutput) {
t.Errorf("decodeState.unmarshal() = %v, want %v", actualOutput, tt.expectedOutput)
}
})
}

mapTests2 := []struct {
name string
input []byte
wantErr bool
expectedOutput map[string]user
}{
{
name: "testing a map of string to struct",
input: []byte{8, 28, 115, 116, 114, 105, 110, 103, 49, 1, 20, 108, 111, 114, 101, 109, 60, 108, 111, 114, 101, 109, 64, 105, 112, 115, 117, 109, 46, 111, 114, 103, 1, 0, 0, 0, 0, 0, 0, 0, 28, 115, 116, 114, 105, 110, 103, 50, 0, 16, 106, 111, 104, 110, 56, 106, 97, 99, 107, 64, 103, 109, 97, 105, 108, 46, 99, 111, 109, 73, 0, 0, 0, 0, 0, 0, 0}, //nolint:lll
expectedOutput: map[string]user{
"string1": {
Active: true,
Username: "lorem",
Email: "lorem@ipsum.org",
SignInCount: 1,
},
"string2": {
Active: false,
Username: "john",
Email: "jack@gmail.com",
SignInCount: 73,
},
},
},
}

for _, tt := range mapTests2 {
tt := tt
t.Run(tt.name, func(t *testing.T) {
actualOutput := make(map[string]user)
if err := Unmarshal(tt.input, &actualOutput); (err != nil) != tt.wantErr {
t.Errorf("decodeState.unmarshal() error = %v, wantErr %v", err, tt.wantErr)
}

if !reflect.DeepEqual(actualOutput, tt.expectedOutput) {
t.Errorf("decodeState.unmarshal() = %v, want %v", actualOutput, tt.expectedOutput)
}
})
}
}

func Test_unmarshal_optionality(t *testing.T) {
var ptrTests tests
for _, t := range append(tests{}, allTests...) {
Expand Down Expand Up @@ -167,7 +282,14 @@ func Test_unmarshal_optionality(t *testing.T) {
t.Errorf("decodeState.unmarshal() = %s", diff)
}
default:
dst := reflect.New(reflect.TypeOf(tt.in)).Interface()
var dst interface{}

if reflect.TypeOf(tt.in).Kind().String() == "map" {
dst = &(map[int8][]byte{})
} else {
dst = reflect.New(reflect.TypeOf(tt.in)).Interface()
}

if err := Unmarshal(tt.want, &dst); (err != nil) != tt.wantErr {
t.Errorf("decodeState.unmarshal() error = %v, wantErr %v", err, tt.wantErr)
return
Expand Down
37 changes: 37 additions & 0 deletions pkg/scale/encode.go
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@ import (
"io"
"math/big"
"reflect"
"sort"
)

// Encoder scale encodes to a given io.Writer.
Expand Down Expand Up @@ -106,6 +107,8 @@ func (es *encodeState) marshal(in interface{}) (err error) {
err = es.encodeArray(in)
case reflect.Slice:
err = es.encodeSlice(in)
case reflect.Map:
err = es.encodeMap(in)
default:
err = fmt.Errorf("%w: %T", ErrUnsupportedType, in)
}
Expand Down Expand Up @@ -223,6 +226,40 @@ func (es *encodeState) encodeArray(in interface{}) (err error) {
return
}

func (es *encodeState) encodeMap(in interface{}) (err error) {
v := reflect.ValueOf(in)
err = es.encodeLength(v.Len())
if err != nil {
return fmt.Errorf("encoding length: %w", err)
}

mapKeys := v.MapKeys()

sort.Slice(mapKeys, func(i, j int) bool {
keyByteOfI, _ := Marshal(mapKeys[i].Interface())
keyByteOfJ, _ := Marshal(mapKeys[j].Interface())
return bytes.Compare(keyByteOfI, keyByteOfJ) < 0
})

for _, key := range mapKeys {
err = es.marshal(key.Interface())
if err != nil {
return fmt.Errorf("encoding map key: %w", err)
}

mapValue := v.MapIndex(key)
if !mapValue.CanInterface() {
continue
}

err = es.marshal(mapValue.Interface())
if err != nil {
return fmt.Errorf("encoding map value: %w", err)
}
}
return nil
}

// encodeBigInt performs the same encoding as encodeInteger, except on a big.Int.
// if 2^30 <= n < 2^536 write
// [lower 2 bits of first byte = 11] [upper 6 bits of first byte = # of bytes following less 4]
Expand Down
40 changes: 39 additions & 1 deletion pkg/scale/encode_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -909,9 +909,28 @@ var (
},
}

mapTests = tests{
{
name: "testMap1",
in: map[int8][]byte{2: []byte("some string")},
want: []byte{4, 2, 44, 115, 111, 109, 101, 32, 115, 116, 114, 105, 110, 103},
},
{
name: "testMap2",
in: map[int8][]byte{
2: []byte("some string"),
16: []byte("lorem ipsum"),
},
want: []byte{
8, 2, 44, 115, 111, 109, 101, 32, 115, 116, 114, 105, 110, 103, 16, 44, 108, 111, 114, 101, 109, 32,
105, 112, 115, 117, 109,
},
},
}

allTests = newTests(
fixedWidthIntegerTests, variableWidthIntegerTests, stringTests,
boolTests, structTests, sliceTests, arrayTests,
boolTests, structTests, sliceTests, arrayTests, mapTests,
varyingDataTypeTests,
)
)
Expand Down Expand Up @@ -1096,6 +1115,25 @@ func Test_encodeState_encodeArray(t *testing.T) {
}
}

func Test_encodeState_encodeMap(t *testing.T) {
for _, tt := range mapTests {
tt := tt
t.Run(tt.name, func(t *testing.T) {
buffer := bytes.NewBuffer(nil)
es := &encodeState{
Writer: buffer,
fieldScaleIndicesCache: cache,
}
if err := es.marshal(tt.in); (err != nil) != tt.wantErr {
t.Errorf("encodeState.encodeMap() error = %v, wantErr %v", err, tt.wantErr)
}
if !reflect.DeepEqual(buffer.Bytes(), tt.want) {
t.Errorf("encodeState.encodeMap() = %v, want %v", buffer.Bytes(), tt.want)
}
})
}
}

func Test_marshal_optionality(t *testing.T) {
var ptrTests tests
for i := range allTests {
Expand Down

0 comments on commit b04ed3b

Please sign in to comment.