Skip to content

Commit

Permalink
Update unit tests and use a constant for default region
Browse files Browse the repository at this point in the history
  • Loading branch information
mill1000 committed Dec 11, 2024
1 parent 0d9fced commit d8383a8
Show file tree
Hide file tree
Showing 5 changed files with 43 additions and 29 deletions.
10 changes: 5 additions & 5 deletions msmart/cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@

from msmart import __version__
from msmart.cloud import Cloud, CloudError
from msmart.const import CLOUD_CREDENTIALS
from msmart.const import CLOUD_CREDENTIALS, DEFAULT_CLOUD_REGION
from msmart.device import AirConditioner as AC
from msmart.discover import Discover
from msmart.lan import AuthenticationError
Expand All @@ -15,7 +15,7 @@
_LOGGER = logging.getLogger(__name__)


DEFAULT_CLOUD_ACCOUNT, DEFAULT_CLOUD_PASSWORD = CLOUD_CREDENTIALS["US"]
DEFAULT_CLOUD_ACCOUNT, DEFAULT_CLOUD_PASSWORD = CLOUD_CREDENTIALS[DEFAULT_CLOUD_REGION]


async def _discover(args) -> None:
Expand Down Expand Up @@ -219,7 +219,7 @@ async def _download(args) -> None:

# Use discovery to to find device information
_LOGGER.info("Discovering %s on local network.", args.host)
device = await Discover.discover_single(args.host, region = args.region, account=args.account, password=args.password, auto_connect=False)
device = await Discover.discover_single(args.host, region=args.region, account=args.account, password=args.password, auto_connect=False)

if device is None:
_LOGGER.error("Device not found.")
Expand All @@ -235,7 +235,7 @@ async def _download(args) -> None:
exit(1)

# Get cloud connection
cloud = Cloud(args.region, account = args.account, password = args.password)
cloud = Cloud(args.region, account=args.account, password=args.password)
try:
await cloud.login()
except CloudError as e:
Expand Down Expand Up @@ -305,7 +305,7 @@ def main() -> NoReturn:
common_parser.add_argument("--region",
help="Country/region for built-in cloud credential selection.",
choices=CLOUD_CREDENTIALS.keys(),
default="US")
default=DEFAULT_CLOUD_REGION)
common_parser.add_argument("--account",
help="Manually specify a MSmart username for cloud authentication.",
default=None)
Expand Down
4 changes: 2 additions & 2 deletions msmart/cloud.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@
from Crypto.Cipher import AES
from Crypto.Util import Padding

from msmart.const import DeviceType, CLOUD_CREDENTIALS
from msmart.const import CLOUD_CREDENTIALS, DEFAULT_CLOUD_REGION, DeviceType

_LOGGER = logging.getLogger(__name__)

Expand Down Expand Up @@ -55,7 +55,7 @@ class Cloud:
RETRIES = 3

def __init__(self,
region: str = "US",
region: str = DEFAULT_CLOUD_REGION,
*,
account: Optional[str] = None,
password: Optional[str] = None,
Expand Down
2 changes: 1 addition & 1 deletion msmart/const.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@
0xb7, 0xe4, 0x2d, 0x53, 0x49, 0x47, 0x62, 0xbe
])


DEFAULT_CLOUD_REGION = "US"
CLOUD_CREDENTIALS = {
"DE": ("midea_eu@mailinator.com", "das_ist_passwort1"),
"KR": ("midea_sea@mailinator.com", "password_for_sea1"),
Expand Down
21 changes: 8 additions & 13 deletions msmart/discover.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
from typing import Any, Optional, Type, cast

from msmart.cloud import Cloud, CloudError
from msmart.const import (CLOUD_CREDENTIALS, DEVICE_INFO_MSG, DISCOVERY_MSG,
from msmart.const import (DEFAULT_CLOUD_REGION, DEVICE_INFO_MSG, DISCOVERY_MSG,
DeviceType)
from msmart.device import AirConditioner, Device
from msmart.lan import AuthenticationError, Security
Expand Down Expand Up @@ -134,7 +134,7 @@ def connection_lost(self, exc) -> None:
class Discover:
"""Discover Midea smart devices on the local network."""

_region = "US"
_region = DEFAULT_CLOUD_REGION
_account = None
_password = None
_lock = None
Expand All @@ -147,12 +147,12 @@ async def discover(
*,
target=_IPV4_BROADCAST,
timeout=5,
discovery_packets:int=3,
discovery_packets: int = 3,
interface=None,
region: str = "US",
region: str = DEFAULT_CLOUD_REGION,
account: Optional[str] = None,
password: Optional[str] = None,
auto_connect:bool =True
auto_connect: bool = True
) -> list[Device]:
"""Discover devices via broadcast."""

Expand All @@ -163,13 +163,7 @@ async def discover(
# Always use a new cloud connection
cls._cloud = None

# # Validate incoming credentials and region
# if (account is None and password is None) and region not in CLOUD_CREDENTIALS:
# raise ValueError(f"Unknown cloud region '{region}'.")
# elif account or password:
# raise ValueError("Account and password must be specified.")

# Save cloud credentials
# Save cloud region and credentials
cls._region = region
cls._account = account
cls._password = password
Expand Down Expand Up @@ -231,7 +225,8 @@ async def _get_cloud(cls) -> Optional[Cloud]:
async with cls._lock:
# Create cloud connection if nonexistent
if cls._cloud is None:
cloud = Cloud(cls._region, account=cls._account, password=cls._password)
cloud = Cloud(cls._region, account=cls._account,
password=cls._password)
try:
await cloud.login()
cls._cloud = cloud
Expand Down
35 changes: 27 additions & 8 deletions msmart/tests/test_cloud.py
Original file line number Diff line number Diff line change
@@ -1,17 +1,20 @@
import unittest
from typing import Any, Optional

from msmart.cloud import ApiError, Cloud, CloudError
from msmart.const import CLOUD_CREDENTIALS

CLOUD_ACCOUNT, CLOUD_PASSWORD = CLOUD_CREDENTIALS["US"]
from msmart.const import DEFAULT_CLOUD_REGION


class TestCloud(unittest.IsolatedAsyncioTestCase):
# pylint: disable=protected-access

async def _login(self, account: str = CLOUD_ACCOUNT,
password: str = CLOUD_PASSWORD) -> Cloud:
client = Cloud(account, password)
async def _login(self,
region: str = DEFAULT_CLOUD_REGION,
*,
account: Optional[str] = None,
password: Optional[str] = None
) -> Cloud:
client = Cloud(region, account=account, password=password)
await client.login()

return client
Expand All @@ -25,11 +28,27 @@ async def test_login(self) -> None:
self.assertIsNotNone(client._access_token)

async def test_login_exception(self) -> None:
"""Test that we can login to the cloud."""
"""Test that bad credentials raise an exception."""

with self.assertRaises(ApiError):
await self._login(account="bad@account.com", password="not_a_password")

async def test_invalid_region(self) -> None:
"""Test that an invalid region raise an exception."""

with self.assertRaises(ValueError):
await self._login("NOT_A_REGION")

async def test_invalid_credentials(self) -> None:
"""Test that invalid credentials raise an exception."""

# Check that specifying only an account or password raises an error
with self.assertRaises(ValueError):
await self._login(account=None, password="some_password")

with self.assertRaises(ValueError):
await self._login(account="some_account", password=None)

async def test_get_token(self) -> None:
"""Test that a token and key can be obtained from the cloud."""

Expand All @@ -55,7 +74,7 @@ async def test_get_token_exception(self) -> None:
async def test_connect_exception(self) -> None:
"""Test that an exception is thrown when the cloud connection fails."""

client = Cloud(CLOUD_ACCOUNT, CLOUD_PASSWORD)
client = Cloud(DEFAULT_CLOUD_REGION)

# Override URL to an invalid domain
client._base_url = "https://fake_server.invalid."
Expand Down

0 comments on commit d8383a8

Please sign in to comment.