Skip to content

Commit

Permalink
Use tree height instead of chunk limit for merkleizer (#74)
Browse files Browse the repository at this point in the history
More closely matches implementation and avoids generic bloat.

Renamed to avoid accidental limit-vs-height confusion.
  • Loading branch information
arnetheduck authored Jan 11, 2024
1 parent f87c99b commit 66de36a
Showing 1 changed file with 74 additions and 59 deletions.
133 changes: 74 additions & 59 deletions ssz_serialization/merkleization.nim
Original file line number Diff line number Diff line change
Expand Up @@ -51,9 +51,9 @@ func binaryTreeHeight*(totalElements: Limit): int =
bitWidth nextPow2(uint64 totalElements)

type
SszMerkleizer*[limit: static[Limit]] = object
SszMerkleizer2*[height: static[int]] = object
## The Merkleizer incrementally computes the SSZ-style Merkle root of a tree
## with `limit` leaf nodes.
## with `2**(height-1)` leaf nodes.
##
## As chunks are added, the combined hash of each pair of chunks is computed
## and partially propagated up the tree in the `combinedChunks` array -
Expand All @@ -65,17 +65,26 @@ type
# `sha256.update` otherwise would do.
# The two digests represent the left and right nodes that get combined to
# a parent node in the tree.
# `SszMerkleizer` used chunk count as limit
# TODO it's possible to further parallelize by using even wider buffers here

combinedChunks: array[binaryTreeHeight limit, (Digest, Digest)]
combinedChunks: array[height, (Digest, Digest)]
totalChunks*: uint64 # Public for historical reasons
topIndex: int
internal: bool
# Avoid copying chunk data into merkleizer when not needed - maw result
# in an incomplete root-to-leaf proof

template getChunkCount*(m: SszMerkleizer): uint64 =
template limit*(T: type SszMerkleizer2): Limit =
if T.height == 0: 0'i64 else: 1'i64 shl (T.height - 1)

template limit*(v: SszMerkleizer2): Limit =
typeof(v).limit

template getChunkCount*(m: SszMerkleizer2): uint64 =
m.totalChunks

func getCombinedChunks*(m: SszMerkleizer): seq[Digest] =
func getCombinedChunks*(m: SszMerkleizer2): seq[Digest] =
mapIt(toOpenArray(m.combinedChunks, 0, m.topIndex), it[0])

when USE_BLST_SHA256:
Expand Down Expand Up @@ -182,7 +191,7 @@ func computeZeroHashes: array[sizeof(Limit) * 8, Digest] =

const zeroHashes* = computeZeroHashes()

template combineChunks(merkleizer: var SszMerkleizer, start: int) =
func combineChunks(merkleizer: var SszMerkleizer2, start: int) =
for i in start..<merkleizer.topIndex:
trs "CALLING MERGE BRANCHES"
if getBitLE(merkleizer.totalChunks, i + 1):
Expand All @@ -195,31 +204,35 @@ template combineChunks(merkleizer: var SszMerkleizer, start: int) =
merkleizer.combinedChunks[i + 1][0])
break

template addChunkDirect(merkleizer: var SszMerkleizer, body: untyped) =
template addChunkDirect(merkleizer: var SszMerkleizer2, body: untyped) =
# add chunk allowing `body` to write directly to `chunk` memory thus avoiding
# an extra copy - body must completely fill the chunk, including any zero
# padding

# the following mixin is a workaround for nim 1.6.12
# and the bug seems to be fixed in nim 1.6.14
mixin combineChunks

# TODO panic here isn't great - turn this into a bool-returning function?
doAssert merkleizer.totalChunks < merkleizer.limit.uint64,
"Adding chunks would exceed merklelizer limit " & $merkleizer.limit

if getBitLE(merkleizer.totalChunks, 0):
template chunk: Digest {.inject.} = merkleizer.combinedChunks[0][1]
let
odd = getBitLE(merkleizer.totalChunks, 0)
# addr needed to work around compile-time evaluation issue
chunkAddr = if odd:
addr merkleizer.combinedChunks[0][1]
else:
addr merkleizer.combinedChunks[0][0]

block:
template chunk: Digest {.inject.} = chunkAddr[]
body

if odd:
merkleizer.combineChunks(0)
else:
template chunk: Digest {.inject.} = merkleizer.combinedChunks[0][0]
body
trs "WROTE BASE CHUNK ", toHex(merkleizer.combinedChunks[0][0].data)
trs "WROTE BASE CHUNK ", toHex(chunkAddr[].data)

inc merkleizer.totalChunks

func addChunk*(merkleizer: var SszMerkleizer, data: openArray[byte]) =
func addChunk*(merkleizer: var SszMerkleizer2, data: openArray[byte]) =
doAssert data.len > 0 and data.len <= bytesPerChunk

when merkleizer.limit > 0:
Expand All @@ -233,7 +246,7 @@ func addChunk*(merkleizer: var SszMerkleizer, data: openArray[byte]) =
template isOdd(x: SomeNumber): bool =
(x and 1) != 0

func addChunks*(merkleizer: var SszMerkleizer, data: openArray[byte]) =
func addChunks*(merkleizer: var SszMerkleizer2, data: openArray[byte]) =
doAssert merkleizer.totalChunks == 0
doAssert merkleizer.limit * bytesPerChunk >= data.len,
"Adding chunks would exceed merklelizer limit " & $merkleizer.limit
Expand Down Expand Up @@ -275,7 +288,7 @@ func addChunks*(merkleizer: var SszMerkleizer, data: openArray[byte]) =
merkleizer.addChunk(data.toOpenArray(done, data.high))
break

func addChunkAndGenMerkleProof*(merkleizer: var SszMerkleizer,
func addChunkAndGenMerkleProof*(merkleizer: var SszMerkleizer2,
hash: Digest,
outProof: var openArray[Digest]) =
var
Expand All @@ -297,7 +310,7 @@ func addChunkAndGenMerkleProof*(merkleizer: var SszMerkleizer,

merkleizer.totalChunks += 1

func completeStartedChunk(merkleizer: var SszMerkleizer,
func completeStartedChunk(merkleizer: var SszMerkleizer2,
hash: Digest, atLevel: int) =
when false:
let
Expand All @@ -313,7 +326,7 @@ func completeStartedChunk(merkleizer: var SszMerkleizer,
merkleizer.combinedChunks[i][0] = hash
break

func addChunksAndGenMerkleProofs*(merkleizer: var SszMerkleizer,
func addChunksAndGenMerkleProofs*(merkleizer: var SszMerkleizer2,
chunks: openArray[Digest]): seq[Digest] =
doAssert chunks.len > 0 and merkleizer.topIndex > 0

Expand Down Expand Up @@ -465,36 +478,38 @@ func addChunksAndGenMerkleProofs*(merkleizer: var SszMerkleizer,

merkleizer.totalChunks = newTotalChunks

func init*(S: type SszMerkleizer): S =
func init*(S: type SszMerkleizer2): S =
S(
topIndex: binaryTreeHeight(S.limit) - 1,
topIndex: S.height - 1,
totalChunks: 0)

func init*(S: type SszMerkleizer,
func init*(S: type SszMerkleizer2,
combinedChunks: openArray[Digest],
totalChunks: uint64): S =
for i in 0..<combinedChunks.len:
result.combinedChunks[i][0] = combinedChunks[i]
result.topIndex = binaryTreeHeight(S.limit) - 1
result.topIndex = S.height - 1
result.totalChunks = totalChunks

func copy*[L: static[Limit]](cloned: SszMerkleizer[L]): SszMerkleizer[L] {.deprecated.} =
cloned

template createMerkleizer*(
totalElements: static Limit, topLayer = 0,
template createMerkleizer2*(
height: static Limit, topLayer = 0,
internalParam = false): auto =
trs "CREATING A MERKLEIZER FOR ", totalElements, " (topLayer: ", topLayer, ")"
trs "CREATING A MERKLEIZER FOR ", height, " (topLayer: ", topLayer, ")"

const treeHeight = binaryTreeHeight totalElements
let topIndex = treeHeight - 1 - topLayer
let topIndex = height - 1 - topLayer

SszMerkleizer[totalElements](
SszMerkleizer2[height](
topIndex: if (topIndex < 0): 0 else: topIndex,
totalChunks: 0,
internal: internalParam)

func getFinalHash(merkleizer: var SszMerkleizer, res: var Digest) =
template createMerkleizer*(
totalElements: static Limit, topLayer = 0,
internalParam = false): auto =
const treeHeight = binaryTreeHeight totalElements
createMerkleizer2(treeHeight, topLayer, internalParam)

func getFinalHash(merkleizer: var SszMerkleizer2, res: var Digest) =
if merkleizer.totalChunks == 0:
res = zeroHashes[merkleizer.topIndex]
return
Expand Down Expand Up @@ -556,7 +571,7 @@ func getFinalHash(merkleizer: var SszMerkleizer, res: var Digest) =
for i in bottomHashIdx + 1 ..< topHashIdx:
mergeBranches(res, zeroHashes[i], res)

func getFinalHash*(merkleizer: var SszMerkleizer): Digest {.noinit.} =
func getFinalHash*(merkleizer: var SszMerkleizer2): Digest {.noinit.} =
getFinalHash(merkleizer, result)

func mixInLength(root: Digest, length: int, res: var Digest) =
Expand Down Expand Up @@ -598,7 +613,7 @@ template writeBytesLE(chunk: var array[bytesPerChunk, byte], atParam: int,
chunk[at ..< at + sizeof(val)] = toBytesLE(val)

func chunkedHashTreeRoot[T: BasicType](
merkleizer: var SszMerkleizer, arr: openArray[T],
merkleizer: var SszMerkleizer2, arr: openArray[T],
firstIdx, numFromFirst: Limit, res: var Digest) =
static:
doAssert bytesPerChunk mod sizeof(T) == 0
Expand Down Expand Up @@ -631,15 +646,15 @@ func chunkedHashTreeRoot[T: BasicType](
getFinalHash(merkleizer, res)

func chunkedHashTreeRoot[T: not BasicType](
merkleizer: var SszMerkleizer, arr: openArray[T],
merkleizer: var SszMerkleizer2, arr: openArray[T],
firstIdx, numFromFirst: Limit, res: var Digest) =
for i in 0 ..< numFromFirst:
addChunkDirect(merkleizer):
chunk = hash_tree_root(arr[firstIdx + i])
getFinalHash(merkleizer, res)

template chunkedHashTreeRoot[T](
totalChunks: static Limit, arr: openArray[T],
func chunkedHashTreeRoot[T](
height: static Limit, arr: openArray[T],
chunks: Slice[Limit], topLayer: int, res: var Digest) =
const valuesPerChunk =
when T is BasicType:
Expand All @@ -648,25 +663,23 @@ template chunkedHashTreeRoot[T](
1
let firstIdx = chunks.a * valuesPerChunk
if arr.len <= firstIdx:
const treeHeight = binaryTreeHeight totalChunks
res = zeroHashes[treeHeight - 1 - topLayer]
res = zeroHashes[height - 1 - topLayer]
else:
var merkleizer = createMerkleizer(totalChunks, topLayer, internalParam = true)
var merkleizer = createMerkleizer2(height, topLayer, internalParam = true)
let numFromFirst =
min((chunks.b - chunks.a + 1) * valuesPerChunk, arr.len - firstIdx)
chunkedHashTreeRoot(merkleizer, arr, firstIdx, numFromFirst, res)

template chunkedHashTreeRoot[T](
totalChunks: static Limit, arr: openArray[T], res: var Digest) =
func chunkedHashTreeRoot[T](
height: static Limit, arr: openArray[T], res: var Digest) =
if arr.len <= 0:
const treeHeight = binaryTreeHeight totalChunks
res = zeroHashes[treeHeight - 1]
res = zeroHashes[height - 1]
else:
var merkleizer = createMerkleizer(totalChunks, internalParam = true)
var merkleizer = createMerkleizer2(height, internalParam = true)
chunkedHashTreeRoot(merkleizer, arr, 0, arr.len, res)

func bitListHashTreeRoot(
merkleizer: var SszMerkleizer, x: BitSeq, chunks: Slice[Limit],
merkleizer: var SszMerkleizer2, x: BitSeq, chunks: Slice[Limit],
res: var Digest) =
# TODO: Switch to a simpler BitList representation and
# replace this with `chunkedHashTreeRoot`
Expand Down Expand Up @@ -806,11 +819,11 @@ func hashTreeRootAux[T](x: T, res: var Digest) =
else:
trs "FIXED TYPE; USE CHUNK STREAM"
const totalChunks = maxChunksCount(T, x.len)
chunkedHashTreeRoot(totalChunks, x, res)
chunkedHashTreeRoot(binaryTreeHeight totalChunks, x, res)
elif T is List:
const totalChunks = maxChunksCount(T, x.maxLen)
var contentsHash {.noinit.}: Digest
chunkedHashTreeRoot(totalChunks, asSeq x, contentsHash)
chunkedHashTreeRoot(binaryTreeHeight totalChunks, asSeq x, contentsHash)
mixInLength(contentsHash, x.len, res)
elif T is OptionalType:
if x.isSome:
Expand Down Expand Up @@ -893,11 +906,11 @@ func hashTreeRootAux[T](
index = indexAt(i)
indexLayer = log2trunc(index)
if index == 1.GeneralizedIndex:
chunkedHashTreeRoot(totalChunks, x, rootAt(i))
chunkedHashTreeRoot(binaryTreeHeight totalChunks, x, rootAt(i))
inc i
elif indexLayer <= chunkLayer:
let chunks = chunksForIndex(index)
chunkedHashTreeRoot(totalChunks, x, chunks, indexLayer, rootAt(i))
chunkedHashTreeRoot(binaryTreeHeight totalChunks, x, chunks, indexLayer, rootAt(i))
inc i
else:
when ElemType(typeof(x)) is BasicType: return unsupportedIndex
Expand Down Expand Up @@ -928,14 +941,14 @@ func hashTreeRootAux[T](
indexLayer = log2trunc(index)
if index == 1.GeneralizedIndex:
var contentsHash {.noinit.}: Digest
chunkedHashTreeRoot(totalChunks, asSeq x, contentsHash)
chunkedHashTreeRoot(binaryTreeHeight totalChunks, asSeq x, contentsHash)
mixInLength(contentsHash, x.len, rootAt(i))
inc i
elif index == 3.GeneralizedIndex:
hashTreeRootAux(x.len.uint64, rootAt(i))
inc i
elif index == 2.GeneralizedIndex:
chunkedHashTreeRoot(totalChunks, asSeq x, rootAt(i))
chunkedHashTreeRoot(binaryTreeHeight totalChunks, asSeq x, rootAt(i))
inc i
elif (index shr (indexLayer - 1)) == 2.GeneralizedIndex:
let
Expand All @@ -945,7 +958,7 @@ func hashTreeRootAux[T](
if indexLayer <= chunkLayer:
let chunks = chunksForIndex(index)
chunkedHashTreeRoot(
totalChunks, asSeq x, chunks, indexLayer, rootAt(i))
binaryTreeHeight totalChunks, asSeq x, chunks, indexLayer, rootAt(i))
inc i
else:
when ElemType(typeof(x)) is BasicType: return unsupportedIndex
Expand Down Expand Up @@ -1020,6 +1033,7 @@ func hashTreeRootAux[T](
trs "MERKLEIZING FIELDS"
const
totalChunks = totalSerializedFields(T)
treeHeight = binaryTreeHeight(totalChunks)
firstChunkIndex = nextPow2(totalChunks.uint64)
chunkLayer = log2trunc(firstChunkIndex)
var
Expand All @@ -1029,7 +1043,7 @@ func hashTreeRootAux[T](
index {.noinit.}: GeneralizedIndex
indexLayer {.noinit.}: int
chunks {.noinit.}: Slice[Limit]
merkleizer {.noinit.}: SszMerkleizer[totalSerializedFields(T)]
merkleizer {.noinit.}: SszMerkleizer2[treeHeight]
chunk {.noinit.}: Limit
nextField {.noinit.}: Limit
x.enumerateSubFields(f):
Expand Down Expand Up @@ -1221,7 +1235,8 @@ func hashTreeRootCached*(
inc i
elif indexLayer == chunkLayer:
let chunks = chunksForIndex(index)
chunkedHashTreeRoot(totalChunks, x.data, chunks, indexLayer, rootAt(i))
chunkedHashTreeRoot(
binaryTreeHeight totalChunks, x.data, chunks, indexLayer, rootAt(i))
inc i
elif indexLayer < chunkLayer:
rootAt(i) = hashTreeRootCachedPtr(x, index.int64)[]
Expand Down Expand Up @@ -1279,7 +1294,7 @@ func hashTreeRootCached*(
if indexLayer == chunkLayer:
let chunks = chunksForIndex(index)
chunkedHashTreeRoot(
totalChunks, asSeq x.data, chunks, indexLayer, rootAt(i))
binaryTreeHeight totalChunks, asSeq x.data, chunks, indexLayer, rootAt(i))
inc i
elif indexLayer < chunkLayer:
rootAt(i) = hashTreeRootCachedPtr(x, index.int64)[]
Expand Down

0 comments on commit 66de36a

Please sign in to comment.