diff --git a/python/ray/data/_internal/plan.py b/python/ray/data/_internal/plan.py index e61f356afd4a..268af0d5b95d 100644 --- a/python/ray/data/_internal/plan.py +++ b/python/ray/data/_internal/plan.py @@ -125,6 +125,8 @@ def __init__( self._stages_after_snapshot = [] # Cache of optimized stages. self._last_optimized_stages = None + # Cached schema. + self._schema = None self._dataset_uuid = dataset_uuid or uuid.uuid4().hex if not stats.dataset_uuid: @@ -381,7 +383,14 @@ def schema( """ from ray.data._internal.stage_impl import RandomizeBlocksStage + if self._schema is not None: + return self._schema + if self._stages_after_snapshot: + # TODO(swang): There are several other stage types that could + # inherit the schema or we can compute the schema without having to + # execute any of the dataset: limit, filter, map_batches for + # add/drop columns, etc. if fetch_if_missing: if isinstance(self._stages_after_snapshot[-1], RandomizeBlocksStage): # TODO(ekl): this is a hack to optimize the case where we have a @@ -412,7 +421,11 @@ def schema( blocks = self._snapshot_blocks if not blocks: return None - return self._get_unified_blocks_schema(blocks, fetch_if_missing) + self._schema = self._get_unified_blocks_schema(blocks, fetch_if_missing) + return self._schema + + def cache_schema(self, schema: Union[type, "pyarrow.lib.Schema"]): + self._schema = schema def _get_unified_blocks_schema( self, blocks: BlockList, fetch_if_missing: bool = False diff --git a/python/ray/data/dataset.py b/python/ray/data/dataset.py index e9e8bedef77c..b101a41eeff5 100644 --- a/python/ray/data/dataset.py +++ b/python/ray/data/dataset.py @@ -2304,6 +2304,7 @@ def schema(self, fetch_if_missing: bool = True) -> Optional["Schema"]: # of this Dataset, which we then execute to get its schema. base_schema = self.limit(1)._plan.schema(fetch_if_missing=fetch_if_missing) if base_schema: + self._plan.cache_schema(base_schema) return Schema(base_schema) else: return None diff --git a/python/ray/data/tests/test_consumption.py b/python/ray/data/tests/test_consumption.py index 5e4b9c41bf14..1950ba59767d 100644 --- a/python/ray/data/tests/test_consumption.py +++ b/python/ray/data/tests/test_consumption.py @@ -241,6 +241,22 @@ def test_schema_lazy(ray_start_regular_shared): assert ds._plan.execute()._num_computed() == 0 +def test_schema_cached(ray_start_regular_shared): + def check_schema_cached(ds): + schema = ds.schema() + assert schema.names == ["a"] + cached_schema = ds.schema(fetch_if_missing=False) + assert cached_schema is not None + assert schema == cached_schema + + ds = ray.data.from_items([{"a": i} for i in range(100)], parallelism=10) + check_schema_cached(ds) + + # Add a map_batches stage so that we are forced to compute the schema. + ds = ds.map_batches(lambda x: x) + check_schema_cached(ds) + + def test_columns(ray_start_regular_shared): ds = ray.data.range(1) assert ds.columns() == ds.schema().names