From 1105e68ce28dced39e551eaabe477a6dc769d21e Mon Sep 17 00:00:00 2001 From: Uku Loskit Date: Tue, 9 Jul 2024 21:55:04 +0300 Subject: [PATCH] Convert tests to pytests --- .github/workflows/tests.yml | 2 +- Dockerfile | 2 +- msgq/tests/test_fake.py | 65 +++++++++++++------------- msgq/tests/test_messaging.py | 33 ++++++------- msgq/tests/test_poller.py | 20 ++++---- msgq/visionipc/tests/test_visionipc.py | 45 ++++++++---------- 6 files changed, 80 insertions(+), 87 deletions(-) diff --git a/.github/workflows/tests.yml b/.github/workflows/tests.yml index 16219b0bc..83694d59a 100644 --- a/.github/workflows/tests.yml +++ b/.github/workflows/tests.yml @@ -43,7 +43,7 @@ jobs: msgq/test_runner && \ msgq/visionipc/test_runner" - name: python tests - run: $RUN_NAMED "${{ matrix.backend }}=1 coverage run -m unittest discover ." + run: $RUN_NAMED "${{ matrix.backend }}=1 pytest" - name: Upload coverage run: | docker commit msgq msgqci diff --git a/Dockerfile b/Dockerfile index 982d8fa5b..77ef04c29 100644 --- a/Dockerfile +++ b/Dockerfile @@ -35,7 +35,7 @@ RUN apt-get update && apt-get install -y --no-install-recommends \ zlib1g-dev \ && rm -rf /var/lib/apt/lists/* -RUN pip3 install --break-system-packages --no-cache-dir pyyaml Cython scons pycapnp pre-commit ruff parameterized coverage numpy +RUN pip3 install --break-system-packages --no-cache-dir pyyaml Cython scons pycapnp pre-commit ruff parameterized coverage numpy pytest WORKDIR /project/msgq/ RUN cd /tmp/ && \ diff --git a/msgq/tests/test_fake.py b/msgq/tests/test_fake.py index b5ed297ab..98cb618b0 100644 --- a/msgq/tests/test_fake.py +++ b/msgq/tests/test_fake.py @@ -1,3 +1,4 @@ +import pytest import os import unittest import multiprocessing @@ -9,18 +10,18 @@ WAIT_TIMEOUT = 5 -@unittest.skipIf(platform.system() == "Darwin", "Events not supported on macOS") -class TestEvents(unittest.TestCase): +@pytest.mark.skipif(platform.system() == "Darwin", reason="Events not supported on macOS") +class TestEvents: def test_mutation(self): handle = msgq.fake_event_handle("carState") event = handle.recv_called_event - self.assertFalse(event.peek()) + assert not event.peek() event.set() - self.assertTrue(event.peek()) + assert event.peek() event.clear() - self.assertFalse(event.peek()) + assert not event.peek() del event @@ -31,9 +32,9 @@ def test_wait(self): event.set() try: event.wait(WAIT_TIMEOUT) - self.assertTrue(event.peek()) + assert event.peek() except RuntimeError: - self.fail("event.wait() timed out") + pytest.fail("event.wait() timed out") def test_wait_multiprocess(self): handle = msgq.fake_event_handle("carState") @@ -46,9 +47,9 @@ def set_event_run(): p = multiprocessing.Process(target=set_event_run) p.start() event.wait(WAIT_TIMEOUT) - self.assertTrue(event.peek()) + assert event.peek() except RuntimeError: - self.fail("event.wait() timed out") + pytest.fail("event.wait() timed out") p.kill() @@ -58,34 +59,34 @@ def test_wait_zero_timeout(self): try: event.wait(0) - self.fail("event.wait() did not time out") + pytest.fail("event.wait() did not time out") except RuntimeError: - self.assertFalse(event.peek()) + assert not event.peek() -@unittest.skipIf(platform.system() == "Darwin", "FakeSockets not supported on macOS") -@unittest.skipIf("ZMQ" in os.environ, "FakeSockets not supported on ZMQ") +@pytest.mark.skipif(platform.system() == "Darwin", reason="FakeSockets not supported on macOS") +@pytest.mark.skipif("ZMQ" in os.environ, reason="FakeSockets not supported on ZMQ") @parameterized_class([{"prefix": None}, {"prefix": "test"}]) -class TestFakeSockets(unittest.TestCase): +class TestFakeSockets: prefix: Optional[str] = None - def setUp(self): + def setup_method(self): msgq.toggle_fake_events(True) if self.prefix is not None: msgq.set_fake_prefix(self.prefix) else: msgq.delete_fake_prefix() - def tearDown(self): + def teardown_method(self): msgq.toggle_fake_events(False) msgq.delete_fake_prefix() def test_event_handle_init(self): handle = msgq.fake_event_handle("controlsState", override=True) - self.assertFalse(handle.enabled) - self.assertGreaterEqual(handle.recv_called_event.fd, 0) - self.assertGreaterEqual(handle.recv_ready_event.fd, 0) + assert not handle.enabled + assert handle.recv_called_event.fd >= 0 + assert handle.recv_ready_event.fd >= 0 def test_non_managed_socket_state(self): # non managed socket should have zero state @@ -93,9 +94,9 @@ def test_non_managed_socket_state(self): handle = msgq.fake_event_handle("ubloxGnss", override=False) - self.assertFalse(handle.enabled) - self.assertEqual(handle.recv_called_event.fd, 0) - self.assertEqual(handle.recv_ready_event.fd, 0) + assert not handle.enabled + assert handle.recv_called_event.fd == 0 + assert handle.recv_ready_event.fd == 0 def test_managed_socket_state(self): # managed socket should not change anything about the state @@ -108,9 +109,9 @@ def test_managed_socket_state(self): _ = msgq.pub_sock("ubloxGnss") - self.assertEqual(handle.enabled, expected_enabled) - self.assertEqual(handle.recv_called_event.fd, expected_recv_called_fd) - self.assertEqual(handle.recv_ready_event.fd, expected_recv_ready_fd) + assert handle.enabled == expected_enabled + assert handle.recv_called_event.fd == expected_recv_called_fd + assert handle.recv_ready_event.fd == expected_recv_ready_fd def test_sockets_enable_disable(self): carState_handle = msgq.fake_event_handle("ubloxGnss", enable=True) @@ -125,16 +126,16 @@ def test_sockets_enable_disable(self): recv_ready.set() pub_sock.send(b"test") _ = sub_sock.receive() - self.assertTrue(recv_called.peek()) + assert recv_called.peek() recv_called.clear() carState_handle.enabled = False recv_ready.set() pub_sock.send(b"test") _ = sub_sock.receive() - self.assertFalse(recv_called.peek()) + assert not recv_called.peek() except RuntimeError: - self.fail("event.wait() timed out") + pytest.fail("event.wait() timed out") def test_synced_pub_sub(self): def daemon_repub_process_run(): @@ -177,13 +178,13 @@ def daemon_repub_process_run(): recv_called.wait(WAIT_TIMEOUT) msg = sub_sock.receive(non_blocking=True) - self.assertIsNotNone(msg) - self.assertEqual(len(msg), 8) + assert msg is not None + assert len(msg) == 8 frame = int.from_bytes(msg, 'little') - self.assertEqual(frame, i) + assert frame == i except RuntimeError: - self.fail("event.wait() timed out") + pytest.fail("event.wait() timed out") finally: p.kill() diff --git a/msgq/tests/test_messaging.py b/msgq/tests/test_messaging.py index bbeeb3d84..e87208a31 100755 --- a/msgq/tests/test_messaging.py +++ b/msgq/tests/test_messaging.py @@ -5,6 +5,9 @@ import time import string import unittest + +import pytest + import msgq @@ -20,23 +23,25 @@ def zmq_sleep(t=1): def zmq_expected_failure(func): if "ZMQ" in os.environ: - return unittest.expectedFailure(func) + with pytest.raises(Exception): + func() else: - return func + return func() def delayed_send(delay, sock, dat): def send_func(): sock.send(dat) threading.Timer(delay, send_func).start() -class TestPubSubSockets(unittest.TestCase): - - def setUp(self): +@pytest.fixture +def do_sleep(): # ZMQ pub socket takes too long to die # sleep to prevent multiple publishers error between tests zmq_sleep() - def test_pub_sub(self): +class TestPubSubSockets: + + def test_pub_sub(self, do_sleep): sock = random_sock() pub_sock = msgq.pub_sock(sock) sub_sock = msgq.sub_sock(sock, conflate=False, timeout=None) @@ -46,9 +51,9 @@ def test_pub_sub(self): msg = random_bytes() pub_sock.send(msg) recvd = sub_sock.receive() - self.assertEqual(msg, recvd) + assert msg == recvd - def test_conflate(self): + def test_conflate(self, do_sleep): sock = random_sock() pub_sock = msgq.pub_sock(sock) for conflate in [True, False]: @@ -65,12 +70,12 @@ def test_conflate(self): time.sleep(0.1) recvd_msgs = msgq.drain_sock_raw(sub_sock) if conflate: - self.assertEqual(len(recvd_msgs), 1) + assert len(recvd_msgs) == 1 else: # TODO: compare actual data - self.assertEqual(len(recvd_msgs), len(sent_msgs)) + assert len(recvd_msgs) == len(sent_msgs) - def test_receive_timeout(self): + def test_receive_timeout(self, do_sleep): sock = random_sock() for _ in range(10): timeout = random.randrange(200) @@ -79,9 +84,5 @@ def test_receive_timeout(self): start_time = time.monotonic() recvd = sub_sock.receive() - self.assertLess(time.monotonic() - start_time, 0.2) + assert time.monotonic() - start_time < 0.2 assert recvd is None - - -if __name__ == "__main__": - unittest.main() diff --git a/msgq/tests/test_poller.py b/msgq/tests/test_poller.py index a68ff4fe7..2bbac98aa 100644 --- a/msgq/tests/test_poller.py +++ b/msgq/tests/test_poller.py @@ -1,4 +1,4 @@ -import unittest +import pytest import time import msgq import concurrent.futures @@ -20,7 +20,7 @@ def poller(): return r -class TestPoller(unittest.TestCase): +class TestPoller: def test_poll_once(self): context = msgq.Context() @@ -41,7 +41,7 @@ def test_poll_once(self): del pub context.term() - self.assertEqual(result, [b"a"]) + assert result == [b"a"] def test_poll_and_create_many_subscribers(self): context = msgq.Context() @@ -68,12 +68,12 @@ def test_poll_and_create_many_subscribers(self): del pub context.term() - self.assertEqual(result, [b"a"]) + assert result == [b"a"] def test_multiple_publishers_exception(self): context = msgq.Context() - with self.assertRaises(msgq.MultiplePublishersError): + with pytest.raises(msgq.MultiplePublishersError): pub1 = msgq.PubSocket() pub1.connect(context, SERVICE_NAME) @@ -106,7 +106,7 @@ def test_multiple_messages(self): r = sub.receive(non_blocking=True) if r is not None: - self.assertEqual(b'a'*i, r) + assert b'a'*i == r msg_seen = True i += 1 @@ -131,12 +131,8 @@ def test_conflate(self): pub.send(b'a') pub.send(b'b') - self.assertEqual(b'b', sub.receive()) + assert b'b' == sub.receive() del pub del sub - context.term() - - -if __name__ == "__main__": - unittest.main() + context.term() \ No newline at end of file diff --git a/msgq/visionipc/tests/test_visionipc.py b/msgq/visionipc/tests/test_visionipc.py index 1c34613dd..aa7c4c2bc 100755 --- a/msgq/visionipc/tests/test_visionipc.py +++ b/msgq/visionipc/tests/test_visionipc.py @@ -2,7 +2,6 @@ import os import time import random -import unittest import numpy as np from msgq.visionipc import VisionIpcServer, VisionIpcClient, VisionStreamType @@ -11,7 +10,7 @@ def zmq_sleep(t=1): time.sleep(t) -class TestVisionIpc(unittest.TestCase): +class TestVisionIpc: def setup_vipc(self, name, *stream_types, num_buffers=1, rgb=False, width=100, height=100, conflate=False): self.server = VisionIpcServer(name) @@ -21,7 +20,7 @@ def setup_vipc(self, name, *stream_types, num_buffers=1, rgb=False, width=100, h if len(stream_types): self.client = VisionIpcClient(name, stream_types[0], conflate) - self.assertTrue(self.client.connect(True)) + assert self.client.connect(True) else: self.client = None @@ -30,28 +29,28 @@ def setup_vipc(self, name, *stream_types, num_buffers=1, rgb=False, width=100, h def test_connect(self): self.setup_vipc("camerad", VisionStreamType.VISION_STREAM_ROAD) - self.assertTrue(self.client.is_connected) + assert self.client.is_connected def test_available_streams(self): for k in range(4): stream_types = set(random.choices([x.value for x in VisionStreamType], k=k)) self.setup_vipc("camerad", *stream_types) available_streams = VisionIpcClient.available_streams("camerad", True) - self.assertEqual(available_streams, stream_types) + assert available_streams == stream_types def test_buffers(self): width, height, num_buffers = 100, 200, 5 self.setup_vipc("camerad", VisionStreamType.VISION_STREAM_ROAD, num_buffers=num_buffers, width=width, height=height) - self.assertEqual(self.client.width, width) - self.assertEqual(self.client.height, height) - self.assertGreater(self.client.buffer_len, 0) - self.assertEqual(self.client.num_buffers, num_buffers) + assert self.client.width == width + assert self.client.height == height + assert self.client.buffer_len > 0 + assert self.client.num_buffers == num_buffers def test_yuv_rgb(self): _, client_yuv = self.setup_vipc("camerad", VisionStreamType.VISION_STREAM_ROAD, rgb=False) _, client_rgb = self.setup_vipc("navd", VisionStreamType.VISION_STREAM_MAP, rgb=True) - self.assertTrue(client_rgb.rgb) - self.assertFalse(client_yuv.rgb) + assert client_rgb.rgb + assert not client_yuv.rgb def test_send_single_buffer(self): self.setup_vipc("camerad", VisionStreamType.VISION_STREAM_ROAD) @@ -61,9 +60,9 @@ def test_send_single_buffer(self): self.server.send(VisionStreamType.VISION_STREAM_ROAD, buf, frame_id=1337) recv_buf = self.client.recv() - self.assertIsNot(recv_buf, None) - self.assertEqual(recv_buf.data.view('