Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: dagster-deltalake-polars LazyFrame support + expose all writer params #19343

Merged
merged 5 commits into from
Feb 1, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion pyright/master/requirements-pinned.txt
Original file line number Diff line number Diff line change
Expand Up @@ -172,7 +172,7 @@ dbt-extractor==0.4.1
debugpy==1.8.0
decorator==5.1.1
defusedxml==0.7.1
deltalake==0.14.0
deltalake==0.15.0
Deprecated==1.2.14
-e examples/development_to_production
dict2css==0.3.0.post1
Expand Down
Original file line number Diff line number Diff line change
@@ -1,29 +1,68 @@
from typing import Any, Dict, Optional, Sequence, Tuple, Type
from typing import Any, Dict, Optional, Sequence, Tuple, Type, Union

import polars as pl
import pyarrow as pa
import pyarrow.dataset as ds
from dagster import InputContext
from dagster._core.storage.db_io_manager import (
DbTypeHandler,
TableSlice,
)
from dagster_deltalake.handler import (
DeltalakeBaseArrowTypeHandler,
DeltaLakePyArrowTypeHandler,
_table_reader,
)
from dagster_deltalake.io_manager import DeltaLakeIOManager
from dagster_deltalake.io_manager import DeltaLakeIOManager, TableConnection

PolarsTypes = Union[pl.DataFrame, pl.LazyFrame]

class DeltaLakePolarsTypeHandler(DeltalakeBaseArrowTypeHandler[pl.DataFrame]):

class DeltaLakePolarsTypeHandler(DeltalakeBaseArrowTypeHandler[PolarsTypes]):
def from_arrow(
self, obj: pa.RecordBatchReader, target_type: Type[pl.DataFrame]
) -> pl.DataFrame:
return pl.from_arrow(obj) # type: ignore
self,
obj: Union[ds.Dataset, pa.RecordBatchReader],
target_type: Type[PolarsTypes],
) -> PolarsTypes:
if isinstance(obj, pa.RecordBatchReader):
return pl.DataFrame(obj.read_all())
elif isinstance(obj, ds.Dataset):
df = pl.scan_pyarrow_dataset(obj)
if target_type == pl.DataFrame:
return df.collect()
else:
return df
else:
raise NotImplementedError("Unsupported objected passed of type: %s", type(obj))

def to_arrow(self, obj: pl.DataFrame) -> Tuple[pa.RecordBatchReader, Dict[str, Any]]:
def to_arrow(self, obj: PolarsTypes) -> Tuple[pa.RecordBatchReader, Dict[str, Any]]:
if isinstance(obj, pl.LazyFrame):
obj = obj.collect()
return obj.to_arrow().to_reader(), {"large_dtypes": True}

def load_input(
self,
context: InputContext,
table_slice: TableSlice,
connection: TableConnection,
) -> PolarsTypes:
"""Loads the input as a Polars DataFrame or LazyFrame."""
dataset = _table_reader(table_slice, connection)

if table_slice.columns is not None:
if context.dagster_type.typing_type == pl.LazyFrame:
return self.from_arrow(dataset, context.dagster_type.typing_type).select(
table_slice.columns
)
else:
scanner = dataset.scanner(columns=table_slice.columns)
return self.from_arrow(scanner.to_reader(), context.dagster_type.typing_type)
else:
return self.from_arrow(dataset, context.dagster_type.typing_type)

@property
def supported_types(self) -> Sequence[Type[object]]:
return [pl.DataFrame]
return [pl.DataFrame, pl.LazyFrame]


class DeltaLakePolarsIOManager(DeltaLakeIOManager):
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -20,13 +20,16 @@
)
from dagster._check import CheckError
from dagster_deltalake import DELTA_DATE_FORMAT, LocalConfig
from dagster_deltalake.io_manager import WriteMode
from dagster_deltalake_polars import DeltaLakePolarsIOManager
from deltalake import DeltaTable


@pytest.fixture
def io_manager(tmp_path) -> DeltaLakePolarsIOManager:
return DeltaLakePolarsIOManager(root_uri=str(tmp_path), storage_options=LocalConfig())
return DeltaLakePolarsIOManager(
root_uri=str(tmp_path), storage_options=LocalConfig(), mode=WriteMode.overwrite
)


@op(out=Out(metadata={"schema": "a_df"}))
Expand Down Expand Up @@ -74,19 +77,38 @@ def b_plus_one(b_df: pl.DataFrame) -> pl.DataFrame:
return b_df + 1


def test_deltalake_io_manager_with_assets(tmp_path, io_manager):
@asset(key_prefix=["my_schema"])
def b_df_lazy() -> pl.LazyFrame:
return pl.LazyFrame({"a": [1, 2, 3], "b": [4, 5, 6]})


@asset(key_prefix=["my_schema"])
def b_plus_one_lazy(b_df_lazy: pl.LazyFrame) -> pl.LazyFrame:
return b_df_lazy.select(pl.all() + 1)


@pytest.mark.parametrize(
"asset1,asset2,asset1_path,asset2_path",
[
(b_df, b_plus_one, "b_df", "b_plus_one"),
(b_df_lazy, b_plus_one_lazy, "b_df_lazy", "b_plus_one_lazy"),
],
)
def test_deltalake_io_manager_with_assets(
tmp_path, io_manager, asset1, asset2, asset1_path, asset2_path
):
resource_defs = {"io_manager": io_manager}

# materialize asset twice to ensure that tables get properly deleted
for _ in range(2):
res = materialize([b_df, b_plus_one], resources=resource_defs)
res = materialize([asset1, asset2], resources=resource_defs)
assert res.success

dt = DeltaTable(os.path.join(tmp_path, "my_schema/b_df"))
dt = DeltaTable(os.path.join(tmp_path, "my_schema/" + asset1_path))
out_df = dt.to_pyarrow_table()
assert out_df["a"].to_pylist() == [1, 2, 3]

dt = DeltaTable(os.path.join(tmp_path, "my_schema/b_plus_one"))
dt = DeltaTable(os.path.join(tmp_path, "my_schema/" + asset2_path))
out_df = dt.to_pyarrow_table()
assert out_df["a"].to_pylist() == [2, 3, 4]

Expand Down Expand Up @@ -125,19 +147,33 @@ def b_plus_one_columns(b_df: pl.DataFrame) -> pl.DataFrame:
return b_df + 1


def test_loading_columns(tmp_path, io_manager):
@asset(
key_prefix=["my_schema"], ins={"b_df_lazy": AssetIn("b_df_lazy", metadata={"columns": ["a"]})}
)
def b_plus_one_columns_lazy(b_df_lazy: pl.LazyFrame) -> pl.LazyFrame:
return b_df_lazy.select(pl.all() + 1)


@pytest.mark.parametrize(
"asset1,asset2,asset1_path,asset2_path",
[
(b_df, b_plus_one_columns, "b_df", "b_plus_one_columns"),
(b_df_lazy, b_plus_one_columns_lazy, "b_df_lazy", "b_plus_one_columns_lazy"),
],
)
def test_loading_columns(tmp_path, io_manager, asset1, asset2, asset1_path, asset2_path):
resource_defs = {"io_manager": io_manager}

# materialize asset twice to ensure that tables get properly deleted
for _ in range(2):
res = materialize([b_df, b_plus_one_columns], resources=resource_defs)
res = materialize([asset1, asset2], resources=resource_defs)
assert res.success

dt = DeltaTable(os.path.join(tmp_path, "my_schema/b_df"))
dt = DeltaTable(os.path.join(tmp_path, "my_schema/" + asset1_path))
out_df = dt.to_pyarrow_table()
assert out_df["a"].to_pylist() == [1, 2, 3]

dt = DeltaTable(os.path.join(tmp_path, "my_schema/b_plus_one_columns"))
dt = DeltaTable(os.path.join(tmp_path, "my_schema/" + asset2_path))
out_df = dt.to_pyarrow_table()
assert out_df["a"].to_pylist() == [2, 3, 4]

Expand Down Expand Up @@ -185,36 +221,64 @@ def daily_partitioned(context: AssetExecutionContext) -> pl.DataFrame:
)


def test_time_window_partitioned_asset(tmp_path, io_manager):
@asset(
partitions_def=DailyPartitionsDefinition(start_date="2022-01-01"),
key_prefix=["my_schema"],
metadata={"partition_expr": "time"},
config_schema={"value": str},
)
def daily_partitioned_lazy(context) -> pl.LazyFrame:
partition = datetime.strptime(
context.asset_partition_key_for_output(), DELTA_DATE_FORMAT
).date()
value = context.op_config["value"]

return pl.LazyFrame(
{
"time": [partition, partition, partition],
"a": [value, value, value],
"b": [4, 5, 6],
}
)


@pytest.mark.parametrize(
"asset1,asset1_path",
[
(daily_partitioned, "daily_partitioned"),
(daily_partitioned_lazy, "daily_partitioned_lazy"),
],
)
def test_time_window_partitioned_asset(tmp_path, io_manager, asset1, asset1_path):
resource_defs = {"io_manager": io_manager}

materialize(
[daily_partitioned],
[asset1],
partition_key="2022-01-01",
resources=resource_defs,
run_config={"ops": {"my_schema__daily_partitioned": {"config": {"value": "1"}}}},
run_config={"ops": {"my_schema__" + asset1_path: {"config": {"value": "1"}}}},
)

dt = DeltaTable(os.path.join(tmp_path, "my_schema/daily_partitioned"))
dt = DeltaTable(os.path.join(tmp_path, "my_schema/" + asset1_path))
out_df = dt.to_pyarrow_table()
assert out_df["a"].to_pylist() == ["1", "1", "1"]

materialize(
[daily_partitioned],
[asset1],
partition_key="2022-01-02",
resources=resource_defs,
run_config={"ops": {"my_schema__daily_partitioned": {"config": {"value": "2"}}}},
run_config={"ops": {"my_schema__" + asset1_path: {"config": {"value": "2"}}}},
)

dt.update_incremental()
out_df = dt.to_pyarrow_table()
assert sorted(out_df["a"].to_pylist()) == ["1", "1", "1", "2", "2", "2"]

materialize(
[daily_partitioned],
[asset1],
partition_key="2022-01-01",
resources=resource_defs,
run_config={"ops": {"my_schema__daily_partitioned": {"config": {"value": "3"}}}},
run_config={"ops": {"my_schema__" + asset1_path: {"config": {"value": "3"}}}},
)

dt.update_incremental()
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,129 @@
import os

import polars as pl
import pytest
from dagster import (
Out,
graph,
op,
)
from dagster_deltalake import LocalConfig
from dagster_deltalake.io_manager import WriteMode
from dagster_deltalake_polars import DeltaLakePolarsIOManager
from deltalake import DeltaTable


@pytest.fixture
def io_manager(tmp_path) -> DeltaLakePolarsIOManager:
return DeltaLakePolarsIOManager(
root_uri=str(tmp_path), storage_options=LocalConfig(), mode=WriteMode.overwrite
)


@pytest.fixture
def io_manager_append(tmp_path) -> DeltaLakePolarsIOManager:
return DeltaLakePolarsIOManager(
root_uri=str(tmp_path), storage_options=LocalConfig(), mode=WriteMode.append
)


@pytest.fixture
def io_manager_ignore(tmp_path) -> DeltaLakePolarsIOManager:
return DeltaLakePolarsIOManager(
root_uri=str(tmp_path), storage_options=LocalConfig(), mode=WriteMode.ignore
)


@op(out=Out(metadata={"schema": "a_df"}))
def a_df() -> pl.DataFrame:
return pl.DataFrame({"a": [1, 2, 3], "b": [4, 5, 6]})


@op(out=Out(metadata={"schema": "add_one"}))
def add_one(df: pl.DataFrame):
return df + 1


@graph
def add_one_to_dataframe():
add_one(a_df())


@graph
def just_a_df():
a_df()


def test_deltalake_io_manager_with_ops_appended(tmp_path, io_manager_append):
resource_defs = {"io_manager": io_manager_append}

job = just_a_df.to_job(resource_defs=resource_defs)

# run the job twice to ensure tables get appended
expected_result1 = [1, 2, 3]

for _ in range(2):
res = job.execute_in_process()

assert res.success

dt = DeltaTable(os.path.join(tmp_path, "a_df/result"))
out_df = dt.to_pyarrow_table()
assert out_df["a"].to_pylist() == expected_result1

expected_result1.extend(expected_result1)


def test_deltalake_io_manager_with_ops_ignored(tmp_path, io_manager_ignore):
resource_defs = {"io_manager": io_manager_ignore}

job = just_a_df.to_job(resource_defs=resource_defs)

# run the job 5 times to ensure tables gets ignored on each write
for _ in range(5):
res = job.execute_in_process()

assert res.success

dt = DeltaTable(os.path.join(tmp_path, "a_df/result"))
out_df = dt.to_pyarrow_table()
assert out_df["a"].to_pylist() == [1, 2, 3]

assert dt.version() == 0


@op(out=Out(metadata={"schema": "a_df", "mode": "append"}))
def a_df_custom() -> pl.DataFrame:
return pl.DataFrame({"a": [1, 2, 3], "b": [4, 5, 6]})


@graph
def add_one_to_dataframe_custom():
add_one(a_df_custom())


def test_deltalake_io_manager_with_ops_mode_overriden(tmp_path, io_manager):
resource_defs = {"io_manager": io_manager}

job = add_one_to_dataframe_custom.to_job(resource_defs=resource_defs)

# run the job twice to ensure that tables get properly deleted

a_df_result = [1, 2, 3]
add_one_result = [2, 3, 4]

for _ in range(2):
res = job.execute_in_process()

assert res.success

dt = DeltaTable(os.path.join(tmp_path, "a_df/result"))
out_df = dt.to_pyarrow_table()
assert out_df["a"].to_pylist() == a_df_result

dt = DeltaTable(os.path.join(tmp_path, "add_one/result"))
out_df = dt.to_pyarrow_table()
assert out_df["a"].to_pylist() == add_one_result

a_df_result.extend(a_df_result)
add_one_result.extend(add_one_result)
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,8 @@
DELTA_DATE_FORMAT as DELTA_DATE_FORMAT,
DELTA_DATETIME_FORMAT as DELTA_DATETIME_FORMAT,
DeltaLakeIOManager as DeltaLakeIOManager,
WriteMode as WriteMode,
WriterEngine as WriterEngine,
)
from .resource import DeltaTableResource as DeltaTableResource
from .version import __version__
Expand Down
Loading