Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Implement support for new type descriptor protocol #427

Merged
merged 3 commits into from
Jul 7, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion edgedb/protocol/codecs/array.pyx
Original file line number Diff line number Diff line change
Expand Up @@ -156,7 +156,7 @@ cdef class ArrayCodec(BaseArrayCodec):
def make_type(self, describe_context):
return describe.ArrayType(
desc_id=uuid.UUID(bytes=self.tid),
name=None,
name=self.type_name,
element_type=self.sub_codec.make_type(describe_context),
)

Expand Down
195 changes: 184 additions & 11 deletions edgedb/protocol/codecs/codecs.pyx
Original file line number Diff line number Diff line change
Expand Up @@ -52,6 +52,8 @@ DEF CTYPE_ARRAY = 6
DEF CTYPE_ENUM = 7
DEF CTYPE_INPUT_SHAPE = 8
DEF CTYPE_RANGE = 9
DEF CTYPE_OBJECT = 10
DEF CTYPE_COMPOUND = 11
DEF CTYPE_ANNO_TYPENAME = 255

DEF _CODECS_BUILD_CACHE_SIZE = 200
Expand Down Expand Up @@ -94,8 +96,9 @@ cdef class CodecsRegistry:
cdef BaseCodec _build_codec(self, FRBuffer *spec, list codecs_list,
protocol_version):
cdef:
uint8_t t = <uint8_t>(frb_read(spec, 1)[0])
bytes tid = frb_read(spec, 16)[:16]
uint32_t desc_len = 0
uint8_t t
bytes tid
uint16_t els
uint16_t i
uint32_t str_len
Expand All @@ -104,12 +107,21 @@ cdef class CodecsRegistry:
BaseCodec res
BaseCodec sub_codec

if protocol_version >= (2, 0):
desc_len = frb_get_len(spec) - 16 - 1

t = <uint8_t>(frb_read(spec, 1)[0])
tid = frb_read(spec, 16)[:16]

res = self.codecs.get(tid, None)
if res is None:
res = self.codecs_build_cache.get(tid, None)
if res is not None:
# We have a codec for this "tid"; advance the buffer
# so that we can process the next codec.
if desc_len > 0:
frb_read(spec, desc_len)
return res

if t == CTYPE_SET:
frb_read(spec, 2)
Expand Down Expand Up @@ -182,7 +194,50 @@ cdef class CodecsRegistry:
sub_codec = <BaseCodec>codecs_list[pos]
res = SetCodec.new(tid, sub_codec)

elif t == CTYPE_SHAPE or t == CTYPE_INPUT_SHAPE:
elif t == CTYPE_SHAPE:
if protocol_version >= (2, 0):
ephemeral_free_shape = <bint>frb_read(spec, 1)[0]
objtype_pos = <uint16_t>hton.unpack_int16(frb_read(spec, 2))

els = <uint16_t>hton.unpack_int16(frb_read(spec, 2))
codecs = cpython.PyTuple_New(els)
names = cpython.PyTuple_New(els)
flags = cpython.PyTuple_New(els)
cards = cpython.PyTuple_New(els)
for i in range(els):
flag = hton.unpack_uint32(frb_read(spec, 4)) # flags
cardinality = <uint8_t>frb_read(spec, 1)[0]

str_len = hton.unpack_uint32(frb_read(spec, 4))
name = cpythonx.PyUnicode_FromStringAndSize(
frb_read(spec, str_len), str_len)
pos = <uint16_t>hton.unpack_int16(frb_read(spec, 2))

if flag & datatypes._EDGE_POINTER_IS_LINKPROP:
name = "@" + name
cpython.Py_INCREF(name)
cpython.PyTuple_SetItem(names, i, name)

sub_codec = codecs_list[pos]
cpython.Py_INCREF(sub_codec)
cpython.PyTuple_SetItem(codecs, i, sub_codec)

cpython.Py_INCREF(flag)
cpython.PyTuple_SetItem(flags, i, flag)

cpython.Py_INCREF(cardinality)
cpython.PyTuple_SetItem(cards, i, cardinality)

if protocol_version >= (2, 0):
source_type_pos = <uint16_t>hton.unpack_int16(
frb_read(spec, 2))
source_type = codecs_list[source_type_pos]

res = ObjectCodec.new(
tid, names, flags, cards, codecs, t == CTYPE_INPUT_SHAPE
)

elif t == CTYPE_INPUT_SHAPE:
els = <uint16_t>hton.unpack_int16(frb_read(spec, 2))
codecs = cpython.PyTuple_New(els)
names = cpython.PyTuple_New(els)
Expand Down Expand Up @@ -223,15 +278,60 @@ cdef class CodecsRegistry:
res = <BaseCodec>BASE_SCALAR_CODECS[tid]

elif t == CTYPE_SCALAR:
pos = <uint16_t>hton.unpack_int16(frb_read(spec, 2))
codec = codecs_list[pos]
if type(codec) is not ScalarCodec:
raise RuntimeError(
f'a scalar codec expected for base scalar type, '
f'got {type(codec).__name__}')
res = (<ScalarCodec>codecs_list[pos]).derive(tid)
if protocol_version >= (2, 0):
str_len = hton.unpack_uint32(frb_read(spec, 4))
type_name = cpythonx.PyUnicode_FromStringAndSize(
frb_read(spec, str_len), str_len)
schema_defined = <bint>frb_read(spec, 1)[0]

ancestor_count = <uint16_t>hton.unpack_int16(frb_read(spec, 2))
ancestors = []
for _ in range(ancestor_count):
ancestor_pos = <uint16_t>hton.unpack_int16(
frb_read(spec, 2))
ancestor_codec = codecs_list[ancestor_pos]
if type(ancestor_codec) is not ScalarCodec:
raise RuntimeError(
f'a scalar codec expected for base scalar type, '
f'got {type(ancestor_codec).__name__}')
ancestors.append(ancestor_codec)

if ancestor_count == 0:
if tid in self.base_codec_overrides:
res = self.base_codec_overrides[tid]
else:
res = <BaseCodec>BASE_SCALAR_CODECS[tid]
else:
fundamental_codec = ancestors[-1]
if type(fundamental_codec) is not ScalarCodec:
raise RuntimeError(
f'a scalar codec expected for base scalar type, '
f'got {type(fundamental_codec).__name__}')
res = (<ScalarCodec>fundamental_codec).derive(tid)
res.type_name = type_name
else:
fundamental_pos = <uint16_t>hton.unpack_int16(
frb_read(spec, 2))
fundamental_codec = codecs_list[fundamental_pos]
if type(fundamental_codec) is not ScalarCodec:
raise RuntimeError(
f'a scalar codec expected for base scalar type, '
f'got {type(fundamental_codec).__name__}')
res = (<ScalarCodec>fundamental_codec).derive(tid)

elif t == CTYPE_TUPLE:
if protocol_version >= (2, 0):
str_len = hton.unpack_uint32(frb_read(spec, 4))
type_name = cpythonx.PyUnicode_FromStringAndSize(
frb_read(spec, str_len), str_len)
schema_defined = <bint>frb_read(spec, 1)[0]
ancestor_count = <uint16_t>hton.unpack_int16(frb_read(spec, 2))
for _ in range(ancestor_count):
ancestor_pos = <uint16_t>hton.unpack_int16(
frb_read(spec, 2))
ancestor_codec = codecs_list[ancestor_pos]
else:
type_name = None
els = <uint16_t>hton.unpack_int16(frb_read(spec, 2))
codecs = cpython.PyTuple_New(els)
for i in range(els):
Expand All @@ -242,8 +342,21 @@ cdef class CodecsRegistry:
cpython.PyTuple_SetItem(codecs, i, sub_codec)

res = TupleCodec.new(tid, codecs)
res.type_name = type_name

elif t == CTYPE_NAMEDTUPLE:
if protocol_version >= (2, 0):
str_len = hton.unpack_uint32(frb_read(spec, 4))
type_name = cpythonx.PyUnicode_FromStringAndSize(
frb_read(spec, str_len), str_len)
schema_defined = <bint>frb_read(spec, 1)[0]
ancestor_count = <uint16_t>hton.unpack_int16(frb_read(spec, 2))
for _ in range(ancestor_count):
ancestor_pos = <uint16_t>hton.unpack_int16(
frb_read(spec, 2))
ancestor_codec = codecs_list[ancestor_pos]
else:
type_name = None
els = <uint16_t>hton.unpack_int16(frb_read(spec, 2))
codecs = cpython.PyTuple_New(els)
names = cpython.PyTuple_New(els)
Expand All @@ -261,8 +374,21 @@ cdef class CodecsRegistry:
cpython.PyTuple_SetItem(codecs, i, sub_codec)

res = NamedTupleCodec.new(tid, names, codecs)
res.type_name = type_name

elif t == CTYPE_ENUM:
if protocol_version >= (2, 0):
str_len = hton.unpack_uint32(frb_read(spec, 4))
type_name = cpythonx.PyUnicode_FromStringAndSize(
frb_read(spec, str_len), str_len)
schema_defined = <bint>frb_read(spec, 1)[0]
ancestor_count = <uint16_t>hton.unpack_int16(frb_read(spec, 2))
for _ in range(ancestor_count):
ancestor_pos = <uint16_t>hton.unpack_int16(
frb_read(spec, 2))
ancestor_codec = codecs_list[ancestor_pos]
else:
type_name = None
els = <uint16_t>hton.unpack_int16(frb_read(spec, 2))
names = cpython.PyTuple_New(els)
for i in range(els):
Expand All @@ -274,8 +400,21 @@ cdef class CodecsRegistry:
cpython.PyTuple_SetItem(names, i, name)

res = EnumCodec.new(tid, names)
res.type_name = type_name

elif t == CTYPE_ARRAY:
if protocol_version >= (2, 0):
str_len = hton.unpack_uint32(frb_read(spec, 4))
type_name = cpythonx.PyUnicode_FromStringAndSize(
frb_read(spec, str_len), str_len)
schema_defined = <bint>frb_read(spec, 1)[0]
ancestor_count = <uint16_t>hton.unpack_int16(frb_read(spec, 2))
for _ in range(ancestor_count):
ancestor_pos = <uint16_t>hton.unpack_int16(
frb_read(spec, 2))
ancestor_codec = codecs_list[ancestor_pos]
else:
type_name = None
pos = <uint16_t>hton.unpack_int16(frb_read(spec, 2))
els = <uint16_t>hton.unpack_int16(frb_read(spec, 2))
if els != 1:
Expand All @@ -285,11 +424,35 @@ cdef class CodecsRegistry:
dim_len = hton.unpack_int32(frb_read(spec, 4))
sub_codec = <BaseCodec>codecs_list[pos]
res = ArrayCodec.new(tid, sub_codec, dim_len)
res.type_name = type_name

elif t == CTYPE_RANGE:
if protocol_version >= (2, 0):
str_len = hton.unpack_uint32(frb_read(spec, 4))
type_name = cpythonx.PyUnicode_FromStringAndSize(
frb_read(spec, str_len), str_len)
schema_defined = <bint>frb_read(spec, 1)[0]
ancestor_count = <uint16_t>hton.unpack_int16(frb_read(spec, 2))
for _ in range(ancestor_count):
ancestor_pos = <uint16_t>hton.unpack_int16(
frb_read(spec, 2))
ancestor_codec = codecs_list[ancestor_pos]
else:
type_name = None
pos = <uint16_t>hton.unpack_int16(frb_read(spec, 2))
sub_codec = <BaseCodec>codecs_list[pos]
res = RangeCodec.new(tid, sub_codec)
res.type_name = type_name

elif t == CTYPE_OBJECT and protocol_version >= (2, 0):
# Ignore
frb_read(spec, desc_len)
res = NULL_CODEC

elif t == CTYPE_COMPOUND and protocol_version >= (2, 0):
# Ignore
frb_read(spec, desc_len)
res = NULL_CODEC

else:
raise NotImplementedError(
Expand Down Expand Up @@ -321,6 +484,7 @@ cdef class CodecsRegistry:
cdef BaseCodec build_codec(self, bytes spec, protocol_version):
cdef:
FRBuffer buf
FRBuffer elem_buf
BaseCodec res
list codecs_list

Expand All @@ -331,7 +495,16 @@ cdef class CodecsRegistry:

codecs_list = []
while frb_get_len(&buf):
res = self._build_codec(&buf, codecs_list, protocol_version)
if protocol_version >= (2, 0):
desc_len = <uint32_t>hton.unpack_int32(frb_read(&buf, 4))
frb_slice_from(&elem_buf, &buf, desc_len)
res = self._build_codec(
&elem_buf, codecs_list, protocol_version)
if frb_get_len(&elem_buf):
raise RuntimeError(
f'unexpected trailing data in type descriptor datum')
else:
res = self._build_codec(&buf, codecs_list, protocol_version)
if res is None:
# An annotation; ignore.
continue
Expand Down
2 changes: 1 addition & 1 deletion edgedb/protocol/codecs/namedtuple.pyx
Original file line number Diff line number Diff line change
Expand Up @@ -78,7 +78,7 @@ cdef class NamedTupleCodec(BaseNamedRecordCodec):
def make_type(self, describe_context):
return describe.NamedTupleType(
desc_id=uuid.UUID(bytes=self.tid),
name=None,
name=self.type_name,
element_types={
field: codec.make_type(describe_context)
for field, codec in zip(
Expand Down
2 changes: 1 addition & 1 deletion edgedb/protocol/codecs/range.pyx
Original file line number Diff line number Diff line change
Expand Up @@ -143,6 +143,6 @@ cdef class RangeCodec(BaseCodec):
def make_type(self, describe_context):
return describe.RangeType(
desc_id=uuid.UUID(bytes=self.tid),
name=None,
name=self.type_name,
value_type=self.sub_codec.make_type(describe_context),
)
2 changes: 1 addition & 1 deletion edgedb/protocol/codecs/tuple.pyx
Original file line number Diff line number Diff line change
Expand Up @@ -81,7 +81,7 @@ cdef class TupleCodec(BaseRecordCodec):
def make_type(self, describe_context):
return describe.TupleType(
desc_id=uuid.UUID(bytes=self.tid),
name=None,
name=self.type_name,
element_types=tuple(
codec.make_type(describe_context)
for codec in self.fields_codecs
Expand Down
6 changes: 3 additions & 3 deletions edgedb/protocol/consts.pxi
Original file line number Diff line number Diff line change
Expand Up @@ -61,8 +61,8 @@ DEF TRANS_STATUS_IDLE = b'I'
DEF TRANS_STATUS_INTRANS = b'T'
DEF TRANS_STATUS_ERROR = b'E'

DEF PROTO_VER_MAJOR = 1
DEF PROTO_VER_MAJOR = 2
DEF PROTO_VER_MINOR = 0

DEF LEGACY_PROTO_VER_MAJOR = 0
DEF LEGACY_PROTO_VER_MINOR_MIN = 13
DEF MIN_PROTO_VER_MAJOR = 0
DEF MIN_PROTO_VER_MINOR = 13
13 changes: 5 additions & 8 deletions edgedb/protocol/protocol.pyx
Original file line number Diff line number Diff line change
Expand Up @@ -148,7 +148,7 @@ cdef class SansIOProtocol:
self.internal_reg = CodecsRegistry()
self.server_settings = {}
self.reset_status()
self.protocol_version = (PROTO_VER_MAJOR, 0)
self.protocol_version = (PROTO_VER_MAJOR, PROTO_VER_MINOR)

self.state_type_id = NULL_CODEC_ID
self.state_codec = None
Expand Down Expand Up @@ -873,22 +873,19 @@ cdef class SansIOProtocol:
minor = self.buffer.read_int16()

# TODO: drop this branch when dropping protocol_v0
if major == LEGACY_PROTO_VER_MAJOR:
if major == 0:
self.is_legacy = True
self.ignore_headers()

self.buffer.finish_message()

if major != PROTO_VER_MAJOR and not (
major == LEGACY_PROTO_VER_MAJOR and
minor >= LEGACY_PROTO_VER_MINOR_MIN
):
if (major, minor) < (MIN_PROTO_VER_MAJOR, MIN_PROTO_VER_MINOR):
raise errors.ClientConnectionError(
f'the server requested an unsupported version of '
f'the protocol: {major}.{minor}'
)

self.protocol_version = (major, minor)
else:
self.protocol_version = (major, minor)

elif mtype == AUTH_REQUEST_MSG:
# Authentication...
Expand Down