Skip to content

Commit

Permalink
custom load input in deltalake-polars
Browse files Browse the repository at this point in the history
  • Loading branch information
ion-elgreco committed Jan 20, 2024
1 parent 7fa6bcd commit a643d87
Show file tree
Hide file tree
Showing 5 changed files with 143 additions and 55 deletions.
Original file line number Diff line number Diff line change
@@ -1,8 +1,7 @@
from typing import Any, Dict, Optional, Sequence, Tuple, Type, Union
from typing import Any, Dict, Optional, Sequence, Tuple, Type

import pandas as pd
import pyarrow as pa
import pyarrow.dataset as ds
from dagster._core.storage.db_io_manager import (
DbTypeHandler,
)
Expand All @@ -15,10 +14,8 @@

class DeltaLakePandasTypeHandler(DeltalakeBaseArrowTypeHandler[pd.DataFrame]):
def from_arrow(
self, obj: Union[ds.dataset, pa.RecordBatchReader], target_type: Type[pd.DataFrame]
self, obj: pa.RecordBatchReader, target_type: Type[pd.DataFrame]
) -> pd.DataFrame:
if isinstance(obj, ds.Dataset):
obj = obj.scanner().to_reader()
return obj.read_pandas()

def to_arrow(self, obj: pd.DataFrame) -> Tuple[pa.RecordBatchReader, Dict[str, Any]]:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -3,41 +3,63 @@
import polars as pl
import pyarrow as pa
import pyarrow.dataset as ds
from dagster import InputContext
from dagster._core.storage.db_io_manager import (
DbTypeHandler,
TableSlice,
)
from dagster_deltalake.handler import (
DeltalakeBaseArrowTypeHandler,
DeltaLakePyArrowTypeHandler,
_table_reader,
)
from dagster_deltalake.io_manager import DeltaLakeIOManager
from dagster_deltalake.io_manager import DeltaLakeIOManager, TableConnection

PolarsTypes = Union[pl.DataFrame, pl.LazyFrame]


class DeltaLakePolarsTypeHandler(DeltalakeBaseArrowTypeHandler[PolarsTypes]):
def from_arrow(
self, obj: Union[ds.Dataset, pa.RecordBatchReader], target_type: Type[PolarsTypes]
self,
obj: Union[ds.Dataset, pa.RecordBatchReader],
target_type: Type[PolarsTypes],
) -> PolarsTypes:
if isinstance(obj, pa.RecordBatchReader):
df = pl.DataFrame(obj.read_all())
if target_type == pl.LazyFrame:
## Maybe allow this but raise a warning that the data has been sliced earlier otherwise it
## would have received a ds.dataset
return df.lazy()
return pl.DataFrame(obj.read_all())
elif isinstance(obj, ds.Dataset):
df = pl.scan_pyarrow_dataset(obj)
if target_type == pl.DataFrame:
return df.collect()
else:
return df
df = pl.scan_pyarrow_dataset(obj)
if target_type == pl.DataFrame:
return df.collect()
else:
return df
raise NotImplementedError("Unsupported objected passed of type: %s", type(obj))

def to_arrow(self, obj: PolarsTypes) -> Tuple[pa.RecordBatchReader, Dict[str, Any]]:
if isinstance(obj, pl.LazyFrame):
obj = obj.collect()
return obj.to_arrow().to_reader(), {"large_dtypes": True}

def load_input(
self,
context: InputContext,
table_slice: TableSlice,
connection: TableConnection,
) -> PolarsTypes:
"""Loads the input as a pyarrow Table or RecordBatchReader."""
dataset = _table_reader(table_slice, connection)

if table_slice.columns is not None:
if context.dagster_type.typing_type == pl.LazyFrame:
return self.from_arrow(dataset, context.dagster_type.typing_type).select(
table_slice.columns
)
else:
scanner = dataset.scanner(columns=table_slice.columns)
return self.from_arrow(scanner.to_reader(), context.dagster_type.typing_type)
else:
return self.from_arrow(dataset, context.dagster_type.typing_type)

@property
def supported_types(self) -> Sequence[Type[object]]:
return [pl.DataFrame, pl.LazyFrame]
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -143,19 +143,33 @@ def b_plus_one_columns(b_df: pl.DataFrame) -> pl.DataFrame:
return b_df + 1


def test_loading_columns(tmp_path, io_manager):
@asset(
key_prefix=["my_schema"], ins={"b_df_lazy": AssetIn("b_df_lazy", metadata={"columns": ["a"]})}
)
def b_plus_one_columns_lazy(b_df_lazy: pl.LazyFrame) -> pl.LazyFrame:
return b_df_lazy.select(pl.all() + 1)


@pytest.mark.parametrize(
"asset1,asset2,asset1_path,asset2_path",
[
(b_df, b_plus_one_columns, "b_df", "b_plus_one_columns"),
(b_df_lazy, b_plus_one_columns_lazy, "b_df_lazy", "b_plus_one_columns_lazy"),
],
)
def test_loading_columns(tmp_path, io_manager, asset1, asset2, asset1_path, asset2_path):
resource_defs = {"io_manager": io_manager}

# materialize asset twice to ensure that tables get properly deleted
for _ in range(2):
res = materialize([b_df, b_plus_one_columns], resources=resource_defs)
res = materialize([asset1, asset2], resources=resource_defs)
assert res.success

dt = DeltaTable(os.path.join(tmp_path, "my_schema/b_df"))
dt = DeltaTable(os.path.join(tmp_path, "my_schema/" + asset1_path))
out_df = dt.to_pyarrow_table()
assert out_df["a"].to_pylist() == [1, 2, 3]

dt = DeltaTable(os.path.join(tmp_path, "my_schema/b_plus_one_columns"))
dt = DeltaTable(os.path.join(tmp_path, "my_schema/" + asset2_path))
out_df = dt.to_pyarrow_table()
assert out_df["a"].to_pylist() == [2, 3, 4]

Expand Down Expand Up @@ -205,36 +219,64 @@ def daily_partitioned(context) -> pl.DataFrame:
)


def test_time_window_partitioned_asset(tmp_path, io_manager):
@asset(
partitions_def=DailyPartitionsDefinition(start_date="2022-01-01"),
key_prefix=["my_schema"],
metadata={"partition_expr": "time"},
config_schema={"value": str},
)
def daily_partitioned_lazy(context) -> pl.LazyFrame:
partition = datetime.strptime(
context.asset_partition_key_for_output(), DELTA_DATE_FORMAT
).date()
value = context.op_config["value"]

return pl.LazyFrame(
{
"time": [partition, partition, partition],
"a": [value, value, value],
"b": [4, 5, 6],
}
)


@pytest.mark.parametrize(
"asset1,asset1_path",
[
(daily_partitioned, "daily_partitioned"),
(daily_partitioned_lazy, "daily_partitioned_lazy"),
],
)
def test_time_window_partitioned_asset(tmp_path, io_manager, asset1, asset1_path):
resource_defs = {"io_manager": io_manager}

materialize(
[daily_partitioned],
[asset1],
partition_key="2022-01-01",
resources=resource_defs,
run_config={"ops": {"my_schema__daily_partitioned": {"config": {"value": "1"}}}},
run_config={"ops": {"my_schema__" + asset1_path: {"config": {"value": "1"}}}},
)

dt = DeltaTable(os.path.join(tmp_path, "my_schema/daily_partitioned"))
dt = DeltaTable(os.path.join(tmp_path, "my_schema/" + asset1_path))
out_df = dt.to_pyarrow_table()
assert out_df["a"].to_pylist() == ["1", "1", "1"]

materialize(
[daily_partitioned],
[asset1],
partition_key="2022-01-02",
resources=resource_defs,
run_config={"ops": {"my_schema__daily_partitioned": {"config": {"value": "2"}}}},
run_config={"ops": {"my_schema__" + asset1_path: {"config": {"value": "2"}}}},
)

dt.update_incremental()
out_df = dt.to_pyarrow_table()
assert sorted(out_df["a"].to_pylist()) == ["1", "1", "1", "2", "2", "2"]

materialize(
[daily_partitioned],
[asset1],
partition_key="2022-01-01",
resources=resource_defs,
run_config={"ops": {"my_schema__daily_partitioned": {"config": {"value": "3"}}}},
run_config={"ops": {"my_schema__" + asset1_path: {"config": {"value": "3"}}}},
)

dt.update_incremental()
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,7 @@

class DeltalakeBaseArrowTypeHandler(DbTypeHandler[T], Generic[T]):
@abstractmethod
def from_arrow(self, obj: Union[ds.Dataset, pa.RecordBatchReader], target_type: type) -> T:
def from_arrow(self, obj: pa.RecordBatchReader, target_type: type) -> T:
pass

@abstractmethod
Expand Down Expand Up @@ -110,41 +110,19 @@ def load_input(
connection: TableConnection,
) -> T:
"""Loads the input as a pyarrow Table or RecordBatchReader."""
table = DeltaTable(
table_uri=connection.table_uri, storage_options=connection.storage_options
)

partition_expr = None
if table_slice.partition_dimensions is not None:
partition_filters = partition_dimensions_to_dnf(
partition_dimensions=table_slice.partition_dimensions,
table_schema=table.schema(),
)
if partition_filters is not None:
partition_expr = _filters_to_expression([partition_filters])

dataset = table.to_pyarrow_dataset()
if partition_expr is not None:
dataset = dataset.filter(expression=partition_expr)
dataset = _table_reader(table_slice, connection)

if context.dagster_type.typing_type == ds.Dataset:
if table_slice.columns is not None:
raise ValueError("Cannot select columns when loading as Dataset.")
return dataset

if table_slice.columns is None:
return self.from_arrow(dataset, context.dagster_type.typing_type)
else:
scanner = dataset.scanner(columns=table_slice.columns)
return self.from_arrow(scanner.to_reader(), context.dagster_type.typing_type)
scanner = dataset.scanner(columns=table_slice.columns)
return self.from_arrow(scanner.to_reader(), context.dagster_type.typing_type)


class DeltaLakePyArrowTypeHandler(DeltalakeBaseArrowTypeHandler[ArrowTypes]):
def from_arrow(
self, obj: Union[ds.Dataset, pa.RecordBatchReader], target_type: Type[ArrowTypes]
) -> ArrowTypes:
if isinstance(obj, ds.Dataset):
obj = obj.scanner().to_reader()
def from_arrow(self, obj: pa.RecordBatchReader, target_type: Type[ArrowTypes]) -> ArrowTypes:
if target_type == pa.Table:
return obj.read_all()
return obj
Expand Down Expand Up @@ -238,3 +216,22 @@ def _get_partition_stats(dt: DeltaTable, partition_filters=None):
}

return table, stats


def _table_reader(table_slice: TableSlice, connection: TableConnection) -> ds.Dataset:
table = DeltaTable(table_uri=connection.table_uri, storage_options=connection.storage_options)

partition_expr = None
if table_slice.partition_dimensions is not None:
partition_filters = partition_dimensions_to_dnf(
partition_dimensions=table_slice.partition_dimensions,
table_schema=table.schema(),
)
if partition_filters is not None:
partition_expr = _filters_to_expression([partition_filters])

dataset = table.to_pyarrow_dataset()
if partition_expr is not None:
dataset = dataset.filter(expression=partition_expr)

return dataset
Original file line number Diff line number Diff line change
Expand Up @@ -32,3 +32,33 @@ def read_table(delta_table: DeltaTableResource):
)
},
)


def test_resource_versioned(tmp_path):
data = pa.table(
{
"a": pa.array([1, 2, 3], type=pa.int32()),
"b": pa.array([5, 6, 7], type=pa.int32()),
}
)

@asset
def create_table(delta_table: DeltaTableResource):
write_deltalake(delta_table.url, data, storage_options=delta_table.storage_options.dict())
write_deltalake(
delta_table.url, data, storage_options=delta_table.storage_options.dict(), mode="append"
)

@asset
def read_table(delta_table: DeltaTableResource):
res = delta_table.load().to_pyarrow_table()
assert res.equals(data)

materialize(
[create_table, read_table],
resources={
"delta_table": DeltaTableResource(
url=os.path.join(tmp_path, "table"), storage_options=LocalConfig(), version=0
)
},
)

0 comments on commit a643d87

Please sign in to comment.