Skip to content

Commit

Permalink
[Datasets] Improve performance of DefaultFileMetaProvider. (ray-proje…
Browse files Browse the repository at this point in the history
…ct#33117)

This PR improves the performance of the DefaultFileMetaProvider. Previously, DefaultFileMetaProvider would serially expand and fetch the file size for a large list of directories and files, respectively. This PR optimizes this by parallelizing directory expansion and file size fetching over Ray tasks. Also, in the common case that all file paths share the same parent directory (or base directory, if using partitioning), we do a single ListObjectsV2 call on the directory followed by a client-side filter, which reduces a 90 second parallel file size fetch to a 0.8 second request + client-side filter.

Signed-off-by: elliottower <elliot@elliottower.com>
  • Loading branch information
clarkzinzow authored and elliottower committed Apr 22, 2023
1 parent b05448e commit 8e0979f
Show file tree
Hide file tree
Showing 6 changed files with 631 additions and 119 deletions.
104 changes: 55 additions & 49 deletions python/ray/data/datasource/file_based_datasource.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import itertools
import logging
import pathlib
import posixpath
Expand All @@ -13,12 +14,17 @@
Optional,
Tuple,
Union,
TypeVar,
)

import numpy as np

from ray.data._internal.arrow_block import ArrowRow
from ray.data._internal.delegating_block_builder import DelegatingBlockBuilder
from ray.data._internal.execution.interfaces import TaskContext
from ray.data._internal.output_buffer import BlockOutputBuffer
from ray.data._internal.progress_bar import ProgressBar
from ray.data._internal.remote_fn import cached_remote_fn
from ray.data._internal.util import _check_pyarrow_version, _resolve_custom_scheme
from ray.data.block import Block, BlockAccessor
from ray.data.context import DatasetContext
Expand All @@ -45,6 +51,13 @@
logger = logging.getLogger(__name__)


# We should parallelize file size fetch operations beyond this threshold.
FILE_SIZE_FETCH_PARALLELIZATION_THRESHOLD = 16

# 16 file size fetches from S3 takes ~1.5 seconds with Arrow's S3FileSystem.
PATHS_PER_FILE_SIZE_FETCH_TASK = 16


@DeveloperAPI
class BlockWritePathProvider:
"""Abstract callable that provides concrete output paths when writing
Expand Down Expand Up @@ -288,9 +301,7 @@ def write(

def write_block(write_path: str, block: Block):
logger.debug(f"Writing {write_path} file.")
fs = filesystem
if isinstance(fs, _S3FileSystemWrapper):
fs = fs.unwrap()
fs = _unwrap_s3_serialization_workaround(filesystem)
if _block_udf is not None:
block = _block_udf(block)

Expand Down Expand Up @@ -373,8 +384,9 @@ def __init__(
self._block_udf = _block_udf
self._reader_args = reader_args
paths, self._filesystem = _resolve_paths_and_filesystem(paths, filesystem)
self._paths, self._file_sizes = meta_provider.expand_paths(
paths, self._filesystem
self._paths, self._file_sizes = map(
list,
zip(*meta_provider.expand_paths(paths, self._filesystem, partitioning)),
)
if self._partition_filter is not None:
# Use partition filter to skip files which are not needed.
Expand Down Expand Up @@ -418,8 +430,7 @@ def read_files(
fs: Union["pyarrow.fs.FileSystem", _S3FileSystemWrapper],
) -> Iterable[Block]:
logger.debug(f"Reading {len(read_paths)} files.")
if isinstance(fs, _S3FileSystemWrapper):
fs = fs.unwrap()
fs = _unwrap_s3_serialization_workaround(filesystem)
ctx = DatasetContext.get_current()
output_buffer = BlockOutputBuffer(
block_udf=_block_udf, target_max_block_size=ctx.target_max_block_size
Expand Down Expand Up @@ -672,48 +683,6 @@ def _resolve_paths_and_filesystem(
return resolved_paths, filesystem


def _expand_directory(
path: str,
filesystem: "pyarrow.fs.FileSystem",
exclude_prefixes: Optional[List[str]] = None,
) -> List[str]:
"""
Expand the provided directory path to a list of file paths.
Args:
path: The directory path to expand.
filesystem: The filesystem implementation that should be used for
reading these files.
exclude_prefixes: The file relative path prefixes that should be
excluded from the returned file set. Default excluded prefixes are
"." and "_".
Returns:
A list of file paths contained in the provided directory.
"""
if exclude_prefixes is None:
exclude_prefixes = [".", "_"]

from pyarrow.fs import FileSelector

selector = FileSelector(path, recursive=True)
files = filesystem.get_file_info(selector)
base_path = selector.base_dir
filtered_paths = []
for file_ in files:
if not file_.is_file:
continue
file_path = file_.path
if not file_path.startswith(base_path):
continue
relative = file_path[len(base_path) :]
if any(relative.startswith(prefix) for prefix in exclude_prefixes):
continue
filtered_paths.append((file_path, file_))
# We sort the paths to guarantee a stable order.
return zip(*sorted(filtered_paths, key=lambda x: x[0]))


def _is_url(path) -> bool:
return urllib.parse.urlparse(path).scheme != ""

Expand Down Expand Up @@ -752,6 +721,15 @@ def _wrap_s3_serialization_workaround(filesystem: "pyarrow.fs.FileSystem"):
return filesystem


def _unwrap_s3_serialization_workaround(
filesystem: Union["pyarrow.fs.FileSystem", "_S3FileSystemWrapper"]
):
if isinstance(filesystem, _S3FileSystemWrapper):
return filesystem.unwrap()
else:
return filesystem


class _S3FileSystemWrapper:
def __init__(self, fs: "pyarrow.fs.S3FileSystem"):
self._fs = fs
Expand Down Expand Up @@ -792,3 +770,31 @@ def _resolve_kwargs(
kwarg_overrides = kwargs_fn()
kwargs.update(kwarg_overrides)
return kwargs


Uri = TypeVar("Uri")
Meta = TypeVar("Meta")


def _fetch_metadata_parallel(
uris: List[Uri],
fetch_func: Callable[[List[Uri]], List[Meta]],
desired_uris_per_task: int,
**ray_remote_args,
) -> Iterator[Meta]:
"""Fetch file metadata in parallel using Ray tasks."""
remote_fetch_func = cached_remote_fn(fetch_func, num_cpus=0.5)
if ray_remote_args:
remote_fetch_func = remote_fetch_func.options(**ray_remote_args)
# Choose a parallelism that results in a # of metadata fetches per task that
# dominates the Ray task overhead while ensuring good parallelism.
# Always launch at least 2 parallel fetch tasks.
parallelism = max(len(uris) // desired_uris_per_task, 2)
metadata_fetch_bar = ProgressBar("Metadata Fetch Progress", total=parallelism)
fetch_tasks = []
for uri_chunk in np.array_split(uris, parallelism):
if len(uri_chunk) == 0:
continue
fetch_tasks.append(remote_fetch_func.remote(uri_chunk))
results = metadata_fetch_bar.fetch_until_complete(fetch_tasks)
yield from itertools.chain.from_iterable(results)
Loading

0 comments on commit 8e0979f

Please sign in to comment.