Skip to content

Commit

Permalink
fix(lib/trie): Check for root in EncodeAndHash (#2359)
Browse files Browse the repository at this point in the history
EncodeAndHash now takes a parameter for whether the given node is root or not.

If a node is root, we return blake2b hash of its encoding as hash regardless of whether length of encoding is smaller or greater than 32
  • Loading branch information
kishansagathiya committed Mar 14, 2022
1 parent 8105cd4 commit 087db89
Show file tree
Hide file tree
Showing 5 changed files with 53 additions and 35 deletions.
8 changes: 4 additions & 4 deletions internal/trie/node/hash.go
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@ func (b *Branch) GetHash() []byte {
// the blake2b hash digest of the encoding of the branch.
// If the encoding is less than 32 bytes, the hash returned
// is the encoding and not the hash of the encoding.
func (b *Branch) EncodeAndHash() (encoding, hash []byte, err error) {
func (b *Branch) EncodeAndHash(isRoot bool) (encoding, hash []byte, err error) {
if !b.Dirty && b.Encoding != nil && b.HashDigest != nil {
return b.Encoding, b.HashDigest, nil
}
Expand All @@ -49,7 +49,7 @@ func (b *Branch) EncodeAndHash() (encoding, hash []byte, err error) {
copy(b.Encoding, bufferBytes)
encoding = b.Encoding // no need to copy

if buffer.Len() < 32 {
if !isRoot && buffer.Len() < 32 {
b.HashDigest = make([]byte, len(bufferBytes))
copy(b.HashDigest, bufferBytes)
hash = b.HashDigest // no need to copy
Expand Down Expand Up @@ -86,7 +86,7 @@ func (l *Leaf) GetHash() []byte {
// the blake2b hash digest of the encoding of the leaf.
// If the encoding is less than 32 bytes, the hash returned
// is the encoding and not the hash of the encoding.
func (l *Leaf) EncodeAndHash() (encoding, hash []byte, err error) {
func (l *Leaf) EncodeAndHash(isRoot bool) (encoding, hash []byte, err error) {
if !l.IsDirty() && l.Encoding != nil && l.HashDigest != nil {
return l.Encoding, l.HashDigest, nil
}
Expand All @@ -108,7 +108,7 @@ func (l *Leaf) EncodeAndHash() (encoding, hash []byte, err error) {
copy(l.Encoding, bufferBytes)
encoding = l.Encoding // no need to copy

if len(bufferBytes) < 32 {
if !isRoot && len(bufferBytes) < 32 {
l.HashDigest = make([]byte, len(bufferBytes))
copy(l.HashDigest, bufferBytes)
hash = l.HashDigest // no need to copy
Expand Down
42 changes: 40 additions & 2 deletions internal/trie/node/hash_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,7 @@ func Test_Branch_EncodeAndHash(t *testing.T) {
expectedBranch *Branch
encoding []byte
hash []byte
isRoot bool
errWrapped error
errMessage string
}{
Expand All @@ -56,6 +57,7 @@ func Test_Branch_EncodeAndHash(t *testing.T) {
},
encoding: []byte{0x80, 0x0, 0x0},
hash: []byte{0x80, 0x0, 0x0},
isRoot: false,
},
"small branch encoding": {
branch: &Branch{
Expand All @@ -68,6 +70,20 @@ func Test_Branch_EncodeAndHash(t *testing.T) {
},
encoding: []byte{0xc1, 0x1, 0x0, 0x0, 0x4, 0x2},
hash: []byte{0xc1, 0x1, 0x0, 0x0, 0x4, 0x2},
isRoot: false,
},
"small branch encoding for root node": {
branch: &Branch{
Key: []byte{1},
Value: []byte{2},
},
expectedBranch: &Branch{
Encoding: []byte{0xc1, 0x1, 0x0, 0x0, 0x4, 0x2},
HashDigest: []byte{0x48, 0x3c, 0xf6, 0x87, 0xcc, 0x5a, 0x60, 0x42, 0xd3, 0xcf, 0xa6, 0x91, 0xe6, 0x88, 0xfb, 0xdc, 0x1b, 0x38, 0x39, 0x5d, 0x6, 0x0, 0xbf, 0xc3, 0xb, 0x4b, 0x5d, 0x6a, 0x37, 0xd9, 0xc5, 0x1c}, // nolint: lll
},
encoding: []byte{0xc1, 0x1, 0x0, 0x0, 0x4, 0x2},
hash: []byte{0x48, 0x3c, 0xf6, 0x87, 0xcc, 0x5a, 0x60, 0x42, 0xd3, 0xcf, 0xa6, 0x91, 0xe6, 0x88, 0xfb, 0xdc, 0x1b, 0x38, 0x39, 0x5d, 0x6, 0x0, 0xbf, 0xc3, 0xb, 0x4b, 0x5d, 0x6a, 0x37, 0xd9, 0xc5, 0x1c}, // nolint: lll
isRoot: true,
},
"branch dirty with precomputed encoding and hash": {
branch: &Branch{
Expand All @@ -83,6 +99,7 @@ func Test_Branch_EncodeAndHash(t *testing.T) {
},
encoding: []byte{0xc1, 0x1, 0x0, 0x0, 0x4, 0x2},
hash: []byte{0xc1, 0x1, 0x0, 0x0, 0x4, 0x2},
isRoot: false,
},
"branch not dirty with precomputed encoding and hash": {
branch: &Branch{
Expand All @@ -100,6 +117,7 @@ func Test_Branch_EncodeAndHash(t *testing.T) {
},
encoding: []byte{3},
hash: []byte{4},
isRoot: false,
},
"large branch encoding": {
branch: &Branch{
Expand All @@ -111,6 +129,7 @@ func Test_Branch_EncodeAndHash(t *testing.T) {
},
encoding: []byte{0xbf, 0x2, 0x7, 0x77, 0x77, 0x77, 0x77, 0x77, 0x77, 0x77, 0x77, 0x77, 0x77, 0x77, 0x77, 0x77, 0x77, 0x77, 0x77, 0x77, 0x77, 0x77, 0x77, 0x77, 0x77, 0x77, 0x77, 0x77, 0x77, 0x77, 0x77, 0x77, 0x77, 0x77, 0x77, 0x0, 0x0}, //nolint:lll
hash: []byte{0x6b, 0xd8, 0xcc, 0xac, 0x71, 0x77, 0x44, 0x17, 0xfe, 0xe0, 0xde, 0xda, 0xd5, 0x97, 0x6e, 0x69, 0xeb, 0xe9, 0xdd, 0x80, 0x1d, 0x4b, 0x51, 0xf1, 0x5b, 0xf3, 0x4a, 0x93, 0x27, 0x32, 0x2c, 0xb0}, //nolint:lll
isRoot: false,
},
}

Expand All @@ -119,7 +138,7 @@ func Test_Branch_EncodeAndHash(t *testing.T) {
t.Run(name, func(t *testing.T) {
t.Parallel()

encoding, hash, err := testCase.branch.EncodeAndHash()
encoding, hash, err := testCase.branch.EncodeAndHash(testCase.isRoot)

assert.ErrorIs(t, err, testCase.errWrapped)
if testCase.errWrapped != nil {
Expand Down Expand Up @@ -167,6 +186,7 @@ func Test_Leaf_EncodeAndHash(t *testing.T) {
expectedLeaf *Leaf
encoding []byte
hash []byte
isRoot bool
errWrapped error
errMessage string
}{
Expand All @@ -178,6 +198,7 @@ func Test_Leaf_EncodeAndHash(t *testing.T) {
},
encoding: []byte{0x40, 0x0},
hash: []byte{0x40, 0x0},
isRoot: false,
},
"small leaf encoding": {
leaf: &Leaf{
Expand All @@ -190,6 +211,20 @@ func Test_Leaf_EncodeAndHash(t *testing.T) {
},
encoding: []byte{0x41, 0x1, 0x4, 0x2},
hash: []byte{0x41, 0x1, 0x4, 0x2},
isRoot: false,
},
"small leaf encoding for root node": {
leaf: &Leaf{
Key: []byte{1},
Value: []byte{2},
},
expectedLeaf: &Leaf{
Encoding: []byte{0x41, 0x1, 0x4, 0x2},
HashDigest: []byte{0x60, 0x51, 0x6d, 0xb, 0xb6, 0xe1, 0xbb, 0xfb, 0x12, 0x93, 0xf1, 0xb2, 0x76, 0xea, 0x95, 0x5, 0xe9, 0xf4, 0xa4, 0xe7, 0xd9, 0x8f, 0x62, 0xd, 0x5, 0x11, 0x5e, 0xb, 0x85, 0x27, 0x4a, 0xe1}, //nolint: lll
},
encoding: []byte{0x41, 0x1, 0x4, 0x2},
hash: []byte{0x60, 0x51, 0x6d, 0xb, 0xb6, 0xe1, 0xbb, 0xfb, 0x12, 0x93, 0xf1, 0xb2, 0x76, 0xea, 0x95, 0x5, 0xe9, 0xf4, 0xa4, 0xe7, 0xd9, 0x8f, 0x62, 0xd, 0x5, 0x11, 0x5e, 0xb, 0x85, 0x27, 0x4a, 0xe1}, // nolint: lll
isRoot: true,
},
"leaf dirty with precomputed encoding and hash": {
leaf: &Leaf{
Expand All @@ -205,6 +240,7 @@ func Test_Leaf_EncodeAndHash(t *testing.T) {
},
encoding: []byte{0x41, 0x1, 0x4, 0x2},
hash: []byte{0x41, 0x1, 0x4, 0x2},
isRoot: false,
},
"leaf not dirty with precomputed encoding and hash": {
leaf: &Leaf{
Expand All @@ -222,6 +258,7 @@ func Test_Leaf_EncodeAndHash(t *testing.T) {
},
encoding: []byte{3},
hash: []byte{4},
isRoot: false,
},
"large leaf encoding": {
leaf: &Leaf{
Expand All @@ -233,6 +270,7 @@ func Test_Leaf_EncodeAndHash(t *testing.T) {
},
encoding: []byte{0x7f, 0x2, 0x7, 0x77, 0x77, 0x77, 0x77, 0x77, 0x77, 0x77, 0x77, 0x77, 0x77, 0x77, 0x77, 0x77, 0x77, 0x77, 0x77, 0x77, 0x77, 0x77, 0x77, 0x77, 0x77, 0x77, 0x77, 0x77, 0x77, 0x77, 0x77, 0x77, 0x77, 0x77, 0x77, 0x0}, //nolint:lll
hash: []byte{0xfb, 0xae, 0x31, 0x4b, 0xef, 0x31, 0x9, 0xc7, 0x62, 0x99, 0x9d, 0x40, 0x9b, 0xd4, 0xdc, 0x64, 0xe7, 0x39, 0x46, 0x8b, 0xd3, 0xaf, 0xe8, 0x63, 0x9d, 0xf9, 0x41, 0x40, 0x76, 0x40, 0x10, 0xa3}, //nolint:lll
isRoot: false,
},
}

Expand All @@ -241,7 +279,7 @@ func Test_Leaf_EncodeAndHash(t *testing.T) {
t.Run(name, func(t *testing.T) {
t.Parallel()

encoding, hash, err := testCase.leaf.EncodeAndHash()
encoding, hash, err := testCase.leaf.EncodeAndHash(testCase.isRoot)

assert.ErrorIs(t, err, testCase.errWrapped)
if testCase.errWrapped != nil {
Expand Down
2 changes: 1 addition & 1 deletion internal/trie/node/node.go
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@ import "github.com/qdm12/gotree"
// Node is a node in the trie and can be a leaf or a branch.
type Node interface {
Encode(buffer Buffer) (err error) // TODO change to io.Writer
EncodeAndHash() (encoding []byte, hash []byte, err error)
EncodeAndHash(isRoot bool) (encoding []byte, hash []byte, err error)
ScaleEncodeHash() (encoding []byte, err error)
IsDirty() bool
SetDirty(dirty bool)
Expand Down
28 changes: 4 additions & 24 deletions lib/trie/database.go
Original file line number Diff line number Diff line change
Expand Up @@ -46,7 +46,7 @@ func (t *Trie) store(db chaindb.Batch, n Node) error {
return nil
}

encoding, hash, err := n.EncodeAndHash()
encoding, hash, err := n.EncodeAndHash(n == t.root)
if err != nil {
return err
}
Expand Down Expand Up @@ -97,7 +97,7 @@ func (t *Trie) loadFromProof(rawProof [][]byte, rootHash []byte) error {
decodedNode.SetDirty(dirty)
decodedNode.SetEncodingAndHash(rawNode, nil)

_, hash, err := decodedNode.EncodeAndHash()
_, hash, err := decodedNode.EncodeAndHash(false)
if err != nil {
return fmt.Errorf("cannot encode and hash node at index %d: %w", i, err)
}
Expand Down Expand Up @@ -370,23 +370,13 @@ func (t *Trie) writeDirty(db chaindb.Batch, n Node) error {
return nil
}

encoding, hash, err := n.EncodeAndHash()
encoding, hash, err := n.EncodeAndHash(n == t.root)
if err != nil {
return fmt.Errorf(
"cannot encode and hash node with hash 0x%x: %w",
n.GetHash(), err)
}

if n == t.root {
// hash root node even if its encoding is under 32 bytes
encodingDigest, err := common.Blake2bHash(encoding)
if err != nil {
return fmt.Errorf("cannot hash root node encoding: %w", err)
}

hash = encodingDigest[:]
}

err = db.Put(hash, encoding)
if err != nil {
return fmt.Errorf(
Expand Down Expand Up @@ -446,23 +436,13 @@ func (t *Trie) getInsertedNodeHashes(n Node, hashes map[common.Hash]struct{}) (e
return nil
}

encoding, hash, err := n.EncodeAndHash()
_, hash, err := n.EncodeAndHash(n == t.root)
if err != nil {
return fmt.Errorf(
"cannot encode and hash node with hash 0x%x: %w",
n.GetHash(), err)
}

if n == t.root && len(encoding) < 32 {
// hash root node even if its encoding is under 32 bytes
encodingDigest, err := common.Blake2bHash(encoding)
if err != nil {
return fmt.Errorf("cannot hash root node encoding: %w", err)
}

hash = encodingDigest[:]
}

hashes[common.BytesToHash(hash)] = struct{}{}

switch n.Type() {
Expand Down
8 changes: 4 additions & 4 deletions lib/trie/lookup.go
Original file line number Diff line number Diff line change
Expand Up @@ -18,11 +18,11 @@ type recorder interface {

// findAndRecord search for a desired key recording all the nodes in the path including the desired node
func findAndRecord(t *Trie, key []byte, recorder recorder) error {
return find(t.root, key, recorder)
return find(t.root, key, recorder, true)
}

func find(parent Node, key []byte, recorder recorder) error {
enc, hash, err := parent.EncodeAndHash()
func find(parent Node, key []byte, recorder recorder, isCurrentRoot bool) error {
enc, hash, err := parent.EncodeAndHash(isCurrentRoot)
if err != nil {
return err
}
Expand All @@ -49,5 +49,5 @@ func find(parent Node, key []byte, recorder recorder) error {
return nil
}

return find(b.Children[key[length]], key[length+1:], recorder)
return find(b.Children[key[length]], key[length+1:], recorder, false)
}

0 comments on commit 087db89

Please sign in to comment.