Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Batched sync for wallet transactions #2995

Merged
merged 10 commits into from
Jul 15, 2020
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
277 changes: 206 additions & 71 deletions lbry/wallet/ledger.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,6 @@
import time
import asyncio
import logging
from io import StringIO
from datetime import datetime
from functools import partial
from operator import itemgetter
Expand Down Expand Up @@ -164,6 +163,7 @@ def __init__(self, config=None):
self._utxo_reservation_lock = asyncio.Lock()
self._header_processing_lock = asyncio.Lock()
self._address_update_locks: DefaultDict[str, asyncio.Lock] = defaultdict(asyncio.Lock)
self._history_lock = asyncio.Lock()

self.coin_selection_strategy = None
self._known_addresses_out_of_sync = set()
Expand Down Expand Up @@ -489,10 +489,10 @@ def process_status_update(self, update):
address, remote_status = update
self._update_tasks.add(self.update_history(address, remote_status))

async def update_history(self, address, remote_status, address_manager: AddressManager = None):
async def update_history(self, address, remote_status, address_manager: AddressManager = None,
reattempt_update: bool = True):
async with self._address_update_locks[address]:
self._known_addresses_out_of_sync.discard(address)

local_status, local_history = await self.get_local_status_and_history(address)

if local_status == remote_status:
Expand All @@ -502,69 +502,111 @@ async def update_history(self, address, remote_status, address_manager: AddressM
remote_history = list(map(itemgetter('tx_hash', 'height'), remote_history))
we_need = set(remote_history) - set(local_history)
if not we_need:
remote_missing = set(local_history) - set(remote_history)
if remote_missing:
log.warning(
"%i transactions we have for %s are not in the remote address history",
len(remote_missing), address
)
return True

cache_tasks: List[asyncio.Task[Transaction]] = []
synced_history = StringIO()
loop = asyncio.get_running_loop()
acquire_lock_tasks = []
synced_txs = []
to_request = {}
pending_synced_history = {}
updated_cached_items = {}
already_synced = set()

already_synced_offset = 0
for i, (txid, remote_height) in enumerate(remote_history):
if i < len(local_history) and local_history[i] == (txid, remote_height) and not cache_tasks:
synced_history.write(f'{txid}:{remote_height}:')
else:
check_local = (txid, remote_height) not in we_need
cache_tasks.append(loop.create_task(
self.cache_transaction(txid, remote_height, check_local=check_local)
))
if i == already_synced_offset and i < len(local_history) and local_history[i] == (txid, remote_height):
pending_synced_history[i] = f'{txid}:{remote_height}:'
already_synced.add((txid, remote_height))
already_synced_offset += 1
continue
cache_item = self._tx_cache.get(txid)
if cache_item is None:
cache_item = TransactionCacheItem()
self._tx_cache[txid] = cache_item

synced_txs = []
for task in cache_tasks:
tx = await task
for txid, remote_height in remote_history[already_synced_offset:]:
cache_item = self._tx_cache[txid]
acquire_lock_tasks.append(asyncio.create_task(cache_item.lock.acquire()))

check_db_for_txos = []
for txi in tx.inputs:
if txi.txo_ref.txo is not None:
continue
cache_item = self._tx_cache.get(txi.txo_ref.tx_ref.id)
if cache_item is not None:
if cache_item.tx is None:
await cache_item.has_tx.wait()
assert cache_item.tx is not None
txi.txo_ref = cache_item.tx.outputs[txi.txo_ref.position].ref
else:
check_db_for_txos.append(txi.txo_ref.id)
if acquire_lock_tasks:
await asyncio.wait(acquire_lock_tasks)

referenced_txos = {} if not check_db_for_txos else {
txo.id: txo for txo in await self.db.get_txos(
txoid__in=check_db_for_txos, order_by='txo.txoid', no_tx=True
)
}
tx_indexes = {}

for txi in tx.inputs:
if txi.txo_ref.txo is not None:
continue
referenced_txo = referenced_txos.get(txi.txo_ref.id)
if referenced_txo is not None:
txi.txo_ref = referenced_txo.ref
for i, (txid, remote_height) in enumerate(remote_history):
tx_indexes[txid] = i
if (txid, remote_height) in already_synced:
continue
cache_item = self._tx_cache.get(txid)
cache_item.pending_verifications += 1
updated_cached_items[txid] = cache_item

assert cache_item is not None, 'cache item is none'
assert cache_item.lock.locked(), 'cache lock is not held?'
# tx = cache_item.tx
# if cache_item.tx is not None and \
# cache_item.tx.height >= remote_height and \
# (cache_item.tx.is_verified or remote_height < 1):
# synced_txs.append(cache_item.tx) # cached tx is already up-to-date
# pending_synced_history[i] = f'{tx.id}:{tx.height}:'
# continue
to_request[i] = (txid, remote_height)

synced_history.write(f'{tx.id}:{tx.height}:')
log.debug(
"request %i transactions, %i/%i for %s are already synced", len(to_request), len(synced_txs),
len(remote_history), address
)
requested_txes = await self._request_transaction_batch(to_request, len(remote_history), address)
for tx in requested_txes:
pending_synced_history[tx_indexes[tx.id]] = f"{tx.id}:{tx.height}:"
synced_txs.append(tx)

assert len(pending_synced_history) == len(remote_history), \
f"{len(pending_synced_history)} vs {len(remote_history)}"
synced_history = ""
for remote_i, i in zip(range(len(remote_history)), sorted(pending_synced_history.keys())):
assert i == remote_i, f"{i} vs {remote_i}"
txid, height = remote_history[remote_i]
if f"{txid}:{height}:" != pending_synced_history[i]:
log.warning("history mismatch: %s vs %s", remote_history[remote_i], pending_synced_history[i])
synced_history += pending_synced_history[i]

cache_size = self.config.get("tx_cache_size", 100_000)
for txid, cache_item in updated_cached_items.items():
cache_item.pending_verifications -= 1
if cache_item.pending_verifications < 0:
log.warning("config value tx cache size %i needs to be increased", cache_size)
cache_item.pending_verifications = 0
try:
cache_item.lock.release()
except RuntimeError:
log.warning("lock was already released?")

await self.db.save_transaction_io_batch(
synced_txs, address, self.address_to_hash160(address), synced_history.getvalue()
[], address, self.address_to_hash160(address), synced_history
)
await asyncio.wait([
self._on_transaction_controller.add(TransactionEvent(address, tx))
for tx in synced_txs
])

if address_manager is None:
address_manager = await self.get_address_manager_for_address(address)

if address_manager is not None:
await address_manager.ensure_address_gap()

for txid, cache_item in updated_cached_items.items():
if self._tx_cache.get(txid) is not cache_item:
log.warning("tx cache corrupted while syncing %s, reattempt sync=%s", address, reattempt_update)
if reattempt_update:
return await self.update_history(address, remote_status, address_manager, False)
return False

local_status, local_history = \
await self.get_local_status_and_history(address, synced_history.getvalue())
await self.get_local_status_and_history(address, synced_history)

if local_status != remote_status:
if local_history == remote_history:
log.warning(
Expand All @@ -590,6 +632,7 @@ async def update_history(self, address, remote_status, address_manager: AddressM
self._known_addresses_out_of_sync.add(address)
return False
else:
log.debug("finished syncing transaction history for %s, %i known txs", address, len(local_history))
return True

async def cache_transaction(self, txid, remote_height, check_local=True):
Expand All @@ -601,41 +644,39 @@ async def cache_transaction(self, txid, remote_height, check_local=True):
(cache_item.tx.is_verified or remote_height < 1):
return cache_item.tx # cached tx is already up-to-date

cache_item.pending_verifications += 1
try:
cache_item.pending_verifications += 1
return await self._update_cache_item(cache_item, txid, remote_height, check_local)
async with cache_item.lock:
tx = cache_item.tx
if tx is None and check_local:
# check local db
tx = cache_item.tx = await self.db.get_transaction(txid=txid)
merkle = None
if tx is None:
# fetch from network
_raw, merkle = await self.network.retriable_call(
self.network.get_transaction_and_merkle, txid, remote_height
)
tx = Transaction(unhexlify(_raw), height=merkle['block_height'])
cache_item.tx = tx # make sure it's saved before caching it
tx.height = remote_height
if merkle and 0 < remote_height < len(self.headers):
merkle_root = self.get_root_of_merkle_tree(merkle['merkle'], merkle['pos'], tx.hash)
header = await self.headers.get(remote_height)
tx.position = merkle['pos']
tx.is_verified = merkle_root == header['merkle_root']
return tx
finally:
cache_item.pending_verifications -= 1

async def _update_cache_item(self, cache_item, txid, remote_height, check_local=True):

async with cache_item.lock:

tx = cache_item.tx

if tx is None and check_local:
# check local db
tx = cache_item.tx = await self.db.get_transaction(txid=txid)

merkle = None
if tx is None:
# fetch from network
_raw, merkle = await self.network.retriable_call(
self.network.get_transaction_and_merkle, txid, remote_height
)
tx = Transaction(unhexlify(_raw), height=merkle.get('block_height'))
cache_item.tx = tx # make sure it's saved before caching it
await self.maybe_verify_transaction(tx, remote_height, merkle)
return tx

async def maybe_verify_transaction(self, tx, remote_height, merkle=None):
tx.height = remote_height
cached = self._tx_cache.get(tx.id)
if not cached:
# cache txs looked up by transaction_show too
cached = TransactionCacheItem()
cached.tx = tx
self._tx_cache[tx.id] = cached
cached.tx = tx
if 0 < remote_height < len(self.headers) and cached.pending_verifications <= 1:
# can't be tx.pending_verifications == 1 because we have to handle the transaction_show case
if not merkle:
Expand All @@ -645,6 +686,100 @@ async def maybe_verify_transaction(self, tx, remote_height, merkle=None):
tx.position = merkle['pos']
tx.is_verified = merkle_root == header['merkle_root']

async def _request_transaction_batch(self, to_request, remote_history_size, address):
header_cache = {}
batches = [[]]
remote_heights = {}
synced_txs = []
heights_in_batch = 0
last_height = 0
for idx in sorted(to_request):
txid = to_request[idx][0]
height = to_request[idx][1]
remote_heights[txid] = height
if height != last_height:
heights_in_batch += 1
last_height = height
if len(batches[-1]) == 100 or heights_in_batch == 20:
batches.append([])
heights_in_batch = 1
batches[-1].append(txid)
if not batches[-1]:
batches.pop()

last_showed_synced_count = 0

async def _single_batch(batch):
this_batch_synced = []
batch_result = await self.network.retriable_call(self.network.get_transaction_batch, batch)
for txid, (raw, merkle) in batch_result.items():
remote_height = remote_heights[txid]
merkle_height = merkle['block_height']
cache_item = self._tx_cache.get(txid)
if cache_item is None:
cache_item = TransactionCacheItem()
self._tx_cache[txid] = cache_item
tx = cache_item.tx or Transaction(unhexlify(raw), height=remote_height)
tx.height = remote_height
cache_item.tx = tx
if 'merkle' in merkle and remote_heights[txid] > 0:
merkle_root = self.get_root_of_merkle_tree(merkle['merkle'], merkle['pos'], tx.hash)
try:
header = header_cache.get(remote_heights[txid]) or (await self.headers.get(merkle_height))
except IndexError:
log.warning("failed to verify %s at height %i", tx.id, merkle_height)
else:
header_cache[remote_heights[txid]] = header
tx.position = merkle['pos']
tx.is_verified = merkle_root == header['merkle_root']
check_db_for_txos = []

for txi in tx.inputs:
if txi.txo_ref.txo is not None:
continue
cache_item = self._tx_cache.get(txi.txo_ref.tx_ref.id)
if cache_item is not None:
if cache_item.tx is not None:
txi.txo_ref = cache_item.tx.outputs[txi.txo_ref.position].ref
else:
check_db_for_txos.append(txi.txo_ref.id)

referenced_txos = {} if not check_db_for_txos else {
txo.id: txo for txo in await self.db.get_txos(
txoid__in=check_db_for_txos, order_by='txo.txoid', no_tx=True
)
}

for txi in tx.inputs:
if txi.txo_ref.txo is not None:
continue
referenced_txo = referenced_txos.get(txi.txo_ref.id)
if referenced_txo is not None:
txi.txo_ref = referenced_txo.ref
continue
cache_item = self._tx_cache.get(txi.txo_ref.id)
if cache_item is None:
cache_item = self._tx_cache[txi.txo_ref.id] = TransactionCacheItem()
if cache_item.tx is not None:
txi.txo_ref = cache_item.tx.ref

synced_txs.append(tx)
this_batch_synced.append(tx)
await self.db.save_transaction_io_batch(
this_batch_synced, address, self.address_to_hash160(address), ""
)
await asyncio.wait([
self._on_transaction_controller.add(TransactionEvent(address, tx))
for tx in this_batch_synced
])
nonlocal last_showed_synced_count
if last_showed_synced_count + 100 < len(synced_txs):
log.info("synced %i/%i transactions for %s", len(synced_txs), remote_history_size, address)
last_showed_synced_count = len(synced_txs)
for batch in batches:
await _single_batch(batch)
return synced_txs

async def get_address_manager_for_address(self, address) -> Optional[AddressManager]:
details = await self.db.get_address(address=address)
for account in self.accounts:
Expand Down Expand Up @@ -697,7 +832,7 @@ async def _wait_round(self, tx: Transaction, height: int, addresses: Iterable[st
local_height, height
)
return False
log.debug(
log.warning(
"local history does not contain %s, requested height %i", tx.id, height
)
return False
Expand Down
4 changes: 4 additions & 0 deletions lbry/wallet/network.py
Original file line number Diff line number Diff line change
Expand Up @@ -255,6 +255,10 @@ def get_transaction(self, tx_hash, known_height=None):
restricted = known_height in (None, -1, 0) or 0 > known_height > self.remote_height - 10
return self.rpc('blockchain.transaction.get', [tx_hash], restricted)

def get_transaction_batch(self, txids):
# use any server if its old, otherwise restrict to who gave us the history
return self.rpc('blockchain.transaction.get_batch', txids, True)

def get_transaction_and_merkle(self, tx_hash, known_height=None):
# use any server if its old, otherwise restrict to who gave us the history
restricted = known_height in (None, -1, 0) or 0 > known_height > self.remote_height - 10
Expand Down
14 changes: 14 additions & 0 deletions lbry/wallet/server/mempool.py
Original file line number Diff line number Diff line change
Expand Up @@ -389,3 +389,17 @@ async def unordered_UTXOs(self, hashX):
if hX == hashX:
utxos.append(UTXO(-1, pos, tx_hash, 0, value))
return utxos

def get_mempool_height(self, tx_hash):
# Height Progression
# -2: not broadcast
# -1: in mempool but has unconfirmed inputs
# 0: in mempool and all inputs confirmed
# +num: confirmed in a specific block (height)
if tx_hash not in self.txs:
return -2
tx = self.txs[tx_hash]
unspent_inputs = sum(1 if hash in self.txs else 0 for hash, idx in tx.prevouts)
if unspent_inputs:
return -1
return 0
Loading