Skip to content

Commit

Permalink
Convert tests to pytests (#626)
Browse files Browse the repository at this point in the history
* make test pass

* linter

---------

Co-authored-by: Maxime Desroches <desroches.maxime@gmail.com>
  • Loading branch information
UkuLoskit and maxime-desroches authored Jul 10, 2024
1 parent 74074d6 commit d7b99c4
Show file tree
Hide file tree
Showing 7 changed files with 94 additions and 100 deletions.
2 changes: 1 addition & 1 deletion .github/workflows/tests.yml
Original file line number Diff line number Diff line change
Expand Up @@ -43,7 +43,7 @@ jobs:
msgq/test_runner && \
msgq/visionipc/test_runner"
- name: python tests
run: $RUN_NAMED "${{ matrix.backend }}=1 coverage run -m unittest discover ."
run: $RUN_NAMED "${{ matrix.backend }}=1 coverage run -m pytest"
- name: Upload coverage
run: |
docker commit msgq msgqci
Expand Down
2 changes: 1 addition & 1 deletion Dockerfile
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,7 @@ RUN apt-get update && apt-get install -y --no-install-recommends \
zlib1g-dev \
&& rm -rf /var/lib/apt/lists/*

RUN pip3 install --break-system-packages --no-cache-dir pyyaml Cython scons pycapnp pre-commit ruff parameterized coverage numpy
RUN pip3 install --break-system-packages --no-cache-dir pyyaml Cython scons pycapnp pre-commit ruff parameterized coverage numpy pytest

WORKDIR /project/msgq/
RUN cd /tmp/ && \
Expand Down
70 changes: 33 additions & 37 deletions msgq/tests/test_fake.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
import pytest
import os
import unittest
import multiprocessing
import platform
import msgq
Expand All @@ -9,18 +9,18 @@
WAIT_TIMEOUT = 5


@unittest.skipIf(platform.system() == "Darwin", "Events not supported on macOS")
class TestEvents(unittest.TestCase):
@pytest.mark.skipif(condition=platform.system() == "Darwin", reason="Events not supported on macOS")
class TestEvents:

def test_mutation(self):
handle = msgq.fake_event_handle("carState")
event = handle.recv_called_event

self.assertFalse(event.peek())
assert not event.peek()
event.set()
self.assertTrue(event.peek())
assert event.peek()
event.clear()
self.assertFalse(event.peek())
assert not event.peek()

del event

Expand All @@ -31,9 +31,9 @@ def test_wait(self):
event.set()
try:
event.wait(WAIT_TIMEOUT)
self.assertTrue(event.peek())
assert event.peek()
except RuntimeError:
self.fail("event.wait() timed out")
pytest.fail("event.wait() timed out")

def test_wait_multiprocess(self):
handle = msgq.fake_event_handle("carState")
Expand All @@ -46,9 +46,9 @@ def set_event_run():
p = multiprocessing.Process(target=set_event_run)
p.start()
event.wait(WAIT_TIMEOUT)
self.assertTrue(event.peek())
assert event.peek()
except RuntimeError:
self.fail("event.wait() timed out")
pytest.fail("event.wait() timed out")

p.kill()

Expand All @@ -58,44 +58,44 @@ def test_wait_zero_timeout(self):

try:
event.wait(0)
self.fail("event.wait() did not time out")
pytest.fail("event.wait() did not time out")
except RuntimeError:
self.assertFalse(event.peek())
assert not event.peek()


@unittest.skipIf(platform.system() == "Darwin", "FakeSockets not supported on macOS")
@unittest.skipIf("ZMQ" in os.environ, "FakeSockets not supported on ZMQ")
@pytest.mark.skipif(condition=platform.system() == "Darwin", reason="FakeSockets not supported on macOS")
@pytest.mark.skipif(condition="ZMQ" in os.environ, reason="FakeSockets not supported on ZMQ")
@parameterized_class([{"prefix": None}, {"prefix": "test"}])
class TestFakeSockets(unittest.TestCase):
class TestFakeSockets:
prefix: Optional[str] = None

def setUp(self):
def setup_method(self):
msgq.toggle_fake_events(True)
if self.prefix is not None:
msgq.set_fake_prefix(self.prefix)
else:
msgq.delete_fake_prefix()

def tearDown(self):
def teardown_method(self):
msgq.toggle_fake_events(False)
msgq.delete_fake_prefix()

def test_event_handle_init(self):
handle = msgq.fake_event_handle("controlsState", override=True)

self.assertFalse(handle.enabled)
self.assertGreaterEqual(handle.recv_called_event.fd, 0)
self.assertGreaterEqual(handle.recv_ready_event.fd, 0)
assert not handle.enabled
assert handle.recv_called_event.fd >= 0
assert handle.recv_ready_event.fd >= 0

def test_non_managed_socket_state(self):
# non managed socket should have zero state
_ = msgq.pub_sock("ubloxGnss")

handle = msgq.fake_event_handle("ubloxGnss", override=False)

self.assertFalse(handle.enabled)
self.assertEqual(handle.recv_called_event.fd, 0)
self.assertEqual(handle.recv_ready_event.fd, 0)
assert not handle.enabled
assert handle.recv_called_event.fd == 0
assert handle.recv_ready_event.fd == 0

def test_managed_socket_state(self):
# managed socket should not change anything about the state
Expand All @@ -108,9 +108,9 @@ def test_managed_socket_state(self):

_ = msgq.pub_sock("ubloxGnss")

self.assertEqual(handle.enabled, expected_enabled)
self.assertEqual(handle.recv_called_event.fd, expected_recv_called_fd)
self.assertEqual(handle.recv_ready_event.fd, expected_recv_ready_fd)
assert handle.enabled == expected_enabled
assert handle.recv_called_event.fd == expected_recv_called_fd
assert handle.recv_ready_event.fd == expected_recv_ready_fd

def test_sockets_enable_disable(self):
carState_handle = msgq.fake_event_handle("ubloxGnss", enable=True)
Expand All @@ -125,16 +125,16 @@ def test_sockets_enable_disable(self):
recv_ready.set()
pub_sock.send(b"test")
_ = sub_sock.receive()
self.assertTrue(recv_called.peek())
assert recv_called.peek()
recv_called.clear()

carState_handle.enabled = False
recv_ready.set()
pub_sock.send(b"test")
_ = sub_sock.receive()
self.assertFalse(recv_called.peek())
assert not recv_called.peek()
except RuntimeError:
self.fail("event.wait() timed out")
pytest.fail("event.wait() timed out")

def test_synced_pub_sub(self):
def daemon_repub_process_run():
Expand Down Expand Up @@ -177,16 +177,12 @@ def daemon_repub_process_run():
recv_called.wait(WAIT_TIMEOUT)

msg = sub_sock.receive(non_blocking=True)
self.assertIsNotNone(msg)
self.assertEqual(len(msg), 8)
assert msg is not None
assert len(msg) == 8

frame = int.from_bytes(msg, 'little')
self.assertEqual(frame, i)
assert frame == i
except RuntimeError:
self.fail("event.wait() timed out")
pytest.fail("event.wait() timed out")
finally:
p.kill()


if __name__ == "__main__":
unittest.main()
30 changes: 6 additions & 24 deletions msgq/tests/test_messaging.py
100755 → 100644
Original file line number Diff line number Diff line change
@@ -1,10 +1,7 @@
#!/usr/bin/env python3
import os
import random
import threading
import time
import string
import unittest
import msgq


Expand All @@ -18,20 +15,9 @@ def zmq_sleep(t=1):
if "ZMQ" in os.environ:
time.sleep(t)

def zmq_expected_failure(func):
if "ZMQ" in os.environ:
return unittest.expectedFailure(func)
else:
return func

def delayed_send(delay, sock, dat):
def send_func():
sock.send(dat)
threading.Timer(delay, send_func).start()

class TestPubSubSockets(unittest.TestCase):
class TestPubSubSockets:

def setUp(self):
def setup_method(self):
# ZMQ pub socket takes too long to die
# sleep to prevent multiple publishers error between tests
zmq_sleep()
Expand All @@ -46,7 +32,7 @@ def test_pub_sub(self):
msg = random_bytes()
pub_sock.send(msg)
recvd = sub_sock.receive()
self.assertEqual(msg, recvd)
assert msg == recvd

def test_conflate(self):
sock = random_sock()
Expand All @@ -65,10 +51,10 @@ def test_conflate(self):
time.sleep(0.1)
recvd_msgs = msgq.drain_sock_raw(sub_sock)
if conflate:
self.assertEqual(len(recvd_msgs), 1)
assert len(recvd_msgs) == 1
else:
# TODO: compare actual data
self.assertEqual(len(recvd_msgs), len(sent_msgs))
assert len(recvd_msgs) == len(sent_msgs)

def test_receive_timeout(self):
sock = random_sock()
Expand All @@ -79,9 +65,5 @@ def test_receive_timeout(self):

start_time = time.monotonic()
recvd = sub_sock.receive()
self.assertLess(time.monotonic() - start_time, 0.2)
assert (time.monotonic() - start_time) < 0.2
assert recvd is None


if __name__ == "__main__":
unittest.main()
18 changes: 7 additions & 11 deletions msgq/tests/test_poller.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
import unittest
import pytest
import time
import msgq
import concurrent.futures
Expand All @@ -20,7 +20,7 @@ def poller():
return r


class TestPoller(unittest.TestCase):
class TestPoller:
def test_poll_once(self):
context = msgq.Context()

Expand All @@ -41,7 +41,7 @@ def test_poll_once(self):
del pub
context.term()

self.assertEqual(result, [b"a"])
assert result == [b"a"]

def test_poll_and_create_many_subscribers(self):
context = msgq.Context()
Expand All @@ -68,12 +68,12 @@ def test_poll_and_create_many_subscribers(self):
del pub
context.term()

self.assertEqual(result, [b"a"])
assert result == [b"a"]

def test_multiple_publishers_exception(self):
context = msgq.Context()

with self.assertRaises(msgq.MultiplePublishersError):
with pytest.raises(msgq.MultiplePublishersError):
pub1 = msgq.PubSocket()
pub1.connect(context, SERVICE_NAME)

Expand Down Expand Up @@ -106,7 +106,7 @@ def test_multiple_messages(self):
r = sub.receive(non_blocking=True)

if r is not None:
self.assertEqual(b'a'*i, r)
assert b'a'*i == r

msg_seen = True
i += 1
Expand All @@ -131,12 +131,8 @@ def test_conflate(self):
pub.send(b'a')
pub.send(b'b')

self.assertEqual(b'b', sub.receive())
assert b'b' == sub.receive()

del pub
del sub
context.term()


if __name__ == "__main__":
unittest.main()
Loading

0 comments on commit d7b99c4

Please sign in to comment.