Skip to content

Commit

Permalink
Add functionality for multiple when calls
Browse files Browse the repository at this point in the history
  • Loading branch information
ion-elgreco committed Oct 21, 2023
1 parent a9cdd60 commit 787b5eb
Show file tree
Hide file tree
Showing 4 changed files with 212 additions and 93 deletions.
16 changes: 8 additions & 8 deletions python/deltalake/_internal.pyi
Original file line number Diff line number Diff line change
Expand Up @@ -103,15 +103,15 @@ class RawDeltaTable:
target_alias: Optional[str],
writer_properties: Optional[Dict[str, int | None]],
safe_cast: bool,
matched_update_updates: Optional[Dict[str, str]],
matched_update_predicate: Optional[str],
matched_delete_predicate: Optional[str],
matched_update_updates: Optional[List[Dict[str, str]]],
matched_update_predicate: Optional[List[Optional[str]]],
matched_delete_predicate: Optional[List[str]],
matched_delete_all: Optional[bool],
not_matched_insert_updates: Optional[Dict[str, str]],
not_matched_insert_predicate: Optional[str],
not_matched_by_source_update_updates: Optional[Dict[str, str]],
not_matched_by_source_update_predicate: Optional[str],
not_matched_by_source_delete_predicate: Optional[str],
not_matched_insert_updates: Optional[List[Dict[str, str]]],
not_matched_insert_predicate: Optional[List[Optional[str]]],
not_matched_by_source_update_updates: Optional[List[Dict[str, str]]],
not_matched_by_source_update_predicate: Optional[List[Optional[str]]],
not_matched_by_source_delete_predicate: Optional[List[str]],
not_matched_by_source_delete_all: Optional[bool],
) -> str: ...
def get_active_partitions(
Expand Down
95 changes: 75 additions & 20 deletions python/deltalake/table.py
Original file line number Diff line number Diff line change
Expand Up @@ -772,15 +772,17 @@ def __init__(
self.target_alias = target_alias
self.safe_cast = safe_cast
self.writer_properties: Optional[Dict[str, Optional[int]]] = None
self.matched_update_updates: Optional[Dict[str, str]] = None
self.matched_update_predicate: Optional[str] = None
self.matched_delete_predicate: Optional[str] = None
self.matched_update_updates: Optional[List[Dict[str, str]]] = None
self.matched_update_predicate: Optional[List[Optional[str]]] = None
self.matched_delete_predicate: Optional[List[str]] = None
self.matched_delete_all: Optional[bool] = None
self.not_matched_insert_updates: Optional[Dict[str, str]] = None
self.not_matched_insert_predicate: Optional[str] = None
self.not_matched_by_source_update_updates: Optional[Dict[str, str]] = None
self.not_matched_by_source_update_predicate: Optional[str] = None
self.not_matched_by_source_delete_predicate: Optional[str] = None
self.not_matched_insert_updates: Optional[List[Dict[str, str]]] = None
self.not_matched_insert_predicate: Optional[List[Optional[str]]] = None
self.not_matched_by_source_update_updates: Optional[List[Dict[str, str]]] = None
self.not_matched_by_source_update_predicate: Optional[
List[Optional[str]]
] = None
self.not_matched_by_source_delete_predicate: Optional[List[str]] = None
self.not_matched_by_source_delete_all: Optional[bool] = None

def with_writer_properties(
Expand Down Expand Up @@ -840,8 +842,15 @@ def when_matched_update(
... }
... ).execute()
"""
self.matched_update_updates = updates
self.matched_update_predicate = predicate

if isinstance(self.matched_update_updates, list) and isinstance(
self.matched_update_predicate, list
):
self.matched_update_updates.append(updates)
self.matched_update_predicate.append(predicate)
else:
self.matched_update_updates = [updates]
self.matched_update_predicate = [predicate]
return self

def when_matched_update_all(self, predicate: Optional[str] = None) -> "TableMerger":
Expand All @@ -867,11 +876,20 @@ def when_matched_update_all(self, predicate: Optional[str] = None) -> "TableMerg
src_alias = (self.source_alias + ".") if self.source_alias is not None else ""
trgt_alias = (self.target_alias + ".") if self.target_alias is not None else ""

self.matched_update_updates = {
updates = {
f"{trgt_alias}{col.name}": f"{src_alias}{col.name}"
for col in self.source.schema
}
self.matched_update_predicate = predicate

if isinstance(self.matched_update_updates, list) and isinstance(
self.matched_update_predicate, list
):
self.matched_update_updates.append(updates)
self.matched_update_predicate.append(predicate)
else:
self.matched_update_updates = [updates]
self.matched_update_predicate = [predicate]

return self

def when_matched_delete(self, predicate: Optional[str] = None) -> "TableMerger":
Expand Down Expand Up @@ -906,11 +924,19 @@ def when_matched_delete(self, predicate: Optional[str] = None) -> "TableMerger":
... .when_matched_delete()
... .execute()
"""
if self.matched_delete_all is not None:
raise ValueError(
"""when_matched_delete without a predicate has already been set, which means
it will delete all, any subsequent when_matched_delete, won't make sense."""
)

if predicate is None:
self.matched_delete_all = True
else:
self.matched_delete_predicate = predicate
if isinstance(self.matched_delete_predicate, list):
self.matched_delete_predicate.append(predicate)
else:
self.matched_delete_predicate = [predicate]
return self

def when_not_matched_insert(
Expand Down Expand Up @@ -941,8 +967,14 @@ def when_not_matched_insert(
... ).execute()
"""

self.not_matched_insert_updates = updates
self.not_matched_insert_predicate = predicate
if isinstance(self.not_matched_insert_updates, list) and isinstance(
self.not_matched_insert_predicate, list
):
self.not_matched_insert_updates.append(updates)
self.not_matched_insert_predicate.append(predicate)
else:
self.not_matched_insert_updates = [updates]
self.not_matched_insert_predicate = [predicate]

return self

Expand Down Expand Up @@ -971,11 +1003,19 @@ def when_not_matched_insert_all(

src_alias = (self.source_alias + ".") if self.source_alias is not None else ""
trgt_alias = (self.target_alias + ".") if self.target_alias is not None else ""
self.not_matched_insert_updates = {
updates = {
f"{trgt_alias}{col.name}": f"{src_alias}{col.name}"
for col in self.source.schema
}
self.not_matched_insert_predicate = predicate
if isinstance(self.not_matched_insert_updates, list) and isinstance(
self.not_matched_insert_predicate, list
):
self.not_matched_insert_updates.append(updates)
self.not_matched_insert_predicate.append(predicate)
else:
self.not_matched_insert_updates = [updates]
self.not_matched_insert_predicate = [predicate]

return self

def when_not_matched_by_source_update(
Expand Down Expand Up @@ -1003,8 +1043,15 @@ def when_not_matched_by_source_update(
... }
... ).execute()
"""
self.not_matched_by_source_update_updates = updates
self.not_matched_by_source_update_predicate = predicate

if isinstance(self.not_matched_by_source_update_updates, list) and isinstance(
self.not_matched_by_source_update_predicate, list
):
self.not_matched_by_source_update_updates.append(updates)
self.not_matched_by_source_update_predicate.append(predicate)
else:
self.not_matched_by_source_update_updates = [updates]
self.not_matched_by_source_update_predicate = [predicate]
return self

def when_not_matched_by_source_delete(
Expand All @@ -1020,11 +1067,19 @@ def when_not_matched_by_source_delete(
Returns:
TableMerger: TableMerger Object
"""
if self.not_matched_by_source_delete_all is not None:
raise ValueError(
"""when_not_matched_by_source_delete without a predicate has already been set, which means
it will delete all, any subsequent when_not_matched_by_source_delete, won't make sense."""
)

if predicate is None:
self.not_matched_by_source_delete_all = True
else:
self.not_matched_by_source_delete_predicate = predicate
if isinstance(self.not_matched_by_source_delete_predicate, list):
self.not_matched_by_source_delete_predicate.append(predicate)
else:
self.not_matched_by_source_delete_predicate = [predicate]
return self

def execute(self) -> Dict[str, Any]:
Expand Down
150 changes: 85 additions & 65 deletions python/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -426,15 +426,15 @@ impl RawDeltaTable {
target_alias: Option<String>,
safe_cast: bool,
writer_properties: Option<HashMap<String, usize>>,
matched_update_updates: Option<HashMap<String, String>>,
matched_update_predicate: Option<String>,
matched_delete_predicate: Option<String>,
matched_update_updates: Option<Vec<HashMap<String, String>>>,
matched_update_predicate: Option<Vec<Option<String>>>,
matched_delete_predicate: Option<Vec<String>>,
matched_delete_all: Option<bool>,
not_matched_insert_updates: Option<HashMap<String, String>>,
not_matched_insert_predicate: Option<String>,
not_matched_by_source_update_updates: Option<HashMap<String, String>>,
not_matched_by_source_update_predicate: Option<String>,
not_matched_by_source_delete_predicate: Option<String>,
not_matched_insert_updates: Option<Vec<HashMap<String, String>>>,
not_matched_insert_predicate: Option<Vec<Option<String>>>,
not_matched_by_source_update_updates: Option<Vec<HashMap<String, String>>>,
not_matched_by_source_update_predicate: Option<Vec<Option<String>>>,
not_matched_by_source_delete_predicate: Option<Vec<String>>,
not_matched_by_source_delete_all: Option<bool>,
) -> PyResult<String> {
let ctx = SessionContext::new();
Expand Down Expand Up @@ -488,23 +488,29 @@ impl RawDeltaTable {

if let Some(mu_updates) = matched_update_updates {
if let Some(mu_predicate) = matched_update_predicate {
cmd = cmd
.when_matched_update(|mut update| {
for (col_name, expression) in mu_updates {
update = update.update(col_name.clone(), expression.clone());
}
update.predicate(mu_predicate)
})
.map_err(PythonError::from)?;
} else {
cmd = cmd
.when_matched_update(|mut update| {
for (col_name, expression) in mu_updates {
update = update.update(col_name.clone(), expression.clone());
}
update
})
.map_err(PythonError::from)?;
for it in mu_updates.iter().zip(mu_predicate.iter()) {
let (update_values, predicate_value) = it;

if let Some(pred) = predicate_value {
cmd = cmd
.when_matched_update(|mut update| {
for (col_name, expression) in update_values {
update = update.update(col_name.clone(), expression.clone());
}
update.predicate(pred.clone())
})
.map_err(PythonError::from)?;
} else {
cmd = cmd
.when_matched_update(|mut update| {
for (col_name, expression) in update_values {
update = update.update(col_name.clone(), expression.clone());
}
update
})
.map_err(PythonError::from)?;
}
}
}
}

Expand All @@ -513,52 +519,64 @@ impl RawDeltaTable {
.when_matched_delete(|delete| delete)
.map_err(PythonError::from)?;
} else if let Some(md_predicate) = matched_delete_predicate {
cmd = cmd
.when_matched_delete(|delete| delete.predicate(md_predicate))
.map_err(PythonError::from)?;
for pred in md_predicate.iter() {
cmd = cmd
.when_matched_delete(|delete| delete.predicate(pred.clone()))
.map_err(PythonError::from)?;
}
}

if let Some(nmi_updates) = not_matched_insert_updates {
if let Some(nmi_predicate) = not_matched_insert_predicate {
cmd = cmd
.when_not_matched_insert(|mut insert| {
for (col_name, expression) in nmi_updates {
insert = insert.set(col_name.clone(), expression.clone());
}
insert.predicate(nmi_predicate)
})
.map_err(PythonError::from)?;
} else {
cmd = cmd
.when_not_matched_insert(|mut insert| {
for (col_name, expression) in nmi_updates {
insert = insert.set(col_name.clone(), expression.clone());
}
insert
})
.map_err(PythonError::from)?;
for it in nmi_updates.iter().zip(nmi_predicate.iter()) {
let (update_values, predicate_value) = it;
if let Some(pred) = predicate_value {
cmd = cmd
.when_not_matched_insert(|mut insert| {
for (col_name, expression) in update_values {
insert = insert.set(col_name.clone(), expression.clone());
}
insert.predicate(pred.clone())
})
.map_err(PythonError::from)?;
} else {
cmd = cmd
.when_not_matched_insert(|mut insert| {
for (col_name, expression) in update_values {
insert = insert.set(col_name.clone(), expression.clone());
}
insert
})
.map_err(PythonError::from)?;
}
}
}
}

if let Some(nmbsu_updates) = not_matched_by_source_update_updates {
if let Some(nmbsu_predicate) = not_matched_by_source_update_predicate {
cmd = cmd
.when_not_matched_by_source_update(|mut update| {
for (col_name, expression) in nmbsu_updates {
update = update.update(col_name.clone(), expression.clone());
}
update.predicate(nmbsu_predicate)
})
.map_err(PythonError::from)?;
} else {
cmd = cmd
.when_not_matched_by_source_update(|mut update| {
for (col_name, expression) in nmbsu_updates {
update = update.update(col_name.clone(), expression.clone());
}
update
})
.map_err(PythonError::from)?;
for it in nmbsu_updates.iter().zip(nmbsu_predicate.iter()) {
let (update_values, predicate_value) = it;
if let Some(pred) = predicate_value {
cmd = cmd
.when_not_matched_by_source_update(|mut update| {
for (col_name, expression) in update_values {
update = update.update(col_name.clone(), expression.clone());
}
update.predicate(pred.clone())
})
.map_err(PythonError::from)?;
} else {
cmd = cmd
.when_not_matched_by_source_update(|mut update| {
for (col_name, expression) in update_values {
update = update.update(col_name.clone(), expression.clone());
}
update
})
.map_err(PythonError::from)?;
}
}
}
}

Expand All @@ -567,9 +585,11 @@ impl RawDeltaTable {
.when_not_matched_by_source_delete(|delete| delete)
.map_err(PythonError::from)?;
} else if let Some(nmbs_predicate) = not_matched_by_source_delete_predicate {
cmd = cmd
.when_not_matched_by_source_delete(|delete| delete.predicate(nmbs_predicate))
.map_err(PythonError::from)?;
for pred in nmbs_predicate.iter() {
cmd = cmd
.when_not_matched_by_source_delete(|delete| delete.predicate(pred.clone()))
.map_err(PythonError::from)?;
}
}

let (table, metrics) = rt()?
Expand Down
Loading

0 comments on commit 787b5eb

Please sign in to comment.