Skip to content

Commit

Permalink
fix(trie): Refactor
Browse files Browse the repository at this point in the history
fix(trie): Add missing copyrights

fix(trie): Hash conversion

fix(trie): Simplify memorydb
  • Loading branch information
dimartiro committed Jul 28, 2023
1 parent de4c7eb commit afcc752
Show file tree
Hide file tree
Showing 25 changed files with 737 additions and 419 deletions.
16 changes: 16 additions & 0 deletions internal/trie/hashdb/hashdb.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,16 @@
// Copyright 2022 ChainSafe Systems (ON)
// SPDX-License-Identifier: LGPL-3.0-only

package hashdb

import "github.com/ChainSafe/gossamer/lib/common"

type Prefix struct {
Data []byte
Padded *byte
}

type HashDB interface {
Get(key []byte) (value []byte, err error)
Insert(prefix Prefix, value []byte) common.Hash
}
26 changes: 4 additions & 22 deletions lib/trie/db/memory.go → internal/trie/memorydb/memory.go
Original file line number Diff line number Diff line change
@@ -1,18 +1,16 @@
// Copyright 2023 ChainSafe Systems (ON)
// SPDX-License-Identifier: LGPL-3.0-only

package db
package memorydb

import (
"bytes"
"fmt"

"github.com/ChainSafe/gossamer/internal/trie/hashdb"
"github.com/ChainSafe/gossamer/lib/common"
"github.com/ChainSafe/gossamer/lib/trie"
)

type KeyFunction func(key common.Hash, prefix trie.Prefix) common.Hash

type MemoryDBItem struct {
data []byte
//Reference count
Expand All @@ -23,7 +21,6 @@ type MemoryDB struct {
data map[common.Hash]MemoryDBItem
hashedNullNode common.Hash
nullNodeData []byte
keyFunction KeyFunction
}

func NewMemoryDB() *MemoryDB {
Expand All @@ -37,7 +34,6 @@ func newMemoryDBFromNullNode(nullKey []byte, nullNodeData []byte) *MemoryDB {
data: make(map[common.Hash]MemoryDBItem),
hashedNullNode: hashedKey,
nullNodeData: nullNodeData,
keyFunction: hashKey,
}
}

Expand All @@ -55,16 +51,7 @@ func (mdb *MemoryDB) Get(key []byte) (value []byte, err error) {
return nil, nil
}

func (mdb *MemoryDB) GetWithPrefix(key []byte, prefix trie.Prefix) (value []byte, err error) {
if bytes.Equal(key, mdb.hashedNullNode[:]) {
return mdb.nullNodeData, nil
}

computatedKey := mdb.keyFunction(common.Hash(key[:]), prefix)
return mdb.Get(computatedKey[:])
}

func (mdb *MemoryDB) Insert(prefix trie.Prefix, value []byte) common.Hash {
func (mdb *MemoryDB) Insert(prefix hashdb.Prefix, value []byte) common.Hash {
if bytes.Equal(value, mdb.nullNodeData) {
return mdb.hashedNullNode
}
Expand All @@ -74,12 +61,11 @@ func (mdb *MemoryDB) Insert(prefix trie.Prefix, value []byte) common.Hash {
return key
}

func (mdb *MemoryDB) emplace(key common.Hash, prefix trie.Prefix, value []byte) {
func (mdb *MemoryDB) emplace(key common.Hash, prefix hashdb.Prefix, value []byte) {
if bytes.Equal(value, mdb.nullNodeData) {
return
}

key = mdb.keyFunction(key, prefix)
data, ok := mdb.data[key]
if !ok {
mdb.data[key] = MemoryDBItem{value, 0}
Expand All @@ -91,7 +77,3 @@ func (mdb *MemoryDB) emplace(key common.Hash, prefix trie.Prefix, value []byte)
}
data.rc++
}

func hashKey(key common.Hash, prefix trie.Prefix) common.Hash {
return key
}
22 changes: 11 additions & 11 deletions internal/trie/node/decode.go
Original file line number Diff line number Diff line change
Expand Up @@ -36,21 +36,21 @@ const hashLength = common.HashLength
// For branch decoding, see the comments on decodeBranch.
// For leaf decoding, see the comments on decodeLeaf.
func Decode(reader io.Reader) (n *Node, err error) {
variant, partialKeyLength, err := decodeHeader(reader)
variant, partialKeyLength, err := DecodeHeader(reader)
if err != nil {
return nil, fmt.Errorf("decoding header: %w", err)
}

switch variant {
case emptyVariant:
case EmptyVariant:
return EmptyNode, nil
case leafVariant, leafWithHashedValueVariant:
case LeafVariant, LeafWithHashedValueVariant:
n, err = decodeLeaf(reader, variant, partialKeyLength)
if err != nil {
return nil, fmt.Errorf("cannot decode leaf: %w", err)
}
return n, nil
case branchVariant, branchWithValueVariant, branchWithHashedValueVariant:
case BranchVariant, BranchWithValueVariant, BranchWithHashedValueVariant:
n, err = decodeBranch(reader, variant, partialKeyLength)
if err != nil {
return nil, fmt.Errorf("cannot decode branch: %w", err)
Expand All @@ -67,13 +67,13 @@ func Decode(reader io.Reader) (n *Node, err error) {
// reconstructing the child nodes from the encoding. This function instead stubs where the
// children are known to be with an empty leaf. The children nodes hashes are then used to
// find other storage values using the persistent database.
func decodeBranch(reader io.Reader, variant variant, partialKeyLength uint16) (
func decodeBranch(reader io.Reader, variant Variant, partialKeyLength uint16) (
node *Node, err error) {
node = &Node{
Children: make([]*Node, ChildrenCapacity),
}

node.PartialKey, err = decodeKey(reader, partialKeyLength)
node.PartialKey, err = DecodeKey(reader, partialKeyLength)
if err != nil {
return nil, fmt.Errorf("cannot decode key: %w", err)
}
Expand All @@ -87,12 +87,12 @@ func decodeBranch(reader io.Reader, variant variant, partialKeyLength uint16) (
sd := scale.NewDecoder(reader)

switch variant {
case branchWithValueVariant:
case BranchWithValueVariant:
err := sd.Decode(&node.StorageValue)
if err != nil {
return nil, fmt.Errorf("%w: %s", ErrDecodeStorageValue, err)
}
case branchWithHashedValueVariant:
case BranchWithHashedValueVariant:
hashedValue, err := decodeHashedValue(reader)
if err != nil {
return nil, err
Expand Down Expand Up @@ -136,17 +136,17 @@ func decodeBranch(reader io.Reader, variant variant, partialKeyLength uint16) (
}

// decodeLeaf reads from a reader and decodes to a leaf node.
func decodeLeaf(reader io.Reader, variant variant, partialKeyLength uint16) (node *Node, err error) {
func decodeLeaf(reader io.Reader, variant Variant, partialKeyLength uint16) (node *Node, err error) {
node = &Node{}

node.PartialKey, err = decodeKey(reader, partialKeyLength)
node.PartialKey, err = DecodeKey(reader, partialKeyLength)
if err != nil {
return nil, fmt.Errorf("cannot decode key: %w", err)
}

sd := scale.NewDecoder(reader)

if variant == leafWithHashedValueVariant {
if variant == LeafWithHashedValueVariant {
hashedValue, err := decodeHashedValue(reader)
if err != nil {
return nil, err
Expand Down
Loading

0 comments on commit afcc752

Please sign in to comment.