Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[data] Read->SplitBlocks to ensure requested read parallelism is always met #36352

Merged
merged 28 commits into from
Jun 22, 2023
Merged
Show file tree
Hide file tree
Changes from 4 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 5 additions & 0 deletions python/ray/data/_internal/block_list.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,7 @@ def __init__(
# Whether the block list is owned by consuming APIs, and if so it can be
# eagerly deleted after read by the consumer.
self._owned_by_consumer = owned_by_consumer
self._estimated_num_blocks = None
ericl marked this conversation as resolved.
Show resolved Hide resolved

def __repr__(self):
return f"BlockList(owned_by_consumer={self._owned_by_consumer})"
Expand Down Expand Up @@ -217,6 +218,10 @@ def initial_num_blocks(self) -> int:
"""Returns the number of blocks of this BlockList."""
return self._num_blocks

def estimated_num_blocks(self) -> int:
""""""
ericl marked this conversation as resolved.
Show resolved Hide resolved
return self._estimated_num_blocks or self._num_blocks

def executed_num_blocks(self) -> int:
"""Returns the number of output blocks after execution.

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@ def __init__(
self,
input_data: Optional[List[RefBundle]] = None,
input_data_factory: Callable[[], List[RefBundle]] = None,
override_num_blocks: Optional[int] = None,
ericl marked this conversation as resolved.
Show resolved Hide resolved
):
"""Create an InputDataBuffer.

Expand All @@ -37,6 +38,7 @@ def __init__(
assert input_data_factory is not None
self._input_data_factory = input_data_factory
self._is_input_initialized = False
self._override_num_blocks = override_num_blocks
super().__init__("Input", [])

def start(self, options: ExecutionOptions) -> None:
Expand All @@ -53,7 +55,7 @@ def get_next(self) -> RefBundle:
return self._input_data.pop(0)

def num_outputs_total(self) -> Optional[int]:
return self._num_outputs
return self._override_num_blocks or self._num_outputs

def get_stats(self) -> StatsDict:
return {}
Expand Down
11 changes: 10 additions & 1 deletion python/ray/data/_internal/logical/operators/read_operator.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,8 +11,17 @@ def __init__(
self,
datasource: Datasource,
read_tasks: List[ReadTask],
estimated_num_blocks: int,
ray_remote_args: Optional[Dict[str, Any]] = None,
):
super().__init__(f"Read{datasource.get_name()}", None, ray_remote_args)
if len(read_tasks) == estimated_num_blocks:
suffix = ""
else:
suffix = f"->SplitBlocks({int(estimated_num_blocks / len(read_tasks))})"
ericl marked this conversation as resolved.
Show resolved Hide resolved
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This looks like SplitBlocks is a separate op. What about Read(spit_blocks=N)?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

yeah, +1 for ReadXXX(split_blocks=N), otherwise Dataset.__repr__ would become confusing.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm not sure I understand this--- the original proposal is that SplitBlock is supposed to be a logical operator, since it only applies to the output of the read. It seems more clear therefore using the chaining syntax of -> instead of making it part of the Read.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sorry if I miss any context, why don't we implement SplitBlock as a separate logical & physical operator?

The current implementation is inside Datasource, so it looks like part of Read & InputDataBuffer.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think we should, but it would get fused with Read anyways. So here we only implement it as part of Read since we have yet to decide whether it should be a general operator.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

E.g., for dynamic_repartition() or such.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I see, +1 to make it a general operator.

super().__init__(f"Read{datasource.get_name()}{suffix}", None, ray_remote_args)
self._datasource = datasource
self._estimated_num_blocks = estimated_num_blocks
self._read_tasks = read_tasks

def fusable(self) -> bool:
ericl marked this conversation as resolved.
Show resolved Hide resolved
return self._estimated_num_blocks == len(self._read_tasks)
4 changes: 4 additions & 0 deletions python/ray/data/_internal/logical/rules/operator_fusion.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@
Repartition,
)
from ray.data._internal.logical.operators.map_operator import AbstractUDFMap
from ray.data._internal.logical.operators.read_operator import Read
from ray.data._internal.stats import StatsDict
from ray.data.block import Block

Expand Down Expand Up @@ -130,6 +131,9 @@ def _can_fuse(self, down_op: PhysicalOperator, up_op: PhysicalOperator) -> bool:
down_logical_op = self._op_map[down_op]
up_logical_op = self._op_map[up_op]

if isinstance(up_logical_op, Read) and not up_logical_op.fusable():
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why the extra check if it's a Read op?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The fusable method is part of the Read class only.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think we can define fusable in the base LogicalOperator class. Other op may need it as well in the future.

return False

# If the downstream operator takes no input, it cannot be fused with
# the upstream operator.
if not down_logical_op._input_dependencies:
Expand Down
2 changes: 1 addition & 1 deletion python/ray/data/_internal/plan.py
Original file line number Diff line number Diff line change
Expand Up @@ -216,7 +216,7 @@ def get_plan_as_string(self, classname: str) -> str:
if dataset_blocks is None:
num_blocks = "?"
else:
num_blocks = dataset_blocks.initial_num_blocks()
num_blocks = dataset_blocks.estimated_num_blocks()
ericl marked this conversation as resolved.
Show resolved Hide resolved
dataset_str = "{}(num_blocks={}, num_rows={}, schema={})".format(
classname, num_blocks, count, schema_str
)
Expand Down
4 changes: 3 additions & 1 deletion python/ray/data/_internal/planner/plan_read_op.py
Original file line number Diff line number Diff line change
Expand Up @@ -58,7 +58,9 @@ def get_input_data() -> List[RefBundle]:
for read_task in read_tasks
]

inputs = InputDataBuffer(input_data_factory=get_input_data)
inputs = InputDataBuffer(
input_data_factory=get_input_data, override_num_blocks=op._estimated_num_blocks
)

def do_read(blocks: Iterator[ReadTask], _: TaskContext) -> Iterator[Block]:
for read_task in blocks:
Expand Down
9 changes: 5 additions & 4 deletions python/ray/data/_internal/util.py
Original file line number Diff line number Diff line change
Expand Up @@ -91,7 +91,7 @@ def _autodetect_parallelism(
ctx: DataContext,
reader: Optional["Reader"] = None,
avail_cpus: Optional[int] = None,
) -> (int, int):
) -> (int, int, Optional[int]):
"""Returns parallelism to use and the min safe parallelism to avoid OOMs.

This detects parallelism using the following heuristics, applied in order:
Expand All @@ -112,8 +112,9 @@ def _autodetect_parallelism(
avail_cpus: Override avail cpus detection (for testing only).

Returns:
Tuple of detected parallelism (only if -1 was specified), and the min safe
parallelism (which can be used to generate warnings about large blocks).
Tuple of detected parallelism (only if -1 was specified), the min safe
parallelism (which can be used to generate warnings about large blocks),
and the estimated inmemory size of the dataset.
"""
min_safe_parallelism = 1
max_reasonable_parallelism = sys.maxsize
Expand Down Expand Up @@ -141,7 +142,7 @@ def _autodetect_parallelism(
f"estimated_available_cpus={avail_cpus} and "
f"estimated_data_size={mem_size}."
)
return parallelism, min_safe_parallelism
return parallelism, min_safe_parallelism, mem_size


def _estimate_avail_cpus(cur_pg: Optional["PlacementGroup"]) -> int:
Expand Down
20 changes: 19 additions & 1 deletion python/ray/data/datasource/datasource.py
Original file line number Diff line number Diff line change
Expand Up @@ -195,6 +195,7 @@ class ReadTask(Callable[[], Iterable[Block]]):
def __init__(self, read_fn: Callable[[], Iterable[Block]], metadata: BlockMetadata):
self._metadata = metadata
self._read_fn = read_fn
self._additional_output_splits = 1

def get_metadata(self) -> BlockMetadata:
return self._metadata
Expand All @@ -211,13 +212,30 @@ def __call__(self) -> Iterable[Block]:

if context.block_splitting_enabled:
for block in result:
yield block
yield from self._do_additional_splits(block)
else:
builder = DelegatingBlockBuilder()
for block in result:
builder.add_block(block)
yield builder.build()

def _set_additional_split_factor(self, k: int) -> None:
self._additional_output_splits = k

def _do_additional_splits(self, block: Block) -> Iterable[Block]:
if self._additional_output_splits > 1:
block = BlockAccessor.for_block(block)
offset = 0
split_sizes = np.array_split(
ericl marked this conversation as resolved.
Show resolved Hide resolved
range(block.num_rows()), self._additional_output_splits
)
for split in split_sizes:
size = len(split)
yield block.slice(offset, offset + size, copy=True)
ericl marked this conversation as resolved.
Show resolved Hide resolved
offset += size
else:
yield block


@PublicAPI
class RangeDatasource(Datasource):
Expand Down
18 changes: 16 additions & 2 deletions python/ray/data/datasource/parquet_datasource.py
Original file line number Diff line number Diff line change
Expand Up @@ -288,6 +288,14 @@ def get_read_tasks(self, parallelism: int) -> List[ReadTask]:

if meta.size_bytes is not None:
meta.size_bytes = int(meta.size_bytes * self._encoding_ratio)

ericl marked this conversation as resolved.
Show resolved Hide resolved
if meta.num_rows is not None and meta.size_bytes is not None:
row_size = meta.size_bytes / meta.num_rows
default_read_batch_size = min(
PARQUET_READER_ROW_BATCH_SIZE, 64e6 / row_size
)
else:
default_read_batch_size = PARQUET_READER_ROW_BATCH_SIZE
ericl marked this conversation as resolved.
Show resolved Hide resolved
block_udf, reader_args, columns, schema = (
self._block_udf,
self._reader_args,
Expand All @@ -299,6 +307,7 @@ def get_read_tasks(self, parallelism: int) -> List[ReadTask]:
lambda p=serialized_pieces: _read_pieces(
block_udf,
reader_args,
default_read_batch_size,
columns,
schema,
p,
Expand Down Expand Up @@ -363,7 +372,12 @@ def _estimate_files_encoding_ratio(self) -> float:


def _read_pieces(
block_udf, reader_args, columns, schema, serialized_pieces: List[_SerializedPiece]
block_udf,
reader_args,
default_read_batch_size,
columns,
schema,
serialized_pieces: List[_SerializedPiece],
) -> Iterator["pyarrow.Table"]:
# This import is necessary to load the tensor extension type.
from ray.data.extensions.tensor_extension import ArrowTensorType # noqa
Expand All @@ -387,7 +401,7 @@ def _read_pieces(

logger.debug(f"Reading {len(pieces)} parquet pieces")
use_threads = reader_args.pop("use_threads", False)
batch_size = reader_args.pop("batch_size", PARQUET_READER_ROW_BATCH_SIZE)
batch_size = reader_args.pop("batch_size", default_read_batch_size)
for piece in pieces:
part = _get_partition_keys(piece.partition_expression)
batches = piece.to_batches(
Expand Down
96 changes: 65 additions & 31 deletions python/ray/data/read_api.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
import collections
import logging
import math
from typing import (
TYPE_CHECKING,
Any,
Expand Down Expand Up @@ -133,7 +134,7 @@ def from_items(
if parallelism == 0:
raise ValueError(f"parallelism must be -1 or > 0, got: {parallelism}")

detected_parallelism, _ = _autodetect_parallelism(
detected_parallelism, _, _ = _autodetect_parallelism(
ericl marked this conversation as resolved.
Show resolved Hide resolved
parallelism,
ray.util.get_current_placement_group(),
DataContext.get_current(),
Expand Down Expand Up @@ -350,9 +351,12 @@ def read_datasource(
force_local = True

if force_local:
requested_parallelism, min_safe_parallelism, read_tasks = _get_read_tasks(
datasource, ctx, cur_pg, parallelism, local_uri, read_args
)
(
requested_parallelism,
min_safe_parallelism,
inmemory_size,
read_tasks,
) = _get_read_tasks(datasource, ctx, cur_pg, parallelism, local_uri, read_args)
else:
# Prepare read in a remote task at same node.
# NOTE: in Ray client mode, this is expected to be run on head node.
Expand All @@ -365,7 +369,12 @@ def read_datasource(
_get_read_tasks, retry_exceptions=False, num_cpus=0
).options(scheduling_strategy=scheduling_strategy)

requested_parallelism, min_safe_parallelism, read_tasks = ray.get(
(
requested_parallelism,
min_safe_parallelism,
inmemory_size,
read_tasks,
) = ray.get(
get_read_tasks.remote(
datasource,
ctx,
Expand All @@ -376,28 +385,51 @@ def read_datasource(
)
)

if read_tasks and len(read_tasks) < min_safe_parallelism * 0.7:
perc = 1 + round((min_safe_parallelism - len(read_tasks)) / len(read_tasks), 1)
logger.warning(
f"{WARN_PREFIX} The blocks of this dataset are estimated to be {perc}x "
"larger than the target block size "
f"of {int(ctx.target_max_block_size / 1024 / 1024)} MiB. This may lead to "
"out-of-memory errors during processing. Consider reducing the size of "
"input files or using `.repartition(n)` to increase the number of "
"dataset blocks."
)
elif len(read_tasks) < requested_parallelism and (
len(read_tasks) < ray.available_resources().get("CPU", 1) // 2
):
logger.warning(
f"{WARN_PREFIX} The number of blocks in this dataset "
f"({len(read_tasks)}) "
f"limits its parallelism to {len(read_tasks)} concurrent tasks. "
"This is much less than the number "
"of available CPU slots in the cluster. Use `.repartition(n)` to "
"increase the number of "
"dataset blocks."
)
# if read_tasks and len(read_tasks) < min_safe_parallelism * 0.7:
# perc = 1 + round((min_safe_parallelism - len(read_tasks)) / len(read_tasks), 1)
# logger.warning(
# f"{WARN_PREFIX} The blocks of this dataset are estimated to be {perc}x "
# "larger than the target block size "
# f"of {int(ctx.target_max_block_size / 1024 / 1024)} MiB. This may lead to "
# "out-of-memory errors during processing. Consider reducing the size of "
# "input files or using `.repartition(n)` to increase the number of "
# "dataset blocks."
# )
# elif len(read_tasks) < requested_parallelism and (
# len(read_tasks) < ray.available_resources().get("CPU", 1) // 2
# ):
# logger.warning(
# f"{WARN_PREFIX} The number of blocks in this dataset "
# f"({len(read_tasks)}) "
# f"limits its parallelism to {len(read_tasks)} concurrent tasks. "
# "This is much less than the number "
# "of available CPU slots in the cluster. Use `.repartition(n)` to "
# "increase the number of "
# "dataset blocks."
# )

# TODO update the warnings above
if len(read_tasks) < requested_parallelism:
ericl marked this conversation as resolved.
Show resolved Hide resolved
desired_splits_per_file = requested_parallelism / len(read_tasks)
print("Desired splits per file", desired_splits_per_file)
if inmemory_size:
ericl marked this conversation as resolved.
Show resolved Hide resolved
expected_block_size = inmemory_size / len(read_tasks)
print("Expected block size", expected_block_size)
size_based_splits = math.floor(
max(1, expected_block_size / ctx.target_max_block_size)
)
print("Size based splits", size_based_splits)
else:
size_based_splits = 1
k = math.ceil(desired_splits_per_file / size_based_splits)
estimated_num_blocks = len(read_tasks) * size_based_splits * k
ericl marked this conversation as resolved.
Show resolved Hide resolved
print("Additional split factor", k)
for r in read_tasks:
r._set_additional_split_factor(k)
print("Estimated num blocks", estimated_num_blocks)
else:
print("No additional splits are needed")
estimated_num_blocks = len(read_tasks)

read_stage_name = f"Read{datasource.get_name()}"
available_cpu_slots = ray.available_resources().get("CPU", 1)
Expand All @@ -423,10 +455,11 @@ def read_datasource(
ray_remote_args=ray_remote_args,
owned_by_consumer=False,
)
block_list._estimated_num_blocks = estimated_num_blocks

# TODO(hchen): move _get_read_tasks and related code to the Read physical operator,
# after removing LazyBlockList code path.
read_op = Read(datasource, read_tasks, ray_remote_args)
read_op = Read(datasource, read_tasks, estimated_num_blocks, ray_remote_args)
logical_plan = LogicalPlan(read_op)

return Dataset(
Expand Down Expand Up @@ -1947,7 +1980,7 @@ def _get_read_tasks(
parallelism: int,
local_uri: bool,
kwargs: dict,
) -> Tuple[int, int, List[ReadTask]]:
) -> Tuple[int, int, Optional[int], List[ReadTask]]:
"""Generates read tasks.

Args:
Expand All @@ -1959,19 +1992,20 @@ def _get_read_tasks(

Returns:
Request parallelism from the datasource, the min safe parallelism to avoid
OOM, and the list of read tasks generated.
OOM, the estimated inmemory data size, and list of read tasks generated.
"""
kwargs = _unwrap_arrow_serialization_workaround(kwargs)
if local_uri:
kwargs["local_uri"] = local_uri
DataContext._set_current(ctx)
reader = ds.create_reader(**kwargs)
requested_parallelism, min_safe_parallelism = _autodetect_parallelism(
requested_parallelism, min_safe_parallelism, mem_size = _autodetect_parallelism(
parallelism, cur_pg, DataContext.get_current(), reader
)
return (
requested_parallelism,
min_safe_parallelism,
mem_size,
reader.get_read_tasks(requested_parallelism),
)

Expand Down