diff --git a/edgedb/__init__.py b/edgedb/__init__.py index 353bdfcd..5ba24795 100644 --- a/edgedb/__init__.py +++ b/edgedb/__init__.py @@ -25,7 +25,7 @@ Tuple, NamedTuple, EnumValue, RelativeDuration, DateDuration, ConfigMemory ) from edgedb.datatypes.datatypes import Set, Object, Array, Link, LinkSet -from edgedb.datatypes.range import Range +from edgedb.datatypes.range import Range, MultiRange from .abstract import ( Executor, AsyncIOExecutor, ReadOnlyExecutor, AsyncIOReadOnlyExecutor, diff --git a/edgedb/datatypes/range.py b/edgedb/datatypes/range.py index eaeb4bcb..e3fd3d1e 100644 --- a/edgedb/datatypes/range.py +++ b/edgedb/datatypes/range.py @@ -16,8 +16,8 @@ # limitations under the License. # -from typing import Generic, Optional, TypeVar - +from typing import (TypeVar, Any, Generic, Optional, Iterable, Iterator, + Sequence) T = TypeVar("T") @@ -78,8 +78,10 @@ def is_empty(self) -> bool: def __bool__(self): return not self.is_empty() - def __eq__(self, other): - if not isinstance(other, Range): + def __eq__(self, other) -> bool: + if isinstance(other, Range): + o = other + else: return NotImplemented return ( @@ -87,13 +89,13 @@ def __eq__(self, other): self._upper, self._inc_lower, self._inc_upper, - self._empty - ) == ( - other._lower, - other._upper, - other._inc_lower, - other._inc_upper, self._empty, + ) == ( + o._lower, + o._upper, + o._inc_lower, + o._inc_upper, + o._empty, ) def __hash__(self) -> int: @@ -125,3 +127,39 @@ def __str__(self) -> str: return f"" __repr__ = __str__ + + +# TODO: maybe we should implement range and multirange operations as well as +# normalization of the sub-ranges? +class MultiRange(Iterable[T]): + + _ranges: Sequence[T] + + def __init__(self, iterable: Optional[Iterable[T]] = None) -> None: + if iterable is not None: + self._ranges = tuple(iterable) + else: + self._ranges = tuple() + + def __len__(self) -> int: + return len(self._ranges) + + def __iter__(self) -> Iterator[T]: + return iter(self._ranges) + + def __reversed__(self) -> Iterator[T]: + return reversed(self._ranges) + + def __str__(self) -> str: + return f'' + + __repr__ = __str__ + + def __eq__(self, other: Any) -> bool: + if isinstance(other, MultiRange): + return set(self._ranges) == set(other._ranges) + else: + return NotImplemented + + def __hash__(self) -> int: + return hash(self._ranges) diff --git a/edgedb/describe.py b/edgedb/describe.py index 05c16398..92e38854 100644 --- a/edgedb/describe.py +++ b/edgedb/describe.py @@ -91,3 +91,8 @@ class SparseObjectType(ObjectType): @dataclasses.dataclass(frozen=True) class RangeType(AnyType): value_type: AnyType + + +@dataclasses.dataclass(frozen=True) +class MultiRangeType(AnyType): + value_type: AnyType diff --git a/edgedb/protocol/codecs/array.pyx b/edgedb/protocol/codecs/array.pyx index ef64fadc..2906f1e8 100644 --- a/edgedb/protocol/codecs/array.pyx +++ b/edgedb/protocol/codecs/array.pyx @@ -39,7 +39,8 @@ cdef class BaseArrayCodec(BaseCodec): if not isinstance( self.sub_codec, - (ScalarCodec, TupleCodec, NamedTupleCodec, RangeCodec, EnumCodec) + (ScalarCodec, TupleCodec, NamedTupleCodec, EnumCodec, + RangeCodec, MultiRangeCodec) ): raise TypeError( 'only arrays of scalars are supported (got type {!r})'.format( diff --git a/edgedb/protocol/codecs/base.pyx b/edgedb/protocol/codecs/base.pyx index 3bd52bb0..a40f6e57 100644 --- a/edgedb/protocol/codecs/base.pyx +++ b/edgedb/protocol/codecs/base.pyx @@ -149,7 +149,7 @@ cdef class BaseRecordCodec(BaseCodec): if not isinstance( codec, (ScalarCodec, ArrayCodec, TupleCodec, NamedTupleCodec, - EnumCodec, RangeCodec), + EnumCodec, RangeCodec, MultiRangeCodec), ): self.encoder_flags |= RECORD_ENCODER_INVALID break diff --git a/edgedb/protocol/codecs/codecs.pyx b/edgedb/protocol/codecs/codecs.pyx index bb669350..1218d3e6 100644 --- a/edgedb/protocol/codecs/codecs.pyx +++ b/edgedb/protocol/codecs/codecs.pyx @@ -54,6 +54,7 @@ DEF CTYPE_INPUT_SHAPE = 8 DEF CTYPE_RANGE = 9 DEF CTYPE_OBJECT = 10 DEF CTYPE_COMPOUND = 11 +DEF CTYPE_MULTIRANGE = 12 DEF CTYPE_ANNO_TYPENAME = 255 DEF _CODECS_BUILD_CACHE_SIZE = 200 @@ -165,6 +166,9 @@ cdef class CodecsRegistry: elif t == CTYPE_RANGE: frb_read(spec, 2) + elif t == CTYPE_MULTIRANGE: + frb_read(spec, 2) + elif t == CTYPE_ENUM: els = hton.unpack_int16(frb_read(spec, 2)) for i in range(els): @@ -444,6 +448,24 @@ cdef class CodecsRegistry: res = RangeCodec.new(tid, sub_codec) res.type_name = type_name + elif t == CTYPE_MULTIRANGE: + if protocol_version >= (2, 0): + str_len = hton.unpack_uint32(frb_read(spec, 4)) + type_name = cpythonx.PyUnicode_FromStringAndSize( + frb_read(spec, str_len), str_len) + schema_defined = frb_read(spec, 1)[0] + ancestor_count = hton.unpack_int16(frb_read(spec, 2)) + for _ in range(ancestor_count): + ancestor_pos = hton.unpack_int16( + frb_read(spec, 2)) + ancestor_codec = codecs_list[ancestor_pos] + else: + type_name = None + pos = hton.unpack_int16(frb_read(spec, 2)) + sub_codec = codecs_list[pos] + res = MultiRangeCodec.new(tid, sub_codec) + res.type_name = type_name + elif t == CTYPE_OBJECT and protocol_version >= (2, 0): # Ignore frb_read(spec, desc_len) diff --git a/edgedb/protocol/codecs/range.pxd b/edgedb/protocol/codecs/range.pxd index 13d642f2..9b232b10 100644 --- a/edgedb/protocol/codecs/range.pxd +++ b/edgedb/protocol/codecs/range.pxd @@ -25,3 +25,19 @@ cdef class RangeCodec(BaseCodec): @staticmethod cdef BaseCodec new(bytes tid, BaseCodec sub_codec) + + @staticmethod + cdef encode_range(WriteBuffer buf, object obj, BaseCodec sub_codec) + + @staticmethod + cdef decode_range(FRBuffer *buf, BaseCodec sub_codec) + + +@cython.final +cdef class MultiRangeCodec(BaseCodec): + + cdef: + BaseCodec sub_codec + + @staticmethod + cdef BaseCodec new(bytes tid, BaseCodec sub_codec) diff --git a/edgedb/protocol/codecs/range.pyx b/edgedb/protocol/codecs/range.pyx index 3d608fd0..ea573b89 100644 --- a/edgedb/protocol/codecs/range.pyx +++ b/edgedb/protocol/codecs/range.pyx @@ -46,7 +46,8 @@ cdef class RangeCodec(BaseCodec): return codec - cdef encode(self, WriteBuffer buf, object obj): + @staticmethod + cdef encode_range(WriteBuffer buf, object obj, BaseCodec sub_codec): cdef: uint8_t flags = 0 WriteBuffer sub_data @@ -56,10 +57,10 @@ cdef class RangeCodec(BaseCodec): bint inc_upper = obj.inc_upper bint empty = obj.is_empty() - if not isinstance(self.sub_codec, ScalarCodec): + if not isinstance(sub_codec, ScalarCodec): raise TypeError( 'only scalar ranges are supported (got type {!r})'.format( - type(self.sub_codec).__name__ + type(sub_codec).__name__ ) ) @@ -78,14 +79,14 @@ cdef class RangeCodec(BaseCodec): sub_data = WriteBuffer.new() if lower is not None: try: - self.sub_codec.encode(sub_data, lower) + sub_codec.encode(sub_data, lower) except TypeError as e: raise ValueError( 'invalid range lower bound: {}'.format( e.args[0])) from None if upper is not None: try: - self.sub_codec.encode(sub_data, upper) + sub_codec.encode(sub_data, upper) except TypeError as e: raise ValueError( 'invalid range upper bound: {}'.format( @@ -95,7 +96,8 @@ cdef class RangeCodec(BaseCodec): buf.write_byte(flags) buf.write_buffer(sub_data) - cdef decode(self, FRBuffer *buf): + @staticmethod + cdef decode_range(FRBuffer *buf, BaseCodec sub_codec): cdef: uint8_t flags = frb_read(buf, 1)[0] bint empty = (flags & RANGE_EMPTY) != 0 @@ -107,7 +109,6 @@ cdef class RangeCodec(BaseCodec): object upper = None int32_t sub_len FRBuffer sub_buf - BaseCodec sub_codec = self.sub_codec if has_lower: sub_len = hton.unpack_int32(frb_read(buf, 4)) @@ -137,6 +138,12 @@ cdef class RangeCodec(BaseCodec): empty=empty, ) + cdef encode(self, WriteBuffer buf, object obj): + RangeCodec.encode_range(buf, obj, self.sub_codec) + + cdef decode(self, FRBuffer *buf): + return RangeCodec.decode_range(buf, self.sub_codec) + cdef dump(self, int level = 0): return f'{level * " "}{self.name}\n{self.sub_codec.dump(level + 1)}' @@ -146,3 +153,104 @@ cdef class RangeCodec(BaseCodec): name=self.type_name, value_type=self.sub_codec.make_type(describe_context), ) + + +@cython.final +cdef class MultiRangeCodec(BaseCodec): + + def __cinit__(self): + self.sub_codec = None + + @staticmethod + cdef BaseCodec new(bytes tid, BaseCodec sub_codec): + cdef: + MultiRangeCodec codec + + codec = MultiRangeCodec.__new__(MultiRangeCodec) + + codec.tid = tid + codec.name = 'MultiRange' + codec.sub_codec = sub_codec + + return codec + + cdef encode(self, WriteBuffer buf, object obj): + cdef: + WriteBuffer elem_data + Py_ssize_t objlen + Py_ssize_t elem_data_len + + if not isinstance(self.sub_codec, ScalarCodec): + raise TypeError( + f'only scalar multiranges are supported (got type ' + f'{type(self.sub_codec).__name__!r})' + ) + + if not _is_array_iterable(obj): + raise TypeError( + f'a sized iterable container expected (got type ' + f'{type(obj).__name__!r})' + ) + + objlen = len(obj) + if objlen > _MAXINT32: + raise ValueError('too many elements in multirange value') + + elem_data = WriteBuffer.new() + for item in obj: + try: + RangeCodec.encode_range(elem_data, item, self.sub_codec) + except TypeError as e: + raise ValueError( + f'invalid multirange element: {e.args[0]}') from None + + elem_data_len = elem_data.len() + if elem_data_len > _MAXINT32 - 4: + raise OverflowError( + f'size of encoded multirange datum exceeds the maximum ' + f'allowed {_MAXINT32 - 4} bytes') + + # Datum length + buf.write_int32(4 + elem_data_len) + # Number of elements in multirange + buf.write_int32(objlen) + buf.write_buffer(elem_data) + + cdef decode(self, FRBuffer *buf): + cdef: + Py_ssize_t elem_count = hton.unpack_int32( + frb_read(buf, 4)) + object result + Py_ssize_t i + int32_t elem_len + FRBuffer elem_buf + + result = cpython.PyList_New(elem_count) + for i in range(elem_count): + elem_len = hton.unpack_int32(frb_read(buf, 4)) + if elem_len == -1: + raise RuntimeError( + 'unexpected NULL element in multirange value') + else: + frb_slice_from(&elem_buf, buf, elem_len) + elem = RangeCodec.decode_range(&elem_buf, self.sub_codec) + if frb_get_len(&elem_buf): + raise RuntimeError( + f'unexpected trailing data in buffer after ' + f'multirange element decoding: ' + f'{frb_get_len(&elem_buf)}') + + cpython.Py_INCREF(elem) + cpython.PyList_SET_ITEM(result, i, elem) + + return range_mod.MultiRange(result) + + cdef dump(self, int level = 0): + return f'{level * " "}{self.name}\n{self.sub_codec.dump(level + 1)}' + + def make_type(self, describe_context): + return describe.MultiRangeType( + desc_id=uuid.UUID(bytes=self.tid), + name=self.type_name, + value_type=self.sub_codec.make_type(describe_context), + ) \ No newline at end of file diff --git a/tests/datatypes/test_datatypes.py b/tests/datatypes/test_datatypes.py index eaff8aff..741489d4 100644 --- a/tests/datatypes/test_datatypes.py +++ b/tests/datatypes/test_datatypes.py @@ -1003,3 +1003,165 @@ def test_array_6(self): self.assertNotEqual( edgedb.Array([1, 2, 3]), False) + + +class TestRange(unittest.TestCase): + + def test_range_empty_1(self): + t = edgedb.Range(empty=True) + self.assertEqual(t.lower, None) + self.assertEqual(t.upper, None) + self.assertFalse(t.inc_lower) + self.assertFalse(t.inc_upper) + self.assertTrue(t.is_empty()) + self.assertFalse(t) + + self.assertEqual(t, edgedb.Range(1, 1, empty=True)) + + with self.assertRaisesRegex(ValueError, 'conflicting arguments'): + edgedb.Range(1, 2, empty=True) + + def test_range_2(self): + t = edgedb.Range(1, 2) + self.assertEqual(repr(t), "") + self.assertEqual(str(t), "") + + self.assertEqual(t.lower, 1) + self.assertEqual(t.upper, 2) + self.assertTrue(t.inc_lower) + self.assertFalse(t.inc_upper) + self.assertFalse(t.is_empty()) + self.assertTrue(t) + + def test_range_3(self): + t = edgedb.Range(1) + self.assertEqual(t.lower, 1) + self.assertEqual(t.upper, None) + self.assertTrue(t.inc_lower) + self.assertFalse(t.inc_upper) + self.assertFalse(t.is_empty()) + + t = edgedb.Range(None, 1) + self.assertEqual(t.lower, None) + self.assertEqual(t.upper, 1) + self.assertFalse(t.inc_lower) + self.assertFalse(t.inc_upper) + self.assertFalse(t.is_empty()) + + t = edgedb.Range(None, None) + self.assertEqual(t.lower, None) + self.assertEqual(t.upper, None) + self.assertFalse(t.inc_lower) + self.assertFalse(t.inc_upper) + self.assertFalse(t.is_empty()) + + def test_range_4(self): + for il in (False, True): + for iu in (False, True): + t = edgedb.Range(1, 2, inc_lower=il, inc_upper=iu) + self.assertEqual(t.lower, 1) + self.assertEqual(t.upper, 2) + self.assertEqual(t.inc_lower, il) + self.assertEqual(t.inc_upper, iu) + self.assertFalse(t.is_empty()) + + def test_range_5(self): + # test hash + self.assertEqual( + { + edgedb.Range(None, 2, inc_upper=True), + edgedb.Range(1, 2), + edgedb.Range(1, 2), + edgedb.Range(1, 2), + edgedb.Range(None, 2, inc_upper=True), + }, + { + edgedb.Range(1, 2), + edgedb.Range(None, 2, inc_upper=True), + } + ) + + +class TestMultiRange(unittest.TestCase): + + def test_multirange_empty_1(self): + t = edgedb.MultiRange() + self.assertEqual(len(t), 0) + self.assertEqual(t, edgedb.MultiRange([])) + + def test_multirange_2(self): + t = edgedb.MultiRange([ + edgedb.Range(1, 2), + edgedb.Range(4), + ]) + self.assertEqual( + repr(t), ", ]>") + self.assertEqual( + str(t), ", ]>") + + self.assertEqual( + t, + edgedb.MultiRange([ + edgedb.Range(1, 2), + edgedb.Range(4), + ]) + ) + + def test_multirange_3(self): + ranges = [ + edgedb.Range(None, 0), + edgedb.Range(1, 2), + edgedb.Range(4), + ] + t = edgedb.MultiRange([ + edgedb.Range(None, 0), + edgedb.Range(1, 2), + edgedb.Range(4), + ]) + + for el, r in zip(t, ranges): + self.assertEqual(el, r) + + def test_multirange_4(self): + # test hash + self.assertEqual( + { + edgedb.MultiRange([ + edgedb.Range(1, 2), + edgedb.Range(4), + ]), + edgedb.MultiRange([edgedb.Range(None, 2, inc_upper=True)]), + edgedb.MultiRange([ + edgedb.Range(1, 2), + edgedb.Range(4), + ]), + edgedb.MultiRange([ + edgedb.Range(1, 2), + edgedb.Range(4), + ]), + edgedb.MultiRange([edgedb.Range(None, 2, inc_upper=True)]), + }, + { + edgedb.MultiRange([edgedb.Range(None, 2, inc_upper=True)]), + edgedb.MultiRange([ + edgedb.Range(1, 2), + edgedb.Range(4), + ]), + } + ) + + def test_multirange_5(self): + # test hash + self.assertEqual( + edgedb.MultiRange([ + edgedb.Range(None, 2, inc_upper=True), + edgedb.Range(5, 9), + edgedb.Range(5, 9), + edgedb.Range(5, 9), + edgedb.Range(None, 2, inc_upper=True), + ]), + edgedb.MultiRange([ + edgedb.Range(5, 9), + edgedb.Range(None, 2, inc_upper=True), + ]), + ) diff --git a/tests/test_async_query.py b/tests/test_async_query.py index d9924324..16a72d91 100644 --- a/tests/test_async_query.py +++ b/tests/test_async_query.py @@ -799,6 +799,83 @@ async def test_async_range_02(self): ) self.assertEqual([edgedb.Range(1, 2)], result) + async def test_async_multirange_01(self): + has_range = await self.client.query( + "select schema::ObjectType filter .name = 'schema::MultiRange'") + if not has_range: + raise unittest.SkipTest( + "server has no support for std::multirange") + + samples = [ + ('multirange', [ + edgedb.MultiRange(), + dict( + input=edgedb.MultiRange([edgedb.Range(empty=True)]), + output=edgedb.MultiRange(), + ), + edgedb.MultiRange([ + edgedb.Range(None, 0), + edgedb.Range(1, 2), + edgedb.Range(4), + ]), + dict( + input=edgedb.MultiRange([ + edgedb.Range(None, 2, inc_upper=True), + edgedb.Range(5, 9), + edgedb.Range(5, 9), + edgedb.Range(5, 9), + edgedb.Range(None, 2, inc_upper=True), + ]), + output=edgedb.MultiRange([ + edgedb.Range(5, 9), + edgedb.Range(None, 3), + ]), + ), + dict( + input=edgedb.MultiRange([ + edgedb.Range(None, 2), + edgedb.Range(-5, 9), + edgedb.Range(13), + ]), + output=edgedb.MultiRange([ + edgedb.Range(None, 9), + edgedb.Range(13), + ]), + ), + ]), + ] + + for typename, sample_data in samples: + for sample in sample_data: + with self.subTest(sample=sample, typname=typename): + stmt = f"SELECT <{typename}>$0" + if isinstance(sample, dict): + inputval = sample['input'] + outputval = sample['output'] + else: + inputval = outputval = sample + + result = await self.client.query_single(stmt, inputval) + err_msg = ( + "unexpected result for {} when passing {!r}: " + "received {!r}, expected {!r}".format( + typename, inputval, result, outputval)) + + self.assertEqual(result, outputval, err_msg) + + async def test_async_multirange_02(self): + has_range = await self.client.query( + "select schema::ObjectType filter .name = 'schema::MultiRange'") + if not has_range: + raise unittest.SkipTest( + "server has no support for std::multirange") + + result = await self.client.query_single( + "SELECT >>$0", + [edgedb.MultiRange([edgedb.Range(1, 2)])] + ) + self.assertEqual([edgedb.MultiRange([edgedb.Range(1, 2)])], result) + async def test_async_wait_cancel_01(self): underscored_lock = await self.client.query_single(""" SELECT EXISTS(