Skip to content

Commit

Permalink
add save mode
Browse files Browse the repository at this point in the history
  • Loading branch information
ion-elgreco committed Jan 23, 2024
1 parent a643d87 commit 32b4768
Show file tree
Hide file tree
Showing 3 changed files with 16 additions and 3 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -69,12 +69,13 @@ def handle_output(

# TODO make robust and move to function
partition_columns = [dim.partition_expr for dim in table_slice.partition_dimensions]

context.log.info('The save mode that will be used %s', context.resource_config.get('mode')) # type: ignore

write_deltalake(
table_or_uri=connection.table_uri,
data=reader,
storage_options=connection.storage_options,
mode="overwrite",
mode=context.resource_config.get('mode'), # type: ignore
partition_filters=partition_filters,
partition_by=partition_columns,
**delta_params,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
TableSlice,
)
from pydantic import Field
from enum import Enum

if sys.version_info >= (3, 8):
from typing import TypedDict
Expand Down Expand Up @@ -45,9 +46,16 @@ class _StorageOptionsConfig(TypedDict, total=False):
azure: Dict[str, str]
gcs: Dict[str, str]

class _DeltaWriteMode(str, Enum):
error = "error"
append = "append"
overwrite = "overwrite"
ignore = "ignore"


class _DeltaTableIOManagerResourceConfig(TypedDict):
root_uri: str
mode: str
storage_options: _StorageOptionsConfig
client_options: NotRequired[Dict[str, str]]
table_config: NotRequired[Dict[str, str]]
Expand Down Expand Up @@ -107,6 +115,8 @@ def my_table_a(my_table: pd.DataFrame):

root_uri: str = Field(description="Storage location where Delta tables are stored.")

mode: str = Field(default='overwrite', description="The write mode passed to save the output.")

storage_options: Union[AzureConfig, S3Config, LocalConfig, GcsConfig] = Field(
discriminator="provider",
description="Configuration for accessing storage location.",
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,9 @@ def my_table(delta_table: DeltaTableResource):
default=None, description="Additional configuration passed to http client."
)

version: Optional[int]
version: Optional[int] = Field(
default = None, description="Version to load delta table."
)

def load(self) -> DeltaTable:
storage_options = self.storage_options.dict() if self.storage_options else {}
Expand Down

0 comments on commit 32b4768

Please sign in to comment.