diff --git a/pkg/scale/README.md b/pkg/scale/README.md index 029c022abc..6e97e6b1eb 100644 --- a/pkg/scale/README.md +++ b/pkg/scale/README.md @@ -88,7 +88,7 @@ import ( "github.com/ChainSafe/gossamer/pkg/scale" ) -func basicExample() { +func ExampleBasic() { // compact length encoded uint var ui uint = 999 bytes, err := scale.Marshal(ui) @@ -117,7 +117,7 @@ import ( "github.com/ChainSafe/gossamer/pkg/scale" ) -func structExample() { +func ExampleStruct() { type MyStruct struct { Baz bool `scale:"3"` Bar int32 `scale:"2"` @@ -165,7 +165,7 @@ import ( "github.com/ChainSafe/gossamer/pkg/scale" ) -func resultExample() { +func ExampleResult() { // pass in zero or non-zero values of the types for Ok and Err cases res := scale.NewResult(bool(false), string("")) @@ -210,10 +210,8 @@ func resultExample() { ### Varying Data Type -A `VaryingDataType` is analogous to a Rust enum. A `VaryingDataType` needs to be registered using the `RegisterVaryingDataType` function with its associated `VaryingDataTypeValue` types. `VaryingDataTypeValue` is an -interface with one `Index() uint` method that needs to be implemented. The returned `uint` index should be unique per type and needs to be the same index as defined in the Rust enum to ensure interopability. - -> TODO: The only custom `VaryingDataTypeValue` types supported are currently `struct`, `int`, and `int16`. Need to add other supported primitives. +A `VaryingDataType` is analogous to a Rust enum. A `VaryingDataType` needs to be constructed using the `NewVaryingDataType` constructor. `VaryingDataTypeValue` is an +interface with one `Index() uint` method that needs to be implemented. The returned `uint` index should be unique per type and needs to be the same index as defined in the Rust enum to ensure interopability. To set the value of the `VaryingDataType`, the `VaryingDataType.Set()` function should be called with an associated `VaryingDataTypeValue`. ``` import ( @@ -247,39 +245,82 @@ func (mi16 MyInt16) Index() uint { return 3 } -type MyVaryingDataType scale.VaryingDataType +func ExampleVaryingDataType() { + vdt, err := scale.NewVaryingDataType(MyStruct{}, MyOtherStruct{}, MyInt16(0)) + if err != nil { + panic(err) + } + + err = vdt.Set(MyStruct{ + Baz: true, + Bar: 999, + Foo: []byte{1, 2}, + }) + if err != nil { + panic(err) + } + + bytes, err := scale.Marshal(vdt) + if err != nil { + panic(err) + } + + vdt1, err := scale.NewVaryingDataType(MyStruct{}, MyOtherStruct{}, MyInt16(0)) + if err != nil { + panic(err) + } + + err = scale.Unmarshal(bytes, &vdt1) + if err != nil { + panic(err) + } + + if !reflect.DeepEqual(vdt, vdt1) { + panic(fmt.Errorf("uh oh: %+v %+v", vdt, vdt1)) + } +} +``` + +A `VaryingDataTypeSlice` is a slice containing multiple `VaryingDataType` elements. Each `VaryingDataTypeValue` must be of a supported type of the `VaryingDataType` passed into the `NewVaryingDataTypeSlice` constructor. The method to call to add `VaryingDataTypeValue` instances is `VaryingDataTypeSlice.Add()`. -func varyingDataTypeExample() { - err := scale.RegisterVaryingDataType(MyVaryingDataType{}, MyStruct{}, MyOtherStruct{}, MyInt16(0)) +``` +func ExampleVaryingDataTypeSlice() { + vdt, err := scale.NewVaryingDataType(MyStruct{}, MyOtherStruct{}, MyInt16(0)) if err != nil { panic(err) } - mvdt := MyVaryingDataType{ + vdts := scale.NewVaryingDataTypeSlice(vdt) + + err = vdts.Add( MyStruct{ Baz: true, Bar: 999, Foo: []byte{1, 2}, }, - MyOtherStruct{ - Foo: "hello", - Bar: 999, - Baz: 888, - }, - MyInt16(111), + MyInt16(1), + ) + if err != nil { + panic(err) } - bytes, err := scale.Marshal(mvdt) + + bytes, err := scale.Marshal(vdts) if err != nil { panic(err) } - var unmarshaled MyVaryingDataType - err = scale.Unmarshal(bytes, &unmarshaled) + vdts1 := scale.NewVaryingDataTypeSlice(vdt) if err != nil { panic(err) } - // [{Baz:true Bar:999 Foo:[1 2]} {Foo:hello Bar:999 Baz:888} 111] - fmt.Printf("%+v", unmarshaled) + err = scale.Unmarshal(bytes, &vdts1) + if err != nil { + panic(err) + } + + if !reflect.DeepEqual(vdts, vdts1) { + panic(fmt.Errorf("uh oh: %+v %+v", vdts, vdts1)) + } } ``` \ No newline at end of file diff --git a/pkg/scale/comparison_test.go b/pkg/scale/comparison_test.go index a1da3697c2..54f4225da3 100644 --- a/pkg/scale/comparison_test.go +++ b/pkg/scale/comparison_test.go @@ -64,7 +64,7 @@ func (prd SealDigest) Index() uint { } func TestOldVsNewEncoding(t *testing.T) { - oldDigest := types.Digest{ + oldDigests := types.Digest{ &types.ChangesTrieRootDigest{ Hash: common.Hash{0, 91, 50, 25, 214, 94, 119, 36, 71, 216, 33, 152, 85, 184, 34, 120, 61, 161, 164, 223, 76, 53, 40, 246, 76, 38, 235, 204, 43, 31, 179, 28}, }, @@ -81,34 +81,53 @@ func TestOldVsNewEncoding(t *testing.T) { Data: []byte{1, 3, 5, 7}, }, } - oldEncode, err := oldDigest.Encode() + oldEncode, err := oldDigests.Encode() if err != nil { t.Errorf("unexpected err: %v", err) return } - type Digests VaryingDataType - err = RegisterVaryingDataType(Digests{}, ChangesTrieRootDigest{}, PreRuntimeDigest{}, ConsensusDigest{}, SealDigest{}) + vdt, err := NewVaryingDataType(ChangesTrieRootDigest{}, PreRuntimeDigest{}, ConsensusDigest{}, SealDigest{}) if err != nil { t.Errorf("unexpected err: %v", err) return } - newDigest := Digests{ - ChangesTrieRootDigest{ - Hash: common.Hash{0, 91, 50, 25, 214, 94, 119, 36, 71, 216, 33, 152, 85, 184, 34, 120, 61, 161, 164, 223, 76, 53, 40, 246, 76, 38, 235, 204, 43, 31, 179, 28}, - }, - PreRuntimeDigest{ - ConsensusEngineID: types.BabeEngineID, - Data: []byte{1, 3, 5, 7}, - }, - ConsensusDigest{ - ConsensusEngineID: types.BabeEngineID, - Data: []byte{1, 3, 5, 7}, - }, - SealDigest{ - ConsensusEngineID: types.BabeEngineID, - Data: []byte{1, 3, 5, 7}, - }, + err = vdt.Set(ChangesTrieRootDigest{ + Hash: common.Hash{0, 91, 50, 25, 214, 94, 119, 36, 71, 216, 33, 152, 85, 184, 34, 120, 61, 161, 164, 223, 76, 53, 40, 246, 76, 38, 235, 204, 43, 31, 179, 28}, + }) + if err != nil { + t.Errorf("unexpected err: %v", err) + return + } + + newDigest := []VaryingDataType{ + mustNewVaryingDataTypeAndSet( + ChangesTrieRootDigest{ + Hash: common.Hash{0, 91, 50, 25, 214, 94, 119, 36, 71, 216, 33, 152, 85, 184, 34, 120, 61, 161, 164, 223, 76, 53, 40, 246, 76, 38, 235, 204, 43, 31, 179, 28}, + }, + ChangesTrieRootDigest{}, PreRuntimeDigest{}, ConsensusDigest{}, SealDigest{}, + ), + mustNewVaryingDataTypeAndSet( + PreRuntimeDigest{ + ConsensusEngineID: types.BabeEngineID, + Data: []byte{1, 3, 5, 7}, + }, + ChangesTrieRootDigest{}, PreRuntimeDigest{}, ConsensusDigest{}, SealDigest{}, + ), + mustNewVaryingDataTypeAndSet( + ConsensusDigest{ + ConsensusEngineID: types.BabeEngineID, + Data: []byte{1, 3, 5, 7}, + }, + ChangesTrieRootDigest{}, PreRuntimeDigest{}, ConsensusDigest{}, SealDigest{}, + ), + mustNewVaryingDataTypeAndSet( + SealDigest{ + ConsensusEngineID: types.BabeEngineID, + Data: []byte{1, 3, 5, 7}, + }, + ChangesTrieRootDigest{}, PreRuntimeDigest{}, ConsensusDigest{}, SealDigest{}, + ), } newEncode, err := Marshal(newDigest) @@ -120,13 +139,14 @@ func TestOldVsNewEncoding(t *testing.T) { t.Errorf("encodeState.encodeStruct() = %v, want %v", oldEncode, newEncode) } - var decoded Digests + decoded := NewVaryingDataTypeSlice(vdt) err = Unmarshal(newEncode, &decoded) if err != nil { t.Errorf("unexpected err: %v", err) } - if !reflect.DeepEqual(decoded, newDigest) { - t.Errorf("Unmarshal() = %v, want %v", decoded, newDigest) + // decoded.Types + if !reflect.DeepEqual(decoded.Types, newDigest) { + t.Errorf("Unmarshal() = %v, want %v", decoded.Types, newDigest) } } diff --git a/pkg/scale/decode.go b/pkg/scale/decode.go index 0f3d31fb85..f08029485e 100644 --- a/pkg/scale/decode.go +++ b/pkg/scale/decode.go @@ -25,6 +25,50 @@ import ( "reflect" ) +// indirect walks down v allocating pointers as needed, +// until it gets to a non-pointer. +func indirect(dstv reflect.Value) (elem reflect.Value) { + dstv0 := dstv + haveAddr := false + for { + // Load value from interface, but only if the result will be + // usefully addressable. + if dstv.Kind() == reflect.Interface && !dstv.IsNil() { + e := dstv.Elem() + if e.Kind() == reflect.Ptr && !e.IsNil() && e.Elem().Kind() == reflect.Ptr { + haveAddr = false + dstv = e + continue + } + } + if dstv.Kind() != reflect.Ptr { + break + } + if dstv.CanSet() { + break + } + // Prevent infinite loop if v is an interface pointing to its own address: + // var v interface{} + // v = &v + if dstv.Elem().Kind() == reflect.Interface && dstv.Elem().Elem() == dstv { + dstv = dstv.Elem() + break + } + if dstv.IsNil() { + dstv.Set(reflect.New(dstv.Type().Elem())) + } + if haveAddr { + dstv = dstv0 // restore original value after round-trip Value.Addr().Elem() + haveAddr = false + } else { + dstv = dstv.Elem() + } + } + elem = dstv + return +} + +// Unmarshal takes data and a destination pointer to unmarshal the data to. func Unmarshal(data []byte, dst interface{}) (err error) { dstv := reflect.ValueOf(dst) if dstv.Kind() != reflect.Ptr || dstv.IsNil() { @@ -32,6 +76,11 @@ func Unmarshal(data []byte, dst interface{}) (err error) { return } + elem := indirect(dstv) + if err != nil { + return + } + buf := &bytes.Buffer{} ds := decodeState{} _, err = buf.Write(data) @@ -39,7 +88,8 @@ func Unmarshal(data []byte, dst interface{}) (err error) { return } ds.Buffer = *buf - err = ds.unmarshal(dstv.Elem()) + + err = ds.unmarshal(elem) if err != nil { return } @@ -69,6 +119,10 @@ func (ds *decodeState) unmarshal(dstv reflect.Value) (err error) { err = ds.decodeBool(dstv) case Result: err = ds.decodeResult(dstv) + case VaryingDataType: + err = ds.decodeVaryingDataType(dstv) + case VaryingDataTypeSlice: + err = ds.decodeVaryingDataTypeSlice(dstv) default: t := reflect.TypeOf(in) switch t.Kind() { @@ -83,14 +137,7 @@ func (ds *decodeState) unmarshal(dstv reflect.Value) (err error) { case reflect.Array: err = ds.decodeArray(dstv) case reflect.Slice: - t := reflect.TypeOf(in) - // check if this is a convertible to VaryingDataType, if so decode using encodeVaryingDataType - switch t.ConvertibleTo(reflect.TypeOf(VaryingDataType{})) { - case true: - err = ds.decodeVaryingDataType(dstv) - case false: - err = ds.decodeSlice(dstv) - } + err = ds.decodeSlice(dstv) default: err = fmt.Errorf("unsupported type: %T", in) } @@ -229,56 +276,72 @@ func (ds *decodeState) decodePointer(dstv reflect.Value) (err error) { case 0x00: // nil case case 0x01: - elemType := reflect.TypeOf(dstv.Interface()).Elem() - tempElem := reflect.New(elemType) - err = ds.unmarshal(tempElem.Elem()) - if err != nil { - break + switch dstv.IsZero() { + case false: + if dstv.Elem().Kind() == reflect.Ptr { + err = ds.unmarshal(dstv.Elem().Elem()) + } else { + err = ds.unmarshal(dstv.Elem()) + } + case true: + elemType := reflect.TypeOf(dstv.Interface()).Elem() + tempElem := reflect.New(elemType) + err = ds.unmarshal(tempElem.Elem()) + if err != nil { + return + } + dstv.Set(tempElem) } - dstv.Set(tempElem) default: err = fmt.Errorf("unsupported Option value: %v, bytes: %v", rb, ds.Bytes()) } return } -func (ds *decodeState) decodeVaryingDataType(dstv reflect.Value) (err error) { +func (ds *decodeState) decodeVaryingDataTypeSlice(dstv reflect.Value) (err error) { + vdts := dstv.Interface().(VaryingDataTypeSlice) l, err := ds.decodeLength() if err != nil { return } - - dstt := reflect.TypeOf(dstv.Interface()) - key := fmt.Sprintf("%s.%s", dstt.PkgPath(), dstt.Name()) - mappedValues, ok := vdtCache[key] - if !ok { - err = fmt.Errorf("unable to find registered custom VaryingDataType: %T", dstv.Interface()) - return - } - - temp := reflect.New(dstt) for i := 0; i < l; i++ { - var b byte - b, err = ds.ReadByte() + vdt := vdts.VaryingDataType + vdtv := reflect.New(reflect.TypeOf(vdt)) + vdtv.Elem().Set(reflect.ValueOf(vdt)) + err = ds.unmarshal(vdtv.Elem()) if err != nil { return } + vdts.Types = append(vdts.Types, vdtv.Elem().Interface().(VaryingDataType)) + } + dstv.Set(reflect.ValueOf(vdts)) + return +} - val, ok := mappedValues[uint(b)] - if !ok { - err = fmt.Errorf("unable to find registered VaryingDataTypeValue for type: %T", dstv.Interface()) - return - } +func (ds *decodeState) decodeVaryingDataType(dstv reflect.Value) (err error) { + var b byte + b, err = ds.ReadByte() + if err != nil { + return + } - tempVal := reflect.New(reflect.TypeOf(val)).Elem() - err = ds.unmarshal(tempVal) - if err != nil { - return - } + vdt := dstv.Interface().(VaryingDataType) + val, ok := vdt.cache[uint(b)] + if !ok { + err = fmt.Errorf("unable to find VaryingDataTypeValue with index: %d", uint(b)) + return + } - temp.Elem().Set(reflect.Append(temp.Elem(), tempVal)) + tempVal := reflect.New(reflect.TypeOf(val)).Elem() + err = ds.unmarshal(tempVal) + if err != nil { + return } - dstv.Set(temp.Elem()) + err = vdt.Set(tempVal.Interface().(VaryingDataTypeValue)) + if err != nil { + return + } + dstv.Set(reflect.ValueOf(vdt)) return } @@ -510,56 +573,56 @@ func (ds *decodeState) decodeFixedWidthInt(dstv reflect.Value) (err error) { var b byte b, err = ds.ReadByte() if err != nil { - break + return } out = int8(b) case uint8: var b byte b, err = ds.ReadByte() if err != nil { - break + return } - out = uint8(b) + out = uint8(b) // nolint case int16: buf := make([]byte, 2) _, err = ds.Read(buf) if err != nil { - break + return } out = int16(binary.LittleEndian.Uint16(buf)) case uint16: buf := make([]byte, 2) _, err = ds.Read(buf) if err != nil { - break + return } out = binary.LittleEndian.Uint16(buf) case int32: buf := make([]byte, 4) _, err = ds.Read(buf) if err != nil { - break + return } out = int32(binary.LittleEndian.Uint32(buf)) case uint32: buf := make([]byte, 4) _, err = ds.Read(buf) if err != nil { - break + return } out = binary.LittleEndian.Uint32(buf) case int64: buf := make([]byte, 8) _, err = ds.Read(buf) if err != nil { - break + return } out = int64(binary.LittleEndian.Uint64(buf)) case uint64: buf := make([]byte, 8) _, err = ds.Read(buf) if err != nil { - break + return } out = binary.LittleEndian.Uint64(buf) default: diff --git a/pkg/scale/decode_test.go b/pkg/scale/decode_test.go index f8051b56da..22412b7d8e 100644 --- a/pkg/scale/decode_test.go +++ b/pkg/scale/decode_test.go @@ -145,7 +145,7 @@ func Test_decodeState_decodeSlice(t *testing.T) { func Test_unmarshal_optionality(t *testing.T) { var ptrTests tests - for _, t := range allTests { + for _, t := range append(tests{}, allTests...) { ptrTest := test{ name: t.name, in: t.in, @@ -159,22 +159,38 @@ func Test_unmarshal_optionality(t *testing.T) { } for _, tt := range ptrTests { t.Run(tt.name, func(t *testing.T) { - // this becomes a pointer to a zero value of the underlying value - dst := reflect.New(reflect.TypeOf(tt.in)).Interface() - if err := Unmarshal(tt.want, &dst); (err != nil) != tt.wantErr { - t.Errorf("decodeState.unmarshal() error = %v, wantErr %v", err, tt.wantErr) - return - } - var diff string - if tt.out != nil { - diff = cmp.Diff(reflect.ValueOf(dst).Elem().Interface(), reflect.ValueOf(tt.out).Interface(), cmpopts.IgnoreUnexported(tt.in)) - } else { - diff = cmp.Diff(reflect.ValueOf(dst).Elem().Interface(), reflect.ValueOf(tt.in).Interface(), cmpopts.IgnoreUnexported(big.Int{}, VDTValue2{}, MyStructWithIgnore{}, MyStructWithPrivate{})) - } - if diff != "" { - t.Errorf("decodeState.unmarshal() = %s", diff) - } + switch in := tt.in.(type) { + case VaryingDataType: + // copy the inputted vdt cause we need the cached values + copy := in + vdt := copy + vdt.value = nil + var dst interface{} = &vdt + if err := Unmarshal(tt.want, &dst); (err != nil) != tt.wantErr { + t.Errorf("decodeState.unmarshal() error = %v, wantErr %v", err, tt.wantErr) + return + } + diff := cmp.Diff(vdt.value, tt.in.(VaryingDataType).value, cmpopts.IgnoreUnexported(big.Int{}, VDTValue2{}, MyStructWithIgnore{}, MyStructWithPrivate{})) + if diff != "" { + t.Errorf("decodeState.unmarshal() = %s", diff) + } + default: + dst := reflect.New(reflect.TypeOf(tt.in)).Interface() + if err := Unmarshal(tt.want, &dst); (err != nil) != tt.wantErr { + t.Errorf("decodeState.unmarshal() error = %v, wantErr %v", err, tt.wantErr) + return + } + var diff string + if tt.out != nil { + diff = cmp.Diff(reflect.ValueOf(dst).Elem().Interface(), reflect.ValueOf(tt.out).Interface(), cmpopts.IgnoreUnexported(tt.in)) + } else { + diff = cmp.Diff(reflect.ValueOf(dst).Elem().Interface(), reflect.ValueOf(tt.in).Interface(), cmpopts.IgnoreUnexported(big.Int{}, VDTValue2{}, MyStructWithIgnore{}, MyStructWithPrivate{})) + } + if diff != "" { + t.Errorf("decodeState.unmarshal() = %s", diff) + } + } }) } } diff --git a/pkg/scale/encode.go b/pkg/scale/encode.go index 458d36695c..7c62a77bd5 100644 --- a/pkg/scale/encode.go +++ b/pkg/scale/encode.go @@ -24,6 +24,7 @@ import ( "reflect" ) +// Marshal takes in an interface{} and attempts to marshal into []byte func Marshal(v interface{}) (b []byte, err error) { es := encodeState{ fieldScaleIndicesCache: cache, @@ -61,6 +62,10 @@ func (es *encodeState) marshal(in interface{}) (err error) { err = es.encodeBool(in) case Result: err = es.encodeResult(in) + case VaryingDataType: + err = es.encodeVaryingDataType(in) + case VaryingDataTypeSlice: + err = es.encodeVaryingDataTypeSlice(in) default: switch reflect.TypeOf(in).Kind() { case reflect.Bool, reflect.Int, reflect.Int8, reflect.Int16, @@ -85,15 +90,7 @@ func (es *encodeState) marshal(in interface{}) (err error) { case reflect.Array: err = es.encodeArray(in) case reflect.Slice: - t := reflect.TypeOf(in) - // check if this is a convertible to VaryingDataType, if so encode using encodeVaryingDataType - switch t.ConvertibleTo(reflect.TypeOf(VaryingDataType{})) { - case true: - invdt := reflect.ValueOf(in).Convert(reflect.TypeOf(VaryingDataType{})) - err = es.encodeVaryingDataType(invdt.Interface().(VaryingDataType)) - case false: - err = es.encodeSlice(in) - } + err = es.encodeSlice(in) default: err = fmt.Errorf("unsupported type: %T", in) } @@ -134,6 +131,7 @@ func (es *encodeState) encodeCustomPrimitive(in interface{}) (err error) { err = es.marshal(in) return } + func (es *encodeState) encodeResult(res Result) (err error) { if !res.IsSet() { err = fmt.Errorf("Result is not set: %+v", res) @@ -163,20 +161,17 @@ func (es *encodeState) encodeResult(res Result) (err error) { return } -func (es *encodeState) encodeVaryingDataType(values VaryingDataType) (err error) { - err = es.encodeLength(len(values)) +func (es *encodeState) encodeVaryingDataType(vdt VaryingDataType) (err error) { + err = es.WriteByte(byte(vdt.value.Index())) if err != nil { return } - for _, val := range values { - // TODO: type checking of val against vdtCache to ensure it is a registered type - // encode type.Index (idx) for varying data type - err = es.WriteByte(byte(val.Index())) - if err != nil { - return - } - err = es.marshal(val) - } + err = es.marshal(vdt.value) + return +} + +func (es *encodeState) encodeVaryingDataTypeSlice(vdts VaryingDataTypeSlice) (err error) { + err = es.marshal(vdts.Types) return } diff --git a/pkg/scale/encode_test.go b/pkg/scale/encode_test.go index b0d93d2f94..4f9deedbd4 100644 --- a/pkg/scale/encode_test.go +++ b/pkg/scale/encode_test.go @@ -511,14 +511,7 @@ var ( }, } - nilPtrMyStruct *MyStruct - ptrMystruct *MyStruct = &MyStruct{ - Foo: []byte{0x01}, - Bar: 2, - Baz: true, - } - nilPtrMyStruct2 *MyStruct = nil - structTests = tests{ + structTests = tests{ { name: "struct {[]byte, int32}", in: MyStruct{ @@ -604,6 +597,7 @@ var ( Foo: []byte{0x01}, Bar: 2, Baz: true, + priv1: []byte{0x00}, }, want: []byte{0x04, 0x01, 0x02, 0, 0, 0, 0x01}, }, @@ -614,6 +608,9 @@ var ( Bar: 2, Baz: true, Ignore: "me", + somethingElse: &struct { + fields int + }{1}, }, want: []byte{0x04, 0x01, 0x02, 0, 0, 0, 0x01}, out: MyStructWithIgnore{ @@ -854,7 +851,8 @@ var ( allTests = newTests( fixedWidthIntegerTests, variableWidthIntegerTests, stringTests, - boolTests, structTests, sliceTests, arrayTests, varyingDataTypeTests, + boolTests, structTests, sliceTests, arrayTests, + varyingDataTypeTests, ) ) diff --git a/pkg/scale/result.go b/pkg/scale/result.go index 237d9609a6..5fd30879a4 100644 --- a/pkg/scale/result.go +++ b/pkg/scale/result.go @@ -21,11 +21,15 @@ import ( "reflect" ) +// ResultMode is the mode the Result is set to type ResultMode int const ( + // Unset ResultMode is zero value mode Unset ResultMode = iota + // OK case OK + // Err case Err ) @@ -37,7 +41,7 @@ type Result struct { } // NewResult is constructor for Result. Use nil to represent empty tuple () in Rust. -func NewResult(okIn interface{}, errIn interface{}) (res Result) { +func NewResult(okIn, errIn interface{}) (res Result) { switch okIn { case nil: res.ok = empty{} @@ -53,18 +57,25 @@ func NewResult(okIn interface{}, errIn interface{}) (res Result) { return } +// Set takes in a mode (OK/Err) and the associated interface and sets the Result value func (r *Result) Set(mode ResultMode, in interface{}) (err error) { switch mode { case OK: - if reflect.TypeOf(r.ok) != reflect.TypeOf(in) { + if reflect.TypeOf(r.ok) == reflect.TypeOf(empty{}) && in == nil { + r.mode = mode + return + } else if reflect.TypeOf(r.ok) != reflect.TypeOf(in) { err = fmt.Errorf("type mistmatch for result.ok: %T, and inputted: %T", r.ok, in) return } r.ok = in r.mode = mode case Err: - if reflect.TypeOf(r.err) != reflect.TypeOf(in) { - err = fmt.Errorf("type mistmatch for result.ok: %T, and inputted: %T", r.ok, in) + if reflect.TypeOf(r.err) == reflect.TypeOf(empty{}) && in == nil { + r.mode = mode + return + } else if reflect.TypeOf(r.err) != reflect.TypeOf(in) { + err = fmt.Errorf("type mistmatch for result.err: %T, and inputted: %T", r.ok, in) return } r.err = in @@ -75,9 +86,10 @@ func (r *Result) Set(mode ResultMode, in interface{}) (err error) { return } +// UnsetResult is error when Result is unset with a value. type UnsetResult error -// Result returns the result in go standard wrapping the Err case in a ResultErr +// Unwrap returns the result in go standard wrapping the Err case in a ResultErr func (r *Result) Unwrap() (ok interface{}, err error) { if !r.IsSet() { err = UnsetResult(fmt.Errorf("result is not set")) @@ -117,10 +129,12 @@ func (r *Result) IsSet() bool { type empty struct{} +// WrappedErr is returned by Result.Unwrap(). The underlying Err value is wrapped and stored in Err attribute type WrappedErr struct { Err interface{} } +// Error fulfils the error interface func (r WrappedErr) Error() string { return fmt.Sprintf("ResultErr %+v", r.Err) } diff --git a/pkg/scale/result_example_test.go b/pkg/scale/result_example_test.go new file mode 100644 index 0000000000..43d7084133 --- /dev/null +++ b/pkg/scale/result_example_test.go @@ -0,0 +1,53 @@ +package scale_test + +import ( + "fmt" + "testing" + + "github.com/ChainSafe/gossamer/pkg/scale" +) + +func ExampleResult() { + // pass in zero or non-zero values of the types for Ok and Err cases + res := scale.NewResult(bool(false), string("")) + + // set the OK case with a value of true, any values for OK that are not bool will return an error + err := res.Set(scale.OK, true) + if err != nil { + panic(err) + } + + bytes, err := scale.Marshal(res) + if err != nil { + panic(err) + } + + // [0x00, 0x01] + fmt.Printf("%v\n", bytes) + + res1 := scale.NewResult(bool(false), string("")) + + err = scale.Unmarshal(bytes, &res1) + if err != nil { + panic(err) + } + + // res1 should be Set with OK mode and value of true + ok, err := res1.Unwrap() + if err != nil { + panic(err) + } + + switch ok := ok.(type) { + case bool: + if !ok { + panic(fmt.Errorf("unexpected ok value: %v", ok)) + } + default: + panic(fmt.Errorf("unexpected type: %T", ok)) + } +} + +func TestExampleResult(t *testing.T) { + ExampleResult() +} diff --git a/pkg/scale/result_test.go b/pkg/scale/result_test.go index 4140a4254b..d5b5320974 100644 --- a/pkg/scale/result_test.go +++ b/pkg/scale/result_test.go @@ -216,29 +216,45 @@ func TestResult_Set(t *testing.T) { in interface{} } tests := []struct { - name string - res Result - args args - wantErr bool + name string + res Result + args args + wantErr bool + wantResult Result }{ - // TODO: Add test cases. { args: args{ mode: Unset, }, + res: NewResult(nil, nil), wantErr: true, + wantResult: Result{ + ok: empty{}, err: empty{}, + }, }, { args: args{ mode: OK, in: nil, }, + res: NewResult(nil, nil), + wantResult: Result{ + ok: empty{}, + err: empty{}, + mode: OK, + }, }, { args: args{ mode: Err, in: nil, }, + res: NewResult(nil, nil), + wantResult: Result{ + ok: empty{}, + err: empty{}, + mode: Err, + }, }, { args: args{ @@ -246,6 +262,11 @@ func TestResult_Set(t *testing.T) { in: true, }, res: NewResult(true, nil), + wantResult: Result{ + ok: true, + err: empty{}, + mode: OK, + }, }, { args: args{ @@ -253,6 +274,11 @@ func TestResult_Set(t *testing.T) { in: true, }, res: NewResult(nil, true), + wantResult: Result{ + ok: empty{}, + err: true, + mode: Err, + }, }, { args: args{ @@ -261,6 +287,10 @@ func TestResult_Set(t *testing.T) { }, res: NewResult("ok", "err"), wantErr: true, + wantResult: Result{ + ok: "ok", + err: "err", + }, }, { args: args{ @@ -269,6 +299,10 @@ func TestResult_Set(t *testing.T) { }, res: NewResult(nil, true), wantErr: true, + wantResult: Result{ + ok: empty{}, + err: true, + }, }, } for _, tt := range tests { @@ -277,6 +311,9 @@ func TestResult_Set(t *testing.T) { if err := r.Set(tt.args.mode, tt.args.in); (err != nil) != tt.wantErr { t.Errorf("Result.Set() error = %v, wantErr %v", err, tt.wantErr) } + if !reflect.DeepEqual(tt.wantResult, r) { + t.Errorf("Result.Unwrap() = %v, want %v", tt.wantResult, r) + } }) } } diff --git a/pkg/scale/uint128.go b/pkg/scale/uint128.go index 53dcf05282..12e00c42b9 100644 --- a/pkg/scale/uint128.go +++ b/pkg/scale/uint128.go @@ -87,7 +87,7 @@ func NewUint128(in interface{}, order ...binary.ByteOrder) (u *Uint128, err erro return } -// Bytes returns the Uint128 in little endian format by default. A variadic paramter +// Bytes returns the Uint128 in little endian format by default. A variadic parameter // order can be used to specify the binary.ByteOrder used func (u *Uint128) Bytes(order ...binary.ByteOrder) (b []byte) { var o binary.ByteOrder = binary.LittleEndian @@ -108,7 +108,7 @@ func (u *Uint128) Bytes(order ...binary.ByteOrder) (b []byte) { return } -// Cmp returns 1 if the receiver is greater than other, 0 if they are equal, and -1 otherwise. +// Compare returns 1 if the receiver is greater than other, 0 if they are equal, and -1 otherwise. func (u *Uint128) Compare(other *Uint128) int { switch { case u.Upper > other.Upper: diff --git a/pkg/scale/varying_data_type.go b/pkg/scale/varying_data_type.go index 8c2bf36f67..de524014e2 100644 --- a/pkg/scale/varying_data_type.go +++ b/pkg/scale/varying_data_type.go @@ -18,34 +18,93 @@ package scale import ( "fmt" - "reflect" ) -type varyingDataTypeCache map[string]map[uint]VaryingDataTypeValue +// VaryingDataTypeValue is used to represent scale encodable types of an associated VaryingDataType +type VaryingDataTypeValue interface { + Index() uint +} -var vdtCache varyingDataTypeCache = make(varyingDataTypeCache) +// VaryingDataTypeSlice is used to represent []VaryingDataType. SCALE requires knowledge +// of the underlying data, so it is required to have the VaryingDataType required for decoding +type VaryingDataTypeSlice struct { + VaryingDataType + Types []VaryingDataType +} -type VaryingDataType []VaryingDataTypeValue +// Add takes variadic parameter values to add VaryingDataTypeValue(s) +func (vdts *VaryingDataTypeSlice) Add(values ...VaryingDataTypeValue) (err error) { + for _, val := range values { + copied := vdts.VaryingDataType + err = copied.Set(val) + if err != nil { + return + } + vdts.Types = append(vdts.Types, copied) + } + return +} -func RegisterVaryingDataType(in interface{}, values ...VaryingDataTypeValue) (err error) { - t := reflect.TypeOf(in) - if !t.ConvertibleTo(reflect.TypeOf(VaryingDataType{})) { - err = fmt.Errorf("%T is not a VaryingDataType", in) - return +// NewVaryingDataTypeSlice is constructor for VaryingDataTypeSlice +func NewVaryingDataTypeSlice(vdt VaryingDataType) (vdts VaryingDataTypeSlice) { + vdts.VaryingDataType = vdt + vdts.Types = make([]VaryingDataType, 0) + return +} + +func mustNewVaryingDataTypeSliceAndSet(vdt VaryingDataType, values ...VaryingDataTypeValue) (vdts VaryingDataTypeSlice) { + vdts = NewVaryingDataTypeSlice(vdt) + if err := vdts.Add(values...); err != nil { + panic(err) } + return +} - key := fmt.Sprintf("%s.%s", t.PkgPath(), t.Name()) - _, ok := vdtCache[key] +// VaryingDataType is analogous to a rust enum. Name is taken from polkadot spec. +type VaryingDataType struct { + value VaryingDataTypeValue + cache map[uint]VaryingDataTypeValue +} + +// Set will set the VaryingDataType value +func (vdt *VaryingDataType) Set(value VaryingDataTypeValue) (err error) { + _, ok := vdt.cache[value.Index()] if !ok { - vdtCache[key] = make(map[uint]VaryingDataTypeValue) + err = fmt.Errorf("unable to append VaryingDataTypeValue: %T, not in cache", value) + return } - for _, val := range values { - vdtCache[key][val.Index()] = val + vdt.value = value + return +} + +// Value returns value stored in vdt +func (vdt *VaryingDataType) Value() VaryingDataTypeValue { + return vdt.value +} + +// NewVaryingDataType is constructor for VaryingDataType +func NewVaryingDataType(values ...VaryingDataTypeValue) (vdt VaryingDataType, err error) { + if len(values) == 0 { + err = fmt.Errorf("must provide atleast one VaryingDataTypeValue") + return + } + vdt.cache = make(map[uint]VaryingDataTypeValue) + for _, value := range values { + _, ok := vdt.cache[value.Index()] + if ok { + err = fmt.Errorf("duplicate index with VaryingDataType: %T with index: %d", value, value.Index()) + return + } + vdt.cache[value.Index()] = value } return } -// VaryingDataType is used to represent scale encodable types -type VaryingDataTypeValue interface { - Index() uint +// MustNewVaryingDataType is constructor for VaryingDataType +func MustNewVaryingDataType(values ...VaryingDataTypeValue) (vdt VaryingDataType) { + vdt, err := NewVaryingDataType(values...) + if err != nil { + panic(err) + } + return } diff --git a/pkg/scale/varying_data_type_example_test.go b/pkg/scale/varying_data_type_example_test.go new file mode 100644 index 0000000000..b14d5c605c --- /dev/null +++ b/pkg/scale/varying_data_type_example_test.go @@ -0,0 +1,115 @@ +package scale_test + +import ( + "fmt" + "reflect" + "testing" + + "github.com/ChainSafe/gossamer/pkg/scale" +) + +type MyStruct struct { + Baz bool + Bar uint32 + Foo []byte +} + +func (ms MyStruct) Index() uint { + return 1 +} + +type MyOtherStruct struct { + Foo string + Bar uint64 + Baz uint +} + +func (mos MyOtherStruct) Index() uint { + return 2 +} + +type MyInt16 int16 + +func (mi16 MyInt16) Index() uint { + return 3 +} + +func ExampleVaryingDataType() { + vdt, err := scale.NewVaryingDataType(MyStruct{}, MyOtherStruct{}, MyInt16(0)) + if err != nil { + panic(err) + } + + err = vdt.Set(MyStruct{ + Baz: true, + Bar: 999, + Foo: []byte{1, 2}, + }) + if err != nil { + panic(err) + } + + bytes, err := scale.Marshal(vdt) + if err != nil { + panic(err) + } + + vdt1, err := scale.NewVaryingDataType(MyStruct{}, MyOtherStruct{}, MyInt16(0)) + if err != nil { + panic(err) + } + + err = scale.Unmarshal(bytes, &vdt1) + if err != nil { + panic(err) + } + + if !reflect.DeepEqual(vdt, vdt1) { + panic(fmt.Errorf("uh oh: %+v %+v", vdt, vdt1)) + } +} + +func ExampleVaryingDataTypeSlice() { + vdt, err := scale.NewVaryingDataType(MyStruct{}, MyOtherStruct{}, MyInt16(0)) + if err != nil { + panic(err) + } + + vdts := scale.NewVaryingDataTypeSlice(vdt) + + err = vdts.Add( + MyStruct{ + Baz: true, + Bar: 999, + Foo: []byte{1, 2}, + }, + MyInt16(1), + ) + if err != nil { + panic(err) + } + + bytes, err := scale.Marshal(vdts) + if err != nil { + panic(err) + } + + vdts1 := scale.NewVaryingDataTypeSlice(vdt) + if err != nil { + panic(err) + } + + err = scale.Unmarshal(bytes, &vdts1) + if err != nil { + panic(err) + } + + if !reflect.DeepEqual(vdts, vdts1) { + panic(fmt.Errorf("uh oh: %+v %+v", vdts, vdts1)) + } +} + +func TestExamples(t *testing.T) { + ExampleVaryingDataType() + ExampleVaryingDataTypeSlice() +} diff --git a/pkg/scale/varying_data_type_test.go b/pkg/scale/varying_data_type_test.go index 78200cf3f5..d92dd7c723 100644 --- a/pkg/scale/varying_data_type_test.go +++ b/pkg/scale/varying_data_type_test.go @@ -25,6 +25,23 @@ import ( "github.com/google/go-cmp/cmp/cmpopts" ) +func mustNewVaryingDataType(values ...VaryingDataTypeValue) (vdt VaryingDataType) { + vdt, err := NewVaryingDataType(values...) + if err != nil { + panic(err) + } + return +} + +func mustNewVaryingDataTypeAndSet(value VaryingDataTypeValue, values ...VaryingDataTypeValue) (vdt VaryingDataType) { + vdt = mustNewVaryingDataType(values...) + err := vdt.Set(value) + if err != nil { + panic(err) + } + return +} + type VDTValue struct { A *big.Int B int @@ -98,24 +115,14 @@ func (ctrd VDTValue3) Index() uint { return 4 } -type testVDT VaryingDataType - -func init() { - err := RegisterVaryingDataType(testVDT{}, VDTValue{}, VDTValue2{}, VDTValue1{}, VDTValue3(0)) - if err != nil { - panic(err) - } -} - var varyingDataTypeTests = tests{ { - in: testVDT{ - VDTValue1{ - O: newBigIntPtr(big.NewInt(1073741823)), - }, - }, + in: mustNewVaryingDataTypeAndSet( + VDTValue1{O: newBigIntPtr(big.NewInt(1073741823))}, + VDTValue{}, VDTValue1{}, VDTValue2{}, VDTValue3(0), + ), want: []byte{ - 4, 2, + 2, 0x01, 0xfe, 0xff, 0xff, 0xff, 0x00, 0x00, @@ -133,7 +140,7 @@ var varyingDataTypeTests = tests{ }, }, { - in: testVDT{ + in: mustNewVaryingDataTypeAndSet( VDTValue{ A: big.NewInt(1073741823), B: int(1073741823), @@ -150,6 +157,32 @@ var varyingDataTypeTests = tests{ M: testStrings[1], N: true, }, + VDTValue{}, VDTValue1{}, VDTValue2{}, VDTValue3(0), + ), + want: newWant( + // index of VDTValue + []byte{1}, + // encoding of struct + []byte{ + 0xfe, 0xff, 0xff, 0xff, + 0xfe, 0xff, 0xff, 0xff, + 0xfe, 0xff, 0xff, 0xff, + 0x01, + 0x01, + 0xff, 0x3f, + 0xff, 0x3f, + 0xff, 0xff, 0xff, 0x3f, + 0xff, 0xff, 0xff, 0x3f, + 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0x7f, + 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0x7f, + }, + append([]byte{0x01, 0x01}, byteArray(64)...), + append([]byte{0xC2, 0x02, 0x01, 0x00}, testStrings[1]...), + []byte{0x01}, + ), + }, + { + in: mustNewVaryingDataTypeAndSet( VDTValue1{ O: newBigIntPtr(big.NewInt(1073741823)), P: newIntPtr(int(1073741823)), @@ -166,6 +199,32 @@ var varyingDataTypeTests = tests{ AA: newStringPtr(testStrings[1]), AB: newBoolPtr(true), }, + VDTValue{}, VDTValue1{}, VDTValue2{}, VDTValue3(0), + ), + want: newWant( + // index of VDTValue1 + []byte{2}, + // encoding of struct + []byte{ + 0x01, 0xfe, 0xff, 0xff, 0xff, + 0x01, 0xfe, 0xff, 0xff, 0xff, + 0x01, 0xfe, 0xff, 0xff, 0xff, + 0x01, 0x01, + 0x01, 0x01, + 0x01, 0xff, 0x3f, + 0x01, 0xff, 0x3f, + 0x01, 0xff, 0xff, 0xff, 0x3f, + 0x01, 0xff, 0xff, 0xff, 0x3f, + 0x01, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0x7f, + 0x01, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0x7f, + }, + append([]byte{0x01, 0x01, 0x01}, byteArray(64)...), + append([]byte{0x01, 0xC2, 0x02, 0x01, 0x00}, testStrings[1]...), + []byte{0x01, 0x01}, + ), + }, + { + in: mustNewVaryingDataTypeAndSet( VDTValue2{ A: MyStruct{ Foo: []byte{0x01}, @@ -202,51 +261,9 @@ var varyingDataTypeTests = tests{ O: [2][]byte{{0x00, 0x01}, {0x01, 0x00}}, P: [2][2]byte{{0x00, 0x01}, {0x01, 0x00}}, }, - VDTValue3(16383), - }, + VDTValue{}, VDTValue1{}, VDTValue2{}, VDTValue3(0), + ), want: newWant( - // length encoding of 3 - []byte{16}, - // index of VDTValue - []byte{1}, - // encoding of struct - []byte{ - 0xfe, 0xff, 0xff, 0xff, - 0xfe, 0xff, 0xff, 0xff, - 0xfe, 0xff, 0xff, 0xff, - 0x01, - 0x01, - 0xff, 0x3f, - 0xff, 0x3f, - 0xff, 0xff, 0xff, 0x3f, - 0xff, 0xff, 0xff, 0x3f, - 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0x7f, - 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0x7f, - }, - append([]byte{0x01, 0x01}, byteArray(64)...), - append([]byte{0xC2, 0x02, 0x01, 0x00}, testStrings[1]...), - []byte{0x01}, - - // index of VDTValue1 - []byte{2}, - // encoding of struct - []byte{ - 0x01, 0xfe, 0xff, 0xff, 0xff, - 0x01, 0xfe, 0xff, 0xff, 0xff, - 0x01, 0xfe, 0xff, 0xff, 0xff, - 0x01, 0x01, - 0x01, 0x01, - 0x01, 0xff, 0x3f, - 0x01, 0xff, 0x3f, - 0x01, 0xff, 0xff, 0xff, 0x3f, - 0x01, 0xff, 0xff, 0xff, 0x3f, - 0x01, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0x7f, - 0x01, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0x7f, - }, - append([]byte{0x01, 0x01, 0x01}, byteArray(64)...), - append([]byte{0x01, 0xC2, 0x02, 0x01, 0x00}, testStrings[1]...), - []byte{0x01, 0x01}, - // index of VDTValue2 []byte{3}, // encoding of struct @@ -268,7 +285,14 @@ var varyingDataTypeTests = tests{ []byte{0x00, 0x04}, []byte{0x08, 0x00, 0x01, 0x08, 0x01, 0x00}, []byte{0x00, 0x01, 0x01, 0x00}, - + ), + }, + { + in: mustNewVaryingDataTypeAndSet( + VDTValue3(16383), + VDTValue{}, VDTValue1{}, VDTValue2{}, VDTValue3(0), + ), + want: newWant( // index of VDTValue2 []byte{4}, // encoding of int16 @@ -281,7 +305,8 @@ func Test_encodeState_encodeVaryingDataType(t *testing.T) { for _, tt := range varyingDataTypeTests { t.Run(tt.name, func(t *testing.T) { es := &encodeState{fieldScaleIndicesCache: cache} - if err := es.marshal(tt.in); (err != nil) != tt.wantErr { + vdt := tt.in.(VaryingDataType) + if err := es.marshal(vdt); (err != nil) != tt.wantErr { t.Errorf("encodeState.encodeStruct() error = %v, wantErr %v", err, tt.wantErr) } if !reflect.DeepEqual(es.Buffer.Bytes(), tt.want) { @@ -294,12 +319,319 @@ func Test_encodeState_encodeVaryingDataType(t *testing.T) { func Test_decodeState_decodeVaryingDataType(t *testing.T) { for _, tt := range varyingDataTypeTests { t.Run(tt.name, func(t *testing.T) { - dst := reflect.New(reflect.TypeOf(tt.in)).Elem().Interface() + dst, err := NewVaryingDataType(VDTValue{}, VDTValue1{}, VDTValue2{}, VDTValue3(0)) + if err != nil { + t.Errorf("%v", err) + return + } if err := Unmarshal(tt.want, &dst); (err != nil) != tt.wantErr { t.Errorf("decodeState.unmarshal() error = %v, wantErr %v", err, tt.wantErr) return } - diff := cmp.Diff(dst, tt.in, cmpopts.IgnoreUnexported(big.Int{}, VDTValue2{}, MyStructWithIgnore{})) + vdt := tt.in.(VaryingDataType) + diff := cmp.Diff(dst.Value(), vdt.Value(), cmpopts.IgnoreUnexported(big.Int{}, VDTValue2{}, MyStructWithIgnore{})) + if diff != "" { + t.Errorf("decodeState.unmarshal() = %s", diff) + } + }) + } +} + +func TestNewVaryingDataType(t *testing.T) { + type args struct { + values []VaryingDataTypeValue + } + tests := []struct { + name string + args args + wantVdt VaryingDataType + wantErr bool + }{ + { + args: args{ + values: []VaryingDataTypeValue{}, + }, + wantErr: true, + }, + { + args: args{ + values: []VaryingDataTypeValue{ + VDTValue{}, VDTValue1{}, VDTValue2{}, VDTValue3(0), + }, + }, + wantVdt: VaryingDataType{ + cache: map[uint]VaryingDataTypeValue{ + VDTValue{}.Index(): VDTValue{}, + VDTValue1{}.Index(): VDTValue1{}, + VDTValue2{}.Index(): VDTValue2{}, + VDTValue3(0).Index(): VDTValue3(0), + }, + }, + }, + { + args: args{ + values: []VaryingDataTypeValue{ + VDTValue{}, VDTValue1{}, VDTValue2{}, VDTValue3(0), VDTValue{}, + }, + }, + wantVdt: VaryingDataType{ + cache: map[uint]VaryingDataTypeValue{ + VDTValue{}.Index(): VDTValue{}, + VDTValue1{}.Index(): VDTValue1{}, + VDTValue2{}.Index(): VDTValue2{}, + VDTValue3(0).Index(): VDTValue3(0), + }, + }, + wantErr: true, + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + gotVdt, err := NewVaryingDataType(tt.args.values...) + if (err != nil) != tt.wantErr { + t.Errorf("NewVaryingDataType() error = %v, wantErr %v", err, tt.wantErr) + return + } + if !reflect.DeepEqual(gotVdt, tt.wantVdt) { + t.Errorf("NewVaryingDataType() = %v, want %v", gotVdt, tt.wantVdt) + } + }) + } +} + +func TestVaryingDataType_Set(t *testing.T) { + type args struct { + value VaryingDataTypeValue + } + tests := []struct { + name string + vdt VaryingDataType + args args + wantErr bool + }{ + { + vdt: mustNewVaryingDataType(VDTValue1{}), + args: args{ + value: VDTValue1{}, + }, + }, + { + vdt: mustNewVaryingDataType(VDTValue1{}, VDTValue2{}), + args: args{ + value: VDTValue1{}, + }, + }, + { + vdt: mustNewVaryingDataType(VDTValue1{}, VDTValue2{}), + args: args{ + value: VDTValue2{}, + }, + }, + { + vdt: mustNewVaryingDataType(VDTValue1{}, VDTValue2{}), + args: args{ + value: VDTValue3(0), + }, + wantErr: true, + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + vdt := tt.vdt + if err := vdt.Set(tt.args.value); (err != nil) != tt.wantErr { + t.Errorf("VaryingDataType.SetValue() error = %v, wantErr %v", err, tt.wantErr) + } + }) + } +} + +func TestVaryingDataTypeSlice_Add(t *testing.T) { + type args struct { + values []VaryingDataTypeValue + } + tests := []struct { + name string + vdts VaryingDataTypeSlice + args args + wantErr bool + wantValues []VaryingDataType + }{ + { + name: "happy path", + vdts: NewVaryingDataTypeSlice(MustNewVaryingDataType(VDTValue{}, VDTValue1{}, VDTValue2{}, VDTValue3(0))), + args: args{ + values: []VaryingDataTypeValue{ + VDTValue{ + B: 1, + }, + }, + }, + wantValues: []VaryingDataType{ + mustNewVaryingDataTypeAndSet( + VDTValue{ + B: 1, + }, + VDTValue{}, VDTValue1{}, VDTValue2{}, VDTValue3(0), + ), + }, + }, + { + name: "invalid value error case", + vdts: NewVaryingDataTypeSlice(MustNewVaryingDataType(VDTValue{}, VDTValue1{}, VDTValue2{})), + args: args{ + values: []VaryingDataTypeValue{ + VDTValue3(0), + }, + }, + wantValues: []VaryingDataType{}, + wantErr: true, + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + vdts := &tt.vdts + if err := vdts.Add(tt.args.values...); (err != nil) != tt.wantErr { + t.Errorf("VaryingDataTypeSlice.Add() error = %v, wantErr %v", err, tt.wantErr) + } + if !reflect.DeepEqual(vdts.Types, tt.wantValues) { + t.Errorf("NewVaryingDataType() = %v, want %v", vdts.Types, tt.wantValues) + } + }) + } +} + +var varyingDataTypeSliceTests = tests{ + { + in: mustNewVaryingDataTypeSliceAndSet( + mustNewVaryingDataType( + VDTValue{}, VDTValue1{}, VDTValue2{}, VDTValue3(0), + ), + VDTValue1{O: newBigIntPtr(big.NewInt(1073741823))}, + ), + want: newWant( + []byte{ + // length + 4, + // index + 2, + // value + 0x01, 0xfe, 0xff, 0xff, 0xff, + 0x00, + 0x00, + 0x00, + 0x00, + 0x00, + 0x00, + 0x00, + 0x00, + 0x00, + 0x00, + 0x00, + 0x00, + 0x00, + }, + ), + }, + { + in: mustNewVaryingDataTypeSliceAndSet( + mustNewVaryingDataType( + VDTValue{}, VDTValue1{}, VDTValue2{}, VDTValue3(0), + ), + VDTValue1{O: newBigIntPtr(big.NewInt(1073741823))}, + VDTValue{ + A: big.NewInt(1073741823), + B: int(1073741823), + C: uint(1073741823), + D: int8(1), + E: uint8(1), + F: int16(16383), + G: uint16(16383), + H: int32(1073741823), + I: uint32(1073741823), + J: int64(9223372036854775807), + K: uint64(9223372036854775807), + L: byteArray(64), + M: testStrings[1], + N: true, + }, + ), + want: newWant( + []byte{ + // length + 8, + }, + []byte{ + // index + 2, + // value + 0x01, 0xfe, 0xff, 0xff, 0xff, + 0x00, + 0x00, + 0x00, + 0x00, + 0x00, + 0x00, + 0x00, + 0x00, + 0x00, + 0x00, + 0x00, + 0x00, + 0x00, + }, + []byte{ + // index + 1, + // value + 0xfe, 0xff, 0xff, 0xff, + 0xfe, 0xff, 0xff, 0xff, + 0xfe, 0xff, 0xff, 0xff, + 0x01, + 0x01, + 0xff, 0x3f, + 0xff, 0x3f, + 0xff, 0xff, 0xff, 0x3f, + 0xff, 0xff, 0xff, 0x3f, + 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0x7f, + 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0x7f, + }, + append([]byte{0x01, 0x01}, byteArray(64)...), + append([]byte{0xC2, 0x02, 0x01, 0x00}, testStrings[1]...), + []byte{0x01}, + ), + }, +} + +func Test_encodeState_encodeVaryingDataTypeSlice(t *testing.T) { + for _, tt := range varyingDataTypeSliceTests { + t.Run(tt.name, func(t *testing.T) { + vdt := tt.in.(VaryingDataTypeSlice) + b, err := Marshal(vdt) + if (err != nil) != tt.wantErr { + t.Errorf("Marshal() error = %v, wantErr %v", err, tt.wantErr) + } + if !reflect.DeepEqual(b, tt.want) { + t.Errorf("Marshal() = %v, want %v", b, tt.want) + } + }) + } +} + +func Test_decodeState_decodeVaryingDataTypeSlice(t *testing.T) { + opt := cmp.Comparer(func(x, y VaryingDataType) bool { + return reflect.DeepEqual(x.value, y.value) && reflect.DeepEqual(x.cache, y.cache) + }) + + for _, tt := range varyingDataTypeSliceTests { + t.Run(tt.name, func(t *testing.T) { + dst := tt.in.(VaryingDataTypeSlice) + dst.Types = make([]VaryingDataType, 0) + if err := Unmarshal(tt.want, &dst); (err != nil) != tt.wantErr { + t.Errorf("Unmarshal() error = %v, wantErr %v", err, tt.wantErr) + return + } + vdts := tt.in.(VaryingDataTypeSlice) + diff := cmp.Diff(dst, vdts, cmpopts.IgnoreUnexported(big.Int{}, VDTValue2{}, MyStructWithIgnore{}), opt) if diff != "" { t.Errorf("decodeState.unmarshal() = %s", diff) }