diff --git a/ibis/common/bases.py b/ibis/common/bases.py index de7c2bc6d9d2..9e1b5da5d80d 100644 --- a/ibis/common/bases.py +++ b/ibis/common/bases.py @@ -195,6 +195,13 @@ def __eq__(self, other) -> bool: return NotImplemented return all(getattr(self, n) == getattr(other, n) for n in self.__slots__) + def __getstate__(self): + return {k: getattr(self, k) for k in self.__slots__} + + def __setstate__(self, state): + for name, value in state.items(): + object.__setattr__(self, name, value) + def __repr__(self): fields = {k: getattr(self, k) for k in self.__slots__} fieldstring = ", ".join(f"{k}={v!r}" for k, v in fields.items()) @@ -221,5 +228,11 @@ def __init__(self, **kwargs) -> None: hashvalue = hash(tuple(kwargs.values())) object.__setattr__(self, "__precomputed_hash__", hashvalue) + def __setstate__(self, state): + for name, value in state.items(): + object.__setattr__(self, name, value) + hashvalue = hash(tuple(state.values())) + object.__setattr__(self, "__precomputed_hash__", hashvalue) + def __hash__(self) -> int: return self.__precomputed_hash__ diff --git a/ibis/common/tests/test_bases.py b/ibis/common/tests/test_bases.py index fb55b66ff6a0..72460a23ff3d 100644 --- a/ibis/common/tests/test_bases.py +++ b/ibis/common/tests/test_bases.py @@ -1,6 +1,7 @@ from __future__ import annotations import copy +import pickle import weakref from abc import ABCMeta, abstractmethod @@ -11,8 +12,10 @@ AbstractMeta, Comparable, Final, + FrozenSlotted, Immutable, Singleton, + Slotted, ) from ibis.common.caching import WeakCache @@ -258,3 +261,55 @@ class A(Final): class B(A): pass + + +class MyObj(Slotted): + __slots__ = ("a", "b") + + def __init__(self, a, b): + super().__init__(a=a, b=b) + + +def test_slotted(): + obj = MyObj(1, 2) + assert obj.a == 1 + assert obj.b == 2 + assert obj.__slots__ == ("a", "b") + with pytest.raises(AttributeError): + obj.c = 3 + + obj2 = MyObj(1, 2) + assert obj == obj2 + assert obj is not obj2 + + obj3 = MyObj(1, 3) + assert obj != obj3 + + assert pickle.loads(pickle.dumps(obj)) == obj + + +class MyFrozenObj(FrozenSlotted): + __slots__ = ("a", "b") + + def __init__(self, a, b): + super().__init__(a=a, b=b) + + +def test_frozen_slotted(): + obj = MyFrozenObj(1, 2) + assert obj.a == 1 + assert obj.b == 2 + assert obj.__slots__ == ("a", "b") + with pytest.raises(AttributeError): + obj.b = 3 + with pytest.raises(AttributeError): + obj.c = 3 + + obj2 = MyFrozenObj(1, 2) + assert obj == obj2 + assert obj is not obj2 + assert hash(obj) == hash(obj2) + + restored = pickle.loads(pickle.dumps(obj)) + assert restored == obj + assert hash(restored) == hash(obj)