Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add way to get coin states in batches #17300

Merged
merged 17 commits into from
Jan 23, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
91 changes: 91 additions & 0 deletions chia/full_node/coin_store.py
Original file line number Diff line number Diff line change
Expand Up @@ -411,6 +411,97 @@ async def get_coin_states_by_ids(

return coins

async def batch_coin_states_by_puzzle_hashes(
self,
puzzle_hashes: List[bytes32],
*,
min_height: uint32 = uint32(0),
include_spent: bool = True,
Rigidity marked this conversation as resolved.
Show resolved Hide resolved
include_unspent: bool = True,
include_hinted: bool = True,
max_items: int = 50000,
) -> Tuple[List[CoinState], Optional[uint32]]:
"""
Returns the coin states, as well as the next block height (or `None` if finished).
Note that the maximum number of puzzle hashes is currently set to 15000.
"""

# This number is chosen such that it's below half of the Python 3.8+ SQLite variable limit.
# It can be changed later without breaking the protocol, but this is a practical limit for now.
assert len(puzzle_hashes) <= 15000
emlowe marked this conversation as resolved.
Show resolved Hide resolved

coin_states: List[CoinState] = []

async with self.db_wrapper.reader_no_transaction() as conn:
puzzle_hashes_db = tuple(puzzle_hashes)
puzzle_hash_count = len(puzzle_hashes_db)

if include_hinted:
require_spent = "cr.spent_index>0"
require_unspent = "cr.spent_index=0"
else:
require_spent = "spent_index>0"
require_unspent = "spent_index=0"

if include_spent and include_unspent:
height_filter = ""
elif include_spent:
height_filter = f"AND {require_spent}"
elif include_unspent:
height_filter = f"AND {require_unspent}"
else:
# There are no coins which are both spent and unspent, so we're finished.
return [], None

if include_hinted:
cursor = await conn.execute(
f"SELECT cr.confirmed_index, cr.spent_index, cr.coinbase, cr.puzzle_hash, "
f"cr.coin_parent, cr.amount, cr.timestamp FROM coin_record cr "
f"LEFT JOIN hints h ON cr.coin_name = h.coin_id "
f'WHERE (cr.puzzle_hash in ({"?," * (puzzle_hash_count - 1)}?) '
f'OR h.hint in ({"?," * (puzzle_hash_count - 1)}?)) '
Rigidity marked this conversation as resolved.
Show resolved Hide resolved
f"AND (cr.confirmed_index>=? OR cr.spent_index>=?) "
f"{height_filter} "
f"ORDER BY MAX(cr.confirmed_index, cr.spent_index) ASC "
f"LIMIT ?",
puzzle_hashes_db + puzzle_hashes_db + (min_height, min_height, max_items + 1),
)
else:
cursor = await conn.execute(
f"SELECT confirmed_index, spent_index, coinbase, puzzle_hash, "
f"coin_parent, amount, timestamp FROM coin_record INDEXED BY coin_puzzle_hash "
f'WHERE puzzle_hash in ({"?," * (puzzle_hash_count - 1)}?) '
f"AND (confirmed_index>=? OR spent_index>=?) "
f"{height_filter} "
f"ORDER BY MAX(confirmed_index, spent_index) ASC "
f"LIMIT ?",
puzzle_hashes_db + (min_height, min_height, max_items + 1),
)
Rigidity marked this conversation as resolved.
Show resolved Hide resolved

for row in await cursor.fetchall():
coin_states.append(self.row_to_coin_state(row))

# If there aren't too many coin states, we've finished syncing these hashes.
# There is no next height to start from, so return `None`.
if len(coin_states) <= max_items:
return coin_states, None

# The last item is the start of the next batch of coin states.
next_coin_state = coin_states.pop()
next_height = uint32(max(next_coin_state.created_height or 0, next_coin_state.spent_height or 0))

# In order to prevent blocks from being split up between batches, remove
# all coin states whose max height is the same as the last coin state's height.
while len(coin_states) > 0:
last_coin_state = coin_states[-1]
height = uint32(max(last_coin_state.created_height or 0, last_coin_state.spent_height or 0))
if height != next_height:
break
Rigidity marked this conversation as resolved.
Show resolved Hide resolved

coin_states.pop()

return coin_states, next_height

async def rollback_to_block(self, block_index: int) -> List[CoinRecord]:
"""
Note that block_index can be negative, in which case everything is rolled back
Expand Down
189 changes: 189 additions & 0 deletions tests/core/full_node/stores/test_coin_store.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
from __future__ import annotations

import logging
from dataclasses import dataclass
from pathlib import Path
from typing import List, Optional, Set, Tuple

Expand All @@ -11,7 +12,9 @@
from chia.consensus.coinbase import create_farmer_coin, create_pool_coin
from chia.full_node.block_store import BlockStore
from chia.full_node.coin_store import CoinStore
from chia.full_node.hint_store import HintStore
from chia.full_node.mempool_check_conditions import get_name_puzzle_conditions
from chia.protocols.wallet_protocol import CoinState
from chia.simulator.block_tools import BlockTools, test_constants
from chia.simulator.wallet_tools import WalletTool
from chia.types.blockchain_format.coin import Coin
Expand Down Expand Up @@ -493,6 +496,192 @@ async def test_get_coin_states(db_version: int) -> None:
assert len(await coin_store.get_coin_states_by_ids(True, coins, uint32(0), max_items=10000)) == 600


@dataclass(frozen=True)
class RandomCoinRecords:
items: List[CoinRecord]
puzzle_hashes: List[bytes32]
hints: List[Tuple[bytes32, bytes]]


@pytest.fixture(scope="session")
def random_coin_records() -> RandomCoinRecords:
coin_records: List[CoinRecord] = []
puzzle_hashes: List[bytes32] = []
hints: List[Tuple[bytes32, bytes]] = []

for i in range(50000):
is_spent = i % 2 == 0
is_hinted = i % 7 == 0
created_height = uint32(i)
spent_height = uint32(created_height + 100)

puzzle_hash = std_hash(i.to_bytes(4, byteorder="big"))

coin = Coin(
std_hash(b"Parent Coin Id " + i.to_bytes(4, byteorder="big")),
puzzle_hash,
uint64(1000),
)

if is_hinted:
hint = std_hash(b"Hinted " + puzzle_hash)
hints.append((coin.name(), hint))
puzzle_hashes.append(hint)
else:
puzzle_hashes.append(puzzle_hash)

coin_records.append(
CoinRecord(
coin=coin,
confirmed_block_index=created_height,
spent_block_index=spent_height if is_spent else uint32(0),
coinbase=False,
timestamp=uint64(0),
)
)

coin_records.sort(key=lambda cr: max(cr.confirmed_block_index, cr.spent_block_index))

return RandomCoinRecords(coin_records, puzzle_hashes, hints)


@pytest.mark.anyio
@pytest.mark.parametrize("include_spent", [True, False])
@pytest.mark.parametrize("include_unspent", [True, False])
@pytest.mark.parametrize("include_hinted", [True, False])
async def test_coin_state_batches(
db_version: int,
random_coin_records: RandomCoinRecords,
include_spent: bool,
include_unspent: bool,
include_hinted: bool,
) -> None:
async with DBConnection(db_version) as db_wrapper:
# Initialize coin and hint stores.
coin_store = await CoinStore.create(db_wrapper)
hint_store = await HintStore.create(db_wrapper)

await coin_store._add_coin_records(random_coin_records.items)
await hint_store.add_hints(random_coin_records.hints)

# Make sure all of the coin states are found when batching.
ph_set = set(random_coin_records.puzzle_hashes)
expected_crs = []
for cr in random_coin_records.items:
if cr.spent_block_index == 0 and not include_unspent:
continue
if cr.spent_block_index > 0 and not include_spent:
continue
if cr.coin.puzzle_hash not in ph_set and not include_hinted:
continue
expected_crs.append(cr)

height: Optional[uint32] = uint32(0)
all_coin_states: List[CoinState] = []
remaining_phs = random_coin_records.puzzle_hashes.copy()

def height_of(coin_state: CoinState) -> int:
return max(coin_state.created_height or 0, coin_state.spent_height or 0)

while height is not None:
(coin_states, height) = await coin_store.batch_coin_states_by_puzzle_hashes(
remaining_phs[:15000],
min_height=height,
include_spent=include_spent,
include_unspent=include_unspent,
include_hinted=include_hinted,
)

# Ensure that all of the returned coin states are in order.
assert all(height_of(coin_states[i]) <= height_of(coin_states[i + 1]) for i in range(len(coin_states) - 1))

all_coin_states += coin_states

if height is None:
remaining_phs = remaining_phs[15000:]

if len(remaining_phs) > 0:
height = uint32(0)

assert len(all_coin_states) == len(expected_crs)

all_coin_states.sort(key=height_of)

for i in range(len(expected_crs)):
actual = all_coin_states[i]
expected = expected_crs[i]

assert actual.coin == expected.coin, i
assert uint32(actual.created_height or 0) == expected.confirmed_block_index, i
assert uint32(actual.spent_height or 0) == expected.spent_block_index, i


@pytest.mark.anyio
@pytest.mark.parametrize("cut_off_middle", [True, False])
async def test_batch_many_coin_states(db_version: int, cut_off_middle: bool) -> None:
async with DBConnection(db_version) as db_wrapper:
ph = bytes32(b"0" * 32)

# Generate coin records.
coin_records: List[CoinRecord] = []
count = 50000

for i in range(count):
# Create coin records at either height 10 or 12.
created_height = uint32((i % 2) * 2 + 10)
coin = Coin(
std_hash(b"Parent Coin Id " + i.to_bytes(4, byteorder="big")),
ph,
uint64(i),
)
coin_records.append(
CoinRecord(
coin=coin,
confirmed_block_index=created_height,
spent_block_index=uint32(0),
coinbase=False,
timestamp=uint64(0),
)
)

# Initialize coin and hint stores.
coin_store = await CoinStore.create(db_wrapper)
await HintStore.create(db_wrapper)

await coin_store._add_coin_records(coin_records)

# Make sure all of the coin states are found.
(all_coin_states, next_height) = await coin_store.batch_coin_states_by_puzzle_hashes([ph])
all_coin_states.sort(key=lambda cs: cs.coin.amount)

assert next_height is None
assert len(all_coin_states) == len(coin_records)

for i in range(min(len(coin_records), len(all_coin_states))):
assert coin_records[i].coin.name().hex() == all_coin_states[i].coin.name().hex(), i

# For the middle case, insert a coin record between the two heights 10 and 12.
await coin_store._add_coin_records(
[
CoinRecord(
coin=Coin(std_hash(b"extra coin"), ph, 0),
# Insert a coin record in the middle between heights 10 and 12.
# Or after all of the other coins if testing the batch limit.
confirmed_block_index=uint32(11 if cut_off_middle else 50),
spent_block_index=uint32(0),
coinbase=False,
timestamp=uint64(0),
)
]
)

(all_coin_states, next_height) = await coin_store.batch_coin_states_by_puzzle_hashes([ph])

# Make sure that the extra coin records are not included in the results.
assert next_height == (12 if cut_off_middle else 50)
assert len(all_coin_states) == (25001 if cut_off_middle else 50000)


@pytest.mark.anyio
async def test_unsupported_version() -> None:
with pytest.raises(RuntimeError, match="CoinStore does not support database schema v1"):
Expand Down
Loading