diff --git a/python_modules/libraries/dagster-deltalake-polars/dagster_deltalake_polars_tests/test_type_handler.py b/python_modules/libraries/dagster-deltalake-polars/dagster_deltalake_polars_tests/test_type_handler.py index 048785e789cc4..c4b9df7b4a8da 100644 --- a/python_modules/libraries/dagster-deltalake-polars/dagster_deltalake_polars_tests/test_type_handler.py +++ b/python_modules/libraries/dagster-deltalake-polars/dagster_deltalake_polars_tests/test_type_handler.py @@ -25,7 +25,23 @@ @pytest.fixture def io_manager(tmp_path) -> DeltaLakePolarsIOManager: - return DeltaLakePolarsIOManager(root_uri=str(tmp_path), storage_options=LocalConfig()) + return DeltaLakePolarsIOManager( + root_uri=str(tmp_path), storage_options=LocalConfig(), mode="overwrite" + ) + + +@pytest.fixture +def io_manager_append(tmp_path) -> DeltaLakePolarsIOManager: + return DeltaLakePolarsIOManager( + root_uri=str(tmp_path), storage_options=LocalConfig(), mode="append" + ) + + +@pytest.fixture +def io_manager_ignore(tmp_path) -> DeltaLakePolarsIOManager: + return DeltaLakePolarsIOManager( + root_uri=str(tmp_path), storage_options=LocalConfig(), mode="ignore" + ) @op(out=Out(metadata={"schema": "a_df"})) @@ -63,6 +79,49 @@ def test_deltalake_io_manager_with_ops(tmp_path, io_manager): assert out_df["a"].to_pylist() == [2, 3, 4] +@graph +def just_a_df(): + a_df() + + +def test_deltalake_io_manager_with_ops_appended(tmp_path, io_manager_append): + resource_defs = {"io_manager": io_manager_append} + + job = just_a_df.to_job(resource_defs=resource_defs) + + # run the job twice to ensure tables get appended + expected_result1 = [1, 2, 3] + + for _ in range(2): + res = job.execute_in_process() + + assert res.success + + dt = DeltaTable(os.path.join(tmp_path, "a_df/result")) + out_df = dt.to_pyarrow_table() + assert out_df["a"].to_pylist() == expected_result1 + + expected_result1.extend(expected_result1) + + +def test_deltalake_io_manager_with_ops_ignored(tmp_path, io_manager_ignore): + resource_defs = {"io_manager": io_manager_ignore} + + job = just_a_df.to_job(resource_defs=resource_defs) + + # run the job 5 times to ensure tables gets ignored on each write + for _ in range(5): + res = job.execute_in_process() + + assert res.success + + dt = DeltaTable(os.path.join(tmp_path, "a_df/result")) + out_df = dt.to_pyarrow_table() + assert out_df["a"].to_pylist() == [1, 2, 3] + + assert dt.version() == 0 + + @asset(key_prefix=["my_schema"]) def b_df() -> pl.DataFrame: return pl.DataFrame({"a": [1, 2, 3], "b": [4, 5, 6]}) diff --git a/python_modules/libraries/dagster-deltalake/dagster_deltalake/handler.py b/python_modules/libraries/dagster-deltalake/dagster_deltalake/handler.py index 6807347bac163..44ac55d5535d3 100644 --- a/python_modules/libraries/dagster-deltalake/dagster_deltalake/handler.py +++ b/python_modules/libraries/dagster-deltalake/dagster_deltalake/handler.py @@ -58,6 +58,17 @@ def handle_output( reader, delta_params = self.to_arrow(obj=obj) delta_schema = Schema.from_pyarrow(reader.schema) + save_mode = context.metadata.get("mode") + main_save_mode = context.resource_config.get("mode") + if save_mode is not None: + context.log.info( + "IO manager mode overridden with the asset metadata mode, %s -> %s", + main_save_mode, + save_mode, + ) + main_save_mode = save_mode + context.log.info("Writing with mode: %s", main_save_mode) + partition_filters = None partition_columns = None if table_slice.partition_dimensions is not None: @@ -69,13 +80,12 @@ def handle_output( # TODO make robust and move to function partition_columns = [dim.partition_expr for dim in table_slice.partition_dimensions] - context.log.info('The save mode that will be used %s', context.resource_config.get('mode')) # type: ignore - + write_deltalake( table_or_uri=connection.table_uri, data=reader, storage_options=connection.storage_options, - mode=context.resource_config.get('mode'), # type: ignore + mode=main_save_mode, partition_filters=partition_filters, partition_by=partition_columns, **delta_params, diff --git a/python_modules/libraries/dagster-deltalake/dagster_deltalake/io_manager.py b/python_modules/libraries/dagster-deltalake/dagster_deltalake/io_manager.py index dabf8d3998f00..f76f56c6cc34a 100644 --- a/python_modules/libraries/dagster-deltalake/dagster_deltalake/io_manager.py +++ b/python_modules/libraries/dagster-deltalake/dagster_deltalake/io_manager.py @@ -15,7 +15,6 @@ TableSlice, ) from pydantic import Field -from enum import Enum if sys.version_info >= (3, 8): from typing import TypedDict @@ -46,12 +45,6 @@ class _StorageOptionsConfig(TypedDict, total=False): azure: Dict[str, str] gcs: Dict[str, str] -class _DeltaWriteMode(str, Enum): - error = "error" - append = "append" - overwrite = "overwrite" - ignore = "ignore" - class _DeltaTableIOManagerResourceConfig(TypedDict): root_uri: str @@ -115,8 +108,8 @@ def my_table_a(my_table: pd.DataFrame): root_uri: str = Field(description="Storage location where Delta tables are stored.") - mode: str = Field(default='overwrite', description="The write mode passed to save the output.") - + mode: str = Field(default="overwrite", description="The write mode passed to save the output.") + storage_options: Union[AzureConfig, S3Config, LocalConfig, GcsConfig] = Field( discriminator="provider", description="Configuration for accessing storage location.", diff --git a/python_modules/libraries/dagster-deltalake/dagster_deltalake/resource.py b/python_modules/libraries/dagster-deltalake/dagster_deltalake/resource.py index 9122440e14e0d..8269786c4c75c 100644 --- a/python_modules/libraries/dagster-deltalake/dagster_deltalake/resource.py +++ b/python_modules/libraries/dagster-deltalake/dagster_deltalake/resource.py @@ -42,9 +42,7 @@ def my_table(delta_table: DeltaTableResource): default=None, description="Additional configuration passed to http client." ) - version: Optional[int] = Field( - default = None, description="Version to load delta table." - ) + version: Optional[int] = Field(default=None, description="Version to load delta table.") def load(self) -> DeltaTable: storage_options = self.storage_options.dict() if self.storage_options else {}