Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat(python): allow python objects to be passed as new values in .update() #1749

Merged
52 changes: 45 additions & 7 deletions python/deltalake/table.py
Original file line number Diff line number Diff line change
Expand Up @@ -441,18 +441,23 @@ def vacuum(

def update(
self,
updates: Dict[str, str],
updates: Optional[Dict[str, str]] = None,
new_values: Optional[
Dict[str, Union[int, float, str, datetime, bool, list]]
] = None,
predicate: Optional[str] = None,
writer_properties: Optional[Dict[str, int]] = None,
error_on_type_mismatch: bool = True,
) -> Dict[str, Any]:
"""UPDATE records in the Delta Table that matches an optional predicate.
"""`UPDATE` records in the Delta Table that matches an optional predicate. Either updates or new_values needs
to be passed for it to execute.

:param updates: a mapping of column name to update SQL expression.
:param new_values: a mapping of column name to python datatype.
:param predicate: a logical expression, defaults to None
:writer_properties: Pass writer properties to the Rust parquet writer, see options https://arrow.apache.org/rust/parquet/file/properties/struct.WriterProperties.html,
only the fields: data_page_size_limit, dictionary_page_size_limit, data_page_row_count_limit, write_batch_size, max_row_group_size are supported.
:error_on_type_mismatch: specify if merge will return error if data types are mismatching :default = True
:error_on_type_mismatch: specify if update will return error if data types are mismatching :default = True
:return: the metrics from delete

Examples:
Expand All @@ -464,7 +469,7 @@ def update(
>>> dt = DeltaTable("tmp")
>>> dt.update(predicate="id = '5'",
... updates = {
... "deleted": True,
... "deleted": 'True',
ion-elgreco marked this conversation as resolved.
Show resolved Hide resolved
... }
... )

Expand All @@ -473,15 +478,48 @@ def update(
>>> from deltalake import DeltaTable
>>> dt = DeltaTable("tmp")
>>> dt.update(updates = {
... "deleted": True,
... "deleted": 'True',
... "id": "concat(id, '_old')"
... }
... )

Update some row values with python object. This is equivalent to
ion-elgreco marked this conversation as resolved.
Show resolved Hide resolved
``UPDATE table SET price = 150.10 WHERE id = '5'``
>>> from deltalake import DeltaTable
>>> dt = DeltaTable("tmp")
>>> dt.update(predicate="id = '5'",
... updates = {
ion-elgreco marked this conversation as resolved.
Show resolved Hide resolved
... "price": 150.10,
... }
... )
"""

if updates is None and new_values is not None:
updates = {}
for key, value in new_values.items():
if isinstance(value, (int, float, bool, list)):
ion-elgreco marked this conversation as resolved.
Show resolved Hide resolved
value = str(value)
elif isinstance(value, str):
value = f"'{value}'"
elif isinstance(value, datetime):
value = str(
int(value.timestamp() * 1000 * 1000)
) # convert to microseconds
updates[key] = value
ion-elgreco marked this conversation as resolved.
Show resolved Hide resolved
elif updates is not None and new_values is None:
pass
ion-elgreco marked this conversation as resolved.
Show resolved Hide resolved
elif updates is not None and new_values is not None:
raise ValueError(
"Passing updates and new_values at same time is not allowed, pick one."
)
else:
raise ValueError(
"Either updates or new_values need to be passed to update the table."
)
metrics = self._table.update(
updates, predicate, writer_properties, safe_cast=not error_on_type_mismatch
updates,
predicate,
writer_properties,
safe_cast=not error_on_type_mismatch,
)
return json.loads(metrics)

Expand Down
75 changes: 75 additions & 0 deletions python/tests/test_update.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,8 @@ def sample_table():
"id": pa.array(["1", "2", "3", "4", "5"]),
"price": pa.array(list(range(nrows)), pa.int64()),
"sold": pa.array(list(range(nrows)), pa.int64()),
"price_float": pa.array(list(range(nrows)), pa.float64()),
"items_in_bucket": pa.array([["item1", "item2", "item3"]] * nrows),
"deleted": pa.array([False] * nrows),
}
)
Expand All @@ -30,6 +32,8 @@ def test_update_with_predicate(tmp_path: pathlib.Path, sample_table: pa.Table):
"id": pa.array(["1", "2", "3", "4", "5"]),
"price": pa.array(list(range(nrows)), pa.int64()),
"sold": pa.array(list(range(nrows)), pa.int64()),
"price_float": pa.array(list(range(nrows)), pa.float64()),
"items_in_bucket": pa.array([["item1", "item2", "item3"]] * nrows),
"deleted": pa.array([False, False, False, False, True]),
}
)
Expand All @@ -54,6 +58,8 @@ def test_update_wo_predicate(tmp_path: pathlib.Path, sample_table: pa.Table):
"id": pa.array(["1", "2", "3", "4", "5"]),
"price": pa.array(list(range(nrows)), pa.int64()),
"sold": pa.array(list(range(nrows)), pa.int64()),
"price_float": pa.array(list(range(nrows)), pa.float64()),
"items_in_bucket": pa.array([["item1", "item2", "item3"]] * nrows),
"deleted": pa.array([True] * 5),
}
)
Expand Down Expand Up @@ -93,6 +99,8 @@ def test_update_wo_predicate_multiple_updates(
"id": pa.array(["1_1", "2_1", "3_1", "4_1", "5_1"]),
"price": pa.array([0, 1, 2, 3, 4], pa.int64()),
"sold": pa.array([0, 1, 4, 9, 16], pa.int64()),
"price_float": pa.array(list(range(5)), pa.float64()),
"items_in_bucket": pa.array([["item1", "item2", "item3"]] * 5),
"deleted": pa.array([True] * 5),
}
)
Expand All @@ -107,3 +115,70 @@ def test_update_wo_predicate_multiple_updates(

assert last_action["operation"] == "UPDATE"
assert result == expected


def test_update_with_predicate_and_new_values(
tmp_path: pathlib.Path, sample_table: pa.Table
):
write_deltalake(tmp_path, sample_table, mode="append")

dt = DeltaTable(tmp_path)

nrows = 5
expected = pa.table(
{
"id": pa.array(["1", "2", "3", "4", "new_id"]),
"price": pa.array(list(range(nrows)), pa.int64()),
"sold": pa.array([0, 1, 2, 3, 100], pa.int64()),
"price_float": pa.array([0, 1, 2, 3, 9999], pa.float64()),
"items_in_bucket": pa.array(
[["item1", "item2", "item3"]] * 4 + [["item4", "item5", "item6"]]
),
"deleted": pa.array([False, False, False, False, True]),
}
)

dt.update(
new_values={
"id": "new_id",
"deleted": True,
"sold": 100,
"price_float": 9999,
"items_in_bucket": ["item4", "item5", "item6"],
},
predicate="price > 3",
)

result = dt.to_pyarrow_table()
last_action = dt.history(1)[0]

assert last_action["operation"] == "UPDATE"
assert result == expected


def test_update_no_inputs(tmp_path: pathlib.Path, sample_table: pa.Table):
write_deltalake(tmp_path, sample_table, mode="append")

dt = DeltaTable(tmp_path)

with pytest.raises(Exception) as excinfo:
dt.update()

assert (
str(excinfo.value)
== "Either updates or new_values need to be passed to update the table."
)


def test_update_to_many_inputs(tmp_path: pathlib.Path, sample_table: pa.Table):
write_deltalake(tmp_path, sample_table, mode="append")

dt = DeltaTable(tmp_path)

with pytest.raises(Exception) as excinfo:
dt.update(updates={}, new_values={})

assert (
str(excinfo.value)
== "Passing updates and new_values at same time is not allowed, pick one."
)
Loading