Skip to content

Commit

Permalink
[serve] add blocking getter to messagequeue (ray-project#47764)
Browse files Browse the repository at this point in the history
## Why are these changes needed?

Add async `get_one_message` to `MessageQueue`, which is a blocking call.
It will wait for either a new message, at which point it will return it,
or if there are no more messages in the queue:
- Raise `StopAsyncIteration` if the queue is closed
- Raise error if `set_error` is called.

Also add unit tests.

Signed-off-by: Cindy Zhang <cindyzyx9@gmail.com>
  • Loading branch information
zcin authored Sep 20, 2024
1 parent 0bc22d0 commit 8173a1e
Show file tree
Hide file tree
Showing 2 changed files with 115 additions and 9 deletions.
50 changes: 42 additions & 8 deletions python/ray/serve/_private/http_util.py
Original file line number Diff line number Diff line change
Expand Up @@ -157,6 +157,7 @@ def __init__(self):
self._message_queue = deque()
self._new_message_event = asyncio.Event()
self._closed = False
self._error = None

def close(self):
"""Close the queue, rejecting new messages.
Expand All @@ -168,6 +169,9 @@ def close(self):
self._closed = True
self._new_message_event.set()

def set_error(self, e: BaseException):
self._error = e

def put_nowait(self, message: Message):
self._message_queue.append(message)
self._new_message_event.set()
Expand All @@ -182,6 +186,18 @@ async def __call__(self, message: Message):

self.put_nowait(message)

async def wait_for_message(self):
"""Wait until at least one new message is available.
If a message is available, this method will return immediately on each call
until `get_messages_nowait` is called.
After the queue is closed using `.close()`, this will always return
immediately.
"""
if not self._closed:
await self._new_message_event.wait()

def get_messages_nowait(self) -> List[Message]:
"""Returns all messages that are currently available (non-blocking).
Expand All @@ -196,17 +212,35 @@ def get_messages_nowait(self) -> List[Message]:
self._new_message_event.clear()
return messages

async def wait_for_message(self):
"""Wait until at least one new message is available.
async def get_one_message(self) -> Message:
"""This blocks until a message is ready.
If a message is available, this method will return immediately on each call
until `get_messages_nowait` is called.
This method should not be used together with get_messages_nowait.
Please use either `get_one_message` or `get_messages_nowait`.
After the queue is closed using `.close()`, this will always return
immediately.
Raises:
StopAsyncIteration: if the queue is closed and there are no
more messages.
Exception (self._error): if there are no more messages in
the queue and an error has been set.
"""
if not self._closed:
await self._new_message_event.wait()

if self._error:
raise self._error

await self._new_message_event.wait()

if len(self._message_queue) > 0:
msg = self._message_queue.popleft()

if len(self._message_queue) == 0 and not self._closed:
self._new_message_event.clear()

return msg
elif len(self._message_queue) == 0 and self._error:
raise self._error
elif len(self._message_queue) == 0 and self._closed:
raise StopAsyncIteration


class ASGIReceiveProxy:
Expand Down
74 changes: 73 additions & 1 deletion python/ray/serve/tests/unit/test_http_util.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@


@pytest.mark.asyncio
async def test_message_queue():
async def test_message_queue_nowait():
queue = MessageQueue()

# Check that wait_for_message hangs until a message is sent.
Expand Down Expand Up @@ -64,6 +64,78 @@ async def test_message_queue():
assert queue.get_messages_nowait() == []


@pytest.mark.asyncio
async def test_message_queue_wait():
queue = MessageQueue()

with pytest.raises(asyncio.TimeoutError):
await asyncio.wait_for(queue.get_one_message(), 0.001)

queue.put_nowait("A")
assert await queue.get_one_message() == "A"

# Check that messages are cleared after being consumed.
with pytest.raises(asyncio.TimeoutError):
await asyncio.wait_for(queue.get_one_message(), 0.001)

# Check that consecutive messages are returned in order.
queue.put_nowait("B")
queue.put_nowait("C")
assert await queue.get_one_message() == "B"
assert await queue.get_one_message() == "C"

# Check that messages are cleared after being consumed.
with pytest.raises(asyncio.TimeoutError):
await asyncio.wait_for(queue.get_one_message(), 0.001)

# Check that a concurrent waiter is notified when a message is available.
loop = asyncio.get_running_loop()
fetch_task = loop.create_task(queue.get_one_message())
for _ in range(1000):
assert not fetch_task.done()
queue.put_nowait("D")
assert await fetch_task == "D"


@pytest.mark.asyncio
async def test_message_queue_wait_closed():
queue = MessageQueue()

queue.put_nowait("A")
assert await queue.get_one_message() == "A"

# Check that once the queue is closed, ongoing and subsequent calls
# to get_one_message should raise an exception
loop = asyncio.get_running_loop()
fetch_task = loop.create_task(queue.get_one_message())
queue.close()
with pytest.raises(StopAsyncIteration):
await fetch_task

for _ in range(10):
with pytest.raises(StopAsyncIteration):
await queue.get_one_message()


@pytest.mark.asyncio
async def test_message_queue_wait_error():
queue = MessageQueue()

queue.put_nowait("A")
assert await queue.get_one_message() == "A"

# Check setting an error
loop = asyncio.get_running_loop()
fetch_task = loop.create_task(queue.get_one_message())
queue.set_error(TypeError("uh oh! something went wrong."))
with pytest.raises(TypeError, match="uh oh! something went wrong"):
await fetch_task

for _ in range(10):
with pytest.raises(TypeError, match="uh oh! something went wrong"):
await queue.get_one_message()


@pytest.fixture
@pytest.mark.asyncio
def setup_receive_proxy(
Expand Down

0 comments on commit 8173a1e

Please sign in to comment.