Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[VIRTS-4695] Fix Deprecated Manx Socket Issues #2852

Closed
wants to merge 6 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
12 changes: 6 additions & 6 deletions app/contacts/contact_tcp.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@

from app.utility.base_world import BaseWorld
from plugins.manx.app.c_session import Session
from plugins.manx.app.c_connection import Connection


class Contact(BaseWorld):
Expand Down Expand Up @@ -60,7 +61,7 @@ async def refresh(self):
session = self.sessions[index]

try:
session.connection.send(str.encode(' '))
await session.connection.send(str.encode(' '))
except socket.error:
self.log.debug('Error occurred when refreshing session %s. Removing from session pool.', session.id)
del self.sessions[index]
Expand All @@ -73,20 +74,19 @@ async def accept(self, reader, writer):
except Exception as e:
self.log.debug('Handshake failed: %s' % e)
return
connection = writer.get_extra_info('socket')
profile['executors'] = [e for e in profile['executors'].split(',') if e]
profile['contact'] = 'tcp'
agent, _ = await self.services.get('contact_svc').handle_heartbeat(**profile)
new_session = Session(id=self.generate_number(size=6), paw=agent.paw, connection=connection)
new_session = Session(id=self.generate_number(size=6), paw=agent.paw, connection=Connection(reader, writer))
self.sessions.append(new_session)
await self.send(new_session.id, agent.paw, timeout=5)

async def send(self, session_id: int, cmd: str, timeout: int = 60) -> Tuple[int, str, str, str]:
try:
conn = next(i.connection for i in self.sessions if i.id == int(session_id))
conn.send(str.encode(' '))
await conn.send(str.encode(' '))
time.sleep(0.01)
conn.send(str.encode('%s\n' % cmd))
await conn.send(str.encode('%s\n' % cmd))
response = await self._attempt_connection(session_id, conn, timeout=timeout)
response = json.loads(response)
return response['status'], response['pwd'], response['response'], response.get('agent_reported_time', '')
Expand All @@ -106,7 +106,7 @@ async def _attempt_connection(self, session_id, connection, timeout):
time.sleep(0.1) # initial wait for fast operations.
while True:
try:
part = connection.recv(buffer)
part = await connection.recv(buffer)
data += part
if len(part) < buffer:
break
Expand Down
74 changes: 68 additions & 6 deletions tests/contacts/test_contact_tcp.py
Original file line number Diff line number Diff line change
@@ -1,37 +1,99 @@
import json
import logging
import socket
from unittest import mock
from tests.conftest import async_return

from app.contacts.contact_tcp import TcpSessionHandler
from plugins.manx.app.c_session import Session

logger = logging.getLogger(__name__)


class TestTcpSessionHandler:

def test_refresh_with_socket_errors(self, event_loop):
def test_refresh_with_socket_errors(self, event_loop, async_return):
handler = TcpSessionHandler(services=None, log=logger)

session_with_socket_error = mock.Mock()
session_with_socket_error.connection.send.side_effect = socket.error()

standard_session = mock.Mock()
standard_session.connection.send.return_value = async_return(True)

handler.sessions = [
session_with_socket_error,
session_with_socket_error,
mock.Mock()
standard_session
]

event_loop.run_until_complete(handler.refresh())
assert len(handler.sessions) == 1
assert all(x is not session_with_socket_error for x in handler.sessions)

def test_refresh_without_socket_errors(self, event_loop):
def test_refresh_without_socket_errors(self, event_loop, async_return):
standard_session = mock.Mock()
standard_session.connection.send.return_value = async_return(True)

handler = TcpSessionHandler(services=None, log=logger)
handler.sessions = [
mock.Mock(),
mock.Mock(),
mock.Mock()
standard_session,
standard_session,
standard_session
]

event_loop.run_until_complete(handler.refresh())
assert len(handler.sessions) == 3

async def test_send_with_connection_errors(self, async_return):
test_session_id = 123
test_paw = 'paw123'
test_cmd = 'whoami'
test_exception = Exception('Exception Raised')

mock_connection = mock.Mock()
mock_connection.send.return_value = async_return(True)
standard_session = Session(id=test_session_id, paw=test_paw, connection=mock_connection)

handler = TcpSessionHandler(services=None, log=logger)
handler.sessions = [
standard_session,
standard_session
]

handler._attempt_connection = mock.Mock()
handler._attempt_connection.side_effect = test_exception
response = await handler.send(test_session_id, test_cmd)
expected_response = (1, '~$ ', str(test_exception), '')

assert len(handler.sessions) == 2
assert response == expected_response

async def test_send_without_connection_error(self, async_return):
test_session_id = 123
test_paw = 'paw123'
test_cmd = 'whoami'
json_response = {
'status': 0,
'pwd': '/test',
'response': ''
}
expected_response = (json_response['status'], json_response['pwd'], json_response['response'],
json_response.get('agent_reported_time', ''))

mock_connection = mock.Mock()
mock_connection.send.return_value = async_return(True)
standard_session = Session(id=test_session_id, paw=test_paw, connection=mock_connection)

handler = TcpSessionHandler(services=None, log=logger)
handler.sessions = [
standard_session,
standard_session
]

handler._attempt_connection = mock.Mock()
handler._attempt_connection.return_value = async_return(json.dumps(json_response))
received_response = await handler.send(test_session_id, test_cmd)

assert len(handler.sessions) == 2
assert received_response == expected_response
Loading