diff --git a/locust/runners.py b/locust/runners.py index 8d1aa08c28..72e8cbac29 100644 --- a/locust/runners.py +++ b/locust/runners.py @@ -495,7 +495,8 @@ class WorkerLocustRunner(DistributedLocustRunner): def __init__(self, *args, master_host, master_port, **kwargs): super().__init__(*args, **kwargs) self.client_id = socket.gethostname() + "_" + uuid4().hex - + self.master_host = master_host + self.master_port = master_port self.client = rpc.Client(master_host, master_port, self.client_id) self.greenlet.spawn(self.heartbeat).link_exception(callback=self.noop) self.greenlet.spawn(self.worker).link_exception(callback=self.noop) diff --git a/locust/test/test_runners.py b/locust/test/test_runners.py index 936f73f8e6..0d7d061f3a 100644 --- a/locust/test/test_runners.py +++ b/locust/test/test_runners.py @@ -18,6 +18,7 @@ from locust.wait_time import between, constant NETWORK_BROKEN = "network broken" +UNHANDLED_EXCEPTION = "unhandled exception" def mocked_rpc(): class MockedRpcServerClient(object): @@ -34,7 +35,12 @@ def mocked_send(cls, message): def recv(self): results = self.queue.get() - return Message.unserialize(results) + msg = Message.unserialize(results) + if msg.data == NETWORK_BROKEN: + raise RPCError() + if msg.data == UNHANDLED_EXCEPTION: + raise HeyAnException() + return msg def send(self, message): self.outbox.append(message) @@ -47,6 +53,8 @@ def recv_from_client(self): msg = Message.unserialize(results) if msg.data == NETWORK_BROKEN: raise RPCError() + if msg.data == UNHANDLED_EXCEPTION: + raise HeyAnException() return msg.node_id, msg def close(self): @@ -73,6 +81,8 @@ def __init__(self): def reset_stats(self): pass +class HeyAnException(Exception): + pass class TestLocustRunner(LocustTestCase): def assert_locust_class_distribution(self, expected_distribution, classes): @@ -603,9 +613,6 @@ def test_spawn_locusts_in_stepload_mode(self): self.assertEqual(10, num_clients, "Total number of locusts that would have been spawned for second step is not 10") def test_exception_in_task(self): - class HeyAnException(Exception): - pass - class MyLocust(Locust): @task def will_error(self): @@ -627,8 +634,6 @@ def will_error(self): def test_exception_is_catched(self): """ Test that exceptions are stored, and execution continues """ - class HeyAnException(Exception): - pass class MyTaskSet(TaskSet): def __init__(self, *a, **kw): @@ -666,13 +671,19 @@ class MyLocust(Locust): self.assertTrue("HeyAnException" in exception["traceback"]) self.assertEqual(2, exception["count"]) - def test_reset_connection(self): + def test_master_reset_connection(self): """ Test that connection will be reset when network issues found """ with mock.patch("locust.rpc.rpc.Server", mocked_rpc()) as server: master = self.get_runner() server.mocked_send(Message("client_ready", NETWORK_BROKEN, "fake_client")) - sleep(6) + sleep(3) assert master.connection_broken == True + server.mocked_send(Message("client_ready", None, "fake_client")) + sleep(3) + assert master.connection_broken == False + server.mocked_send(Message("client_ready", UNHANDLED_EXCEPTION, "fake_client")) + sleep(3) + assert master.connection_broken == False class TestWorkerLocustRunner(LocustTestCase): def setUp(self): @@ -792,7 +803,6 @@ def my_task(self): self.assertEqual(9, len(worker.locusts)) worker.quit() - class TestMessageSerializing(unittest.TestCase): def test_message_serialize(self): msg = Message("client_ready", None, "my_id")