diff --git a/merkle/node.go b/merkle/node.go new file mode 100644 index 00000000..18ef5f4d --- /dev/null +++ b/merkle/node.go @@ -0,0 +1,30 @@ +package merkle + +type Node struct { + hash []byte + + parent *Node + left *Node + right *Node + + // first index of elements in this node subnodes (inclusive) + firstIndex int + // last index of elements in this node subnodes (inclusive) + lastIndex int +} + +func (n *Node) GetIndexProofs(i int) []Proof { + proofs := make([]Proof, 0) + + if n.left != nil && i >= n.left.firstIndex && i <= n.left.lastIndex { + proofs = n.left.GetIndexProofs(i) + proofs = append(proofs, Proof{hash: n.right.hash, leftSide: false}) + } + + if n.right != nil && i >= n.right.firstIndex && i <= n.right.lastIndex { + proofs = n.right.GetIndexProofs(i) + proofs = append(proofs, Proof{hash: n.left.hash, leftSide: true}) + } + + return proofs +} diff --git a/merkle/proof.go b/merkle/proof.go index 44f8e6ff..968f2f4a 100644 --- a/merkle/proof.go +++ b/merkle/proof.go @@ -1,23 +1,20 @@ package merkle -import "crypto/sha256" +import ( + "hash" +) type Proof struct { - left bool // where proof should be placed for concat with hash (left or right) - hash []byte + leftSide bool // where proof should be placed to sum with hash (left or right side) + hash []byte } // calculate sum hash -func (p *Proof) ConcatWith(hash []byte) []byte { - h := sha256.New() +func (p *Proof) SumWith(hashF hash.Hash, hash []byte) []byte { - if p.left { - h.Write(p.hash) - h.Write(hash) + if p.leftSide { + return sum(hashF, p.hash, hash) } else { - h.Write(hash) - h.Write(p.hash) + return sum(hashF, hash, p.hash) } - - return h.Sum(nil) } diff --git a/merkle/subtree.go b/merkle/subtree.go new file mode 100644 index 00000000..3a92f7f0 --- /dev/null +++ b/merkle/subtree.go @@ -0,0 +1,62 @@ +package merkle + +import "hash" + +// subtrees always should have power of 2 number of elements. +// tree could contain few of subtrees. +// root hash calculates from right to left by summing subtree roots hashes. +type Subtree struct { + root *Node // root node + + left *Subtree // left subtree from this one + right *Subtree // right subtree from this one + + height int // height of subtree + // hash function to hash sum nodes and hash data + hashF hash.Hash +} + +// get proofs for root of this subtree +func (t *Subtree) GetRootProofs() []Proof { + proofs := make([]Proof, 0) + + proofs = append(proofs, t.getRightSubtreesProof()...) + proofs = append(proofs, t.getLeftSubtreesProof()...) + + return proofs +} + +func (t *Subtree) getLeftSubtreesProof() []Proof { + proofs := make([]Proof, 0) + current := t.left + for current != nil { + proofs = append(proofs, Proof{hash: current.root.hash, leftSide: false}) + current = current.left + } + return proofs +} + +// right proof is only one cause we have to sum all right subtrees +// we have to sum hashes from right to left +func (t *Subtree) getRightSubtreesProof() []Proof { + + if t.right == nil { + return make([]Proof, 0) + } + + hashesToSum := make([][]byte, 0) + + rightTree := t.right + for rightTree != nil { + hashesToSum = append(hashesToSum, rightTree.root.hash) + rightTree = rightTree.right + } + + n := len(hashesToSum) - 1 + proofHash := hashesToSum[n] + for i := n - 1; i >= 0; i-- { + proofHash = sum(t.hashF, proofHash, hashesToSum[i]) + } + + return []Proof{{hash: proofHash, leftSide: true}} +} diff --git a/merkle/tree.go b/merkle/tree.go index cb25319d..2ba759a9 100644 --- a/merkle/tree.go +++ b/merkle/tree.go @@ -2,121 +2,33 @@ package merkle import ( "bytes" - "crypto/sha256" + "hash" "math" ) -type Subtree struct { - root *Node - - left *Subtree - right *Subtree - - height int -} - -func (t *Subtree) GetRootProofs() []Proof { - proofs := make([]Proof, 0) - - proofs = append(proofs, t.getRightProofs()...) - proofs = append(proofs, t.getLeftProofs()...) - - return proofs -} - -func (t *Subtree) getLeftProofs() []Proof { - proofs := make([]Proof, 0, 1) - current := t.left - for current != nil { - proofs = append(proofs, Proof{hash: current.root.hash, left: false}) - current = current.left - } - return proofs -} - -// right proof is only one cause we have to merge all right trees -func (t *Subtree) getRightProofs() []Proof { - - if t.right == nil { - return make([]Proof, 0) - } - - hashesToSum := make([][]byte, 0) - - rightTree := t.right - for rightTree != nil { - hashesToSum = append(hashesToSum, rightTree.root.hash) - rightTree = rightTree.right - } - - n := len(hashesToSum) - 1 - proofHash := hashesToSum[n] - for i := n - 1; i >= 0; i-- { - h := sha256.New() - h.Write(proofHash) - h.Write(hashesToSum[i]) - proofHash = h.Sum(nil) - } - - return []Proof{{hash: proofHash, left: true}} -} - -type Node struct { - hash []byte - - parent *Node - left *Node - right *Node - - // first index of elements in this node subnodes (inclusive) - firstIndex int - // last index of elements in this node subnodes (inclusive) - lastIndex int -} - -func (n *Node) GetIndexProofs(i int) []Proof { - proofs := make([]Proof, 0) - - if n.left != nil && i >= n.left.firstIndex && i <= n.left.lastIndex { - proofs = n.left.GetIndexProofs(i) - proofs = append(proofs, Proof{hash: n.right.hash, left: false}) - } - - if n.right != nil && i >= n.right.firstIndex && i <= n.right.lastIndex { - proofs = n.right.GetIndexProofs(i) - proofs = append(proofs, Proof{hash: n.left.hash, left: true}) - } - - return proofs -} - -// we separate whole tree to sub trees where nodes count equal power of 2 +// Merkle tree data structure based on RFC-6962 standard (https://tools.ietf.org/html/rfc6962#section-2.1) +// we separate whole tree to subtrees where nodes count equal power of 2 +// root hash calculates from right to left by summing subtree roots hashes. type Tree struct { - // this tree subtrees start from lowest height (from last right subtree) + // this tree subtrees start from lowest height (extreme right subtree) subTree *Subtree - // first index of elements in this tree (inclusive) - firstIndex int // last index of elements in this tree (exclusive) lastIndex int - hash []byte + hashF hash.Hash // DON'T USE IT FOR PARALLEL CALCULATION (results in errors) } -func NewTree() Tree { - return Tree{} +func NewTree(hashF hash.Hash) Tree { + return Tree{hashF: hashF} } func (t *Tree) joinAllSubtrees() { for t.subTree.left != nil && t.subTree.height == t.subTree.left.height { - newRootHash := sha256.New() - newRootHash.Write(t.subTree.left.root.hash) - newRootHash.Write(t.subTree.root.hash) - newSubtreeRoot := &Node{ - hash: newRootHash.Sum(nil), + hash: sum(t.hashF, t.subTree.left.root.hash, t.subTree.root.hash), parent: nil, left: t.subTree.left.root, right: t.subTree.root, @@ -132,6 +44,7 @@ func (t *Tree) joinAllSubtrees() { right: nil, left: t.subTree.left.left, height: t.subTree.height + 1, + hashF: t.hashF, } if t.subTree.left != nil { @@ -141,14 +54,48 @@ func (t *Tree) joinAllSubtrees() { } } +func (t *Tree) Reset() { + t.lastIndex = 0 + t.subTree = nil +} + +// build completely new tree with data +// works the same (by time) as using Push method one by one +func (t *Tree) BuildNew(data [][]byte) { + t.Reset() + itemsLeft := int64(len(data)) + + nextSubtreeLen := int64(math.Pow(2, float64(int64(math.Log2(float64(itemsLeft)))))) + startIndex := int64(0) + endIndex := startIndex + nextSubtreeLen + + for nextSubtreeLen != 0 { + + nextSubtree := buildSubTree(t.hashF, int(startIndex), data[startIndex:endIndex]) + + if t.subTree != nil { + t.subTree.right = nextSubtree + nextSubtree.left = t.subTree + t.subTree = nextSubtree + } else { + t.subTree = nextSubtree + } + + itemsLeft = itemsLeft - nextSubtreeLen + nextSubtreeLen = int64(math.Pow(2, float64(int64(math.Log2(float64(itemsLeft)))))) + startIndex = endIndex + endIndex = startIndex + nextSubtreeLen + } + + t.lastIndex = int(endIndex) + +} + // n*log(n) func (t *Tree) Push(data []byte) { - hash := sha256.New() - hash.Write(data) - newSubtreeRoot := &Node{ - hash: hash.Sum(nil), + hash: sum(t.hashF, data), parent: nil, left: nil, right: nil, @@ -163,6 +110,7 @@ func (t *Tree) Push(data []byte) { right: nil, left: t.subTree, height: 0, + hashF: t.hashF, } if t.subTree.left != nil { @@ -170,8 +118,6 @@ func (t *Tree) Push(data []byte) { } t.joinAllSubtrees() - - t.hash = t.GetRootHash() } // going from right trees to left @@ -199,18 +145,16 @@ func (t *Tree) ValidateIndex(i int, data []byte) bool { func (t *Tree) ValidateIndexByProofs(i int, data []byte, proofs []Proof) bool { - h := sha256.New() - h.Write(data) - - rootHash := h.Sum(nil) + rootHash := sum(t.hashF, data) for _, proof := range proofs { - rootHash = proof.ConcatWith(rootHash) + rootHash = proof.SumWith(t.hashF, rootHash) } - return bytes.Equal(rootHash, t.GetRootHash()) + return bytes.Equal(rootHash, t.RootHash()) } -func (t *Tree) GetRootHash() []byte { +// root hash calculates from right to left by summing subtree roots hashes. +func (t *Tree) RootHash() []byte { if t.subTree == nil { return nil @@ -220,11 +164,7 @@ func (t *Tree) GetRootHash() []byte { current := t.subTree.left for current != nil { - h := sha256.New() - h.Write(rootHash) - h.Write(current.root.hash) - - rootHash = h.Sum(nil) + rootHash = sum(t.hashF, rootHash, current.root.hash) current = current.left } diff --git a/merkle/tree_test.go b/merkle/tree_test.go index c4af8109..94d8859b 100644 --- a/merkle/tree_test.go +++ b/merkle/tree_test.go @@ -1,14 +1,15 @@ package merkle import ( + "crypto/sha256" "encoding/binary" "github.com/stretchr/testify/require" "testing" ) -func TestProofs(t *testing.T) { +func TestPushAndProofs(t *testing.T) { - tree := NewTree() + tree := NewTree(sha256.New()) data := make([]byte, 8) @@ -25,3 +26,53 @@ func TestProofs(t *testing.T) { } } + +func TestBuildNewAndProofs(t *testing.T) { + + tree := NewTree(sha256.New()) + + allData := make([][]byte, 0, 31) + + for i := 0; i < 31; i++ { + data := make([]byte, 8) + binary.LittleEndian.PutUint64(data, uint64(i)) + allData = append(allData, data) + } + + tree.BuildNew(allData) + + // Check all proofs + for i := 0; i < 31; i++ { + proofs := tree.GetIndexProofs(i) + data := make([]byte, 8) + binary.LittleEndian.PutUint64(data, uint64(i)) + require.Equal(t, true, tree.ValidateIndexByProofs(i, data, proofs)) + } + +} + +func TestEqualityOfBuildNewAndPush(t *testing.T) { + + tree1 := NewTree(sha256.New()) + + data := make([]byte, 8) + + for i := 0; i < 31; i++ { + binary.LittleEndian.PutUint64(data, uint64(i)) + tree1.Push(data) + } + + tree2 := NewTree(sha256.New()) + + allData := make([][]byte, 0, 31) + + for i := 0; i < 31; i++ { + data := make([]byte, 8) + binary.LittleEndian.PutUint64(data, uint64(i)) + allData = append(allData, data) + } + + tree2.BuildNew(allData) + + require.Equal(t, tree1.RootHash(), tree2.RootHash()) +} diff --git a/merkle/util.go b/merkle/util.go new file mode 100644 index 00000000..8f94f7a4 --- /dev/null +++ b/merkle/util.go @@ -0,0 +1,69 @@ +package merkle + +import ( + "hash" + "math" +) + +func sum(h hash.Hash, data ...[]byte) []byte { + h.Reset() + for _, d := range data { + // the Hash interface specifies that Write never returns an error + _, _ = h.Write(d) + } + return h.Sum(nil) +} + +// number of data elements should be power of 2 +// not suitable for parallel calculations cause using same hash.Hash +func buildSubTree(h hash.Hash, startIndex int, data [][]byte) *Subtree { + + nodes := make([]*Node, len(data)) + for i := 0; i < len(data); i++ { + + nodes[i] = &Node{ + hash: sum(h, data[i]), + firstIndex: startIndex + i, + lastIndex: startIndex + i, + } + + } + + root := sumNodes(h, nodes)[0] + + return &Subtree{ + root: root, + left: nil, + right: nil, + height: int(math.Log2(float64(len(data)))), + hashF: h, + } +} + +func sumNodes(h hash.Hash, nodes []*Node) []*Node { + + if len(nodes) == 1 { + return nodes + } + + newNodes := make([]*Node, len(nodes)/2) + for i := 0; i < len(nodes); i += 2 { + newNodes[i/2] = joinNodes(h, nodes[i], nodes[i+1]) + } + + return sumNodes(h, newNodes) +} + +func joinNodes(h hash.Hash, left *Node, right *Node) *Node { + newNode := &Node{ + firstIndex: left.firstIndex, + lastIndex: right.lastIndex, + hash: sum(h, left.hash, right.hash), + left: left, + right: right, + parent: nil, + } + left.parent = newNode + right.parent = newNode + return newNode +}