diff --git a/pyright/master/requirements-pinned.txt b/pyright/master/requirements-pinned.txt index ca506dce3a07e..90222fddb64fd 100644 --- a/pyright/master/requirements-pinned.txt +++ b/pyright/master/requirements-pinned.txt @@ -172,7 +172,7 @@ dbt-extractor==0.4.1 debugpy==1.8.0 decorator==5.1.1 defusedxml==0.7.1 -deltalake==0.14.0 +deltalake==0.15.0 Deprecated==1.2.14 -e examples/development_to_production dict2css==0.3.0.post1 diff --git a/python_modules/libraries/dagster-deltalake-polars/dagster_deltalake_polars/deltalake_polars_type_handler.py b/python_modules/libraries/dagster-deltalake-polars/dagster_deltalake_polars/deltalake_polars_type_handler.py index ff62eece88397..7d4d250ea97b3 100644 --- a/python_modules/libraries/dagster-deltalake-polars/dagster_deltalake_polars/deltalake_polars_type_handler.py +++ b/python_modules/libraries/dagster-deltalake-polars/dagster_deltalake_polars/deltalake_polars_type_handler.py @@ -1,29 +1,68 @@ -from typing import Any, Dict, Optional, Sequence, Tuple, Type +from typing import Any, Dict, Optional, Sequence, Tuple, Type, Union import polars as pl import pyarrow as pa +import pyarrow.dataset as ds +from dagster import InputContext from dagster._core.storage.db_io_manager import ( DbTypeHandler, + TableSlice, ) from dagster_deltalake.handler import ( DeltalakeBaseArrowTypeHandler, DeltaLakePyArrowTypeHandler, + _table_reader, ) -from dagster_deltalake.io_manager import DeltaLakeIOManager +from dagster_deltalake.io_manager import DeltaLakeIOManager, TableConnection +PolarsTypes = Union[pl.DataFrame, pl.LazyFrame] -class DeltaLakePolarsTypeHandler(DeltalakeBaseArrowTypeHandler[pl.DataFrame]): + +class DeltaLakePolarsTypeHandler(DeltalakeBaseArrowTypeHandler[PolarsTypes]): def from_arrow( - self, obj: pa.RecordBatchReader, target_type: Type[pl.DataFrame] - ) -> pl.DataFrame: - return pl.from_arrow(obj) # type: ignore + self, + obj: Union[ds.Dataset, pa.RecordBatchReader], + target_type: Type[PolarsTypes], + ) -> PolarsTypes: + if isinstance(obj, pa.RecordBatchReader): + return pl.DataFrame(obj.read_all()) + elif isinstance(obj, ds.Dataset): + df = pl.scan_pyarrow_dataset(obj) + if target_type == pl.DataFrame: + return df.collect() + else: + return df + else: + raise NotImplementedError("Unsupported objected passed of type: %s", type(obj)) - def to_arrow(self, obj: pl.DataFrame) -> Tuple[pa.RecordBatchReader, Dict[str, Any]]: + def to_arrow(self, obj: PolarsTypes) -> Tuple[pa.RecordBatchReader, Dict[str, Any]]: + if isinstance(obj, pl.LazyFrame): + obj = obj.collect() return obj.to_arrow().to_reader(), {"large_dtypes": True} + def load_input( + self, + context: InputContext, + table_slice: TableSlice, + connection: TableConnection, + ) -> PolarsTypes: + """Loads the input as a Polars DataFrame or LazyFrame.""" + dataset = _table_reader(table_slice, connection) + + if table_slice.columns is not None: + if context.dagster_type.typing_type == pl.LazyFrame: + return self.from_arrow(dataset, context.dagster_type.typing_type).select( + table_slice.columns + ) + else: + scanner = dataset.scanner(columns=table_slice.columns) + return self.from_arrow(scanner.to_reader(), context.dagster_type.typing_type) + else: + return self.from_arrow(dataset, context.dagster_type.typing_type) + @property def supported_types(self) -> Sequence[Type[object]]: - return [pl.DataFrame] + return [pl.DataFrame, pl.LazyFrame] class DeltaLakePolarsIOManager(DeltaLakeIOManager): diff --git a/python_modules/libraries/dagster-deltalake-polars/dagster_deltalake_polars_tests/test_type_handler.py b/python_modules/libraries/dagster-deltalake-polars/dagster_deltalake_polars_tests/test_type_handler.py index 1166b0583ccbe..505421458afc7 100644 --- a/python_modules/libraries/dagster-deltalake-polars/dagster_deltalake_polars_tests/test_type_handler.py +++ b/python_modules/libraries/dagster-deltalake-polars/dagster_deltalake_polars_tests/test_type_handler.py @@ -20,13 +20,16 @@ ) from dagster._check import CheckError from dagster_deltalake import DELTA_DATE_FORMAT, LocalConfig +from dagster_deltalake.io_manager import WriteMode from dagster_deltalake_polars import DeltaLakePolarsIOManager from deltalake import DeltaTable @pytest.fixture def io_manager(tmp_path) -> DeltaLakePolarsIOManager: - return DeltaLakePolarsIOManager(root_uri=str(tmp_path), storage_options=LocalConfig()) + return DeltaLakePolarsIOManager( + root_uri=str(tmp_path), storage_options=LocalConfig(), mode=WriteMode.overwrite + ) @op(out=Out(metadata={"schema": "a_df"})) @@ -74,19 +77,38 @@ def b_plus_one(b_df: pl.DataFrame) -> pl.DataFrame: return b_df + 1 -def test_deltalake_io_manager_with_assets(tmp_path, io_manager): +@asset(key_prefix=["my_schema"]) +def b_df_lazy() -> pl.LazyFrame: + return pl.LazyFrame({"a": [1, 2, 3], "b": [4, 5, 6]}) + + +@asset(key_prefix=["my_schema"]) +def b_plus_one_lazy(b_df_lazy: pl.LazyFrame) -> pl.LazyFrame: + return b_df_lazy.select(pl.all() + 1) + + +@pytest.mark.parametrize( + "asset1,asset2,asset1_path,asset2_path", + [ + (b_df, b_plus_one, "b_df", "b_plus_one"), + (b_df_lazy, b_plus_one_lazy, "b_df_lazy", "b_plus_one_lazy"), + ], +) +def test_deltalake_io_manager_with_assets( + tmp_path, io_manager, asset1, asset2, asset1_path, asset2_path +): resource_defs = {"io_manager": io_manager} # materialize asset twice to ensure that tables get properly deleted for _ in range(2): - res = materialize([b_df, b_plus_one], resources=resource_defs) + res = materialize([asset1, asset2], resources=resource_defs) assert res.success - dt = DeltaTable(os.path.join(tmp_path, "my_schema/b_df")) + dt = DeltaTable(os.path.join(tmp_path, "my_schema/" + asset1_path)) out_df = dt.to_pyarrow_table() assert out_df["a"].to_pylist() == [1, 2, 3] - dt = DeltaTable(os.path.join(tmp_path, "my_schema/b_plus_one")) + dt = DeltaTable(os.path.join(tmp_path, "my_schema/" + asset2_path)) out_df = dt.to_pyarrow_table() assert out_df["a"].to_pylist() == [2, 3, 4] @@ -125,19 +147,33 @@ def b_plus_one_columns(b_df: pl.DataFrame) -> pl.DataFrame: return b_df + 1 -def test_loading_columns(tmp_path, io_manager): +@asset( + key_prefix=["my_schema"], ins={"b_df_lazy": AssetIn("b_df_lazy", metadata={"columns": ["a"]})} +) +def b_plus_one_columns_lazy(b_df_lazy: pl.LazyFrame) -> pl.LazyFrame: + return b_df_lazy.select(pl.all() + 1) + + +@pytest.mark.parametrize( + "asset1,asset2,asset1_path,asset2_path", + [ + (b_df, b_plus_one_columns, "b_df", "b_plus_one_columns"), + (b_df_lazy, b_plus_one_columns_lazy, "b_df_lazy", "b_plus_one_columns_lazy"), + ], +) +def test_loading_columns(tmp_path, io_manager, asset1, asset2, asset1_path, asset2_path): resource_defs = {"io_manager": io_manager} # materialize asset twice to ensure that tables get properly deleted for _ in range(2): - res = materialize([b_df, b_plus_one_columns], resources=resource_defs) + res = materialize([asset1, asset2], resources=resource_defs) assert res.success - dt = DeltaTable(os.path.join(tmp_path, "my_schema/b_df")) + dt = DeltaTable(os.path.join(tmp_path, "my_schema/" + asset1_path)) out_df = dt.to_pyarrow_table() assert out_df["a"].to_pylist() == [1, 2, 3] - dt = DeltaTable(os.path.join(tmp_path, "my_schema/b_plus_one_columns")) + dt = DeltaTable(os.path.join(tmp_path, "my_schema/" + asset2_path)) out_df = dt.to_pyarrow_table() assert out_df["a"].to_pylist() == [2, 3, 4] @@ -185,25 +221,53 @@ def daily_partitioned(context: AssetExecutionContext) -> pl.DataFrame: ) -def test_time_window_partitioned_asset(tmp_path, io_manager): +@asset( + partitions_def=DailyPartitionsDefinition(start_date="2022-01-01"), + key_prefix=["my_schema"], + metadata={"partition_expr": "time"}, + config_schema={"value": str}, +) +def daily_partitioned_lazy(context) -> pl.LazyFrame: + partition = datetime.strptime( + context.asset_partition_key_for_output(), DELTA_DATE_FORMAT + ).date() + value = context.op_config["value"] + + return pl.LazyFrame( + { + "time": [partition, partition, partition], + "a": [value, value, value], + "b": [4, 5, 6], + } + ) + + +@pytest.mark.parametrize( + "asset1,asset1_path", + [ + (daily_partitioned, "daily_partitioned"), + (daily_partitioned_lazy, "daily_partitioned_lazy"), + ], +) +def test_time_window_partitioned_asset(tmp_path, io_manager, asset1, asset1_path): resource_defs = {"io_manager": io_manager} materialize( - [daily_partitioned], + [asset1], partition_key="2022-01-01", resources=resource_defs, - run_config={"ops": {"my_schema__daily_partitioned": {"config": {"value": "1"}}}}, + run_config={"ops": {"my_schema__" + asset1_path: {"config": {"value": "1"}}}}, ) - dt = DeltaTable(os.path.join(tmp_path, "my_schema/daily_partitioned")) + dt = DeltaTable(os.path.join(tmp_path, "my_schema/" + asset1_path)) out_df = dt.to_pyarrow_table() assert out_df["a"].to_pylist() == ["1", "1", "1"] materialize( - [daily_partitioned], + [asset1], partition_key="2022-01-02", resources=resource_defs, - run_config={"ops": {"my_schema__daily_partitioned": {"config": {"value": "2"}}}}, + run_config={"ops": {"my_schema__" + asset1_path: {"config": {"value": "2"}}}}, ) dt.update_incremental() @@ -211,10 +275,10 @@ def test_time_window_partitioned_asset(tmp_path, io_manager): assert sorted(out_df["a"].to_pylist()) == ["1", "1", "1", "2", "2", "2"] materialize( - [daily_partitioned], + [asset1], partition_key="2022-01-01", resources=resource_defs, - run_config={"ops": {"my_schema__daily_partitioned": {"config": {"value": "3"}}}}, + run_config={"ops": {"my_schema__" + asset1_path: {"config": {"value": "3"}}}}, ) dt.update_incremental() diff --git a/python_modules/libraries/dagster-deltalake-polars/dagster_deltalake_polars_tests/test_type_handler_save_modes.py b/python_modules/libraries/dagster-deltalake-polars/dagster_deltalake_polars_tests/test_type_handler_save_modes.py new file mode 100644 index 0000000000000..804909aea91b7 --- /dev/null +++ b/python_modules/libraries/dagster-deltalake-polars/dagster_deltalake_polars_tests/test_type_handler_save_modes.py @@ -0,0 +1,129 @@ +import os + +import polars as pl +import pytest +from dagster import ( + Out, + graph, + op, +) +from dagster_deltalake import LocalConfig +from dagster_deltalake.io_manager import WriteMode +from dagster_deltalake_polars import DeltaLakePolarsIOManager +from deltalake import DeltaTable + + +@pytest.fixture +def io_manager(tmp_path) -> DeltaLakePolarsIOManager: + return DeltaLakePolarsIOManager( + root_uri=str(tmp_path), storage_options=LocalConfig(), mode=WriteMode.overwrite + ) + + +@pytest.fixture +def io_manager_append(tmp_path) -> DeltaLakePolarsIOManager: + return DeltaLakePolarsIOManager( + root_uri=str(tmp_path), storage_options=LocalConfig(), mode=WriteMode.append + ) + + +@pytest.fixture +def io_manager_ignore(tmp_path) -> DeltaLakePolarsIOManager: + return DeltaLakePolarsIOManager( + root_uri=str(tmp_path), storage_options=LocalConfig(), mode=WriteMode.ignore + ) + + +@op(out=Out(metadata={"schema": "a_df"})) +def a_df() -> pl.DataFrame: + return pl.DataFrame({"a": [1, 2, 3], "b": [4, 5, 6]}) + + +@op(out=Out(metadata={"schema": "add_one"})) +def add_one(df: pl.DataFrame): + return df + 1 + + +@graph +def add_one_to_dataframe(): + add_one(a_df()) + + +@graph +def just_a_df(): + a_df() + + +def test_deltalake_io_manager_with_ops_appended(tmp_path, io_manager_append): + resource_defs = {"io_manager": io_manager_append} + + job = just_a_df.to_job(resource_defs=resource_defs) + + # run the job twice to ensure tables get appended + expected_result1 = [1, 2, 3] + + for _ in range(2): + res = job.execute_in_process() + + assert res.success + + dt = DeltaTable(os.path.join(tmp_path, "a_df/result")) + out_df = dt.to_pyarrow_table() + assert out_df["a"].to_pylist() == expected_result1 + + expected_result1.extend(expected_result1) + + +def test_deltalake_io_manager_with_ops_ignored(tmp_path, io_manager_ignore): + resource_defs = {"io_manager": io_manager_ignore} + + job = just_a_df.to_job(resource_defs=resource_defs) + + # run the job 5 times to ensure tables gets ignored on each write + for _ in range(5): + res = job.execute_in_process() + + assert res.success + + dt = DeltaTable(os.path.join(tmp_path, "a_df/result")) + out_df = dt.to_pyarrow_table() + assert out_df["a"].to_pylist() == [1, 2, 3] + + assert dt.version() == 0 + + +@op(out=Out(metadata={"schema": "a_df", "mode": "append"})) +def a_df_custom() -> pl.DataFrame: + return pl.DataFrame({"a": [1, 2, 3], "b": [4, 5, 6]}) + + +@graph +def add_one_to_dataframe_custom(): + add_one(a_df_custom()) + + +def test_deltalake_io_manager_with_ops_mode_overriden(tmp_path, io_manager): + resource_defs = {"io_manager": io_manager} + + job = add_one_to_dataframe_custom.to_job(resource_defs=resource_defs) + + # run the job twice to ensure that tables get properly deleted + + a_df_result = [1, 2, 3] + add_one_result = [2, 3, 4] + + for _ in range(2): + res = job.execute_in_process() + + assert res.success + + dt = DeltaTable(os.path.join(tmp_path, "a_df/result")) + out_df = dt.to_pyarrow_table() + assert out_df["a"].to_pylist() == a_df_result + + dt = DeltaTable(os.path.join(tmp_path, "add_one/result")) + out_df = dt.to_pyarrow_table() + assert out_df["a"].to_pylist() == add_one_result + + a_df_result.extend(a_df_result) + add_one_result.extend(add_one_result) diff --git a/python_modules/libraries/dagster-deltalake/dagster_deltalake/__init__.py b/python_modules/libraries/dagster-deltalake/dagster_deltalake/__init__.py index d7337ce5bb14f..d8a1b0cd9384e 100644 --- a/python_modules/libraries/dagster-deltalake/dagster_deltalake/__init__.py +++ b/python_modules/libraries/dagster-deltalake/dagster_deltalake/__init__.py @@ -16,6 +16,8 @@ DELTA_DATE_FORMAT as DELTA_DATE_FORMAT, DELTA_DATETIME_FORMAT as DELTA_DATETIME_FORMAT, DeltaLakeIOManager as DeltaLakeIOManager, + WriteMode as WriteMode, + WriterEngine as WriterEngine, ) from .resource import DeltaTableResource as DeltaTableResource from .version import __version__ diff --git a/python_modules/libraries/dagster-deltalake/dagster_deltalake/handler.py b/python_modules/libraries/dagster-deltalake/dagster_deltalake/handler.py index f2db7b8a7ae9b..1085ed21c0ef7 100644 --- a/python_modules/libraries/dagster-deltalake/dagster_deltalake/handler.py +++ b/python_modules/libraries/dagster-deltalake/dagster_deltalake/handler.py @@ -24,7 +24,7 @@ TablePartitionDimension, TableSlice, ) -from deltalake import DeltaTable, write_deltalake +from deltalake import DeltaTable, WriterProperties, write_deltalake from deltalake.schema import ( Field as DeltaField, PrimitiveType, @@ -55,28 +55,56 @@ def handle_output( connection: TableConnection, ): """Stores pyarrow types in Delta table.""" + metadata = context.metadata or {} + resource_config = context.resource_config or {} reader, delta_params = self.to_arrow(obj=obj) delta_schema = Schema.from_pyarrow(reader.schema) + engine = resource_config.get("writer_engine") + save_mode = metadata.get("mode") + main_save_mode = resource_config.get("mode") + main_custom_metadata = resource_config.get("custom_metadata") + overwrite_schema = resource_config.get("overwrite_schema") + writerprops = resource_config.get("writer_properties") + + if save_mode is not None: + context.log.debug( + "IO manager mode overridden with the asset metadata mode, %s -> %s", + main_save_mode, + save_mode, + ) + main_save_mode = save_mode + context.log.debug("Writing with mode: %s", main_save_mode) + partition_filters = None partition_columns = None + if table_slice.partition_dimensions is not None: partition_filters = partition_dimensions_to_dnf( partition_dimensions=table_slice.partition_dimensions, table_schema=delta_schema, str_values=True, ) - + if partition_filters is not None and engine == "rust": + raise ValueError( + """Partition dimension with rust engine writer combined is not supported yet, use the default 'pyarrow' engine.""" + ) # TODO make robust and move to function partition_columns = [dim.partition_expr for dim in table_slice.partition_dimensions] - write_deltalake( + write_deltalake( # type: ignore table_or_uri=connection.table_uri, data=reader, storage_options=connection.storage_options, - mode="overwrite", + mode=main_save_mode, partition_filters=partition_filters, partition_by=partition_columns, + engine=engine, + overwrite_schema=metadata.get("overwrite_schema") or overwrite_schema, + custom_metadata=metadata.get("custom_metadata") or main_custom_metadata, + writer_properties=WriterProperties(**writerprops) # type: ignore + if writerprops is not None + else writerprops, **delta_params, ) @@ -110,22 +138,7 @@ def load_input( connection: TableConnection, ) -> T: """Loads the input as a pyarrow Table or RecordBatchReader.""" - table = DeltaTable( - table_uri=connection.table_uri, storage_options=connection.storage_options - ) - - partition_expr = None - if table_slice.partition_dimensions is not None: - partition_filters = partition_dimensions_to_dnf( - partition_dimensions=table_slice.partition_dimensions, - table_schema=table.schema(), - ) - if partition_filters is not None: - partition_expr = _filters_to_expression([partition_filters]) - - dataset = table.to_pyarrow_dataset() - if partition_expr is not None: - dataset = dataset.filter(expression=partition_expr) + dataset = _table_reader(table_slice, connection) if context.dagster_type.typing_type == ds.Dataset: if table_slice.columns is not None: @@ -231,3 +244,22 @@ def _get_partition_stats(dt: DeltaTable, partition_filters=None): } return table, stats + + +def _table_reader(table_slice: TableSlice, connection: TableConnection) -> ds.Dataset: + table = DeltaTable(table_uri=connection.table_uri, storage_options=connection.storage_options) + + partition_expr = None + if table_slice.partition_dimensions is not None: + partition_filters = partition_dimensions_to_dnf( + partition_dimensions=table_slice.partition_dimensions, + table_schema=table.schema(), + ) + if partition_filters is not None: + partition_expr = _filters_to_expression([partition_filters]) + + dataset = table.to_pyarrow_dataset() + if partition_expr is not None: + dataset = dataset.filter(expression=partition_expr) + + return dataset diff --git a/python_modules/libraries/dagster-deltalake/dagster_deltalake/io_manager.py b/python_modules/libraries/dagster-deltalake/dagster_deltalake/io_manager.py index 7901992cac0fe..461cb49c9bb2b 100644 --- a/python_modules/libraries/dagster-deltalake/dagster_deltalake/io_manager.py +++ b/python_modules/libraries/dagster-deltalake/dagster_deltalake/io_manager.py @@ -2,6 +2,7 @@ from abc import abstractmethod from contextlib import contextmanager from dataclasses import dataclass +from enum import Enum from typing import Dict, Iterator, Optional, Sequence, Type, Union, cast from dagster import OutputContext @@ -46,11 +47,28 @@ class _StorageOptionsConfig(TypedDict, total=False): gcs: Dict[str, str] +class WriteMode(str, Enum): + error = "error" + append = "append" + overwrite = "overwrite" + ignore = "ignore" + + +class WriterEngine(str, Enum): + pyarrow = "pyarrow" + rust = "rust" + + class _DeltaTableIOManagerResourceConfig(TypedDict): root_uri: str + mode: WriteMode + overwrite_schema: bool + writer_engine: WriterEngine storage_options: _StorageOptionsConfig client_options: NotRequired[Dict[str, str]] table_config: NotRequired[Dict[str, str]] + custom_metadata: NotRequired[Dict[str, str]] + writer_properties: NotRequired[Dict[str, str]] class DeltaLakeIOManager(ConfigurableIOManagerFactory): @@ -106,6 +124,13 @@ def my_table_a(my_table: pd.DataFrame): """ root_uri: str = Field(description="Storage location where Delta tables are stored.") + mode: WriteMode = Field( + default=WriteMode.overwrite.value, description="The write mode passed to save the output." + ) + overwrite_schema: bool = Field(default=False) + writer_engine: WriterEngine = Field( + default=WriterEngine.pyarrow.value, description="Engine passed to write_deltalake." + ) storage_options: Union[AzureConfig, S3Config, LocalConfig, GcsConfig] = Field( discriminator="provider", @@ -125,6 +150,13 @@ def my_table_a(my_table: pd.DataFrame): default=None, alias="schema", description="Name of the schema to use." ) # schema is a reserved word for pydantic + custom_metadata: Optional[Dict[str, str]] = Field( + default=None, description="Custom metadata that is added to transaction commit." + ) + writer_properties: Optional[Dict[str, str]] = Field( + default=None, description="Writer properties passed to the rust engine writer." + ) + @staticmethod @abstractmethod def type_handlers() -> Sequence[DbTypeHandler]: diff --git a/python_modules/libraries/dagster-deltalake/dagster_deltalake/resource.py b/python_modules/libraries/dagster-deltalake/dagster_deltalake/resource.py index 5da8e1a52fe07..8269786c4c75c 100644 --- a/python_modules/libraries/dagster-deltalake/dagster_deltalake/resource.py +++ b/python_modules/libraries/dagster-deltalake/dagster_deltalake/resource.py @@ -42,9 +42,11 @@ def my_table(delta_table: DeltaTableResource): default=None, description="Additional configuration passed to http client." ) + version: Optional[int] = Field(default=None, description="Version to load delta table.") + def load(self) -> DeltaTable: storage_options = self.storage_options.dict() if self.storage_options else {} client_options = self.client_options.dict() if self.client_options else {} options = {**storage_options, **client_options} - table = DeltaTable(table_uri=self.url, storage_options=options) + table = DeltaTable(table_uri=self.url, storage_options=options, version=self.version) return table diff --git a/python_modules/libraries/dagster-deltalake/dagster_deltalake_tests/test_delta_table_resource.py b/python_modules/libraries/dagster-deltalake/dagster_deltalake_tests/test_delta_table_resource.py index b0b06b0594840..c4834a7e3489b 100644 --- a/python_modules/libraries/dagster-deltalake/dagster_deltalake_tests/test_delta_table_resource.py +++ b/python_modules/libraries/dagster-deltalake/dagster_deltalake_tests/test_delta_table_resource.py @@ -32,3 +32,33 @@ def read_table(delta_table: DeltaTableResource): ) }, ) + + +def test_resource_versioned(tmp_path): + data = pa.table( + { + "a": pa.array([1, 2, 3], type=pa.int32()), + "b": pa.array([5, 6, 7], type=pa.int32()), + } + ) + + @asset + def create_table(delta_table: DeltaTableResource): + write_deltalake(delta_table.url, data, storage_options=delta_table.storage_options.dict()) + write_deltalake( + delta_table.url, data, storage_options=delta_table.storage_options.dict(), mode="append" + ) + + @asset + def read_table(delta_table: DeltaTableResource): + res = delta_table.load().to_pyarrow_table() + assert res.equals(data) + + materialize( + [create_table, read_table], + resources={ + "delta_table": DeltaTableResource( + url=os.path.join(tmp_path, "table"), storage_options=LocalConfig(), version=0 + ) + }, + ) diff --git a/python_modules/libraries/dagster-deltalake/dagster_deltalake_tests/test_metadata_inputs.py b/python_modules/libraries/dagster-deltalake/dagster_deltalake_tests/test_metadata_inputs.py new file mode 100644 index 0000000000000..b7af591339ba9 --- /dev/null +++ b/python_modules/libraries/dagster-deltalake/dagster_deltalake_tests/test_metadata_inputs.py @@ -0,0 +1,93 @@ +import os + +import pyarrow as pa +import pytest +from dagster import ( + Out, + graph, + op, +) +from dagster_deltalake import DeltaLakePyarrowIOManager, LocalConfig, WriterEngine +from deltalake import DeltaTable + + +@pytest.fixture +def io_manager(tmp_path) -> DeltaLakePyarrowIOManager: + return DeltaLakePyarrowIOManager( + root_uri=str(tmp_path), storage_options=LocalConfig(), writer_engine=WriterEngine.rust + ) + + +@op( + out=Out( + metadata={"schema": "a_df", "mode": "append", "custom_metadata": {"userName": "John Doe"}} + ) +) +def a_df() -> pa.Table: + return pa.Table.from_pydict({"a": [1, 2, 3], "b": [4, 5, 6]}) + + +@graph +def to_one_df(): + a_df() + + +def test_deltalake_io_manager_with_ops_rust_writer(tmp_path, io_manager): + resource_defs = {"io_manager": io_manager} + + job = to_one_df.to_job(resource_defs=resource_defs) + + result = [1, 2, 3] + for _ in range(1, 4): + res = job.execute_in_process() + + assert res.success + + dt = DeltaTable(os.path.join(tmp_path, "a_df/result")) + last_action = dt.history(1)[0] + assert last_action["userName"] == "John Doe" + out_df = dt.to_pyarrow_table() + assert out_df["a"].to_pylist() == result + + result.extend([1, 2, 3]) + + +@pytest.fixture +def io_manager_with_writer_metadata(tmp_path) -> DeltaLakePyarrowIOManager: + return DeltaLakePyarrowIOManager( + root_uri=str(tmp_path), + storage_options=LocalConfig(), + writer_engine=WriterEngine.rust, + custom_metadata={"userName": "John Doe"}, + writer_properties={"compression": "ZSTD"}, + ) + + +@op(out=Out(metadata={"schema": "a_df"})) +def a_df2() -> pa.Table: + return pa.Table.from_pydict({"a": [1, 2, 3], "b": [4, 5, 6]}) + + +@graph +def to_one_df2(): + a_df2() + + +def test_deltalake_io_manager_with_additional_configs(tmp_path, io_manager_with_writer_metadata): + resource_defs = {"io_manager": io_manager_with_writer_metadata} + + job = to_one_df2.to_job(resource_defs=resource_defs) + res = job.execute_in_process() + + assert res.success + + dt = DeltaTable(os.path.join(tmp_path, "a_df/result")) + + last_action = dt.history(1)[0] + assert last_action["userName"] == "John Doe" + + file = dt.get_add_actions()["path"].to_pylist()[0] + assert os.path.splitext(os.path.splitext(file)[0])[1] == ".zstd" + + out_df = dt.to_pyarrow_table() + assert out_df["a"].to_pylist() == [1, 2, 3] diff --git a/python_modules/libraries/dagster-deltalake/dagster_deltalake_tests/test_type_handler_extra_params.py b/python_modules/libraries/dagster-deltalake/dagster_deltalake_tests/test_type_handler_extra_params.py new file mode 100644 index 0000000000000..85d8865ebb404 --- /dev/null +++ b/python_modules/libraries/dagster-deltalake/dagster_deltalake_tests/test_type_handler_extra_params.py @@ -0,0 +1,53 @@ +import os + +import pyarrow as pa +import pytest +from dagster import ( + Out, + graph, + op, +) +from dagster_deltalake import DeltaLakePyarrowIOManager, LocalConfig, WriterEngine +from deltalake import DeltaTable + + +@pytest.fixture +def io_manager(tmp_path) -> DeltaLakePyarrowIOManager: + return DeltaLakePyarrowIOManager( + root_uri=str(tmp_path), storage_options=LocalConfig(), writer_engine=WriterEngine.rust + ) + + +@op(out=Out(metadata={"schema": "a_df"})) +def a_df() -> pa.Table: + return pa.Table.from_pydict({"a": [1, 2, 3], "b": [4, 5, 6]}) + + +@op(out=Out(metadata={"schema": "add_one"})) +def add_one(df: pa.Table): + return df.set_column(0, "a", pa.array([2, 3, 4])) + + +@graph +def add_one_to_dataframe(): + add_one(a_df()) + + +def test_deltalake_io_manager_with_ops_rust_writer(tmp_path, io_manager): + resource_defs = {"io_manager": io_manager} + + job = add_one_to_dataframe.to_job(resource_defs=resource_defs) + + # run the job twice to ensure that tables get properly deleted + for _ in range(2): + res = job.execute_in_process() + + assert res.success + + dt = DeltaTable(os.path.join(tmp_path, "a_df/result")) + out_df = dt.to_pyarrow_table() + assert out_df["a"].to_pylist() == [1, 2, 3] + + dt = DeltaTable(os.path.join(tmp_path, "add_one/result")) + out_df = dt.to_pyarrow_table() + assert out_df["a"].to_pylist() == [2, 3, 4] diff --git a/python_modules/libraries/dagster-deltalake/setup.py b/python_modules/libraries/dagster-deltalake/setup.py index c018954302682..1a4a229d51e53 100644 --- a/python_modules/libraries/dagster-deltalake/setup.py +++ b/python_modules/libraries/dagster-deltalake/setup.py @@ -34,7 +34,7 @@ def get_version() -> str: packages=find_packages(exclude=["dagster_deltalake_tests*"]), include_package_data=True, install_requires=[ - "deltalake>=0.12", + "deltalake>=0.15", f"dagster{pin}", ], extras_require={