Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat(python): allow for multiple when calls in MERGE operation #1750

Merged
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
16 changes: 8 additions & 8 deletions python/deltalake/_internal.pyi
Original file line number Diff line number Diff line change
Expand Up @@ -104,15 +104,15 @@ class RawDeltaTable:
target_alias: Optional[str],
writer_properties: Optional[Dict[str, int | None]],
safe_cast: bool,
matched_update_updates: Optional[Dict[str, str]],
matched_update_predicate: Optional[str],
matched_delete_predicate: Optional[str],
matched_update_updates: Optional[List[Dict[str, str]]],
matched_update_predicate: Optional[List[Optional[str]]],
matched_delete_predicate: Optional[List[str]],
matched_delete_all: Optional[bool],
not_matched_insert_updates: Optional[Dict[str, str]],
not_matched_insert_predicate: Optional[str],
not_matched_by_source_update_updates: Optional[Dict[str, str]],
not_matched_by_source_update_predicate: Optional[str],
not_matched_by_source_delete_predicate: Optional[str],
not_matched_insert_updates: Optional[List[Dict[str, str]]],
not_matched_insert_predicate: Optional[List[Optional[str]]],
not_matched_by_source_update_updates: Optional[List[Dict[str, str]]],
not_matched_by_source_update_predicate: Optional[List[Optional[str]]],
not_matched_by_source_delete_predicate: Optional[List[str]],
not_matched_by_source_delete_all: Optional[bool],
) -> str: ...
def get_active_partitions(
Expand Down
276 changes: 189 additions & 87 deletions python/deltalake/table.py
Original file line number Diff line number Diff line change
Expand Up @@ -904,7 +904,7 @@ def repair(self, dry_run: bool = False) -> Dict[str, Any]:


class TableMerger:
"""API for various table MERGE commands."""
"""API for various table `MERGE` commands."""

def __init__(
self,
Expand All @@ -922,15 +922,17 @@ def __init__(
self.target_alias = target_alias
self.safe_cast = safe_cast
self.writer_properties: Optional[Dict[str, Optional[int]]] = None
self.matched_update_updates: Optional[Dict[str, str]] = None
self.matched_update_predicate: Optional[str] = None
self.matched_delete_predicate: Optional[str] = None
self.matched_update_updates: Optional[List[Dict[str, str]]] = None
self.matched_update_predicate: Optional[List[Optional[str]]] = None
self.matched_delete_predicate: Optional[List[str]] = None
self.matched_delete_all: Optional[bool] = None
self.not_matched_insert_updates: Optional[Dict[str, str]] = None
self.not_matched_insert_predicate: Optional[str] = None
self.not_matched_by_source_update_updates: Optional[Dict[str, str]] = None
self.not_matched_by_source_update_predicate: Optional[str] = None
self.not_matched_by_source_delete_predicate: Optional[str] = None
self.not_matched_insert_updates: Optional[List[Dict[str, str]]] = None
self.not_matched_insert_predicate: Optional[List[Optional[str]]] = None
self.not_matched_by_source_update_updates: Optional[List[Dict[str, str]]] = None
self.not_matched_by_source_update_predicate: Optional[
List[Optional[str]]
] = None
self.not_matched_by_source_delete_predicate: Optional[List[str]] = None
self.not_matched_by_source_delete_all: Optional[bool] = None

def with_writer_properties(
Expand Down Expand Up @@ -975,23 +977,32 @@ def when_matched_update(

Returns:
TableMerger: TableMerger Object

Examples:

>>> from deltalake import DeltaTable
>>> import pyarrow as pa
>>> data = pa.table({"x": [1, 2, 3], "y": [4, 5, 6]})
>>> dt = DeltaTable("tmp")
>>> dt.merge(source=data, predicate='target.x = source.x', source_alias='source', target_alias='target') \
... .when_matched_update(
... updates = {
... "x": "source.x",
... "y": "source.y"
... }
... ).execute()
Examples:
```
from deltalake import DeltaTable
import pyarrow as pa
data = pa.table({"x": [1, 2, 3], "y": [4, 5, 6]})
dt = DeltaTable("tmp")
( \
dt.merge( \
source=data, \
predicate="target.x = source.x", \
source_alias="source", \
target_alias="target") \
.when_matched_update(updates={"x": "source.x", "y": "source.y"}) \
.execute() \
)
```
"""
self.matched_update_updates = updates
self.matched_update_predicate = predicate
if isinstance(self.matched_update_updates, list) and isinstance(
self.matched_update_predicate, list
):
self.matched_update_updates.append(updates)
self.matched_update_predicate.append(predicate)
else:
self.matched_update_updates = [updates]
self.matched_update_predicate = [predicate]
return self

def when_matched_update_all(self, predicate: Optional[str] = None) -> "TableMerger":
Expand All @@ -1006,22 +1017,40 @@ def when_matched_update_all(self, predicate: Optional[str] = None) -> "TableMerg

Examples:

>>> from deltalake import DeltaTable
>>> import pyarrow as pa
>>> data = pa.table({"x": [1, 2, 3], "y": [4, 5, 6]})
>>> dt = DeltaTable("tmp")
>>> dt.merge(source=data, predicate='target.x = source.x', source_alias='source', target_alias='target') \
... .when_matched_update_all().execute()
```
from deltalake import DeltaTable
import pyarrow as pa
data = pa.table({"x": [1, 2, 3], "y": [4, 5, 6]})
dt = DeltaTable("tmp")
(\
dt.merge( \
source=data, \
predicate='target.x = source.x', \
source_alias='source', \
target_alias='target') \
.when_matched_update_all() \
.execute() \
)
```
"""

src_alias = (self.source_alias + ".") if self.source_alias is not None else ""
trgt_alias = (self.target_alias + ".") if self.target_alias is not None else ""

self.matched_update_updates = {
updates = {
f"{trgt_alias}{col.name}": f"{src_alias}{col.name}"
for col in self.source.schema
}
self.matched_update_predicate = predicate

if isinstance(self.matched_update_updates, list) and isinstance(
self.matched_update_predicate, list
):
self.matched_update_updates.append(updates)
self.matched_update_predicate.append(predicate)
else:
self.matched_update_updates = [updates]
self.matched_update_predicate = [predicate]

return self

def when_matched_delete(self, predicate: Optional[str] = None) -> "TableMerger":
Expand All @@ -1037,30 +1066,50 @@ def when_matched_delete(self, predicate: Optional[str] = None) -> "TableMerger":
Examples:

Delete on a predicate

>>> from deltalake import DeltaTable
>>> import pyarrow as pa
>>> data = pa.table({"x": [1, 2, 3], "y": [4, 5, 6]})
>>> dt = DeltaTable("tmp")
>>> dt.merge(source=data, predicate='target.x = source.x', source_alias='source', target_alias='target') \
... .when_matched_delete(predicate = "source.deleted = true")
... .execute()

```
from deltalake import DeltaTable
import pyarrow as pa
data = pa.table({"x": [1, 2, 3], "y": [4, 5, 6]})
dt = DeltaTable("tmp")
( \
dt.merge( \
source=data, \
predicate='target.x = source.x', \
source_alias='source', \
target_alias='target') \
.when_matched_delete( \
predicate = "source.deleted = true") \
.execute() \
```
Delete all records that were matched

>>> from deltalake import DeltaTable
>>> import pyarrow as pa
>>> data = pa.table({"x": [1, 2, 3], "y": [4, 5, 6]})
>>> dt = DeltaTable("tmp")
>>> dt.merge(source=data, predicate='target.x = source.x', source_alias='source', target_alias='target') \
... .when_matched_delete()
... .execute()
```
from deltalake import DeltaTable
import pyarrow as pa
data = pa.table({"x": [1, 2, 3], "y": [4, 5, 6]})
dt = DeltaTable("tmp")
( \
dt.merge( \
source=data, \
predicate='target.x = source.x', \
source_alias='source', \
target_alias='target') \
.when_matched_delete() \
.execute() \
```
"""
if self.matched_delete_all is not None:
raise ValueError(
"""when_matched_delete without a predicate has already been set, which means
it will delete all, any subsequent when_matched_delete, won't make sense."""
)

if predicate is None:
self.matched_delete_all = True
else:
self.matched_delete_predicate = predicate
if isinstance(self.matched_delete_predicate, list):
self.matched_delete_predicate.append(predicate)
else:
self.matched_delete_predicate = [predicate]
return self

def when_not_matched_insert(
Expand All @@ -1078,21 +1127,35 @@ def when_not_matched_insert(

Examples:

>>> from deltalake import DeltaTable
>>> import pyarrow as pa
>>> data = pa.table({"x": [1, 2, 3], "y": [4, 5, 6]})
>>> dt = DeltaTable("tmp")
>>> dt.merge(source=data, predicate='target.x = source.x', source_alias='source', target_alias='target') \
... .when_not_matched_insert(
... updates = {
... "x": "source.x",
... "y": "source.y"
... }
... ).execute()
```
from deltalake import DeltaTable
import pyarrow as pa
data = pa.table({"x": [1, 2, 3], "y": [4, 5, 6]})
dt = DeltaTable("tmp")
( \
dt.merge( \
source=data, \
predicate='target.x = source.x', \
source_alias='source', \
target_alias='target') \
.when_not_matched_insert( \
updates = { \
"x": "source.x", \
"y": "source.y", \
}) \
.execute() \
)
```
"""

self.not_matched_insert_updates = updates
self.not_matched_insert_predicate = predicate
if isinstance(self.not_matched_insert_updates, list) and isinstance(
self.not_matched_insert_predicate, list
):
self.not_matched_insert_updates.append(updates)
self.not_matched_insert_predicate.append(predicate)
else:
self.not_matched_insert_updates = [updates]
self.not_matched_insert_predicate = [predicate]

return self

Expand All @@ -1111,21 +1174,39 @@ def when_not_matched_insert_all(

Examples:

>>> from deltalake import DeltaTable
>>> import pyarrow as pa
>>> data = pa.table({"x": [1, 2, 3], "y": [4, 5, 6]})
>>> dt = DeltaTable("tmp")
>>> dt.merge(source=data, predicate='target.x = source.x', source_alias='source', target_alias='target') \
... .when_not_matched_insert_all().execute()
```
from deltalake import DeltaTable
import pyarrow as pa
data = pa.table({"x": [1, 2, 3], "y": [4, 5, 6]})
dt = DeltaTable("tmp")
( \
dt \
.merge( \
source=data, \
predicate='target.x = source.x', \
source_alias='source', \
target_alias='target') \
.when_not_matched_insert_all() \
.execute() \
)
```
"""

src_alias = (self.source_alias + ".") if self.source_alias is not None else ""
trgt_alias = (self.target_alias + ".") if self.target_alias is not None else ""
self.not_matched_insert_updates = {
updates = {
f"{trgt_alias}{col.name}": f"{src_alias}{col.name}"
for col in self.source.schema
}
self.not_matched_insert_predicate = predicate
if isinstance(self.not_matched_insert_updates, list) and isinstance(
self.not_matched_insert_predicate, list
):
self.not_matched_insert_updates.append(updates)
self.not_matched_insert_predicate.append(predicate)
else:
self.not_matched_insert_updates = [updates]
self.not_matched_insert_predicate = [predicate]

return self

def when_not_matched_by_source_update(
Expand All @@ -1140,21 +1221,34 @@ def when_not_matched_by_source_update(

Returns:
TableMerger: TableMerger Object

>>> from deltalake import DeltaTable
>>> import pyarrow as pa
>>> data = pa.table({"x": [1, 2, 3], "y": [4, 5, 6]})
>>> dt = DeltaTable("tmp")
>>> dt.merge(source=data, predicate='target.x = source.x', source_alias='source', target_alias='target') \
... .when_not_matched_by_source_update(
... predicate = "y > 3"
... updates = {
... "y": "0",
... }
... ).execute()

```
from deltalake import DeltaTable
import pyarrow as pa
data = pa.table({"x": [1, 2, 3], "y": [4, 5, 6]})
dt = DeltaTable("tmp")
( \
dt.merge( \
source=data, \
predicate='target.x = source.x', \
source_alias='source', \
target_alias='target') \
.when_not_matched_by_source_update( \
predicate = "y > 3", \
updates = {"y": "0"}) \
.execute() \
) \
```
"""
self.not_matched_by_source_update_updates = updates
self.not_matched_by_source_update_predicate = predicate

if isinstance(self.not_matched_by_source_update_updates, list) and isinstance(
self.not_matched_by_source_update_predicate, list
):
self.not_matched_by_source_update_updates.append(updates)
self.not_matched_by_source_update_predicate.append(predicate)
else:
self.not_matched_by_source_update_updates = [updates]
self.not_matched_by_source_update_predicate = [predicate]
return self

def when_not_matched_by_source_delete(
Expand All @@ -1169,15 +1263,23 @@ def when_not_matched_by_source_delete(
Returns:
TableMerger: TableMerger Object
"""
if self.not_matched_by_source_delete_all is not None:
raise ValueError(
"""when_not_matched_by_source_delete without a predicate has already been set, which means
it will delete all, any subsequent when_not_matched_by_source_delete, won't make sense."""
)

if predicate is None:
self.not_matched_by_source_delete_all = True
else:
self.not_matched_by_source_delete_predicate = predicate
if isinstance(self.not_matched_by_source_delete_predicate, list):
self.not_matched_by_source_delete_predicate.append(predicate)
else:
self.not_matched_by_source_delete_predicate = [predicate]
return self

def execute(self) -> Dict[str, Any]:
"""Executes MERGE with the previously provided settings in Rust with Apache Datafusion query engine.
"""Executes `MERGE` with the previously provided settings in Rust with Apache Datafusion query engine.

Returns:
Dict[str, Any]: metrics
Expand Down
Loading
Loading