Skip to content

Commit

Permalink
Fix enum values
Browse files Browse the repository at this point in the history
User-defined enums can be used interchangably with edgedb.EnumValue,
like the `__eq__` (`==`) test, `__hash__` (as dict key) or comparing
will pass between user enum and edgedb.EnumValue, but the `is` check
won't pass.

This fixes a typing issue in the codegen. Fixes #419
  • Loading branch information
fantix committed May 26, 2023
1 parent ea53463 commit bb7522c
Show file tree
Hide file tree
Showing 4 changed files with 57 additions and 16 deletions.
27 changes: 18 additions & 9 deletions edgedb/datatypes/enum.pyx
Original file line number Diff line number Diff line change
Expand Up @@ -27,33 +27,42 @@ class EnumValue(enum.Enum):
def __repr__(self):
return f'<edgedb.EnumValue {self._value_!r}>'

@classmethod
def _try_from(cls, value):
if isinstance(value, EnumValue):
return value
elif isinstance(value, enum.Enum):
return cls(value.value)
else:
raise TypeError

def __lt__(self, other):
if not isinstance(other, EnumValue):
return NotImplemented
other = self._try_from(other)
if self.__tid__ != other.__tid__:
return NotImplemented
return self._index_ < other._index_

def __gt__(self, other):
if not isinstance(other, EnumValue):
return NotImplemented
other = self._try_from(other)
if self.__tid__ != other.__tid__:
return NotImplemented
return self._index_ > other._index_

def __le__(self, other):
if not isinstance(other, EnumValue):
return NotImplemented
other = self._try_from(other)
if self.__tid__ != other.__tid__:
return NotImplemented
return self._index_ <= other._index_

def __ge__(self, other):
if not isinstance(other, EnumValue):
return NotImplemented
other = self._try_from(other)
if self.__tid__ != other.__tid__:
return NotImplemented
return self._index_ >= other._index_

def __eq__(self, other):
other = self._try_from(other)
return self is other

def __hash__(self):
return hash((self.__tid__, self._value_))
return hash(self._value_)
10 changes: 7 additions & 3 deletions edgedb/protocol/codecs/enum.pyx
Original file line number Diff line number Diff line change
Expand Up @@ -24,9 +24,13 @@ cdef class EnumCodec(BaseCodec):

cdef encode(self, WriteBuffer buf, object obj):
if not isinstance(obj, (self.cls, str)):
raise TypeError(
f'a str or edgedb.EnumValue(__tid__={self.cls.__tid__}) is '
f'expected as a valid enum argument, got {type(obj).__name__}')
try:
obj = self.cls._try_from(obj)
except (TypeError, ValueError):
raise TypeError(
f'a str or edgedb.EnumValue(__tid__={self.cls.__tid__}) '
f'is expected as a valid enum argument, '
f'got {type(obj).__name__}') from None
pgproto.text_encode(DEFAULT_CODEC_CONTEXT, buf, str(obj))

cdef decode(self, FRBuffer *buf):
Expand Down
26 changes: 26 additions & 0 deletions tests/test_async_query.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@

import datetime
import decimal
import enum
import json
import random
import unittest
Expand Down Expand Up @@ -912,6 +913,31 @@ async def test_enum_argument_01(self):
edgedb.InvalidArgumentError, 'a str or edgedb.EnumValue'):
await self.client.query_single('SELECT <MyEnum>$0', 123)

async def test_enum_argument_02(self):
class MyEnum(enum.Enum):
A = "A"
B = "B"
C = "C"

A = await self.client.query_single('SELECT <MyEnum>$0', MyEnum.A)
self.assertEqual(str(A), 'A')
self.assertEqual(A, MyEnum.A)
self.assertEqual(MyEnum.A, A)
self.assertLess(A, MyEnum.B)
self.assertGreater(MyEnum.B, A)

mapping = {MyEnum.A: 1, MyEnum.B: 2}
self.assertEqual(mapping[A], 1)

with self.assertRaises(ValueError):
_ = A > MyEnum.C
with self.assertRaises(ValueError):
_ = A < MyEnum.C
with self.assertRaises(ValueError):
_ = A == MyEnum.C
with self.assertRaises(edgedb.InvalidArgumentError):
await self.client.query_single('SELECT <MyEnum>$0', MyEnum.C)

async def test_json(self):
self.assertEqual(
await self.client.query_json('SELECT {"aaa", "bbb"}'),
Expand Down
10 changes: 6 additions & 4 deletions tests/test_enum.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,8 +43,10 @@ async def test_enum_01(self):
self.assertEqual(repr(ct_red), "<edgedb.EnumValue 'red'>")

self.assertEqual(str(ct_red), 'red')
self.assertNotEqual(ct_red, 'red')
self.assertFalse(ct_red == 'red')
with self.assertRaises(TypeError):
_ = ct_red != 'red'
with self.assertRaises(TypeError):
_ = ct_red == 'red'
self.assertFalse(ct_red == c_red)

self.assertEqual(ct_red, ct_red)
Expand Down Expand Up @@ -74,8 +76,8 @@ async def test_enum_01(self):
with self.assertRaises(TypeError):
_ = ct_red >= c_red

self.assertNotEqual(hash(ct_red), hash(c_red))
self.assertNotEqual(hash(ct_red), hash('red'))
self.assertEqual(hash(ct_red), hash(c_red))
self.assertEqual(hash(ct_red), hash('red'))

async def test_enum_02(self):
c_red = await self.client.query_single('SELECT <Color>"red"')
Expand Down

0 comments on commit bb7522c

Please sign in to comment.