diff --git a/asyncpg/_testbase/__init__.py b/asyncpg/_testbase/__init__.py index 9944b20f..7aca834f 100644 --- a/asyncpg/_testbase/__init__.py +++ b/asyncpg/_testbase/__init__.py @@ -435,3 +435,93 @@ def tearDown(self): self.con = None finally: super().tearDown() + + +class HotStandbyTestCase(ClusterTestCase): + + @classmethod + def setup_cluster(cls): + cls.master_cluster = cls.new_cluster(pg_cluster.TempCluster) + cls.start_cluster( + cls.master_cluster, + server_settings={ + 'max_wal_senders': 10, + 'wal_level': 'hot_standby' + } + ) + + con = None + + try: + con = cls.loop.run_until_complete( + cls.master_cluster.connect( + database='postgres', user='postgres', loop=cls.loop)) + + cls.loop.run_until_complete( + con.execute(''' + CREATE ROLE replication WITH LOGIN REPLICATION + ''')) + + cls.master_cluster.trust_local_replication_by('replication') + + conn_spec = cls.master_cluster.get_connection_spec() + + cls.standby_cluster = cls.new_cluster( + pg_cluster.HotStandbyCluster, + cluster_kwargs={ + 'master': conn_spec, + 'replication_user': 'replication' + } + ) + cls.start_cluster( + cls.standby_cluster, + server_settings={ + 'hot_standby': True + } + ) + + finally: + if con is not None: + cls.loop.run_until_complete(con.close()) + + @classmethod + def get_cluster_connection_spec(cls, cluster, kwargs={}): + conn_spec = cluster.get_connection_spec() + if kwargs.get('dsn'): + conn_spec.pop('host') + conn_spec.update(kwargs) + if not os.environ.get('PGHOST') and not kwargs.get('dsn'): + if 'database' not in conn_spec: + conn_spec['database'] = 'postgres' + if 'user' not in conn_spec: + conn_spec['user'] = 'postgres' + return conn_spec + + @classmethod + def get_connection_spec(cls, kwargs={}): + primary_spec = cls.get_cluster_connection_spec( + cls.master_cluster, kwargs + ) + standby_spec = cls.get_cluster_connection_spec( + cls.standby_cluster, kwargs + ) + return { + 'host': [primary_spec['host'], standby_spec['host']], + 'port': [primary_spec['port'], standby_spec['port']], + 'database': primary_spec['database'], + 'user': primary_spec['user'], + **kwargs + } + + @classmethod + def connect_primary(cls, **kwargs): + conn_spec = cls.get_cluster_connection_spec(cls.master_cluster, kwargs) + return pg_connection.connect(**conn_spec, loop=cls.loop) + + @classmethod + def connect_standby(cls, **kwargs): + conn_spec = cls.get_cluster_connection_spec( + cls.standby_cluster, + kwargs + ) + return pg_connection.connect(**conn_spec, loop=cls.loop) diff --git a/asyncpg/cluster.py b/asyncpg/cluster.py index 0999e41c..4467cc2a 100644 --- a/asyncpg/cluster.py +++ b/asyncpg/cluster.py @@ -626,7 +626,7 @@ def init(self, **settings): 'pg_basebackup init exited with status {:d}:\n{}'.format( process.returncode, output.decode())) - if self._pg_version <= (11, 0): + if self._pg_version < (12, 0): with open(os.path.join(self._data_dir, 'recovery.conf'), 'w') as f: f.write(textwrap.dedent("""\ standby_mode = 'on' diff --git a/asyncpg/connect_utils.py b/asyncpg/connect_utils.py index 5f6423c2..ad5fabc2 100644 --- a/asyncpg/connect_utils.py +++ b/asyncpg/connect_utils.py @@ -13,6 +13,7 @@ import os import pathlib import platform +import random import re import socket import ssl as ssl_module @@ -56,6 +57,7 @@ def parse(cls, sslmode): 'direct_tls', 'connect_timeout', 'server_settings', + 'target_session_attrs', ]) @@ -260,7 +262,8 @@ def _dot_postgresql_path(filename) -> typing.Optional[pathlib.Path]: def _parse_connect_dsn_and_args(*, dsn, host, port, user, password, passfile, database, ssl, - direct_tls, connect_timeout, server_settings): + direct_tls, connect_timeout, server_settings, + target_session_attrs): # `auth_hosts` is the version of host information for the purposes # of reading the pgpass file. auth_hosts = None @@ -607,10 +610,28 @@ def _parse_connect_dsn_and_args(*, dsn, host, port, user, 'server_settings is expected to be None or ' 'a Dict[str, str]') + if target_session_attrs is None: + + target_session_attrs = os.getenv( + "PGTARGETSESSIONATTRS", SessionAttribute.any + ) + try: + + target_session_attrs = SessionAttribute(target_session_attrs) + except ValueError as exc: + raise exceptions.InterfaceError( + "target_session_attrs is expected to be one of " + "{!r}" + ", got {!r}".format( + SessionAttribute.__members__.values, target_session_attrs + ) + ) from exc + params = _ConnectionParameters( user=user, password=password, database=database, ssl=ssl, sslmode=sslmode, direct_tls=direct_tls, - connect_timeout=connect_timeout, server_settings=server_settings) + connect_timeout=connect_timeout, server_settings=server_settings, + target_session_attrs=target_session_attrs) return addrs, params @@ -620,8 +641,8 @@ def _parse_connect_arguments(*, dsn, host, port, user, password, passfile, statement_cache_size, max_cached_statement_lifetime, max_cacheable_statement_size, - ssl, direct_tls, server_settings): - + ssl, direct_tls, server_settings, + target_session_attrs): local_vars = locals() for var_name in {'max_cacheable_statement_size', 'max_cached_statement_lifetime', @@ -649,7 +670,8 @@ def _parse_connect_arguments(*, dsn, host, port, user, password, passfile, dsn=dsn, host=host, port=port, user=user, password=password, passfile=passfile, ssl=ssl, direct_tls=direct_tls, database=database, - connect_timeout=timeout, server_settings=server_settings) + connect_timeout=timeout, server_settings=server_settings, + target_session_attrs=target_session_attrs) config = _ClientConfiguration( command_timeout=command_timeout, @@ -882,18 +904,84 @@ async def __connect_addr( return con +class SessionAttribute(str, enum.Enum): + any = 'any' + primary = 'primary' + standby = 'standby' + prefer_standby = 'prefer-standby' + read_write = "read-write" + read_only = "read-only" + + +def _accept_in_hot_standby(should_be_in_hot_standby: bool): + """ + If the server didn't report "in_hot_standby" at startup, we must determine + the state by checking "SELECT pg_catalog.pg_is_in_recovery()". + If the server allows a connection and states it is in recovery it must + be a replica/standby server. + """ + async def can_be_used(connection): + settings = connection.get_settings() + hot_standby_status = getattr(settings, 'in_hot_standby', None) + if hot_standby_status is not None: + is_in_hot_standby = hot_standby_status == 'on' + else: + is_in_hot_standby = await connection.fetchval( + "SELECT pg_catalog.pg_is_in_recovery()" + ) + return is_in_hot_standby == should_be_in_hot_standby + + return can_be_used + + +def _accept_read_only(should_be_read_only: bool): + """ + Verify the server has not set default_transaction_read_only=True + """ + async def can_be_used(connection): + settings = connection.get_settings() + is_readonly = getattr(settings, 'default_transaction_read_only', 'off') + + if is_readonly == "on": + return should_be_read_only + + return await _accept_in_hot_standby(should_be_read_only)(connection) + return can_be_used + + +async def _accept_any(_): + return True + + +target_attrs_check = { + SessionAttribute.any: _accept_any, + SessionAttribute.primary: _accept_in_hot_standby(False), + SessionAttribute.standby: _accept_in_hot_standby(True), + SessionAttribute.prefer_standby: _accept_in_hot_standby(True), + SessionAttribute.read_write: _accept_read_only(False), + SessionAttribute.read_only: _accept_read_only(True), +} + + +async def _can_use_connection(connection, attr: SessionAttribute): + can_use = target_attrs_check[attr] + return await can_use(connection) + + async def _connect(*, loop, timeout, connection_class, record_class, **kwargs): if loop is None: loop = asyncio.get_event_loop() addrs, params, config = _parse_connect_arguments(timeout=timeout, **kwargs) + target_attr = params.target_session_attrs + candidates = [] + chosen_connection = None last_error = None - addr = None for addr in addrs: before = time.monotonic() try: - return await _connect_addr( + conn = await _connect_addr( addr=addr, loop=loop, timeout=timeout, @@ -902,12 +990,30 @@ async def _connect(*, loop, timeout, connection_class, record_class, **kwargs): connection_class=connection_class, record_class=record_class, ) + candidates.append(conn) + if await _can_use_connection(conn, target_attr): + chosen_connection = conn + break except (OSError, asyncio.TimeoutError, ConnectionError) as ex: last_error = ex finally: timeout -= time.monotonic() - before + else: + if target_attr == SessionAttribute.prefer_standby and candidates: + chosen_connection = random.choice(candidates) + + await asyncio.gather( + (c.close() for c in candidates if c is not chosen_connection), + return_exceptions=True + ) + + if chosen_connection: + return chosen_connection - raise last_error + raise last_error or exceptions.TargetServerAttributeNotMatched( + 'None of the hosts match the target attribute requirement ' + '{!r}'.format(target_attr) + ) async def _cancel(*, loop, addr, params: _ConnectionParameters, diff --git a/asyncpg/connection.py b/asyncpg/connection.py index 73cb6e66..0b13d356 100644 --- a/asyncpg/connection.py +++ b/asyncpg/connection.py @@ -257,9 +257,9 @@ def transaction(self, *, isolation=None, readonly=False, :param isolation: Transaction isolation mode, can be one of: `'serializable'`, `'repeatable_read'`, - `'read_committed'`. If not specified, the behavior - is up to the server and session, which is usually - ``read_committed``. + `'read_uncommitted'`, `'read_committed'`. If not + specified, the behavior is up to the server and + session, which is usually ``read_committed``. :param readonly: Specifies whether or not this transaction is read-only. @@ -1792,7 +1792,8 @@ async def connect(dsn=None, *, direct_tls=False, connection_class=Connection, record_class=protocol.Record, - server_settings=None): + server_settings=None, + target_session_attrs=None): r"""A coroutine to establish a connection to a PostgreSQL server. The connection parameters may be specified either as a connection @@ -2003,6 +2004,22 @@ async def connect(dsn=None, *, this connection object. Must be a subclass of :class:`~asyncpg.Record`. + :param SessionAttribute target_session_attrs: + If specified, check that the host has the correct attribute. + Can be one of: + "any": the first successfully connected host + "primary": the host must NOT be in hot standby mode + "standby": the host must be in hot standby mode + "read-write": the host must allow writes + "read-only": the host most NOT allow writes + "prefer-standby": first try to find a standby host, but if + none of the listed hosts is a standby server, + return any of them. + + If not specified will try to use PGTARGETSESSIONATTRS + from the environment. + Defaults to "any" if no value is set. + :return: A :class:`~asyncpg.connection.Connection` instance. Example: @@ -2109,6 +2126,7 @@ async def connect(dsn=None, *, statement_cache_size=statement_cache_size, max_cached_statement_lifetime=max_cached_statement_lifetime, max_cacheable_statement_size=max_cacheable_statement_size, + target_session_attrs=target_session_attrs ) diff --git a/asyncpg/exceptions/_base.py b/asyncpg/exceptions/_base.py index 783b5eb5..de981d25 100644 --- a/asyncpg/exceptions/_base.py +++ b/asyncpg/exceptions/_base.py @@ -13,7 +13,7 @@ __all__ = ('PostgresError', 'FatalPostgresError', 'UnknownPostgresError', 'InterfaceError', 'InterfaceWarning', 'PostgresLogMessage', 'InternalClientError', 'OutdatedSchemaCacheError', 'ProtocolError', - 'UnsupportedClientFeatureError') + 'UnsupportedClientFeatureError', 'TargetServerAttributeNotMatched') def _is_asyncpg_class(cls): @@ -244,6 +244,10 @@ class ProtocolError(InternalClientError): """Unexpected condition in the handling of PostgreSQL protocol input.""" +class TargetServerAttributeNotMatched(InternalClientError): + """Could not find a host that satisfies the target attribute requirement""" + + class OutdatedSchemaCacheError(InternalClientError): """A value decoding error caused by a schema change before row fetching.""" diff --git a/asyncpg/protocol/scram.pyx b/asyncpg/protocol/scram.pyx index 765ddd46..9b485aee 100644 --- a/asyncpg/protocol/scram.pyx +++ b/asyncpg/protocol/scram.pyx @@ -156,12 +156,12 @@ cdef class SCRAMAuthentication: if not self.server_nonce.startswith(self.client_nonce): raise Exception("invalid nonce") try: - self.password_salt = re.search(b's=([^,]+),', + self.password_salt = re.search(b',s=([^,]+),', self.server_first_message).group(1) except IndexError: raise Exception("could not get salt") try: - self.password_iterations = int(re.search(b'i=(\d+),?', + self.password_iterations = int(re.search(b',i=(\d+),?', self.server_first_message).group(1)) except (IndexError, TypeError, ValueError): raise Exception("could not get iterations") diff --git a/asyncpg/transaction.py b/asyncpg/transaction.py index 2d7ba49f..562811e6 100644 --- a/asyncpg/transaction.py +++ b/asyncpg/transaction.py @@ -19,9 +19,15 @@ class TransactionState(enum.Enum): FAILED = 4 -ISOLATION_LEVELS = {'read_committed', 'serializable', 'repeatable_read'} +ISOLATION_LEVELS = { + 'read_committed', + 'read_uncommitted', + 'serializable', + 'repeatable_read', +} ISOLATION_LEVELS_BY_VALUE = { 'read committed': 'read_committed', + 'read uncommitted': 'read_uncommitted', 'serializable': 'serializable', 'repeatable read': 'repeatable_read', } @@ -124,6 +130,8 @@ async def start(self): query = 'BEGIN' if self._isolation == 'read_committed': query += ' ISOLATION LEVEL READ COMMITTED' + elif self._isolation == 'read_uncommitted': + query += ' ISOLATION LEVEL READ UNCOMMITTED' elif self._isolation == 'repeatable_read': query += ' ISOLATION LEVEL REPEATABLE READ' elif self._isolation == 'serializable': diff --git a/tests/test_connect.py b/tests/test_connect.py index d5cdb18f..31dffd24 100644 --- a/tests/test_connect.py +++ b/tests/test_connect.py @@ -24,7 +24,7 @@ import asyncpg from asyncpg import _testbase as tb -from asyncpg import connection +from asyncpg import connection as pg_connection from asyncpg import connect_utils from asyncpg import cluster as pg_cluster from asyncpg import exceptions @@ -392,7 +392,8 @@ class TestConnectParams(tb.TestCase): 'password': 'passw', 'database': 'testdb', 'ssl': True, - 'sslmode': SSLMode.prefer}) + 'sslmode': SSLMode.prefer, + 'target_session_attrs': 'any'}) }, { @@ -414,7 +415,8 @@ class TestConnectParams(tb.TestCase): 'result': ([('host2', 456)], { 'user': 'user2', 'password': 'passw2', - 'database': 'db2'}) + 'database': 'db2', + 'target_session_attrs': 'any'}) }, { @@ -442,7 +444,8 @@ class TestConnectParams(tb.TestCase): 'password': 'passw2', 'database': 'db2', 'sslmode': SSLMode.disable, - 'ssl': False}) + 'ssl': False, + 'target_session_attrs': 'any'}) }, { @@ -463,7 +466,8 @@ class TestConnectParams(tb.TestCase): 'password': '123123', 'database': 'abcdef', 'ssl': True, - 'sslmode': SSLMode.allow}) + 'sslmode': SSLMode.allow, + 'target_session_attrs': 'any'}) }, { @@ -491,7 +495,8 @@ class TestConnectParams(tb.TestCase): 'password': 'passw2', 'database': 'db2', 'sslmode': SSLMode.disable, - 'ssl': False}) + 'ssl': False, + 'target_session_attrs': 'any'}) }, { @@ -512,7 +517,8 @@ class TestConnectParams(tb.TestCase): 'password': '123123', 'database': 'abcdef', 'ssl': True, - 'sslmode': SSLMode.prefer}) + 'sslmode': SSLMode.prefer, + 'target_session_attrs': 'any'}) }, { @@ -521,7 +527,8 @@ class TestConnectParams(tb.TestCase): 'result': ([('localhost', 5555)], { 'user': 'user3', 'password': '123123', - 'database': 'abcdef'}) + 'database': 'abcdef', + 'target_session_attrs': 'any'}) }, { @@ -530,6 +537,7 @@ class TestConnectParams(tb.TestCase): 'result': ([('host1', 5432), ('host2', 5432)], { 'database': 'db', 'user': 'user', + 'target_session_attrs': 'any', }) }, @@ -539,6 +547,7 @@ class TestConnectParams(tb.TestCase): 'result': ([('host1', 1111), ('host2', 2222)], { 'database': 'db', 'user': 'user', + 'target_session_attrs': 'any', }) }, @@ -548,6 +557,7 @@ class TestConnectParams(tb.TestCase): 'result': ([('2001:db8::1234%eth0', 5432), ('::1', 5432)], { 'database': 'db', 'user': 'user', + 'target_session_attrs': 'any', }) }, @@ -557,6 +567,7 @@ class TestConnectParams(tb.TestCase): 'result': ([('2001:db8::1234', 1111), ('::1', 2222)], { 'database': 'db', 'user': 'user', + 'target_session_attrs': 'any', }) }, @@ -566,6 +577,7 @@ class TestConnectParams(tb.TestCase): 'result': ([('2001:db8::1234', 5432), ('::1', 5432)], { 'database': 'db', 'user': 'user', + 'target_session_attrs': 'any', }) }, @@ -580,6 +592,7 @@ class TestConnectParams(tb.TestCase): 'result': ([('host1', 1111), ('host2', 2222)], { 'database': 'db', 'user': 'foo', + 'target_session_attrs': 'any', }) }, @@ -592,6 +605,7 @@ class TestConnectParams(tb.TestCase): 'result': ([('host1', 1111), ('host2', 2222)], { 'database': 'db', 'user': 'foo', + 'target_session_attrs': 'any', }) }, @@ -605,6 +619,7 @@ class TestConnectParams(tb.TestCase): 'result': ([('host1', 5432), ('host2', 5432)], { 'database': 'db', 'user': 'foo', + 'target_session_attrs': 'any', }) }, @@ -637,7 +652,8 @@ class TestConnectParams(tb.TestCase): 'password': 'ask', 'database': 'db', 'ssl': True, - 'sslmode': SSLMode.require}) + 'sslmode': SSLMode.require, + 'target_session_attrs': 'any'}) }, { @@ -658,7 +674,8 @@ class TestConnectParams(tb.TestCase): 'password': 'ask', 'database': 'db', 'sslmode': SSLMode.verify_full, - 'ssl': True}) + 'ssl': True, + 'target_session_attrs': 'any'}) }, { @@ -666,7 +683,8 @@ class TestConnectParams(tb.TestCase): 'dsn': 'postgresql:///dbname?host=/unix_sock/test&user=spam', 'result': ([os.path.join('/unix_sock/test', '.s.PGSQL.5432')], { 'user': 'spam', - 'database': 'dbname'}) + 'database': 'dbname', + 'target_session_attrs': 'any'}) }, { @@ -678,6 +696,7 @@ class TestConnectParams(tb.TestCase): 'user': 'us@r', 'password': 'p@ss', 'database': 'db', + 'target_session_attrs': 'any', } ) }, @@ -691,6 +710,7 @@ class TestConnectParams(tb.TestCase): 'user': 'user', 'password': 'p', 'database': 'db', + 'target_session_attrs': 'any', } ) }, @@ -703,6 +723,7 @@ class TestConnectParams(tb.TestCase): { 'user': 'us@r', 'database': 'db', + 'target_session_attrs': 'any', } ) }, @@ -730,7 +751,8 @@ class TestConnectParams(tb.TestCase): 'user': 'user', 'database': 'user', 'sslmode': SSLMode.disable, - 'ssl': None + 'ssl': None, + 'target_session_attrs': 'any', } ) }, @@ -744,7 +766,8 @@ class TestConnectParams(tb.TestCase): '.s.PGSQL.5432' )], { 'user': 'spam', - 'database': 'db' + 'database': 'db', + 'target_session_attrs': 'any', } ) }, @@ -765,6 +788,7 @@ class TestConnectParams(tb.TestCase): 'database': 'db', 'ssl': True, 'sslmode': SSLMode.prefer, + 'target_session_attrs': 'any', } ) }, @@ -809,6 +833,7 @@ def run_testcase(self, testcase): database = testcase.get('database') sslmode = testcase.get('ssl') server_settings = testcase.get('server_settings') + target_session_attrs = testcase.get('target_session_attrs') expected = testcase.get('result') expected_error = testcase.get('error') @@ -832,7 +857,8 @@ def run_testcase(self, testcase): dsn=dsn, host=host, port=port, user=user, password=password, passfile=passfile, database=database, ssl=sslmode, direct_tls=False, connect_timeout=None, - server_settings=server_settings) + server_settings=server_settings, + target_session_attrs=target_session_attrs) params = { k: v for k, v in params._asdict().items() @@ -893,7 +919,9 @@ def test_test_connect_params_run_testcase(self): 'host': 'abc', 'result': ( [('abc', 5432)], - {'user': '__test__', 'database': '__test__'} + {'user': '__test__', + 'database': '__test__', + 'target_session_attrs': 'any'} ) }) @@ -931,6 +959,7 @@ def test_connect_pgpass_regular(self): 'password': 'password from pgpass for user@abc', 'user': 'user', 'database': 'db', + 'target_session_attrs': 'any', } ) }) @@ -947,6 +976,7 @@ def test_connect_pgpass_regular(self): 'password': 'password from pgpass for user@abc', 'user': 'user', 'database': 'db', + 'target_session_attrs': 'any', } ) }) @@ -961,6 +991,7 @@ def test_connect_pgpass_regular(self): 'password': 'password from pgpass for user@abc', 'user': 'user', 'database': 'db', + 'target_session_attrs': 'any', } ) }) @@ -976,6 +1007,7 @@ def test_connect_pgpass_regular(self): 'password': 'password from pgpass for localhost', 'user': 'user', 'database': 'db', + 'target_session_attrs': 'any', } ) }) @@ -993,6 +1025,7 @@ def test_connect_pgpass_regular(self): 'password': 'password from pgpass for localhost', 'user': 'user', 'database': 'db', + 'target_session_attrs': 'any', } ) }) @@ -1010,6 +1043,7 @@ def test_connect_pgpass_regular(self): 'password': 'password from pgpass for cde:5433', 'user': 'user', 'database': 'db', + 'target_session_attrs': 'any', } ) }) @@ -1026,6 +1060,7 @@ def test_connect_pgpass_regular(self): 'password': 'password from pgpass for testuser', 'user': 'testuser', 'database': 'db', + 'target_session_attrs': 'any', } ) }) @@ -1042,6 +1077,7 @@ def test_connect_pgpass_regular(self): 'password': 'password from pgpass for testdb', 'user': 'user', 'database': 'testdb', + 'target_session_attrs': 'any', } ) }) @@ -1058,6 +1094,7 @@ def test_connect_pgpass_regular(self): 'password': 'password from pgpass with escapes', 'user': R'test\\', 'database': R'test\:db', + 'target_session_attrs': 'any', } ) }) @@ -1085,6 +1122,7 @@ def test_connect_pgpass_badness_mode(self): { 'user': 'user', 'database': 'db', + 'target_session_attrs': 'any', } ) }) @@ -1105,6 +1143,7 @@ def test_connect_pgpass_badness_non_file(self): { 'user': 'user', 'database': 'db', + 'target_session_attrs': 'any', } ) }) @@ -1121,6 +1160,7 @@ def test_connect_pgpass_nonexistent(self): { 'user': 'user', 'database': 'db', + 'target_session_attrs': 'any', } ) }) @@ -1141,6 +1181,7 @@ def test_connect_pgpass_inaccessible_file(self): { 'user': 'user', 'database': 'db', + 'target_session_attrs': 'any', } ) }) @@ -1163,6 +1204,7 @@ def test_connect_pgpass_inaccessible_directory(self): { 'user': 'user', 'database': 'db', + 'target_session_attrs': 'any', } ) }) @@ -1185,7 +1227,7 @@ async def test_connect_args_validation(self): class TestConnection(tb.ConnectedTestCase): async def test_connection_isinstance(self): - self.assertTrue(isinstance(self.con, connection.Connection)) + self.assertTrue(isinstance(self.con, pg_connection.Connection)) self.assertTrue(isinstance(self.con, object)) self.assertFalse(isinstance(self.con, list)) @@ -1778,8 +1820,96 @@ async def test_no_explicit_close_with_debug(self): r'unclosed connection') as rw: await self._run_no_explicit_close_test() - msg = rw.warning.args[0] + msg = " ".join(rw.warning.args) self.assertIn(' created at:\n', msg) self.assertIn('in test_no_explicit_close_with_debug', msg) finally: self.loop.set_debug(olddebug) + + +class TestConnectionAttributes(tb.HotStandbyTestCase): + + async def _run_connection_test( + self, connect, target_attribute, expected_port + ): + conn = await connect(target_session_attrs=target_attribute) + self.assertTrue(_get_connected_host(conn).endswith(expected_port)) + await conn.close() + + async def test_target_server_attribute_port(self): + master_port = self.master_cluster.get_connection_spec()['port'] + standby_port = self.standby_cluster.get_connection_spec()['port'] + tests = [ + (self.connect_primary, 'primary', master_port), + (self.connect_standby, 'standby', standby_port), + ] + + for connect, target_attr, expected_port in tests: + await self._run_connection_test( + connect, target_attr, expected_port + ) + if self.master_cluster.get_pg_version()[0] < 14: + self.skipTest("PostgreSQL<14 does not support these features") + tests = [ + (self.connect_primary, 'read-write', master_port), + (self.connect_standby, 'read-only', standby_port), + ] + + for connect, target_attr, expected_port in tests: + await self._run_connection_test( + connect, target_attr, expected_port + ) + + async def test_target_attribute_not_matched(self): + tests = [ + (self.connect_standby, 'primary'), + (self.connect_primary, 'standby'), + ] + + for connect, target_attr in tests: + with self.assertRaises(exceptions.TargetServerAttributeNotMatched): + await connect(target_session_attrs=target_attr) + + if self.master_cluster.get_pg_version()[0] < 14: + self.skipTest("PostgreSQL<14 does not support these features") + tests = [ + (self.connect_standby, 'read-write'), + (self.connect_primary, 'read-only'), + ] + + for connect, target_attr in tests: + with self.assertRaises(exceptions.TargetServerAttributeNotMatched): + await connect(target_session_attrs=target_attr) + + async def test_prefer_standby_when_standby_is_up(self): + con = await self.connect(target_session_attrs='prefer-standby') + standby_port = self.standby_cluster.get_connection_spec()['port'] + connected_host = _get_connected_host(con) + self.assertTrue(connected_host.endswith(standby_port)) + await con.close() + + async def test_prefer_standby_picks_master_when_standby_is_down(self): + primary_spec = self.get_cluster_connection_spec(self.master_cluster) + connection_spec = { + 'host': [ + primary_spec['host'], + 'unlocalhost', + ], + 'port': [primary_spec['port'], 15345], + 'database': primary_spec['database'], + 'user': primary_spec['user'], + 'target_session_attrs': 'prefer-standby' + } + + con = await self.connect(**connection_spec) + master_port = self.master_cluster.get_connection_spec()['port'] + connected_host = _get_connected_host(con) + self.assertTrue(connected_host.endswith(master_port)) + await con.close() + + +def _get_connected_host(con): + peername = con._transport.get_extra_info('peername') + if isinstance(peername, tuple): + peername = "".join((str(s) for s in peername if s)) + return peername diff --git a/tests/test_pool.py b/tests/test_pool.py index 5577632c..540efb08 100644 --- a/tests/test_pool.py +++ b/tests/test_pool.py @@ -17,7 +17,6 @@ import asyncpg from asyncpg import _testbase as tb from asyncpg import connection as pg_connection -from asyncpg import cluster as pg_cluster from asyncpg import pool as pg_pool _system = platform.uname().system @@ -971,52 +970,7 @@ async def worker(): @unittest.skipIf(os.environ.get('PGHOST'), 'using remote cluster for testing') -class TestHotStandby(tb.ClusterTestCase): - @classmethod - def setup_cluster(cls): - cls.master_cluster = cls.new_cluster(pg_cluster.TempCluster) - cls.start_cluster( - cls.master_cluster, - server_settings={ - 'max_wal_senders': 10, - 'wal_level': 'hot_standby' - } - ) - - con = None - - try: - con = cls.loop.run_until_complete( - cls.master_cluster.connect( - database='postgres', user='postgres', loop=cls.loop)) - - cls.loop.run_until_complete( - con.execute(''' - CREATE ROLE replication WITH LOGIN REPLICATION - ''')) - - cls.master_cluster.trust_local_replication_by('replication') - - conn_spec = cls.master_cluster.get_connection_spec() - - cls.standby_cluster = cls.new_cluster( - pg_cluster.HotStandbyCluster, - cluster_kwargs={ - 'master': conn_spec, - 'replication_user': 'replication' - } - ) - cls.start_cluster( - cls.standby_cluster, - server_settings={ - 'hot_standby': True - } - ) - - finally: - if con is not None: - cls.loop.run_until_complete(con.close()) - +class TestHotStandby(tb.HotStandbyTestCase): def create_pool(self, **kwargs): conn_spec = self.standby_cluster.get_connection_spec() conn_spec.update(kwargs) diff --git a/tests/test_transaction.py b/tests/test_transaction.py index 8b7ffd95..f84cf7c0 100644 --- a/tests/test_transaction.py +++ b/tests/test_transaction.py @@ -188,6 +188,7 @@ async def test_isolation_level(self): isolation_levels = { None: default_isolation, 'read_committed': 'read committed', + 'read_uncommitted': 'read uncommitted', 'repeatable_read': 'repeatable read', 'serializable': 'serializable', } @@ -214,6 +215,7 @@ async def test_nested_isolation_level(self): set_sql = 'SET SESSION CHARACTERISTICS AS TRANSACTION ISOLATION LEVEL ' isolation_levels = { 'read_committed': 'read committed', + 'read_uncommitted': 'read uncommitted', 'repeatable_read': 'repeatable read', 'serializable': 'serializable', }