Skip to content

Commit

Permalink
Merge pull request #7364 from kozlovsky/fix/events_endpoint_send_even…
Browse files Browse the repository at this point in the history
…ts_serialized

Send events to GUI only before shutdown and in the proper order
  • Loading branch information
kozlovsky authored Apr 14, 2023
2 parents 32ac3ad + 2b7e261 commit 7e6cb60
Show file tree
Hide file tree
Showing 2 changed files with 93 additions and 45 deletions.
63 changes: 46 additions & 17 deletions src/tribler/core/components/restapi/rest/events_endpoint.py
Original file line number Diff line number Diff line change
@@ -1,15 +1,15 @@
import asyncio
import json
import time
from asyncio import CancelledError
from asyncio import CancelledError, Queue
from dataclasses import asdict
from typing import List, Optional
from typing import Any, Dict, List, Optional

import marshmallow.fields
from aiohttp import web
from aiohttp_apispec import docs
from ipv8.REST.schema import schema
from ipv8.messaging.anonymization.tunnel import Circuit
from marshmallow.fields import Dict, String

from tribler.core import notifications
from tribler.core.components.reporter.reported_error import ReportedError
Expand Down Expand Up @@ -38,6 +38,9 @@ def passthrough(x):
]


MessageDict = Dict[str, Any]


@froze_it
class EventsEndpoint(RESTEndpoint):
"""
Expand All @@ -49,16 +52,17 @@ class EventsEndpoint(RESTEndpoint):
def __init__(self, notifier: Notifier, public_key: str = None):
super().__init__()
self.events_responses: List[RESTStreamResponse] = []
self.undelivered_error: Optional[dict] = None
self.undelivered_error: Optional[MessageDict] = None
self.public_key = public_key
self.notifier = notifier
self.queue = Queue()
self.async_group.add_task(self.process_queue())
notifier.add_observer(notifications.circuit_removed, self.on_circuit_removed)
notifier.add_generic_observer(self.on_notification)

def on_notification(self, topic, *args, **kwargs):
if topic in topics_to_send_to_gui:
data = {"topic": topic.__name__, "args": args, "kwargs": kwargs}
self.async_group.add_task(self.write_data(data))
self.send_event({"topic": topic.__name__, "args": args, "kwargs": kwargs})

def on_circuit_removed(self, circuit: Circuit, additional_info: str):
# The original notification contains non-JSON-serializable argument, so we send another one to GUI
Expand All @@ -75,19 +79,19 @@ async def shutdown(self):
def setup_routes(self):
self.app.add_routes([web.get('', self.get_events)])

def initial_message(self) -> dict:
def initial_message(self) -> MessageDict:
return {
"topic": notifications.events_start.__name__,
"kwargs": {"public_key": self.public_key, "version": version_id}
}

def error_message(self, reported_error: ReportedError) -> dict:
def error_message(self, reported_error: ReportedError) -> MessageDict:
return {
"topic": notifications.tribler_exception.__name__,
"kwargs": {"error": asdict(reported_error)},
}

def encode_message(self, message: dict) -> bytes:
def encode_message(self, message: MessageDict) -> bytes:
try:
message = json.dumps(message)
except UnicodeDecodeError:
Expand All @@ -96,17 +100,43 @@ def encode_message(self, message: dict) -> bytes:
message = json.dumps(fix_unicode_dict(message))
return b'data: ' + message.encode('utf-8') + b'\n\n'

def has_connection_to_gui(self):
def has_connection_to_gui(self) -> bool:
return bool(self.events_responses)

async def write_data(self, message):
def should_skip_message(self, message: MessageDict) -> bool:
"""
Write data over the event socket if it's open.
Returns True if EventsEndpoint should skip sending message to GUI due to a shutdown or no connection to GUI.
Issue an appropriate warning if the message cannot be sent.
"""
if self._shutdown:
self._logger.warning(f"Shutdown is in progress, skip message: {message}")
return True

if not self.has_connection_to_gui():
return
self._logger.warning(f"No connections to GUI, skip message: {message}")
return True

return False

def send_event(self, message: MessageDict):
"""
Put event message to a queue to be sent to GUI
"""
if not self.should_skip_message(message):
self.queue.put_nowait(message)

async def process_queue(self):
while True:
message = await self.queue.get()
if not self.should_skip_message(message):
await self._write_data(message)

async def _write_data(self, message: MessageDict):
"""
Write data over the event socket if it's open.
"""
self._logger.debug(f'Write message: {message}')
try:
self._logger.debug(f'Write message: {message}')
message_bytes = self.encode_message(message)
except Exception as e: # pylint: disable=broad-except
# if a notification arguments contains non-JSON-serializable data, the exception should be logged
Expand All @@ -125,7 +155,7 @@ def on_tribler_exception(self, reported_error: ReportedError):

message = self.error_message(reported_error)
if self.has_connection_to_gui():
self.async_group.add_task(self.write_data(message))
self.send_event(message)
elif not self.undelivered_error:
# If there are several undelivered errors, we store the first error as more important and skip other
self.undelivered_error = message
Expand All @@ -135,8 +165,7 @@ def on_tribler_exception(self, reported_error: ReportedError):
summary="Open an EventStream for receiving Tribler events.",
responses={
200: {
"schema": schema(EventsResponse={'type': String,
'event': Dict})
"schema": schema(EventsResponse={'type': marshmallow.fields.String, 'event': marshmallow.fields.Dict})
}
}
)
Expand Down
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import asyncio
import json
from asyncio import CancelledError, Event, create_task
from contextlib import suppress
Expand Down Expand Up @@ -31,8 +32,8 @@ def fixture_notifier(event_loop):
return Notifier(loop=event_loop)


@pytest.fixture
async def endpoint(notifier):
@pytest.fixture(name='events_endpoint')
async def events_endpoint_fixture(notifier):
events_endpoint = EventsEndpoint(notifier)
yield events_endpoint

Expand All @@ -45,12 +46,12 @@ def fixture_reported_error():


@pytest.fixture(name="rest_manager")
async def fixture_rest_manager(api_port, tmp_path, endpoint):
async def fixture_rest_manager(api_port, tmp_path, events_endpoint):
config = TriblerConfig()
config.api.http_enabled = True
config.api.http_port = api_port
root_endpoint = RootEndpoint(middlewares=[ApiKeyMiddleware(config.api.key), error_middleware])
root_endpoint.add_endpoint('/events', endpoint)
root_endpoint.add_endpoint('/events', events_endpoint)
rest_manager = RESTManager(config=config.api, root_endpoint=root_endpoint, state_dir=tmp_path)

await rest_manager.start()
Expand Down Expand Up @@ -114,64 +115,82 @@ async def test_events(rest_manager, notifier: Notifier):
await event_socket_task


@patch.object(EventsEndpoint, 'write_data')
@patch.object(EventsEndpoint, '_write_data')
@patch.object(EventsEndpoint, 'has_connection_to_gui', new=MagicMock(return_value=True))
async def test_on_tribler_exception_has_connection_to_gui(mocked_write_data, endpoint, reported_error):
async def test_on_tribler_exception_has_connection_to_gui(mocked_write_data, events_endpoint, reported_error):
# test that in case of established connection to GUI, `on_tribler_exception` will work
# as a normal endpoint function, that is call `write_data`
endpoint.on_tribler_exception(reported_error)
# as a normal events_endpoint function, that is call `_write_data`
events_endpoint.on_tribler_exception(reported_error)
await asyncio.sleep(0.01)

mocked_write_data.assert_called_once()
assert not endpoint.undelivered_error
assert not events_endpoint.undelivered_error


@patch.object(EventsEndpoint, 'write_data')
@patch.object(EventsEndpoint, '_write_data')
@patch.object(EventsEndpoint, 'has_connection_to_gui', new=MagicMock(return_value=False))
async def test_on_tribler_exception_no_connection_to_gui(mocked_write_data, endpoint, reported_error):
async def test_on_tribler_exception_no_connection_to_gui(mocked_write_data, events_endpoint, reported_error):
# test that if no connection to GUI, then `on_tribler_exception` will store
# reported_error in `self.undelivered_error`
endpoint.on_tribler_exception(reported_error)
events_endpoint.on_tribler_exception(reported_error)

mocked_write_data.assert_not_called()
assert endpoint.undelivered_error == endpoint.error_message(reported_error)
assert events_endpoint.undelivered_error == events_endpoint.error_message(reported_error)


@patch.object(EventsEndpoint, 'write_data', new=MagicMock())
@patch.object(EventsEndpoint, '_write_data', new=MagicMock())
@patch.object(EventsEndpoint, 'has_connection_to_gui', new=MagicMock(return_value=False))
async def test_on_tribler_exception_stores_only_first_error(endpoint, reported_error):
async def test_on_tribler_exception_stores_only_first_error(events_endpoint, reported_error):
# test that if no connection to GUI, then `on_tribler_exception` will store
# only the very first `reported_error`
first_reported_error = reported_error
endpoint.on_tribler_exception(first_reported_error)
events_endpoint.on_tribler_exception(first_reported_error)

second_reported_error = ReportedError('second_type', 'second_text', {})
endpoint.on_tribler_exception(second_reported_error)
events_endpoint.on_tribler_exception(second_reported_error)

assert endpoint.undelivered_error == endpoint.error_message(first_reported_error)
assert events_endpoint.undelivered_error == events_endpoint.error_message(first_reported_error)


@patch('asyncio.sleep', new=AsyncMock(side_effect=CancelledError))
@patch.object(RESTStreamResponse, 'prepare', new=AsyncMock())
@patch.object(RESTStreamResponse, 'write', new_callable=AsyncMock)
@patch.object(EventsEndpoint, 'encode_message')
async def test_get_events_has_undelivered_error(mocked_encode_message, mocked_write, endpoint):
async def test_get_events_has_undelivered_error(mocked_encode_message, mocked_write, events_endpoint):
# test that in case `self.undelivered_error` is not None, then it will be sent
endpoint.undelivered_error = {'undelivered': 'error'}
events_endpoint.undelivered_error = {'undelivered': 'error'}

await endpoint.get_events(MagicMock())
await events_endpoint.get_events(MagicMock())

mocked_write.assert_called()
mocked_encode_message.assert_called_with({'undelivered': 'error'})
assert not endpoint.undelivered_error
assert not events_endpoint.undelivered_error


async def test_on_tribler_exception_shutdown():
# test that `on_tribler_exception` will not send any error message if endpoint is shutting down
endpoint = EventsEndpoint(Mock())
endpoint.error_message = Mock()
# test that `on_tribler_exception` will not send any error message if events_endpoint is shutting down
events_endpoint = EventsEndpoint(Mock())
events_endpoint.error_message = Mock()

await endpoint.shutdown()
await events_endpoint.shutdown()

events_endpoint.on_tribler_exception(ReportedError('', '', {}))

assert not events_endpoint.error_message.called


async def test_should_skip_message(events_endpoint):
assert not events_endpoint._shutdown and not events_endpoint.events_responses # pylint: disable=protected-access
message = Mock()

# Initially the events endpoint is not in shutdown state, but it does not have any connection,
# so it should skip message as nobody is listen to it
assert events_endpoint.should_skip_message(message)

endpoint.on_tribler_exception(ReportedError('', '', {}))
with patch.object(events_endpoint, 'events_responses', new=[Mock()]):
# We add a mocked connection to GUI, and now the events endpoint should not skip a message
assert not events_endpoint.should_skip_message(message)

assert not endpoint.error_message.called
with patch.object(events_endpoint, '_shutdown', new=True):
# But, if it is in shutdown state, it should always skip a message
assert events_endpoint.should_skip_message(message)

0 comments on commit 7e6cb60

Please sign in to comment.