Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Allow spawning serialization to threads for large objects #2944

Merged
merged 3 commits into from
Apr 21, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions mars/core/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@
from functools import wraps
from typing import Dict

from ..serialization.core import Placeholder, short_id
from ..serialization.core import Placeholder, fast_id
from ..serialization.serializables import Serializable, StringField
from ..serialization.serializables.core import SerializableSerializer
from ..utils import tokenize
Expand Down Expand Up @@ -123,7 +123,7 @@ def buffered_base(func):
def wrapped(self, obj: Base, context: Dict):
obj_id = (obj.key, obj.id)
if obj_id in context:
return Placeholder(short_id(context[obj_id]))
return Placeholder(fast_id(context[obj_id]))
else:
context[obj_id] = obj
return func(self, obj, context)
Expand Down
2 changes: 1 addition & 1 deletion mars/oscar/backends/message.pyx
Original file line number Diff line number Diff line change
Expand Up @@ -515,7 +515,7 @@ class DeserializeMessageFailed(RuntimeError):


cdef class MessageSerializer(Serializer):
serializer_id = 56951
serializer_id = 32105

cpdef serial(self, object obj, dict context):
cdef _MessageBase msg = <_MessageBase>obj
Expand Down
8 changes: 6 additions & 2 deletions mars/oscar/batch.py
Original file line number Diff line number Diff line change
Expand Up @@ -143,7 +143,9 @@ def _gen_args_kwargs_list(delays):
async def _async_batch(self, *delays):
# when there is only one call in batch, calling one-pass method
# will be more efficient
if len(delays) == 1:
if len(delays) == 0:
return []
elif len(delays) == 1:
d = delays[0]
return [await self._async_call(*d.args, **d.kwargs)]
elif self.batch_func:
Expand All @@ -162,7 +164,9 @@ async def _async_batch(self, *delays):
return await asyncio.gather(*tasks)

def _sync_batch(self, *delays):
if self.batch_func:
if delays == 0:
return []
elif self.batch_func:
args_list, kwargs_list = self._gen_args_kwargs_list(delays)
return self.batch_func(args_list, kwargs_list)
else:
Expand Down
5 changes: 5 additions & 0 deletions mars/oscar/tests/test_batch.py
Original file line number Diff line number Diff line change
Expand Up @@ -146,6 +146,11 @@ def method(self, args_list, kwargs_list):
if use_async:
assert asyncio.iscoroutinefunction(TestClass.method)

test_inst = TestClass()
ret = test_inst.method.batch()
ret = await ret if use_async else ret
assert ret == []

test_inst = TestClass()
ret = test_inst.method.batch(test_inst.method.delay(12))
ret = await ret if use_async else ret
Expand Down
2 changes: 1 addition & 1 deletion mars/serialization/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@
# limitations under the License.

from .aio import AioSerializer, AioDeserializer
from .core import serialize, deserialize, Serializer
from .core import serialize, serialize_with_spawn, deserialize, Serializer

from . import arrow, cuda, numpy, scipy, mars_objects, ray, exception

Expand Down
19 changes: 14 additions & 5 deletions mars/serialization/aio.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.

import asyncio
import struct
from io import BytesIO
from typing import Any
Expand All @@ -20,12 +21,13 @@
import numpy as np

from ..utils import lazy_import
from .core import serialize, deserialize
from .core import serialize_with_spawn, deserialize

cupy = lazy_import("cupy", globals=globals())
cudf = lazy_import("cudf", globals=globals())

DEFAULT_SERIALIZATION_VERSION = 1
DEFAULT_SPAWN_THRESHOLD = 100
BUFFER_SIZES_NAME = "buf_sizes"


Expand All @@ -34,8 +36,10 @@ def __init__(self, obj: Any, compress=0):
self._obj = obj
self._compress = compress

def _get_buffers(self):
headers, buffers = serialize(self._obj)
async def _get_buffers(self):
headers, buffers = await serialize_with_spawn(
self._obj, spawn_threshold=DEFAULT_SPAWN_THRESHOLD
)

def _is_cuda_buffer(buf): # pragma: no cover
if cupy is not None and cudf is not None:
Expand Down Expand Up @@ -78,7 +82,7 @@ def _is_cuda_buffer(buf): # pragma: no cover
return out_buffers

async def run(self):
return self._get_buffers()
return await self._get_buffers()


MALFORMED_MSG = """\
Expand Down Expand Up @@ -123,8 +127,13 @@ async def _get_obj(self):
buffer_sizes = header[0].pop(BUFFER_SIZES_NAME)
# get buffers
buffers = [await self._readexactly(size) for size in buffer_sizes]
# get num of objs
num_objs = header[0].get("_N", 0)

return deserialize(header, buffers)
if num_objs <= DEFAULT_SPAWN_THRESHOLD:
return deserialize(header, buffers)
else:
return await asyncio.to_thread(deserialize, header, buffers)

async def run(self):
return await self._get_obj()
Expand Down
9 changes: 8 additions & 1 deletion mars/serialization/core.pyi
Original file line number Diff line number Diff line change
Expand Up @@ -12,10 +12,11 @@
# See the License for the specific language governing permissions and
# limitations under the License.

from concurrent.futures import Executor
from typing import Any, Callable, Dict, List, Tuple

def buffered(func: Callable) -> Callable: ...
def short_id(obj: Any) -> int: ...
def fast_id(obj: Any) -> int: ...

class Serializer:
serializer_id: int
Expand All @@ -42,4 +43,10 @@ class Placeholder:
def __eq__(self, other): ...

def serialize(obj: Any, context: Dict = None): ...
async def serialize_with_spawn(
obj: Any,
context: Dict = None,
spawn_threshold: int = 100,
executor: Executor = None,
): ...
def deserialize(headers: List, buffers: List, context: Dict = None): ...
Loading