Skip to content

Commit

Permalink
Cache schema and test (#37103) (#37201)
Browse files Browse the repository at this point in the history
Cache the computed schema to avoid re-executing.

Closes #37077.
  • Loading branch information
stephanie-wang authored Jul 7, 2023
1 parent 47ec25b commit 12a569f
Show file tree
Hide file tree
Showing 3 changed files with 31 additions and 1 deletion.
15 changes: 14 additions & 1 deletion python/ray/data/_internal/plan.py
Original file line number Diff line number Diff line change
Expand Up @@ -125,6 +125,8 @@ def __init__(
self._stages_after_snapshot = []
# Cache of optimized stages.
self._last_optimized_stages = None
# Cached schema.
self._schema = None

self._dataset_uuid = dataset_uuid or uuid.uuid4().hex
if not stats.dataset_uuid:
Expand Down Expand Up @@ -381,7 +383,14 @@ def schema(
"""
from ray.data._internal.stage_impl import RandomizeBlocksStage

if self._schema is not None:
return self._schema

if self._stages_after_snapshot:
# TODO(swang): There are several other stage types that could
# inherit the schema or we can compute the schema without having to
# execute any of the dataset: limit, filter, map_batches for
# add/drop columns, etc.
if fetch_if_missing:
if isinstance(self._stages_after_snapshot[-1], RandomizeBlocksStage):
# TODO(ekl): this is a hack to optimize the case where we have a
Expand Down Expand Up @@ -412,7 +421,11 @@ def schema(
blocks = self._snapshot_blocks
if not blocks:
return None
return self._get_unified_blocks_schema(blocks, fetch_if_missing)
self._schema = self._get_unified_blocks_schema(blocks, fetch_if_missing)
return self._schema

def cache_schema(self, schema: Union[type, "pyarrow.lib.Schema"]):
self._schema = schema

def _get_unified_blocks_schema(
self, blocks: BlockList, fetch_if_missing: bool = False
Expand Down
1 change: 1 addition & 0 deletions python/ray/data/dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -2304,6 +2304,7 @@ def schema(self, fetch_if_missing: bool = True) -> Optional["Schema"]:
# of this Dataset, which we then execute to get its schema.
base_schema = self.limit(1)._plan.schema(fetch_if_missing=fetch_if_missing)
if base_schema:
self._plan.cache_schema(base_schema)
return Schema(base_schema)
else:
return None
Expand Down
16 changes: 16 additions & 0 deletions python/ray/data/tests/test_consumption.py
Original file line number Diff line number Diff line change
Expand Up @@ -241,6 +241,22 @@ def test_schema_lazy(ray_start_regular_shared):
assert ds._plan.execute()._num_computed() == 0


def test_schema_cached(ray_start_regular_shared):
def check_schema_cached(ds):
schema = ds.schema()
assert schema.names == ["a"]
cached_schema = ds.schema(fetch_if_missing=False)
assert cached_schema is not None
assert schema == cached_schema

ds = ray.data.from_items([{"a": i} for i in range(100)], parallelism=10)
check_schema_cached(ds)

# Add a map_batches stage so that we are forced to compute the schema.
ds = ds.map_batches(lambda x: x)
check_schema_cached(ds)


def test_columns(ray_start_regular_shared):
ds = ray.data.range(1)
assert ds.columns() == ds.schema().names
Expand Down

0 comments on commit 12a569f

Please sign in to comment.