Skip to content

Commit

Permalink
refator(cardano): validate map key order in HashBuilderDict
Browse files Browse the repository at this point in the history
  • Loading branch information
davidmisiak committed Mar 11, 2022
1 parent 4eb1db1 commit b3b406d
Show file tree
Hide file tree
Showing 3 changed files with 50 additions and 54 deletions.
30 changes: 26 additions & 4 deletions core/src/apps/cardano/helpers/hash_builder_collection.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@

if TYPE_CHECKING:
from typing import Any, Generic, TypeVar
from trezor import wire
from trezor.utils import HashContext

T = TypeVar("T")
Expand Down Expand Up @@ -43,7 +44,13 @@ def _do_enter_item(self) -> None:

self.remaining -= 1

def _hash_item(self, item: Any) -> None:
def _hash_item(self, item: Any) -> bytes:
assert self.hash_fn is not None
encoded_item = cbor.encode(item)
self.hash_fn.update(encoded_item)
return encoded_item

def _hash_item_streamed(self, item: Any) -> None:
assert self.hash_fn is not None
for chunk in cbor.encode_streamed(item):
self.hash_fn.update(chunk)
Expand Down Expand Up @@ -74,7 +81,7 @@ def append(self, item: T) -> T:
if isinstance(item, HashBuilderCollection):
self._insert_child(item)
else:
self._hash_item(item)
self._hash_item_streamed(item)

return item

Expand All @@ -83,16 +90,31 @@ def _header_bytes(self) -> bytes:


class HashBuilderDict(HashBuilderCollection, Generic[K, V]):
key_order_error: wire.ProcessError
previous_encoded_key: bytes

def __init__(self, size: int, key_order_error: wire.ProcessError):
super().__init__(size)
self.key_order_error = key_order_error
self.previous_encoded_key = b""

def add(self, key: K, value: V) -> V:
self._do_enter_item()

# enter key, this must not nest
assert not isinstance(key, HashBuilderCollection)
self._hash_item(key)
encoded_key = self._hash_item(key)

# check key ordering
if not cbor.precedes(self.previous_encoded_key, encoded_key):
raise self.key_order_error
self.previous_encoded_key = encoded_key

# enter value, this can nest
if isinstance(value, HashBuilderCollection):
self._insert_child(value)
else:
self._hash_item(value)
self._hash_item_streamed(value)

return value

Expand Down
66 changes: 20 additions & 46 deletions core/src/apps/cardano/sign_tx.py
Original file line number Diff line number Diff line change
Expand Up @@ -185,7 +185,9 @@ async def sign_tx(
account_path_checker = AccountPathChecker()

hash_fn = hashlib.blake2b(outlen=32)
tx_dict: HashBuilderDict[int, Any] = HashBuilderDict(tx_body_map_item_count)
tx_dict: HashBuilderDict[int, Any] = HashBuilderDict(
tx_body_map_item_count, INVALID_TX_SIGNING_REQUEST
)
tx_dict.start(hash_fn)
with tx_dict:
await _process_transaction(ctx, msg, keychain, tx_dict, account_path_checker)
Expand Down Expand Up @@ -296,7 +298,7 @@ async def _process_transaction(

if msg.withdrawals_count > 0:
withdrawals_dict: HashBuilderDict[bytes, int] = HashBuilderDict(
msg.withdrawals_count
msg.withdrawals_count, INVALID_WITHDRAWAL
)
with tx_dict.add(TX_BODY_KEY_WITHDRAWALS, withdrawals_dict):
await _process_withdrawals(
Expand Down Expand Up @@ -324,7 +326,7 @@ async def _process_transaction(

if msg.minting_asset_groups_count > 0:
minting_dict: HashBuilderDict[bytes, HashBuilderDict] = HashBuilderDict(
msg.minting_asset_groups_count
msg.minting_asset_groups_count, INVALID_TOKEN_BUNDLE_MINT
)
with tx_dict.add(TX_BODY_KEY_MINT, minting_dict):
await _process_minting(ctx, minting_dict)
Expand Down Expand Up @@ -468,7 +470,9 @@ async def _process_outputs(
output_value_list.append(output.amount)
asset_groups_dict: HashBuilderDict[
bytes, HashBuilderDict[bytes, int]
] = HashBuilderDict(output.asset_groups_count)
] = HashBuilderDict(
output.asset_groups_count, INVALID_TOKEN_BUNDLE_OUTPUT
)
with output_value_list.append(asset_groups_dict):
await _process_asset_groups(
ctx,
Expand All @@ -492,15 +496,15 @@ async def _process_asset_groups(
should_show_tokens: bool,
) -> None:
"""Read, validate and serialize the asset groups of an output."""
previous_policy_id: bytes = b""
for _ in range(asset_groups_count):
asset_group: CardanoAssetGroup = await ctx.call(
CardanoTxItemAck(), CardanoAssetGroup
)
_validate_asset_group(asset_group, previous_policy_id)
previous_policy_id = asset_group.policy_id
_validate_asset_group(asset_group)

tokens: HashBuilderDict[bytes, int] = HashBuilderDict(asset_group.tokens_count)
tokens: HashBuilderDict[bytes, int] = HashBuilderDict(
asset_group.tokens_count, INVALID_TOKEN_BUNDLE_OUTPUT
)
with asset_groups_dict.add(asset_group.policy_id, tokens):
await _process_tokens(
ctx,
Expand All @@ -519,11 +523,9 @@ async def _process_tokens(
should_show_tokens: bool,
) -> None:
"""Read, validate, confirm and serialize the tokens of an asset group."""
previous_asset_name_bytes: bytes = b""
for _ in range(tokens_count):
token: CardanoToken = await ctx.call(CardanoTxItemAck(), CardanoToken)
_validate_token(token, previous_asset_name_bytes)
previous_asset_name_bytes = token.asset_name_bytes
_validate_token(token)
if should_show_tokens:
await confirm_sending_token(ctx, policy_id, token)

Expand Down Expand Up @@ -641,24 +643,18 @@ async def _process_withdrawals(
if withdrawals_count == 0:
return

previous_reward_address_bytes: bytes = b""
for _ in range(withdrawals_count):
withdrawal: CardanoTxWithdrawal = await ctx.call(
CardanoTxItemAck(), CardanoTxWithdrawal
)
_validate_withdrawal(
keychain,
withdrawal,
signing_mode,
protocol_magic,
network_id,
account_path_checker,
previous_reward_address_bytes,
)
reward_address_bytes = _derive_withdrawal_reward_address_bytes(
keychain, withdrawal, protocol_magic, network_id
)
previous_reward_address_bytes = reward_address_bytes

await confirm_withdrawal(ctx, withdrawal, reward_address_bytes, network_id)

Expand Down Expand Up @@ -707,15 +703,15 @@ async def _process_minting(

await show_warning_tx_contains_mint(ctx)

previous_policy_id: bytes = b""
for _ in range(token_minting.asset_groups_count):
asset_group: CardanoAssetGroup = await ctx.call(
CardanoTxItemAck(), CardanoAssetGroup
)
_validate_asset_group(asset_group, previous_policy_id, is_mint=True)
previous_policy_id = asset_group.policy_id
_validate_asset_group(asset_group, is_mint=True)

tokens: HashBuilderDict[bytes, int] = HashBuilderDict(asset_group.tokens_count)
tokens: HashBuilderDict[bytes, int] = HashBuilderDict(
asset_group.tokens_count, INVALID_TOKEN_BUNDLE_MINT
)
with minting_dict.add(asset_group.policy_id, tokens):
await _process_minting_tokens(
ctx,
Expand All @@ -732,11 +728,9 @@ async def _process_minting_tokens(
tokens_count: int,
) -> None:
"""Read, validate, confirm and serialize the tokens of an asset group."""
previous_asset_name_bytes: bytes = b""
for _ in range(tokens_count):
token: CardanoToken = await ctx.call(CardanoTxItemAck(), CardanoToken)
_validate_token(token, previous_asset_name_bytes, is_mint=True)
previous_asset_name_bytes = token.asset_name_bytes
_validate_token(token, is_mint=True)
await confirm_token_minting(ctx, policy_id, token)

assert token.mint_amount is not None # _validate_token
Expand Down Expand Up @@ -1005,7 +999,7 @@ async def _show_output(


def _validate_asset_group(
asset_group: CardanoAssetGroup, previous_policy_id: bytes, is_mint: bool = False
asset_group: CardanoAssetGroup, is_mint: bool = False
) -> None:
INVALID_TOKEN_BUNDLE = (
INVALID_TOKEN_BUNDLE_MINT if is_mint else INVALID_TOKEN_BUNDLE_OUTPUT
Expand All @@ -1015,13 +1009,9 @@ def _validate_asset_group(
raise INVALID_TOKEN_BUNDLE
if asset_group.tokens_count == 0:
raise INVALID_TOKEN_BUNDLE
if not cbor.are_canonically_ordered(previous_policy_id, asset_group.policy_id):
raise INVALID_TOKEN_BUNDLE


def _validate_token(
token: CardanoToken, previous_asset_name_bytes: bytes, is_mint: bool = False
) -> None:
def _validate_token(token: CardanoToken, is_mint: bool = False) -> None:
INVALID_TOKEN_BUNDLE = (
INVALID_TOKEN_BUNDLE_MINT if is_mint else INVALID_TOKEN_BUNDLE_OUTPUT
)
Expand All @@ -1035,10 +1025,6 @@ def _validate_token(

if len(token.asset_name_bytes) > MAX_ASSET_NAME_LENGTH:
raise INVALID_TOKEN_BUNDLE
if not cbor.are_canonically_ordered(
previous_asset_name_bytes, token.asset_name_bytes
):
raise INVALID_TOKEN_BUNDLE


async def _show_certificate(
Expand All @@ -1065,13 +1051,9 @@ async def _show_certificate(


def _validate_withdrawal(
keychain: seed.Keychain,
withdrawal: CardanoTxWithdrawal,
signing_mode: CardanoTxSigningMode,
protocol_magic: int,
network_id: int,
account_path_checker: AccountPathChecker,
previous_reward_address_bytes: bytes,
) -> None:
validate_stake_credential(
withdrawal.path,
Expand All @@ -1084,14 +1066,6 @@ def _validate_withdrawal(
if not 0 <= withdrawal.amount < LOVELACE_MAX_SUPPLY:
raise INVALID_WITHDRAWAL

reward_address_bytes = _derive_withdrawal_reward_address_bytes(
keychain, withdrawal, protocol_magic, network_id
)
if not cbor.are_canonically_ordered(
previous_reward_address_bytes, reward_address_bytes
):
raise INVALID_WITHDRAWAL

account_path_checker.add_withdrawal(withdrawal)


Expand Down
8 changes: 4 additions & 4 deletions core/src/apps/common/cbor.py
Original file line number Diff line number Diff line change
Expand Up @@ -320,11 +320,11 @@ def create_map_header(size: int) -> bytes:
return _header(_CBOR_MAP, size)


def are_canonically_ordered(previous: Value, current: Value) -> bool:
def precedes(prev: bytes, curr: bytes) -> bool:
"""
Returns True if `previous` is smaller than `current` with regards to
Returns True if `prev` is smaller than `curr` with regards to
the cbor map key ordering as defined in
https://datatracker.ietf.org/doc/html/rfc7049#section-3.9
Note that `prev` and `curr` must already be cbor-encoded.
"""
u, v = encode(previous), encode(current)
return len(u) < len(v) or (len(u) == len(v) and u < v)
return len(prev) < len(curr) or (len(prev) == len(curr) and prev < curr)

0 comments on commit b3b406d

Please sign in to comment.