Skip to content

Commit

Permalink
Add a TFDS data source for Space dataset
Browse files Browse the repository at this point in the history
  • Loading branch information
coufon committed Dec 26, 2023
1 parent 01bf674 commit 5a6a839
Show file tree
Hide file tree
Showing 11 changed files with 330 additions and 30 deletions.
10 changes: 10 additions & 0 deletions python/src/space/core/datasets.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
import pyarrow as pa

from space.core.runners import LocalRunner
from space.core.serializers.base import DictSerializer
from space.core.storage import Storage


Expand Down Expand Up @@ -48,6 +49,15 @@ def load(cls, location: str) -> Dataset:
"""Load an existing dataset from the given location."""
return Dataset(Storage.load(location))

@property
def schema(self) -> pa.Schema:
"""Return the dataset schema."""
return self._storage.logical_schema

def serializer(self) -> DictSerializer:
"""Return a serializer (deserializer) for the dataset."""
return DictSerializer(self.schema)

def local(self) -> LocalRunner:
"""Get a runner that runs operations locally."""
return LocalRunner(self._storage)
2 changes: 1 addition & 1 deletion python/src/space/core/ops/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,4 +16,4 @@

from space.core.ops.append import LocalAppendOp
from space.core.ops.delete import FileSetDeleteOp
from space.core.ops.read import FileSetReadOp
from space.core.ops.read import FileSetReadOp, ReadOptions
39 changes: 33 additions & 6 deletions python/src/space/core/ops/read.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
from abc import abstractmethod
from typing import Iterator, Dict, List, Tuple, Optional

from dataclasses import dataclass
import numpy as np
import pyarrow as pa
import pyarrow.parquet as pq
Expand All @@ -35,6 +36,18 @@
_RECORD_KEY_FIELD = "__RECORD_KEY"


@dataclass
class ReadOptions:
"""Options of reading data."""
# Filters on index fields.
filter_: Optional[pc.Expression] = None
# When specified, only read the given fields instead of all fields.
fields: Optional[List[str]] = None
# If true, read the references (e.g., address) of read record fields instead
# of values.
reference_read: bool = False


class BaseReadOp(BaseOp):
"""Abstract base read operation class."""

Expand All @@ -56,36 +69,50 @@ def __init__(self,
location: str,
metadata: meta.StorageMetadata,
file_set: runtime.FileSet,
filter_: Optional[pc.Expression] = None):
options: Optional[ReadOptions] = None):
StoragePaths.__init__(self, location)

# TODO: to validate that filter_ does not contain record files.

self._metadata = metadata
self._file_set = file_set

# TODO: to validate options, e.g., fields are valid.
self._options = ReadOptions() if options is None else options

record_fields = set(self._metadata.schema.record_fields)
self._physical_schema = arrow.arrow_schema(self._metadata.schema.fields,
record_fields,
physical=True)

if self._options.fields is None:
self._selected_fields = [f.name for f in self._physical_schema]
else:
self._selected_fields = self._options.fields

self._index_fields, self._record_fields = arrow.classify_fields(
self._physical_schema, record_fields, selected_fields=None)
self._physical_schema,
record_fields,
selected_fields=set(self._selected_fields))

self._index_field_ids = set(schema_utils.field_ids(self._index_fields))

self._record_fields_dict: Dict[int, schema_utils.Field] = {}
for f in self._record_fields:
self._record_fields_dict[f.field_id] = f

self._filter = filter_

def __iter__(self) -> Iterator[pa.Table]:
for file in self._file_set.index_files:
yield self._read_index_and_record(file.path)

def _read_index_and_record(self, index_path: str) -> pa.Table:
index_data = pq.read_table(self.full_path(index_path),
filters=self._filter) # type: ignore[arg-type]
index_data = pq.read_table(
self.full_path(index_path),
columns=self._selected_fields,
filters=self._options.filter_) # type: ignore[arg-type]

if self._options.reference_read:
return index_data

index_column_ids: List[int] = []
record_columns: List[Tuple[int, pa.Field]] = []
Expand Down
28 changes: 23 additions & 5 deletions python/src/space/core/runners.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,13 +16,16 @@

from __future__ import annotations
from abc import ABC, abstractmethod
from typing import Iterator, Optional
from typing import Iterator, List, Optional

from absl import logging # type: ignore[import-untyped]
import pyarrow as pa
import pyarrow.compute as pc

from space.core.ops import FileSetDeleteOp, FileSetReadOp, LocalAppendOp
from space.core.ops import FileSetDeleteOp
from space.core.ops import FileSetReadOp
from space.core.ops import LocalAppendOp
from space.core.ops import ReadOptions
from space.core.ops.base import InputData
import space.core.proto.runtime_pb2 as runtime
from space.core.storage import Storage
Expand All @@ -38,9 +41,20 @@ def __init__(self, storage: Storage):
@abstractmethod
def read(self,
filter_: Optional[pc.Expression] = None,
snapshot_id: Optional[int] = None) -> Iterator[pa.Table]:
fields: Optional[List[str]] = None,
snapshot_id: Optional[int] = None,
reference_read: bool = False) -> Iterator[pa.Table]:
"""Read data from the dataset as an iterator."""

def read_all(self,
filter_: Optional[pc.Expression] = None,
fields: Optional[List[str]] = None,
snapshot_id: Optional[int] = None,
reference_read: bool = False) -> pa.Table:
"""Read data from the dataset as an Arrow table."""
return pa.concat_tables(
list(self.read(filter_, fields, snapshot_id, reference_read)))

@abstractmethod
def append(self, data: InputData) -> runtime.JobResult:
"""Append data into the dataset."""
Expand Down Expand Up @@ -72,12 +86,16 @@ class LocalRunner(BaseRunner):

def read(self,
filter_: Optional[pc.Expression] = None,
snapshot_id: Optional[int] = None) -> Iterator[pa.Table]:
fields: Optional[List[str]] = None,
snapshot_id: Optional[int] = None,
reference_read: bool = False) -> Iterator[pa.Table]:
return iter(
FileSetReadOp(
self._storage.location, self._storage.metadata,
self._storage.data_files(filter_, snapshot_id=snapshot_id),
filter_))
ReadOptions(filter_=filter_,
fields=fields,
reference_read=reference_read)))

def append(self, data: InputData) -> runtime.JobResult:
op = LocalAppendOp(self._storage.location, self._storage.metadata)
Expand Down
35 changes: 25 additions & 10 deletions python/src/space/core/serializers/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@
"""Serializers (and deserializers) for unstructured record fields."""

from abc import ABC, abstractmethod
from typing import Any, Dict, List
from typing import Any, Dict, List, Optional
from typing_extensions import TypeAlias

import pyarrow as pa
Expand Down Expand Up @@ -63,26 +63,41 @@ def __init__(self, logical_schema: pa.Schema):
if isinstance(field.type, FieldSerializer):
self._serializers[field.name] = field.type

def field_serializer(self, field: str) -> Optional[FieldSerializer]:
"""Return the FieldSerializer of a given field, or None if not found."""
if field not in self._serializers:
return None

return self._serializers[field]

def serialize(self, value: DictData) -> DictData:
"""Serialize a value.
Args:
value: a dict of numpy-like nested dicts.
"""
for name, ser in self._serializers.items():
if name in value:
value[name] = [ser.serialize(d) for d in value[name]]
result = {}
for field_name, value_batch in value.items():
if field_name in self._serializers:
ser = self._serializers[field_name]
result[field_name] = [ser.serialize(v) for v in value_batch]
else:
result[field_name] = value_batch

return value
return result

def deserialize(self, value_bytes: DictData) -> DictData:
"""Deserialize a dict of bytes to a dict of values.
Returns:
A dict of numpy-like nested dicts.
"""
for name, ser in self._serializers.items():
if name in value_bytes:
value_bytes[name] = [ser.deserialize(d) for d in value_bytes[name]]

return value_bytes
result = {}
for field_name, value_batch in value_bytes.items():
if field_name in self._serializers:
ser = self._serializers[field_name]
result[field_name] = [ser.deserialize(v) for v in value_batch]
else:
result[field_name] = value_batch

return result
15 changes: 12 additions & 3 deletions python/src/space/core/storage.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,9 +46,12 @@ def __init__(self, location: str, metadata: meta.StorageMetadata):
self._fs = create_fs(location)

record_fields = set(self._metadata.schema.record_fields)
self._physical_schema = arrow.arrow_schema(self._metadata.schema.fields,
record_fields,
physical=True)
self._logical_schema = arrow.arrow_schema(self._metadata.schema.fields,
record_fields,
physical=False)
self._physical_schema = arrow.logical_to_physical_schema(
self._logical_schema, record_fields)

self._field_name_ids: Dict[str, int] = arrow.field_name_to_id_dict(
self._physical_schema)

Expand All @@ -57,6 +60,11 @@ def metadata(self) -> meta.StorageMetadata:
"""Return the storage metadata."""
return self._metadata

@property
def logical_schema(self) -> pa.Schema:
"""Return the user specified schema."""
return self._logical_schema

@property
def physical_schema(self) -> pa.Schema:
"""Return the physcal schema that uses reference for record fields."""
Expand Down Expand Up @@ -89,6 +97,7 @@ def create(
# TODO: to verify that location is an empty directory.
# TODO: to verify primary key fields and record_fields (and types) are
# valid.
# TODO: to auto infer record_fields.

field_id_mgr = FieldIdManager()
schema = field_id_mgr.assign_field_ids(schema)
Expand Down
Loading

0 comments on commit 5a6a839

Please sign in to comment.