Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[core] out of band serialization exception #47544

Merged
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
@@ -0,0 +1,64 @@
# __anti_pattern_start__
import ray
import pickle
from ray._private.internal_api import memory_summary
import ray.exceptions

ray.init()


@ray.remote
def out_of_band_serialization_pickle():
obj_ref = ray.put(1)
import pickle

# object_ref is serialized from user code using a regular pickle.
# Ray can't keep track of the reference, so the underlying object
# can be GC'ed unexpectedly, which can cause unexpected hangs.
return pickle.dumps(obj_ref)


@ray.remote
def out_of_band_serialization_ray_cloudpickle():
obj_ref = ray.put(1)
from ray import cloudpickle

# ray.cloudpickle can serialize only when
# RAY_allow_out_of_band_object_ref_serialization=1 env var is set.
# However, the object_ref is pinned for the lifetime of the worker,
# which can cause Ray object leaks that can cause spilling.
return cloudpickle.dumps(obj_ref)


print("==== serialize object ref with pickle ====")
result = ray.get(out_of_band_serialization_pickle.remote())
try:
ray.get(pickle.loads(result), timeout=5)
except ray.exceptions.GetTimeoutError:
print("Underlying object is unexpectedly GC'ed!\n\n")

print("==== serialize object ref with ray.cloudpickle ====")
# By default, it's allowed to serialize ray.ObjectRef using
# ray.cloudpickle.
ray.get(out_of_band_serialization_ray_cloudpickle.options().remote())
# you can see objects are stil pinned although it's GC'ed and not used anymore.
print(memory_summary())

print(
"==== serialize object ref with ray.cloudpickle with env var "
"RAY_allow_out_of_band_object_ref_serialization=0 for debugging ===="
)
try:
ray.get(
out_of_band_serialization_ray_cloudpickle.options(
runtime_env={
"env_vars": {
"RAY_allow_out_of_band_object_ref_serialization": "0",
}
}
).remote()
)
except Exception as e:
print(f"Exception raised from out_of_band_serialization_ray_cloudpickle {e}\n\n")

# __anti_pattern_end__
2 changes: 2 additions & 0 deletions doc/source/ray-core/objects/serialization.rst
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,8 @@ Plasma is used to efficiently transfer objects across different processes and di

Each node has its own object store. When data is put into the object store, it does not get automatically broadcasted to other nodes. Data remains local to the writer until requested by another task or actor on another node.

.. _serialize-object-ref:

Serializing ObjectRefs
~~~~~~~~~~~~~~~~~~~~~~

Expand Down
1 change: 1 addition & 0 deletions doc/source/ray-core/patterns/index.rst
Original file line number Diff line number Diff line change
Expand Up @@ -26,3 +26,4 @@ This section is a collection of common design patterns and anti-patterns for wri
pass-large-arg-by-value
closure-capture-large-objects
global-variables
out-of-band-object-ref-serialization
Original file line number Diff line number Diff line change
@@ -0,0 +1,26 @@
.. _ray-out-of-band-object-ref-serialization:

Anti-pattern: Serialize ray.ObjectRef out of band
rkooo567 marked this conversation as resolved.
Show resolved Hide resolved
=================================================

**TLDR:** Avoid serializing ``ray.ObjectRef`` because Ray can't know when to garbage collect the underlying object.

Ray's ``ray.ObjectRef`` is distributed reference counted. Ray pins the underlying object until the reference isn't used by the system anymore.
When all references are the pinned object gone, Ray garbage collects the pinned object and cleans it up from the system.
However, if user code serializes ``ray.objectRef``, Ray can't keep track of the reference.

To avoid incorrect behavior, if ``ray.cloudpickle`` serializes``ray.ObjectRef``, Ray pins the object for the lifetime of a worker. "Pin" means that object can't be evicted from the object store
until the corresponding owner worker dies. It's prone to Ray object leaks, which can lead disk spilling. See :ref:`thjs page <serialize-object-ref>` for more details.

To detect if this pattern exists in your code, you can set an environment variable ``RAY_allow_out_of_band_object_ref_serialization=0``. If Ray detects
that ``ray.cloudpickle`` serialized``ray.ObjectRef``, it raises an exception with helpful messages.

Code example
------------

**Anti-pattern:**

.. literalinclude:: ../doc_code/anti_pattern_out_of_band_object_ref_serialization.py
:language: python
:start-after: __anti_pattern_start__
:end-before: __anti_pattern_end__
63 changes: 50 additions & 13 deletions python/ray/_private/serialization.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,8 @@
import logging
import threading
import traceback
from typing import Any
from typing import Any, Optional


import google.protobuf.message

Expand Down Expand Up @@ -48,11 +49,15 @@
OutOfMemoryError,
ObjectRefStreamEndOfStreamError,
)
import ray.exceptions
from ray.experimental.compiled_dag_ref import CompiledDAGRef
from ray.util import serialization_addons
from ray.util import inspect_serializability

logger = logging.getLogger(__name__)
ALLOW_OUT_OF_BAND_OBJECT_REF_SERIALIZATION = ray_constants.env_bool(
"RAY_allow_out_of_band_object_ref_serialization", True
)


class DeserializationError(Exception):
Expand All @@ -65,11 +70,14 @@ def pickle_dumps(obj: Any, error_msg: str):
"""
try:
return pickle.dumps(obj)
except TypeError as e:
except (TypeError, ray.exceptions.OufOfBandObjectRefSerializationException) as e:
sio = io.StringIO()
inspect_serializability(obj, print_file=sio)
msg = f"{error_msg}:\n{sio.getvalue()}"
raise TypeError(msg) from e
if isinstance(e, TypeError):
raise TypeError(msg) from e
else:
raise ray.exceptions.OufOfBandObjectRefSerializationException(msg)


def _object_ref_deserializer(binary, call_site, owner_address, object_status):
Expand Down Expand Up @@ -127,7 +135,12 @@ def actor_handle_reducer(obj):
serialized, actor_handle_id, weak_ref = obj._serialization_helper()
# Update ref counting for the actor handle
if not weak_ref:
self.add_contained_object_ref(actor_handle_id)
self.add_contained_object_ref(
actor_handle_id,
# Right now, so many tests are failing when this is set.
# Allow it for now, but we should eventually disallow it here.
allow_out_of_band_serialization=True,
)
return _actor_handle_deserializer, (serialized, weak_ref)

self._register_cloudpickle_reducer(ray.actor.ActorHandle, actor_handle_reducer)
Expand All @@ -140,7 +153,13 @@ def compiled_dag_ref_reducer(obj):
def object_ref_reducer(obj):
worker = ray._private.worker.global_worker
worker.check_connected()
self.add_contained_object_ref(obj)
self.add_contained_object_ref(
obj,
allow_out_of_band_serialization=(
ALLOW_OUT_OF_BAND_OBJECT_REF_SERIALIZATION
),
call_site=obj.call_site(),
)
obj, owner_address, object_status = worker.core_worker.serialize_object_ref(
obj
)
Expand Down Expand Up @@ -199,7 +218,13 @@ def get_and_clear_contained_object_refs(self):
self._thread_local.object_refs = set()
return object_refs

def add_contained_object_ref(self, object_ref):
def add_contained_object_ref(
self,
object_ref: "ray.ObjectRef",
*,
allow_out_of_band_serialization: bool,
call_site: Optional[str] = None,
):
if self.is_in_band_serialization():
# This object ref is being stored in an object. Add the ID to the
# list of IDs contained in the object so that we keep the inner
Expand All @@ -208,13 +233,25 @@ def add_contained_object_ref(self, object_ref):
self._thread_local.object_refs = set()
self._thread_local.object_refs.add(object_ref)
else:
# If this serialization is out-of-band (e.g., from a call to
# cloudpickle directly or captured in a remote function/actor),
# then pin the object for the lifetime of this worker by adding
# a local reference that won't ever be removed.
ray._private.worker.global_worker.core_worker.add_object_ref_reference(
object_ref
)
if not allow_out_of_band_serialization:
raise ray.exceptions.OufOfBandObjectRefSerializationException(
f"It is not allowed to serialize ray.ObjectRef {object_ref.hex()}. "
"If you want to allow serialization, "
"set `RAY_allow_out_of_band_object_ref_serialization=1.` "
"If you set the env var, the object is pinned forever in the "
"lifetime of the worker process and can cause Ray object leaks. "
"See the callsite and trace to find where the serialization "
"occurs.\nCallsite: "
f"{call_site or 'Disabled. Set RAY_record_ref_creation_sites=1'}"
)
else:
# If this serialization is out-of-band (e.g., from a call to
# cloudpickle directly or captured in a remote function/actor),
# then pin the object for the lifetime of this worker by adding
# a local reference that won't ever be removed.
ray._private.worker.global_worker.core_worker.add_object_ref_reference(
object_ref
)

def _deserialize_pickle5_data(self, data):
try:
Expand Down
25 changes: 16 additions & 9 deletions python/ray/data/_internal/planner/plan_read_op.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,6 @@
from typing import Iterable, List

import ray
import ray.cloudpickle as cloudpickle
from ray.data._internal.compute import TaskPoolStrategy
from ray.data._internal.execution.interfaces import PhysicalOperator, RefBundle
from ray.data._internal.execution.interfaces.task_context import TaskContext
Expand All @@ -20,15 +19,21 @@
from ray.data._internal.util import _warn_on_high_parallelism
from ray.data.block import Block, BlockMetadata
from ray.data.datasource.datasource import ReadTask
from ray.experimental.locations import get_local_object_locations
from ray.util.debug import log_once

TASK_SIZE_WARN_THRESHOLD_BYTES = 1024 * 1024 # 1 MiB

logger = logging.getLogger(__name__)


def cleaned_metadata(read_task: ReadTask) -> BlockMetadata:
task_size = len(cloudpickle.dumps(read_task))
def cleaned_metadata(read_task: ReadTask, read_task_ref) -> BlockMetadata:
# NOTE: Use the `get_local_object_locations` API to get the size of the
# serialized ReadTask, instead of pickling.
# Because the ReadTask may capture ObjectRef objects, which cannot
# be serialized out-of-band.
locations = get_local_object_locations([read_task_ref])
task_size = locations[read_task_ref]["object_size"]
if task_size > TASK_SIZE_WARN_THRESHOLD_BYTES and log_once(
f"large_read_task_{read_task.read_fn.__name__}"
):
Expand Down Expand Up @@ -68,23 +73,25 @@ def get_input_data(target_max_block_size) -> List[RefBundle]:
read_tasks = op._datasource_or_legacy_reader.get_read_tasks(parallelism)
_warn_on_high_parallelism(parallelism, len(read_tasks))

return [
RefBundle(
ret = []
for read_task in read_tasks:
read_task_ref = ray.put(read_task)
ref_bundle = RefBundle(
[
(
# TODO(chengsu): figure out a better way to pass read
# tasks other than ray.put().
ray.put(read_task),
cleaned_metadata(read_task),
read_task_ref,
cleaned_metadata(read_task, read_task_ref),
)
],
# `owns_blocks` is False, because these refs are the root of the
# DAG. We shouldn't eagerly free them. Otherwise, the DAG cannot
# be reconstructed.
owns_blocks=False,
)
for read_task in read_tasks
]
ret.append(ref_bundle)
return ret

inputs = InputDataBuffer(
input_data_factory=get_input_data,
Expand Down
10 changes: 10 additions & 0 deletions python/ray/exceptions.py
Original file line number Diff line number Diff line change
Expand Up @@ -829,6 +829,15 @@ class ObjectRefStreamEndOfStreamError(RayError):
pass


@DeveloperAPI
class OufOfBandObjectRefSerializationException(RayError):
"""Raised when an `ray.ObjectRef` is out of band serialized by
`ray.cloudpickle`. It is an anti pattern.
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

link to the doc?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I can't yet (because doc is not there yet). we should do in a follow up

"""

pass


@PublicAPI(stability="alpha")
class RayChannelError(RaySystemError):
"""Indicates that Ray encountered a system error related
Expand Down Expand Up @@ -879,5 +888,6 @@ class RayAdagCapacityExceeded(RaySystemError):
ActorUnavailableError,
RayChannelError,
RayChannelTimeoutError,
OufOfBandObjectRefSerializationException,
RayAdagCapacityExceeded,
]
56 changes: 56 additions & 0 deletions python/ray/tests/test_serialization.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,8 @@

import ray
import ray.cluster_utils
import ray.exceptions
from ray import cloudpickle

logger = logging.getLogger(__name__)

Expand Down Expand Up @@ -733,6 +735,60 @@ def test(x, expect):
assert dataclasses.asdict(new_y) == expect_dict


def test_cannot_out_of_band_serialize_object_ref(shutdown_only, monkeypatch):
monkeypatch.setenv("RAY_allow_out_of_band_object_ref_serialization", "0")
ray.init()

# Use ray.remote as a workaround because
# RAY_allow_out_of_band_object_ref_serialization cannot be set dynamically.
@ray.remote
def test():
ref = ray.put(1)

@ray.remote
def f():
ref

with pytest.raises(ray.exceptions.OufOfBandObjectRefSerializationException):
ray.get(f.remote())

@ray.remote
def f():
cloudpickle.dumps(ray.put(1))

with pytest.raises(ray.exceptions.OufOfBandObjectRefSerializationException):
ray.get(f.remote())

return ray.get(test.remote())


def test_can_out_of_band_serialize_object_ref_with_env_var(shutdown_only, monkeypatch):
monkeypatch.setenv("RAY_allow_out_of_band_object_ref_serialization", "1")
ray.init()

# Use ray.remote as a workaround because
# RAY_allow_out_of_band_object_ref_serialization cannot be set dynamically.
@ray.remote
def test():
ref = ray.put(1)

@ray.remote
def f():
ref

ray.get(f.remote())

@ray.remote
def f():
ref = ray.put(1)
cloudpickle.dumps(ref)

ray.get(f.remote())

# It should pass.
ray.get(test.remote())


if __name__ == "__main__":
import os
import pytest
Expand Down
Loading