Skip to content

Commit

Permalink
[Data] Yield remaining results from async map_batches (#47696)
Browse files Browse the repository at this point in the history
## Why are these changes needed?

When using an async actor with `map_batches()`, there is currently an
unhandled edge case, where if tasks are scheduled very closely with one
another, and all remaining futures complete at the same time, some
remaining items in the internal queue to yield results from the futures
will not be yielded. This PR ensures that we fully drain the internal
queue to get all expected results.

Concretely, this issue came up while using async actors to yield results
from vLLM async engine.

## Related issue number

## Checks

- [x] I've signed off every commit(by using the -s flag, i.e., `git
commit -s`) in this PR.
- [x] I've run `scripts/format.sh` to lint the changes in this PR.
- [ ] I've included any doc changes needed for
https://docs.ray.io/en/master/.
- [ ] I've added any new APIs to the API Reference. For example, if I
added a
method in Tune, I've added it in `doc/source/tune/api/` under the
           corresponding `.rst` file.
- [x] I've made sure the tests are passing. Note that there might be a
few flaky tests, see the recent failures at https://flakey-tests.ray.io/
- Testing Strategy
   - [x] Unit tests
   - [ ] Release tests
   - [ ] This PR is not tested :(

---------

Signed-off-by: Scott Lee <sjl@anyscale.com>
  • Loading branch information
scottjlee authored Sep 18, 2024
1 parent bc2b26e commit ceceb68
Show file tree
Hide file tree
Showing 2 changed files with 42 additions and 3 deletions.
4 changes: 3 additions & 1 deletion python/ray/data/_internal/planner/plan_udf_map_op.py
Original file line number Diff line number Diff line change
Expand Up @@ -348,7 +348,9 @@ async def process_all_batches():
future = asyncio.run_coroutine_threadsafe(process_all_batches(), loop)

# Yield results as they become available.
while not future.done():
# After all futures are completed, drain the queue to
# yield any remaining results.
while not future.done() or not output_batch_queue.empty():
# Here, `out_batch` is a one-row output batch
# from the async generator, corresponding to a
# single row from the input batch.
Expand Down
41 changes: 39 additions & 2 deletions python/ray/data/tests/test_map.py
Original file line number Diff line number Diff line change
Expand Up @@ -1068,9 +1068,7 @@ def test_map_batches_async_generator(shutdown_only):
ray.init(num_cpus=10)

async def sleep_and_yield(i):
print("sleep", i)
await asyncio.sleep(i % 5)
print("yield", i)
return {"input": [i], "output": [2**i]}

class AsyncActor:
Expand Down Expand Up @@ -1119,6 +1117,45 @@ async def __call__(self, batch):
assert "assert False" in str(exc_info.value)


def test_map_batches_async_generator_fast_yield(shutdown_only):
# Tests the case where the async generator yields immediately,
# with a high number of tasks in flight, which results in
# the internal queue being almost instantaneously filled.
# This test ensures that the internal queue is completely drained in this scenario.

ray.shutdown()
ray.init(num_cpus=4)

async def task_yield(row):
return row

class AsyncActor:
def __init__(self):
pass

async def __call__(self, batch):
rows = [{"id": np.array([i])} for i in batch["id"]]
tasks = [asyncio.create_task(task_yield(row)) for row in rows]
for task in tasks:
yield await task

n = 8
ds = ray.data.range(n, override_num_blocks=n)
ds = ds.map_batches(
AsyncActor,
batch_size=n,
compute=ray.data.ActorPoolStrategy(size=1, max_tasks_in_flight_per_actor=n),
concurrency=1,
max_concurrency=n,
)

output = ds.take_all()
expected_output = [{"id": i} for i in range(n)]
# Because all tasks are submitted almost simultaneously,
# the output order may be different compared to the original input.
assert len(output) == len(expected_output), (len(output), len(expected_output))


if __name__ == "__main__":
import sys

Expand Down

0 comments on commit ceceb68

Please sign in to comment.