Skip to content

Commit

Permalink
lazyframe_params
Browse files Browse the repository at this point in the history
  • Loading branch information
ion-elgreco committed Jan 27, 2024
1 parent 4a7b917 commit c56f856
Show file tree
Hide file tree
Showing 10 changed files with 524 additions and 46 deletions.
Original file line number Diff line number Diff line change
@@ -1,29 +1,68 @@
from typing import Any, Dict, Optional, Sequence, Tuple, Type
from typing import Any, Dict, Optional, Sequence, Tuple, Type, Union

import polars as pl
import pyarrow as pa
import pyarrow.dataset as ds
from dagster import InputContext
from dagster._core.storage.db_io_manager import (
DbTypeHandler,
TableSlice,
)
from dagster_deltalake.handler import (
DeltalakeBaseArrowTypeHandler,
DeltaLakePyArrowTypeHandler,
_table_reader,
)
from dagster_deltalake.io_manager import DeltaLakeIOManager
from dagster_deltalake.io_manager import DeltaLakeIOManager, TableConnection

PolarsTypes = Union[pl.DataFrame, pl.LazyFrame]

class DeltaLakePolarsTypeHandler(DeltalakeBaseArrowTypeHandler[pl.DataFrame]):

class DeltaLakePolarsTypeHandler(DeltalakeBaseArrowTypeHandler[PolarsTypes]):
def from_arrow(
self, obj: pa.RecordBatchReader, target_type: Type[pl.DataFrame]
) -> pl.DataFrame:
return pl.from_arrow(obj) # type: ignore
self,
obj: Union[ds.Dataset, pa.RecordBatchReader],
target_type: Type[PolarsTypes],
) -> PolarsTypes:
if isinstance(obj, pa.RecordBatchReader):
return pl.DataFrame(obj.read_all())
elif isinstance(obj, ds.Dataset):
df = pl.scan_pyarrow_dataset(obj)
if target_type == pl.DataFrame:
return df.collect()
else:
return df
else:
raise NotImplementedError("Unsupported objected passed of type: %s", type(obj))

def to_arrow(self, obj: pl.DataFrame) -> Tuple[pa.RecordBatchReader, Dict[str, Any]]:
def to_arrow(self, obj: PolarsTypes) -> Tuple[pa.RecordBatchReader, Dict[str, Any]]:
if isinstance(obj, pl.LazyFrame):
obj = obj.collect()
return obj.to_arrow().to_reader(), {"large_dtypes": True}

def load_input(
self,
context: InputContext,
table_slice: TableSlice,
connection: TableConnection,
) -> PolarsTypes:
"""Loads the input as a pyarrow Table or RecordBatchReader."""
dataset = _table_reader(table_slice, connection)

if table_slice.columns is not None:
if context.dagster_type.typing_type == pl.LazyFrame:
return self.from_arrow(dataset, context.dagster_type.typing_type).select(
table_slice.columns
)
else:
scanner = dataset.scanner(columns=table_slice.columns)
return self.from_arrow(scanner.to_reader(), context.dagster_type.typing_type)
else:
return self.from_arrow(dataset, context.dagster_type.typing_type)

@property
def supported_types(self) -> Sequence[Type[object]]:
return [pl.DataFrame]
return [pl.DataFrame, pl.LazyFrame]


class DeltaLakePolarsIOManager(DeltaLakeIOManager):
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -19,13 +19,16 @@
)
from dagster._check import CheckError
from dagster_deltalake import DELTA_DATE_FORMAT, LocalConfig
from dagster_deltalake.io_manager import WriteMode
from dagster_deltalake_polars import DeltaLakePolarsIOManager
from deltalake import DeltaTable


@pytest.fixture
def io_manager(tmp_path) -> DeltaLakePolarsIOManager:
return DeltaLakePolarsIOManager(root_uri=str(tmp_path), storage_options=LocalConfig())
return DeltaLakePolarsIOManager(
root_uri=str(tmp_path), storage_options=LocalConfig(), mode=WriteMode.overwrite
)


@op(out=Out(metadata={"schema": "a_df"}))
Expand Down Expand Up @@ -73,19 +76,38 @@ def b_plus_one(b_df: pl.DataFrame) -> pl.DataFrame:
return b_df + 1


def test_deltalake_io_manager_with_assets(tmp_path, io_manager):
@asset(key_prefix=["my_schema"])
def b_df_lazy() -> pl.LazyFrame:
return pl.LazyFrame({"a": [1, 2, 3], "b": [4, 5, 6]})


@asset(key_prefix=["my_schema"])
def b_plus_one_lazy(b_df_lazy: pl.LazyFrame) -> pl.LazyFrame:
return b_df_lazy.select(pl.all() + 1)


@pytest.mark.parametrize(
"asset1,asset2,asset1_path,asset2_path",
[
(b_df, b_plus_one, "b_df", "b_plus_one"),
(b_df_lazy, b_plus_one_lazy, "b_df_lazy", "b_plus_one_lazy"),
],
)
def test_deltalake_io_manager_with_assets(
tmp_path, io_manager, asset1, asset2, asset1_path, asset2_path
):
resource_defs = {"io_manager": io_manager}

# materialize asset twice to ensure that tables get properly deleted
for _ in range(2):
res = materialize([b_df, b_plus_one], resources=resource_defs)
res = materialize([asset1, asset2], resources=resource_defs)
assert res.success

dt = DeltaTable(os.path.join(tmp_path, "my_schema/b_df"))
dt = DeltaTable(os.path.join(tmp_path, "my_schema/" + asset1_path))
out_df = dt.to_pyarrow_table()
assert out_df["a"].to_pylist() == [1, 2, 3]

dt = DeltaTable(os.path.join(tmp_path, "my_schema/b_plus_one"))
dt = DeltaTable(os.path.join(tmp_path, "my_schema/" + asset2_path))
out_df = dt.to_pyarrow_table()
assert out_df["a"].to_pylist() == [2, 3, 4]

Expand Down Expand Up @@ -124,19 +146,33 @@ def b_plus_one_columns(b_df: pl.DataFrame) -> pl.DataFrame:
return b_df + 1


def test_loading_columns(tmp_path, io_manager):
@asset(
key_prefix=["my_schema"], ins={"b_df_lazy": AssetIn("b_df_lazy", metadata={"columns": ["a"]})}
)
def b_plus_one_columns_lazy(b_df_lazy: pl.LazyFrame) -> pl.LazyFrame:
return b_df_lazy.select(pl.all() + 1)


@pytest.mark.parametrize(
"asset1,asset2,asset1_path,asset2_path",
[
(b_df, b_plus_one_columns, "b_df", "b_plus_one_columns"),
(b_df_lazy, b_plus_one_columns_lazy, "b_df_lazy", "b_plus_one_columns_lazy"),
],
)
def test_loading_columns(tmp_path, io_manager, asset1, asset2, asset1_path, asset2_path):
resource_defs = {"io_manager": io_manager}

# materialize asset twice to ensure that tables get properly deleted
for _ in range(2):
res = materialize([b_df, b_plus_one_columns], resources=resource_defs)
res = materialize([asset1, asset2], resources=resource_defs)
assert res.success

dt = DeltaTable(os.path.join(tmp_path, "my_schema/b_df"))
dt = DeltaTable(os.path.join(tmp_path, "my_schema/" + asset1_path))
out_df = dt.to_pyarrow_table()
assert out_df["a"].to_pylist() == [1, 2, 3]

dt = DeltaTable(os.path.join(tmp_path, "my_schema/b_plus_one_columns"))
dt = DeltaTable(os.path.join(tmp_path, "my_schema/" + asset2_path))
out_df = dt.to_pyarrow_table()
assert out_df["a"].to_pylist() == [2, 3, 4]

Expand Down Expand Up @@ -186,36 +222,64 @@ def daily_partitioned(context) -> pl.DataFrame:
)


def test_time_window_partitioned_asset(tmp_path, io_manager):
@asset(
partitions_def=DailyPartitionsDefinition(start_date="2022-01-01"),
key_prefix=["my_schema"],
metadata={"partition_expr": "time"},
config_schema={"value": str},
)
def daily_partitioned_lazy(context) -> pl.LazyFrame:
partition = datetime.strptime(
context.asset_partition_key_for_output(), DELTA_DATE_FORMAT
).date()
value = context.op_config["value"]

return pl.LazyFrame(
{
"time": [partition, partition, partition],
"a": [value, value, value],
"b": [4, 5, 6],
}
)


@pytest.mark.parametrize(
"asset1,asset1_path",
[
(daily_partitioned, "daily_partitioned"),
(daily_partitioned_lazy, "daily_partitioned_lazy"),
],
)
def test_time_window_partitioned_asset(tmp_path, io_manager, asset1, asset1_path):
resource_defs = {"io_manager": io_manager}

materialize(
[daily_partitioned],
[asset1],
partition_key="2022-01-01",
resources=resource_defs,
run_config={"ops": {"my_schema__daily_partitioned": {"config": {"value": "1"}}}},
run_config={"ops": {"my_schema__" + asset1_path: {"config": {"value": "1"}}}},
)

dt = DeltaTable(os.path.join(tmp_path, "my_schema/daily_partitioned"))
dt = DeltaTable(os.path.join(tmp_path, "my_schema/" + asset1_path))
out_df = dt.to_pyarrow_table()
assert out_df["a"].to_pylist() == ["1", "1", "1"]

materialize(
[daily_partitioned],
[asset1],
partition_key="2022-01-02",
resources=resource_defs,
run_config={"ops": {"my_schema__daily_partitioned": {"config": {"value": "2"}}}},
run_config={"ops": {"my_schema__" + asset1_path: {"config": {"value": "2"}}}},
)

dt.update_incremental()
out_df = dt.to_pyarrow_table()
assert sorted(out_df["a"].to_pylist()) == ["1", "1", "1", "2", "2", "2"]

materialize(
[daily_partitioned],
[asset1],
partition_key="2022-01-01",
resources=resource_defs,
run_config={"ops": {"my_schema__daily_partitioned": {"config": {"value": "3"}}}},
run_config={"ops": {"my_schema__" + asset1_path: {"config": {"value": "3"}}}},
)

dt.update_incremental()
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,131 @@
import os

import polars as pl
import pytest
from dagster import (
Out,
graph,
op,
)
from dagster_deltalake import LocalConfig
from dagster_deltalake_polars import DeltaLakePolarsIOManager
from deltalake import DeltaTable


@pytest.fixture
def io_manager(tmp_path) -> DeltaLakePolarsIOManager:
return DeltaLakePolarsIOManager(
root_uri=str(tmp_path), storage_options=LocalConfig(), mode=WriteMode.overwrite
)


@pytest.fixture
def io_manager_append(tmp_path) -> DeltaLakePolarsIOManager:
return DeltaLakePolarsIOManager(
root_uri=str(tmp_path), storage_options=LocalConfig(), mode=WriteMode.append
)


from dagster_deltalake.io_manager import WriteMode


@pytest.fixture
def io_manager_ignore(tmp_path) -> DeltaLakePolarsIOManager:
return DeltaLakePolarsIOManager(
root_uri=str(tmp_path), storage_options=LocalConfig(), mode=WriteMode.ignore
)


@op(out=Out(metadata={"schema": "a_df"}))
def a_df() -> pl.DataFrame:
return pl.DataFrame({"a": [1, 2, 3], "b": [4, 5, 6]})


@op(out=Out(metadata={"schema": "add_one"}))
def add_one(df: pl.DataFrame):
return df + 1


@graph
def add_one_to_dataframe():
add_one(a_df())


@graph
def just_a_df():
a_df()


def test_deltalake_io_manager_with_ops_appended(tmp_path, io_manager_append):
resource_defs = {"io_manager": io_manager_append}

job = just_a_df.to_job(resource_defs=resource_defs)

# run the job twice to ensure tables get appended
expected_result1 = [1, 2, 3]

for _ in range(2):
res = job.execute_in_process()

assert res.success

dt = DeltaTable(os.path.join(tmp_path, "a_df/result"))
out_df = dt.to_pyarrow_table()
assert out_df["a"].to_pylist() == expected_result1

expected_result1.extend(expected_result1)


def test_deltalake_io_manager_with_ops_ignored(tmp_path, io_manager_ignore):
resource_defs = {"io_manager": io_manager_ignore}

job = just_a_df.to_job(resource_defs=resource_defs)

# run the job 5 times to ensure tables gets ignored on each write
for _ in range(5):
res = job.execute_in_process()

assert res.success

dt = DeltaTable(os.path.join(tmp_path, "a_df/result"))
out_df = dt.to_pyarrow_table()
assert out_df["a"].to_pylist() == [1, 2, 3]

assert dt.version() == 0


@op(out=Out(metadata={"schema": "a_df", "mode": "append"}))
def a_df_custom() -> pl.DataFrame:
return pl.DataFrame({"a": [1, 2, 3], "b": [4, 5, 6]})


@graph
def add_one_to_dataframe_custom():
add_one(a_df_custom())


def test_deltalake_io_manager_with_ops_mode_overriden(tmp_path, io_manager):
resource_defs = {"io_manager": io_manager}

job = add_one_to_dataframe_custom.to_job(resource_defs=resource_defs)

# run the job twice to ensure that tables get properly deleted

a_df_result = [1, 2, 3]
add_one_result = [2, 3, 4]

for _ in range(2):
res = job.execute_in_process()

assert res.success

dt = DeltaTable(os.path.join(tmp_path, "a_df/result"))
out_df = dt.to_pyarrow_table()
assert out_df["a"].to_pylist() == a_df_result

dt = DeltaTable(os.path.join(tmp_path, "add_one/result"))
out_df = dt.to_pyarrow_table()
assert out_df["a"].to_pylist() == add_one_result

a_df_result.extend(a_df_result)
add_one_result.extend(add_one_result)
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,8 @@
DELTA_DATE_FORMAT as DELTA_DATE_FORMAT,
DELTA_DATETIME_FORMAT as DELTA_DATETIME_FORMAT,
DeltaLakeIOManager as DeltaLakeIOManager,
WriteMode as WriteMode,
WriterEngine as WriterEngine,
)
from .resource import DeltaTableResource as DeltaTableResource
from .version import __version__
Expand Down
Loading

0 comments on commit c56f856

Please sign in to comment.