Skip to content

Commit

Permalink
[CHIA-1087] Simplify batch pre validate blocks (#18602)
Browse files Browse the repository at this point in the history
* minor simplification of stacked if-conditions and early exits on failure paths

* Simplify NPCResult -> SpendBundleConditions

* make include_spends() take SpendBundleConditions, rather than NPCResult
  • Loading branch information
arvidn authored Sep 26, 2024
1 parent 439cd07 commit 0059b5e
Show file tree
Hide file tree
Showing 3 changed files with 48 additions and 47 deletions.
9 changes: 4 additions & 5 deletions chia/consensus/block_body_validation.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
from dataclasses import dataclass, field
from typing import Awaitable, Callable, Collection, Dict, List, Optional, Set, Tuple, Union

from chia_rs import AugSchemeMPL, BLSCache, G1Element
from chia_rs import AugSchemeMPL, BLSCache, G1Element, SpendBundleConditions
from chiabip158 import PyBIP158

from chia.consensus.block_record import BlockRecord
Expand Down Expand Up @@ -85,7 +85,7 @@ def reset(self, fork_height: int, header_hash: bytes32) -> None:
self.removals_since_fork = {}
self.block_hashes = []

def include_spends(self, npc_result: Optional[NPCResult], block: FullBlock, header_hash: bytes32) -> None:
def include_spends(self, conds: Optional[SpendBundleConditions], block: FullBlock, header_hash: bytes32) -> None:
height = block.height

assert self.peak_height == height - 1
Expand All @@ -97,11 +97,10 @@ def include_spends(self, npc_result: Optional[NPCResult], block: FullBlock, head
self.peak_height = int(block.height)
self.peak_hash = header_hash

if npc_result is not None:
assert npc_result.conds is not None
if conds is not None:
assert block.foliage_transaction_block is not None
timestamp = block.foliage_transaction_block.timestamp
for spend in npc_result.conds.spends:
for spend in conds.spends:
self.removals_since_fork[bytes32(spend.coin_id)] = ForkRem(bytes32(spend.puzzle_hash), height)
for puzzle_hash, amount, hint in spend.create_coin:
coin = Coin(bytes32(spend.coin_id), bytes32(puzzle_hash), uint64(amount))
Expand Down
6 changes: 3 additions & 3 deletions chia/consensus/blockchain.py
Original file line number Diff line number Diff line change
Expand Up @@ -282,7 +282,7 @@ async def run_single_block(self, block: FullBlock, fork_info: ForkInfo) -> None:
)
assert npc.error is None

fork_info.include_spends(npc, block, block.header_hash)
fork_info.include_spends(None if npc is None else npc.conds, block, block.header_hash)

async def add_block(
self,
Expand Down Expand Up @@ -412,7 +412,7 @@ async def add_block(
# main chain, we still need to re-run it to update the additions and
# removals in fork_info.
await self.advance_fork_info(block, fork_info)
fork_info.include_spends(npc_result, block, header_hash)
fork_info.include_spends(None if npc_result is None else npc_result.conds, block, header_hash)
self.add_block_record(block_rec)
return AddBlockResult.ALREADY_HAVE_BLOCK, None, None

Expand Down Expand Up @@ -444,7 +444,7 @@ async def add_block(
# case we're validating blocks on a fork, the next block validation will
# need to know of these additions and removals. Also, _reconsider_peak()
# will need these results
fork_info.include_spends(npc_result, block, header_hash)
fork_info.include_spends(None if npc_result is None else npc_result.conds, block, header_hash)

# block_to_block_record() require the previous block in the cache
if not genesis and prev_block is not None:
Expand Down
80 changes: 41 additions & 39 deletions chia/consensus/multiprocess_validation.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@
from dataclasses import dataclass
from typing import Dict, List, Optional, Sequence, Tuple

from chia_rs import AugSchemeMPL
from chia_rs import AugSchemeMPL, SpendBundleConditions

from chia.consensus.block_header_validation import validate_finished_header_block
from chia.consensus.block_record import BlockRecord
Expand Down Expand Up @@ -53,7 +53,7 @@ def batch_pre_validate_blocks(
blocks_pickled: Dict[bytes, bytes],
full_blocks_pickled: List[bytes],
prev_transaction_generators: List[Optional[List[bytes]]],
npc_results: Dict[uint32, bytes],
conditions: Dict[uint32, bytes],
expected_difficulty: List[uint64],
expected_sub_slot_iters: List[uint64],
validate_signatures: bool,
Expand All @@ -71,16 +71,14 @@ def batch_pre_validate_blocks(
block: FullBlock = FullBlock.from_bytes_unchecked(full_blocks_pickled[i])
tx_additions: List[Coin] = []
removals: List[bytes32] = []
npc_result: Optional[NPCResult] = None
if block.height in npc_results:
npc_result = NPCResult.from_bytes(npc_results[block.height])
assert npc_result is not None
if npc_result.conds is not None:
removals, tx_additions = tx_removals_and_additions(npc_result.conds)
else:
removals, tx_additions = [], []

if block.transactions_generator is not None and npc_result is None:
conds: Optional[SpendBundleConditions] = None
if block.height in conditions:
conds = SpendBundleConditions.from_bytes(conditions[block.height])
removals, tx_additions = tx_removals_and_additions(conds)
elif block.transactions_generator is not None:
# TODO: this function would be simpler if conditions were
# required to be passed in for all transaction blocks. We would
# no longer need prev_transaction_generators
prev_generators = prev_transaction_generators[i]
assert prev_generators is not None
assert block.transactions_info is not None
Expand All @@ -93,15 +91,17 @@ def batch_pre_validate_blocks(
height=block.height,
constants=constants,
)
removals, tx_additions = tx_removals_and_additions(npc_result.conds)
if npc_result is not None and npc_result.error is not None:
validation_time = time.monotonic() - validation_start
results.append(
PreValidationResult(
uint16(npc_result.error), None, npc_result, False, uint32(validation_time * 1000)
if npc_result.error is not None:
validation_time = time.monotonic() - validation_start
results.append(
PreValidationResult(
uint16(npc_result.error), None, npc_result, False, uint32(validation_time * 1000)
)
)
)
continue
continue
assert npc_result.conds is not None
conds = npc_result.conds
removals, tx_additions = tx_removals_and_additions(conds)

header_block = get_block_header(block, tx_additions, removals)
prev_ses_block = None
Expand All @@ -123,28 +123,28 @@ def batch_pre_validate_blocks(
error_int = uint16(error.code.value)

successfully_validated_signatures = False
# If we failed CLVM, no need to validate signature, the block is already invalid
if error_int is None:
# If this is False, it means either we don't have a signature (not a tx block) or we have an invalid
# signature (which also puts in an error) or we didn't validate the signature because we want to
# validate it later. add_block will attempt to validate the signature later.
if validate_signatures:
if npc_result is not None and block.transactions_info is not None:
assert npc_result.conds
pairs_pks, pairs_msgs = pkm_pairs(npc_result.conds, constants.AGG_SIG_ME_ADDITIONAL_DATA)
if not AugSchemeMPL.aggregate_verify(
pairs_pks, pairs_msgs, block.transactions_info.aggregated_signature
):
error_int = uint16(Err.BAD_AGGREGATE_SIGNATURE.value)
else:
successfully_validated_signatures = True
# If we failed header block validation, no need to validate
# signature, the block is already invalid If this is False, it means
# either we don't have a signature (not a tx block) or we have an
# invalid signature (which also puts in an error) or we didn't
# validate the signature because we want to validate it later.
# add_block will attempt to validate the signature later.
if error_int is None and validate_signatures and conds is not None:
assert block.transactions_info is not None
pairs_pks, pairs_msgs = pkm_pairs(conds, constants.AGG_SIG_ME_ADDITIONAL_DATA)
if not AugSchemeMPL.aggregate_verify(
pairs_pks, pairs_msgs, block.transactions_info.aggregated_signature
):
error_int = uint16(Err.BAD_AGGREGATE_SIGNATURE.value)
else:
successfully_validated_signatures = True

validation_time = time.monotonic() - validation_start
results.append(
PreValidationResult(
error_int,
required_iters,
npc_result,
None if conds is None else NPCResult(None, conds),
successfully_validated_signatures,
uint32(validation_time * 1000),
)
Expand Down Expand Up @@ -274,9 +274,11 @@ async def pre_validate_blocks_multiprocessing(
if block_rec.sub_epoch_summary_included is not None:
prev_ses_block = block_rec

npc_results_pickled = {}
conditions_pickled = {}
for k, v in npc_results.items():
npc_results_pickled[k] = bytes(v)
assert v.error is None
assert v.conds is not None
conditions_pickled[k] = bytes(v.conds)
futures = []
# Pool of workers to validate blocks concurrently
recent_blocks_bytes = {bytes(k): bytes(v) for k, v in recent_blocks.items()} # convert to bytes
Expand Down Expand Up @@ -321,7 +323,7 @@ async def pre_validate_blocks_multiprocessing(
recent_blocks_bytes,
b_pickled,
previous_generators,
npc_results_pickled,
conditions_pickled,
[diff_ssis[j][0] for j in range(i, end_i)],
[diff_ssis[j][1] for j in range(i, end_i)],
validate_signatures,
Expand Down

0 comments on commit 0059b5e

Please sign in to comment.