diff --git a/python/deltalake/table.py b/python/deltalake/table.py index 896dd3b7a2..61dce5ee0f 100644 --- a/python/deltalake/table.py +++ b/python/deltalake/table.py @@ -496,24 +496,28 @@ def vacuum( def update( self, - updates: Dict[str, str], + updates: Optional[Dict[str, str]] = None, + new_values: Optional[ + Dict[str, Union[int, float, str, datetime, bool, List[Any]]] + ] = None, predicate: Optional[str] = None, writer_properties: Optional[Dict[str, int]] = None, error_on_type_mismatch: bool = True, ) -> Dict[str, Any]: - """UPDATE records in the Delta Table that matches an optional predicate. + """`UPDATE` records in the Delta Table that matches an optional predicate. Either updates or new_values needs + to be passed for it to execute. Args: updates: a mapping of column name to update SQL expression. + new_values: a mapping of column name to python datatype. predicate: a logical expression, defaults to None - writer_properties: Pass writer properties to the Rust parquet writer, see options - https://arrow.apache.org/rust/parquet/file/properties/struct.WriterProperties.html, - only the following fields are supported: `data_page_size_limit`, `dictionary_page_size_limit`, - `data_page_row_count_limit`, `write_batch_size`, `max_row_group_size`. - error_on_type_mismatch: specify if merge will return error if data types are mismatching, default = True + writer_properties: Pass writer properties to the Rust parquet writer, see options https://arrow.apache.org/rust/parquet/file/properties/struct.WriterProperties.html, + only the following fields are supported: `data_page_size_limit`, `dictionary_page_size_limit`, + `data_page_row_count_limit`, `write_batch_size`, `max_row_group_size`. + error_on_type_mismatch: specify if update will return error if data types are mismatching :default = True Returns: - the metrics from delete + the metrics from update Examples: @@ -522,18 +526,63 @@ def update( ``` from deltalake import DeltaTable dt = DeltaTable("tmp") - dt.update(predicate="id = '5'", updates = {"deleted": True}) + dt.update(predicate="id = '5'", updates = {"deleted": 'True'}) ``` - Update all row values. This is equivalent to `UPDATE table SET id = concat(id, '_old')`. + Update all row values. This is equivalent to + ``UPDATE table SET deleted = true, id = concat(id, '_old')``. ``` from deltalake import DeltaTable dt = DeltaTable("tmp") - dt.update(updates={"deleted": True, "id": "concat(id, '_old')"}) + dt.update(updates = {"deleted": 'True', "id": "concat(id, '_old')"}) + ``` + + To use Python objects instead of SQL strings, use the `new_values` parameter + instead of the `updates` parameter. For example, this is equivalent to + ``UPDATE table SET price = 150.10 WHERE id = '5'`` + ``` + from deltalake import DeltaTable + dt = DeltaTable("tmp") + dt.update(predicate="id = '5'", new_values = {"price": 150.10}) ``` """ + if updates is None and new_values is not None: + updates = {} + for key, value in new_values.items(): + if isinstance(value, (int, float, bool, list)): + value = str(value) + elif isinstance(value, str): + value = f"'{value}'" + elif isinstance(value, datetime): + value = str( + int(value.timestamp() * 1000 * 1000) + ) # convert to microseconds + else: + raise TypeError( + "Invalid datatype provided in new_values, only int, float, bool, list, str or datetime or accepted." + ) + updates[key] = value + elif updates is not None and new_values is None: + for key, value in updates.items(): + print(type(key), type(value)) + if not isinstance(value, str) or not isinstance(key, str): + raise TypeError( + f"The values of the updates parameter must all be SQL strings. Got {updates}. Did you mean to use the new_values parameter?" + ) + + elif updates is not None and new_values is not None: + raise ValueError( + "Passing updates and new_values at same time is not allowed, pick one." + ) + else: + raise ValueError( + "Either updates or new_values need to be passed to update the table." + ) metrics = self._table.update( - updates, predicate, writer_properties, safe_cast=not error_on_type_mismatch + updates, + predicate, + writer_properties, + safe_cast=not error_on_type_mismatch, ) return json.loads(metrics) diff --git a/python/tests/test_update.py b/python/tests/test_update.py index defdd1a396..2e3fc82fdd 100644 --- a/python/tests/test_update.py +++ b/python/tests/test_update.py @@ -14,6 +14,8 @@ def sample_table(): "id": pa.array(["1", "2", "3", "4", "5"]), "price": pa.array(list(range(nrows)), pa.int64()), "sold": pa.array(list(range(nrows)), pa.int64()), + "price_float": pa.array(list(range(nrows)), pa.float64()), + "items_in_bucket": pa.array([["item1", "item2", "item3"]] * nrows), "deleted": pa.array([False] * nrows), } ) @@ -30,6 +32,8 @@ def test_update_with_predicate(tmp_path: pathlib.Path, sample_table: pa.Table): "id": pa.array(["1", "2", "3", "4", "5"]), "price": pa.array(list(range(nrows)), pa.int64()), "sold": pa.array(list(range(nrows)), pa.int64()), + "price_float": pa.array(list(range(nrows)), pa.float64()), + "items_in_bucket": pa.array([["item1", "item2", "item3"]] * nrows), "deleted": pa.array([False, False, False, False, True]), } ) @@ -54,6 +58,8 @@ def test_update_wo_predicate(tmp_path: pathlib.Path, sample_table: pa.Table): "id": pa.array(["1", "2", "3", "4", "5"]), "price": pa.array(list(range(nrows)), pa.int64()), "sold": pa.array(list(range(nrows)), pa.int64()), + "price_float": pa.array(list(range(nrows)), pa.float64()), + "items_in_bucket": pa.array([["item1", "item2", "item3"]] * nrows), "deleted": pa.array([True] * 5), } ) @@ -93,6 +99,8 @@ def test_update_wo_predicate_multiple_updates( "id": pa.array(["1_1", "2_1", "3_1", "4_1", "5_1"]), "price": pa.array([0, 1, 2, 3, 4], pa.int64()), "sold": pa.array([0, 1, 4, 9, 16], pa.int64()), + "price_float": pa.array(list(range(5)), pa.float64()), + "items_in_bucket": pa.array([["item1", "item2", "item3"]] * 5), "deleted": pa.array([True] * 5), } ) @@ -107,3 +115,86 @@ def test_update_wo_predicate_multiple_updates( assert last_action["operation"] == "UPDATE" assert result == expected + + +def test_update_with_predicate_and_new_values( + tmp_path: pathlib.Path, sample_table: pa.Table +): + write_deltalake(tmp_path, sample_table, mode="append") + + dt = DeltaTable(tmp_path) + + nrows = 5 + expected = pa.table( + { + "id": pa.array(["1", "2", "3", "4", "new_id"]), + "price": pa.array(list(range(nrows)), pa.int64()), + "sold": pa.array([0, 1, 2, 3, 100], pa.int64()), + "price_float": pa.array([0, 1, 2, 3, 9999], pa.float64()), + "items_in_bucket": pa.array( + [["item1", "item2", "item3"]] * 4 + [["item4", "item5", "item6"]] + ), + "deleted": pa.array([False, False, False, False, True]), + } + ) + + dt.update( + new_values={ + "id": "new_id", + "deleted": True, + "sold": 100, + "price_float": 9999, + "items_in_bucket": ["item4", "item5", "item6"], + }, + predicate="price > 3", + ) + + result = dt.to_pyarrow_table() + last_action = dt.history(1)[0] + + assert last_action["operation"] == "UPDATE" + assert result == expected + + +def test_update_no_inputs(tmp_path: pathlib.Path, sample_table: pa.Table): + write_deltalake(tmp_path, sample_table, mode="append") + + dt = DeltaTable(tmp_path) + + with pytest.raises(Exception) as excinfo: + dt.update() + + assert ( + str(excinfo.value) + == "Either updates or new_values need to be passed to update the table." + ) + + +def test_update_to_many_inputs(tmp_path: pathlib.Path, sample_table: pa.Table): + write_deltalake(tmp_path, sample_table, mode="append") + + dt = DeltaTable(tmp_path) + + with pytest.raises(Exception) as excinfo: + dt.update(updates={}, new_values={}) + + assert ( + str(excinfo.value) + == "Passing updates and new_values at same time is not allowed, pick one." + ) + + +def test_update_with_incorrect_updates_input( + tmp_path: pathlib.Path, sample_table: pa.Table +): + write_deltalake(tmp_path, sample_table, mode="append") + + dt = DeltaTable(tmp_path) + updates = {"col": {}} + with pytest.raises(Exception) as excinfo: + dt.update(new_values=updates) + + assert ( + str(excinfo.value) + == "Invalid datatype provided in new_values, only int, float, bool, list, str or datetime or accepted." + )