Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat(python): allow python objects to be passed as new values in .update() #1749

Merged
73 changes: 61 additions & 12 deletions python/deltalake/table.py
Original file line number Diff line number Diff line change
Expand Up @@ -496,24 +496,28 @@ def vacuum(

def update(
self,
updates: Dict[str, str],
updates: Optional[Dict[str, str]] = None,
new_values: Optional[
Dict[str, Union[int, float, str, datetime, bool, List[Any]]]
] = None,
predicate: Optional[str] = None,
writer_properties: Optional[Dict[str, int]] = None,
error_on_type_mismatch: bool = True,
) -> Dict[str, Any]:
"""UPDATE records in the Delta Table that matches an optional predicate.
"""`UPDATE` records in the Delta Table that matches an optional predicate. Either updates or new_values needs
to be passed for it to execute.

Args:
updates: a mapping of column name to update SQL expression.
new_values: a mapping of column name to python datatype.
predicate: a logical expression, defaults to None
writer_properties: Pass writer properties to the Rust parquet writer, see options
https://arrow.apache.org/rust/parquet/file/properties/struct.WriterProperties.html,
only the following fields are supported: `data_page_size_limit`, `dictionary_page_size_limit`,
`data_page_row_count_limit`, `write_batch_size`, `max_row_group_size`.
error_on_type_mismatch: specify if merge will return error if data types are mismatching, default = True
writer_properties: Pass writer properties to the Rust parquet writer, see options https://arrow.apache.org/rust/parquet/file/properties/struct.WriterProperties.html,
only the following fields are supported: `data_page_size_limit`, `dictionary_page_size_limit`,
`data_page_row_count_limit`, `write_batch_size`, `max_row_group_size`.
error_on_type_mismatch: specify if update will return error if data types are mismatching :default = True

Returns:
the metrics from delete
the metrics from update

Examples:

Expand All @@ -522,18 +526,63 @@ def update(
```
from deltalake import DeltaTable
dt = DeltaTable("tmp")
dt.update(predicate="id = '5'", updates = {"deleted": True})
dt.update(predicate="id = '5'", updates = {"deleted": 'True'})
```

Update all row values. This is equivalent to `UPDATE table SET id = concat(id, '_old')`.
Update all row values. This is equivalent to
``UPDATE table SET deleted = true, id = concat(id, '_old')``.
```
from deltalake import DeltaTable
dt = DeltaTable("tmp")
dt.update(updates={"deleted": True, "id": "concat(id, '_old')"})
dt.update(updates = {"deleted": 'True', "id": "concat(id, '_old')"})
```

To use Python objects instead of SQL strings, use the `new_values` parameter
instead of the `updates` parameter. For example, this is equivalent to
``UPDATE table SET price = 150.10 WHERE id = '5'``
```
from deltalake import DeltaTable
dt = DeltaTable("tmp")
dt.update(predicate="id = '5'", new_values = {"price": 150.10})
```
"""
if updates is None and new_values is not None:
updates = {}
for key, value in new_values.items():
if isinstance(value, (int, float, bool, list)):
ion-elgreco marked this conversation as resolved.
Show resolved Hide resolved
value = str(value)
elif isinstance(value, str):
value = f"'{value}'"
elif isinstance(value, datetime):
value = str(
int(value.timestamp() * 1000 * 1000)
) # convert to microseconds
else:
raise TypeError(
"Invalid datatype provided in new_values, only int, float, bool, list, str or datetime or accepted."
)
updates[key] = value
elif updates is not None and new_values is None:
for key, value in updates.items():
print(type(key), type(value))
if not isinstance(value, str) or not isinstance(key, str):
raise TypeError(
f"The values of the updates parameter must all be SQL strings. Got {updates}. Did you mean to use the new_values parameter?"
)

elif updates is not None and new_values is not None:
raise ValueError(
"Passing updates and new_values at same time is not allowed, pick one."
)
else:
raise ValueError(
"Either updates or new_values need to be passed to update the table."
)
metrics = self._table.update(
updates, predicate, writer_properties, safe_cast=not error_on_type_mismatch
updates,
predicate,
writer_properties,
safe_cast=not error_on_type_mismatch,
)
return json.loads(metrics)

Expand Down
91 changes: 91 additions & 0 deletions python/tests/test_update.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,8 @@ def sample_table():
"id": pa.array(["1", "2", "3", "4", "5"]),
"price": pa.array(list(range(nrows)), pa.int64()),
"sold": pa.array(list(range(nrows)), pa.int64()),
"price_float": pa.array(list(range(nrows)), pa.float64()),
"items_in_bucket": pa.array([["item1", "item2", "item3"]] * nrows),
"deleted": pa.array([False] * nrows),
}
)
Expand All @@ -30,6 +32,8 @@ def test_update_with_predicate(tmp_path: pathlib.Path, sample_table: pa.Table):
"id": pa.array(["1", "2", "3", "4", "5"]),
"price": pa.array(list(range(nrows)), pa.int64()),
"sold": pa.array(list(range(nrows)), pa.int64()),
"price_float": pa.array(list(range(nrows)), pa.float64()),
"items_in_bucket": pa.array([["item1", "item2", "item3"]] * nrows),
"deleted": pa.array([False, False, False, False, True]),
}
)
Expand All @@ -54,6 +58,8 @@ def test_update_wo_predicate(tmp_path: pathlib.Path, sample_table: pa.Table):
"id": pa.array(["1", "2", "3", "4", "5"]),
"price": pa.array(list(range(nrows)), pa.int64()),
"sold": pa.array(list(range(nrows)), pa.int64()),
"price_float": pa.array(list(range(nrows)), pa.float64()),
"items_in_bucket": pa.array([["item1", "item2", "item3"]] * nrows),
"deleted": pa.array([True] * 5),
}
)
Expand Down Expand Up @@ -93,6 +99,8 @@ def test_update_wo_predicate_multiple_updates(
"id": pa.array(["1_1", "2_1", "3_1", "4_1", "5_1"]),
"price": pa.array([0, 1, 2, 3, 4], pa.int64()),
"sold": pa.array([0, 1, 4, 9, 16], pa.int64()),
"price_float": pa.array(list(range(5)), pa.float64()),
"items_in_bucket": pa.array([["item1", "item2", "item3"]] * 5),
"deleted": pa.array([True] * 5),
}
)
Expand All @@ -107,3 +115,86 @@ def test_update_wo_predicate_multiple_updates(

assert last_action["operation"] == "UPDATE"
assert result == expected


def test_update_with_predicate_and_new_values(
tmp_path: pathlib.Path, sample_table: pa.Table
):
write_deltalake(tmp_path, sample_table, mode="append")

dt = DeltaTable(tmp_path)

nrows = 5
expected = pa.table(
{
"id": pa.array(["1", "2", "3", "4", "new_id"]),
"price": pa.array(list(range(nrows)), pa.int64()),
"sold": pa.array([0, 1, 2, 3, 100], pa.int64()),
"price_float": pa.array([0, 1, 2, 3, 9999], pa.float64()),
"items_in_bucket": pa.array(
[["item1", "item2", "item3"]] * 4 + [["item4", "item5", "item6"]]
),
"deleted": pa.array([False, False, False, False, True]),
}
)

dt.update(
new_values={
"id": "new_id",
"deleted": True,
"sold": 100,
"price_float": 9999,
"items_in_bucket": ["item4", "item5", "item6"],
},
predicate="price > 3",
)

result = dt.to_pyarrow_table()
last_action = dt.history(1)[0]

assert last_action["operation"] == "UPDATE"
assert result == expected


def test_update_no_inputs(tmp_path: pathlib.Path, sample_table: pa.Table):
write_deltalake(tmp_path, sample_table, mode="append")

dt = DeltaTable(tmp_path)

with pytest.raises(Exception) as excinfo:
dt.update()

assert (
str(excinfo.value)
== "Either updates or new_values need to be passed to update the table."
)


def test_update_to_many_inputs(tmp_path: pathlib.Path, sample_table: pa.Table):
write_deltalake(tmp_path, sample_table, mode="append")

dt = DeltaTable(tmp_path)

with pytest.raises(Exception) as excinfo:
dt.update(updates={}, new_values={})

assert (
str(excinfo.value)
== "Passing updates and new_values at same time is not allowed, pick one."
)


def test_update_with_incorrect_updates_input(
tmp_path: pathlib.Path, sample_table: pa.Table
):
write_deltalake(tmp_path, sample_table, mode="append")

dt = DeltaTable(tmp_path)
updates = {"col": {}}
with pytest.raises(Exception) as excinfo:
dt.update(new_values=updates)

assert (
str(excinfo.value)
== "Invalid datatype provided in new_values, only int, float, bool, list, str or datetime or accepted."
)
Loading