Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Small cleanups to IP connection #338

Merged
merged 5 commits into from
Oct 19, 2023
Merged
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
91 changes: 58 additions & 33 deletions aiohomekit/controller/ip/connection.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@

import asyncio
import logging
from typing import TYPE_CHECKING
from typing import TYPE_CHECKING, Any

from async_interrupt import interrupt

Expand Down Expand Up @@ -54,10 +54,12 @@ class ConnectionReady(Exception):


class InsecureHomeKitProtocol(asyncio.Protocol):
def __init__(self, connection):
"""An asyncio.Protocol implementation for HomeKit connections."""

def __init__(self, connection: HomeKitConnection) -> None:
self.connection = connection
self.host = ":".join((connection.host, str(connection.port)))
self.result_cbs = []
self.result_cbs: list[asyncio.Future[HttpResponse]] = []
self.current_response = HttpResponse()
self.loop = asyncio.get_running_loop()

Expand All @@ -68,7 +70,13 @@ def connection_made(self, transport):
def connection_lost(self, exception):
self.connection._connection_lost(exception)

async def send_bytes(self, payload: bytes):
def _handle_timeout(self, fut: asyncio.Future[None]) -> None:
"""Handle a timeout."""
if not fut.done():
fut.set_exception(asyncio.TimeoutError)

async def send_bytes(self, payload: bytes) -> HttpResponse:
"""Send bytes to the device."""
if self.transport.is_closing():
# FIXME: It would be nice to try and wait for the reconnect in future.
# In that case we need to make sure we do it at a layer above send_bytes otherwise
Expand All @@ -83,16 +91,21 @@ async def send_bytes(self, payload: bytes):
# We return a future so that our caller can block on a reply
# We can send many requests and dispatch the results in order
# Should mean we don't need locking around request/reply cycles
result = self.loop.create_future()
loop = self.loop
result: asyncio.Future[HttpResponse] = loop.create_future()
self.result_cbs.append(result)

timeout_handle = loop.call_at(loop.time() + 30, self._handle_timeout, result)
timeout_expired = False
try:
async with asyncio_timeout(30):
return await result
return await result
except asyncio.TimeoutError:
timeout_expired = True
self.transport.write_eof()
self.transport.close()
raise AccessoryDisconnectedError("Timeout while waiting for response")
finally:
if not timeout_expired:
timeout_handle.cancel()

def data_received(self, data):
while data:
Expand Down Expand Up @@ -124,7 +137,11 @@ def close(self):


class SecureHomeKitProtocol(InsecureHomeKitProtocol):
def __init__(self, connection, a2c_key, c2a_key):
"""An asyncio.Protocol implementation for secure HomeKit connections."""

def __init__(
self, connection: HomeKitConnection, a2c_key: bytes, c2a_key: bytes
) -> None:
super().__init__(connection)

self._incoming_buffer: bytearray = bytearray()
Expand All @@ -138,7 +155,7 @@ def __init__(self, connection, a2c_key, c2a_key):
self.encryptor = ChaCha20Poly1305Encryptor(self.c2a_key)
self.decryptor = ChaCha20Poly1305Decryptor(self.a2c_key)

async def send_bytes(self, payload: bytes):
async def send_bytes(self, payload: bytes) -> HttpResponse:
buffer: list[bytes] = []

while len(payload) > 0:
Expand All @@ -157,7 +174,7 @@ async def send_bytes(self, payload: bytes):

return await super().send_bytes(b"".join(buffer))

def data_received(self, data):
def data_received(self, data: bytes) -> None:
"""
Called by asyncio when data is received from a TCP socket.

Expand Down Expand Up @@ -200,7 +217,7 @@ def data_received(self, data):


class HomeKitConnection:
def __init__(self, owner, host, port, concurrency_limit=1):
def __init__(self, owner: IpPairing, host: str, port, concurrency_limit=1) -> None:
self.owner = owner
self.host = host
self.port = port
Expand All @@ -224,17 +241,17 @@ def __init__(self, owner, host, port, concurrency_limit=1):
self._last_connector_error: Exception | None = None

@property
def name(self):
def name(self) -> str:
"""Return the name of the connection."""
if self.owner:
return self.owner.name
return f"{self.host}:{self.port}"

@property
def is_connected(self):
def is_connected(self) -> bool:
return self.transport and self.protocol and not self.closed

def _start_connector(self):
def _start_connector(self) -> None:
"""
Start a reconnect background task.

Expand Down Expand Up @@ -264,7 +281,7 @@ def reconnect_soon(self) -> None:
return
self._start_reconnecting()

def _start_reconnecting(self):
def _start_reconnecting(self) -> bool:
"""Start reconnecting."""
if self.is_connected:
return False
Expand Down Expand Up @@ -292,7 +309,7 @@ async def ensure_connection(self) -> None:
# connector task so it continues to run if the timeout is hit.
await asyncio.shield(self._connector)

async def _stop_connector(self):
async def _stop_connector(self) -> None:
"""
Cancels any active reconnect tasks.

Expand All @@ -313,7 +330,7 @@ async def _stop_connector(self):
except asyncio.CancelledError:
pass

async def get(self, target):
async def get(self, target: str) -> HttpResponse:
"""
Sends a HTTP POST request to the current transport and returns an awaitable
that can be used to wait for a response.
Expand All @@ -323,11 +340,13 @@ async def get(self, target):
target=target,
)

async def get_json(self, target):
async def get_json(self, target: str) -> dict[str, Any]:
response = await self.get(target)
return hkjson.loads(response.body)

async def put(self, target, body, content_type=HttpContentTypes.JSON):
async def put(
self, target: str, body: Any, content_type=HttpContentTypes.JSON
bdraco marked this conversation as resolved.
Show resolved Hide resolved
) -> HttpResponse:
"""
Sends a HTTP POST request to the current transport and returns an awaitable
that can be used to wait for a response.
Expand All @@ -342,7 +361,7 @@ async def put(self, target, body, content_type=HttpContentTypes.JSON):
body=body,
)

async def put_json(self, target, body):
async def put_json(self, target: str, body: Any) -> dict[str, Any]:
response = await self.put(
target,
hkjson.dump_bytes(body),
Expand Down Expand Up @@ -370,7 +389,9 @@ async def put_json(self, target, body):

return parsed

async def post(self, target, body, content_type=HttpContentTypes.TLV):
async def post(
self, target: str, body: Any, content_type=HttpContentTypes.TLV
bdraco marked this conversation as resolved.
Show resolved Hide resolved
) -> HttpResponse:
"""
Sends a HTTP POST request to the current transport and returns an awaitable
that can be used to wait for a response.
Expand All @@ -385,7 +406,7 @@ async def post(self, target, body, content_type=HttpContentTypes.TLV):
body=body,
)

async def post_json(self, target, body):
async def post_json(self, target: str, body: Any) -> dict[str, Any]:
response = await self.post(
target,
hkjson.dump_bytes(body),
Expand All @@ -412,7 +433,7 @@ async def post_json(self, target, body):

return parsed

async def post_tlv(self, target, body, expected=None):
async def post_tlv(self, target: str, body: Any, expected=None) -> list:
try:
response = await self.post(
target,
Expand All @@ -425,7 +446,9 @@ async def post_tlv(self, target, body, expected=None):
body = TLV.decode_bytes(response.body, expected=expected)
return body

async def request(self, method, target, headers=None, body=None):
async def request(
self, method: str, target: str, headers=None, body: Any | None = None
) -> HttpResponse:
"""
Sends a HTTP request to the current transport and returns an awaitable
that can be used to wait for the response.
Expand Down Expand Up @@ -486,7 +509,7 @@ async def request(self, method, target, headers=None, body=None):

return resp

async def close(self):
async def close(self) -> None:
"""
Close the connection transport.
"""
Expand All @@ -501,7 +524,7 @@ async def close(self):
self.transport = None
self.is_secure = None

def _connection_lost(self, exception):
def _connection_lost(self, exception: Exception) -> None:
"""
Called by a Protocol instance when eof_received happens.
"""
Expand All @@ -516,7 +539,7 @@ def _connection_lost(self, exception):
self.transport = None
self.protocol = None

async def _connect_once(self):
async def _connect_once(self) -> None:
"""_connect_once must only ever be called from _reconnect to ensure its done with a lock."""
loop = asyncio.get_event_loop()

Expand All @@ -537,7 +560,7 @@ async def _connect_once(self):
if self.owner:
await self.owner.connection_made(False)

async def _reconnect(self):
async def _reconnect(self) -> None:
# When the device is seen by zeroconf, call reconnect_soon
# to force the reconnect wait to be canceled and _connect_once
# will be called soon.
Expand Down Expand Up @@ -590,7 +613,7 @@ async def _reconnect(self):
finally:
self._reconnect_future = None

def event_received(self, event):
def event_received(self, event: HttpResponse) -> None:
if not self.owner:
return

Expand All @@ -607,12 +630,14 @@ def event_received(self, event):

self.owner.event_received(parsed)

def __repr__(self):
def __repr__(self) -> str:
return f"HomeKitConnection(host={self.host!r}, port={self.port!r})"


class SecureHomeKitConnection(HomeKitConnection):
def __init__(self, owner, pairing_data):
"""A HomeKit connection that negotiates a secure session."""

def __init__(self, owner: IpPairing, pairing_data: dict[str, Any]) -> None:
super().__init__(
owner,
pairing_data["AccessoryIP"],
Expand All @@ -629,7 +654,7 @@ async def _connect_once(self):
self.is_secure = False

if self.owner and self.owner.description:
pairing: IpPairing = self.owner
pairing = self.owner
try:
if self.host != pairing.description.address:
logger.debug(
Expand Down
Loading