Skip to content

Commit

Permalink
More robust handling of ZMQ/RPC errors (#2120)
Browse files Browse the repository at this point in the history
* More robust RPC error handling on msg from worker

* Use dedicated exceptions, fewer nested try blocks

* Fix test_zmqrpc.py

* Undo function split since added new exceptions

* Fix more tests

* Fix some tests

* Fix typo

* Fix scoping of variables

* Add tests for RPC/ZMQ changes

* flake and black fixes

* Remove debug print line

Co-authored-by: Ryan Warner <ryan.warner@edgecast.com>
  • Loading branch information
solowalker27 and Ryan Warner authored Jun 27, 2022
1 parent 481e066 commit 357f106
Show file tree
Hide file tree
Showing 5 changed files with 134 additions and 22 deletions.
16 changes: 16 additions & 0 deletions locust/exception.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,6 +54,22 @@ class RPCError(Exception):
"""


class RPCSendError(Exception):
"""
Exception when sending message to client.
When raised from zmqrpc, sending can be retried or RPC can be reestablished.
"""


class RPCReceiveError(Exception):
"""
Exception when receiving message from client is interrupted or message is corrupted.
When raised from zmqrpc, client connection should be reestablished.
"""


class AuthCredentialsError(ValueError):
"""
Exception when the auth credentials provided
Expand Down
21 changes: 12 additions & 9 deletions locust/rpc/zmqrpc.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
import zmq.green as zmq
from .protocol import Message
from locust.util.exception_handler import retry
from locust.exception import RPCError
from locust.exception import RPCError, RPCSendError, RPCReceiveError
import zmq.error as zmqerr
import msgpack.exceptions as msgerr

Expand All @@ -19,21 +19,21 @@ def send(self, msg):
try:
self.socket.send(msg.serialize(), zmq.NOBLOCK)
except zmqerr.ZMQError as e:
raise RPCError("ZMQ sent failure") from e
raise RPCSendError("ZMQ sent failure") from e

@retry()
def send_to_client(self, msg):
try:
self.socket.send_multipart([msg.node_id.encode(), msg.serialize()])
except zmqerr.ZMQError as e:
raise RPCError("ZMQ sent failure") from e
raise RPCSendError("ZMQ sent failure") from e

def recv(self):
try:
data = self.socket.recv()
msg = Message.unserialize(data)
except msgerr.ExtraData as e:
raise RPCError("ZMQ interrupted message") from e
raise RPCReceiveError("ZMQ interrupted message") from e
except zmqerr.ZMQError as e:
raise RPCError("ZMQ network broken") from e
return msg
Expand All @@ -42,15 +42,18 @@ def recv_from_client(self):
try:
data = self.socket.recv_multipart()
addr = data[0].decode()
msg = Message.unserialize(data[1])
except (UnicodeDecodeError, msgerr.ExtraData) as e:
raise RPCError("ZMQ interrupted message") from e
except UnicodeDecodeError as e:
raise RPCReceiveError("ZMQ interrupted or corrupted message") from e
except zmqerr.ZMQError as e:
raise RPCError("ZMQ network broken") from e
try:
msg = Message.unserialize(data[1])
except (UnicodeDecodeError, msgerr.ExtraData) as e:
raise RPCReceiveError("ZMQ interrupted or corrupted message") from e
return addr, msg

def close(self):
self.socket.close()
def close(self, linger=None):
self.socket.close(linger=linger)


class Server(BaseSocket):
Expand Down
29 changes: 23 additions & 6 deletions locust/runners.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,7 +46,7 @@
from . import User
from locust import __version__
from .dispatch import UsersDispatcher
from .exception import RPCError
from .exception import RPCError, RPCReceiveError, RPCSendError
from .log import greenlet_exception_logger
from .rpc import (
Message,
Expand Down Expand Up @@ -946,9 +946,9 @@ def heartbeat_worker(self) -> NoReturn:
self.start(user_count=self.target_user_count, spawn_rate=self.spawn_rate)

def reset_connection(self) -> None:
logger.info("Reset connection to worker")
logger.info("Resetting RPC server and all client connections.")
try:
self.server.close()
self.server.close(linger=0)
self.server = rpc.Server(self.master_bind_host, self.master_bind_port)
self.connection_broken = False
except RPCError as e:
Expand All @@ -958,12 +958,26 @@ def client_listener(self) -> NoReturn:
while True:
try:
client_id, msg = self.server.recv_from_client()
except RPCReceiveError as e:
logger.error(f"RPCError when receiving from client: {e}. Will reset client {client_id}.")
try:
self.server.send_to_client(Message("reconnect", None, client_id))
except Exception as e:
logger.error(f"Error sending reconnect message to client: {e}. Will reset RPC server.")
self.connection_broken = True
gevent.sleep(FALLBACK_INTERVAL)
continue
except RPCSendError as e:
logger.error(f"Error sending reconnect message to client: {e}. Will reset RPC server.")
self.connection_broken = True
gevent.sleep(FALLBACK_INTERVAL)
continue
except RPCError as e:
if self.clients.ready:
logger.error(f"RPCError found when receiving from client: {e}")
if self.clients.ready or self.clients.spawning or self.clients.running:
logger.error(f"RPCError: {e}. Will reset RPC server.")
else:
logger.debug(
"RPCError found when receiving from client: %s (but no clients were expected to be connected anyway)"
"RPCError when receiving from client: %s (but no clients were expected to be connected anyway)"
% (e)
)
self.connection_broken = True
Expand Down Expand Up @@ -1285,6 +1299,9 @@ def worker(self) -> NoReturn:
self.stop()
self._send_stats() # send a final report, in case there were any samples not yet reported
self.greenlet.kill(block=True)
elif msg.type == "reconnect":
logger.warning("Received reconnect message from master. Resetting RPC connection.")
self.reset_connection()
elif msg.type in self.custom_messages:
logger.debug(f"Received {msg.type} message from master")
self.custom_messages[msg.type](environment=self.environment, msg=msg)
Expand Down
86 changes: 81 additions & 5 deletions locust/test/test_runners.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,10 +21,7 @@
)
from locust.argument_parser import parse_options
from locust.env import Environment
from locust.exception import (
RPCError,
StopUser,
)
from locust.exception import RPCError, StopUser, RPCReceiveError
from locust.main import create_environment
from locust.rpc import Message
from locust.runners import (
Expand All @@ -49,6 +46,7 @@
from .util import patch_env

NETWORK_BROKEN = "network broken"
BAD_MESSAGE = "bad message"


def mocked_rpc(raise_on_close=True):
Expand Down Expand Up @@ -83,9 +81,11 @@ def recv_from_client(self):
msg = Message.unserialize(results)
if msg.data == NETWORK_BROKEN:
raise RPCError()
if msg.data == BAD_MESSAGE:
raise RPCReceiveError("Bad message")
return msg.node_id, msg

def close(self):
def close(self, linger=None):
if self.raise_error_on_close:
raise RPCError()
else:
Expand Down Expand Up @@ -2923,6 +2923,34 @@ def test_master_discard_first_client_ready(self):
self.assertEqual("ack", server.outbox[0][1].type)
self.assertEqual(1, len(server.outbox))

def test_worker_sends_bad_message_to_master(self):
"""
Validate master sends reconnect message to worker when it receives a bad message.
"""

class TestUser(User):
@task
def my_task(self):
pass

with mock.patch("locust.rpc.rpc.Server", mocked_rpc()) as server:
master = self.get_runner(user_classes=[TestUser])
server.mocked_send(Message("client_ready", __version__, "zeh_fake_client1"))
self.assertEqual(1, len(master.clients))
self.assertTrue(
"zeh_fake_client1" in master.clients, "Could not find fake client in master instance's clients dict"
)

master.start(10, 10)
sleep(0.1)
server.mocked_send(Message("stats", BAD_MESSAGE, "zeh_fake_client1"))
self.assertEqual(4, len(server.outbox))

# Expected message order in outbox: ack, spawn, reconnect, ack
self.assertEqual(
"reconnect", server.outbox[2][1].type, "Master didn't send worker reconnect message when expected."
)


class TestWorkerRunner(LocustTestCase):
def setUp(self):
Expand Down Expand Up @@ -3201,6 +3229,54 @@ def my_task(self):

worker.quit()

def test_reset_rpc_connection_to_master(self):
"""
Validate worker resets RPC connection to master on "reconnect" message.
"""

class MyUser(User):
wait_time = constant(1)

@task
def my_task(self):
pass

with mock.patch("locust.rpc.rpc.Client", mocked_rpc(raise_on_close=False)) as client:
client_id = id(client)
worker = self.get_runner(environment=Environment(), user_classes=[MyUser], client=client)
client.mocked_send(
Message(
"spawn",
{
"timestamp": 1605538584,
"user_classes_count": {"MyUser": 10},
"host": "",
"stop_timeout": None,
"parsed_options": {},
},
"dummy_client_id",
)
)
sleep(0.6)
self.assertEqual(STATE_RUNNING, worker.state)
with self.assertLogs("locust.runners") as capture:
with mock.patch("locust.rpc.rpc.Client.close") as close:
client.mocked_send(
Message(
"reconnect",
None,
"dummy_client_id",
)
)
sleep(0)
worker.spawning_greenlet.join()
worker.quit()
close.assert_called_once()
self.assertIn(
"WARNING:locust.runners:Received reconnect message from master. Resetting RPC connection.",
capture.output,
)

def test_change_user_count_during_spawning(self):
class MyUser(User):
wait_time = constant(1)
Expand Down
4 changes: 2 additions & 2 deletions locust/test/test_zmqrpc.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
import zmq
from locust.rpc import zmqrpc, Message
from locust.test.testcases import LocustTestCase
from locust.exception import RPCError
from locust.exception import RPCError, RPCSendError, RPCReceiveError


class ZMQRPC_tests(LocustTestCase):
Expand Down Expand Up @@ -50,5 +50,5 @@ def test_rpc_error(self):
with self.assertRaises(RPCError):
server = zmqrpc.Server("127.0.0.1", server.port)
server.close()
with self.assertRaises(RPCError):
with self.assertRaises(RPCSendError):
server.send_to_client(Message("test", "message", "identity"))

0 comments on commit 357f106

Please sign in to comment.