diff --git a/python/ray/data/read_api.py b/python/ray/data/read_api.py index 25add1d6247b9..e2cb556674d2c 100644 --- a/python/ray/data/read_api.py +++ b/python/ray/data/read_api.py @@ -92,25 +92,35 @@ def from_items(items: List[Any], *, parallelism: int = -1) -> Dataset[Any]: Returns: Dataset holding the items. """ + import builtins + + if parallelism == 0: + raise ValueError(f"parallelism must be -1 or > 0, got: {parallelism}") detected_parallelism, _ = _autodetect_parallelism( parallelism, ray.util.get_current_placement_group(), DatasetContext.get_current(), ) - block_size = max( - 1, - len(items) // detected_parallelism, - ) + # Truncate parallelism to number of items to avoid empty blocks. + detected_parallelism = min(len(items), detected_parallelism) + if detected_parallelism > 0: + block_size, remainder = divmod(len(items), detected_parallelism) + else: + block_size, remainder = 0, 0 + # NOTE: We need to explicitly use the builtins range since we override range below, + # with the definition of ray.data.range. blocks: List[ObjectRef[Block]] = [] metadata: List[BlockMetadata] = [] - i = 0 - while i < len(items): + for i in builtins.range(detected_parallelism): stats = BlockExecStats.builder() builder = DelegatingBlockBuilder() - for item in items[i : i + block_size]: - builder.add(item) + # Evenly distribute remainder across block slices while preserving record order. + block_start = i * block_size + min(i, remainder) + block_end = (i + 1) * block_size + min(i + 1, remainder) + for j in builtins.range(block_start, block_end): + builder.add(items[j]) block = builder.build() blocks.append(ray.put(block)) metadata.append( @@ -118,7 +128,6 @@ def from_items(items: List[Any], *, parallelism: int = -1) -> Dataset[Any]: input_files=None, exec_stats=stats.build() ) ) - i += block_size return Dataset( ExecutionPlan( diff --git a/python/ray/data/tests/test_dataset.py b/python/ray/data/tests/test_dataset.py index 33624bd25d03c..0032ff2529248 100644 --- a/python/ray/data/tests/test_dataset.py +++ b/python/ray/data/tests/test_dataset.py @@ -1550,6 +1550,29 @@ def test_from_items(ray_start_regular_shared): assert ds.take() == ["hello", "world"] +@pytest.mark.parametrize("parallelism", list(range(1, 21))) +def test_from_items_parallelism(ray_start_regular_shared, parallelism): + # Test that specifying parallelism yields the expected number of blocks. + n = 20 + records = [{"a": i} for i in range(n)] + ds = ray.data.from_items(records, parallelism=parallelism) + out = ds.take_all() + assert out == records + assert ds.num_blocks() == parallelism + + +def test_from_items_parallelism_truncated(ray_start_regular_shared): + # Test that specifying parallelism greater than the number of items is truncated to + # the number of items. + n = 10 + parallelism = 20 + records = [{"a": i} for i in range(n)] + ds = ray.data.from_items(records, parallelism=parallelism) + out = ds.take_all() + assert out == records + assert ds.num_blocks() == n + + def test_repartition_shuffle(ray_start_regular_shared): ds = ray.data.range(20, parallelism=10) assert ds.num_blocks() == 10