Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[CHIA-1307] Port key management RPCs to @marshal decorator #18593

Merged
merged 11 commits into from
Sep 24, 2024
2 changes: 1 addition & 1 deletion chia/_tests/cmds/test_cmd_framework.py
Original file line number Diff line number Diff line change
Expand Up @@ -392,7 +392,7 @@ def run(self) -> None:
check_click_parsing(expected_command, "-wp", str(port), "-f", str(fingerprint))

async with expected_command.rpc_info.wallet_rpc(consume_errors=False) as client_info:
assert await client_info.client.get_logged_in_fingerprint() == fingerprint
assert (await client_info.client.get_logged_in_fingerprint()).fingerprint == fingerprint

# We don't care about setting the correct arg type here
test_present_client_info = TempCMD(rpc_info=NeedsWalletRPC(client_info="hello world")) # type: ignore[arg-type]
Expand Down
77 changes: 40 additions & 37 deletions chia/_tests/wallet/rpc/test_wallet_rpc.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,10 +47,15 @@
from chia.rpc.rpc_client import ResponseFailureError
from chia.rpc.rpc_server import RpcServer
from chia.rpc.wallet_request_types import (
AddKey,
CheckDeleteKey,
CombineCoins,
DefaultCAT,
DeleteKey,
DIDGetPubkey,
GetNotifications,
GetPrivateKey,
LogIn,
SplitCoins,
SplitCoinsResponse,
VerifySignature,
Expand Down Expand Up @@ -1709,22 +1714,22 @@ async def _check_delete_key(
save_config(wallet_node.root_path, "config.yaml", test_config)

# Check farmer_fp key
sk_dict = await client.check_delete_key(farmer_fp)
assert sk_dict["fingerprint"] == farmer_fp
assert sk_dict["used_for_farmer_rewards"] is True
assert sk_dict["used_for_pool_rewards"] is False
resp = await client.check_delete_key(CheckDeleteKey(uint32(farmer_fp)))
assert resp.fingerprint == farmer_fp
assert resp.used_for_farmer_rewards is True
assert resp.used_for_pool_rewards is False
Quexington marked this conversation as resolved.
Show resolved Hide resolved

# Check pool_fp key
sk_dict = await client.check_delete_key(pool_fp)
assert sk_dict["fingerprint"] == pool_fp
assert sk_dict["used_for_farmer_rewards"] is False
assert sk_dict["used_for_pool_rewards"] is True
resp = await client.check_delete_key(CheckDeleteKey(uint32(pool_fp)))
assert resp.fingerprint == pool_fp
assert resp.used_for_farmer_rewards is False
assert resp.used_for_pool_rewards is True

# Check unknown key
sk_dict = await client.check_delete_key(123456, 10)
assert sk_dict["fingerprint"] == 123456
assert sk_dict["used_for_farmer_rewards"] is False
assert sk_dict["used_for_pool_rewards"] is False
resp = await client.check_delete_key(CheckDeleteKey(uint32(123456), uint16(10)))
assert resp.fingerprint == 123456
assert resp.used_for_farmer_rewards is False
assert resp.used_for_pool_rewards is False


@pytest.mark.anyio
Expand All @@ -1738,7 +1743,7 @@ async def test_key_and_address_endpoints(wallet_rpc_environment: WalletRpcTestEn
address = await client.get_next_address(1, True)
assert len(address) > 10

pks = await client.get_public_keys()
pks = (await client.get_public_keys()).pk_fingerprints
assert len(pks) == 1

await generate_funds(env.full_node.api, env.wallet_1)
Expand All @@ -1756,23 +1761,21 @@ async def test_key_and_address_endpoints(wallet_rpc_environment: WalletRpcTestEn
await client.delete_unconfirmed_transactions(1)
assert len(await wallet.wallet_state_manager.tx_store.get_unconfirmed_for_wallet(1)) == 0

sk_dict = await client.get_private_key(pks[0])
assert sk_dict["fingerprint"] == pks[0]
assert sk_dict["sk"] is not None
assert sk_dict["pk"] is not None
assert sk_dict["seed"] is not None
sk_resp = await client.get_private_key(GetPrivateKey(pks[0]))
assert sk_resp.private_key.fingerprint == pks[0]
assert sk_resp.private_key.seed is not None

mnemonic = await client.generate_mnemonic()
assert len(mnemonic) == 24
resp = await client.generate_mnemonic()
assert len(resp.mnemonic) == 24

await client.add_key(mnemonic)
await client.add_key(AddKey(resp.mnemonic))

pks = await client.get_public_keys()
pks = (await client.get_public_keys()).pk_fingerprints
assert len(pks) == 2

await client.log_in(pks[1])
sk_dict = await client.get_private_key(pks[1])
assert sk_dict["fingerprint"] == pks[1]
await client.log_in(LogIn(pks[1]))
sk_resp = await client.get_private_key(GetPrivateKey(pks[1]))
assert sk_resp.private_key.fingerprint == pks[1]

# test hardened keys
await _check_delete_key(client=client, wallet_node=wallet_node, farmer_fp=pks[0], pool_fp=pks[1], observer=False)
Expand All @@ -1786,10 +1789,10 @@ async def test_key_and_address_endpoints(wallet_rpc_environment: WalletRpcTestEn
save_config(wallet_node.root_path, "config.yaml", test_config)

# Check key
sk_dict = await client.check_delete_key(pks[1])
assert sk_dict["fingerprint"] == pks[1]
assert sk_dict["used_for_farmer_rewards"] is False
assert sk_dict["used_for_pool_rewards"] is True
delete_key_resp = await client.check_delete_key(CheckDeleteKey(pks[1]))
assert delete_key_resp.fingerprint == pks[1]
assert delete_key_resp.used_for_farmer_rewards is False
assert delete_key_resp.used_for_pool_rewards is True

# set farmer and pool to empty string
with lock_and_load_config(wallet_node.root_path, "config.yaml") as test_config:
Expand All @@ -1798,14 +1801,14 @@ async def test_key_and_address_endpoints(wallet_rpc_environment: WalletRpcTestEn
save_config(wallet_node.root_path, "config.yaml", test_config)

# Check key
sk_dict = await client.check_delete_key(pks[0])
assert sk_dict["fingerprint"] == pks[0]
assert sk_dict["used_for_farmer_rewards"] is False
assert sk_dict["used_for_pool_rewards"] is False
delete_key_resp = await client.check_delete_key(CheckDeleteKey(pks[0]))
assert delete_key_resp.fingerprint == pks[0]
assert delete_key_resp.used_for_farmer_rewards is False
assert delete_key_resp.used_for_pool_rewards is False

await client.delete_key(pks[0])
await client.log_in(pks[1])
assert len(await client.get_public_keys()) == 1
await client.delete_key(DeleteKey(pks[0]))
await client.log_in(LogIn(uint32(pks[1])))
assert len((await client.get_public_keys()).pk_fingerprints) == 1

assert not (await client.get_sync_status())

Expand All @@ -1818,7 +1821,7 @@ async def test_key_and_address_endpoints(wallet_rpc_environment: WalletRpcTestEn

# Delete all keys
await client.delete_all_keys()
assert len(await client.get_public_keys()) == 0
assert len((await client.get_public_keys()).pk_fingerprints) == 0


@pytest.mark.anyio
Expand Down
2 changes: 1 addition & 1 deletion chia/_tests/wallet/test_wallet_test_framework.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,7 @@ async def test_basic_functionality(wallet_environments: WalletTestFramework) ->
env_0: WalletEnvironment = wallet_environments.environments[0]
env_1: WalletEnvironment = wallet_environments.environments[1]

assert await env_0.rpc_client.get_logged_in_fingerprint() is not None
assert (await env_0.rpc_client.get_logged_in_fingerprint()).fingerprint is not None
# assert await env_1.rpc_client.get_logged_in_fingerprint() is not None

assert await env_0.xch_wallet.get_confirmed_balance() == 2_000_000_000_000
Expand Down
12 changes: 7 additions & 5 deletions chia/cmds/cmds_util.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,14 +18,15 @@
from chia.rpc.full_node_rpc_client import FullNodeRpcClient
from chia.rpc.harvester_rpc_client import HarvesterRpcClient
from chia.rpc.rpc_client import ResponseFailureError, RpcClient
from chia.rpc.wallet_request_types import LogIn
from chia.rpc.wallet_rpc_client import WalletRpcClient
from chia.simulator.simulator_full_node_rpc_client import SimulatorFullNodeRpcClient
from chia.types.blockchain_format.sized_bytes import bytes32
from chia.types.mempool_submission_status import MempoolSubmissionStatus
from chia.util.config import load_config
from chia.util.default_root import DEFAULT_ROOT_PATH
from chia.util.errors import CliRpcConnectionError, InvalidPathError
from chia.util.ints import uint16, uint64
from chia.util.ints import uint16, uint32, uint64
from chia.util.keychain import KeyData
from chia.util.streamable import Streamable, streamable
from chia.wallet.conditions import ConditionValidTimes
Expand Down Expand Up @@ -169,7 +170,7 @@ async def get_wallet(root_path: Path, wallet_client: WalletRpcClient, fingerprin
# if only a single key is available, select it automatically
selected_fingerprint = fingerprints[0]
else:
logged_in_fingerprint: Optional[int] = await wallet_client.get_logged_in_fingerprint()
logged_in_fingerprint: Optional[int] = (await wallet_client.get_logged_in_fingerprint()).fingerprint
logged_in_key: Optional[KeyData] = None
if logged_in_fingerprint is not None:
logged_in_key = next((key for key in all_keys if key.fingerprint == logged_in_fingerprint), None)
Expand Down Expand Up @@ -227,10 +228,11 @@ async def get_wallet(root_path: Path, wallet_client: WalletRpcClient, fingerprin
selected_fingerprint = fp

if selected_fingerprint is not None:
log_in_response = await wallet_client.log_in(selected_fingerprint)
try:
await wallet_client.log_in(LogIn(uint32(selected_fingerprint)))
except ValueError as e:
raise CliRpcConnectionError(f"Login failed for fingerprint {selected_fingerprint}: {e.args[0]}")

if log_in_response["success"] is False:
raise CliRpcConnectionError(f"Login failed for fingerprint {selected_fingerprint}: {log_in_response}")
finally:
# Closing the keychain proxy takes a moment, so we wait until after the login is complete
if keychain_proxy is not None:
Expand Down
12 changes: 6 additions & 6 deletions chia/data_layer/data_layer.py
Original file line number Diff line number Diff line change
Expand Up @@ -67,6 +67,7 @@
write_files_for_root,
)
from chia.rpc.rpc_server import StateChangedProtocol, default_get_connections
from chia.rpc.wallet_request_types import LogIn
from chia.rpc.wallet_rpc_client import WalletRpcClient
from chia.server.outbound_message import NodeType
from chia.server.server import ChiaServer
Expand Down Expand Up @@ -242,13 +243,12 @@ def set_server(self, server: ChiaServer) -> None:
self._server = server

async def wallet_log_in(self, fingerprint: int) -> int:
result = await self.wallet_rpc.log_in(fingerprint)
if not result.get("success", False):
wallet_error = result.get("error", "no error message provided")
raise Exception(f"DataLayer wallet RPC log in request failed: {wallet_error}")
try:
result = await self.wallet_rpc.log_in(LogIn(uint32(fingerprint)))
except ValueError as e:
raise Exception(f"DataLayer wallet RPC log in request failed: {e.args[0]}")

fingerprint = cast(int, result["fingerprint"])
return fingerprint
return result.fingerprint

async def create_store(
self, fee: uint64, root: bytes32 = bytes32([0] * 32)
Expand Down
105 changes: 104 additions & 1 deletion chia/rpc/wallet_request_types.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
from dataclasses import dataclass, field
from typing import Any, Dict, List, Optional, Type, TypeVar

from chia_rs import G1Element, G2Element
from chia_rs import G1Element, G2Element, PrivateKey
from typing_extensions import dataclass_transform

from chia.types.blockchain_format.sized_bytes import bytes32
Expand Down Expand Up @@ -42,6 +42,109 @@ def default_raise() -> Any: # pragma: no cover
raise RuntimeError("This should be impossible to hit and is just for < 3.10 compatibility")


@streamable
Quexington marked this conversation as resolved.
Show resolved Hide resolved
@dataclass(frozen=True)
class Empty(Streamable):
pass


@streamable
@dataclass(frozen=True)
class LogIn(Streamable):
Quexington marked this conversation as resolved.
Show resolved Hide resolved
fingerprint: uint32


@streamable
@dataclass(frozen=True)
class LogInResponse(Streamable):
fingerprint: uint32


@streamable
@dataclass(frozen=True)
class GetLoggedInFingerprintResponse(Streamable):
fingerprint: Optional[uint32]


@streamable
@dataclass(frozen=True)
class GetPublicKeysResponse(Streamable):
keyring_is_locked: bool
public_key_fingerprints: Optional[List[uint32]] = None

@property
def pk_fingerprints(self) -> List[uint32]:
if self.keyring_is_locked:
raise RuntimeError("get_public_keys cannot return public keys because the keyring is locked")
else:
assert self.public_key_fingerprints is not None
return self.public_key_fingerprints


@streamable
@dataclass(frozen=True)
class GetPrivateKey(Streamable):
fingerprint: uint32


# utility for `GetPrivateKeyResponse`
@streamable
@dataclass(frozen=True)
class GetPrivateKeyFormat(Streamable):
fingerprint: uint32
sk: PrivateKey
pk: G1Element
farmer_pk: G1Element
pool_pk: G1Element
seed: Optional[str]


@streamable
@dataclass(frozen=True)
class GetPrivateKeyResponse(Streamable):
private_key: GetPrivateKeyFormat
altendky marked this conversation as resolved.
Show resolved Hide resolved


@streamable
@dataclass(frozen=True)
class GenerateMnemonicResponse(Streamable):
mnemonic: List[str]


@streamable
@dataclass(frozen=True)
class AddKey(Streamable):
mnemonic: List[str]


@streamable
@dataclass(frozen=True)
class AddKeyResponse(Streamable):
fingerprint: uint32


@streamable
@dataclass(frozen=True)
class DeleteKey(Streamable):
fingerprint: uint32


@streamable
@dataclass(frozen=True)
class CheckDeleteKey(Streamable):
fingerprint: uint32
max_ph_to_search: uint16 = uint16(100)


@streamable
@dataclass(frozen=True)
class CheckDeleteKeyResponse(Streamable):
fingerprint: uint32
used_for_farmer_rewards: bool
used_for_pool_rewards: bool
wallet_balance: bool


@streamable
@dataclass(frozen=True)
class GetNotifications(Streamable):
Expand Down
Loading
Loading