Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Use tree height instead of chunk limit for merkleizer #74

Merged
merged 1 commit into from
Jan 11, 2024
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
133 changes: 74 additions & 59 deletions ssz_serialization/merkleization.nim
Original file line number Diff line number Diff line change
Expand Up @@ -51,9 +51,9 @@ func binaryTreeHeight*(totalElements: Limit): int =
bitWidth nextPow2(uint64 totalElements)

type
SszMerkleizer*[limit: static[Limit]] = object
SszMerkleizer2*[height: static[int]] = object
## The Merkleizer incrementally computes the SSZ-style Merkle root of a tree
## with `limit` leaf nodes.
## with `2**(height-1)` leaf nodes.
##
## As chunks are added, the combined hash of each pair of chunks is computed
## and partially propagated up the tree in the `combinedChunks` array -
Expand All @@ -65,17 +65,26 @@ type
# `sha256.update` otherwise would do.
# The two digests represent the left and right nodes that get combined to
# a parent node in the tree.
# `SszMerkleizer` used chunk count as limit
# TODO it's possible to further parallelize by using even wider buffers here

combinedChunks: array[binaryTreeHeight limit, (Digest, Digest)]
combinedChunks: array[height, (Digest, Digest)]
totalChunks*: uint64 # Public for historical reasons
topIndex: int
internal: bool
# Avoid copying chunk data into merkleizer when not needed - maw result
# in an incomplete root-to-leaf proof

template getChunkCount*(m: SszMerkleizer): uint64 =
template limit*(T: type SszMerkleizer2): Limit =
if T.height == 0: 0'i64 else: 1'i64 shl (T.height - 1)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hrm, what about the cases where the limit is not a power of 2? We have to review more carefully how this code is being used.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

no critical usages, ie it doesn't materially matter since all trees are extended to power-of-2 anyway


template limit*(v: SszMerkleizer2): Limit =
typeof(v).limit

template getChunkCount*(m: SszMerkleizer2): uint64 =
m.totalChunks

func getCombinedChunks*(m: SszMerkleizer): seq[Digest] =
func getCombinedChunks*(m: SszMerkleizer2): seq[Digest] =
mapIt(toOpenArray(m.combinedChunks, 0, m.topIndex), it[0])

when USE_BLST_SHA256:
Expand Down Expand Up @@ -182,7 +191,7 @@ func computeZeroHashes: array[sizeof(Limit) * 8, Digest] =

const zeroHashes* = computeZeroHashes()

template combineChunks(merkleizer: var SszMerkleizer, start: int) =
func combineChunks(merkleizer: var SszMerkleizer2, start: int) =
for i in start..<merkleizer.topIndex:
trs "CALLING MERGE BRANCHES"
if getBitLE(merkleizer.totalChunks, i + 1):
Expand All @@ -195,31 +204,35 @@ template combineChunks(merkleizer: var SszMerkleizer, start: int) =
merkleizer.combinedChunks[i + 1][0])
break

template addChunkDirect(merkleizer: var SszMerkleizer, body: untyped) =
template addChunkDirect(merkleizer: var SszMerkleizer2, body: untyped) =
# add chunk allowing `body` to write directly to `chunk` memory thus avoiding
# an extra copy - body must completely fill the chunk, including any zero
# padding

# the following mixin is a workaround for nim 1.6.12
# and the bug seems to be fixed in nim 1.6.14
mixin combineChunks

# TODO panic here isn't great - turn this into a bool-returning function?
doAssert merkleizer.totalChunks < merkleizer.limit.uint64,
"Adding chunks would exceed merklelizer limit " & $merkleizer.limit

if getBitLE(merkleizer.totalChunks, 0):
template chunk: Digest {.inject.} = merkleizer.combinedChunks[0][1]
let
odd = getBitLE(merkleizer.totalChunks, 0)
# addr needed to work around compile-time evaluation issue
chunkAddr = if odd:
addr merkleizer.combinedChunks[0][1]
else:
addr merkleizer.combinedChunks[0][0]

block:
template chunk: Digest {.inject.} = chunkAddr[]
body

if odd:
merkleizer.combineChunks(0)
else:
template chunk: Digest {.inject.} = merkleizer.combinedChunks[0][0]
body
trs "WROTE BASE CHUNK ", toHex(merkleizer.combinedChunks[0][0].data)
trs "WROTE BASE CHUNK ", toHex(chunkAddr[].data)

inc merkleizer.totalChunks

func addChunk*(merkleizer: var SszMerkleizer, data: openArray[byte]) =
func addChunk*(merkleizer: var SszMerkleizer2, data: openArray[byte]) =
doAssert data.len > 0 and data.len <= bytesPerChunk

when merkleizer.limit > 0:
Expand All @@ -233,7 +246,7 @@ func addChunk*(merkleizer: var SszMerkleizer, data: openArray[byte]) =
template isOdd(x: SomeNumber): bool =
(x and 1) != 0

func addChunks*(merkleizer: var SszMerkleizer, data: openArray[byte]) =
func addChunks*(merkleizer: var SszMerkleizer2, data: openArray[byte]) =
doAssert merkleizer.totalChunks == 0
doAssert merkleizer.limit * bytesPerChunk >= data.len,
"Adding chunks would exceed merklelizer limit " & $merkleizer.limit
Expand Down Expand Up @@ -275,7 +288,7 @@ func addChunks*(merkleizer: var SszMerkleizer, data: openArray[byte]) =
merkleizer.addChunk(data.toOpenArray(done, data.high))
break

func addChunkAndGenMerkleProof*(merkleizer: var SszMerkleizer,
func addChunkAndGenMerkleProof*(merkleizer: var SszMerkleizer2,
hash: Digest,
outProof: var openArray[Digest]) =
var
Expand All @@ -297,7 +310,7 @@ func addChunkAndGenMerkleProof*(merkleizer: var SszMerkleizer,

merkleizer.totalChunks += 1

func completeStartedChunk(merkleizer: var SszMerkleizer,
func completeStartedChunk(merkleizer: var SszMerkleizer2,
hash: Digest, atLevel: int) =
when false:
let
Expand All @@ -313,7 +326,7 @@ func completeStartedChunk(merkleizer: var SszMerkleizer,
merkleizer.combinedChunks[i][0] = hash
break

func addChunksAndGenMerkleProofs*(merkleizer: var SszMerkleizer,
func addChunksAndGenMerkleProofs*(merkleizer: var SszMerkleizer2,
chunks: openArray[Digest]): seq[Digest] =
doAssert chunks.len > 0 and merkleizer.topIndex > 0

Expand Down Expand Up @@ -465,36 +478,38 @@ func addChunksAndGenMerkleProofs*(merkleizer: var SszMerkleizer,

merkleizer.totalChunks = newTotalChunks

func init*(S: type SszMerkleizer): S =
func init*(S: type SszMerkleizer2): S =
S(
topIndex: binaryTreeHeight(S.limit) - 1,
topIndex: S.height - 1,
totalChunks: 0)

func init*(S: type SszMerkleizer,
func init*(S: type SszMerkleizer2,
combinedChunks: openArray[Digest],
totalChunks: uint64): S =
for i in 0..<combinedChunks.len:
result.combinedChunks[i][0] = combinedChunks[i]
result.topIndex = binaryTreeHeight(S.limit) - 1
result.topIndex = S.height - 1
result.totalChunks = totalChunks

func copy*[L: static[Limit]](cloned: SszMerkleizer[L]): SszMerkleizer[L] {.deprecated.} =
cloned

template createMerkleizer*(
totalElements: static Limit, topLayer = 0,
template createMerkleizer2*(
height: static Limit, topLayer = 0,
internalParam = false): auto =
trs "CREATING A MERKLEIZER FOR ", totalElements, " (topLayer: ", topLayer, ")"
trs "CREATING A MERKLEIZER FOR ", height, " (topLayer: ", topLayer, ")"

const treeHeight = binaryTreeHeight totalElements
let topIndex = treeHeight - 1 - topLayer
let topIndex = height - 1 - topLayer

SszMerkleizer[totalElements](
SszMerkleizer2[height](
topIndex: if (topIndex < 0): 0 else: topIndex,
totalChunks: 0,
internal: internalParam)

func getFinalHash(merkleizer: var SszMerkleizer, res: var Digest) =
template createMerkleizer*(
totalElements: static Limit, topLayer = 0,
internalParam = false): auto =
const treeHeight = binaryTreeHeight totalElements
createMerkleizer2(treeHeight, topLayer, internalParam)

func getFinalHash(merkleizer: var SszMerkleizer2, res: var Digest) =
if merkleizer.totalChunks == 0:
res = zeroHashes[merkleizer.topIndex]
return
Expand Down Expand Up @@ -556,7 +571,7 @@ func getFinalHash(merkleizer: var SszMerkleizer, res: var Digest) =
for i in bottomHashIdx + 1 ..< topHashIdx:
mergeBranches(res, zeroHashes[i], res)

func getFinalHash*(merkleizer: var SszMerkleizer): Digest {.noinit.} =
func getFinalHash*(merkleizer: var SszMerkleizer2): Digest {.noinit.} =
getFinalHash(merkleizer, result)

func mixInLength(root: Digest, length: int, res: var Digest) =
Expand Down Expand Up @@ -598,7 +613,7 @@ template writeBytesLE(chunk: var array[bytesPerChunk, byte], atParam: int,
chunk[at ..< at + sizeof(val)] = toBytesLE(val)

func chunkedHashTreeRoot[T: BasicType](
merkleizer: var SszMerkleizer, arr: openArray[T],
merkleizer: var SszMerkleizer2, arr: openArray[T],
firstIdx, numFromFirst: Limit, res: var Digest) =
static:
doAssert bytesPerChunk mod sizeof(T) == 0
Expand Down Expand Up @@ -631,15 +646,15 @@ func chunkedHashTreeRoot[T: BasicType](
getFinalHash(merkleizer, res)

func chunkedHashTreeRoot[T: not BasicType](
merkleizer: var SszMerkleizer, arr: openArray[T],
merkleizer: var SszMerkleizer2, arr: openArray[T],
firstIdx, numFromFirst: Limit, res: var Digest) =
for i in 0 ..< numFromFirst:
addChunkDirect(merkleizer):
chunk = hash_tree_root(arr[firstIdx + i])
getFinalHash(merkleizer, res)

template chunkedHashTreeRoot[T](
totalChunks: static Limit, arr: openArray[T],
func chunkedHashTreeRoot[T](
height: static Limit, arr: openArray[T],
chunks: Slice[Limit], topLayer: int, res: var Digest) =
const valuesPerChunk =
when T is BasicType:
Expand All @@ -648,25 +663,23 @@ template chunkedHashTreeRoot[T](
1
let firstIdx = chunks.a * valuesPerChunk
if arr.len <= firstIdx:
const treeHeight = binaryTreeHeight totalChunks
res = zeroHashes[treeHeight - 1 - topLayer]
res = zeroHashes[height - 1 - topLayer]
else:
var merkleizer = createMerkleizer(totalChunks, topLayer, internalParam = true)
var merkleizer = createMerkleizer2(height, topLayer, internalParam = true)
let numFromFirst =
min((chunks.b - chunks.a + 1) * valuesPerChunk, arr.len - firstIdx)
chunkedHashTreeRoot(merkleizer, arr, firstIdx, numFromFirst, res)

template chunkedHashTreeRoot[T](
totalChunks: static Limit, arr: openArray[T], res: var Digest) =
func chunkedHashTreeRoot[T](
height: static Limit, arr: openArray[T], res: var Digest) =
if arr.len <= 0:
const treeHeight = binaryTreeHeight totalChunks
res = zeroHashes[treeHeight - 1]
res = zeroHashes[height - 1]
else:
var merkleizer = createMerkleizer(totalChunks, internalParam = true)
var merkleizer = createMerkleizer2(height, internalParam = true)
chunkedHashTreeRoot(merkleizer, arr, 0, arr.len, res)

func bitListHashTreeRoot(
merkleizer: var SszMerkleizer, x: BitSeq, chunks: Slice[Limit],
merkleizer: var SszMerkleizer2, x: BitSeq, chunks: Slice[Limit],
res: var Digest) =
# TODO: Switch to a simpler BitList representation and
# replace this with `chunkedHashTreeRoot`
Expand Down Expand Up @@ -806,11 +819,11 @@ func hashTreeRootAux[T](x: T, res: var Digest) =
else:
trs "FIXED TYPE; USE CHUNK STREAM"
const totalChunks = maxChunksCount(T, x.len)
chunkedHashTreeRoot(totalChunks, x, res)
chunkedHashTreeRoot(binaryTreeHeight totalChunks, x, res)
elif T is List:
const totalChunks = maxChunksCount(T, x.maxLen)
var contentsHash {.noinit.}: Digest
chunkedHashTreeRoot(totalChunks, asSeq x, contentsHash)
chunkedHashTreeRoot(binaryTreeHeight totalChunks, asSeq x, contentsHash)
mixInLength(contentsHash, x.len, res)
elif T is OptionalType:
if x.isSome:
Expand Down Expand Up @@ -893,11 +906,11 @@ func hashTreeRootAux[T](
index = indexAt(i)
indexLayer = log2trunc(index)
if index == 1.GeneralizedIndex:
chunkedHashTreeRoot(totalChunks, x, rootAt(i))
chunkedHashTreeRoot(binaryTreeHeight totalChunks, x, rootAt(i))
inc i
elif indexLayer <= chunkLayer:
let chunks = chunksForIndex(index)
chunkedHashTreeRoot(totalChunks, x, chunks, indexLayer, rootAt(i))
chunkedHashTreeRoot(binaryTreeHeight totalChunks, x, chunks, indexLayer, rootAt(i))
inc i
else:
when ElemType(typeof(x)) is BasicType: return unsupportedIndex
Expand Down Expand Up @@ -928,14 +941,14 @@ func hashTreeRootAux[T](
indexLayer = log2trunc(index)
if index == 1.GeneralizedIndex:
var contentsHash {.noinit.}: Digest
chunkedHashTreeRoot(totalChunks, asSeq x, contentsHash)
chunkedHashTreeRoot(binaryTreeHeight totalChunks, asSeq x, contentsHash)
mixInLength(contentsHash, x.len, rootAt(i))
inc i
elif index == 3.GeneralizedIndex:
hashTreeRootAux(x.len.uint64, rootAt(i))
inc i
elif index == 2.GeneralizedIndex:
chunkedHashTreeRoot(totalChunks, asSeq x, rootAt(i))
chunkedHashTreeRoot(binaryTreeHeight totalChunks, asSeq x, rootAt(i))
inc i
elif (index shr (indexLayer - 1)) == 2.GeneralizedIndex:
let
Expand All @@ -945,7 +958,7 @@ func hashTreeRootAux[T](
if indexLayer <= chunkLayer:
let chunks = chunksForIndex(index)
chunkedHashTreeRoot(
totalChunks, asSeq x, chunks, indexLayer, rootAt(i))
binaryTreeHeight totalChunks, asSeq x, chunks, indexLayer, rootAt(i))
inc i
else:
when ElemType(typeof(x)) is BasicType: return unsupportedIndex
Expand Down Expand Up @@ -1020,6 +1033,7 @@ func hashTreeRootAux[T](
trs "MERKLEIZING FIELDS"
const
totalChunks = totalSerializedFields(T)
treeHeight = binaryTreeHeight(totalChunks)
firstChunkIndex = nextPow2(totalChunks.uint64)
chunkLayer = log2trunc(firstChunkIndex)
var
Expand All @@ -1029,7 +1043,7 @@ func hashTreeRootAux[T](
index {.noinit.}: GeneralizedIndex
indexLayer {.noinit.}: int
chunks {.noinit.}: Slice[Limit]
merkleizer {.noinit.}: SszMerkleizer[totalSerializedFields(T)]
merkleizer {.noinit.}: SszMerkleizer2[treeHeight]
chunk {.noinit.}: Limit
nextField {.noinit.}: Limit
x.enumerateSubFields(f):
Expand Down Expand Up @@ -1221,7 +1235,8 @@ func hashTreeRootCached*(
inc i
elif indexLayer == chunkLayer:
let chunks = chunksForIndex(index)
chunkedHashTreeRoot(totalChunks, x.data, chunks, indexLayer, rootAt(i))
chunkedHashTreeRoot(
binaryTreeHeight totalChunks, x.data, chunks, indexLayer, rootAt(i))
inc i
elif indexLayer < chunkLayer:
rootAt(i) = hashTreeRootCachedPtr(x, index.int64)[]
Expand Down Expand Up @@ -1279,7 +1294,7 @@ func hashTreeRootCached*(
if indexLayer == chunkLayer:
let chunks = chunksForIndex(index)
chunkedHashTreeRoot(
totalChunks, asSeq x.data, chunks, indexLayer, rootAt(i))
binaryTreeHeight totalChunks, asSeq x.data, chunks, indexLayer, rootAt(i))
inc i
elif indexLayer < chunkLayer:
rootAt(i) = hashTreeRootCachedPtr(x, index.int64)[]
Expand Down
Loading