From bf8f3281f6da68b243ce63ad89fef3e691926959 Mon Sep 17 00:00:00 2001 From: Fantix King Date: Wed, 29 May 2024 16:19:51 -0400 Subject: [PATCH] Extract ExecuteContext as in/out argument So that we can pass out the parsed capabilities to control retries. This also allows further code optimization. --- edgedb/_testbase.py | 6 ++ edgedb/abstract.py | 30 ++++++ edgedb/base_client.py | 63 ++--------- edgedb/protocol/protocol.pxd | 24 +++++ edgedb/protocol/protocol.pyx | 181 ++++++++++++++++---------------- edgedb/protocol/protocol_v0.pyx | 94 ++++++++--------- tests/test_sync_retry.py | 75 +++++++++++++ tests/test_sync_tx.py | 2 +- 8 files changed, 280 insertions(+), 195 deletions(-) diff --git a/edgedb/_testbase.py b/edgedb/_testbase.py index 5680036c..e5b59508 100644 --- a/edgedb/_testbase.py +++ b/edgedb/_testbase.py @@ -372,10 +372,16 @@ def make_test_client( database='edgedb', user='edgedb', password='test', + host=..., + port=..., connection_class=..., ): conargs = cls.get_connect_args( cluster=cluster, database=database, user=user, password=password) + if host is not ...: + conargs['host'] = host + if port is not ...: + conargs['port'] = port if connection_class is ...: connection_class = ( asyncio_client.AsyncIOConnection diff --git a/edgedb/abstract.py b/edgedb/abstract.py index 7e8f4f6a..a1c6cfa4 100644 --- a/edgedb/abstract.py +++ b/edgedb/abstract.py @@ -65,12 +65,42 @@ class QueryContext(typing.NamedTuple): retry_options: typing.Optional[options.RetryOptions] state: typing.Optional[options.State] + def lower( + self, *, allow_capabilities: enums.Capability + ) -> protocol.ExecuteContext: + return protocol.ExecuteContext( + query=self.query.query, + args=self.query.args, + kwargs=self.query.kwargs, + reg=self.cache.codecs_registry, + qc=self.cache.query_cache, + output_format=self.query_options.output_format, + expect_one=self.query_options.expect_one, + required_one=self.query_options.required_one, + allow_capabilities=allow_capabilities, + state=self.state.as_dict() if self.state else None, + ) + class ExecuteContext(typing.NamedTuple): query: QueryWithArgs cache: QueryCache state: typing.Optional[options.State] + def lower( + self, *, allow_capabilities: enums.Capability + ) -> protocol.ExecuteContext: + return protocol.ExecuteContext( + query=self.query.query, + args=self.query.args, + kwargs=self.query.kwargs, + reg=self.cache.codecs_registry, + qc=self.cache.query_cache, + output_format=protocol.OutputFormat.NONE, + allow_capabilities=allow_capabilities, + state=self.state.as_dict() if self.state else None, + ) + @dataclasses.dataclass class DescribeContext: diff --git a/edgedb/base_client.py b/edgedb/base_client.py index d8c33767..0272cc13 100644 --- a/edgedb/base_client.py +++ b/edgedb/base_client.py @@ -183,17 +183,7 @@ async def privileged_execute( ) else: await self._protocol.execute( - query=execute_context.query.query, - args=execute_context.query.args, - kwargs=execute_context.query.kwargs, - reg=execute_context.cache.codecs_registry, - qc=execute_context.cache.query_cache, - output_format=protocol.OutputFormat.NONE, - allow_capabilities=enums.Capability.ALL, - state=( - execute_context.state.as_dict() - if execute_context.state else None - ), + execute_context.lower(allow_capabilities=enums.Capability.ALL) ) def is_in_transaction(self) -> bool: @@ -211,56 +201,31 @@ async def raw_query(self, query_context: abstract.QueryContext): await self.connect() reconnect = False - capabilities = None i = 0 - args = dict( - query=query_context.query.query, - args=query_context.query.args, - kwargs=query_context.query.kwargs, - reg=query_context.cache.codecs_registry, - qc=query_context.cache.query_cache, - output_format=query_context.query_options.output_format, - expect_one=query_context.query_options.expect_one, - required_one=query_context.query_options.required_one, - ) if self._protocol.is_legacy: - args["allow_capabilities"] = enums.Capability.LEGACY_EXECUTE + allow_capabilities = enums.Capability.LEGACY_EXECUTE else: - args["allow_capabilities"] = enums.Capability.EXECUTE - if query_context.state is not None: - args["state"] = query_context.state.as_dict() + allow_capabilities = enums.Capability.EXECUTE + ctx = query_context.lower(allow_capabilities=allow_capabilities) while True: i += 1 try: if reconnect: await self.connect(single_attempt=True) if self._protocol.is_legacy: - return await self._protocol.legacy_execute_anonymous( - **args - ) + return await self._protocol.legacy_execute_anonymous(ctx) else: - return await self._protocol.query(**args) + return await self._protocol.query(ctx) except errors.EdgeDBError as e: if query_context.retry_options is None: raise if not e.has_tag(errors.SHOULD_RETRY): raise e - if capabilities is None: - cache_item = query_context.cache.query_cache.get( - query_context.query.query, - query_context.query_options.output_format, - implicit_limit=0, - inline_typenames=False, - inline_typeids=False, - expect_one=query_context.query_options.expect_one, - ) - if cache_item is not None: - _, _, _, capabilities = cache_item # A query is read-only if it has no capabilities i.e. # capabilities == 0. Read-only queries are safe to retry. # Explicit transaction conflicts as well. if ( - capabilities != 0 + ctx.capabilities != 0 and not isinstance(e, errors.TransactionConflictError) ): raise e @@ -281,17 +246,9 @@ async def _execute(self, execute_context: abstract.ExecuteContext) -> None: ) else: await self._protocol.execute( - query=execute_context.query.query, - args=execute_context.query.args, - kwargs=execute_context.query.kwargs, - reg=execute_context.cache.codecs_registry, - qc=execute_context.cache.query_cache, - output_format=protocol.OutputFormat.NONE, - allow_capabilities=enums.Capability.EXECUTE, - state=( - execute_context.state.as_dict() - if execute_context.state else None - ), + execute_context.lower( + allow_capabilities=enums.Capability.EXECUTE + ) ) async def describe( diff --git a/edgedb/protocol/protocol.pxd b/edgedb/protocol/protocol.pxd index 6dccf624..3befd59b 100644 --- a/edgedb/protocol/protocol.pxd +++ b/edgedb/protocol/protocol.pxd @@ -78,6 +78,30 @@ cdef class QueryCodecsCache: BaseCodec in_type, BaseCodec out_type, int capabilities) +cdef class ExecuteContext: + cdef: + # Input arguments + str query + object args + object kwargs + CodecsRegistry reg + QueryCodecsCache qc + OutputFormat output_format + bint expect_one + bint required_one + int implicit_limit + bint inline_typenames + bint inline_typeids + uint64_t allow_capabilities + object state + + # Contextual variables + bytes cardinality + BaseCodec in_dc + BaseCodec out_dc + readonly uint64_t capabilities + + cdef class SansIOProtocol: cdef: diff --git a/edgedb/protocol/protocol.pyx b/edgedb/protocol/protocol.pyx index 505f92ce..0a022dd9 100644 --- a/edgedb/protocol/protocol.pyx +++ b/edgedb/protocol/protocol.pyx @@ -132,6 +132,43 @@ cdef class QueryCodecsCache: ) +cdef class ExecuteContext: + def __init__( + self, + *, + query: str, + args, + kwargs, + reg: CodecsRegistry, + qc: QueryCodecsCache, + output_format: OutputFormat, + expect_one: bool = False, + required_one: bool = False, + implicit_limit: int = 0, + inline_typenames: bool = False, + inline_typeids: bool = False, + allow_capabilities: enums.Capability = enums.Capability.ALL, + state: typing.Optional[dict] = None, + ): + self.query = query + self.args = args + self.kwargs = kwargs + self.reg = reg + self.qc = qc + self.output_format = output_format + self.expect_one = bool(expect_one) + self.required_one = bool(required_one) + self.implicit_limit = implicit_limit + self.inline_typenames = bool(inline_typenames) + self.inline_typeids = bool(inline_typeids) + self.allow_capabilities = allow_capabilities + self.state = state + + self.cardinality = None + self.in_dc = self.out_dc = None + self.capabilities = 0 + + cdef class SansIOProtocol: def __init__(self, con_params): @@ -330,25 +367,7 @@ cdef class SansIOProtocol: return cardinality, in_dc, out_dc, capabilities - async def _execute( - self, - *, - query: str, - args, - kwargs, - reg: CodecsRegistry, - qc: QueryCodecsCache, - output_format: object, - expect_one: bint, - required_one: bint, - implicit_limit: int, - inline_typenames: bint, - inline_typeids: bint, - allow_capabilities: enums.Capability = enums.Capability.ALL, - in_dc: BaseCodec, - out_dc: BaseCodec, - state: typing.Optional[dict] = None, - ): + async def _execute(self, ctx: ExecuteContext): cdef: WriteBuffer packet WriteBuffer buf @@ -357,6 +376,22 @@ cdef class SansIOProtocol: object result bytes new_cardinality = None + str query = ctx.query + object args = ctx.args + object kwargs = ctx.kwargs + CodecsRegistry reg = ctx.reg + QueryCodecsCache qc = ctx.qc + OutputFormat output_format = ctx.output_format + bint expect_one = ctx.expect_one + bint required_one = ctx.required_one + int implicit_limit = ctx.implicit_limit + bint inline_typenames = ctx.inline_typenames + bint inline_typeids = ctx.inline_typeids + uint64_t allow_capabilities = ctx.allow_capabilities + BaseCodec in_dc = ctx.in_dc + BaseCodec out_dc = ctx.out_dc + object state = ctx.state + params = self.encode_parse_params( query=query, output_format=output_format, @@ -407,6 +442,10 @@ cdef class SansIOProtocol: expect_one, new_cardinality, in_dc, out_dc, capabilities) + ctx.cardinality = new_cardinality + ctx.in_dc = in_dc + ctx.out_dc = out_dc + ctx.capabilities = capabilities elif mtype == STATE_DATA_DESC_MSG: self.parse_describe_state_message() @@ -481,28 +520,26 @@ cdef class SansIOProtocol: else: return NULL_CODEC_ID, EMPTY_NULL_DATA - async def execute( - self, - *, - query: str, - args, - kwargs, - reg: CodecsRegistry, - qc: QueryCodecsCache, - output_format: object, - expect_one: bint = False, - required_one: bool = False, - implicit_limit: int = 0, - inline_typenames: bool = False, - inline_typeids: bool = False, - allow_capabilities: enums.Capability = enums.Capability.ALL, - state: typing.Optional[dict] = None, - ): + async def execute(self, ctx: ExecuteContext): cdef: BaseCodec in_dc BaseCodec out_dc bytes cardinality + str query = ctx.query + object args = ctx.args + object kwargs = ctx.kwargs + CodecsRegistry reg = ctx.reg + QueryCodecsCache qc = ctx.qc + OutputFormat output_format = ctx.output_format + bint expect_one = ctx.expect_one + bint required_one = ctx.required_one + int implicit_limit = ctx.implicit_limit + bint inline_typenames = ctx.inline_typenames + bint inline_typeids = ctx.inline_typeids + uint64_t allow_capabilities = ctx.allow_capabilities + object state = ctx.state + self.ensure_connected() self.reset_status() @@ -515,8 +552,10 @@ cdef class SansIOProtocol: expect_one) if codecs is not None: + ctx.cardinality = codecs[0] in_dc = codecs[1] out_dc = codecs[2] + ctx.capabilities = codecs[3] elif not args and not kwargs and not required_one: # We don't have knowledge about the in/out desc of the command, but # the caller didn't provide any arguments, so let's try using NULL @@ -556,79 +595,39 @@ cdef class SansIOProtocol: out_dc, capabilities, ) + ctx.cardinality = cardinality + ctx.capabilities = capabilities - return await self._execute( - query=query, - args=args, - kwargs=kwargs, - reg=reg, - qc=qc, - output_format=output_format, - expect_one=expect_one, - required_one=required_one, - implicit_limit=implicit_limit, - inline_typenames=inline_typenames, - inline_typeids=inline_typeids, - allow_capabilities=allow_capabilities, - in_dc=in_dc, - out_dc=out_dc, - state=state, - ) + ctx.in_dc = in_dc + ctx.out_dc = out_dc - async def query( - self, - *, - query: str, - args, - kwargs, - reg: CodecsRegistry, - qc: QueryCodecsCache, - output_format: object, - expect_one: bint = False, - required_one: bool = False, - implicit_limit: int = 0, - inline_typenames: bool = False, - inline_typeids: bool = False, - allow_capabilities: enums.Capability = enums.Capability.ALL, - state: typing.Optional[dict] = None, - ): - ret = await self.execute( - query=query, - args=args, - kwargs=kwargs, - reg=reg, - qc=qc, - output_format=output_format, - expect_one=expect_one, - required_one=required_one, - implicit_limit=implicit_limit, - inline_typenames=inline_typenames, - inline_typeids=inline_typeids, - allow_capabilities=allow_capabilities, - state=state, - ) + return await self._execute(ctx) - if expect_one: - if ret or not required_one: + async def query(self, ctx: ExecuteContext): + ret = await self.execute(ctx) + if ctx.expect_one: + if ret or not ctx.required_one: if ret: return ret[0] else: - if output_format == OutputFormat.JSON: + if ctx.output_format == OutputFormat.JSON: return 'null' else: return None else: - methname = _QUERY_SINGLE_METHOD[required_one][output_format] + methname = ( + _QUERY_SINGLE_METHOD[ctx.required_one][ctx.output_format] + ) raise errors.NoDataError( f'query executed via {methname}() returned no data') else: if ret: - if output_format == OutputFormat.JSON: + if ctx.output_format == OutputFormat.JSON: return ret[0] else: return ret else: - if output_format == OutputFormat.JSON: + if ctx.output_format == OutputFormat.JSON: return '[]' else: return ret diff --git a/edgedb/protocol/protocol_v0.pyx b/edgedb/protocol/protocol_v0.pyx index 574adda6..bcde081c 100644 --- a/edgedb/protocol/protocol_v0.pyx +++ b/edgedb/protocol/protocol_v0.pyx @@ -225,24 +225,7 @@ cdef class SansIOProtocolBackwardsCompatible(SansIOProtocol): return result - async def _legacy_optimistic_execute( - self, - *, - query: str, - args, - kwargs, - reg: CodecsRegistry, - qc: QueryCodecsCache, - output_format: object, - expect_one: bint, - required_one: bint, - implicit_limit: int, - inline_typenames: bint, - inline_typeids: bint, - allow_capabilities: typing.Optional[int] = None, - in_dc: BaseCodec, - out_dc: BaseCodec, - ): + async def _legacy_optimistic_execute(self, ctx: ExecuteContext): cdef: WriteBuffer packet WriteBuffer buf @@ -251,6 +234,21 @@ cdef class SansIOProtocolBackwardsCompatible(SansIOProtocol): object result bytes new_cardinality = None + str query = ctx.query + object args = ctx.args + object kwargs = ctx.kwargs + CodecsRegistry reg = ctx.reg + QueryCodecsCache qc = ctx.qc + OutputFormat output_format = ctx.output_format + bint expect_one = ctx.expect_one + bint required_one = ctx.required_one + int implicit_limit = ctx.implicit_limit + bint inline_typenames = ctx.inline_typenames + bint inline_typeids = ctx.inline_typeids + uint64_t allow_capabilities = ctx.allow_capabilities + BaseCodec in_dc = ctx.in_dc + BaseCodec out_dc = ctx.out_dc + buf = WriteBuffer.new_message(EXECUTE_MSG) self.legacy_write_execute_headers( buf, implicit_limit, inline_typenames, inline_typeids, @@ -296,6 +294,11 @@ cdef class SansIOProtocolBackwardsCompatible(SansIOProtocol): expect_one, new_cardinality, in_dc, out_dc, capabilities) + ctx.cardinality = new_cardinality + ctx.in_dc = in_dc + ctx.out_dc = out_dc + ctx.capabilities = capabilities + re_exec = True elif mtype == DATA_MSG: @@ -351,26 +354,24 @@ cdef class SansIOProtocolBackwardsCompatible(SansIOProtocol): else: return result - async def legacy_execute_anonymous( - self, - *, - query: str, - args, - kwargs, - reg: CodecsRegistry, - qc: QueryCodecsCache, - output_format: object, - expect_one: bint = False, - required_one: bool = False, - implicit_limit: int = 0, - inline_typenames: bool = False, - inline_typeids: bool = False, - allow_capabilities: enums.Capability = enums.Capability.ALL, - ): + async def legacy_execute_anonymous(self, ctx: ExecuteContext): cdef: BaseCodec in_dc BaseCodec out_dc + str query = ctx.query + object args = ctx.args + object kwargs = ctx.kwargs + CodecsRegistry reg = ctx.reg + QueryCodecsCache qc = ctx.qc + OutputFormat output_format = ctx.output_format + bint expect_one = ctx.expect_one + bint required_one = ctx.required_one + int implicit_limit = ctx.implicit_limit + bint inline_typenames = ctx.inline_typenames + bint inline_typeids = ctx.inline_typeids + uint64_t allow_capabilities = ctx.allow_capabilities + self.ensure_connected() self.reset_status() @@ -417,6 +418,10 @@ cdef class SansIOProtocolBackwardsCompatible(SansIOProtocol): out_dc, capabilities, ) + ctx.cardinality = cardinality + ctx.in_dc = in_dc + ctx.out_dc = out_dc + ctx.capabilities = capabilities ret = await self._legacy_execute(in_dc, out_dc, args, kwargs) @@ -424,6 +429,10 @@ cdef class SansIOProtocolBackwardsCompatible(SansIOProtocol): cardinality = codecs[0] in_dc = codecs[1] out_dc = codecs[2] + ctx.cardinality = cardinality + ctx.in_dc = in_dc + ctx.out_dc = out_dc + ctx.capabilities = codecs[3] if required_one and cardinality == CARDINALITY_NOT_APPLICABLE: methname = _QUERY_SINGLE_METHOD[required_one][output_format] @@ -431,22 +440,7 @@ cdef class SansIOProtocolBackwardsCompatible(SansIOProtocol): f'query cannot be executed with {methname}() as it ' f'does not return any data') - ret = await self._legacy_optimistic_execute( - query=query, - args=args, - kwargs=kwargs, - reg=reg, - qc=qc, - output_format=output_format, - expect_one=expect_one, - required_one=required_one, - implicit_limit=implicit_limit, - inline_typenames=inline_typenames, - inline_typeids=inline_typeids, - allow_capabilities=allow_capabilities, - in_dc=in_dc, - out_dc=out_dc, - ) + ret = await self._legacy_optimistic_execute(ctx) if expect_one: if ret or not required_one: diff --git a/tests/test_sync_retry.py b/tests/test_sync_retry.py index 831f0964..ae32c633 100644 --- a/tests/test_sync_retry.py +++ b/tests/test_sync_retry.py @@ -17,7 +17,9 @@ # +import asyncio import threading +import queue import unittest.mock from concurrent import futures @@ -254,3 +256,76 @@ def test_sync_transaction_interface_errors(self): with tx: with tx: pass + + def test_sync_retry_parse(self): + loop = asyncio.new_event_loop() + q = queue.Queue() + + async def init(): + return asyncio.Event(), asyncio.Event() + + reconnect, terminate = loop.run_until_complete(init()) + + async def proxy(r, w): + try: + while True: + buf = await r.read(65536) + if not buf: + w.close() + break + w.write(buf) + except asyncio.CancelledError: + pass + + async def cb(ri, wi): + try: + args = self.get_connect_args() + ro, wo = await asyncio.open_connection( + args["host"], args["port"] + ) + try: + fs = [ + asyncio.create_task(proxy(ri, wo)), + asyncio.create_task(proxy(ro, wi)), + asyncio.create_task(terminate.wait()), + ] + if not reconnect.is_set(): + fs.append(asyncio.create_task(reconnect.wait())) + _, pending = await asyncio.wait( + fs, return_when=asyncio.FIRST_COMPLETED + ) + for f in pending: + f.cancel() + finally: + wo.close() + finally: + wi.close() + + async def proxy_server(): + srv = await asyncio.start_server(cb, host="127.0.0.1", port=0) + try: + q.put(srv.sockets[0].getsockname()[1]) + await terminate.wait() + finally: + srv.close() + await srv.wait_closed() + + with futures.ThreadPoolExecutor(1) as pool: + pool.submit(loop.run_until_complete, proxy_server()) + try: + client = self.make_test_client( + host="127.0.0.1", + port=q.get(), + database=self.get_database_name(), + ) + + # Fill the connection pool with a healthy connection + self.assertEqual(client.query_single("SELECT 42"), 42) + + # Cut the connection to simulate an Internet interruption + loop.call_soon_threadsafe(reconnect.set) + + # Run a new query that was never compiled, retry should work + self.assertEqual(client.query_single("SELECT 1*2+3-4"), 1) + finally: + loop.call_soon_threadsafe(terminate.set) diff --git a/tests/test_sync_tx.py b/tests/test_sync_tx.py index 3ed2fc55..497af782 100644 --- a/tests/test_sync_tx.py +++ b/tests/test_sync_tx.py @@ -102,7 +102,7 @@ def test_sync_transaction_commit_failure(self): def test_sync_transaction_exclusive(self): for tx in self.client.transaction(): with tx: - query = "select sys::_sleep(0.01)" + query = "select sys::_sleep(0.5)" with ThreadPoolExecutor(max_workers=2) as executor: f1 = executor.submit(tx.execute, query) f2 = executor.submit(tx.execute, query)