Skip to content

Commit

Permalink
expand param inputs
Browse files Browse the repository at this point in the history
  • Loading branch information
ion-elgreco committed Jan 23, 2024
1 parent b9d44c9 commit e58956a
Show file tree
Hide file tree
Showing 7 changed files with 333 additions and 64 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -19,28 +19,15 @@
)
from dagster._check import CheckError
from dagster_deltalake import DELTA_DATE_FORMAT, LocalConfig
from dagster_deltalake.io_manager import WriteMode
from dagster_deltalake_polars import DeltaLakePolarsIOManager
from deltalake import DeltaTable


@pytest.fixture
def io_manager(tmp_path) -> DeltaLakePolarsIOManager:
return DeltaLakePolarsIOManager(
root_uri=str(tmp_path), storage_options=LocalConfig(), mode="overwrite"
)


@pytest.fixture
def io_manager_append(tmp_path) -> DeltaLakePolarsIOManager:
return DeltaLakePolarsIOManager(
root_uri=str(tmp_path), storage_options=LocalConfig(), mode="append"
)


@pytest.fixture
def io_manager_ignore(tmp_path) -> DeltaLakePolarsIOManager:
return DeltaLakePolarsIOManager(
root_uri=str(tmp_path), storage_options=LocalConfig(), mode="ignore"
root_uri=str(tmp_path), storage_options=LocalConfig(), mode=WriteMode.overwrite
)


Expand Down Expand Up @@ -79,49 +66,6 @@ def test_deltalake_io_manager_with_ops(tmp_path, io_manager):
assert out_df["a"].to_pylist() == [2, 3, 4]


@graph
def just_a_df():
a_df()


def test_deltalake_io_manager_with_ops_appended(tmp_path, io_manager_append):
resource_defs = {"io_manager": io_manager_append}

job = just_a_df.to_job(resource_defs=resource_defs)

# run the job twice to ensure tables get appended
expected_result1 = [1, 2, 3]

for _ in range(2):
res = job.execute_in_process()

assert res.success

dt = DeltaTable(os.path.join(tmp_path, "a_df/result"))
out_df = dt.to_pyarrow_table()
assert out_df["a"].to_pylist() == expected_result1

expected_result1.extend(expected_result1)


def test_deltalake_io_manager_with_ops_ignored(tmp_path, io_manager_ignore):
resource_defs = {"io_manager": io_manager_ignore}

job = just_a_df.to_job(resource_defs=resource_defs)

# run the job 5 times to ensure tables gets ignored on each write
for _ in range(5):
res = job.execute_in_process()

assert res.success

dt = DeltaTable(os.path.join(tmp_path, "a_df/result"))
out_df = dt.to_pyarrow_table()
assert out_df["a"].to_pylist() == [1, 2, 3]

assert dt.version() == 0


@asset(key_prefix=["my_schema"])
def b_df() -> pl.DataFrame:
return pl.DataFrame({"a": [1, 2, 3], "b": [4, 5, 6]})
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,131 @@
import os

import polars as pl
import pytest
from dagster import (
Out,
graph,
op,
)
from dagster_deltalake import LocalConfig
from dagster_deltalake_polars import DeltaLakePolarsIOManager
from deltalake import DeltaTable


@pytest.fixture
def io_manager(tmp_path) -> DeltaLakePolarsIOManager:
return DeltaLakePolarsIOManager(
root_uri=str(tmp_path), storage_options=LocalConfig(), mode=WriteMode.overwrite
)


@pytest.fixture
def io_manager_append(tmp_path) -> DeltaLakePolarsIOManager:
return DeltaLakePolarsIOManager(
root_uri=str(tmp_path), storage_options=LocalConfig(), mode=WriteMode.append
)


from dagster_deltalake.io_manager import WriteMode


@pytest.fixture
def io_manager_ignore(tmp_path) -> DeltaLakePolarsIOManager:
return DeltaLakePolarsIOManager(
root_uri=str(tmp_path), storage_options=LocalConfig(), mode=WriteMode.ignore
)


@op(out=Out(metadata={"schema": "a_df"}))
def a_df() -> pl.DataFrame:
return pl.DataFrame({"a": [1, 2, 3], "b": [4, 5, 6]})


@op(out=Out(metadata={"schema": "add_one"}))
def add_one(df: pl.DataFrame):
return df + 1


@graph
def add_one_to_dataframe():
add_one(a_df())


@graph
def just_a_df():
a_df()


def test_deltalake_io_manager_with_ops_appended(tmp_path, io_manager_append):
resource_defs = {"io_manager": io_manager_append}

job = just_a_df.to_job(resource_defs=resource_defs)

# run the job twice to ensure tables get appended
expected_result1 = [1, 2, 3]

for _ in range(2):
res = job.execute_in_process()

assert res.success

dt = DeltaTable(os.path.join(tmp_path, "a_df/result"))
out_df = dt.to_pyarrow_table()
assert out_df["a"].to_pylist() == expected_result1

expected_result1.extend(expected_result1)


def test_deltalake_io_manager_with_ops_ignored(tmp_path, io_manager_ignore):
resource_defs = {"io_manager": io_manager_ignore}

job = just_a_df.to_job(resource_defs=resource_defs)

# run the job 5 times to ensure tables gets ignored on each write
for _ in range(5):
res = job.execute_in_process()

assert res.success

dt = DeltaTable(os.path.join(tmp_path, "a_df/result"))
out_df = dt.to_pyarrow_table()
assert out_df["a"].to_pylist() == [1, 2, 3]

assert dt.version() == 0


@op(out=Out(metadata={"schema": "a_df", "mode": "append"}))
def a_df_custom() -> pl.DataFrame:
return pl.DataFrame({"a": [1, 2, 3], "b": [4, 5, 6]})


@graph
def add_one_to_dataframe_custom():
add_one(a_df_custom())


def test_deltalake_io_manager_with_ops_mode_overriden(tmp_path, io_manager):
resource_defs = {"io_manager": io_manager}

job = add_one_to_dataframe_custom.to_job(resource_defs=resource_defs)

# run the job twice to ensure that tables get properly deleted

a_df_result = [1, 2, 3]
add_one_result = [2, 3, 4]

for _ in range(2):
res = job.execute_in_process()

assert res.success

dt = DeltaTable(os.path.join(tmp_path, "a_df/result"))
out_df = dt.to_pyarrow_table()
assert out_df["a"].to_pylist() == a_df_result

dt = DeltaTable(os.path.join(tmp_path, "add_one/result"))
out_df = dt.to_pyarrow_table()
assert out_df["a"].to_pylist() == add_one_result

a_df_result.extend(a_df_result)
add_one_result.extend(add_one_result)
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,8 @@
DELTA_DATE_FORMAT as DELTA_DATE_FORMAT,
DELTA_DATETIME_FORMAT as DELTA_DATETIME_FORMAT,
DeltaLakeIOManager as DeltaLakeIOManager,
WriteMode as WriteMode,
WriterEngine as WriterEngine,
)
from .resource import DeltaTableResource as DeltaTableResource
from .version import __version__
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@
TablePartitionDimension,
TableSlice,
)
from deltalake import DeltaTable, write_deltalake
from deltalake import DeltaTable, WriterProperties, write_deltalake
from deltalake.schema import (
Field as DeltaField,
PrimitiveType,
Expand Down Expand Up @@ -55,11 +55,18 @@ def handle_output(
connection: TableConnection,
):
"""Stores pyarrow types in Delta table."""
assert context.metadata is not None
assert context.resource_config is not None
reader, delta_params = self.to_arrow(obj=obj)
delta_schema = Schema.from_pyarrow(reader.schema)

engine = context.resource_config.get("writer_engine")
save_mode = context.metadata.get("mode")
main_save_mode = context.resource_config.get("mode")
main_custom_metadata = context.resource_config.get("custom_metadata")
overwrite_schema = context.resource_config.get("overwrite_schema")
writerprops = context.resource_config.get("writer_properties")

if save_mode is not None:
context.log.info(
"IO manager mode overridden with the asset metadata mode, %s -> %s",
Expand All @@ -71,23 +78,33 @@ def handle_output(

partition_filters = None
partition_columns = None

if table_slice.partition_dimensions is not None:
partition_filters = partition_dimensions_to_dnf(
partition_dimensions=table_slice.partition_dimensions,
table_schema=delta_schema,
str_values=True,
)

if partition_filters is not None and engine == "rust":
raise ValueError(
"""Partition dimension with rust engine writer combined is not supported yet, use the default 'pyarrow' engine."""
)
# TODO make robust and move to function
partition_columns = [dim.partition_expr for dim in table_slice.partition_dimensions]

write_deltalake(
write_deltalake( # type: ignore
table_or_uri=connection.table_uri,
data=reader,
storage_options=connection.storage_options,
mode=main_save_mode,
partition_filters=partition_filters,
partition_by=partition_columns,
engine=engine,
overwrite_schema=context.metadata.get("overwrite_schema") or overwrite_schema,
custom_metadata=context.metadata.get("custom_metadata") or main_custom_metadata,
writer_properties=WriterProperties(**writerprops)
if writerprops is not None
else writerprops, # type: ignore
**delta_params,
)

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
from abc import abstractmethod
from contextlib import contextmanager
from dataclasses import dataclass
from enum import Enum
from typing import Dict, Iterator, Optional, Sequence, Type, Union, cast

from dagster import OutputContext
Expand Down Expand Up @@ -46,12 +47,28 @@ class _StorageOptionsConfig(TypedDict, total=False):
gcs: Dict[str, str]


class WriteMode(str, Enum):
error = "error"
append = "append"
overwrite = "overwrite"
ignore = "ignore"


class WriterEngine(str, Enum):
pyarrow = "pyarrow"
rust = "rust"


class _DeltaTableIOManagerResourceConfig(TypedDict):
root_uri: str
mode: str
mode: WriteMode
overwrite_schema: bool
writer_engine: WriterEngine
storage_options: _StorageOptionsConfig
client_options: NotRequired[Dict[str, str]]
table_config: NotRequired[Dict[str, str]]
custom_metadata: NotRequired[Dict[str, str]]
writer_properties: NotRequired[Dict[str, str]]


class DeltaLakeIOManager(ConfigurableIOManagerFactory):
Expand Down Expand Up @@ -107,8 +124,13 @@ def my_table_a(my_table: pd.DataFrame):
"""

root_uri: str = Field(description="Storage location where Delta tables are stored.")

mode: str = Field(default="overwrite", description="The write mode passed to save the output.")
mode: WriteMode = Field(
default=WriteMode.overwrite.value, description="The write mode passed to save the output."
)
overwrite_schema: bool = Field(default=False)
writer_engine: WriterEngine = Field(
default=WriterEngine.pyarrow.value, description="Engine passed to write_deltalake."
)

storage_options: Union[AzureConfig, S3Config, LocalConfig, GcsConfig] = Field(
discriminator="provider",
Expand All @@ -128,6 +150,13 @@ def my_table_a(my_table: pd.DataFrame):
default=None, alias="schema", description="Name of the schema to use."
) # schema is a reserved word for pydantic

custom_metadata: Optional[Dict[str, str]] = Field(
default=None, description="Custom metadata that is added to transaction commit."
)
writer_properties: Optional[Dict[str, str]] = Field(
default=None, description="Writer properties passed to the rust engine writer."
)

@staticmethod
@abstractmethod
def type_handlers() -> Sequence[DbTypeHandler]:
Expand Down
Loading

0 comments on commit e58956a

Please sign in to comment.