Skip to content

Commit

Permalink
feat: add support for subtree roots proof verification
Browse files Browse the repository at this point in the history
  • Loading branch information
rach-id committed May 29, 2024
1 parent 51df389 commit 4daae42
Show file tree
Hide file tree
Showing 4 changed files with 1,125 additions and 107 deletions.
35 changes: 35 additions & 0 deletions nmt.go
Original file line number Diff line number Diff line change
Expand Up @@ -644,6 +644,41 @@ func (n *NamespacedMerkleTree) updateMinMaxID(id namespace.ID) {
}
}

// ComputeSubtreeRoot takes a leaf range and returns the corresponding subtree root.
// This method requires the merkle tree size to be a power of two.
// Also, it requires the start and end range to correctly reference an inner node.
func (n *NamespacedMerkleTree) ComputeSubtreeRoot(start, end int) ([]byte, error) {
// check if the tree's number of leaves is a power of two.
if !isPowerOfTwo(n.Size()) {
return nil, fmt.Errorf("the tree size %d needs to be a power of two", n.Size())
}
if start < 0 {
return nil, fmt.Errorf("start %d shouldn't be strictly negative", start)
}
if end <= start {
return nil, fmt.Errorf("end %d should be stricly bigger than start %d", end, start)
}
uStart, err := safeIntToUint(start)
if err != nil {
return nil, err

Check warning on line 663 in nmt.go

View check run for this annotation

Codecov / codecov/patch

nmt.go#L663

Added line #L663 was not covered by tests
}
uEnd, err := safeIntToUint(end)
if err != nil {
return nil, err

Check warning on line 667 in nmt.go

View check run for this annotation

Codecov / codecov/patch

nmt.go#L667

Added line #L667 was not covered by tests
}
// check if the provided range correctly references an inner node.
// calculates the ideal tree from the provided range, and verifies if it is the same as the range
if idealTreeRange := nextSubtreeSize(uint64(uStart), uint64(uEnd)); end-start != idealTreeRange {
return nil, fmt.Errorf("the provided range [%d, %d) does not construct a valid subtree root range", start, end)
}
return n.computeRoot(start, end)
}

// isPowerOfTwo checks if a number is a power of two
func isPowerOfTwo(n int) bool {
return n > 0 && (n&(n-1)) == 0
}

type leafRange struct {
// start and end denote the indices of a leaf in the tree. start ranges from
// 0 up to the total number of leaves minus 1 end ranges from 1 up to the
Expand Down
172 changes: 172 additions & 0 deletions nmt_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -862,6 +862,20 @@ func exampleNMT(nidSize int, ignoreMaxNamespace bool, leavesNIDs ...byte) *Names
return tree
}

// exampleNMT2 Replica of exampleNMT except that it uses the namespace IDs in the
// leaves instead of the index.
func exampleNMT2(nidSize int, ignoreMaxNamespace bool, leavesNIDs ...byte) *NamespacedMerkleTree {
tree := New(sha256.New(), NamespaceIDSize(nidSize), IgnoreMaxNamespace(ignoreMaxNamespace))
for _, nid := range leavesNIDs {
namespace := bytes.Repeat([]byte{nid}, nidSize)
d := append(namespace, []byte(fmt.Sprintf("leaf_%d", nid))...)
if err := tree.Push(d); err != nil {
panic(fmt.Sprintf("unexpected error: %v", err))
}
}
return tree
}

func swap(slice [][]byte, i int, j int) {
temp := slice[i]
slice[i] = slice[j]
Expand Down Expand Up @@ -1175,3 +1189,161 @@ func TestForcedOutOfOrderNamespacedMerkleTree(t *testing.T) {
assert.NoError(t, err)
}
}

func TestIsPowerOfTwo(t *testing.T) {
tests := []struct {
input int
expected bool
}{
{input: 0, expected: false},
{input: 1, expected: true},
{input: 2, expected: true},
{input: 3, expected: false},
{input: 4, expected: true},
{input: 5, expected: false},
{input: 8, expected: true},
{input: 16, expected: true},
{input: -1, expected: false},
{input: -2, expected: false},
}

for _, tt := range tests {
t.Run(fmt.Sprintf("input=%d", tt.input), func(t *testing.T) {
result := isPowerOfTwo(tt.input)
if result != tt.expected {
t.Errorf("expected %v, got %v", tt.expected, result)
}
})
}
}

func TestComputeSubtreeRoot(t *testing.T) {
n := exampleNMT2(1, true, 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15)
tests := []struct {
start, end int
tree *NamespacedMerkleTree
expectedRoot []byte
expectError bool
}{
{
start: 0,
end: 16,
tree: n,
expectedRoot: func() []byte {
root, err := n.Root()
require.NoError(t, err)
return root
}(),
},
{
start: 0,
end: 8,
tree: n,
expectedRoot: func() []byte {
// because the root of the range [0,8) coincides with the root of this tree
root, err := exampleNMT2(1, true, 0, 1, 2, 3, 4, 5, 6, 7).Root()
require.NoError(t, err)
return root
}(),
},
{
start: 8,
end: 16,
tree: n,
expectedRoot: func() []byte {
// because the root of the range [8,16) coincides with the root of this tree
root, err := exampleNMT2(1, true, 8, 9, 10, 11, 12, 13, 14, 15).Root()
require.NoError(t, err)
return root
}(),
},
{
start: 8,
end: 12,
tree: n,
expectedRoot: func() []byte {
// because the root of the range [8,12) coincides with the root of this tree
root, err := exampleNMT2(1, true, 8, 9, 10, 11).Root()
require.NoError(t, err)
return root
}(),
},
{
start: 4,
end: 8,
tree: n,
expectedRoot: func() []byte {
// because the root of the range [4,8) coincides with the root of this tree
root, err := exampleNMT2(1, true, 4, 5, 6, 7).Root()
require.NoError(t, err)
return root
}(),
},
{
start: 4,
end: 6,
tree: n,
expectedRoot: func() []byte {
// because the root of the range [4,6) coincides with the root of this tree
root, err := exampleNMT2(1, true, 4, 5).Root()
require.NoError(t, err)
return root
}(),
},
{
start: 4,
end: 5,
tree: n,
expectedRoot: func() []byte {
// because the root of the range [4,5) coincides with the root of this tree
root, err := exampleNMT2(1, true, 4).Root()
require.NoError(t, err)
return root
}(),
},
{ // doesn't correctly reference an inner node
start: 2,
end: 6,
tree: n,
expectError: true,
},
{
start: -1, // invalid start
end: 4,
tree: n,
expectError: true,
},
{
start: 4,
end: 4, // start == end
tree: n,
expectError: true,
},
{
start: 5, // start >= end
end: 4,
tree: n,
expectError: true,
},
{
start: 0,
end: 16,
tree: func() *NamespacedMerkleTree {
return exampleNMT2(1, true, 0, 1, 2, 3, 4) // tree leaves are not a power of 2
}(),
expectError: true,
},
}

for _, tt := range tests {
t.Run(fmt.Sprintf("treeSize=%d,start=%d,end=%d", tt.tree.Size(), tt.start, tt.end), func(t *testing.T) {
root, err := tt.tree.ComputeSubtreeRoot(tt.start, tt.end)
if tt.expectError {
assert.Error(t, err)
} else {
assert.NoError(t, err)
assert.Equal(t, tt.expectedRoot, root)
}
})
}
}
Loading

0 comments on commit 4daae42

Please sign in to comment.