diff --git a/python/deltalake/_internal.pyi b/python/deltalake/_internal.pyi index d9adbe7cbc..a58139ab3b 100644 --- a/python/deltalake/_internal.pyi +++ b/python/deltalake/_internal.pyi @@ -103,15 +103,15 @@ class RawDeltaTable: target_alias: Optional[str], writer_properties: Optional[Dict[str, int | None]], safe_cast: bool, - matched_update_updates: Optional[Dict[str, str]], - matched_update_predicate: Optional[str], - matched_delete_predicate: Optional[str], + matched_update_updates: Optional[List[Dict[str, str]]], + matched_update_predicate: Optional[List[Optional[str]]], + matched_delete_predicate: Optional[List[str]], matched_delete_all: Optional[bool], - not_matched_insert_updates: Optional[Dict[str, str]], - not_matched_insert_predicate: Optional[str], - not_matched_by_source_update_updates: Optional[Dict[str, str]], - not_matched_by_source_update_predicate: Optional[str], - not_matched_by_source_delete_predicate: Optional[str], + not_matched_insert_updates: Optional[List[Dict[str, str]]], + not_matched_insert_predicate: Optional[List[Optional[str]]], + not_matched_by_source_update_updates: Optional[List[Dict[str, str]]], + not_matched_by_source_update_predicate: Optional[List[Optional[str]]], + not_matched_by_source_delete_predicate: Optional[List[str]], not_matched_by_source_delete_all: Optional[bool], ) -> str: ... def get_active_partitions( diff --git a/python/deltalake/table.py b/python/deltalake/table.py index 51355477b5..3f6930e151 100644 --- a/python/deltalake/table.py +++ b/python/deltalake/table.py @@ -772,15 +772,17 @@ def __init__( self.target_alias = target_alias self.safe_cast = safe_cast self.writer_properties: Optional[Dict[str, Optional[int]]] = None - self.matched_update_updates: Optional[Dict[str, str]] = None - self.matched_update_predicate: Optional[str] = None - self.matched_delete_predicate: Optional[str] = None + self.matched_update_updates: Optional[List[Dict[str, str]]] = None + self.matched_update_predicate: Optional[List[Optional[str]]] = None + self.matched_delete_predicate: Optional[List[str]] = None self.matched_delete_all: Optional[bool] = None - self.not_matched_insert_updates: Optional[Dict[str, str]] = None - self.not_matched_insert_predicate: Optional[str] = None - self.not_matched_by_source_update_updates: Optional[Dict[str, str]] = None - self.not_matched_by_source_update_predicate: Optional[str] = None - self.not_matched_by_source_delete_predicate: Optional[str] = None + self.not_matched_insert_updates: Optional[List[Dict[str, str]]] = None + self.not_matched_insert_predicate: Optional[List[Optional[str]]] = None + self.not_matched_by_source_update_updates: Optional[List[Dict[str, str]]] = None + self.not_matched_by_source_update_predicate: Optional[ + List[Optional[str]] + ] = None + self.not_matched_by_source_delete_predicate: Optional[List[str]] = None self.not_matched_by_source_delete_all: Optional[bool] = None def with_writer_properties( @@ -840,8 +842,15 @@ def when_matched_update( ... } ... ).execute() """ - self.matched_update_updates = updates - self.matched_update_predicate = predicate + + if isinstance(self.matched_update_updates, list) and isinstance( + self.matched_update_predicate, list + ): + self.matched_update_updates.append(updates) + self.matched_update_predicate.append(predicate) + else: + self.matched_update_updates = [updates] + self.matched_update_predicate = [predicate] return self def when_matched_update_all(self, predicate: Optional[str] = None) -> "TableMerger": @@ -867,11 +876,20 @@ def when_matched_update_all(self, predicate: Optional[str] = None) -> "TableMerg src_alias = (self.source_alias + ".") if self.source_alias is not None else "" trgt_alias = (self.target_alias + ".") if self.target_alias is not None else "" - self.matched_update_updates = { + updates = { f"{trgt_alias}{col.name}": f"{src_alias}{col.name}" for col in self.source.schema } - self.matched_update_predicate = predicate + + if isinstance(self.matched_update_updates, list) and isinstance( + self.matched_update_predicate, list + ): + self.matched_update_updates.append(updates) + self.matched_update_predicate.append(predicate) + else: + self.matched_update_updates = [updates] + self.matched_update_predicate = [predicate] + return self def when_matched_delete(self, predicate: Optional[str] = None) -> "TableMerger": @@ -906,11 +924,19 @@ def when_matched_delete(self, predicate: Optional[str] = None) -> "TableMerger": ... .when_matched_delete() ... .execute() """ + if self.matched_delete_all is not None: + raise ValueError( + """when_matched_delete without a predicate has already been set, which means + it will delete all, any subsequent when_matched_delete, won't make sense.""" + ) if predicate is None: self.matched_delete_all = True else: - self.matched_delete_predicate = predicate + if isinstance(self.matched_delete_predicate, list): + self.matched_delete_predicate.append(predicate) + else: + self.matched_delete_predicate = [predicate] return self def when_not_matched_insert( @@ -941,8 +967,14 @@ def when_not_matched_insert( ... ).execute() """ - self.not_matched_insert_updates = updates - self.not_matched_insert_predicate = predicate + if isinstance(self.not_matched_insert_updates, list) and isinstance( + self.not_matched_insert_predicate, list + ): + self.not_matched_insert_updates.append(updates) + self.not_matched_insert_predicate.append(predicate) + else: + self.not_matched_insert_updates = [updates] + self.not_matched_insert_predicate = [predicate] return self @@ -971,11 +1003,19 @@ def when_not_matched_insert_all( src_alias = (self.source_alias + ".") if self.source_alias is not None else "" trgt_alias = (self.target_alias + ".") if self.target_alias is not None else "" - self.not_matched_insert_updates = { + updates = { f"{trgt_alias}{col.name}": f"{src_alias}{col.name}" for col in self.source.schema } - self.not_matched_insert_predicate = predicate + if isinstance(self.not_matched_insert_updates, list) and isinstance( + self.not_matched_insert_predicate, list + ): + self.not_matched_insert_updates.append(updates) + self.not_matched_insert_predicate.append(predicate) + else: + self.not_matched_insert_updates = [updates] + self.not_matched_insert_predicate = [predicate] + return self def when_not_matched_by_source_update( @@ -1003,8 +1043,15 @@ def when_not_matched_by_source_update( ... } ... ).execute() """ - self.not_matched_by_source_update_updates = updates - self.not_matched_by_source_update_predicate = predicate + + if isinstance(self.not_matched_by_source_update_updates, list) and isinstance( + self.not_matched_by_source_update_predicate, list + ): + self.not_matched_by_source_update_updates.append(updates) + self.not_matched_by_source_update_predicate.append(predicate) + else: + self.not_matched_by_source_update_updates = [updates] + self.not_matched_by_source_update_predicate = [predicate] return self def when_not_matched_by_source_delete( @@ -1020,11 +1067,19 @@ def when_not_matched_by_source_delete( Returns: TableMerger: TableMerger Object """ + if self.not_matched_by_source_delete_all is not None: + raise ValueError( + """when_not_matched_by_source_delete without a predicate has already been set, which means + it will delete all, any subsequent when_not_matched_by_source_delete, won't make sense.""" + ) if predicate is None: self.not_matched_by_source_delete_all = True else: - self.not_matched_by_source_delete_predicate = predicate + if isinstance(self.not_matched_by_source_delete_predicate, list): + self.not_matched_by_source_delete_predicate.append(predicate) + else: + self.not_matched_by_source_delete_predicate = [predicate] return self def execute(self) -> Dict[str, Any]: diff --git a/python/src/lib.rs b/python/src/lib.rs index 2f46436984..9343f3c3a1 100644 --- a/python/src/lib.rs +++ b/python/src/lib.rs @@ -426,15 +426,15 @@ impl RawDeltaTable { target_alias: Option, safe_cast: bool, writer_properties: Option>, - matched_update_updates: Option>, - matched_update_predicate: Option, - matched_delete_predicate: Option, + matched_update_updates: Option>>, + matched_update_predicate: Option>>, + matched_delete_predicate: Option>, matched_delete_all: Option, - not_matched_insert_updates: Option>, - not_matched_insert_predicate: Option, - not_matched_by_source_update_updates: Option>, - not_matched_by_source_update_predicate: Option, - not_matched_by_source_delete_predicate: Option, + not_matched_insert_updates: Option>>, + not_matched_insert_predicate: Option>>, + not_matched_by_source_update_updates: Option>>, + not_matched_by_source_update_predicate: Option>>, + not_matched_by_source_delete_predicate: Option>, not_matched_by_source_delete_all: Option, ) -> PyResult { let ctx = SessionContext::new(); @@ -488,23 +488,29 @@ impl RawDeltaTable { if let Some(mu_updates) = matched_update_updates { if let Some(mu_predicate) = matched_update_predicate { - cmd = cmd - .when_matched_update(|mut update| { - for (col_name, expression) in mu_updates { - update = update.update(col_name.clone(), expression.clone()); - } - update.predicate(mu_predicate) - }) - .map_err(PythonError::from)?; - } else { - cmd = cmd - .when_matched_update(|mut update| { - for (col_name, expression) in mu_updates { - update = update.update(col_name.clone(), expression.clone()); - } - update - }) - .map_err(PythonError::from)?; + for it in mu_updates.iter().zip(mu_predicate.iter()) { + let (update_values, predicate_value) = it; + + if let Some(pred) = predicate_value { + cmd = cmd + .when_matched_update(|mut update| { + for (col_name, expression) in update_values { + update = update.update(col_name.clone(), expression.clone()); + } + update.predicate(pred.clone()) + }) + .map_err(PythonError::from)?; + } else { + cmd = cmd + .when_matched_update(|mut update| { + for (col_name, expression) in update_values { + update = update.update(col_name.clone(), expression.clone()); + } + update + }) + .map_err(PythonError::from)?; + } + } } } @@ -513,52 +519,64 @@ impl RawDeltaTable { .when_matched_delete(|delete| delete) .map_err(PythonError::from)?; } else if let Some(md_predicate) = matched_delete_predicate { - cmd = cmd - .when_matched_delete(|delete| delete.predicate(md_predicate)) - .map_err(PythonError::from)?; + for pred in md_predicate.iter() { + cmd = cmd + .when_matched_delete(|delete| delete.predicate(pred.clone())) + .map_err(PythonError::from)?; + } } if let Some(nmi_updates) = not_matched_insert_updates { if let Some(nmi_predicate) = not_matched_insert_predicate { - cmd = cmd - .when_not_matched_insert(|mut insert| { - for (col_name, expression) in nmi_updates { - insert = insert.set(col_name.clone(), expression.clone()); - } - insert.predicate(nmi_predicate) - }) - .map_err(PythonError::from)?; - } else { - cmd = cmd - .when_not_matched_insert(|mut insert| { - for (col_name, expression) in nmi_updates { - insert = insert.set(col_name.clone(), expression.clone()); - } - insert - }) - .map_err(PythonError::from)?; + for it in nmi_updates.iter().zip(nmi_predicate.iter()) { + let (update_values, predicate_value) = it; + if let Some(pred) = predicate_value { + cmd = cmd + .when_not_matched_insert(|mut insert| { + for (col_name, expression) in update_values { + insert = insert.set(col_name.clone(), expression.clone()); + } + insert.predicate(pred.clone()) + }) + .map_err(PythonError::from)?; + } else { + cmd = cmd + .when_not_matched_insert(|mut insert| { + for (col_name, expression) in update_values { + insert = insert.set(col_name.clone(), expression.clone()); + } + insert + }) + .map_err(PythonError::from)?; + } + } } } if let Some(nmbsu_updates) = not_matched_by_source_update_updates { if let Some(nmbsu_predicate) = not_matched_by_source_update_predicate { - cmd = cmd - .when_not_matched_by_source_update(|mut update| { - for (col_name, expression) in nmbsu_updates { - update = update.update(col_name.clone(), expression.clone()); - } - update.predicate(nmbsu_predicate) - }) - .map_err(PythonError::from)?; - } else { - cmd = cmd - .when_not_matched_by_source_update(|mut update| { - for (col_name, expression) in nmbsu_updates { - update = update.update(col_name.clone(), expression.clone()); - } - update - }) - .map_err(PythonError::from)?; + for it in nmbsu_updates.iter().zip(nmbsu_predicate.iter()) { + let (update_values, predicate_value) = it; + if let Some(pred) = predicate_value { + cmd = cmd + .when_not_matched_by_source_update(|mut update| { + for (col_name, expression) in update_values { + update = update.update(col_name.clone(), expression.clone()); + } + update.predicate(pred.clone()) + }) + .map_err(PythonError::from)?; + } else { + cmd = cmd + .when_not_matched_by_source_update(|mut update| { + for (col_name, expression) in update_values { + update = update.update(col_name.clone(), expression.clone()); + } + update + }) + .map_err(PythonError::from)?; + } + } } } @@ -567,9 +585,11 @@ impl RawDeltaTable { .when_not_matched_by_source_delete(|delete| delete) .map_err(PythonError::from)?; } else if let Some(nmbs_predicate) = not_matched_by_source_delete_predicate { - cmd = cmd - .when_not_matched_by_source_delete(|delete| delete.predicate(nmbs_predicate)) - .map_err(PythonError::from)?; + for pred in nmbs_predicate.iter() { + cmd = cmd + .when_not_matched_by_source_delete(|delete| delete.predicate(pred.clone())) + .map_err(PythonError::from)?; + } } let (table, metrics) = rt()? diff --git a/python/tests/test_merge.py b/python/tests/test_merge.py index fc08563443..76e3543160 100644 --- a/python/tests/test_merge.py +++ b/python/tests/test_merge.py @@ -490,3 +490,47 @@ def test_merge_when_not_matched_by_source_delete_wo_predicate( assert last_action["operation"] == "MERGE" assert result == expected + + +def test_merge_multiple_when_matched_update_with_predicate( + tmp_path: pathlib.Path, sample_table: pa.Table +): + write_deltalake(tmp_path, sample_table, mode="append") + + dt = DeltaTable(tmp_path) + + source_table = pa.table( + { + "id": pa.array(["4", "5"]), + "price": pa.array([10, 100], pa.int64()), + "sold": pa.array([10, 20], pa.int32()), + "deleted": pa.array([False, True]), + } + ) + + dt.merge( + source=source_table, + source_alias="source", + target_alias="target", + predicate="target.id = source.id", + ).when_matched_update( + updates={"price": "source.price", "sold": "source.sold"}, + predicate="source.deleted = False", + ).when_matched_update( + updates={"price": "source.price", "sold": "source.sold"}, + predicate="source.deleted = True", + ).execute() + + expected = pa.table( + { + "id": pa.array(["1", "2", "3", "4", "5"]), + "price": pa.array([0, 1, 2, 10, 100], pa.int64()), + "sold": pa.array([0, 1, 2, 10, 20], pa.int32()), + "deleted": pa.array([False] * 5), + } + ) + result = dt.to_pyarrow_table().sort_by([("id", "ascending")]) + last_action = dt.history(1)[0] + + assert last_action["operation"] == "MERGE" + assert result == expected