Skip to content

Commit

Permalink
add save modes
Browse files Browse the repository at this point in the history
  • Loading branch information
ion-elgreco committed Jan 23, 2024
1 parent 32b4768 commit b9d44c9
Show file tree
Hide file tree
Showing 4 changed files with 76 additions and 16 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,23 @@

@pytest.fixture
def io_manager(tmp_path) -> DeltaLakePolarsIOManager:
return DeltaLakePolarsIOManager(root_uri=str(tmp_path), storage_options=LocalConfig())
return DeltaLakePolarsIOManager(
root_uri=str(tmp_path), storage_options=LocalConfig(), mode="overwrite"
)


@pytest.fixture
def io_manager_append(tmp_path) -> DeltaLakePolarsIOManager:
return DeltaLakePolarsIOManager(
root_uri=str(tmp_path), storage_options=LocalConfig(), mode="append"
)


@pytest.fixture
def io_manager_ignore(tmp_path) -> DeltaLakePolarsIOManager:
return DeltaLakePolarsIOManager(
root_uri=str(tmp_path), storage_options=LocalConfig(), mode="ignore"
)


@op(out=Out(metadata={"schema": "a_df"}))
Expand Down Expand Up @@ -63,6 +79,49 @@ def test_deltalake_io_manager_with_ops(tmp_path, io_manager):
assert out_df["a"].to_pylist() == [2, 3, 4]


@graph
def just_a_df():
a_df()


def test_deltalake_io_manager_with_ops_appended(tmp_path, io_manager_append):
resource_defs = {"io_manager": io_manager_append}

job = just_a_df.to_job(resource_defs=resource_defs)

# run the job twice to ensure tables get appended
expected_result1 = [1, 2, 3]

for _ in range(2):
res = job.execute_in_process()

assert res.success

dt = DeltaTable(os.path.join(tmp_path, "a_df/result"))
out_df = dt.to_pyarrow_table()
assert out_df["a"].to_pylist() == expected_result1

expected_result1.extend(expected_result1)


def test_deltalake_io_manager_with_ops_ignored(tmp_path, io_manager_ignore):
resource_defs = {"io_manager": io_manager_ignore}

job = just_a_df.to_job(resource_defs=resource_defs)

# run the job 5 times to ensure tables gets ignored on each write
for _ in range(5):
res = job.execute_in_process()

assert res.success

dt = DeltaTable(os.path.join(tmp_path, "a_df/result"))
out_df = dt.to_pyarrow_table()
assert out_df["a"].to_pylist() == [1, 2, 3]

assert dt.version() == 0


@asset(key_prefix=["my_schema"])
def b_df() -> pl.DataFrame:
return pl.DataFrame({"a": [1, 2, 3], "b": [4, 5, 6]})
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -58,6 +58,17 @@ def handle_output(
reader, delta_params = self.to_arrow(obj=obj)
delta_schema = Schema.from_pyarrow(reader.schema)

save_mode = context.metadata.get("mode")
main_save_mode = context.resource_config.get("mode")
if save_mode is not None:
context.log.info(
"IO manager mode overridden with the asset metadata mode, %s -> %s",
main_save_mode,
save_mode,
)
main_save_mode = save_mode
context.log.info("Writing with mode: %s", main_save_mode)

partition_filters = None
partition_columns = None
if table_slice.partition_dimensions is not None:
Expand All @@ -69,13 +80,12 @@ def handle_output(

# TODO make robust and move to function
partition_columns = [dim.partition_expr for dim in table_slice.partition_dimensions]
context.log.info('The save mode that will be used %s', context.resource_config.get('mode')) # type: ignore


write_deltalake(
table_or_uri=connection.table_uri,
data=reader,
storage_options=connection.storage_options,
mode=context.resource_config.get('mode'), # type: ignore
mode=main_save_mode,
partition_filters=partition_filters,
partition_by=partition_columns,
**delta_params,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,6 @@
TableSlice,
)
from pydantic import Field
from enum import Enum

if sys.version_info >= (3, 8):
from typing import TypedDict
Expand Down Expand Up @@ -46,12 +45,6 @@ class _StorageOptionsConfig(TypedDict, total=False):
azure: Dict[str, str]
gcs: Dict[str, str]

class _DeltaWriteMode(str, Enum):
error = "error"
append = "append"
overwrite = "overwrite"
ignore = "ignore"


class _DeltaTableIOManagerResourceConfig(TypedDict):
root_uri: str
Expand Down Expand Up @@ -115,8 +108,8 @@ def my_table_a(my_table: pd.DataFrame):

root_uri: str = Field(description="Storage location where Delta tables are stored.")

mode: str = Field(default='overwrite', description="The write mode passed to save the output.")
mode: str = Field(default="overwrite", description="The write mode passed to save the output.")

storage_options: Union[AzureConfig, S3Config, LocalConfig, GcsConfig] = Field(
discriminator="provider",
description="Configuration for accessing storage location.",
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -42,9 +42,7 @@ def my_table(delta_table: DeltaTableResource):
default=None, description="Additional configuration passed to http client."
)

version: Optional[int] = Field(
default = None, description="Version to load delta table."
)
version: Optional[int] = Field(default=None, description="Version to load delta table.")

def load(self) -> DeltaTable:
storage_options = self.storage_options.dict() if self.storage_options else {}
Expand Down

0 comments on commit b9d44c9

Please sign in to comment.