Skip to content

Commit

Permalink
Fix: Improve code structure in pair-programming with Lyam
Browse files Browse the repository at this point in the history
  • Loading branch information
nesitor committed Oct 2, 2024
1 parent 2b25fe5 commit e272479
Show file tree
Hide file tree
Showing 8 changed files with 141 additions and 408 deletions.
2 changes: 1 addition & 1 deletion .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -50,7 +50,7 @@ MANIFEST
**/device.key

# environment variables
.env
.config.json
.env.local

.gitsigners
166 changes: 29 additions & 137 deletions src/aleph/sdk/account.py
Original file line number Diff line number Diff line change
@@ -1,102 +1,30 @@
import asyncio
import json
import logging
from pathlib import Path
from typing import Dict, List, Optional, Type, TypeVar, Union, overload
from typing import Dict, Optional, Type, TypeVar

import base58
from aleph_message.models import Chain

from aleph.sdk.chains.common import get_fallback_private_key
from aleph.sdk.chains.ethereum import ETHAccount
from aleph.sdk.chains.remote import RemoteAccount
from aleph.sdk.chains.solana import (
SOLAccount,
parse_solana_private_key,
solana_private_key_from_bytes,
)
from aleph.sdk.conf import settings
from aleph.sdk.chains.solana import SOLAccount
from aleph.sdk.conf import settings, load_main_configuration
from aleph.sdk.types import AccountFromPrivateKey
from aleph.sdk.utils import load_account_key_context

logger = logging.getLogger(__name__)

T = TypeVar("T", bound=AccountFromPrivateKey)

CHAIN_TO_ACCOUNT_MAP: Dict[Chain, Type[AccountFromPrivateKey]] = {
Chain.ETH: ETHAccount,
Chain.AVAX: ETHAccount,
Chain.SOL: SOLAccount,
Chain.BASE: ETHAccount,
}


def detect_chain_from_private_key(private_key: Union[str, List[int], bytes]) -> Chain:
"""
Detect the blockchain chain based on the private key format.
- Chain.ETH for Ethereum (EVM) private keys
- Chain.SOL for Solana private keys (base58 or uint8 format).
Raises:
ValueError: If the private key format is invalid or not recognized.
"""
if isinstance(private_key, (str, bytes)) and is_valid_private_key(
private_key, ETHAccount
):
return Chain.ETH

elif is_valid_private_key(private_key, SOLAccount):
return Chain.SOL

else:
raise ValueError("Unsupported private key format. Unable to detect chain.")


@overload
def is_valid_private_key(
private_key: Union[str, bytes], account_type: Type[ETHAccount]
) -> bool: ...


@overload
def is_valid_private_key(
private_key: Union[str, List[int], bytes], account_type: Type[SOLAccount]
) -> bool: ...


def is_valid_private_key(
private_key: Union[str, List[int], bytes], account_type: Type[T]
) -> bool:
"""
Check if the private key is valid for either Ethereum or Solana based on the account type.
"""
try:
if account_type == ETHAccount:
# Handle Ethereum private key validation
if isinstance(private_key, str):
if private_key.startswith("0x"):
private_key = private_key[2:]
private_key = bytes.fromhex(private_key)
elif isinstance(private_key, list):
raise ValueError("Ethereum keys cannot be a list of integers")

account_type(private_key)

elif account_type == SOLAccount:
# Handle Solana private key validation
if isinstance(private_key, bytes):
return len(private_key) == 64
elif isinstance(private_key, str):
decoded_key = base58.b58decode(private_key)
return len(decoded_key) == 64
elif isinstance(private_key, list):
return len(private_key) == 64 and all(
isinstance(i, int) and 0 <= i <= 255 for i in private_key
)

return True
except Exception:
return False
def load_chain_account_type(chain: Chain):
chain_account_map: Dict[Chain, Type[AccountFromPrivateKey]] = {
Chain.ETH: ETHAccount,
Chain.AVAX: ETHAccount,
Chain.SOL: SOLAccount,
Chain.BASE: ETHAccount,
}
return chain_account_map.get(chain) or ETHAccount


def account_from_hex_string(private_key_str: str, account_type: Type[T]) -> T:
Expand All @@ -107,72 +35,36 @@ def account_from_hex_string(private_key_str: str, account_type: Type[T]) -> T:

def account_from_file(private_key_path: Path, account_type: Type[T]) -> T:
private_key = private_key_path.read_bytes()
if account_type == SOLAccount:
private_key = parse_solana_private_key(
solana_private_key_from_bytes(private_key)
)

return account_type(private_key)


def _load_account(
private_key_str: Optional[str] = None,
private_key_path: Optional[Path] = None,
account_type: Type[AccountFromPrivateKey] = ETHAccount,
account_type: Optional[Type[AccountFromPrivateKey]] = None,
) -> AccountFromPrivateKey:
"""Load private key from a string or a file. takes the string argument in priority"""

default_account_type = ETHAccount # Default account type
if private_key_str:
# Check Account type based on private-key string format (base58 / uint for solana)
private_key_chain = detect_chain_from_private_key(private_key=private_key_str)
if private_key_chain == Chain.SOL:
account_type = SOLAccount
logger.debug("Solana private key is detected")
parsed_key = parse_solana_private_key(private_key_str)
return account_type(parsed_key)
logger.debug("Using account from string")
return account_from_hex_string(private_key_str, account_type)
if not account_type:
account_type = default_account_type
return account_from_hex_string(private_key_str, account_type)
else:
return account_from_hex_string(private_key_str, account_type)
elif private_key_path and private_key_path.is_file():
if private_key_path:
account_type = ETHAccount # Default account type

try:
account_data = load_account_key_context(settings.CONFIG_FILE)

if account_data:
chain = Chain(account_data.chain)
account_type = (
CHAIN_TO_ACCOUNT_MAP.get(chain, ETHAccount) or ETHAccount
)
logger.debug(
f"Detected {chain} account for path {private_key_path}"
)
else:
logger.warning(
f"No account data found in {private_key_path}, defaulting to {account_type.__name__}"
)

except FileNotFoundError:
logger.warning(
f"{private_key_path} not found, using default account type {account_type.__name__}"
)
except json.JSONDecodeError:
logger.error(
f"Invalid format in {private_key_path}, unable to load account info."
if not account_type:
account_type = default_account_type
main_configuration = load_main_configuration(settings.CONFIG_FILE)
if main_configuration:
account_type = load_chain_account_type(main_configuration.chain)
logger.debug(
f"Detected {main_configuration.chain} account for path {settings.CONFIG_FILE}"
)
raise ValueError(f"Invalid format in {private_key_path}.")
except KeyError as e:
logger.error(f"Missing key in account config: {e}")
raise ValueError(
f"Invalid account data in {private_key_path}. Key {e} is missing."
)
except Exception as e:
logger.error(f"Error loading account from {private_key_path}: {e}")
raise ValueError(
f"Could not load account data from {private_key_path}."
else:
logger.warning(
f"No main configuration data found in {settings.CONFIG_FILE}, defaulting to {account_type.__name__}"
)

return account_from_file(private_key_path, account_type)
return account_from_file(private_key_path, account_type)
elif settings.REMOTE_CRYPTO_HOST:
logger.debug("Using remote account")
loop = asyncio.get_event_loop()
Expand Down
8 changes: 5 additions & 3 deletions src/aleph/sdk/chains/solana.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,9 @@ class SOLAccount(BaseAccount):
_private_key: PrivateKey

def __init__(self, private_key: bytes):
self.private_key = private_key
self.private_key = parse_private_key(
private_key_from_bytes(private_key)
)
self._signing_key = SigningKey(self.private_key)
self._private_key = self._signing_key.to_curve25519_private_key()

Expand Down Expand Up @@ -93,7 +95,7 @@ def verify_signature(
raise BadSignatureError from e


def solana_private_key_from_bytes(
def private_key_from_bytes(
private_key_bytes: bytes, output_format: str = "base58"
) -> Union[str, List[int], bytes]:
"""
Expand Down Expand Up @@ -132,7 +134,7 @@ def solana_private_key_from_bytes(
raise ValueError("Invalid output format. Choose 'base58', 'list', or 'bytes'.")


def parse_solana_private_key(private_key: Union[str, List[int], bytes]) -> bytes:
def parse_private_key(private_key: Union[str, List[int], bytes]) -> bytes:
"""
Parse the private key which could be either:
- a base58-encoded string (which may contain both private and public key)
Expand Down
55 changes: 51 additions & 4 deletions src/aleph/sdk/conf.py
Original file line number Diff line number Diff line change
@@ -1,21 +1,25 @@
import json
import logging
import os
from pathlib import Path
from shutil import which
from typing import Dict, Optional, Union

from aleph_message.models import Chain
from aleph_message.models.execution.environment import HypervisorType
from pydantic import BaseSettings, Field
from pydantic import BaseSettings, Field, BaseModel

from aleph.sdk.types import ChainInfo


logger = logging.getLogger(__name__)


class Settings(BaseSettings):
CONFIG_HOME: Optional[str] = None

CONFIG_FILE: Path = Field(
default=Path("chains_config.json"),
default=Path("config.json"),
description="Path to the JSON file containing chain account configurations",
)

Expand Down Expand Up @@ -145,6 +149,17 @@ class Config:
env_file = ".env"


class MainConfiguration(BaseModel):
"""
Intern Chain Management with Account.
"""
path: Path
chain: Chain

class Config:
use_enum_values = True


# Settings singleton
settings = Settings()

Expand All @@ -168,8 +183,8 @@ class Config:
settings.PRIVATE_MNEMONIC_FILE = Path(
settings.CONFIG_HOME, "private-keys", "substrate.mnemonic"
)
if str(settings.CONFIG_FILE) == "chains_config.json":
settings.CONFIG_FILE = Path(settings.CONFIG_HOME, "chains_config.json")
if str(settings.CONFIG_FILE) == "config.json":
settings.CONFIG_FILE = Path(settings.CONFIG_HOME, "config.json")
# If Config file exist and well filled we update the PRIVATE_KEY_FILE default
if settings.CONFIG_FILE.exists():
try:
Expand All @@ -191,3 +206,35 @@ class Config:
field = field.lower()
settings.CHAINS[chain].__dict__[field] = value
settings.__delattr__(f"CHAINS_{fields}")


def save_main_configuration(file_path: Path, data: MainConfiguration):
"""
Synchronously save a single ChainAccount object as JSON to a file.
"""
with file_path.open("w") as file:
data_serializable = data.dict()
data_serializable["path"] = str(data_serializable["path"])
json.dump(data_serializable, file, indent=4)


def load_main_configuration(file_path: Path) -> Optional[MainConfiguration]:
"""
Synchronously load the private key and chain type from a file.
If the file does not exist or is empty, return None.
"""
if not file_path.exists() or file_path.stat().st_size == 0:
logger.debug(f"File {file_path} does not exist or is empty. Returning None.")
return None

try:
with file_path.open("rb") as file:
content = file.read()
data = json.loads(content.decode("utf-8"))
return MainConfiguration(**data)
except UnicodeDecodeError as e:
logger.error(f"Unable to decode {file_path} as UTF-8: {e}")
except json.JSONDecodeError:
logger.error(f"Invalid JSON format in {file_path}.")

return None
15 changes: 1 addition & 14 deletions src/aleph/sdk/types.py
Original file line number Diff line number Diff line change
@@ -1,13 +1,12 @@
from abc import abstractmethod
from enum import Enum
from pathlib import Path
from typing import Dict, Optional, Protocol, TypeVar

from pydantic import BaseModel

__all__ = ("StorageEnum", "Account", "AccountFromPrivateKey", "GenericMessage")

from aleph_message.models import AlephMessage, Chain
from aleph_message.models import AlephMessage


class StorageEnum(str, Enum):
Expand Down Expand Up @@ -77,15 +76,3 @@ class ChainInfo(BaseModel):
token: str
super_token: Optional[str] = None
active: bool = True


class ChainAccount(BaseModel):
"""
Intern Chain Management with Account.
"""

path: Path
chain: Chain

class Config:
use_enum_values = True
Loading

0 comments on commit e272479

Please sign in to comment.