Skip to content

Commit

Permalink
Fix encode_parse_params() to use ExecuteContext
Browse files Browse the repository at this point in the history
  • Loading branch information
fantix committed May 31, 2024
1 parent bf8f328 commit b00c5ec
Show file tree
Hide file tree
Showing 2 changed files with 24 additions and 82 deletions.
12 changes: 1 addition & 11 deletions edgedb/protocol/protocol.pxd
Original file line number Diff line number Diff line change
Expand Up @@ -165,17 +165,7 @@ cdef class SansIOProtocol:

cdef ensure_connected(self)

cdef WriteBuffer encode_parse_params(
self,
str query,
object output_format,
bint expect_one,
int implicit_limit,
bint inline_typenames,
bint inline_typeids,
uint64_t allow_capabilities,
object state,
)
cdef WriteBuffer encode_parse_params(self, ExecuteContext ctx)


include "protocol_v0.pxd"
94 changes: 23 additions & 71 deletions edgedb/protocol/protocol.pyx
Original file line number Diff line number Diff line change
Expand Up @@ -246,54 +246,31 @@ cdef class SansIOProtocol:
raise errors.ClientConnectionClosedError(
'the connection has been closed')

cdef WriteBuffer encode_parse_params(
self,
str query,
object output_format,
bint expect_one,
int implicit_limit,
bint inline_typenames,
bint inline_typeids,
uint64_t allow_capabilities,
object state,
):
cdef WriteBuffer encode_parse_params(self, ExecuteContext ctx):
cdef:
WriteBuffer buf

compilation_flags = enums.CompilationFlag.INJECT_OUTPUT_OBJECT_IDS
if inline_typenames:
if ctx.inline_typenames:
compilation_flags |= enums.CompilationFlag.INJECT_OUTPUT_TYPE_NAMES
if inline_typeids:
if ctx.inline_typeids:
compilation_flags |= enums.CompilationFlag.INJECT_OUTPUT_TYPE_IDS

buf = WriteBuffer.new()
buf.write_int64(<int64_t>allow_capabilities)
buf.write_int64(<int64_t>ctx.allow_capabilities)
buf.write_int64(<int64_t><uint64_t>compilation_flags)
buf.write_int64(<int64_t>implicit_limit)
buf.write_byte(output_format)
buf.write_byte(CARDINALITY_ONE if expect_one else CARDINALITY_MANY)
buf.write_len_prefixed_utf8(query)
buf.write_int64(<int64_t>ctx.implicit_limit)
buf.write_byte(ctx.output_format)
buf.write_byte(CARDINALITY_ONE if ctx.expect_one else CARDINALITY_MANY)
buf.write_len_prefixed_utf8(ctx.query)

state_type_id, state_data = self.encode_state(state)
state_type_id, state_data = self.encode_state(ctx.state)
buf.write_bytes(state_type_id)
buf.write_bytes(state_data)

return buf

async def _parse(
self,
query: str,
*,
reg: CodecsRegistry,
output_format: OutputFormat=OutputFormat.BINARY,
expect_one: bint=False,
required_one: bool=False,
implicit_limit: int=0,
inline_typenames: bool=False,
inline_typeids: bool=False,
allow_capabilities: enums.Capability = enums.Capability.ALL,
state: typing.Optional[dict] = None,
):
async def _parse(self, ctx: ExecuteContext):
cdef:
WriteBuffer buf, params
char mtype
Expand All @@ -310,16 +287,7 @@ cdef class SansIOProtocol:
buf = WriteBuffer.new_message(PREPARE_MSG)
buf.write_int16(0) # no headers

params = self.encode_parse_params(
query=query,
output_format=output_format,
expect_one=expect_one,
implicit_limit=implicit_limit,
inline_typenames=inline_typenames,
inline_typeids=inline_typeids,
allow_capabilities=allow_capabilities,
state=state,
)
params = self.encode_parse_params(ctx)

buf.write_buffer(params)
buf.end_message()
Expand All @@ -335,16 +303,20 @@ cdef class SansIOProtocol:
try:
if mtype == STMT_DATA_DESC_MSG:
capabilities, cardinality, in_dc, out_dc = \
self.parse_describe_type_message(reg)
self.parse_describe_type_message(ctx.reg)

elif mtype == STATE_DATA_DESC_MSG:
self.parse_describe_state_message()

elif mtype == ERROR_RESPONSE_MSG:
exc = self.parse_error_message()
exc._query = query
exc._query = ctx.query
exc = self._amend_parse_error(
exc, output_format, expect_one, required_one)
exc,
ctx.output_format,
ctx.expect_one,
ctx.required_one,
)

elif mtype == READY_FOR_COMMAND_MSG:
self.parse_sync_message()
Expand All @@ -358,9 +330,9 @@ cdef class SansIOProtocol:
if exc is not None:
raise exc

if required_one and cardinality == CARDINALITY_NOT_APPLICABLE:
assert output_format != OutputFormat.NONE
methname = _QUERY_SINGLE_METHOD[required_one][output_format]
if ctx.required_one and cardinality == CARDINALITY_NOT_APPLICABLE:
assert ctx.output_format != OutputFormat.NONE
methname = _QUERY_SINGLE_METHOD[ctx.required_one][ctx.output_format]
raise errors.InterfaceError(
f'query cannot be executed with {methname}() as it '
f'does not return any data')
Expand Down Expand Up @@ -392,16 +364,7 @@ cdef class SansIOProtocol:
BaseCodec out_dc = ctx.out_dc
object state = ctx.state

params = self.encode_parse_params(
query=query,
output_format=output_format,
expect_one=expect_one,
implicit_limit=implicit_limit,
inline_typenames=inline_typenames,
inline_typeids=inline_typeids,
allow_capabilities=allow_capabilities,
state=state,
)
params = self.encode_parse_params(ctx)

buf = WriteBuffer.new_message(EXECUTE_MSG)
buf.write_int16(0) # no headers
Expand Down Expand Up @@ -565,18 +528,7 @@ cdef class SansIOProtocol:
# command is already executed.
in_dc = out_dc = NULL_CODEC
else:
parsed = await self._parse(
query,
reg=reg,
output_format=output_format,
expect_one=expect_one,
required_one=required_one,
implicit_limit=implicit_limit,
inline_typenames=inline_typenames,
inline_typeids=inline_typeids,
allow_capabilities=allow_capabilities,
state=state,
)
parsed = await self._parse(ctx)

cardinality = parsed[0]
in_dc = <BaseCodec>parsed[1]
Expand Down

0 comments on commit b00c5ec

Please sign in to comment.