Skip to content

Commit

Permalink
Fix tests
Browse files Browse the repository at this point in the history
  • Loading branch information
al-rigazzi committed Oct 14, 2024
1 parent ce5a306 commit f23c267
Show file tree
Hide file tree
Showing 9 changed files with 56 additions and 27 deletions.
13 changes: 11 additions & 2 deletions smartsim/_core/mli/comm/channel/channel.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,19 +52,28 @@ def __init__(
"""A user-friendly identifier for channel-related logging"""

@abstractmethod
def send(self, value: bytes, timeout: t.Optional[float] = 0.001) -> None:
def send(
self,
value: bytes,
timeout: t.Optional[float] = 0.001,
handle_timeout: float = 0.001,
) -> None:
"""Send a message through the underlying communication channel.
:param value: The value to send
:param timeout: Maximum time to wait (in seconds) for messages to send
:param handle_timeout: Maximum time to wait to obtain new send handle
:raises SmartSimError: If sending message fails
"""

@abstractmethod
def recv(self, timeout: t.Optional[float] = 0.001) -> t.List[bytes]:
def recv(
self, timeout: t.Optional[float] = 0.001, handle_timeout: float = 0.001
) -> t.List[bytes]:
"""Receives message(s) through the underlying communication channel.
:param timeout: Maximum time to wait (in seconds) for messages to arrive
:param handle_timeout: Maximum time to wait to obtain new receive handle
:returns: The received message
"""

Expand Down
25 changes: 17 additions & 8 deletions smartsim/_core/mli/comm/channel/dragon_channel.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,33 +57,42 @@ def channel(self) -> "dch.Channel":
"""
return self._channel

def send(self, value: bytes, timeout: t.Optional[float] = 0.001) -> None:
def send(
self,
value: bytes,
timeout: t.Optional[float] = 0.001,
handle_timeout: float = 0.001,
) -> None:
"""Send a message through the underlying communication channel.
:param value: The value to send
:param timeout: Maximum time to wait (in seconds) for messages to send
:param timeout: Maximum time to wait (in seconds) for messages to be sent
:param handle_timeout: Maximum time to wait to obtain new send handle
:raises SmartSimError: If sending message fails
"""
try:
with self._channel.sendh(timeout=timeout) as sendh:
sendh.send_bytes(value, timeout=None)
with self._channel.sendh(timeout=handle_timeout) as sendh:
sendh.send_bytes(value, timeout=timeout)
logger.debug(f"DragonCommChannel {self.descriptor} sent message")
except Exception as e:
raise SmartSimError(
f"Error sending via DragonCommChannel {self.descriptor}"
) from e

def recv(self, timeout: t.Optional[float] = 0.001) -> t.List[bytes]:
def recv(
self, timeout: t.Optional[float] = 0.001, handle_timeout: float = 0.001
) -> t.List[bytes]:
"""Receives message(s) through the underlying communication channel.
:param timeout: Maximum time to wait (in seconds) for messages to arrive
:param timeout: Maximum time to wait (in seconds) for message to arrive
:param handle_timeout: Maximum time to wait to obtain new receive handle
:returns: The received message(s)
"""
with self._channel.recvh(timeout=timeout) as recvh:
with self._channel.recvh(timeout=handle_timeout) as recvh:
messages: t.List[bytes] = []

try:
message_bytes = recvh.recv_bytes(timeout=None)
message_bytes = recvh.recv_bytes(timeout=timeout)
messages.append(message_bytes)
logger.debug(f"DragonCommChannel {self.descriptor} received message")
except dch.ChannelEmpty:
Expand Down
26 changes: 18 additions & 8 deletions smartsim/_core/mli/comm/channel/dragon_fli.py
Original file line number Diff line number Diff line change
Expand Up @@ -66,21 +66,27 @@ def __init__(
self._buffer_size: int = buffer_size
"""Maximum number of messages that can be buffered before sending"""

def send(self, value: bytes, timeout: t.Optional[float] = 0.001) -> None:
def send(
self,
value: bytes,
timeout: t.Optional[float] = 0.001,
handle_timeout: float = 0.001,
) -> None:
"""Send a message through the underlying communication channel.
:param value: The value to send
:param timeout: Maximum time to wait (in seconds) for messages to send
:param handle_timeout: Maximum time to wait to obtain new send handle
:raises SmartSimError: If sending message fails
"""
try:
if self._channel is None:
self._channel = drg_util.create_local(self._buffer_size)

with self._fli.sendh(
timeout=timeout, stream_channel=self._channel
timeout=handle_timeout, stream_channel=self._channel
) as sendh:
sendh.send_bytes(value, timeout=None)
sendh.send_bytes(value, timeout=timeout)
logger.debug(f"DragonFLIChannel {self.descriptor} sent message")
except Exception as e:
self._channel = None
Expand All @@ -92,30 +98,34 @@ def send_multiple(
self,
values: t.Sequence[bytes],
timeout: t.Optional[float] = 0.001,
handle_timeout: float = 0.001,
) -> None:
"""Send a message through the underlying communication channel.
:param values: The values to send
:param timeout: Maximum time to wait (in seconds) for messages to send
:param handle_timeout: Maximum time to wait to obtain new send handle
:raises SmartSimError: If sending message fails
"""
try:
if self._channel is None:
self._channel = drg_util.create_local(self._buffer_size)

with self._fli.sendh(
timeout=timeout, stream_channel=self._channel
timeout=handle_timeout, stream_channel=self._channel
) as sendh:
for value in values:
sendh.send_bytes(value, timeout=None)
sendh.send_bytes(value, timeout=timeout)
logger.debug(f"DragonFLIChannel {self.descriptor} sent message")
except Exception as e:
self._channel = None
raise SmartSimError(
f"Error sending via DragonFLIChannel {self.descriptor} {e}"
) from e

def recv(self, timeout: t.Optional[float] = 0.001) -> t.List[bytes]:
def recv(
self, timeout: t.Optional[float] = 0.001, handle_timeout: float = 0.001
) -> t.List[bytes]:
"""Receives message(s) through the underlying communication channel.
:param timeout: Maximum time to wait (in seconds) for messages to arrive
Expand All @@ -124,10 +134,10 @@ def recv(self, timeout: t.Optional[float] = 0.001) -> t.List[bytes]:
"""
messages = []
eot = False
with self._fli.recvh(timeout=timeout) as recvh:
with self._fli.recvh(timeout=handle_timeout) as recvh:
while not eot:
try:
message, _ = recvh.recv_bytes(timeout=None)
message, _ = recvh.recv_bytes(timeout=timeout)
messages.append(message)
logger.debug(f"DragonFLIChannel {self.descriptor} received message")
except fli.FLIEOT:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -243,7 +243,7 @@ def __init__(
raise SmartSimError("No incoming channel for dispatcher")
self._incoming_channel = incoming_channel
"""The channel the dispatcher monitors for new tasks"""
self._outgoing_queue: DragonQueue = mp.Queue(maxsize=10000)
self._outgoing_queue: DragonQueue = mp.Queue(maxsize=1000)

Check warning on line 246 in smartsim/_core/mli/infrastructure/control/request_dispatcher.py

View check run for this annotation

Codecov / codecov/patch

smartsim/_core/mli/infrastructure/control/request_dispatcher.py#L246

Added line #L246 was not covered by tests
"""The queue on which batched inference requests are placed"""
self._feature_stores: t.Dict[str, FeatureStore] = {}
"""A collection of attached feature stores"""
Expand Down
8 changes: 4 additions & 4 deletions smartsim/_core/mli/infrastructure/storage/dragon_util.py
Original file line number Diff line number Diff line change
Expand Up @@ -67,7 +67,7 @@ def create_ddict(
:param num_nodes: The number of distributed nodes to distribute the dictionary to.
At least one node is required.
:param mgr_per_node: The number of manager processes per node
:param mem_per_node: The amount of memory (in megabytes) to allocate per node. Total
:param mem_per_node: The amount of memory (in bytes) to allocate per node. Total
memory available will be calculated as `num_nodes * node_mem`
:returns: The instantiated dragon dictionary
Expand All @@ -84,18 +84,18 @@ def create_ddict(
if mem_per_node < dragon_ddict.DDICT_MIN_SIZE:
raise ValueError(
"A dragon dictionary requires at least "
f"{dragon_ddict.DDICT_MIN_SIZE / 1024} MB"
f"{dragon_ddict.DDICT_MIN_SIZE / (1024**2)} MB"
)

mem_total = num_nodes * mem_per_node

logger.debug(
f"Creating dragon dictionary with {num_nodes} nodes, {mem_total} MB memory"
f"Creating dragon dictionary with {num_nodes} nodes, {mem_total} bytes memory"
)

distributed_dict = dragon_ddict.DDict(num_nodes, mgr_per_node, total_mem=mem_total)
logger.debug(
"Successfully created dragon dictionary with "
f"{num_nodes} nodes, {mem_total} MB total memory"
f"{num_nodes} nodes, {mem_total} bytes total memory"
)
return distributed_dict
2 changes: 1 addition & 1 deletion tests/dragon/test_device_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -143,7 +143,7 @@ def test_device_manager_model_in_request():
) as returned_device:

assert returned_device == worker_device
assert worker_device.get_model(model_key.key).model == b'raw model'
assert worker_device.get_model(model_key.key).model == b"raw model"

assert model_key.key not in worker_device

Expand Down
1 change: 0 additions & 1 deletion tests/dragon/test_dragon_backend.py
Original file line number Diff line number Diff line change
Expand Up @@ -109,7 +109,6 @@ def test_dragonbackend_start_listener(the_backend: DragonBackend):
comm_channel.send(event_bytes)

subscriber_list = []
logger.warning(backbone.notification_channels)

# Give the channel time to write the message and the listener time to handle it
for i in range(20):
Expand Down
2 changes: 1 addition & 1 deletion tests/dragon/test_protoclient.py
Original file line number Diff line number Diff line change
Expand Up @@ -196,7 +196,7 @@ def test_protoclient_initialization(
assert client._to_worker_ch is not None

# wrap the channels just to easily verify they produces a descriptor
assert DragonCommChannel(client._from_worker_ch).descriptor
assert DragonCommChannel(client._from_worker_ch.channel).descriptor
assert DragonCommChannel(client._to_worker_ch).descriptor

# confirm a publisher is created
Expand Down
4 changes: 3 additions & 1 deletion tests/dragon/utils/msg_pump.py
Original file line number Diff line number Diff line change
Expand Up @@ -163,7 +163,9 @@ def _mock_messages(

# send the header & body together so they arrive together
try:
request_dispatcher_queue.send_multiple([request_bytes, tensor.tobytes()])
request_dispatcher_queue.send_multiple(
[request_bytes, tensor.tobytes()], timeout=None, handle_timeout=None
)
logger.info(f"\tenvelope 0: {request_bytes[:5]}...")
logger.info(f"\tenvelope 1: {tensor.tobytes()[:5]}...")
except Exception as ex:
Expand Down

0 comments on commit f23c267

Please sign in to comment.