diff --git a/app/contacts/contact_tcp.py b/app/contacts/contact_tcp.py index d6ff0b754..65cf66a15 100644 --- a/app/contacts/contact_tcp.py +++ b/app/contacts/contact_tcp.py @@ -7,6 +7,7 @@ from app.utility.base_world import BaseWorld from plugins.manx.app.c_session import Session +from plugins.manx.app.c_connection import Connection class Contact(BaseWorld): @@ -60,7 +61,7 @@ async def refresh(self): session = self.sessions[index] try: - session.connection.send(str.encode(' ')) + await session.connection.send(str.encode(' ')) except socket.error: self.log.debug('Error occurred when refreshing session %s. Removing from session pool.', session.id) del self.sessions[index] @@ -73,20 +74,19 @@ async def accept(self, reader, writer): except Exception as e: self.log.debug('Handshake failed: %s' % e) return - connection = writer.get_extra_info('socket') profile['executors'] = [e for e in profile['executors'].split(',') if e] profile['contact'] = 'tcp' agent, _ = await self.services.get('contact_svc').handle_heartbeat(**profile) - new_session = Session(id=self.generate_number(size=6), paw=agent.paw, connection=connection) + new_session = Session(id=self.generate_number(size=6), paw=agent.paw, connection=Connection(reader, writer)) self.sessions.append(new_session) await self.send(new_session.id, agent.paw, timeout=5) async def send(self, session_id: int, cmd: str, timeout: int = 60) -> Tuple[int, str, str, str]: try: conn = next(i.connection for i in self.sessions if i.id == int(session_id)) - conn.send(str.encode(' ')) + await conn.send(str.encode(' ')) time.sleep(0.01) - conn.send(str.encode('%s\n' % cmd)) + await conn.send(str.encode('%s\n' % cmd)) response = await self._attempt_connection(session_id, conn, timeout=timeout) response = json.loads(response) return response['status'], response['pwd'], response['response'], response.get('agent_reported_time', '') @@ -106,7 +106,7 @@ async def _attempt_connection(self, session_id, connection, timeout): time.sleep(0.1) # initial wait for fast operations. while True: try: - part = connection.recv(buffer) + part = await connection.recv(buffer) data += part if len(part) < buffer: break diff --git a/tests/contacts/test_contact_tcp.py b/tests/contacts/test_contact_tcp.py index cb420a64f..c96b407e6 100644 --- a/tests/contacts/test_contact_tcp.py +++ b/tests/contacts/test_contact_tcp.py @@ -1,37 +1,99 @@ +import json import logging import socket from unittest import mock +from tests.conftest import async_return from app.contacts.contact_tcp import TcpSessionHandler +from plugins.manx.app.c_session import Session logger = logging.getLogger(__name__) class TestTcpSessionHandler: - def test_refresh_with_socket_errors(self, event_loop): + def test_refresh_with_socket_errors(self, event_loop, async_return): handler = TcpSessionHandler(services=None, log=logger) session_with_socket_error = mock.Mock() session_with_socket_error.connection.send.side_effect = socket.error() + standard_session = mock.Mock() + standard_session.connection.send.return_value = async_return(True) + handler.sessions = [ session_with_socket_error, session_with_socket_error, - mock.Mock() + standard_session ] event_loop.run_until_complete(handler.refresh()) assert len(handler.sessions) == 1 assert all(x is not session_with_socket_error for x in handler.sessions) - def test_refresh_without_socket_errors(self, event_loop): + def test_refresh_without_socket_errors(self, event_loop, async_return): + standard_session = mock.Mock() + standard_session.connection.send.return_value = async_return(True) + handler = TcpSessionHandler(services=None, log=logger) handler.sessions = [ - mock.Mock(), - mock.Mock(), - mock.Mock() + standard_session, + standard_session, + standard_session ] event_loop.run_until_complete(handler.refresh()) assert len(handler.sessions) == 3 + + async def test_send_with_connection_errors(self, async_return): + test_session_id = 123 + test_paw = 'paw123' + test_cmd = 'whoami' + test_exception = Exception('Exception Raised') + + mock_connection = mock.Mock() + mock_connection.send.return_value = async_return(True) + standard_session = Session(id=test_session_id, paw=test_paw, connection=mock_connection) + + handler = TcpSessionHandler(services=None, log=logger) + handler.sessions = [ + standard_session, + standard_session + ] + + handler._attempt_connection = mock.Mock() + handler._attempt_connection.side_effect = test_exception + response = await handler.send(test_session_id, test_cmd) + expected_response = (1, '~$ ', str(test_exception), '') + + assert len(handler.sessions) == 2 + assert response == expected_response + + async def test_send_without_connection_error(self, async_return): + test_session_id = 123 + test_paw = 'paw123' + test_cmd = 'whoami' + json_response = { + 'status': 0, + 'pwd': '/test', + 'response': '' + } + expected_response = (json_response['status'], json_response['pwd'], json_response['response'], + json_response.get('agent_reported_time', '')) + + mock_connection = mock.Mock() + mock_connection.send.return_value = async_return(True) + standard_session = Session(id=test_session_id, paw=test_paw, connection=mock_connection) + + handler = TcpSessionHandler(services=None, log=logger) + handler.sessions = [ + standard_session, + standard_session + ] + + handler._attempt_connection = mock.Mock() + handler._attempt_connection.return_value = async_return(json.dumps(json_response)) + received_response = await handler.send(test_session_id, test_cmd) + + assert len(handler.sessions) == 2 + assert received_response == expected_response