Skip to content

Commit

Permalink
Add functionality for multiple when calls
Browse files Browse the repository at this point in the history
  • Loading branch information
ion-elgreco committed Oct 31, 2023
1 parent 8054882 commit 9c335d4
Show file tree
Hide file tree
Showing 4 changed files with 212 additions and 93 deletions.
16 changes: 8 additions & 8 deletions python/deltalake/_internal.pyi
Original file line number Diff line number Diff line change
Expand Up @@ -104,15 +104,15 @@ class RawDeltaTable:
target_alias: Optional[str],
writer_properties: Optional[Dict[str, int | None]],
safe_cast: bool,
matched_update_updates: Optional[Dict[str, str]],
matched_update_predicate: Optional[str],
matched_delete_predicate: Optional[str],
matched_update_updates: Optional[List[Dict[str, str]]],
matched_update_predicate: Optional[List[Optional[str]]],
matched_delete_predicate: Optional[List[str]],
matched_delete_all: Optional[bool],
not_matched_insert_updates: Optional[Dict[str, str]],
not_matched_insert_predicate: Optional[str],
not_matched_by_source_update_updates: Optional[Dict[str, str]],
not_matched_by_source_update_predicate: Optional[str],
not_matched_by_source_delete_predicate: Optional[str],
not_matched_insert_updates: Optional[List[Dict[str, str]]],
not_matched_insert_predicate: Optional[List[Optional[str]]],
not_matched_by_source_update_updates: Optional[List[Dict[str, str]]],
not_matched_by_source_update_predicate: Optional[List[Optional[str]]],
not_matched_by_source_delete_predicate: Optional[List[str]],
not_matched_by_source_delete_all: Optional[bool],
) -> str: ...
def get_active_partitions(
Expand Down
95 changes: 75 additions & 20 deletions python/deltalake/table.py
Original file line number Diff line number Diff line change
Expand Up @@ -871,15 +871,17 @@ def __init__(
self.target_alias = target_alias
self.safe_cast = safe_cast
self.writer_properties: Optional[Dict[str, Optional[int]]] = None
self.matched_update_updates: Optional[Dict[str, str]] = None
self.matched_update_predicate: Optional[str] = None
self.matched_delete_predicate: Optional[str] = None
self.matched_update_updates: Optional[List[Dict[str, str]]] = None
self.matched_update_predicate: Optional[List[Optional[str]]] = None
self.matched_delete_predicate: Optional[List[str]] = None
self.matched_delete_all: Optional[bool] = None
self.not_matched_insert_updates: Optional[Dict[str, str]] = None
self.not_matched_insert_predicate: Optional[str] = None
self.not_matched_by_source_update_updates: Optional[Dict[str, str]] = None
self.not_matched_by_source_update_predicate: Optional[str] = None
self.not_matched_by_source_delete_predicate: Optional[str] = None
self.not_matched_insert_updates: Optional[List[Dict[str, str]]] = None
self.not_matched_insert_predicate: Optional[List[Optional[str]]] = None
self.not_matched_by_source_update_updates: Optional[List[Dict[str, str]]] = None
self.not_matched_by_source_update_predicate: Optional[
List[Optional[str]]
] = None
self.not_matched_by_source_delete_predicate: Optional[List[str]] = None
self.not_matched_by_source_delete_all: Optional[bool] = None

def with_writer_properties(
Expand Down Expand Up @@ -939,8 +941,15 @@ def when_matched_update(
... }
... ).execute()
"""
self.matched_update_updates = updates
self.matched_update_predicate = predicate

if isinstance(self.matched_update_updates, list) and isinstance(
self.matched_update_predicate, list
):
self.matched_update_updates.append(updates)
self.matched_update_predicate.append(predicate)
else:
self.matched_update_updates = [updates]
self.matched_update_predicate = [predicate]
return self

def when_matched_update_all(self, predicate: Optional[str] = None) -> "TableMerger":
Expand All @@ -966,11 +975,20 @@ def when_matched_update_all(self, predicate: Optional[str] = None) -> "TableMerg
src_alias = (self.source_alias + ".") if self.source_alias is not None else ""
trgt_alias = (self.target_alias + ".") if self.target_alias is not None else ""

self.matched_update_updates = {
updates = {
f"{trgt_alias}{col.name}": f"{src_alias}{col.name}"
for col in self.source.schema
}
self.matched_update_predicate = predicate

if isinstance(self.matched_update_updates, list) and isinstance(
self.matched_update_predicate, list
):
self.matched_update_updates.append(updates)
self.matched_update_predicate.append(predicate)
else:
self.matched_update_updates = [updates]
self.matched_update_predicate = [predicate]

return self

def when_matched_delete(self, predicate: Optional[str] = None) -> "TableMerger":
Expand Down Expand Up @@ -1005,11 +1023,19 @@ def when_matched_delete(self, predicate: Optional[str] = None) -> "TableMerger":
... .when_matched_delete()
... .execute()
"""
if self.matched_delete_all is not None:
raise ValueError(
"""when_matched_delete without a predicate has already been set, which means
it will delete all, any subsequent when_matched_delete, won't make sense."""
)

if predicate is None:
self.matched_delete_all = True
else:
self.matched_delete_predicate = predicate
if isinstance(self.matched_delete_predicate, list):
self.matched_delete_predicate.append(predicate)
else:
self.matched_delete_predicate = [predicate]
return self

def when_not_matched_insert(
Expand Down Expand Up @@ -1040,8 +1066,14 @@ def when_not_matched_insert(
... ).execute()
"""

self.not_matched_insert_updates = updates
self.not_matched_insert_predicate = predicate
if isinstance(self.not_matched_insert_updates, list) and isinstance(
self.not_matched_insert_predicate, list
):
self.not_matched_insert_updates.append(updates)
self.not_matched_insert_predicate.append(predicate)
else:
self.not_matched_insert_updates = [updates]
self.not_matched_insert_predicate = [predicate]

return self

Expand Down Expand Up @@ -1070,11 +1102,19 @@ def when_not_matched_insert_all(

src_alias = (self.source_alias + ".") if self.source_alias is not None else ""
trgt_alias = (self.target_alias + ".") if self.target_alias is not None else ""
self.not_matched_insert_updates = {
updates = {
f"{trgt_alias}{col.name}": f"{src_alias}{col.name}"
for col in self.source.schema
}
self.not_matched_insert_predicate = predicate
if isinstance(self.not_matched_insert_updates, list) and isinstance(
self.not_matched_insert_predicate, list
):
self.not_matched_insert_updates.append(updates)
self.not_matched_insert_predicate.append(predicate)
else:
self.not_matched_insert_updates = [updates]
self.not_matched_insert_predicate = [predicate]

return self

def when_not_matched_by_source_update(
Expand Down Expand Up @@ -1102,8 +1142,15 @@ def when_not_matched_by_source_update(
... }
... ).execute()
"""
self.not_matched_by_source_update_updates = updates
self.not_matched_by_source_update_predicate = predicate

if isinstance(self.not_matched_by_source_update_updates, list) and isinstance(
self.not_matched_by_source_update_predicate, list
):
self.not_matched_by_source_update_updates.append(updates)
self.not_matched_by_source_update_predicate.append(predicate)
else:
self.not_matched_by_source_update_updates = [updates]
self.not_matched_by_source_update_predicate = [predicate]
return self

def when_not_matched_by_source_delete(
Expand All @@ -1118,11 +1165,19 @@ def when_not_matched_by_source_delete(
Returns:
TableMerger: TableMerger Object
"""
if self.not_matched_by_source_delete_all is not None:
raise ValueError(
"""when_not_matched_by_source_delete without a predicate has already been set, which means
it will delete all, any subsequent when_not_matched_by_source_delete, won't make sense."""
)

if predicate is None:
self.not_matched_by_source_delete_all = True
else:
self.not_matched_by_source_delete_predicate = predicate
if isinstance(self.not_matched_by_source_delete_predicate, list):
self.not_matched_by_source_delete_predicate.append(predicate)
else:
self.not_matched_by_source_delete_predicate = [predicate]
return self

def execute(self) -> Dict[str, Any]:
Expand Down
150 changes: 85 additions & 65 deletions python/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -427,15 +427,15 @@ impl RawDeltaTable {
target_alias: Option<String>,
safe_cast: bool,
writer_properties: Option<HashMap<String, usize>>,
matched_update_updates: Option<HashMap<String, String>>,
matched_update_predicate: Option<String>,
matched_delete_predicate: Option<String>,
matched_update_updates: Option<Vec<HashMap<String, String>>>,
matched_update_predicate: Option<Vec<Option<String>>>,
matched_delete_predicate: Option<Vec<String>>,
matched_delete_all: Option<bool>,
not_matched_insert_updates: Option<HashMap<String, String>>,
not_matched_insert_predicate: Option<String>,
not_matched_by_source_update_updates: Option<HashMap<String, String>>,
not_matched_by_source_update_predicate: Option<String>,
not_matched_by_source_delete_predicate: Option<String>,
not_matched_insert_updates: Option<Vec<HashMap<String, String>>>,
not_matched_insert_predicate: Option<Vec<Option<String>>>,
not_matched_by_source_update_updates: Option<Vec<HashMap<String, String>>>,
not_matched_by_source_update_predicate: Option<Vec<Option<String>>>,
not_matched_by_source_delete_predicate: Option<Vec<String>>,
not_matched_by_source_delete_all: Option<bool>,
) -> PyResult<String> {
let ctx = SessionContext::new();
Expand Down Expand Up @@ -489,23 +489,29 @@ impl RawDeltaTable {

if let Some(mu_updates) = matched_update_updates {
if let Some(mu_predicate) = matched_update_predicate {
cmd = cmd
.when_matched_update(|mut update| {
for (col_name, expression) in mu_updates {
update = update.update(col_name.clone(), expression.clone());
}
update.predicate(mu_predicate)
})
.map_err(PythonError::from)?;
} else {
cmd = cmd
.when_matched_update(|mut update| {
for (col_name, expression) in mu_updates {
update = update.update(col_name.clone(), expression.clone());
}
update
})
.map_err(PythonError::from)?;
for it in mu_updates.iter().zip(mu_predicate.iter()) {
let (update_values, predicate_value) = it;

if let Some(pred) = predicate_value {
cmd = cmd
.when_matched_update(|mut update| {
for (col_name, expression) in update_values {
update = update.update(col_name.clone(), expression.clone());
}
update.predicate(pred.clone())
})
.map_err(PythonError::from)?;
} else {
cmd = cmd
.when_matched_update(|mut update| {
for (col_name, expression) in update_values {
update = update.update(col_name.clone(), expression.clone());
}
update
})
.map_err(PythonError::from)?;
}
}
}
}

Expand All @@ -514,52 +520,64 @@ impl RawDeltaTable {
.when_matched_delete(|delete| delete)
.map_err(PythonError::from)?;
} else if let Some(md_predicate) = matched_delete_predicate {
cmd = cmd
.when_matched_delete(|delete| delete.predicate(md_predicate))
.map_err(PythonError::from)?;
for pred in md_predicate.iter() {
cmd = cmd
.when_matched_delete(|delete| delete.predicate(pred.clone()))
.map_err(PythonError::from)?;
}
}

if let Some(nmi_updates) = not_matched_insert_updates {
if let Some(nmi_predicate) = not_matched_insert_predicate {
cmd = cmd
.when_not_matched_insert(|mut insert| {
for (col_name, expression) in nmi_updates {
insert = insert.set(col_name.clone(), expression.clone());
}
insert.predicate(nmi_predicate)
})
.map_err(PythonError::from)?;
} else {
cmd = cmd
.when_not_matched_insert(|mut insert| {
for (col_name, expression) in nmi_updates {
insert = insert.set(col_name.clone(), expression.clone());
}
insert
})
.map_err(PythonError::from)?;
for it in nmi_updates.iter().zip(nmi_predicate.iter()) {
let (update_values, predicate_value) = it;
if let Some(pred) = predicate_value {
cmd = cmd
.when_not_matched_insert(|mut insert| {
for (col_name, expression) in update_values {
insert = insert.set(col_name.clone(), expression.clone());
}
insert.predicate(pred.clone())
})
.map_err(PythonError::from)?;
} else {
cmd = cmd
.when_not_matched_insert(|mut insert| {
for (col_name, expression) in update_values {
insert = insert.set(col_name.clone(), expression.clone());
}
insert
})
.map_err(PythonError::from)?;
}
}
}
}

if let Some(nmbsu_updates) = not_matched_by_source_update_updates {
if let Some(nmbsu_predicate) = not_matched_by_source_update_predicate {
cmd = cmd
.when_not_matched_by_source_update(|mut update| {
for (col_name, expression) in nmbsu_updates {
update = update.update(col_name.clone(), expression.clone());
}
update.predicate(nmbsu_predicate)
})
.map_err(PythonError::from)?;
} else {
cmd = cmd
.when_not_matched_by_source_update(|mut update| {
for (col_name, expression) in nmbsu_updates {
update = update.update(col_name.clone(), expression.clone());
}
update
})
.map_err(PythonError::from)?;
for it in nmbsu_updates.iter().zip(nmbsu_predicate.iter()) {
let (update_values, predicate_value) = it;
if let Some(pred) = predicate_value {
cmd = cmd
.when_not_matched_by_source_update(|mut update| {
for (col_name, expression) in update_values {
update = update.update(col_name.clone(), expression.clone());
}
update.predicate(pred.clone())
})
.map_err(PythonError::from)?;
} else {
cmd = cmd
.when_not_matched_by_source_update(|mut update| {
for (col_name, expression) in update_values {
update = update.update(col_name.clone(), expression.clone());
}
update
})
.map_err(PythonError::from)?;
}
}
}
}

Expand All @@ -568,9 +586,11 @@ impl RawDeltaTable {
.when_not_matched_by_source_delete(|delete| delete)
.map_err(PythonError::from)?;
} else if let Some(nmbs_predicate) = not_matched_by_source_delete_predicate {
cmd = cmd
.when_not_matched_by_source_delete(|delete| delete.predicate(nmbs_predicate))
.map_err(PythonError::from)?;
for pred in nmbs_predicate.iter() {
cmd = cmd
.when_not_matched_by_source_delete(|delete| delete.predicate(pred.clone()))
.map_err(PythonError::from)?;
}
}

let (table, metrics) = rt()?
Expand Down
Loading

0 comments on commit 9c335d4

Please sign in to comment.