Skip to content

Commit

Permalink
60% state replay speedup (#4434)
Browse files Browse the repository at this point in the history
* 60% state replay speedup

* don't use HashList for epoch participation - in addition to the code
currently clearing the caches several times redundantly, clearing has to
be done each block nullifying the benefit (35%)
* introduce active balance cache - computing it is slow due to cache
unfriendliness in the random access pattern and bounds checking and we
do it for every block - this cache follows the same update pattern as
the active validator index cache (20%)
* avoid recomputing base reward several times per attestation (5%)

Applying 1024 blocks goes from 20s to ~8s on my laptop - these kinds of
requests happen on historical REST queries but also whenever there's a
reorg.

* fix test and diffs
  • Loading branch information
arnetheduck authored Dec 19, 2022
1 parent 064d164 commit 7501f10
Show file tree
Hide file tree
Showing 8 changed files with 68 additions and 40 deletions.
20 changes: 12 additions & 8 deletions beacon_chain/spec/beaconstate.nim
Original file line number Diff line number Diff line change
Expand Up @@ -637,8 +637,13 @@ func get_total_active_balance*(state: ForkyBeaconState, cache: var StateCache):

let epoch = state.get_current_epoch()

get_total_balance(
state, cache.get_shuffled_active_validator_indices(state, epoch))
cache.total_active_balance.withValue(epoch, tab) do:
return tab[]
do:
let tab = get_total_balance(
state, cache.get_shuffled_active_validator_indices(state, epoch))
cache.total_active_balance[epoch] = tab
return tab

# https://github.com/ethereum/consensus-specs/blob/v1.3.0-alpha.2/specs/altair/beacon-chain.md#get_base_reward_per_increment
func get_base_reward_per_increment_sqrt*(
Expand Down Expand Up @@ -704,15 +709,15 @@ func get_proposer_reward*(state: ForkyBeaconState,
state, attestation.data, state.slot - attestation.data.slot)
for index in get_attesting_indices(
state, attestation.data, attestation.aggregation_bits, cache):
let
base_reward = get_base_reward(state, index, base_reward_per_increment)
for flag_index, weight in PARTICIPATION_FLAG_WEIGHTS:
if flag_index in participation_flag_indices and
not has_flag(epoch_participation.item(index), flag_index):
epoch_participation[index] =
asList(epoch_participation)[index] =
add_flag(epoch_participation.item(index), flag_index)
# these are all valid; TODO statically verify or do it type-safely
result += get_base_reward(
state, index, base_reward_per_increment) * weight.uint64
epoch_participation.asHashList.clearCache()
result += base_reward * weight.uint64

let proposer_reward_denominator =
(WEIGHT_DENOMINATOR.uint64 - PROPOSER_WEIGHT.uint64) *
Expand Down Expand Up @@ -860,8 +865,7 @@ func upgrade_to_altair*(cfg: RuntimeConfig, pre: phase0.BeaconState):
empty_participation: EpochParticipationFlags
inactivity_scores = HashList[uint64, Limit VALIDATOR_REGISTRY_LIMIT]()

doAssert empty_participation.data.setLen(pre.validators.len)
empty_participation.asHashList.resetCache()
doAssert empty_participation.asList.setLen(pre.validators.len)

doAssert inactivity_scores.data.setLen(pre.validators.len)
inactivity_scores.resetCache()
Expand Down
48 changes: 32 additions & 16 deletions beacon_chain/spec/datatypes/altair.nim
Original file line number Diff line number Diff line change
Expand Up @@ -78,7 +78,7 @@ type
ParticipationFlags* = uint8

EpochParticipationFlags* =
distinct HashList[ParticipationFlags, Limit VALIDATOR_REGISTRY_LIMIT]
distinct List[ParticipationFlags, Limit VALIDATOR_REGISTRY_LIMIT]

# https://github.com/ethereum/consensus-specs/blob/v1.3.0-alpha.2/specs/altair/beacon-chain.md#syncaggregate
SyncAggregate* = object
Expand Down Expand Up @@ -558,10 +558,8 @@ type
# Represent in full; for the next epoch, current_epoch_participation in
# epoch n is previous_epoch_participation in epoch n+1 but this doesn't
# generalize.
previous_epoch_participation*:
List[ParticipationFlags, Limit VALIDATOR_REGISTRY_LIMIT]
current_epoch_participation*:
List[ParticipationFlags, Limit VALIDATOR_REGISTRY_LIMIT]
previous_epoch_participation*: EpochParticipationFlags
current_epoch_participation*: EpochParticipationFlags

justification_bits*: JustificationBits
previous_justified_checkpoint*: Checkpoint
Expand Down Expand Up @@ -589,26 +587,44 @@ template `[]`*(arr: array[SYNC_COMMITTEE_SIZE, auto] | seq;
makeLimitedU8(SyncSubcommitteeIndex, SYNC_COMMITTEE_SUBNET_COUNT)
makeLimitedU16(IndexInSyncCommittee, SYNC_COMMITTEE_SIZE)

template asHashList*(epochFlags: EpochParticipationFlags): untyped =
HashList[ParticipationFlags, Limit VALIDATOR_REGISTRY_LIMIT] epochFlags
template asList*(epochFlags: EpochParticipationFlags): untyped =
List[ParticipationFlags, Limit VALIDATOR_REGISTRY_LIMIT] epochFlags
template asList*(epochFlags: var EpochParticipationFlags): untyped =
let tmp = cast[ptr List[ParticipationFlags, Limit VALIDATOR_REGISTRY_LIMIT]](addr epochFlags)
tmp[]

template asSeq*(epochFlags: EpochParticipationFlags): untyped =
seq[ParticipationFlags] asList(epochFlags)

template asSeq*(epochFlags: var EpochParticipationFlags): untyped =
let tmp = cast[ptr seq[ParticipationFlags]](addr epochFlags)
tmp[]

template item*(epochFlags: EpochParticipationFlags, idx: ValidatorIndex): ParticipationFlags =
asHashList(epochFlags).item(idx)
asList(epochFlags)[idx]

template `[]`*(epochFlags: EpochParticipationFlags, idx: ValidatorIndex|uint64): ParticipationFlags =
asHashList(epochFlags)[idx]
template `[]`*(epochFlags: EpochParticipationFlags, idx: ValidatorIndex|uint64|int): ParticipationFlags =
asList(epochFlags)[idx]

template `[]=`*(epochFlags: EpochParticipationFlags, idx: ValidatorIndex, flags: ParticipationFlags) =
asHashList(epochFlags)[idx] = flags
asList(epochFlags)[idx] = flags

template add*(epochFlags: var EpochParticipationFlags, flags: ParticipationFlags): bool =
asHashList(epochFlags).add flags
asList(epochFlags).add flags

template len*(epochFlags: EpochParticipationFlags): int =
asHashList(epochFlags).len

template data*(epochFlags: EpochParticipationFlags): untyped =
asHashList(epochFlags).data
asList(epochFlags).len

template low*(epochFlags: EpochParticipationFlags): int =
asSeq(epochFlags).low
template high*(epochFlags: EpochParticipationFlags): int =
asSeq(epochFlags).high

template assign*(v: var EpochParticipationFlags, src: EpochParticipationFlags) =
# TODO https://github.com/nim-lang/Nim/issues/21123
mixin assign
var tmp = cast[ptr seq[ParticipationFlags]](addr v)
assign(tmp[], distinctBase src)

func shortLog*(v: SomeBeaconBlock): auto =
(
Expand Down
10 changes: 10 additions & 0 deletions beacon_chain/spec/datatypes/base.nim
Original file line number Diff line number Diff line change
Expand Up @@ -406,6 +406,7 @@ type
# This doesn't know about forks or branches in the DAG. It's for straight,
# linear chunks of the chain.
StateCache* = object
total_active_balance*: Table[Epoch, Gwei]
shuffled_active_validator_indices*: Table[Epoch, seq[ValidatorIndex]]
beacon_proposer_indices*: Table[Slot, Option[ValidatorIndex]]
sync_committees*: Table[SyncCommitteePeriod, SyncCommitteeCache]
Expand Down Expand Up @@ -923,6 +924,14 @@ func prune*(cache: var StateCache, epoch: Epoch) =
pruneEpoch = epoch - 2

var drops: seq[Slot]
block:
for k in cache.total_active_balance.keys:
if k < pruneEpoch:
drops.add pruneEpoch.start_slot
for drop in drops:
cache.total_active_balance.del drop.epoch
drops.setLen(0)

block:
for k in cache.shuffled_active_validator_indices.keys:
if k < pruneEpoch:
Expand All @@ -948,6 +957,7 @@ func prune*(cache: var StateCache, epoch: Epoch) =
drops.setLen(0)

func clear*(cache: var StateCache) =
cache.total_active_balance.clear
cache.shuffled_active_validator_indices.clear
cache.beacon_proposer_indices.clear
cache.sync_committees.clear
Expand Down
7 changes: 2 additions & 5 deletions beacon_chain/spec/eth2_apis/eth2_rest_serialization.nim
Original file line number Diff line number Diff line change
Expand Up @@ -612,15 +612,12 @@ proc readValue*(reader: var JsonReader[RestJson], value: var Epoch) {.
proc writeValue*(writer: var JsonWriter[RestJson],
epochFlags: EpochParticipationFlags)
{.raises: [IOError, Defect].} =
for e in writer.stepwiseArrayCreation(epochFlags.asHashList):
for e in writer.stepwiseArrayCreation(epochFlags.asList):
writer.writeValue $e

proc readValue*(reader: var JsonReader[RestJson],
epochFlags: var EpochParticipationFlags)
{.raises: [SerializationError, IOError, Defect].} =
# Please note that this function won't compute the cached hash tree roots
# immediately. They will be computed on the first HTR attempt.

for e in reader.readArray(string):
let parsed = try:
parseBiggestUInt(e)
Expand All @@ -632,7 +629,7 @@ proc readValue*(reader: var JsonReader[RestJson],
reader.raiseUnexpectedValue(
"The usigned integer value should fit in 8 bits")

if not epochFlags.data.add(uint8(parsed)):
if not epochFlags.asList.add(uint8(parsed)):
reader.raiseUnexpectedValue(
"The participation flags list size exceeds limit")

Expand Down
8 changes: 5 additions & 3 deletions beacon_chain/spec/ssz_codec.nim
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@ import
./datatypes/base

from ./datatypes/altair import
ParticipationFlags, EpochParticipationFlags, asHashList
ParticipationFlags, EpochParticipationFlags

export codec, base, typetraits, EpochParticipationFlags

Expand All @@ -28,7 +28,7 @@ template toSszType*(v: BlsCurveType): auto = toRaw(v)
template toSszType*(v: ForkDigest|GraffitiBytes): auto = distinctBase(v)
template toSszType*(v: Version): auto = distinctBase(v)
template toSszType*(v: JustificationBits): auto = distinctBase(v)
template toSszType*(epochFlags: EpochParticipationFlags): auto = asHashList epochFlags
template toSszType*(v: EpochParticipationFlags): auto = asList v

func fromSszBytes*(T: type GraffitiBytes, data: openArray[byte]): T {.raisesssz.} =
if data.len != sizeof(result):
Expand Down Expand Up @@ -60,4 +60,6 @@ func fromSszBytes*(T: type JustificationBits, bytes: openArray[byte]): T {.raise
copyMem(result.addr, unsafeAddr bytes[0], sizeof(result))

func fromSszBytes*(T: type EpochParticipationFlags, bytes: openArray[byte]): T {.raisesssz.} =
readSszValue(bytes, HashList[ParticipationFlags, Limit VALIDATOR_REGISTRY_LIMIT] result)
# TODO https://github.com/nim-lang/Nim/issues/21123
let tmp = cast[ptr List[ParticipationFlags, Limit VALIDATOR_REGISTRY_LIMIT]](addr result)
readSszValue(bytes, tmp[])
1 change: 1 addition & 0 deletions beacon_chain/spec/state_transition.nim
Original file line number Diff line number Diff line change
Expand Up @@ -120,6 +120,7 @@ func process_slot*(
hash_tree_root(state.latest_block_header)

func clear_epoch_from_cache(cache: var StateCache, epoch: Epoch) =
cache.total_active_balance.del epoch
cache.shuffled_active_validator_indices.del epoch

for slot in epoch.slots():
Expand Down
6 changes: 2 additions & 4 deletions beacon_chain/spec/state_transition_epoch.nim
Original file line number Diff line number Diff line change
Expand Up @@ -1005,13 +1005,11 @@ func process_participation_flag_updates*(

const zero = 0.ParticipationFlags
for i in 0 ..< state.current_epoch_participation.len:
state.current_epoch_participation.data[i] = zero
asList(state.current_epoch_participation)[i] = zero

# Shouldn't be wasted zeroing, because state.current_epoch_participation only
# grows. New elements are automatically initialized to 0, as required.
doAssert state.current_epoch_participation.data.setLen(state.validators.len)

state.current_epoch_participation.asHashList.resetCache()
doAssert state.current_epoch_participation.asList.setLen(state.validators.len)

# https://github.com/ethereum/consensus-specs/blob/v1.3.0-alpha.2/specs/altair/beacon-chain.md#sync-committee-updates
func process_sync_committee_updates*(
Expand Down
8 changes: 4 additions & 4 deletions beacon_chain/statediff.nim
Original file line number Diff line number Diff line change
Expand Up @@ -143,8 +143,8 @@ func diffStates*(state0, state1: bellatrix.BeaconState): BeaconStateDiff =
slashing: state1.slashings[state0.slot.epoch.uint64 mod
EPOCHS_PER_HISTORICAL_VECTOR.uint64],

previous_epoch_participation: state1.previous_epoch_participation.data,
current_epoch_participation: state1.current_epoch_participation.data,
previous_epoch_participation: state1.previous_epoch_participation,
current_epoch_participation: state1.current_epoch_participation,

justification_bits: state1.justification_bits,
previous_justified_checkpoint: state1.previous_justified_checkpoint,
Expand Down Expand Up @@ -192,9 +192,9 @@ func applyDiff*(
assign(state.slashings.mitem(epochIndex), stateDiff.slashing)

assign(
state.previous_epoch_participation.data, stateDiff.previous_epoch_participation)
state.previous_epoch_participation, stateDiff.previous_epoch_participation)
assign(
state.current_epoch_participation.data, stateDiff.current_epoch_participation)
state.current_epoch_participation, stateDiff.current_epoch_participation)

state.justification_bits = stateDiff.justification_bits
assign(
Expand Down

0 comments on commit 7501f10

Please sign in to comment.