Skip to content

Commit

Permalink
INITIAL pickle serial
Browse files Browse the repository at this point in the history
  • Loading branch information
wjsi committed Apr 16, 2022
1 parent 8f55cfd commit 3257e82
Show file tree
Hide file tree
Showing 28 changed files with 821 additions and 724 deletions.
5 changes: 4 additions & 1 deletion .github/workflows/run-tests.sh
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,10 @@ if [ -n "$WITH_CYTHON" ]; then
mkdir -p build
export POOL_START_METHOD=forkserver

coverage run --rcfile=setup.cfg -m pytest $PYTEST_CONFIG_WITHOUT_COV mars/tests mars/core/graph
coverage run --rcfile=setup.cfg -m pytest $PYTEST_CONFIG_WITHOUT_COV \
mars/tests \
mars/core/graph \
mars/serialization
python .github/workflows/remove_tracer_errors.py
coverage combine
mv .coverage build/.coverage.non-oscar.file
Expand Down
1 change: 1 addition & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -86,6 +86,7 @@ mars/learn/cluster/*.c*
mars/learn/utils/*.c*
mars/lib/*.c*
mars/oscar/**/*.c*
mars/serialization/*.c*

# web bundle file
mars/services/web/static
Expand Down
39 changes: 24 additions & 15 deletions asv_bench/benchmarks/serialize.py
Original file line number Diff line number Diff line change
Expand Up @@ -72,7 +72,7 @@ class MySerializable(Serializable):
_dict_val = DictField("dict_val", FieldTypes.string, FieldTypes.bytes)


class SerializationSuite:
class SerializeSerializableSuite:
def setup(self):
children = []
for idx in range(1000):
Expand All @@ -90,6 +90,12 @@ def setup(self):
children.append(child)
self.test_data = SerializableParent(children=children)

def time_serialize_deserialize(self):
deserialize(*serialize(self.test_data))


class SerializeSubtaskSuite:
def setup(self):
self.subtasks = []
for i in range(10000):
subtask = Subtask(
Expand All @@ -110,7 +116,13 @@ def setup(self):
)
self.subtasks.append(subtask)

self.test_basic_serializable = []
def time_pickle_serialize_deserialize_subtask(self):
deserialize(*cloudpickle.loads(cloudpickle.dumps(serialize(self.subtasks))))


class SerializePrimitivesSuite:
def setup(self):
self.test_primitive_serializable = []
for i in range(10000):
my_serializable = MySerializable(
_bool_val=True,
Expand All @@ -129,27 +141,24 @@ def setup(self):
_tuple_val=("a", "b"),
_dict_val={"a": b"bytes_value"},
)
self.test_basic_serializable.append(my_serializable)

self.test_list = list(range(100000))
self.test_tuple = tuple(range(100000))
self.test_dict = {i: i for i in range(100000)}

def time_serialize_deserialize(self):
deserialize(*serialize(self.test_data))
self.test_primitive_serializable.append(my_serializable)

def time_serialize_deserialize_basic(self):
deserialize(*serialize(self.test_basic_serializable))
def time_serialize_deserialize_primitive(self):
deserialize(*serialize(self.test_primitive_serializable))

def time_pickle_serialize_deserialize_basic(self):
deserialize(
*cloudpickle.loads(
cloudpickle.dumps(serialize(self.test_basic_serializable))
cloudpickle.dumps(serialize(self.test_primitive_serializable))
)
)

def time_pickle_serialize_deserialize_subtask(self):
deserialize(*cloudpickle.loads(cloudpickle.dumps(serialize(self.subtasks))))

class SerializeContainersSuite:
def setup(self):
self.test_list = list(range(100000))
self.test_tuple = tuple(range(100000))
self.test_dict = {i: i for i in range(100000)}

def time_pickle_serialize_deserialize_list(self):
deserialize(*cloudpickle.loads(cloudpickle.dumps(serialize(self.test_list))))
Expand Down
12 changes: 12 additions & 0 deletions mars/_utils.pxd
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,18 @@
# See the License for the specific language governing permissions and
# limitations under the License.


cdef class TypeDispatcher:
cdef dict _handlers
cdef dict _lazy_handlers
cdef dict _inherit_handlers

cpdef void register(self, object type_, object handler)
cpdef void unregister(self, object type_)
cdef _reload_lazy_handlers(self)
cpdef get_handler(self, object type_)


cpdef str to_str(s, encoding=*)
cpdef bytes to_binary(s, encoding=*)
cpdef unicode to_text(s, encoding=*)
Expand Down
4 changes: 0 additions & 4 deletions mars/_utils.pyx
Original file line number Diff line number Diff line change
Expand Up @@ -84,10 +84,6 @@ cpdef unicode to_text(s, encoding='utf-8'):


cdef class TypeDispatcher:
cdef dict _handlers
cdef dict _lazy_handlers
cdef dict _inherit_handlers

def __init__(self):
self._handlers = dict()
self._lazy_handlers = dict()
Expand Down
1 change: 0 additions & 1 deletion mars/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,6 @@
from mars.lib.aio import stop_isolation
from mars.oscar.backends.router import Router
from mars.oscar.backends.ray.communication import RayServer
from mars.serialization.ray import register_ray_serializers, unregister_ray_serializers
from mars.utils import lazy_import

ray = lazy_import("ray")
Expand Down
15 changes: 6 additions & 9 deletions mars/core/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
from functools import wraps
from typing import Dict

from ..serialization.core import Placeholder, short_id
from ..serialization.serializables import Serializable, StringField
from ..serialization.serializables.core import SerializableSerializer
from ..utils import tokenize
Expand Down Expand Up @@ -117,16 +118,12 @@ def id(self):
return self._id


def buffered(func):
def buffered_base(func):
@wraps(func)
def wrapped(self, obj: Base, context: Dict):
obj_id = (obj.key, obj.id)
if obj_id in context:
return {
"id": id(context[obj_id]),
"serializer": "ref",
"buf_num": 0,
}, []
return Placeholder(short_id(context[obj_id]))
else:
context[obj_id] = obj
return func(self, obj, context)
Expand All @@ -135,9 +132,9 @@ def wrapped(self, obj: Base, context: Dict):


class BaseSerializer(SerializableSerializer):
@buffered
def serialize(self, obj: Serializable, context: Dict):
return (yield from super().serialize(obj, context))
@buffered_base
def serial(self, obj: Base, context: Dict):
return super().serial(obj, context)


BaseSerializer.register(Base)
Expand Down
2 changes: 1 addition & 1 deletion mars/core/entity/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@


class EntityData(Base):
__slots__ = "__weakref__", "_siblings"
__slots__ = ("_siblings",)
type_name = None

# required fields
Expand Down
2 changes: 0 additions & 2 deletions mars/core/entity/tileables.py
Original file line number Diff line number Diff line change
Expand Up @@ -345,8 +345,6 @@ def detach(self, entity):


class Tileable(Entity):
__slots__ = ("__weakref__",)

def __init__(self, data: TileableType = None, **kw):
super().__init__(data=data, **kw)
data = self._data
Expand Down
16 changes: 6 additions & 10 deletions mars/core/graph/entity.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@
# limitations under the License.

from abc import ABCMeta, abstractmethod
from typing import List, Dict, Union, Iterable
from typing import List, Dict, Tuple, Union, Iterable

from ...core import Tileable, Chunk
from ...serialization.core import buffered
Expand Down Expand Up @@ -133,19 +133,15 @@ def to_graph(self) -> Union[TileableGraph, ChunkGraph]:


class GraphSerializer(SerializableSerializer):
serializer_name = "graph"

@buffered
def serialize(self, obj: Union[TileableGraph, ChunkGraph], context: Dict):
def serial(self, obj: Union[TileableGraph, ChunkGraph], context: Dict):
serializable_graph = SerializableGraph.from_graph(obj)
return (yield from super().serialize(serializable_graph, context))
return (), [serializable_graph], False

def deserialize(
self, header: Dict, buffers: List, context: Dict
def deserial(
self, serialized: Tuple, context: Dict, subs: List
) -> Union[TileableGraph, ChunkGraph]:
serializable_graph: SerializableGraph = (
yield from super().deserialize(header, buffers, context)
)
serializable_graph: SerializableGraph = subs[0]
return serializable_graph.to_graph()


Expand Down
10 changes: 6 additions & 4 deletions mars/core/operand/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -161,11 +161,11 @@ class Operand(Base, OperatorLogicKeyGeneratorMixin, metaclass=OperandMetaclass):
which should be the :class:`mars.tensor.core.TensorData`, :class:`mars.tensor.core.ChunkData` etc.
"""

__slots__ = ("__weakref__",)
attr_tag = "attr"
_init_update_key_ = False
_output_type_ = None
_no_copy_attrs_ = Base._no_copy_attrs_ | {"scheduling_hint"}
_cache_primitive_serial = True

sparse = BoolField("sparse", default=False)
device = Int32Field("device", default=None)
Expand Down Expand Up @@ -328,11 +328,13 @@ def on_input_modify(self, new_input):


class OperandSerializer(SerializableSerializer):
serializer_name = "operand"
def serial(self, obj: Serializable, context: Dict):
res = super().serial(obj, context)
return res

def deserialize(self, header: Dict, buffers: List, context: Dict) -> Operand:
def deserial(self, serialized: Tuple, context: Dict, subs: List) -> Operand:
# convert outputs back to weak-refs
operand: Operand = (yield from super().deserialize(header, buffers, context))
operand: Operand = super().deserial(serialized, context, subs)
for i, out in enumerate(operand._outputs):

def cb(o, index):
Expand Down
22 changes: 3 additions & 19 deletions mars/core/operand/fetch.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,6 @@
import cloudpickle

from ... import opcodes
from ...serialization.core import cached_pickle_dumps
from ...serialization.serializables import FieldTypes, StringField, ListField
from .base import Operand
from .core import TileableOperandMixin
Expand Down Expand Up @@ -47,21 +46,6 @@ def execute(cls, ctx, op):
class FetchShuffle(Operand):
_op_type_ = opcodes.FETCH_SHUFFLE

source_keys = ListField(
"source_keys",
FieldTypes.string,
on_serialize=cached_pickle_dumps,
on_deserialize=cloudpickle.loads,
)
source_idxes = ListField(
"source_idxes",
FieldTypes.tuple(FieldTypes.uint64),
on_serialize=cached_pickle_dumps,
on_deserialize=cloudpickle.loads,
)
source_mappers = ListField(
"source_mappers",
FieldTypes.uint16,
on_serialize=cached_pickle_dumps,
on_deserialize=cloudpickle.loads,
)
source_keys = ListField("source_keys", FieldTypes.string)
source_idxes = ListField("source_idxes", FieldTypes.tuple(FieldTypes.uint64))
source_mappers = ListField("source_mappers", FieldTypes.uint16)
4 changes: 2 additions & 2 deletions mars/oscar/backends/communication/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -68,8 +68,8 @@ async def read_buffers(header: Dict, reader: StreamReader):
CPBuffer = CPUnownedMemory = CPMemoryPointer = None

# construct a empty cuda buffer and copy from host
is_cuda_buffers = header.get("is_cuda_buffers")
buffer_sizes = header.pop(BUFFER_SIZES_NAME)
is_cuda_buffers = header[0].get("is_cuda_buffers")
buffer_sizes = header[0].pop(BUFFER_SIZES_NAME)

buffers = []
for is_cuda_buffer, buf_size in zip(is_cuda_buffers, buffer_sizes):
Expand Down
30 changes: 9 additions & 21 deletions mars/oscar/backends/message.py
Original file line number Diff line number Diff line change
Expand Up @@ -384,32 +384,20 @@ def __str__(self):


class MessageSerializer(Serializer):
serializer_name = "actor_message"

@buffered
def serialize(self, obj: _MessageBase, context: Dict):
def serial(self, obj: _MessageBase, context: Dict):
assert obj.protocol == 0, "only support protocol 0 for now"

message_class = type(obj)
to_serialize = [getattr(obj, slot) for slot in _get_slots(message_class)]
header, buffers = yield to_serialize
new_header = {
"message_class": message_class,
"message_id": obj.message_id,
"protocol": obj.protocol,
"attributes_header": header,
}
return new_header, buffers

def deserialize(self, header: Dict, buffers: List, context: Dict):
protocol = header["protocol"]
message_cls = type(obj)
to_serialize = [getattr(obj, slot) for slot in _get_slots(message_cls)]
return (message_cls, obj.message_id, obj.protocol), [to_serialize], False

def deserial(self, serialized: Tuple, context: Dict, subs: List):
message_cls, message_id, protocol = serialized
assert protocol == 0, "only support protocol 0 for now"
message_id = header["message_id"]
message_class = header["message_class"]
try:
serialized = yield header["attributes_header"], buffers
message = object.__new__(message_class)
for slot, val in zip(_get_slots(message_class), serialized):
message = object.__new__(message_cls)
for slot, val in zip(_get_slots(message_cls), subs[0]):
setattr(message, slot, val)
return message
except pickle.UnpicklingError as e: # pragma: no cover
Expand Down
2 changes: 1 addition & 1 deletion mars/serialization/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@
# limitations under the License.

from .aio import AioSerializer, AioDeserializer
from .core import serialize, deserialize
from .core import serialize, deserialize, Serializer

from . import arrow, cuda, numpy, scipy, mars_objects, ray, exception

Expand Down
10 changes: 5 additions & 5 deletions mars/serialization/aio.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@
cupy = lazy_import("cupy", globals=globals())
cudf = lazy_import("cudf", globals=globals())

DEFAULT_SERIALIZATION_VERSION = 0
DEFAULT_SERIALIZATION_VERSION = 1
BUFFER_SIZES_NAME = "buf_sizes"


Expand All @@ -51,10 +51,10 @@ def _is_cuda_buffer(buf): # pragma: no cover
return False

is_cuda_buffers = [_is_cuda_buffer(buf) for buf in buffers]
headers["is_cuda_buffers"] = np.array(is_cuda_buffers)
headers[0]["is_cuda_buffers"] = np.array(is_cuda_buffers)

# add buffer lengths into headers
headers[BUFFER_SIZES_NAME] = [
headers[0][BUFFER_SIZES_NAME] = [
getattr(buf, "nbytes", len(buf)) for buf in buffers
]
header = pickle.dumps(headers)
Expand Down Expand Up @@ -113,7 +113,7 @@ async def _get_obj_header_bytes(self):
async def _get_obj(self):
header = pickle.loads(await self._get_obj_header_bytes())
# get buffer size
buffer_sizes = header.pop(BUFFER_SIZES_NAME)
buffer_sizes = header[0].pop(BUFFER_SIZES_NAME)
# get buffers
buffers = [await self._readexactly(size) for size in buffer_sizes]

Expand All @@ -127,7 +127,7 @@ async def get_size(self):
header_bytes = await self._get_obj_header_bytes()
header = pickle.loads(header_bytes)
# get buffer size
buffer_sizes = header.pop(BUFFER_SIZES_NAME)
buffer_sizes = header[0].pop(BUFFER_SIZES_NAME)
return 11 + len(header_bytes) + sum(buffer_sizes)

async def get_header(self):
Expand Down
Loading

0 comments on commit 3257e82

Please sign in to comment.