diff --git a/chia/full_node/coin_store.py b/chia/full_node/coin_store.py index dcfbef07a951..b741687a6794 100644 --- a/chia/full_node/coin_store.py +++ b/chia/full_node/coin_store.py @@ -411,6 +411,97 @@ async def get_coin_states_by_ids( return coins + async def batch_coin_states_by_puzzle_hashes( + self, + puzzle_hashes: List[bytes32], + *, + min_height: uint32 = uint32(0), + include_spent: bool = True, + include_unspent: bool = True, + include_hinted: bool = True, + max_items: int = 50000, + ) -> Tuple[List[CoinState], Optional[uint32]]: + """ + Returns the coin states, as well as the next block height (or `None` if finished). + Note that the maximum number of puzzle hashes is currently set to 15000. + """ + + # This number is chosen such that it's below half of the Python 3.8+ SQLite variable limit. + # It can be changed later without breaking the protocol, but this is a practical limit for now. + assert len(puzzle_hashes) <= 15000 + + coin_states: List[CoinState] = [] + + async with self.db_wrapper.reader_no_transaction() as conn: + puzzle_hashes_db = tuple(puzzle_hashes) + puzzle_hash_count = len(puzzle_hashes_db) + + if include_hinted: + require_spent = "cr.spent_index>0" + require_unspent = "cr.spent_index=0" + else: + require_spent = "spent_index>0" + require_unspent = "spent_index=0" + + if include_spent and include_unspent: + height_filter = "" + elif include_spent: + height_filter = f"AND {require_spent}" + elif include_unspent: + height_filter = f"AND {require_unspent}" + else: + # There are no coins which are both spent and unspent, so we're finished. + return [], None + + if include_hinted: + cursor = await conn.execute( + f"SELECT cr.confirmed_index, cr.spent_index, cr.coinbase, cr.puzzle_hash, " + f"cr.coin_parent, cr.amount, cr.timestamp FROM coin_record cr " + f"LEFT JOIN hints h ON cr.coin_name = h.coin_id " + f'WHERE (cr.puzzle_hash in ({"?," * (puzzle_hash_count - 1)}?) ' + f'OR h.hint in ({"?," * (puzzle_hash_count - 1)}?)) ' + f"AND (cr.confirmed_index>=? OR cr.spent_index>=?) " + f"{height_filter} " + f"ORDER BY MAX(cr.confirmed_index, cr.spent_index) ASC " + f"LIMIT ?", + puzzle_hashes_db + puzzle_hashes_db + (min_height, min_height, max_items + 1), + ) + else: + cursor = await conn.execute( + f"SELECT confirmed_index, spent_index, coinbase, puzzle_hash, " + f"coin_parent, amount, timestamp FROM coin_record INDEXED BY coin_puzzle_hash " + f'WHERE puzzle_hash in ({"?," * (puzzle_hash_count - 1)}?) ' + f"AND (confirmed_index>=? OR spent_index>=?) " + f"{height_filter} " + f"ORDER BY MAX(confirmed_index, spent_index) ASC " + f"LIMIT ?", + puzzle_hashes_db + (min_height, min_height, max_items + 1), + ) + + for row in await cursor.fetchall(): + coin_states.append(self.row_to_coin_state(row)) + + # If there aren't too many coin states, we've finished syncing these hashes. + # There is no next height to start from, so return `None`. + if len(coin_states) <= max_items: + return coin_states, None + + # The last item is the start of the next batch of coin states. + next_coin_state = coin_states.pop() + next_height = uint32(max(next_coin_state.created_height or 0, next_coin_state.spent_height or 0)) + + # In order to prevent blocks from being split up between batches, remove + # all coin states whose max height is the same as the last coin state's height. + while len(coin_states) > 0: + last_coin_state = coin_states[-1] + height = uint32(max(last_coin_state.created_height or 0, last_coin_state.spent_height or 0)) + if height != next_height: + break + + coin_states.pop() + + return coin_states, next_height + async def rollback_to_block(self, block_index: int) -> List[CoinRecord]: """ Note that block_index can be negative, in which case everything is rolled back diff --git a/tests/core/full_node/stores/test_coin_store.py b/tests/core/full_node/stores/test_coin_store.py index 818dde3a52ce..12dd038bfcbc 100644 --- a/tests/core/full_node/stores/test_coin_store.py +++ b/tests/core/full_node/stores/test_coin_store.py @@ -1,6 +1,7 @@ from __future__ import annotations import logging +from dataclasses import dataclass from pathlib import Path from typing import List, Optional, Set, Tuple @@ -11,7 +12,9 @@ from chia.consensus.coinbase import create_farmer_coin, create_pool_coin from chia.full_node.block_store import BlockStore from chia.full_node.coin_store import CoinStore +from chia.full_node.hint_store import HintStore from chia.full_node.mempool_check_conditions import get_name_puzzle_conditions +from chia.protocols.wallet_protocol import CoinState from chia.simulator.block_tools import BlockTools, test_constants from chia.simulator.wallet_tools import WalletTool from chia.types.blockchain_format.coin import Coin @@ -493,6 +496,192 @@ async def test_get_coin_states(db_version: int) -> None: assert len(await coin_store.get_coin_states_by_ids(True, coins, uint32(0), max_items=10000)) == 600 +@dataclass(frozen=True) +class RandomCoinRecords: + items: List[CoinRecord] + puzzle_hashes: List[bytes32] + hints: List[Tuple[bytes32, bytes]] + + +@pytest.fixture(scope="session") +def random_coin_records() -> RandomCoinRecords: + coin_records: List[CoinRecord] = [] + puzzle_hashes: List[bytes32] = [] + hints: List[Tuple[bytes32, bytes]] = [] + + for i in range(50000): + is_spent = i % 2 == 0 + is_hinted = i % 7 == 0 + created_height = uint32(i) + spent_height = uint32(created_height + 100) + + puzzle_hash = std_hash(i.to_bytes(4, byteorder="big")) + + coin = Coin( + std_hash(b"Parent Coin Id " + i.to_bytes(4, byteorder="big")), + puzzle_hash, + uint64(1000), + ) + + if is_hinted: + hint = std_hash(b"Hinted " + puzzle_hash) + hints.append((coin.name(), hint)) + puzzle_hashes.append(hint) + else: + puzzle_hashes.append(puzzle_hash) + + coin_records.append( + CoinRecord( + coin=coin, + confirmed_block_index=created_height, + spent_block_index=spent_height if is_spent else uint32(0), + coinbase=False, + timestamp=uint64(0), + ) + ) + + coin_records.sort(key=lambda cr: max(cr.confirmed_block_index, cr.spent_block_index)) + + return RandomCoinRecords(coin_records, puzzle_hashes, hints) + + +@pytest.mark.anyio +@pytest.mark.parametrize("include_spent", [True, False]) +@pytest.mark.parametrize("include_unspent", [True, False]) +@pytest.mark.parametrize("include_hinted", [True, False]) +async def test_coin_state_batches( + db_version: int, + random_coin_records: RandomCoinRecords, + include_spent: bool, + include_unspent: bool, + include_hinted: bool, +) -> None: + async with DBConnection(db_version) as db_wrapper: + # Initialize coin and hint stores. + coin_store = await CoinStore.create(db_wrapper) + hint_store = await HintStore.create(db_wrapper) + + await coin_store._add_coin_records(random_coin_records.items) + await hint_store.add_hints(random_coin_records.hints) + + # Make sure all of the coin states are found when batching. + ph_set = set(random_coin_records.puzzle_hashes) + expected_crs = [] + for cr in random_coin_records.items: + if cr.spent_block_index == 0 and not include_unspent: + continue + if cr.spent_block_index > 0 and not include_spent: + continue + if cr.coin.puzzle_hash not in ph_set and not include_hinted: + continue + expected_crs.append(cr) + + height: Optional[uint32] = uint32(0) + all_coin_states: List[CoinState] = [] + remaining_phs = random_coin_records.puzzle_hashes.copy() + + def height_of(coin_state: CoinState) -> int: + return max(coin_state.created_height or 0, coin_state.spent_height or 0) + + while height is not None: + (coin_states, height) = await coin_store.batch_coin_states_by_puzzle_hashes( + remaining_phs[:15000], + min_height=height, + include_spent=include_spent, + include_unspent=include_unspent, + include_hinted=include_hinted, + ) + + # Ensure that all of the returned coin states are in order. + assert all(height_of(coin_states[i]) <= height_of(coin_states[i + 1]) for i in range(len(coin_states) - 1)) + + all_coin_states += coin_states + + if height is None: + remaining_phs = remaining_phs[15000:] + + if len(remaining_phs) > 0: + height = uint32(0) + + assert len(all_coin_states) == len(expected_crs) + + all_coin_states.sort(key=height_of) + + for i in range(len(expected_crs)): + actual = all_coin_states[i] + expected = expected_crs[i] + + assert actual.coin == expected.coin, i + assert uint32(actual.created_height or 0) == expected.confirmed_block_index, i + assert uint32(actual.spent_height or 0) == expected.spent_block_index, i + + +@pytest.mark.anyio +@pytest.mark.parametrize("cut_off_middle", [True, False]) +async def test_batch_many_coin_states(db_version: int, cut_off_middle: bool) -> None: + async with DBConnection(db_version) as db_wrapper: + ph = bytes32(b"0" * 32) + + # Generate coin records. + coin_records: List[CoinRecord] = [] + count = 50000 + + for i in range(count): + # Create coin records at either height 10 or 12. + created_height = uint32((i % 2) * 2 + 10) + coin = Coin( + std_hash(b"Parent Coin Id " + i.to_bytes(4, byteorder="big")), + ph, + uint64(i), + ) + coin_records.append( + CoinRecord( + coin=coin, + confirmed_block_index=created_height, + spent_block_index=uint32(0), + coinbase=False, + timestamp=uint64(0), + ) + ) + + # Initialize coin and hint stores. + coin_store = await CoinStore.create(db_wrapper) + await HintStore.create(db_wrapper) + + await coin_store._add_coin_records(coin_records) + + # Make sure all of the coin states are found. + (all_coin_states, next_height) = await coin_store.batch_coin_states_by_puzzle_hashes([ph]) + all_coin_states.sort(key=lambda cs: cs.coin.amount) + + assert next_height is None + assert len(all_coin_states) == len(coin_records) + + for i in range(min(len(coin_records), len(all_coin_states))): + assert coin_records[i].coin.name().hex() == all_coin_states[i].coin.name().hex(), i + + # For the middle case, insert a coin record between the two heights 10 and 12. + await coin_store._add_coin_records( + [ + CoinRecord( + coin=Coin(std_hash(b"extra coin"), ph, 0), + # Insert a coin record in the middle between heights 10 and 12. + # Or after all of the other coins if testing the batch limit. + confirmed_block_index=uint32(11 if cut_off_middle else 50), + spent_block_index=uint32(0), + coinbase=False, + timestamp=uint64(0), + ) + ] + ) + + (all_coin_states, next_height) = await coin_store.batch_coin_states_by_puzzle_hashes([ph]) + + # Make sure that the extra coin records are not included in the results. + assert next_height == (12 if cut_off_middle else 50) + assert len(all_coin_states) == (25001 if cut_off_middle else 50000) + + @pytest.mark.anyio async def test_unsupported_version() -> None: with pytest.raises(RuntimeError, match="CoinStore does not support database schema v1"):