From 317123119434e22e11c273bf422917bfa0f2a626 Mon Sep 17 00:00:00 2001 From: Aymeric Augustin Date: Wed, 28 Aug 2024 21:40:33 +0200 Subject: [PATCH] Add tests for the logger argument of clients. --- tests/asyncio/test_client.py | 8 ++++++++ tests/asyncio/test_server.py | 2 +- tests/sync/test_client.py | 8 ++++++++ tests/sync/test_server.py | 2 +- 4 files changed, 18 insertions(+), 2 deletions(-) diff --git a/tests/asyncio/test_client.py b/tests/asyncio/test_client.py index 53a6eaaf..15178f8b 100644 --- a/tests/asyncio/test_client.py +++ b/tests/asyncio/test_client.py @@ -1,4 +1,5 @@ import asyncio +import logging import socket import ssl import unittest @@ -78,6 +79,13 @@ async def test_disable_keepalive(self): await asyncio.sleep(2 * MS) self.assertEqual(client.latency, 0) + async def test_logger(self): + """Client accepts a logger argument.""" + logger = logging.getLogger("test") + async with serve(*args) as server: + async with connect(get_uri(server), logger=logger) as client: + self.assertEqual(client.logger.name, logger.name) + async def test_custom_connection_factory(self): """Client runs ClientConnection factory provided in create_connection.""" diff --git a/tests/asyncio/test_server.py b/tests/asyncio/test_server.py index fe0cafe1..ceb0417a 100644 --- a/tests/asyncio/test_server.py +++ b/tests/asyncio/test_server.py @@ -352,7 +352,7 @@ async def test_logger(self): """Server accepts a logger argument.""" logger = logging.getLogger("test") async with serve(*args, logger=logger) as server: - self.assertIs(server.logger, logger) + self.assertEqual(server.logger.name, logger.name) async def test_custom_connection_factory(self): """Server runs ServerConnection factory provided in create_connection.""" diff --git a/tests/sync/test_client.py b/tests/sync/test_client.py index 0d5273d1..81241220 100644 --- a/tests/sync/test_client.py +++ b/tests/sync/test_client.py @@ -1,3 +1,4 @@ +import logging import socket import ssl import threading @@ -67,6 +68,13 @@ def test_disable_compression(self): with connect(get_uri(server), compression=None) as client: self.assertEqual(client.protocol.extensions, []) + def test_logger(self): + """Client accepts a logger argument.""" + logger = logging.getLogger("test") + with run_server() as server: + with connect(get_uri(server), logger=logger) as client: + self.assertEqual(client.logger.name, logger.name) + def test_custom_connection_factory(self): """Client runs ClientConnection factory provided in create_connection.""" diff --git a/tests/sync/test_server.py b/tests/sync/test_server.py index a4b537c6..541a1460 100644 --- a/tests/sync/test_server.py +++ b/tests/sync/test_server.py @@ -238,7 +238,7 @@ def test_logger(self): """Server accepts a logger argument.""" logger = logging.getLogger("test") with run_server(logger=logger) as server: - self.assertIs(server.logger, logger) + self.assertEqual(server.logger.name, logger.name) def test_custom_connection_factory(self): """Server runs ServerConnection factory provided in create_connection."""