Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Remove merkledb codec struct #2883

Merged
merged 4 commits into from
Mar 29, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
144 changes: 59 additions & 85 deletions x/merkledb/codec.go
Original file line number Diff line number Diff line change
Expand Up @@ -36,8 +36,6 @@ const (
)

var (
_ encoderDecoder = (*codecImpl)(nil)

trueBytes = []byte{trueByte}
falseBytes = []byte{falseByte}

Expand All @@ -49,131 +47,107 @@ var (
errIntOverflow = errors.New("value overflows int")
)

// encoderDecoder defines the interface needed by merkleDB to marshal
// and unmarshal relevant types.
type encoderDecoder interface {
encoder
decoder
}

type encoder interface {
// Assumes [n] is non-nil.
encodeDBNode(n *dbNode) []byte
encodedDBNodeSize(n *dbNode) int

// Returns the bytes that will be hashed to generate [n]'s ID.
// Assumes [n] is non-nil.
encodeHashValues(n *node) []byte
encodeKey(key Key) []byte
}

type decoder interface {
// Assumes [n] is non-nil.
decodeDBNode(bytes []byte, n *dbNode) error
decodeKey(bytes []byte) (Key, error)
}

func newCodec() encoderDecoder {
return &codecImpl{}
}

// Note that bytes.Buffer.Write always returns nil, so we
// can ignore its return values in [codecImpl] methods.
type codecImpl struct{}
// Note that bytes.Buffer.Write always returns nil, so we ignore its return
// values in all encode methods.

func (c *codecImpl) childSize(index byte, childEntry *child) int {
func childSize(index byte, childEntry *child) int {
// * index
// * child ID
// * child key
// * bool indicating whether the child has a value
return c.uintSize(uint64(index)) + ids.IDLen + c.keySize(childEntry.compressedKey) + boolLen
return uintSize(uint64(index)) + ids.IDLen + keySize(childEntry.compressedKey) + boolLen
}

// based on the current implementation of codecImpl.encodeUint which uses binary.PutUvarint
func (*codecImpl) uintSize(value uint64) int {
// based on the implementation of encodeUint which uses binary.PutUvarint
func uintSize(value uint64) int {
if value == 0 {
return 1
}
return (bits.Len64(value) + 6) / 7
}

func (c *codecImpl) keySize(p Key) int {
return c.uintSize(uint64(p.length)) + bytesNeeded(p.length)
func keySize(p Key) int {
return uintSize(uint64(p.length)) + bytesNeeded(p.length)
}

func (c *codecImpl) encodedDBNodeSize(n *dbNode) int {
// Assumes [n] is non-nil.
func encodedDBNodeSize(n *dbNode) int {
// * number of children
// * bool indicating whether [n] has a value
// * the value (optional)
// * children
size := c.uintSize(uint64(len(n.children))) + boolLen
size := uintSize(uint64(len(n.children))) + boolLen
if n.value.HasValue() {
valueLen := len(n.value.Value())
size += c.uintSize(uint64(valueLen)) + valueLen
size += uintSize(uint64(valueLen)) + valueLen
}
// for each non-nil entry, we add the additional size of the child entry
for index, entry := range n.children {
size += c.childSize(index, entry)
size += childSize(index, entry)
}
return size
}

func (c *codecImpl) encodeDBNode(n *dbNode) []byte {
buf := bytes.NewBuffer(make([]byte, 0, c.encodedDBNodeSize(n)))
c.encodeMaybeByteSlice(buf, n.value)
c.encodeUint(buf, uint64(len(n.children)))
// Assumes [n] is non-nil.
func encodeDBNode(n *dbNode) []byte {
buf := bytes.NewBuffer(make([]byte, 0, encodedDBNodeSize(n)))
encodeMaybeByteSlice(buf, n.value)
encodeUint(buf, uint64(len(n.children)))
// Note we insert children in order of increasing index
// for determinism.
keys := maps.Keys(n.children)
slices.Sort(keys)
for _, index := range keys {
entry := n.children[index]
c.encodeUint(buf, uint64(index))
c.encodeKeyToBuffer(buf, entry.compressedKey)
encodeUint(buf, uint64(index))
encodeKeyToBuffer(buf, entry.compressedKey)
_, _ = buf.Write(entry.id[:])
c.encodeBool(buf, entry.hasValue)
encodeBool(buf, entry.hasValue)
}
return buf.Bytes()
}

func (c *codecImpl) encodeHashValues(n *node) []byte {
// Returns the bytes that will be hashed to generate [n]'s ID.
// Assumes [n] is non-nil.
func encodeHashValues(n *node) []byte {
var (
numChildren = len(n.children)
// Estimate size [hv] to prevent memory allocations
estimatedLen = minVarIntLen + numChildren*hashValuesChildLen + estimatedValueLen + estimatedKeyLen
buf = bytes.NewBuffer(make([]byte, 0, estimatedLen))
)

c.encodeUint(buf, uint64(numChildren))
encodeUint(buf, uint64(numChildren))

// ensure that the order of entries is consistent
keys := maps.Keys(n.children)
slices.Sort(keys)
for _, index := range keys {
entry := n.children[index]
c.encodeUint(buf, uint64(index))
encodeUint(buf, uint64(index))
_, _ = buf.Write(entry.id[:])
}
c.encodeMaybeByteSlice(buf, n.valueDigest)
c.encodeKeyToBuffer(buf, n.key)
encodeMaybeByteSlice(buf, n.valueDigest)
encodeKeyToBuffer(buf, n.key)

return buf.Bytes()
}

func (c *codecImpl) decodeDBNode(b []byte, n *dbNode) error {
// Assumes [n] is non-nil.
func decodeDBNode(b []byte, n *dbNode) error {
if minDBNodeLen > len(b) {
return io.ErrUnexpectedEOF
}

src := bytes.NewReader(b)

value, err := c.decodeMaybeByteSlice(src)
value, err := decodeMaybeByteSlice(src)
if err != nil {
return err
}
n.value = value

numChildren, err := c.decodeUint(src)
numChildren, err := decodeUint(src)
switch {
case err != nil:
return err
Expand All @@ -184,7 +158,7 @@ func (c *codecImpl) decodeDBNode(b []byte, n *dbNode) error {
n.children = make(map[byte]*child, numChildren)
var previousChild uint64
for i := uint64(0); i < numChildren; i++ {
index, err := c.decodeUint(src)
index, err := decodeUint(src)
if err != nil {
return err
}
Expand All @@ -193,15 +167,15 @@ func (c *codecImpl) decodeDBNode(b []byte, n *dbNode) error {
}
previousChild = index

compressedKey, err := c.decodeKeyFromReader(src)
compressedKey, err := decodeKeyFromReader(src)
if err != nil {
return err
}
childID, err := c.decodeID(src)
childID, err := decodeID(src)
if err != nil {
return err
}
hasValue, err := c.decodeBool(src)
hasValue, err := decodeBool(src)
if err != nil {
return err
}
Expand All @@ -217,15 +191,15 @@ func (c *codecImpl) decodeDBNode(b []byte, n *dbNode) error {
return nil
}

func (*codecImpl) encodeBool(dst *bytes.Buffer, value bool) {
func encodeBool(dst *bytes.Buffer, value bool) {
bytesValue := falseBytes
if value {
bytesValue = trueBytes
}
_, _ = dst.Write(bytesValue)
}

func (*codecImpl) decodeBool(src *bytes.Reader) (bool, error) {
func decodeBool(src *bytes.Reader) (bool, error) {
boolByte, err := src.ReadByte()
switch {
case err == io.EOF:
Expand All @@ -241,7 +215,7 @@ func (*codecImpl) decodeBool(src *bytes.Reader) (bool, error) {
}
}

func (*codecImpl) decodeUint(src *bytes.Reader) (uint64, error) {
func decodeUint(src *bytes.Reader) (uint64, error) {
// To ensure encoding/decoding is canonical, we need to check for leading
// zeroes in the varint.
// The last byte of the varint we read is the most significant byte.
Expand Down Expand Up @@ -274,43 +248,43 @@ func (*codecImpl) decodeUint(src *bytes.Reader) (uint64, error) {
return val64, nil
}

func (*codecImpl) encodeUint(dst *bytes.Buffer, value uint64) {
func encodeUint(dst *bytes.Buffer, value uint64) {
var buf [binary.MaxVarintLen64]byte
size := binary.PutUvarint(buf[:], value)
_, _ = dst.Write(buf[:size])
}

func (c *codecImpl) encodeMaybeByteSlice(dst *bytes.Buffer, maybeValue maybe.Maybe[[]byte]) {
func encodeMaybeByteSlice(dst *bytes.Buffer, maybeValue maybe.Maybe[[]byte]) {
hasValue := maybeValue.HasValue()
c.encodeBool(dst, hasValue)
encodeBool(dst, hasValue)
if hasValue {
c.encodeByteSlice(dst, maybeValue.Value())
encodeByteSlice(dst, maybeValue.Value())
}
}

func (c *codecImpl) decodeMaybeByteSlice(src *bytes.Reader) (maybe.Maybe[[]byte], error) {
func decodeMaybeByteSlice(src *bytes.Reader) (maybe.Maybe[[]byte], error) {
if minMaybeByteSliceLen > src.Len() {
return maybe.Nothing[[]byte](), io.ErrUnexpectedEOF
}

if hasValue, err := c.decodeBool(src); err != nil || !hasValue {
if hasValue, err := decodeBool(src); err != nil || !hasValue {
return maybe.Nothing[[]byte](), err
}

rawBytes, err := c.decodeByteSlice(src)
rawBytes, err := decodeByteSlice(src)
if err != nil {
return maybe.Nothing[[]byte](), err
}

return maybe.Some(rawBytes), nil
}

func (c *codecImpl) decodeByteSlice(src *bytes.Reader) ([]byte, error) {
func decodeByteSlice(src *bytes.Reader) ([]byte, error) {
if minByteSliceLen > src.Len() {
return nil, io.ErrUnexpectedEOF
}

length, err := c.decodeUint(src)
length, err := decodeUint(src)
switch {
case err == io.EOF:
return nil, io.ErrUnexpectedEOF
Expand All @@ -330,14 +304,14 @@ func (c *codecImpl) decodeByteSlice(src *bytes.Reader) ([]byte, error) {
return result, err
}

func (c *codecImpl) encodeByteSlice(dst *bytes.Buffer, value []byte) {
c.encodeUint(dst, uint64(len(value)))
func encodeByteSlice(dst *bytes.Buffer, value []byte) {
encodeUint(dst, uint64(len(value)))
if value != nil {
_, _ = dst.Write(value)
}
}

func (*codecImpl) decodeID(src *bytes.Reader) (ids.ID, error) {
func decodeID(src *bytes.Reader) (ids.ID, error) {
if ids.IDLen > src.Len() {
return ids.ID{}, io.ErrUnexpectedEOF
}
Expand All @@ -350,21 +324,21 @@ func (*codecImpl) decodeID(src *bytes.Reader) (ids.ID, error) {
return id, err
}

func (c *codecImpl) encodeKey(key Key) []byte {
func encodeKey(key Key) []byte {
estimatedLen := binary.MaxVarintLen64 + len(key.Bytes())
dst := bytes.NewBuffer(make([]byte, 0, estimatedLen))
c.encodeKeyToBuffer(dst, key)
encodeKeyToBuffer(dst, key)
return dst.Bytes()
}

func (c *codecImpl) encodeKeyToBuffer(dst *bytes.Buffer, key Key) {
c.encodeUint(dst, uint64(key.length))
func encodeKeyToBuffer(dst *bytes.Buffer, key Key) {
encodeUint(dst, uint64(key.length))
_, _ = dst.Write(key.Bytes())
}

func (c *codecImpl) decodeKey(b []byte) (Key, error) {
func decodeKey(b []byte) (Key, error) {
src := bytes.NewReader(b)
key, err := c.decodeKeyFromReader(src)
key, err := decodeKeyFromReader(src)
if err != nil {
return Key{}, err
}
Expand All @@ -374,12 +348,12 @@ func (c *codecImpl) decodeKey(b []byte) (Key, error) {
return key, err
}

func (c *codecImpl) decodeKeyFromReader(src *bytes.Reader) (Key, error) {
func decodeKeyFromReader(src *bytes.Reader) (Key, error) {
if minKeyLen > src.Len() {
return Key{}, io.ErrUnexpectedEOF
}

length, err := c.decodeUint(src)
length, err := decodeUint(src)
if err != nil {
return Key{}, err
}
Expand Down
Loading
Loading