Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Extract ExecuteContext as in/out argument #500

Merged
merged 10 commits into from
Jun 19, 2024
6 changes: 6 additions & 0 deletions edgedb/_testbase.py
Original file line number Diff line number Diff line change
Expand Up @@ -372,10 +372,16 @@ def make_test_client(
database='edgedb',
user='edgedb',
password='test',
host=...,
port=...,
connection_class=...,
):
conargs = cls.get_connect_args(
cluster=cluster, database=database, user=user, password=password)
if host is not ...:
conargs['host'] = host
if port is not ...:
conargs['port'] = port
if connection_class is ...:
connection_class = (
asyncio_client.AsyncIOConnection
Expand Down
32 changes: 31 additions & 1 deletion edgedb/abstract.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,7 +49,7 @@ class QueryWithArgs(typing.NamedTuple):

class QueryCache(typing.NamedTuple):
codecs_registry: protocol.CodecsRegistry
query_cache: protocol.QueryCodecsCache
query_cache: protocol.LRUMapping


class QueryOptions(typing.NamedTuple):
Expand All @@ -65,12 +65,42 @@ class QueryContext(typing.NamedTuple):
retry_options: typing.Optional[options.RetryOptions]
state: typing.Optional[options.State]

def lower(
self, *, allow_capabilities: enums.Capability
) -> protocol.ExecuteContext:
return protocol.ExecuteContext(
query=self.query.query,
args=self.query.args,
kwargs=self.query.kwargs,
reg=self.cache.codecs_registry,
qc=self.cache.query_cache,
output_format=self.query_options.output_format,
expect_one=self.query_options.expect_one,
required_one=self.query_options.required_one,
allow_capabilities=allow_capabilities,
state=self.state.as_dict() if self.state else None,
)


class ExecuteContext(typing.NamedTuple):
query: QueryWithArgs
cache: QueryCache
state: typing.Optional[options.State]

def lower(
self, *, allow_capabilities: enums.Capability
) -> protocol.ExecuteContext:
return protocol.ExecuteContext(
query=self.query.query,
args=self.query.args,
kwargs=self.query.kwargs,
reg=self.cache.codecs_registry,
qc=self.cache.query_cache,
output_format=protocol.OutputFormat.NONE,
allow_capabilities=allow_capabilities,
state=self.state.as_dict() if self.state else None,
)


@dataclasses.dataclass
class DescribeContext:
Expand Down
68 changes: 13 additions & 55 deletions edgedb/base_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@


BaseConnection_T = typing.TypeVar('BaseConnection_T', bound='BaseConnection')
QUERY_CACHE_SIZE = 1000


class BaseConnection(metaclass=abc.ABCMeta):
Expand Down Expand Up @@ -183,17 +184,7 @@ async def privileged_execute(
)
else:
await self._protocol.execute(
query=execute_context.query.query,
args=execute_context.query.args,
kwargs=execute_context.query.kwargs,
reg=execute_context.cache.codecs_registry,
qc=execute_context.cache.query_cache,
output_format=protocol.OutputFormat.NONE,
allow_capabilities=enums.Capability.ALL,
state=(
execute_context.state.as_dict()
if execute_context.state else None
),
execute_context.lower(allow_capabilities=enums.Capability.ALL)
)

def is_in_transaction(self) -> bool:
Expand All @@ -211,56 +202,31 @@ async def raw_query(self, query_context: abstract.QueryContext):
await self.connect()

reconnect = False
capabilities = None
i = 0
args = dict(
query=query_context.query.query,
args=query_context.query.args,
kwargs=query_context.query.kwargs,
reg=query_context.cache.codecs_registry,
qc=query_context.cache.query_cache,
output_format=query_context.query_options.output_format,
expect_one=query_context.query_options.expect_one,
required_one=query_context.query_options.required_one,
)
if self._protocol.is_legacy:
args["allow_capabilities"] = enums.Capability.LEGACY_EXECUTE
allow_capabilities = enums.Capability.LEGACY_EXECUTE
else:
args["allow_capabilities"] = enums.Capability.EXECUTE
if query_context.state is not None:
args["state"] = query_context.state.as_dict()
allow_capabilities = enums.Capability.EXECUTE
ctx = query_context.lower(allow_capabilities=allow_capabilities)
while True:
i += 1
try:
if reconnect:
await self.connect(single_attempt=True)
if self._protocol.is_legacy:
return await self._protocol.legacy_execute_anonymous(
**args
)
return await self._protocol.legacy_execute_anonymous(ctx)
else:
return await self._protocol.query(**args)
return await self._protocol.query(ctx)
except errors.EdgeDBError as e:
if query_context.retry_options is None:
raise
if not e.has_tag(errors.SHOULD_RETRY):
raise e
if capabilities is None:
cache_item = query_context.cache.query_cache.get(
query_context.query.query,
query_context.query_options.output_format,
implicit_limit=0,
inline_typenames=False,
inline_typeids=False,
expect_one=query_context.query_options.expect_one,
)
if cache_item is not None:
_, _, _, capabilities = cache_item
# A query is read-only if it has no capabilities i.e.
# capabilities == 0. Read-only queries are safe to retry.
# Explicit transaction conflicts as well.
if (
capabilities != 0
ctx.capabilities != 0
and not isinstance(e, errors.TransactionConflictError)
):
raise e
Expand All @@ -281,17 +247,9 @@ async def _execute(self, execute_context: abstract.ExecuteContext) -> None:
)
else:
await self._protocol.execute(
query=execute_context.query.query,
args=execute_context.query.args,
kwargs=execute_context.query.kwargs,
reg=execute_context.cache.codecs_registry,
qc=execute_context.cache.query_cache,
output_format=protocol.OutputFormat.NONE,
allow_capabilities=enums.Capability.EXECUTE,
state=(
execute_context.state.as_dict()
if execute_context.state else None
),
execute_context.lower(
allow_capabilities=enums.Capability.EXECUTE
)
)

async def describe(
Expand Down Expand Up @@ -473,7 +431,7 @@ def __init__(
self._connection_factory = connection_factory
self._connect_args = connect_args
self._codecs_registry = protocol.CodecsRegistry()
self._query_cache = protocol.QueryCodecsCache()
self._query_cache = protocol.LRUMapping(maxsize=QUERY_CACHE_SIZE)

if max_concurrency is not None and max_concurrency <= 0:
raise ValueError(
Expand Down Expand Up @@ -570,7 +528,7 @@ def set_connect_args(self, dsn=None, **connect_kwargs):
connect_kwargs["dsn"] = dsn
self._connect_args = connect_kwargs
self._codecs_registry = protocol.CodecsRegistry()
self._query_cache = protocol.QueryCodecsCache()
self._query_cache = protocol.LRUMapping(maxsize=QUERY_CACHE_SIZE)
self._working_addr = None
self._working_config = None
self._working_params = None
Expand Down
47 changes: 27 additions & 20 deletions edgedb/protocol/protocol.pxd
Original file line number Diff line number Diff line change
Expand Up @@ -67,15 +67,32 @@ cdef enum AuthenticationStatuses:
AUTH_SASL_FINAL = 12


cdef class QueryCodecsCache:

cdef class ExecuteContext:
cdef:
LRUMapping queries

cdef set(self, str query, OutputFormat output_format,
int implicit_limit, bint inline_typenames, bint inline_typeids,
bint expect_one, bint has_na_cardinality,
BaseCodec in_type, BaseCodec out_type, int capabilities)
# Input arguments
str query
object args
object kwargs
CodecsRegistry reg
LRUMapping qc
OutputFormat output_format
bint expect_one
bint required_one
int implicit_limit
bint inline_typenames
bint inline_typeids
uint64_t allow_capabilities
object state

# Contextual variables
bytes cardinality
BaseCodec in_dc
BaseCodec out_dc
readonly uint64_t capabilities

cdef inline bint has_na_cardinality(self)
cdef bint load_from_cache(self)
cdef inline store_to_cache(self)


cdef class SansIOProtocol:
Expand Down Expand Up @@ -113,7 +130,7 @@ cdef class SansIOProtocol:
cdef parse_data_messages(self, BaseCodec out_dc, result)
cdef parse_sync_message(self)
cdef parse_command_complete_message(self)
cdef parse_describe_type_message(self, CodecsRegistry reg)
cdef parse_describe_type_message(self, ExecuteContext ctx)
cdef parse_describe_state_message(self)
cdef parse_type_data(self, CodecsRegistry reg)
cdef _amend_parse_error(
Expand Down Expand Up @@ -141,17 +158,7 @@ cdef class SansIOProtocol:

cdef ensure_connected(self)

cdef WriteBuffer encode_parse_params(
self,
str query,
object output_format,
bint expect_one,
int implicit_limit,
bint inline_typenames,
bint inline_typeids,
uint64_t allow_capabilities,
object state,
)
cdef WriteBuffer encode_parse_params(self, ExecuteContext ctx)


include "protocol_v0.pxd"
Loading
Loading