Skip to content

Commit

Permalink
Allows client to discard object data (#46)
Browse files Browse the repository at this point in the history
  • Loading branch information
rafa-be authored Jan 23, 2025
1 parent 6b5269b commit 1d36288
Show file tree
Hide file tree
Showing 20 changed files with 203 additions and 34 deletions.
2 changes: 1 addition & 1 deletion scaler/about.py
Original file line number Diff line number Diff line change
@@ -1 +1 @@
__version__ = "1.8.14"
__version__ = "1.8.17"
2 changes: 1 addition & 1 deletion scaler/client/agent/client_agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -176,7 +176,7 @@ async def __get_loops(self):
finally:
self._stop_event.set() # always set the stop event before setting futures' exceptions

await self._object_manager.clean_all_objects()
await self._object_manager.clear_all_objects(clear_serializer=True)

self._connector_external.destroy()
self._connector_internal.destroy()
Expand Down
10 changes: 8 additions & 2 deletions scaler/client/agent/future_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@
from scaler.client.serializer.mixins import Serializer
from scaler.io.utility import concat_list_of_bytes
from scaler.protocol.python.common import TaskStatus
from scaler.protocol.python.message import ObjectResponse, TaskResult
from scaler.protocol.python.message import ObjectResponse, TaskCancel, TaskResult
from scaler.utility.exceptions import DisconnectedError, NoWorkerError, TaskNotFoundError, WorkerDiedError
from scaler.utility.metadata.profile_result import retrieve_profiling_result_from_task_result
from scaler.utility.object_utility import deserialize_failure
Expand All @@ -34,6 +34,8 @@ def cancel_all_futures(self):
for task_id, future in self._task_id_to_future.items():
future.cancel()

self._task_id_to_future.clear()

def set_all_futures_with_exception(self, exception: Exception):
with self._lock:
for future in self._task_id_to_future.values():
Expand All @@ -42,7 +44,7 @@ def set_all_futures_with_exception(self, exception: Exception):
except InvalidStateError:
continue # Future got canceled

self._task_id_to_future = dict()
self._task_id_to_future.clear()

def on_task_result(self, result: TaskResult):
with self._lock:
Expand Down Expand Up @@ -94,6 +96,10 @@ def on_task_result(self, result: TaskResult):
except InvalidStateError:
return # Future got canceled

def on_cancel_task(self, task_cancel: TaskCancel):
with self._lock:
self._task_id_to_future.pop(task_cancel.task_id, None)

def on_object_response(self, response: ObjectResponse):
for object_id, object_name, object_bytes in zip(
response.object_content.object_ids,
Expand Down
9 changes: 7 additions & 2 deletions scaler/client/agent/mixins.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
ObjectRequest,
ObjectResponse,
Task,
TaskCancel,
TaskResult,
)

Expand Down Expand Up @@ -40,11 +41,11 @@ async def on_object_request(self, request: ObjectRequest):
raise NotImplementedError()

@abc.abstractmethod
def record_task_result(self, task_id: bytes, object_id: bytes):
def on_task_result(self, result: TaskResult):
raise NotImplementedError()

@abc.abstractmethod
async def clean_all_objects(self):
async def clear_all_objects(self, clear_serializer: bool):
raise NotImplementedError()


Expand Down Expand Up @@ -79,6 +80,10 @@ def set_all_futures_with_exception(self, exception: Exception):
def on_task_result(self, result: TaskResult):
raise NotImplementedError()

@abc.abstractmethod
def on_cancel_task(self, task_cancel: TaskCancel):
raise NotImplementedError()

@abc.abstractmethod
def on_object_response(self, response: ObjectResponse):
raise NotImplementedError()
Expand Down
46 changes: 40 additions & 6 deletions scaler/client/agent/object_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,12 +3,18 @@
from scaler.client.agent.mixins import ObjectManager
from scaler.io.async_connector import AsyncConnector
from scaler.protocol.python.common import ObjectContent
from scaler.protocol.python.message import ObjectInstruction, ObjectRequest
from scaler.protocol.python.message import (
ObjectInstruction,
ObjectRequest,
TaskResult,
)


class ClientObjectManager(ObjectManager):
def __init__(self, identity: bytes):
self._sent_object_ids: Set[bytes] = set()
self._sent_serializer_id: Optional[bytes] = None

self._identity = identity

self._connector_internal: Optional[AsyncConnector] = None
Expand All @@ -23,23 +29,38 @@ async def on_object_instruction(self, instruction: ObjectInstruction):
await self.__send_object_creation(instruction)
elif instruction.instruction_type == ObjectInstruction.ObjectInstructionType.Delete:
await self.__delete_objects(instruction)
elif instruction.instruction_type == ObjectInstruction.ObjectInstructionType.Clear:
await self.clear_all_objects(clear_serializer=False)

async def on_object_request(self, object_request: ObjectRequest):
assert object_request.request_type == ObjectRequest.ObjectRequestType.Get
await self._connector_external.send(object_request)

def record_task_result(self, task_id: bytes, object_id: bytes):
self._sent_object_ids.add(object_id)
def on_task_result(self, task_result: TaskResult):
# TODO: received result objects should be deleted from the scheduler when no longer needed.
# This requires to not delete objects that are required by not-yet-computed dependent graph tasks.
# For now, we just remove the objects when the client makes a clear request, or on client shutdown.
# https://github.com/Citi/scaler/issues/43

self._sent_object_ids.update(task_result.results)

async def clear_all_objects(self, clear_serializer):
cleared_object_ids = self._sent_object_ids.copy()

if clear_serializer:
self._sent_serializer_id = None
elif self._sent_serializer_id is not None:
cleared_object_ids.remove(self._sent_serializer_id)

self._sent_object_ids.difference_update(cleared_object_ids)

async def clean_all_objects(self):
await self._connector_external.send(
ObjectInstruction.new_msg(
ObjectInstruction.ObjectInstructionType.Delete,
self._identity,
ObjectContent.new_msg(tuple(self._sent_object_ids)),
ObjectContent.new_msg(tuple(cleared_object_ids)),
)
)
self._sent_object_ids = set()

async def __send_object_creation(self, instruction: ObjectInstruction):
assert instruction.instruction_type == ObjectInstruction.ObjectInstructionType.Create
Expand All @@ -48,12 +69,20 @@ async def __send_object_creation(self, instruction: ObjectInstruction):
if not new_object_ids:
return

if ObjectContent.ObjectContentType.Serializer in instruction.object_content.object_types:
if self._sent_serializer_id is not None:
raise ValueError("trying to send multiple serializers.")

serializer_index = instruction.object_content.object_types.index(ObjectContent.ObjectContentType.Serializer)
self._sent_serializer_id = instruction.object_content.object_ids[serializer_index]

new_object_content = ObjectContent.new_msg(
*zip(
*filter(
lambda object_pack: object_pack[0] in new_object_ids,
zip(
instruction.object_content.object_ids,
instruction.object_content.object_types,
instruction.object_content.object_names,
instruction.object_content.object_bytes,
),
Expand All @@ -71,5 +100,10 @@ async def __send_object_creation(self, instruction: ObjectInstruction):

async def __delete_objects(self, instruction: ObjectInstruction):
assert instruction.instruction_type == ObjectInstruction.ObjectInstructionType.Delete

if self._sent_serializer_id in instruction.object_content.object_ids:
raise ValueError("trying to delete serializer.")

self._sent_object_ids.difference_update(instruction.object_content.object_ids)

await self._connector_external.send(instruction)
12 changes: 7 additions & 5 deletions scaler/client/agent/task_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,6 @@
from scaler.client.agent.future_manager import ClientFutureManager
from scaler.client.agent.mixins import ObjectManager, TaskManager
from scaler.io.async_connector import AsyncConnector
from scaler.protocol.python.common import TaskStatus
from scaler.protocol.python.message import GraphTask, GraphTaskCancel, Task, TaskCancel, TaskResult


Expand Down Expand Up @@ -39,6 +38,8 @@ async def on_cancel_task(self, task_cancel: TaskCancel):
return

self._task_ids.remove(task_cancel.task_id)
self._future_manager.on_cancel_task(task_cancel)

await self._connector_external.send(task_cancel)

async def on_new_graph_task(self, task: GraphTask):
Expand All @@ -54,13 +55,14 @@ async def on_cancel_graph_task(self, task_cancel: GraphTaskCancel):
await self._connector_external.send(task_cancel)

async def on_task_result(self, result: TaskResult):
# All task result objects must be propagated to the object manager, even if we do not track the task anymore
# (e.g. if it got cancelled). If we don't, we might lose track of these result objects and not properly clear
# them.
self._object_manager.on_task_result(result)

if result.task_id not in self._task_ids:
return

self._task_ids.remove(result.task_id)

if result.status != TaskStatus.Canceled:
for result_object_id in result.results:
self._object_manager.record_task_result(result.task_id, result_object_id)

self._future_manager.on_task_result(result)
9 changes: 9 additions & 0 deletions scaler/client/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -308,6 +308,15 @@ def send_object(self, obj: Any, name: Optional[str] = None) -> ObjectReference:
cache = self._object_buffer.buffer_send_object(obj, name)
return ObjectReference(cache.object_name, cache.object_id, sum(map(len, cache.object_bytes)))

def clear(self):
"""
clear all resources used by the client, this will cancel all running futures and invalidate all existing object
references
"""

self._future_manager.cancel_all_futures()
self._object_buffer.clear()

def disconnect(self):
"""
disconnect from connected scheduler, this will not shut down the scheduler
Expand Down
37 changes: 33 additions & 4 deletions scaler/client/object_buffer.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
@dataclasses.dataclass
class ObjectCache:
object_id: bytes
object_type: ObjectContent.ObjectContentType
object_name: bytes
object_bytes: List[bytes]

Expand Down Expand Up @@ -54,7 +55,8 @@ def commit_send_objects(self):
return

objects_to_send = [
(obj_cache.object_id, obj_cache.object_name, obj_cache.object_bytes) for obj_cache in self._pending_objects
(obj_cache.object_id, obj_cache.object_type, obj_cache.object_name, obj_cache.object_bytes)
for obj_cache in self._pending_objects
]

self._connector.send(
Expand All @@ -65,7 +67,7 @@ def commit_send_objects(self):
)
)

self._pending_objects = list()
self._pending_objects.clear()

def commit_delete_objects(self):
if not self._pending_delete_objects:
Expand All @@ -81,16 +83,38 @@ def commit_delete_objects(self):

self._pending_delete_objects.clear()

def clear(self):
"""
remove all committed and pending objects.
"""

self._pending_delete_objects.clear()
self._pending_objects.clear()

self._connector.send(
ObjectInstruction.new_msg(
ObjectInstruction.ObjectInstructionType.Clear,
self._identity,
ObjectContent.new_msg(tuple()),
)
)

def __construct_serializer(self) -> ObjectCache:
serializer_bytes = cloudpickle.dumps(self._serializer, protocol=pickle.HIGHEST_PROTOCOL)
object_id = generate_serializer_object_id(self._identity)
return ObjectCache(object_id, b"serializer", chunk_to_list_of_bytes(serializer_bytes))
return ObjectCache(
object_id,
ObjectContent.ObjectContentType.Serializer,
b"serializer",
chunk_to_list_of_bytes(serializer_bytes)
)

def __construct_function(self, fn: Callable) -> ObjectCache:
function_bytes = self._serializer.serialize(fn)
object_id = generate_object_id(self._identity, function_bytes)
function_cache = ObjectCache(
object_id,
ObjectContent.ObjectContentType.Object,
getattr(fn, "__name__", f"<func {object_id.hex()[:6]}>").encode(),
chunk_to_list_of_bytes(function_bytes),
)
Expand All @@ -100,4 +124,9 @@ def __construct_object(self, obj: Any, name: Optional[str] = None) -> ObjectCach
object_payload = self._serializer.serialize(obj)
object_id = generate_object_id(self._identity, object_payload)
name_bytes = name.encode() if name else f"<obj {object_id.hex()[-6:]}>".encode()
return ObjectCache(object_id, name_bytes, chunk_to_list_of_bytes(object_payload))
return ObjectCache(
object_id,
ObjectContent.ObjectContentType.Object,
name_bytes,
chunk_to_list_of_bytes(object_payload)
)
10 changes: 8 additions & 2 deletions scaler/protocol/capnp/common.capnp
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,12 @@ enum TaskStatus {

struct ObjectContent {
objectIds @0 :List(Data);
objectNames @1 :List(Data);
objectBytes @2 :List(List(Data));
objectTypes @1 :List(ObjectContentType);
objectNames @2 :List(Data);
objectBytes @3 :List(List(Data));

enum ObjectContentType {
serializer @0;
object @1;
}
}
1 change: 1 addition & 0 deletions scaler/protocol/capnp/message.capnp
Original file line number Diff line number Diff line change
Expand Up @@ -77,6 +77,7 @@ struct ObjectInstruction {
enum ObjectInstructionType {
create @0;
delete @1;
clear @2;
}
}

Expand Down
18 changes: 17 additions & 1 deletion scaler/protocol/python/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,13 +25,25 @@ class TaskStatus(enum.Enum):

@dataclasses.dataclass
class ObjectContent(Message):
class ObjectContentType(enum.Enum):
# FIXME: Pycapnp does not support assignment of raw enum values when the enum is itself declared within a list.
# However, assigning the enum's string value works.
# See https://github.com/capnproto/pycapnp/issues/374

Serializer = "serializer"
Object = "object"

def __init__(self, msg):
super().__init__(msg)

@property
def object_ids(self) -> Tuple[bytes, ...]:
return tuple(self._msg.objectIds)

@property
def object_types(self) -> Tuple[ObjectContentType, ...]:
return tuple(ObjectContent.ObjectContentType(object_type._as_str()) for object_type in self._msg.objectTypes)

@property
def object_names(self) -> Tuple[bytes, ...]:
return tuple(self._msg.objectNames)
Expand All @@ -43,12 +55,16 @@ def object_bytes(self) -> Tuple[List[bytes], ...]:
@staticmethod
def new_msg(
object_ids: Tuple[bytes, ...],
object_types: Tuple[ObjectContentType, ...] = tuple(),
object_names: Tuple[bytes, ...] = tuple(),
object_bytes: Tuple[List[bytes], ...] = tuple(),
) -> "ObjectContent":
return ObjectContent(
_common.ObjectContent(
objectIds=list(object_ids), objectNames=list(object_names), objectBytes=tuple(object_bytes)
objectIds=list(object_ids),
objectTypes=[object_type.value for object_type in object_types],
objectNames=list(object_names),
objectBytes=tuple(object_bytes),
)
)

Expand Down
1 change: 1 addition & 0 deletions scaler/protocol/python/message.py
Original file line number Diff line number Diff line change
Expand Up @@ -290,6 +290,7 @@ class ObjectInstruction(Message):
class ObjectInstructionType(enum.Enum):
Create = _message.ObjectInstruction.ObjectInstructionType.create
Delete = _message.ObjectInstruction.ObjectInstructionType.delete
Clear = _message.ObjectInstruction.ObjectInstructionType.clear

def __init__(self, msg):
super().__init__(msg)
Expand Down
Loading

0 comments on commit 1d36288

Please sign in to comment.