Skip to content

Commit

Permalink
[CHIA-1128] Generalize Keychain to support multiple key types (#18466)
Browse files Browse the repository at this point in the history
With the advent of vaults, in order to preserve existing UX where the
wallet can perform all actions by itself, you need a secp private key
available to the wallet. In order to store these private keys in the
keychain, this PR generalizes the keychain code to support keys of any
type.
  • Loading branch information
Quexington authored Sep 16, 2024
2 parents df53def + 2f30b6a commit d2f5d0a
Show file tree
Hide file tree
Showing 22 changed files with 394 additions and 215 deletions.
4 changes: 3 additions & 1 deletion chia/_tests/core/daemon/test_daemon.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@
import pytest
from aiohttp import WSMessage
from aiohttp.web_ws import WebSocketResponse
from chia_rs import G1Element
from chia_rs import G1Element, PrivateKey
from pytest_mock import MockerFixture

from chia._tests.util.misc import Marks, datacases
Expand Down Expand Up @@ -205,6 +205,8 @@ async def get_keys_for_plotting(self, request: Dict[str, Any]) -> Dict[str, Any]
"hammer stable page grunt venture purse canyon discover "
"egg vivid spare immune awake code announce message"
)
assert isinstance(test_key_data.private_key, PrivateKey)
assert isinstance(test_key_data_2.private_key, PrivateKey)

success_response_data = {
"success": True,
Expand Down
12 changes: 12 additions & 0 deletions chia/_tests/core/daemon/test_keychain_proxy.py
Original file line number Diff line number Diff line change
Expand Up @@ -100,3 +100,15 @@ async def test_get_keys(keychain_proxy_with_keys: KeychainProxy, include_secrets
else:
expected_keys = [replace(TEST_KEY_1, secrets=None), replace(TEST_KEY_2, secrets=None)]
assert keys == expected_keys


@pytest.mark.anyio
async def test_get_first_private_key(keychain_proxy_with_keys: KeychainProxy) -> None:
assert TEST_KEY_1.private_key == await keychain_proxy_with_keys.get_first_private_key()


@pytest.mark.anyio
async def test_get_all_private_keys(keychain_proxy_with_keys: KeychainProxy) -> None:
assert [TEST_KEY_1.private_key, TEST_KEY_2.private_key] == [
k for k, e in await keychain_proxy_with_keys.get_all_private_keys()
]
2 changes: 1 addition & 1 deletion chia/_tests/core/data_layer/test_data_rpc.py
Original file line number Diff line number Diff line change
Expand Up @@ -2373,7 +2373,7 @@ async def test_wallet_log_in_changes_active_fingerprint(
mnemonic = create_mnemonic()
assert wallet_rpc_api.service.local_keychain is not None
private_key, _ = wallet_rpc_api.service.local_keychain.add_key(mnemonic_or_pk=mnemonic)
secondary_fingerprint: int = private_key.get_g1().get_fingerprint()
secondary_fingerprint: int = private_key.public_key().get_fingerprint()

await wallet_rpc_api.log_in(request={"fingerprint": primary_fingerprint})

Expand Down
19 changes: 11 additions & 8 deletions chia/_tests/core/util/test_keychain.py
Original file line number Diff line number Diff line change
Expand Up @@ -98,7 +98,7 @@ def test_basic_add_delete(
entropy = bytes_from_mnemonic(mnemonic)
assert bytes_to_mnemonic(entropy) == mnemonic
mnemonic_2 = generate_mnemonic()
fingerprint_2 = AugSchemeMPL.key_gen(mnemonic_to_seed(mnemonic_2)).get_g1().get_fingerprint()
fingerprint_2 = AugSchemeMPL.key_gen(mnemonic_to_seed(mnemonic_2)).public_key().get_fingerprint()

# misspelled words in the mnemonic
bad_mnemonic = mnemonic.split(" ")
Expand Down Expand Up @@ -130,7 +130,7 @@ def test_basic_add_delete(

seed_2 = mnemonic_to_seed(mnemonic)
seed_key_2 = AugSchemeMPL.key_gen(seed_2)
kc.delete_key_by_fingerprint(seed_key_2.get_g1().get_fingerprint())
kc.delete_key_by_fingerprint(seed_key_2.public_key().get_fingerprint())
assert kc._get_free_private_key_index() == 0
assert len(kc.get_all_private_keys()) == 1

Expand Down Expand Up @@ -201,7 +201,7 @@ def test_bip39_eip2333_test_vector(self, empty_temp_file_keyring: TempKeyring):
tv_master_int = 8075452428075949470768183878078858156044736575259233735633523546099624838313
tv_child_int = 18507161868329770878190303689452715596635858303241878571348190917018711023613
assert master_sk == PrivateKey.from_bytes(tv_master_int.to_bytes(32, "big"))
child_sk = AugSchemeMPL.derive_child_sk(master_sk, 0)
child_sk = master_sk.derive_hardened(0)
assert child_sk == PrivateKey.from_bytes(tv_child_int.to_bytes(32, "big"))

def test_bip39_test_vectors(self):
Expand Down Expand Up @@ -279,8 +279,8 @@ def test_key_data_generate(label: Optional[str]) -> None:
key_data = KeyData.generate(label)
assert key_data.private_key == AugSchemeMPL.key_gen(mnemonic_to_seed(key_data.mnemonic_str()))
assert key_data.entropy == bytes_from_mnemonic(key_data.mnemonic_str())
assert key_data.observation_root == key_data.private_key.get_g1()
assert key_data.fingerprint == key_data.private_key.get_g1().get_fingerprint()
assert key_data.observation_root == key_data.private_key.public_key()
assert key_data.fingerprint == key_data.private_key.public_key().get_fingerprint()
assert key_data.label == label


Expand Down Expand Up @@ -323,10 +323,10 @@ def test_key_data_without_secrets(key_info: KeyInfo) -> None:
[
((_24keyinfo.mnemonic.split()[:-1], _24keyinfo.entropy, _24keyinfo.private_key), "mnemonic"),
((_24keyinfo.mnemonic.split(), KeyDataSecrets.generate().entropy, _24keyinfo.private_key), "entropy"),
((_24keyinfo.mnemonic.split(), _24keyinfo.entropy, KeyDataSecrets.generate().private_key), "private_key"),
((_24keyinfo.mnemonic.split(), _24keyinfo.entropy, KeyDataSecrets.generate().secret_info_bytes), "private_key"),
],
)
def test_key_data_secrets_post_init(input_data: Tuple[List[str], bytes, PrivateKey], data_type: str) -> None:
def test_key_data_secrets_post_init(input_data: Tuple[List[str], bytes, bytes], data_type: str) -> None:
with pytest.raises(KeychainKeyDataMismatch, match=data_type):
KeyDataSecrets(*input_data)

Expand All @@ -339,7 +339,7 @@ def test_key_data_secrets_post_init(input_data: Tuple[List[str], bytes, PrivateK
_24keyinfo.fingerprint,
bytes(G1Element()),
None,
KeyDataSecrets(_24keyinfo.mnemonic.split(), _24keyinfo.entropy, _24keyinfo.private_key),
KeyDataSecrets(_24keyinfo.mnemonic.split(), _24keyinfo.entropy, bytes(_24keyinfo.private_key)),
KeyTypes.G1_ELEMENT.value,
),
"public_key",
Expand Down Expand Up @@ -548,3 +548,6 @@ def test_key_type_support(key_type: str, key_info: KeyInfo) -> None:
assert KeyTypes.parse_observation_root(bytes(obr), KeyTypes(key_type)) == obr
if secret_info is not None:
assert KeyTypes.parse_secret_info(bytes(secret_info), KeyTypes(key_type)) == secret_info
assert (
KeyTypes.parse_secret_info_from_seed(mnemonic_to_seed(key_info.mnemonic), KeyTypes(key_type)) == secret_info
)
4 changes: 3 additions & 1 deletion chia/_tests/wallet/rpc/test_wallet_rpc.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@

import aiosqlite
import pytest
from chia_rs import G2Element
from chia_rs import G2Element, PrivateKey

from chia._tests.conftest import ConsensusMode
from chia._tests.environments.wallet import WalletStateTransition, WalletTestFramework
Expand Down Expand Up @@ -1666,10 +1666,12 @@ async def _check_delete_key(

sk = await wallet_node.get_key_for_fingerprint(farmer_fp, private=True)
assert sk is not None
assert isinstance(sk, PrivateKey)
farmer_ph = create_puzzlehash_for_pk(create_sk(sk, uint32(0)).get_g1())

sk = await wallet_node.get_key_for_fingerprint(pool_fp, private=True)
assert sk is not None
assert isinstance(sk, PrivateKey)
pool_ph = create_puzzlehash_for_pk(create_sk(sk, uint32(0)).get_g1())

with lock_and_load_config(wallet_node.root_path, "config.yaml") as test_config:
Expand Down
30 changes: 22 additions & 8 deletions chia/_tests/wallet/test_wallet_node.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,7 @@ async def test_get_private_key(root_path_populated_with_config: Path, get_temp_k
config = load_config(root_path, "config.yaml", "wallet")
node = WalletNode(config, root_path, test_constants, keychain)
sk, _ = keychain.add_key(generate_mnemonic())
fingerprint = sk.get_g1().get_fingerprint()
fingerprint = sk.public_key().get_fingerprint()

key = await node.get_key(fingerprint)

Expand All @@ -57,7 +57,7 @@ async def test_get_private_key_default_key(root_path_populated_with_config: Path
config = load_config(root_path, "config.yaml", "wallet")
node = WalletNode(config, root_path, test_constants, keychain)
sk, _ = keychain.add_key(generate_mnemonic())
fingerprint = sk.get_g1().get_fingerprint()
fingerprint = sk.public_key().get_fingerprint()

# Add a couple more keys
keychain.add_key(generate_mnemonic())
Expand All @@ -70,6 +70,20 @@ async def test_get_private_key_default_key(root_path_populated_with_config: Path
assert isinstance(key, PrivateKey)
assert key.get_g1().get_fingerprint() == fingerprint

# We should get the same result with a bogus fingerprint
key = await node.get_key(123456789)

assert key is not None
assert isinstance(key, PrivateKey)
assert key.get_g1().get_fingerprint() == fingerprint

# Test coverage
key = await node.get_key(123456789, private=False)

assert key is not None
assert isinstance(key, G1Element)
assert key.get_fingerprint() == fingerprint


@pytest.mark.anyio
@pytest.mark.parametrize("fingerprint", [None, 1234567890])
Expand Down Expand Up @@ -164,7 +178,7 @@ def test_log_in(root_path_populated_with_config: Path, get_temp_keyring: Keychai
config = load_config(root_path, "config.yaml", "wallet")
node = WalletNode(config, root_path, test_constants)
sk, _ = keychain.add_key(generate_mnemonic())
fingerprint = sk.get_g1().get_fingerprint()
fingerprint = sk.public_key().get_fingerprint()

node.log_in(fingerprint)

Expand All @@ -190,7 +204,7 @@ def patched_update_last_used_fingerprint(self: Self) -> None:
config = load_config(root_path, "config.yaml", "wallet")
node = WalletNode(config, root_path, test_constants)
sk, _ = keychain.add_key(generate_mnemonic())
fingerprint = sk.get_g1().get_fingerprint()
fingerprint = sk.public_key().get_fingerprint()

# Expect log_in to succeed, even though we can't write the last used fingerprint
node.log_in(fingerprint)
Expand All @@ -207,7 +221,7 @@ def test_log_out(root_path_populated_with_config: Path, get_temp_keyring: Keycha
config = load_config(root_path, "config.yaml", "wallet")
node = WalletNode(config, root_path, test_constants)
sk, _ = keychain.add_key(generate_mnemonic())
fingerprint = sk.get_g1().get_fingerprint()
fingerprint = sk.public_key().get_fingerprint()

node.log_in(fingerprint)

Expand Down Expand Up @@ -652,7 +666,7 @@ async def test_get_last_used_fingerprint_if_exists(
assert await node.get_last_used_fingerprint_if_exists() is None

sk_2, _ = await node.keychain_proxy.add_key(generate_mnemonic())
fingerprint_2: int = sk_2.get_g1().get_fingerprint()
fingerprint_2: int = sk_2.public_key().get_fingerprint()

node._close()
await node._await_closed(shutting_down=False)
Expand Down Expand Up @@ -747,7 +761,7 @@ async def restart_with_fingerprint(fingerprint: Optional[int]) -> None:

initial_sk = wallet_node.wallet_state_manager.private_key

sk_2: PrivateKey = (
sk_2 = (
await wallet_node.keychain_proxy.add_key(
(
"cup smoke miss park baby say island tomorrow segment lava bitter easily settle gift "
Expand All @@ -757,7 +771,7 @@ async def restart_with_fingerprint(fingerprint: Optional[int]) -> None:
private=True,
)
)[0]
fingerprint_2: int = sk_2.get_g1().get_fingerprint()
fingerprint_2: int = sk_2.public_key().get_fingerprint()

await restart_with_fingerprint(fingerprint_2)
assert wallet_node.wallet_state_manager.private_key == sk_2
Expand Down
5 changes: 4 additions & 1 deletion chia/_tests/wallet/vault/test_vault_wallet.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
from chia.rpc.wallet_request_types import VaultCreate, VaultRecovery
from chia.types.blockchain_format.sized_bytes import bytes32
from chia.util.ints import uint32, uint64
from chia.util.keychain import KeyTypes
from chia.wallet.payment import Payment
from chia.wallet.util.tx_config import DEFAULT_TX_CONFIG
from chia.wallet.vault.vault_info import VaultInfo
Expand Down Expand Up @@ -86,7 +87,9 @@ async def vault_setup(wallet_environments: WalletTestFramework, with_recovery: b
),
]
)
await env.node.keychain_proxy.add_key(launcher_id.hex(), label="vault", private=False)
await env.node.keychain_proxy.add_key(
launcher_id.hex(), label="vault", private=False, key_type=KeyTypes.VAULT_LAUNCHER
)
await env.restart(vault_root.get_fingerprint())
await wallet_environments.full_node.wait_for_wallet_synced(env.node, 20)

Expand Down
9 changes: 5 additions & 4 deletions chia/cmds/init_funcs.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
from typing import Any, Dict, List, Optional

import yaml
from chia_rs import PrivateKey

from chia.cmds.configure import configure
from chia.consensus.coinbase import create_puzzlehash_for_pk
Expand Down Expand Up @@ -64,13 +65,13 @@ def dict_add_new_default(updated: Dict[str, Any], default: Dict[str, Any], do_no
def check_keys(new_root: Path, keychain: Optional[Keychain] = None) -> None:
if keychain is None:
keychain = Keychain()
all_sks = keychain.get_all_private_keys()
all_sks: List[PrivateKey] = [sk for sk, _ in keychain.get_all_private_keys() if isinstance(sk, PrivateKey)]
if len(all_sks) == 0:
print("No keys are present in the keychain. Generate them with 'chia keys generate'")
return None

with lock_and_load_config(new_root, "config.yaml") as config:
pool_child_pubkeys = [master_sk_to_pool_sk(sk).get_g1() for sk, _ in all_sks]
pool_child_pubkeys = [master_sk_to_pool_sk(sk).get_g1() for sk in all_sks]
all_targets = []
stop_searching_for_farmer = "xch_target_address" not in config["farmer"]
stop_searching_for_pool = "xch_target_address" not in config["pool"]
Expand All @@ -79,7 +80,7 @@ def check_keys(new_root: Path, keychain: Optional[Keychain] = None) -> None:
prefix = config["network_overrides"]["config"][selected]["address_prefix"]

intermediates = {}
for sk, _ in all_sks:
for sk in all_sks:
intermediates[bytes(sk)] = {
"observer": master_sk_to_wallet_sk_unhardened_intermediate(sk),
"non-observer": master_sk_to_wallet_sk_intermediate(sk),
Expand All @@ -88,7 +89,7 @@ def check_keys(new_root: Path, keychain: Optional[Keychain] = None) -> None:
for i in range(number_of_ph_to_search):
if stop_searching_for_farmer and stop_searching_for_pool and i > 0:
break
for sk, _ in all_sks:
for sk in all_sks:
intermediate_n = intermediates[bytes(sk)]["non-observer"]
intermediate_o = intermediates[bytes(sk)]["observer"]

Expand Down
13 changes: 13 additions & 0 deletions chia/cmds/keys.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
from typing import Optional, Tuple

import click
from chia_rs import PrivateKey

from chia.cmds import options

Expand Down Expand Up @@ -334,6 +335,10 @@ def search_cmd(
if fingerprint is None and filename is not None:
sk = resolve_derivation_master_key(filename)

if sk is not None and not isinstance(sk, PrivateKey):
print("Cannot derive from non-BLS keys")
return

found: bool = search_derive(
ctx.obj["root_path"],
fingerprint,
Expand Down Expand Up @@ -385,6 +390,10 @@ def wallet_address_cmd(
if fingerprint is None and filename is not None:
sk = resolve_derivation_master_key(filename)

if sk is not None and not isinstance(sk, PrivateKey):
print("Cannot derive from non-BLS keys")
return

derive_wallet_address(
ctx.obj["root_path"], fingerprint, index, count, prefix, non_observer_derivation, show_hd_path, sk
)
Expand Down Expand Up @@ -462,6 +471,10 @@ def child_key_cmd(
if fingerprint is None and filename is not None:
sk = resolve_derivation_master_key(filename)

if sk is not None and not isinstance(sk, PrivateKey):
print("Cannot derive from non-BLS keys")
return

derive_child_key(
fingerprint,
key_type,
Expand Down
Loading

0 comments on commit d2f5d0a

Please sign in to comment.