diff --git a/tests/test_types_basic.py b/tests/test_types_basic.py index 13b6e614..7a5293f9 100644 --- a/tests/test_types_basic.py +++ b/tests/test_types_basic.py @@ -38,6 +38,19 @@ class TE(t.enum_flag_uint16): assert TE(0x8012).serialize() == data +def test_abstract_ints(): + assert issubclass(t.uint8_t, t.uint_t) + assert not issubclass(t.uint8_t, t.int_t) + assert t.int_t._signed is True + assert t.uint_t._signed is False + + with pytest.raises(TypeError): + t.int_t(0) + + with pytest.raises(TypeError): + t.FixedIntType(0) + + def test_int_too_short(): with pytest.raises(ValueError): t.uint8_t.deserialize(b"") @@ -159,3 +172,11 @@ class TestList(t.FixedList, length=3, item_type=t.uint16_t): assert r[0] == 0x1234 assert r[1] == 0xAA55 assert r[2] == 0xAB89 + + +def test_enum_instance_types(): + class TestEnum(t.enum_uint8): + Member = 0x00 + + assert TestEnum._member_type_ is t.uint8_t + assert type(TestEnum.Member.value) is t.uint8_t diff --git a/tests/test_types_named.py b/tests/test_types_named.py index 19f845f8..ae6b56a2 100644 --- a/tests/test_types_named.py +++ b/tests/test_types_named.py @@ -95,13 +95,21 @@ def test_addr_mode_address(): def test_missing_status_enum(): - assert 0x33 not in list(t.Status) - assert isinstance(t.Status(0x33), t.Status) - assert t.Status(0x33).value == 0x33 + class TestEnum(t.MissingEnumMixin, t.enum_uint8): + Member = 0x00 - # Status values that don't fit can't be created + assert 0xFF not in list(TestEnum) + assert isinstance(TestEnum(0xFF), TestEnum) + assert TestEnum(0xFF).value == 0xFF + assert type(TestEnum(0xFF).value) is t.uint8_t + + # Missing members that don't fit can't be created + with pytest.raises(ValueError): + TestEnum(0xFF + 1) + + # Missing members that aren't integers can't be created with pytest.raises(ValueError): - t.Status(0xFF + 1) + TestEnum("0xFF") def test_zdo_nullable_node_descriptor(): diff --git a/zigpy_znp/types/basic.py b/zigpy_znp/types/basic.py index f9cf476e..9a511b1d 100644 --- a/zigpy_znp/types/basic.py +++ b/zigpy_znp/types/basic.py @@ -24,18 +24,32 @@ def serialize_list(objects) -> Bytes: return Bytes(b"".join([o.serialize() for o in objects])) -class int_t(int): - _signed = True +class FixedIntType(int): + _signed = None _size = None + def _concrete_new(cls, value=0): + instance = super().__new__(cls, value) + instance.serialize() + + return instance + def __new__(cls, value): - instance = int.__new__(cls, value) + raise TypeError(f"Instances of abstract type {cls} cannot be created") - if instance._signed is not None and instance._size is not None: - # It's a concrete int_t type, check to make sure it's valid - instance.serialize() + def __init_subclass__(cls, signed=None, size=None, **kwargs) -> None: + if signed is not None: + cls._signed = signed - return instance + if size is not None: + cls._size = size + + # XXX: The enum module uses the first class with `__new__` in its `__dict__` + # as the member type. We have to give each subclass its own `__new__`. + if signed is not None or size is not None: + cls.__new__ = cls._concrete_new + + super().__init_subclass__(**kwargs) def serialize(self) -> bytes: try: @@ -54,72 +68,76 @@ def deserialize(cls, data: bytes) -> typing.Tuple["int_t", bytes]: return r, data -class int8s(int_t): - _size = 1 +class uint_t(FixedIntType, signed=False): + pass -class int16s(int_t): - _size = 2 +class int_t(FixedIntType, signed=True): + pass + + +class int8s(int_t, size=1): + pass -class int24s(int_t): - _size = 3 +class int16s(int_t, size=2): + pass -class int32s(int_t): - _size = 4 +class int24s(int_t, size=3): + pass -class int40s(int_t): - _size = 5 +class int32s(int_t, size=4): + pass -class int48s(int_t): - _size = 6 +class int40s(int_t, size=5): + pass -class int56s(int_t): - _size = 7 +class int48s(int_t, size=6): + pass -class int64s(int_t): - _size = 8 +class int56s(int_t, size=7): + pass -class uint_t(int_t): - _signed = False +class int64s(int_t, size=8): + pass -class uint8_t(uint_t): - _size = 1 +class uint8_t(uint_t, size=1): + pass -class uint16_t(uint_t): - _size = 2 +class uint16_t(uint_t, size=2): + pass -class uint24_t(uint_t): - _size = 3 +class uint24_t(uint_t, size=3): + pass -class uint32_t(uint_t): - _size = 4 +class uint32_t(uint_t, size=4): + pass -class uint40_t(uint_t): - _size = 5 +class uint40_t(uint_t, size=5): + pass -class uint48_t(uint_t): - _size = 6 +class uint48_t(uint_t, size=6): + pass -class uint56_t(uint_t): - _size = 7 +class uint56_t(uint_t, size=7): + pass -class uint64_t(uint_t): - _size = 8 +class uint64_t(uint_t, size=8): + pass class ShortBytes(Bytes): diff --git a/zigpy_znp/types/named.py b/zigpy_znp/types/named.py index d6321fd1..8f7451d8 100644 --- a/zigpy_znp/types/named.py +++ b/zigpy_znp/types/named.py @@ -149,14 +149,12 @@ class Schema: class MissingEnumMixin: @classmethod def _missing_(cls, value): - if not isinstance(value, int) or value < 0 or value > 0xFF: - # `return None` works with Python 3.7.7, breaks with 3.7.1 + if not isinstance(value, int): raise ValueError(f"{value} is not a valid {cls.__name__}") - # XXX: infer type from enum - new_member = basic.uint8_t.__new__(cls, value) + new_member = cls._member_type_.__new__(cls, value) new_member._name_ = f"unknown_0x{value:02X}" - new_member._value_ = value + new_member._value_ = cls._member_type_(value) if sys.version_info >= (3, 8): # Show the warning in the calling code, not in this function