Skip to content

Commit

Permalink
Broaden the Keychain definition of public keys to include more than…
Browse files Browse the repository at this point in the history
… BLS (#16953)

This PR adds a new dimension to the keychain in which it can store more
than one type of public information. It converts most instances of
`G1Element` to a new Protocol called `ObservationRoot` which must simply
implement the ability to serialize to bytes and generate a fingerprint.
This is crucial for upcoming applications in which the public
information used to sync a wallet is not a BLS public key but rather a
launcher ID, or perhaps other key types like SECP, etc.

The private information is left constrained to BLS secret keys for now
as it is less immediately useful.
  • Loading branch information
Quexington committed Mar 26, 2024
2 parents 061ee98 + eb62c81 commit 23c9ee8
Show file tree
Hide file tree
Showing 17 changed files with 224 additions and 104 deletions.
3 changes: 2 additions & 1 deletion chia/_tests/core/daemon/test_daemon.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,7 @@
from chia.simulator.setup_services import setup_full_node
from chia.util.config import load_config
from chia.util.json_util import dict_to_json_str
from chia.util.keychain import Keychain, KeyData, supports_os_passphrase_storage
from chia.util.keychain import Keychain, KeyData, KeyTypes, supports_os_passphrase_storage
from chia.util.keyring_wrapper import DEFAULT_PASSPHRASE_IF_NO_MASTER_PASSPHRASE, KeyringWrapper
from chia.util.ws_message import create_payload, create_payload_dict
from chia.wallet.derive_keys import master_sk_to_farmer_sk, master_sk_to_pool_sk
Expand Down Expand Up @@ -236,6 +236,7 @@ async def get_keys_for_plotting(self, request: Dict[str, Any]) -> Dict[str, Any]
def add_private_key_response_data(fingerprint: int) -> Dict[str, object]:
return {
"success": True,
"key_type": KeyTypes.G1_ELEMENT.value,
"fingerprint": fingerprint,
}

Expand Down
16 changes: 9 additions & 7 deletions chia/_tests/core/daemon/test_keychain_proxy.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
from typing import Any, AsyncGenerator

import pytest
from chia_rs import G1Element

from chia.daemon.keychain_proxy import KeychainProxy, connect_to_keychain_and_validate
from chia.simulator.block_tools import BlockTools
Expand Down Expand Up @@ -47,21 +48,22 @@ async def test_add_private_key(keychain_proxy: KeychainProxy) -> None:
@pytest.mark.anyio
async def test_add_public_key(keychain_proxy: KeychainProxy) -> None:
keychain = keychain_proxy
await keychain.add_key(bytes(TEST_KEY_3.public_key).hex(), TEST_KEY_3.label, private=False)
assert isinstance(TEST_KEY_3.observation_root, G1Element)
await keychain.add_key(TEST_KEY_3.public_key.hex(), TEST_KEY_3.label, private=False)
with pytest.raises(Exception, match="already exists"):
await keychain.add_key(bytes(TEST_KEY_3.public_key).hex(), "", private=False)
await keychain.add_key(TEST_KEY_3.public_key.hex(), "", private=False)
key = await keychain.get_key(TEST_KEY_3.fingerprint, include_secrets=False)
assert key is not None
assert key.public_key == TEST_KEY_3.public_key
assert key.observation_root == TEST_KEY_3.observation_root
assert key.secrets is None

pk = await keychain.get_key_for_fingerprint(TEST_KEY_3.fingerprint, private=False)
assert pk is not None
assert pk == TEST_KEY_3.public_key
assert pk == TEST_KEY_3.observation_root

pk = await keychain.get_key_for_fingerprint(None, private=False)
assert pk is not None
assert pk == TEST_KEY_3.public_key
assert pk == TEST_KEY_3.observation_root

with pytest.raises(KeychainKeyNotFound):
pk = await keychain.get_key_for_fingerprint(1234567890, private=False)
Expand All @@ -82,8 +84,8 @@ async def test_get_key_for_fingerprint(keychain_proxy: KeychainProxy) -> None:
with pytest.raises(KeychainIsEmpty):
await keychain.get_key_for_fingerprint(None, private=False)
await keychain_proxy.add_key(TEST_KEY_1.mnemonic_str(), TEST_KEY_1.label)
assert await keychain.get_key_for_fingerprint(TEST_KEY_1.fingerprint, private=False) == TEST_KEY_1.public_key
assert await keychain.get_key_for_fingerprint(None, private=False) == TEST_KEY_1.public_key
assert await keychain.get_key_for_fingerprint(TEST_KEY_1.fingerprint, private=False) == TEST_KEY_1.observation_root
assert await keychain.get_key_for_fingerprint(None, private=False) == TEST_KEY_1.observation_root
with pytest.raises(KeychainKeyNotFound):
await keychain.get_key_for_fingerprint(1234567890, private=False)

Expand Down
2 changes: 1 addition & 1 deletion chia/_tests/core/data_layer/test_data_rpc.py
Original file line number Diff line number Diff line change
Expand Up @@ -2288,7 +2288,7 @@ async def test_wallet_log_in_changes_active_fingerprint(

mnemonic = create_mnemonic()
assert wallet_rpc_api.service.local_keychain is not None
private_key = wallet_rpc_api.service.local_keychain.add_key(mnemonic_or_pk=mnemonic)
private_key, _ = wallet_rpc_api.service.local_keychain.add_key(mnemonic_or_pk=mnemonic)
secondary_fingerprint: int = private_key.get_g1().get_fingerprint()

await wallet_rpc_api.log_in(request={"fingerprint": primary_fingerprint})
Expand Down
41 changes: 33 additions & 8 deletions chia/_tests/core/util/test_keychain.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
import json
import random
from dataclasses import replace
from typing import Callable, List, Optional, Tuple
from typing import Callable, Dict, List, Optional, Tuple

import pkg_resources
import pytest
Expand All @@ -25,12 +25,14 @@
Keychain,
KeyData,
KeyDataSecrets,
KeyTypes,
bytes_from_mnemonic,
bytes_to_mnemonic,
generate_mnemonic,
mnemonic_from_short_words,
mnemonic_to_seed,
)
from chia.util.observation_root import ObservationRoot

mnemonic = (
"rapid this oven common drive ribbon bulb urban uncover napkin kitten usage enforce uncle unveil scene "
Expand Down Expand Up @@ -148,7 +150,7 @@ def test_bip39_eip2333_test_vector(self, empty_temp_file_keyring: TempKeyring):

mnemonic = "abandon abandon abandon abandon abandon abandon abandon abandon abandon abandon abandon about"
print("entropy to seed:", mnemonic_to_seed(mnemonic).hex())
master_sk = kc.add_key(mnemonic)
master_sk, _ = kc.add_key(mnemonic)
tv_master_int = 8075452428075949470768183878078858156044736575259233735633523546099624838313
tv_child_int = 18507161868329770878190303689452715596635858303241878571348190917018711023613
assert master_sk == PrivateKey.from_bytes(tv_master_int.to_bytes(32, "big"))
Expand Down Expand Up @@ -229,7 +231,7 @@ def test_key_data_generate(label: Optional[str]) -> None:
key_data = KeyData.generate(label)
assert key_data.private_key == AugSchemeMPL.key_gen(mnemonic_to_seed(key_data.mnemonic_str()))
assert key_data.entropy == bytes_from_mnemonic(key_data.mnemonic_str())
assert key_data.public_key == key_data.private_key.get_g1()
assert key_data.observation_root == key_data.private_key.get_g1()
assert key_data.fingerprint == key_data.private_key.get_g1().get_fingerprint()
assert key_data.label == label

Expand All @@ -241,7 +243,7 @@ def test_key_data_generate(label: Optional[str]) -> None:
def test_key_data_creation(input_data: object, from_method: Callable[..., KeyData], label: Optional[str]) -> None:
key_data = from_method(input_data, label)
assert key_data.fingerprint == fingerprint
assert key_data.public_key == public_key
assert key_data.observation_root == public_key
assert key_data.mnemonic == mnemonic.split()
assert key_data.mnemonic_str() == mnemonic
assert key_data.entropy == entropy
Expand All @@ -250,7 +252,7 @@ def test_key_data_creation(input_data: object, from_method: Callable[..., KeyDat


def test_key_data_without_secrets() -> None:
key_data = KeyData(fingerprint, public_key, None, None)
key_data = KeyData(fingerprint, bytes(public_key), None, None, KeyTypes.G1_ELEMENT.value)
assert key_data.secrets is None

with pytest.raises(KeychainSecretsMissing):
Expand Down Expand Up @@ -282,12 +284,21 @@ def test_key_data_secrets_post_init(input_data: Tuple[List[str], bytes, PrivateK
@pytest.mark.parametrize(
"input_data, data_type",
[
((fingerprint, G1Element(), None, KeyDataSecrets(mnemonic.split(), entropy, private_key)), "public_key"),
((fingerprint, G1Element(), None, None), "fingerprint"),
(
(
fingerprint,
bytes(G1Element()),
None,
KeyDataSecrets(mnemonic.split(), entropy, private_key),
KeyTypes.G1_ELEMENT.value,
),
"public_key",
),
((fingerprint, bytes(G1Element()), None, None, KeyTypes.G1_ELEMENT.value), "fingerprint"),
],
)
def test_key_data_post_init(
input_data: Tuple[uint32, G1Element, Optional[str], Optional[KeyDataSecrets]], data_type: str
input_data: Tuple[uint32, bytes, Optional[str], Optional[KeyDataSecrets], str], data_type: str
) -> None:
with pytest.raises(KeychainKeyDataMismatch, match=data_type):
KeyData(*input_data)
Expand Down Expand Up @@ -460,3 +471,17 @@ async def test_delete_drops_labels(get_temp_keyring: Keychain, delete_all: bool)
for key_data in keys:
keychain.delete_key_by_fingerprint(key_data.fingerprint)
assert keychain.keyring_wrapper.get_label(key_data.fingerprint) is None


@pytest.mark.parametrize("key_type", [e.value for e in KeyTypes])
def test_key_type_support(key_type: str) -> None:
"""
The purpose of this test is to make sure that whenever KeyTypes is updated, all relevant functionality is
also updated with it.
"""
generate_test_key_for_key_type: Dict[str, Tuple[int, bytes, ObservationRoot]] = {
KeyTypes.G1_ELEMENT.value: (G1Element().get_fingerprint(), bytes(G1Element()), G1Element())
}
obr_fingerprint, obr_bytes, obr = generate_test_key_for_key_type[key_type]
assert KeyData(uint32(obr_fingerprint), obr_bytes, None, None, key_type).observation_root == obr
assert KeyTypes.parse_observation_root(obr_bytes, KeyTypes(key_type)) == obr
23 changes: 13 additions & 10 deletions chia/_tests/wallet/test_wallet_node.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@
from chia.util.config import load_config
from chia.util.errors import Err
from chia.util.ints import uint8, uint32, uint64, uint128
from chia.util.keychain import Keychain, KeyData, generate_mnemonic
from chia.util.keychain import Keychain, KeyData, KeyTypes, generate_mnemonic
from chia.util.misc import to_batches
from chia.wallet.util.tx_config import DEFAULT_TX_CONFIG
from chia.wallet.wallet_node import Balance, WalletNode
Expand All @@ -39,7 +39,7 @@ async def test_get_private_key(root_path_populated_with_config: Path, get_temp_k
keychain = get_temp_keyring
config = load_config(root_path, "config.yaml", "wallet")
node = WalletNode(config, root_path, test_constants, keychain)
sk = keychain.add_key(generate_mnemonic())
sk, _ = keychain.add_key(generate_mnemonic())
fingerprint = sk.get_g1().get_fingerprint()

key = await node.get_key(fingerprint)
Expand All @@ -55,7 +55,7 @@ async def test_get_private_key_default_key(root_path_populated_with_config: Path
keychain = get_temp_keyring
config = load_config(root_path, "config.yaml", "wallet")
node = WalletNode(config, root_path, test_constants, keychain)
sk = keychain.add_key(generate_mnemonic())
sk, _ = keychain.add_key(generate_mnemonic())
fingerprint = sk.get_g1().get_fingerprint()

# Add a couple more keys
Expand Down Expand Up @@ -94,7 +94,7 @@ async def test_get_private_key_missing_key_use_default(
keychain = get_temp_keyring
config = load_config(root_path, "config.yaml", "wallet")
node = WalletNode(config, root_path, test_constants, keychain)
sk = keychain.add_key(generate_mnemonic())
sk, _ = keychain.add_key(generate_mnemonic())
fingerprint = sk.get_g1().get_fingerprint()

# Stupid sanity check that the fingerprint we're going to use isn't actually in the keychain
Expand All @@ -114,7 +114,7 @@ async def test_get_public_key(root_path_populated_with_config: Path, get_temp_ke
keychain: Keychain = get_temp_keyring
config: Dict[str, Any] = load_config(root_path, "config.yaml", "wallet")
node: WalletNode = WalletNode(config, root_path, test_constants, keychain)
pk: G1Element = keychain.add_key(
pk, key_type = keychain.add_key(
"c00000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000",
None,
private=False,
Expand All @@ -126,6 +126,7 @@ async def test_get_public_key(root_path_populated_with_config: Path, get_temp_ke
assert key is not None
assert isinstance(key, G1Element)
assert key.get_fingerprint() == fingerprint
assert key_type == KeyTypes.G1_ELEMENT


@pytest.mark.anyio
Expand All @@ -134,7 +135,7 @@ async def test_get_public_key_default_key(root_path_populated_with_config: Path,
keychain: Keychain = get_temp_keyring
config: Dict[str, Any] = load_config(root_path, "config.yaml", "wallet")
node: WalletNode = WalletNode(config, root_path, test_constants, keychain)
pk: G1Element = keychain.add_key(
pk, key_type = keychain.add_key(
"c00000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000",
None,
private=False,
Expand All @@ -159,6 +160,7 @@ async def test_get_public_key_default_key(root_path_populated_with_config: Path,
assert key is not None
assert isinstance(key, G1Element)
assert key.get_fingerprint() == fingerprint
assert key_type == KeyTypes.G1_ELEMENT


@pytest.mark.anyio
Expand All @@ -185,7 +187,7 @@ async def test_get_public_key_missing_key_use_default(
keychain: Keychain = get_temp_keyring
config: Dict[str, Any] = load_config(root_path, "config.yaml", "wallet")
node: WalletNode = WalletNode(config, root_path, test_constants, keychain)
pk: G1Element = keychain.add_key(
pk, key_type = keychain.add_key(
"c00000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000",
None,
private=False,
Expand All @@ -201,14 +203,15 @@ async def test_get_public_key_missing_key_use_default(
assert key is not None
assert isinstance(key, G1Element)
assert key.get_fingerprint() == fingerprint
assert key_type == KeyTypes.G1_ELEMENT


def test_log_in(root_path_populated_with_config: Path, get_temp_keyring: Keychain) -> None:
root_path = root_path_populated_with_config
keychain = get_temp_keyring
config = load_config(root_path, "config.yaml", "wallet")
node = WalletNode(config, root_path, test_constants)
sk = keychain.add_key(generate_mnemonic())
sk, _ = keychain.add_key(generate_mnemonic())
fingerprint = sk.get_g1().get_fingerprint()

node.log_in(fingerprint)
Expand All @@ -234,7 +237,7 @@ def patched_update_last_used_fingerprint(self: Self) -> None:
keychain = get_temp_keyring
config = load_config(root_path, "config.yaml", "wallet")
node = WalletNode(config, root_path, test_constants)
sk = keychain.add_key(generate_mnemonic())
sk, _ = keychain.add_key(generate_mnemonic())
fingerprint = sk.get_g1().get_fingerprint()

# Expect log_in to succeed, even though we can't write the last used fingerprint
Expand All @@ -251,7 +254,7 @@ def test_log_out(root_path_populated_with_config: Path, get_temp_keyring: Keycha
keychain = get_temp_keyring
config = load_config(root_path, "config.yaml", "wallet")
node = WalletNode(config, root_path, test_constants)
sk = keychain.add_key(generate_mnemonic())
sk, _ = keychain.add_key(generate_mnemonic())
fingerprint = sk.get_g1().get_fingerprint()

node.log_in(fingerprint)
Expand Down
54 changes: 36 additions & 18 deletions chia/cmds/keys_funcs.py
Original file line number Diff line number Diff line change
Expand Up @@ -86,11 +86,11 @@ def add_key_info(mnemonic_or_pk: str, label: Optional[str]) -> None:
unlock_keyring()
try:
if check_mnemonic_validity(mnemonic_or_pk):
sk = Keychain().add_key(mnemonic_or_pk, label, private=True)
sk, _ = Keychain().add_key(mnemonic_or_pk, label, private=True)
fingerprint = sk.get_g1().get_fingerprint()
print(f"Added private key with public key fingerprint {fingerprint}")
else:
pk = Keychain().add_key(mnemonic_or_pk, label, private=False)
pk, _ = Keychain().add_key(mnemonic_or_pk, label, private=False)
fingerprint = pk.get_fingerprint()
print(f"Added public key with fingerprint {fingerprint}")

Expand Down Expand Up @@ -172,27 +172,32 @@ def process_key_data(key_data: KeyData) -> Dict[str, Any]:
key["label"] = key_data.label

key["fingerprint"] = key_data.fingerprint
key["master_pk"] = bytes(key_data.public_key).hex()
if isinstance(key_data.observation_root, G1Element):
key["master_pk"] = key_data.public_key.hex()
else: # pragma: no cover
# TODO: Add test coverage once vault wallet exists
key["observation_root"] = key_data.public_key.hex()
if sk is not None:
key["farmer_pk"] = bytes(master_sk_to_farmer_sk(sk).get_g1()).hex()
key["pool_pk"] = bytes(master_sk_to_pool_sk(sk).get_g1()).hex()
else:
key["farmer_pk"] = None
key["pool_pk"] = None

if non_observer_derivation:
if sk is None:
first_wallet_pk: Optional[G1Element] = None
if isinstance(key_data.observation_root, G1Element):
if non_observer_derivation:
if sk is None:
first_wallet_pk: Optional[G1Element] = None
else:
first_wallet_pk = master_sk_to_wallet_sk(sk, uint32(0)).get_g1()
else:
first_wallet_pk = master_sk_to_wallet_sk(sk, uint32(0)).get_g1()
else:
first_wallet_pk = master_pk_to_wallet_pk_unhardened(key_data.public_key, uint32(0))
first_wallet_pk = master_pk_to_wallet_pk_unhardened(key_data.observation_root, uint32(0))

if first_wallet_pk is not None:
wallet_address: str = encode_puzzle_hash(create_puzzlehash_for_pk(first_wallet_pk), prefix)
key["wallet_address"] = wallet_address
else:
key["wallet_address"] = None
if first_wallet_pk is not None:
wallet_address: str = encode_puzzle_hash(create_puzzlehash_for_pk(first_wallet_pk), prefix)
key["wallet_address"] = wallet_address
else:
key["wallet_address"] = None

key["non_observer"] = non_observer_derivation

Expand Down Expand Up @@ -514,8 +519,13 @@ def search_derive(
private_keys = [private_key]
else:
master_key_data = Keychain().get_key(fingerprint, include_secrets=True)
public_keys = [master_key_data.public_key]
private_keys = [master_key_data.private_key if master_key_data.secrets is not None else None]
if isinstance(master_key_data.observation_root, G1Element):
public_keys = [master_key_data.observation_root]
private_keys = [master_key_data.private_key if master_key_data.secrets is not None else None]
else: # pragma: no cover
# TODO: Add test coverage once vault wallet exists
print("Cannot currently derive paths from non-BLS keys")
return True

for pk, sk in zip(public_keys, private_keys):
if sk is None and non_observer_derivation:
Expand Down Expand Up @@ -648,14 +658,18 @@ def derive_wallet_address(
"""
if fingerprint is not None:
key_data: KeyData = Keychain().get_key(fingerprint, include_secrets=non_observer_derivation)
if not isinstance(key_data.observation_root, G1Element): # pragma: no cover
# TODO: Add test coverage once vault wallet exists
print("Cannot currently derive from non-BLS keys")
return
if non_observer_derivation and key_data.secrets is None:
print("Need a private key for non observer derivation of wallet addresses")
return
elif non_observer_derivation:
sk = key_data.private_key
else:
sk = None
pk = key_data.public_key
pk: G1Element = key_data.observation_root
else:
assert private_key is not None
sk = private_key
Expand Down Expand Up @@ -712,7 +726,11 @@ def derive_child_key(

if fingerprint is not None:
key_data: KeyData = Keychain().get_key(fingerprint, include_secrets=True)
current_pk: G1Element = key_data.public_key
if not isinstance(key_data.observation_root, G1Element): # pragma: no cover
# TODO: Add coverage when vault wallet exists
print("Cannot currently derive from non-BLS keys")
return
current_pk: G1Element = key_data.observation_root
current_sk: Optional[PrivateKey] = key_data.private_key if key_data.secrets is not None else None
else:
assert private_key is not None
Expand Down
Loading

0 comments on commit 23c9ee8

Please sign in to comment.