Skip to content

Commit

Permalink
Some testing around automatic reconnects
Browse files Browse the repository at this point in the history
  • Loading branch information
Jc2k committed Jul 31, 2019
1 parent 5cd977d commit 0302111
Show file tree
Hide file tree
Showing 3 changed files with 40 additions and 39 deletions.
22 changes: 19 additions & 3 deletions homekit/aio/controller/ip/connection.py
Original file line number Diff line number Diff line change
Expand Up @@ -58,7 +58,17 @@ def connection_lost(self, exception):
self.connection._connection_lost(exception)

def send_bytes(self, payload):
if self.transport.is_closing():
# FIXME: It would be nice to try and wait for the reconnect in future.
# In that case we need to make sure we do it at a layer above send_bytes otherwise
# we might encrypt payloads with the last sessions keys then wait for a new connection
# to send them - and on that connection the keys would be different.
# Also need to make sure that the new connection has chance to pair-verify before
# queued writes can happy.
raise AccessoryDisconnectedError('Transport is closed')

self.transport.write(payload)

# We return a future so that our caller can block on a reply
# We can send many requests and dispatch the results in order
# Should mean we don't need locking around request/reply cycles
Expand Down Expand Up @@ -378,6 +388,15 @@ def _connection_lost(self, exception):
"""
logger.info("Connection %r lost.", self)

if not self.when_connected.done():
self.when_connected.set_exception(
AccessoryDisconnectedError(
'Current connection attempt failed and will be retried',
)
)

self.when_connected = asyncio.Future()

if self.auto_reconnect and not self.closing:
asyncio.ensure_future(self._reconnect())

Expand All @@ -393,9 +412,6 @@ async def _connect_once(self):
)

async def _reconnect(self):
if self.when_connected.done():
self.when_connected = asyncio.Future()

# FIXME: How to integrate discovery here?
# There is aiozeroconf but that doesn't work on Windows until python 3.9
# In HASS, zeroconf is a service provided by HASS itself and want to be able to
Expand Down
45 changes: 9 additions & 36 deletions homekit/aio/controller/ip/pairing.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,8 +22,7 @@

from homekit.controller.tools import AbstractPairing, check_convert_value
from homekit.protocol.statuscodes import HapStatusCodes
from homekit.exceptions import UnknownError, UnpairedError, \
AccessoryDisconnectedError, EncryptionError
from homekit.exceptions import UnknownError, UnpairedError
from homekit.protocol.tlv import TLV
from homekit.model.characteristics import CharacteristicsTypes
from homekit.model.services import ServicesTypes
Expand Down Expand Up @@ -79,12 +78,7 @@ async def list_accessories_and_characteristics(self):
"""
await self._ensure_connected()

try:
response = await self.connection.get_json('/accessories')
except (AccessoryDisconnectedError, EncryptionError):
self.session.close()
self.session = None
raise
response = await self.connection.get_json('/accessories')

accessories = response['accessories']

Expand Down Expand Up @@ -123,15 +117,10 @@ async def list_pairings(self):
"""
await self._ensure_connected()

try:
data = await self.connection.post_tlv('/pairings', [
(TLV.kTLVType_State, TLV.M1),
(TLV.kTLVType_Method, TLV.ListPairings)
])
except (AccessoryDisconnectedError, EncryptionError):
self.session.close()
self.session = None
raise
data = await self.connection.post_tlv('/pairings', [
(TLV.kTLVType_State, TLV.M1),
(TLV.kTLVType_Method, TLV.ListPairings)
])

if not (data[0][0] == TLV.kTLVType_State and data[0][1] == TLV.M2):
raise UnknownError('unexpected data received: ' + str(data))
Expand Down Expand Up @@ -188,12 +177,7 @@ async def get_characteristics(self, characteristics, include_meta=False, include
if include_events:
url += '&ev=1'

try:
response = await self.connection.get_json(url)
except (AccessoryDisconnectedError, EncryptionError):
self.session.close()
self.session = None
raise
response = await self.connection.get_json(url)

tmp = {}
for c in response['characteristics']:
Expand Down Expand Up @@ -347,21 +331,10 @@ async def get_events(self, characteristics, callback_fun, max_events=-1, max_sec
event_count = 0
s = time.time()
while (max_events == -1 or event_count < max_events) and (max_seconds == -1 or s + max_seconds >= time.time()):
try:
r = self.session.sec_http.handle_event_response()
body = r.read().decode()
except (AccessoryDisconnectedError, EncryptionError):
self.session.close()
self.session = None
raise
r = self.session.sec_http.handle_event_response()
body = r.read().decode()

if len(body) > 0:
try:
r = json.loads(body)
except JSONDecodeError:
self.session.close()
self.session = None
raise AccessoryDisconnectedError("Session closed after receiving malformed response from device")
tmp = []
for c in r['characteristics']:
tmp.append((c['aid'], c['iid'], c['value']))
Expand Down
12 changes: 12 additions & 0 deletions tests/aio/test_ip_pairing.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
import pytest

from homekit import AccessoryServer
from homekit.exceptions import AccessoryDisconnectedError
from homekit.model import Accessory
from homekit.model.services import LightBulbService
from homekit.model import mixin as model_mixin
Expand Down Expand Up @@ -111,6 +112,17 @@ async def test_get_characteristics_after_failure(pairing):

pairing.connection.transport.close()

# The connection is closed but the reconnection mechanism hasn't kicked in yet.
# Attempts to use the connection should fail.
with pytest.raises(AccessoryDisconnectedError):
characteristics = await pairing.get_characteristics([
(1, 10),
])

# We can't await a close - this lets the coroutine fall into the 'reactor'
# and process queued work which will include the real transport.close work.
await asyncio.sleep(0)

characteristics = await pairing.get_characteristics([
(1, 10),
])
Expand Down

0 comments on commit 0302111

Please sign in to comment.