diff --git a/cl/cltypes/beacon_header.go b/cl/cltypes/beacon_header.go index ba72b328b24..107a1c46fd2 100644 --- a/cl/cltypes/beacon_header.go +++ b/cl/cltypes/beacon_header.go @@ -25,12 +25,7 @@ func (b *BeaconBlockHeader) EncodeSSZ(dst []byte) ([]byte, error) { } func (b *BeaconBlockHeader) DecodeSSZ(buf []byte) error { - b.Slot = ssz.UnmarshalUint64SSZ(buf) - b.ProposerIndex = ssz.UnmarshalUint64SSZ(buf[8:]) - copy(b.ParentRoot[:], buf[16:]) - copy(b.Root[:], buf[48:]) - copy(b.BodyRoot[:], buf[80:]) - return nil + return ssz.Decode(b, buf) } func (b *BeaconBlockHeader) HashSSZ() ([32]byte, error) { @@ -54,12 +49,7 @@ func (b *SignedBeaconBlockHeader) EncodeSSZ(dst []byte) ([]byte, error) { } func (b *SignedBeaconBlockHeader) DecodeSSZ(buf []byte) error { - b.Header = new(BeaconBlockHeader) - if err := b.Header.DecodeSSZ(buf); err != nil { - return err - } - copy(b.Signature[:], buf[b.Header.EncodingSizeSSZ():]) - return nil + return ssz.Decode(b, buf) } func (b *SignedBeaconBlockHeader) HashSSZ() ([32]byte, error) { diff --git a/cl/cltypes/bls_to_execution_change.go b/cl/cltypes/bls_to_execution_change.go index bdceac096be..c047ea58045 100644 --- a/cl/cltypes/bls_to_execution_change.go +++ b/cl/cltypes/bls_to_execution_change.go @@ -22,13 +22,7 @@ func (b *BLSToExecutionChange) HashSSZ() ([32]byte, error) { } func (b *BLSToExecutionChange) DecodeSSZ(buf []byte) error { - if len(buf) < b.EncodingSizeSSZ() { - return ssz.ErrLowBufferSize - } - b.ValidatorIndex = ssz.UnmarshalUint64SSZ(buf) - copy(b.From[:], buf[8:]) - copy(b.To[:], buf[56:]) - return nil + return ssz.Decode(b, buf) } func (*BLSToExecutionChange) EncodingSizeSSZ() int { @@ -45,15 +39,7 @@ func (s *SignedBLSToExecutionChange) EncodeSSZ(buf []byte) ([]byte, error) { } func (s *SignedBLSToExecutionChange) DecodeSSZ(buf []byte) error { - if len(buf) < s.EncodingSizeSSZ() { - return ssz.ErrLowBufferSize - } - s.Message = new(BLSToExecutionChange) - if err := s.Message.DecodeSSZ(buf); err != nil { - return err - } - copy(s.Signature[:], buf[s.Message.EncodingSizeSSZ():]) - return nil + return ssz.Decode(s, buf) } func (s *SignedBLSToExecutionChange) DecodeSSZWithVersion(buf []byte, _ int) error { diff --git a/cl/cltypes/checkpoint.go b/cl/cltypes/checkpoint.go index 0d858fb4869..6726df7785d 100644 --- a/cl/cltypes/checkpoint.go +++ b/cl/cltypes/checkpoint.go @@ -22,15 +22,7 @@ func (c *Checkpoint) EncodeSSZ(buf []byte) ([]byte, error) { } func (c *Checkpoint) DecodeSSZ(buf []byte) error { - var err error - size := uint64(len(buf)) - if size < uint64(c.EncodingSizeSSZ()) { - return ssz.ErrLowBufferSize - } - c.Epoch = ssz.UnmarshalUint64SSZ(buf[0:8]) - copy(c.Root[:], buf[8:40]) - - return err + return ssz.Decode(c, buf) } func (c *Checkpoint) EncodingSizeSSZ() int { diff --git a/cl/cltypes/eth1_data.go b/cl/cltypes/eth1_data.go index 10d165d588e..c154a78c46c 100644 --- a/cl/cltypes/eth1_data.go +++ b/cl/cltypes/eth1_data.go @@ -25,17 +25,7 @@ func (e *Eth1Data) EncodeSSZ(buf []byte) ([]byte, error) { // DecodeSSZ ssz unmarshals the Eth1Data object func (e *Eth1Data) DecodeSSZ(buf []byte) error { - var err error - size := uint64(len(buf)) - if size < 72 { - return ssz.ErrLowBufferSize - } - - copy(e.Root[:], buf[0:32]) - e.DepositCount = ssz.UnmarshalUint64SSZ(buf[32:40]) - copy(e.BlockHash[:], buf[40:72]) - - return err + return ssz.Decode(e, buf) } func (e *Eth1Data) DecodeSSZWithVersion(buf []byte, _ int) error { diff --git a/cl/cltypes/fork.go b/cl/cltypes/fork.go index f4a999dcbec..9692b6fd2db 100644 --- a/cl/cltypes/fork.go +++ b/cl/cltypes/fork.go @@ -18,13 +18,7 @@ func (f *Fork) EncodeSSZ(dst []byte) ([]byte, error) { } func (f *Fork) DecodeSSZ(buf []byte) error { - if len(buf) < f.EncodingSizeSSZ() { - return ssz.ErrLowBufferSize - } - copy(f.PreviousVersion[:], buf) - copy(f.CurrentVersion[:], buf[clparams.VersionLength:]) - f.Epoch = ssz.UnmarshalUint64SSZ(buf[clparams.VersionLength*2:]) - return nil + return ssz.Decode(f, buf) } func (f *Fork) EncodingSizeSSZ() int { diff --git a/cl/cltypes/historical_summary.go b/cl/cltypes/historical_summary.go index 34f2e864266..2b7673e1189 100644 --- a/cl/cltypes/historical_summary.go +++ b/cl/cltypes/historical_summary.go @@ -17,12 +17,7 @@ func (h *HistoricalSummary) EncodeSSZ(buf []byte) ([]byte, error) { } func (h *HistoricalSummary) DecodeSSZ(buf []byte) error { - if len(buf) < h.EncodingSizeSSZ() { - return ssz.ErrLowBufferSize - } - copy(h.BlockSummaryRoot[:], buf) - copy(h.StateSummaryRoot[:], buf[length.Hash:]) - return nil + return ssz.Decode(h, buf) } func (h *HistoricalSummary) DecodeSSZWithVersion(buf []byte, _ int) error { diff --git a/cl/cltypes/network.go b/cl/cltypes/network.go index 8aec6282939..26d1392d880 100644 --- a/cl/cltypes/network.go +++ b/cl/cltypes/network.go @@ -50,16 +50,15 @@ func (m *Metadata) DecodeSSZWithVersion(buf []byte, _ int) error { // Ping is a test P2P message, used to test out liveness of our peer/signaling disconnection. type Ping struct { - Id uint64 + Id uint64 `ssz:"true"` } func (p *Ping) EncodeSSZ(buf []byte) ([]byte, error) { - return append(buf, ssz.Uint64SSZ(p.Id)...), nil + return ssz.Encode(p, buf) } func (p *Ping) DecodeSSZ(buf []byte) error { - p.Id = ssz.UnmarshalUint64SSZ(buf) - return nil + return ssz.Decode(p, buf) } func (p *Ping) EncodingSizeSSZ() int { @@ -72,16 +71,15 @@ func (p *Ping) DecodeSSZWithVersion(buf []byte, _ int) error { // P2P Message for bootstrap type SingleRoot struct { - Root [32]byte + Root [32]byte `ssz:"true"` } func (s *SingleRoot) EncodeSSZ(buf []byte) ([]byte, error) { - return append(buf, s.Root[:]...), nil + return ssz.Encode(s, buf) } func (s *SingleRoot) DecodeSSZ(buf []byte) error { - copy(s.Root[:], buf) - return nil + return ssz.Decode(s, buf) } func (s *SingleRoot) EncodingSizeSSZ() int { @@ -118,9 +116,7 @@ func (l *LightClientUpdatesByRangeRequest) EncodeSSZ(buf []byte) ([]byte, error) } func (l *LightClientUpdatesByRangeRequest) DecodeSSZ(buf []byte) error { - l.Period = ssz.UnmarshalUint64SSZ(buf) - l.Count = ssz.UnmarshalUint64SSZ(buf[8:]) - return nil + return ssz.Decode(l, buf) } func (l *LightClientUpdatesByRangeRequest) EncodingSizeSSZ() int { @@ -141,10 +137,7 @@ func (b *BeaconBlocksByRangeRequest) EncodeSSZ(buf []byte) ([]byte, error) { } func (b *BeaconBlocksByRangeRequest) DecodeSSZ(buf []byte) error { - b.StartSlot = ssz.UnmarshalUint64SSZ(buf) - b.Count = ssz.UnmarshalUint64SSZ(buf[8:]) - b.Step = ssz.UnmarshalUint64SSZ(buf[16:]) - return nil + return ssz.Decode(b, buf) } func (b *BeaconBlocksByRangeRequest) DecodeSSZWithVersion(buf []byte, _ int) error { @@ -176,12 +169,7 @@ func (s *Status) EncodeSSZ(buf []byte) ([]byte, error) { } func (s *Status) DecodeSSZ(buf []byte) error { - copy(s.ForkDigest[:], buf) - copy(s.FinalizedRoot[:], buf[4:]) - s.FinalizedEpoch = ssz.UnmarshalUint64SSZ(buf[36:]) - copy(s.HeadRoot[:], buf[44:]) - s.HeadSlot = ssz.UnmarshalUint64SSZ(buf[76:]) - return nil + return ssz.Decode(s, buf) } func (s *Status) DecodeSSZWithVersion(buf []byte, _ int) error { diff --git a/cl/cltypes/network_test.go b/cl/cltypes/network_test.go index 0af1cbdd523..8572acec71c 100644 --- a/cl/cltypes/network_test.go +++ b/cl/cltypes/network_test.go @@ -1,6 +1,7 @@ package cltypes_test import ( + "fmt" "testing" libcommon "github.com/ledgerwatch/erigon-lib/common" @@ -115,6 +116,7 @@ func TestMarshalNetworkTypes(t *testing.T) { &cltypes.LightClientBootstrap{}, } for i, tc := range cases { + fmt.Println("c") marshalledBytes, err := tc.EncodeSSZ(nil) require.NoError(t, err) require.Equal(t, len(marshalledBytes), tc.EncodingSizeSSZ()) diff --git a/cl/cltypes/slashings.go b/cl/cltypes/slashings.go index 9ddd0381be5..4ba614e9c28 100644 --- a/cl/cltypes/slashings.go +++ b/cl/cltypes/slashings.go @@ -15,12 +15,7 @@ func (p *ProposerSlashing) EncodeSSZ(dst []byte) ([]byte, error) { } func (p *ProposerSlashing) DecodeSSZ(buf []byte) error { - p.Header1 = new(SignedBeaconBlockHeader) - p.Header2 = new(SignedBeaconBlockHeader) - if err := p.Header1.DecodeSSZ(buf); err != nil { - return err - } - return p.Header2.DecodeSSZ(buf[p.Header1.EncodingSizeSSZ():]) + return ssz.Decode(p, buf) } func (p *ProposerSlashing) DecodeSSZWithVersion(buf []byte, _ int) error { diff --git a/cl/cltypes/ssz/decode.go b/cl/cltypes/ssz/decode.go new file mode 100644 index 00000000000..50c6ecc9fd7 --- /dev/null +++ b/cl/cltypes/ssz/decode.go @@ -0,0 +1,112 @@ +package ssz + +import ( + "reflect" + + "github.com/ledgerwatch/erigon-lib/common" + libcommon "github.com/ledgerwatch/erigon-lib/common" + "github.com/ledgerwatch/erigon-lib/common/length" +) + +// This package is a working progress. only base functionality is supported. no more no less. + +// Decodes just decodes a specific struct. +func Decode(x any, buf []byte) error { + // Get values in struct. + reflValue := reflect.ValueOf(x) + reflType := reflect.TypeOf(x) + + // Iterate over all the fields. + _, err := decodeValue(reflValue, reflType, buf) + return err + +} + +func decodeValue(reflValue reflect.Value, reflType reflect.Type, buf []byte) (int, error) { + if reflValue.Kind() == reflect.Ptr { + reflValue = reflValue.Elem() + reflType = reflType.Elem() + } + pos := 0 + for i := 0; i < reflValue.NumField(); i++ { + // Process each field. + field := reflValue.Field(i) + if !field.CanInterface() || reflType.Field(i).Tag.Get(TagSSZ) != TagTrueFlag { + continue + } + switch fieldVal := field.Interface().(type) { + // Base field can just be appended + case uint64: + if len(buf) < 8 { + return 0, ErrLowBufferSize + } + num := UnmarshalUint64SSZ(buf[pos:]) + field.Set(reflect.ValueOf(num)) + pos += 8 + case libcommon.Hash: + if len(buf) < length.Hash { + return 0, ErrLowBufferSize + } + field.Set(reflect.ValueOf(common.BytesToHash(buf[pos : pos+length.Hash]))) + pos += length.Hash + // Will be fixed in the future. + case [32]byte: + if len(buf) < length.Hash { + return 0, ErrLowBufferSize + } + field.Set(reflect.ValueOf(common.BytesToHash(buf[pos : pos+length.Hash]))) + pos += length.Hash + case [48]byte: + var val [48]byte + if len(buf) < 48 { + return 0, ErrLowBufferSize + } + copy(val[:], buf[pos:]) + field.Set(reflect.ValueOf(val)) + pos += 48 + case [96]byte: + var val [96]byte + if len(buf) < 96 { + return 0, ErrLowBufferSize + } + copy(val[:], buf[pos:]) + field.Set(reflect.ValueOf(val)) + pos += 96 + case [4]byte: + var val [4]byte + if len(buf) < 4 { + return 0, ErrLowBufferSize + } + copy(val[:], buf[pos:]) + field.Set(reflect.ValueOf(val)) + pos += 4 + case libcommon.Address: + if len(buf) < length.Addr { + return 0, ErrLowBufferSize + } + field.Set(reflect.ValueOf(common.BytesToAddress(buf[pos : pos+length.Addr]))) + pos += length.Addr + case bool: + if len(buf) == 0 { + return 0, ErrLowBufferSize + } + if buf[pos] == 0x01 { + field.Set(reflect.ValueOf(true)) + } else { + field.Set(reflect.ValueOf(false)) + } + pos++ + default: + // Create pointer to default. + t := reflect.TypeOf(fieldVal) + v := reflect.New(t.Elem()) + n, err := decodeValue(v, t, buf[pos:]) + if err != nil { + return 0, err + } + field.Set(v) + pos += n + } + } + return pos, nil +} diff --git a/cl/cltypes/validator.go b/cl/cltypes/validator.go index 2deec79b938..bb272e78138 100644 --- a/cl/cltypes/validator.go +++ b/cl/cltypes/validator.go @@ -28,11 +28,7 @@ func (d *DepositData) EncodeSSZ(dst []byte) ([]byte, error) { } func (d *DepositData) DecodeSSZ(buf []byte) error { - copy(d.PubKey[:], buf) - copy(d.WithdrawalCredentials[:], buf[48:]) - d.Amount = ssz.UnmarshalUint64SSZ(buf[80:]) - copy(d.Signature[:], buf[88:]) - return nil + return ssz.Decode(d, buf) } func (d *DepositData) EncodingSizeSSZ() int { @@ -321,18 +317,7 @@ func (v *Validator) DecodeSSZWithVersion(buf []byte, _ int) error { } func (v *Validator) DecodeSSZ(buf []byte) error { - if len(buf) < v.EncodingSizeSSZ() { - return ssz.ErrLowBufferSize - } - copy(v.PublicKey[:], buf) - copy(v.WithdrawalCredentials[:], buf[48:]) - v.EffectiveBalance = ssz.UnmarshalUint64SSZ(buf[80:]) - v.Slashed = buf[88] == 1 - v.ActivationEligibilityEpoch = ssz.UnmarshalUint64SSZ(buf[89:]) - v.ActivationEpoch = ssz.UnmarshalUint64SSZ(buf[97:]) - v.ExitEpoch = ssz.UnmarshalUint64SSZ(buf[105:]) - v.WithdrawableEpoch = ssz.UnmarshalUint64SSZ(buf[113:]) - return nil + return ssz.Decode(v, buf) } func (v *Validator) EncodingSizeSSZ() int {