From 346f51a277c23be09eca373cd6d775b745d59c05 Mon Sep 17 00:00:00 2001 From: Will Jones Date: Sun, 20 Mar 2022 15:47:06 -0700 Subject: [PATCH] [Python] Initial PyArrow writer (#566) * Initial writer implementation * Add basic partitioning support * Update docs and link to other projects * Add Pandas support * Test writer stats and partitioning * Test statistics * Enforce protocol version * Add experimental to docstring * Need tying extensions for checking now * Add nipick ignore for typing_extensions --- .gitignore | 4 +- README.adoc | 2 +- python/deltalake/__init__.py | 1 + python/deltalake/schema.py | 4 +- python/deltalake/table.py | 10 +- python/deltalake/writer.py | 243 +++++++++++++++ python/docs/source/api_reference.rst | 5 + python/docs/source/conf.py | 15 +- python/docs/source/usage.rst | 30 +- python/pyproject.toml | 3 +- python/src/lib.rs | 148 ++++++++- python/stubs/deltalake/deltalake.pyi | 11 +- python/stubs/pyarrow/__init__.pyi | 1 + python/stubs/pyarrow/dataset.pyi | 1 + python/stubs/pyarrow/lib.pyi | 3 + python/tests/__init__.py | 0 .../_delta_log/.00000000000000000000.json.crc | Bin 16 -> 0 bytes .../_delta_log/00000000000000000000.json | 4 - ...-b5ec-511b932751ea.c000.snappy.parquet.crc | Bin 12 -> 0 bytes ...4007-b5ec-511b932751ea.c000.snappy.parquet | Bin 500 -> 0 bytes .../_delta_log/.00000000000000000000.json.crc | Bin 16 -> 0 bytes .../_delta_log/00000000000000000000.json | 4 - ...-bbfa-0479c4c1e704.c000.snappy.parquet.crc | Bin 12 -> 0 bytes ...44ea-bbfa-0479c4c1e704.c000.snappy.parquet | Bin 500 -> 0 bytes python/tests/test_schema.py | 6 +- python/tests/test_table_read.py | 22 +- python/tests/test_writer.py | 294 ++++++++++++++++++ rust/Cargo.toml | 5 +- rust/src/delta.rs | 10 +- rust/src/delta_arrow.rs | 94 ++++++ rust/tests/adls_gen2_table_test.rs | 2 +- rust/tests/concurrent_writes_test.rs | 2 +- rust/tests/fs_common/mod.rs | 2 +- 33 files changed, 882 insertions(+), 44 deletions(-) create mode 100644 python/deltalake/writer.py create mode 100644 python/stubs/pyarrow/lib.pyi create mode 100644 python/tests/__init__.py delete mode 100644 python/tests/data/date_partitioned_df/_delta_log/.00000000000000000000.json.crc delete mode 100644 python/tests/data/date_partitioned_df/_delta_log/00000000000000000000.json delete mode 100644 python/tests/data/date_partitioned_df/date=2021-01-01/.part-00000-6ae76612-3903-4007-b5ec-511b932751ea.c000.snappy.parquet.crc delete mode 100644 python/tests/data/date_partitioned_df/date=2021-01-01/part-00000-6ae76612-3903-4007-b5ec-511b932751ea.c000.snappy.parquet delete mode 100644 python/tests/data/timestamp_partitioned_df/_delta_log/.00000000000000000000.json.crc delete mode 100644 python/tests/data/timestamp_partitioned_df/_delta_log/00000000000000000000.json delete mode 100644 python/tests/data/timestamp_partitioned_df/date=2021-01-01 00%3A00%3A00/.part-00000-6177a755-69ce-44ea-bbfa-0479c4c1e704.c000.snappy.parquet.crc delete mode 100644 python/tests/data/timestamp_partitioned_df/date=2021-01-01 00%3A00%3A00/part-00000-6177a755-69ce-44ea-bbfa-0479c4c1e704.c000.snappy.parquet create mode 100644 python/tests/test_writer.py diff --git a/.gitignore b/.gitignore index 776885c858..ff22decf25 100644 --- a/.gitignore +++ b/.gitignore @@ -10,4 +10,6 @@ tlaplus/*.toolbox/*/MC.cfg tlaplus/*.toolbox/*/[0-9]*-[0-9]*-[0-9]*-[0-9]*-[0-9]*-[0-9]*/ /.idea .vscode -.env \ No newline at end of file +.env +**/.DS_Store +**/.python-version \ No newline at end of file diff --git a/README.adoc b/README.adoc index db8c98e744..aa902a8303 100644 --- a/README.adoc +++ b/README.adoc @@ -73,7 +73,7 @@ link:https://github.com/rajasekarv/vega[vega], etc. It also provides bindings to | High-level file writer | -| +| link:https://github.com/delta-io/delta-rs/issues/542[#542] | | Optimize diff --git a/python/deltalake/__init__.py b/python/deltalake/__init__.py index aeb999c97a..eaa3c39c9c 100644 --- a/python/deltalake/__init__.py +++ b/python/deltalake/__init__.py @@ -2,3 +2,4 @@ from .deltalake import PyDeltaTableError, RawDeltaTable, rust_core_version from .schema import DataType, Field, Schema from .table import DeltaTable, Metadata +from .writer import write_deltalake diff --git a/python/deltalake/schema.py b/python/deltalake/schema.py index 9ac5123098..9363b6d097 100644 --- a/python/deltalake/schema.py +++ b/python/deltalake/schema.py @@ -205,7 +205,7 @@ def pyarrow_datatype_from_dict(json_dict: Dict[str, Any]) -> pyarrow.DataType: key, pyarrow.list_( pyarrow.field( - "element", pyarrow.struct([pyarrow_field_from_dict(value_type)]) + "entries", pyarrow.struct([pyarrow_field_from_dict(value_type)]) ) ), ) @@ -218,7 +218,7 @@ def pyarrow_datatype_from_dict(json_dict: Dict[str, Any]) -> pyarrow.DataType: elif type_class == "list": field = json_dict["children"][0] element_type = pyarrow_datatype_from_dict(field) - return pyarrow.list_(pyarrow.field("element", element_type)) + return pyarrow.list_(pyarrow.field("item", element_type)) elif type_class == "struct": fields = [pyarrow_field_from_dict(field) for field in json_dict["children"]] return pyarrow.struct(fields) diff --git a/python/deltalake/table.py b/python/deltalake/table.py index 1c8cc8b4e7..59096c38db 100644 --- a/python/deltalake/table.py +++ b/python/deltalake/table.py @@ -1,7 +1,7 @@ import json import warnings from dataclasses import dataclass -from typing import TYPE_CHECKING, Any, Dict, List, Optional, Tuple, Union +from typing import TYPE_CHECKING, Any, Dict, List, NamedTuple, Optional, Tuple, Union import pyarrow import pyarrow.fs as pa_fs @@ -63,6 +63,11 @@ def __str__(self) -> str: ) +class ProtocolVersions(NamedTuple): + min_reader_version: int + min_writer_version: int + + @dataclass(init=False) class DeltaTable: """Create a DeltaTable instance.""" @@ -219,6 +224,9 @@ def metadata(self) -> Metadata: """ return self._metadata + def protocol(self) -> ProtocolVersions: + return ProtocolVersions(*self._table.protocol_versions()) + def history(self, limit: Optional[int] = None) -> List[Dict[str, Any]]: """ Run the history command on the DeltaTable. diff --git a/python/deltalake/writer.py b/python/deltalake/writer.py new file mode 100644 index 0000000000..7c38887301 --- /dev/null +++ b/python/deltalake/writer.py @@ -0,0 +1,243 @@ +import json +import uuid +from dataclasses import dataclass +from datetime import date, datetime +from decimal import Decimal +from typing import Any, Dict, Iterable, Iterator, List, Mapping, Optional, Union + +import pandas as pd +import pyarrow as pa +import pyarrow.dataset as ds +import pyarrow.fs as pa_fs +from pyarrow.lib import RecordBatchReader +from typing_extensions import Literal + +from .deltalake import PyDeltaTableError +from .deltalake import write_new_deltalake as _write_new_deltalake +from .table import DeltaTable + + +class DeltaTableProtocolError(PyDeltaTableError): + pass + + +@dataclass +class AddAction: + path: str + size: int + partition_values: Mapping[str, Optional[str]] + modification_time: int + data_change: bool + stats: str + + +def write_deltalake( + table_or_uri: Union[str, DeltaTable], + data: Union[ + pd.DataFrame, + pa.Table, + pa.RecordBatch, + Iterable[pa.RecordBatch], + RecordBatchReader, + ], + schema: Optional[pa.Schema] = None, + partition_by: Optional[List[str]] = None, + filesystem: Optional[pa_fs.FileSystem] = None, + mode: Literal["error", "append", "overwrite", "ignore"] = "error", +) -> None: + """Write to a Delta Lake table (Experimental) + + If the table does not already exist, it will be created. + + This function only supports protocol version 1 currently. If an attempting + to write to an existing table with a higher min_writer_version, this + function will throw DeltaTableProtocolError. + + :param table_or_uri: URI of a table or a DeltaTable object. + :param data: Data to write. If passing iterable, the schema must also be given. + :param schema: Optional schema to write. + :param partition_by: List of columns to partition the table by. Only required + when creating a new table. + :param filesystem: Optional filesystem to pass to PyArrow. If not provided will + be inferred from uri. + :param mode: How to handle existing data. Default is to error if table + already exists. If 'append', will add new data. If 'overwrite', will + replace table with new data. If 'ignore', will not write anything if + table already exists. + """ + if isinstance(data, pd.DataFrame): + data = pa.Table.from_pandas(data) + + if schema is None: + if isinstance(data, RecordBatchReader): + schema = data.schema + elif isinstance(data, Iterable): + raise ValueError("You must provide schema if data is Iterable") + else: + schema = data.schema + + if isinstance(table_or_uri, str): + table = try_get_deltatable(table_or_uri) + table_uri = table_or_uri + else: + table = table_or_uri + table_uri = table_uri = table._table.table_uri() + + # TODO: Pass through filesystem once it is complete + # if filesystem is None: + # filesystem = pa_fs.PyFileSystem(DeltaStorageHandler(table_uri)) + + if table: # already exists + if mode == "error": + raise AssertionError("DeltaTable already exists.") + elif mode == "ignore": + return + + current_version = table.version() + + if partition_by: + assert partition_by == table.metadata().partition_columns + + if table.protocol().min_writer_version > 1: + raise DeltaTableProtocolError( + "This table's min_writer_version is " + f"{table.protocol().min_writer_version}, " + "but this method only supports version 1." + ) + else: # creating a new table + current_version = -1 + + # TODO: Don't allow writing to non-empty directory + # Blocked on: Finish filesystem implementation in fs.py + # assert len(filesystem.get_file_info(pa_fs.FileSelector(table_uri, allow_not_found=True))) == 0 + + if partition_by: + partition_schema = pa.schema([schema.field(name) for name in partition_by]) + partitioning = ds.partitioning(partition_schema, flavor="hive") + else: + partitioning = None + + add_actions: List[AddAction] = [] + + def visitor(written_file: Any) -> None: + partition_values = get_partitions_from_path(table_uri, written_file.path) + stats = get_file_stats_from_metadata(written_file.metadata) + + add_actions.append( + AddAction( + written_file.path, + written_file.metadata.serialized_size, + partition_values, + int(datetime.now().timestamp()), + True, + json.dumps(stats, cls=DeltaJSONEncoder), + ) + ) + + ds.write_dataset( + data, + base_dir=table_uri, + basename_template=f"{current_version + 1}-{uuid.uuid4()}-{{i}}.parquet", + format="parquet", + partitioning=partitioning, + # It will not accept a schema if using a RBR + schema=schema if not isinstance(data, RecordBatchReader) else None, + file_visitor=visitor, + existing_data_behavior="overwrite_or_ignore", + ) + + if table is None: + _write_new_deltalake(table_uri, schema, add_actions, mode, partition_by or []) + else: + table._table.create_write_transaction( + add_actions, + mode, + partition_by or [], + ) + + +class DeltaJSONEncoder(json.JSONEncoder): + def default(self, obj: Any) -> Any: + if isinstance(obj, bytes): + return obj.decode("unicode_escape") + elif isinstance(obj, date): + return obj.isoformat() + elif isinstance(obj, datetime): + return obj.isoformat() + elif isinstance(obj, Decimal): + return str(obj) + # Let the base class default method raise the TypeError + return json.JSONEncoder.default(self, obj) + + +def try_get_deltatable(table_uri: str) -> Optional[DeltaTable]: + try: + return DeltaTable(table_uri) + except PyDeltaTableError as err: + if "Not a Delta table" not in str(err): + raise + return None + + +def get_partitions_from_path(base_path: str, path: str) -> Dict[str, str]: + path = path.split(base_path, maxsplit=1)[1] + parts = path.split("/") + parts.pop() # remove filename + out = {} + for part in parts: + if part == "": + continue + key, value = part.split("=", maxsplit=1) + out[key] = value + return out + + +def get_file_stats_from_metadata( + metadata: Any, +) -> Dict[str, Union[int, Dict[str, Any]]]: + stats = { + "numRecords": metadata.num_rows, + "minValues": {}, + "maxValues": {}, + "nullCount": {}, + } + + def iter_groups(metadata: Any) -> Iterator[Any]: + for i in range(metadata.num_row_groups): + yield metadata.row_group(i) + + for column_idx in range(metadata.num_columns): + name = metadata.row_group(0).column(column_idx).path_in_schema + # If stats missing, then we can't know aggregate stats + if all( + group.column(column_idx).is_stats_set for group in iter_groups(metadata) + ): + stats["nullCount"][name] = sum( + group.column(column_idx).statistics.null_count + for group in iter_groups(metadata) + ) + + # I assume for now this is based on data type, and thus is + # consistent between groups + if metadata.row_group(0).column(column_idx).statistics.has_min_max: + # Min and Max are recorded in physical type, not logical type + # https://stackoverflow.com/questions/66753485/decoding-parquet-min-max-statistics-for-decimal-type + # TODO: Add logic to decode physical type for DATE, DECIMAL + logical_type = ( + metadata.row_group(0) + .column(column_idx) + .statistics.logical_type.type + ) + # + if logical_type not in ["STRING", "INT", "TIMESTAMP", "NONE"]: + continue + # import pdb; pdb.set_trace() + stats["minValues"][name] = min( + group.column(column_idx).statistics.min + for group in iter_groups(metadata) + ) + stats["maxValues"][name] = max( + group.column(column_idx).statistics.max + for group in iter_groups(metadata) + ) + return stats diff --git a/python/docs/source/api_reference.rst b/python/docs/source/api_reference.rst index 0fcb6579c3..09659ebc10 100644 --- a/python/docs/source/api_reference.rst +++ b/python/docs/source/api_reference.rst @@ -7,6 +7,11 @@ DeltaTable .. automodule:: deltalake.table :members: +Writing DeltaTables +------------------- + +.. autofunction:: deltalake.write_deltalake + DeltaSchema ----------- diff --git a/python/docs/source/conf.py b/python/docs/source/conf.py index 5fbd59eb17..bb6c11237b 100644 --- a/python/docs/source/conf.py +++ b/python/docs/source/conf.py @@ -42,7 +42,12 @@ def get_release_version() -> str: # Add any Sphinx extension module names here, as strings. They can be # extensions coming with Sphinx (named 'sphinx.ext.*') or your custom # ones. -extensions = ["sphinx_rtd_theme", "sphinx.ext.autodoc", "edit_on_github"] +extensions = [ + "sphinx_rtd_theme", + "sphinx.ext.autodoc", + "sphinx.ext.intersphinx", + "edit_on_github", +] autodoc_typehints = "description" nitpicky = True nitpick_ignore = [ @@ -52,6 +57,7 @@ def get_release_version() -> str: ("py:class", "pyarrow.lib.DataType"), ("py:class", "pyarrow.lib.Field"), ("py:class", "pyarrow.lib.NativeFile"), + ("py:class", "pyarrow.lib.RecordBatchReader"), ("py:class", "pyarrow._fs.FileSystem"), ("py:class", "pyarrow._fs.FileInfo"), ("py:class", "pyarrow._fs.FileSelector"), @@ -84,3 +90,10 @@ def get_release_version() -> str: edit_on_github_project = "delta-io/delta-rs" edit_on_github_branch = "main" page_source_prefix = "python/docs/source" + + +intersphinx_mapping = { + "pyarrow": ("https://arrow.apache.org/docs/", None), + "pyspark": ("https://spark.apache.org/docs/latest/api/python/", None), + "pandas": ("https://pandas.pydata.org/docs/", None), +} diff --git a/python/docs/source/usage.rst b/python/docs/source/usage.rst index e72aa802c2..249eee01c2 100644 --- a/python/docs/source/usage.rst +++ b/python/docs/source/usage.rst @@ -328,4 +328,32 @@ Optimizing tables is not currently supported. Writing Delta Tables -------------------- -Writing Delta tables is not currently supported. +.. py:currentmodule:: deltalake + +.. warning:: + The writer is currently *experimental*. Please use on test data first, not + on production data. Report any issues at https://github.com/delta-io/delta-rs/issues. + +For overwrites and appends, use :py:func:`write_deltalake`. If the table does not +already exist, it will be created. The ``data`` parameter will accept a Pandas +DataFrame, a PyArrow Table, or an iterator of PyArrow Record Batches. + +.. code-block:: python + + >>> from deltalake.writer import write_deltalake + >>> df = pd.DataFrame({'x': [1, 2, 3]}) + >>> write_deltalake('path/to/table', df) + +.. note:: + :py:func:`write_deltalake` accepts a Pandas DataFrame, but will convert it to + a Arrow table before writing. See caveats in :doc:`pyarrow:python/pandas`. + +By default, writes create a new table and error if it already exists. This is +controlled by the ``mode`` parameter, which mirrors the behavior of Spark's +:py:meth:`pyspark.sql.DataFrameWriter.saveAsTable` DataFrame method. To overwrite pass in ``mode='overwrite'`` and +to append pass in ``mode='append'``: + +.. code-block:: python + + >>> write_deltalake('path/to/table', df, mode='overwrite') + >>> write_deltalake('path/to/table', df, mode='append') diff --git a/python/pyproject.toml b/python/pyproject.toml index bd02df5790..7b096de6f5 100644 --- a/python/pyproject.toml +++ b/python/pyproject.toml @@ -36,7 +36,8 @@ devel = [ "sphinx", "sphinx-rtd-theme", "toml", - "pandas" + "pandas", + "typing-extensions" ] [project.urls] diff --git a/python/src/lib.rs b/python/src/lib.rs index fcab8b11b0..64762a6beb 100644 --- a/python/src/lib.rs +++ b/python/src/lib.rs @@ -3,19 +3,28 @@ extern crate pyo3; use chrono::{DateTime, FixedOffset, Utc}; -use deltalake::action::Stats; -use deltalake::action::{ColumnCountStat, ColumnValueStat}; +use deltalake::action; +use deltalake::action::Action; +use deltalake::action::{ColumnCountStat, ColumnValueStat, DeltaOperation, SaveMode, Stats}; use deltalake::arrow::datatypes::Schema as ArrowSchema; +use deltalake::get_backend_for_uri; use deltalake::partitions::PartitionFilter; use deltalake::storage; +use deltalake::DeltaDataTypeLong; +use deltalake::DeltaDataTypeTimestamp; +use deltalake::DeltaTableMetaData; +use deltalake::DeltaTransactionOptions; use deltalake::{arrow, StorageBackend}; use pyo3::create_exception; use pyo3::exceptions::PyException; +use pyo3::exceptions::PyValueError; use pyo3::prelude::*; use pyo3::types::{PyBytes, PyTuple, PyType}; use std::collections::HashMap; use std::collections::HashSet; use std::convert::TryFrom; +use std::time::SystemTime; +use std::time::UNIX_EPOCH; create_exception!(deltalake, PyDeltaTableError, PyException); @@ -145,6 +154,13 @@ impl RawDeltaTable { }) } + pub fn protocol_versions(&self) -> PyResult<(i32, i32)> { + Ok(( + self._table.get_min_reader_version(), + self._table.get_min_writer_version(), + )) + } + pub fn load_version(&mut self, version: deltalake::DeltaDataTypeVersion) -> PyResult<()> { rt()? .block_on(self._table.load_version(version)) @@ -272,6 +288,50 @@ impl RawDeltaTable { }) .collect() } + + fn create_write_transaction( + &mut self, + add_actions: Vec, + mode: &str, + partition_by: Vec, + ) -> PyResult<()> { + let mode = save_mode_from_str(mode)?; + + let mut actions: Vec = add_actions + .iter() + .map(|add| Action::add(add.into())) + .collect(); + + if let SaveMode::Overwrite = mode { + // Remove all current files + for old_add in self._table.get_state().files().iter() { + let remove_action = Action::remove(action::Remove { + path: old_add.path.clone(), + deletion_timestamp: Some(current_timestamp()), + data_change: true, + extended_file_metadata: Some(old_add.tags.is_some()), + partition_values: Some(old_add.partition_values.clone()), + size: Some(old_add.size), + tags: old_add.tags.clone(), + }); + actions.push(remove_action); + } + } + + let mut transaction = self + ._table + .create_transaction(Some(DeltaTransactionOptions::new(3))); + transaction.add_actions(actions); + rt()? + .block_on(transaction.commit(Some(DeltaOperation::Write { + mode, + partitionBy: Some(partition_by), + predicate: None, + }))) + .map_err(PyDeltaTableError::from_raw)?; + + Ok(()) + } } fn json_value_to_py(value: &serde_json::Value, py: Python) -> PyObject { @@ -409,12 +469,96 @@ fn rust_core_version() -> &'static str { deltalake::crate_version() } +fn save_mode_from_str(value: &str) -> PyResult { + match value { + "append" => Ok(SaveMode::Append), + "overwrite" => Ok(SaveMode::Overwrite), + "error" => Ok(SaveMode::ErrorIfExists), + "ignore" => Ok(SaveMode::Ignore), + _ => Err(PyValueError::new_err("Invalid save mode")), + } +} + +fn current_timestamp() -> DeltaDataTypeTimestamp { + let start = SystemTime::now(); + let since_the_epoch = start + .duration_since(UNIX_EPOCH) + .expect("Time went backwards"); + since_the_epoch.as_millis().try_into().unwrap() +} + +#[derive(FromPyObject)] +pub struct PyAddAction { + path: String, + size: DeltaDataTypeLong, + partition_values: HashMap>, + modification_time: DeltaDataTypeTimestamp, + data_change: bool, + stats: Option, +} + +impl From<&PyAddAction> for action::Add { + fn from(action: &PyAddAction) -> Self { + action::Add { + path: action.path.clone(), + size: action.size, + partition_values: action.partition_values.clone(), + partition_values_parsed: None, + modification_time: action.modification_time, + data_change: action.data_change, + stats: action.stats.clone(), + stats_parsed: None, + tags: None, + } + } +} + +#[pyfunction] +fn write_new_deltalake( + table_uri: String, + schema: ArrowSchema, + add_actions: Vec, + _mode: &str, + partition_by: Vec, +) -> PyResult<()> { + let mut table = deltalake::DeltaTable::new( + &table_uri, + get_backend_for_uri(&table_uri).map_err(PyDeltaTableError::from_storage)?, + deltalake::DeltaTableConfig::default(), + ) + .map_err(PyDeltaTableError::from_raw)?; + + let metadata = DeltaTableMetaData::new( + None, + None, + None, + (&schema).try_into()?, + partition_by, + HashMap::new(), + ); + + let fut = table.create( + metadata, + action::Protocol { + min_reader_version: 1, + min_writer_version: 1, // TODO: Make sure we comply with protocol + }, + None, // TODO + Some(add_actions.iter().map(|add| add.into()).collect()), + ); + + rt()?.block_on(fut).map_err(PyDeltaTableError::from_raw)?; + + Ok(()) +} + #[pymodule] // module name need to match project name fn deltalake(py: Python, m: &PyModule) -> PyResult<()> { env_logger::Builder::from_env(env_logger::Env::default().default_filter_or("warn")).init(); m.add_function(pyo3::wrap_pyfunction!(rust_core_version, m)?)?; + m.add_function(pyo3::wrap_pyfunction!(write_new_deltalake, m)?)?; m.add_class::()?; m.add_class::()?; m.add_class::()?; diff --git a/python/stubs/deltalake/deltalake.pyi b/python/stubs/deltalake/deltalake.pyi index f420569e5d..b872e61bd3 100644 --- a/python/stubs/deltalake/deltalake.pyi +++ b/python/stubs/deltalake/deltalake.pyi @@ -1,6 +1,13 @@ -from typing import Any, Callable +from typing import Any, Callable, List + +import pyarrow as pa + +from deltalake.writer import AddAction RawDeltaTable: Any -PyDeltaTableError: Any rust_core_version: Callable[[], str] DeltaStorageFsBackend: Any + +write_new_deltalake: Callable[[str, pa.Schema, List[AddAction], str, List[str]], None] + +class PyDeltaTableError(BaseException): ... diff --git a/python/stubs/pyarrow/__init__.pyi b/python/stubs/pyarrow/__init__.pyi index 3b28a2a714..c7cef34ba9 100644 --- a/python/stubs/pyarrow/__init__.pyi +++ b/python/stubs/pyarrow/__init__.pyi @@ -2,6 +2,7 @@ from typing import Any, Callable Schema: Any Table: Any +RecordBatch: Any Field: Any DataType: Any schema: Any diff --git a/python/stubs/pyarrow/dataset.pyi b/python/stubs/pyarrow/dataset.pyi index 5d9683dee1..d06f843246 100644 --- a/python/stubs/pyarrow/dataset.pyi +++ b/python/stubs/pyarrow/dataset.pyi @@ -5,3 +5,4 @@ dataset: Any partitioning: Any FileSystemDataset: Any ParquetFileFormat: Any +write_dataset: Any diff --git a/python/stubs/pyarrow/lib.pyi b/python/stubs/pyarrow/lib.pyi new file mode 100644 index 0000000000..fc97dea727 --- /dev/null +++ b/python/stubs/pyarrow/lib.pyi @@ -0,0 +1,3 @@ +from typing import Any + +RecordBatchReader: Any diff --git a/python/tests/__init__.py b/python/tests/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/python/tests/data/date_partitioned_df/_delta_log/.00000000000000000000.json.crc b/python/tests/data/date_partitioned_df/_delta_log/.00000000000000000000.json.crc deleted file mode 100644 index f141a1d1b77c482f0540175c8e170192b2f4e1a5..0000000000000000000000000000000000000000 GIT binary patch literal 0 HcmV?d00001 literal 16 XcmYc;N@ieSU}DHPkYgZgvXTV=Aj$*% diff --git a/python/tests/data/date_partitioned_df/_delta_log/00000000000000000000.json b/python/tests/data/date_partitioned_df/_delta_log/00000000000000000000.json deleted file mode 100644 index 9c01cf24d8..0000000000 --- a/python/tests/data/date_partitioned_df/_delta_log/00000000000000000000.json +++ /dev/null @@ -1,4 +0,0 @@ -{"protocol":{"minReaderVersion":1,"minWriterVersion":2}} -{"metaData":{"id":"588135b2-b298-4d9f-aab6-6dd9bf90d575","format":{"provider":"parquet","options":{}},"schemaString":"{\"type\":\"struct\",\"fields\":[{\"name\":\"id\",\"type\":\"long\",\"nullable\":true,\"metadata\":{}},{\"name\":\"date\",\"type\":\"date\",\"nullable\":true,\"metadata\":{}}]}","partitionColumns":["date"],"configuration":{},"createdTime":1645893400586}} -{"add":{"path":"date=2021-01-01/part-00000-6ae76612-3903-4007-b5ec-511b932751ea.c000.snappy.parquet","partitionValues":{"date":"2021-01-01"},"size":500,"modificationTime":1645893404567,"dataChange":true}} -{"commitInfo":{"timestamp":1645893404671,"operation":"WRITE","operationParameters":{"mode":"ErrorIfExists","partitionBy":"[\"date\"]"},"isolationLevel":"Serializable","isBlindAppend":true,"operationMetrics":{"numFiles":"1","numOutputRows":"5","numOutputBytes":"500"},"engineInfo":"Apache-Spark/3.2.1 Delta-Lake/1.1.0"}} diff --git a/python/tests/data/date_partitioned_df/date=2021-01-01/.part-00000-6ae76612-3903-4007-b5ec-511b932751ea.c000.snappy.parquet.crc b/python/tests/data/date_partitioned_df/date=2021-01-01/.part-00000-6ae76612-3903-4007-b5ec-511b932751ea.c000.snappy.parquet.crc deleted file mode 100644 index 3271be8603d6a86d6b0c806f614ee9ae2315b70b..0000000000000000000000000000000000000000 GIT binary patch literal 0 HcmV?d00001 literal 12 TcmYc;N@ieSU}9*k3_1e<5^n=n diff --git a/python/tests/data/date_partitioned_df/date=2021-01-01/part-00000-6ae76612-3903-4007-b5ec-511b932751ea.c000.snappy.parquet b/python/tests/data/date_partitioned_df/date=2021-01-01/part-00000-6ae76612-3903-4007-b5ec-511b932751ea.c000.snappy.parquet deleted file mode 100644 index c59e5b21d81dbadd3f198c879be88a813350a5eb..0000000000000000000000000000000000000000 GIT binary patch literal 0 HcmV?d00001 literal 500 zcmZWm%SyvQ6un7n2}KuzGt5B1GGI}tgUP6UPy~0ymAG)DBGY85!L+GKT9Hy$uKXnN zAN&&W7reEtU3fQh&$*9thI@W`;SnI$q(i>H-apq|Qbt%J4uE7f02H=i!vbr=v8ZLW z3&_l(IDdd|+MHqUJTP4rh|MVgcd0|hm;mBhXFVJ0^x-xO?oD448%S}-W_A<;EpwN< zRmUYoe&j%j&sbv9GJUh?xazc5i&ttCcK_s7ENbEhP!E89SY(U7kQ11#DOVjj-a=0` z#**|->Y+v-^4F3an>34(0b5hjmmaXae;wIlEYbbr$mB9jo@C$TI@GylmlKgc4~=Mv zEKf&4E^49|nK;Zu>uMqfx<4yLA<~PsI2qg_8jRvtcVn#Ln5l{7LeZK`r#DkzXA8Og zep*@ht9a7$rC-_Yj-oiM)ayna1dW5HQjKF19QlC?dipp}huymADdl_0_k-r)9(P^5 L!K(#uj`#in3TH-apq|Qbt%J4uE7f02H=i!vbr=v8ZLW z3&_l(IDdd|+MHqUJTP4rh|MVgcd0|hm;mBhXFVJ0^x-xO?oD448%S}-W_A<;EpwN< zRmUYoe&j%j&sbv9GJUh?xazc5i&ttCcK_s7ENbEhP!E89SY(U7kQ11#DOVjj-a=0` z#**|->Y+v-^4F3an>34(0b5hjmmaXae;wIlEYbbr$mB9jo@C$TI@GylmlKgc4~=Mv zEKf&4E^49|nK;Zu>uMqfx<4yLA<~PsI2qg_8jRvtcVn#Ln5l{7LeZK`r#DkzXA8Og zep*@ht9a7$rC-_Yj-oiM)ayna1dW5HQjKF19QlC?dipp}huymADdl_0_k-r)9(P^5 L!K(#uj`#in3T>, + add_actions: Option>, ) -> Result<(), DeltaTableError> { let meta = action::MetaData::try_from(metadata)?; @@ -1307,11 +1308,16 @@ impl DeltaTable { Value::Number(serde_json::Number::from(Utc::now().timestamp_millis())), ); - let actions = vec![ + let mut actions = vec![ Action::commitInfo(enriched_commit_info), Action::protocol(protocol), Action::metaData(meta), ]; + if let Some(add_actions) = add_actions { + for add_action in add_actions { + actions.push(Action::add(add_action)); + } + }; let mut transaction = self.create_transaction(None); transaction.add_actions(actions.clone()); @@ -1812,7 +1818,7 @@ mod tests { serde_json::Value::String("test user".to_string()), ); // Action - dt.create(delta_md.clone(), protocol.clone(), Some(commit_info)) + dt.create(delta_md.clone(), protocol.clone(), Some(commit_info), None) .await .unwrap(); diff --git a/rust/src/delta_arrow.rs b/rust/src/delta_arrow.rs index d2f1bd9ce9..4df46fc48e 100644 --- a/rust/src/delta_arrow.rs +++ b/rust/src/delta_arrow.rs @@ -7,6 +7,7 @@ use arrow::datatypes::{ use arrow::error::ArrowError; use lazy_static::lazy_static; use regex::Regex; +use std::collections::HashMap; use std::convert::TryFrom; impl TryFrom<&schema::Schema> for ArrowSchema { @@ -162,6 +163,99 @@ impl TryFrom<&schema::SchemaDataType> for ArrowDataType { } } +impl TryFrom<&ArrowSchema> for schema::Schema { + type Error = ArrowError; + fn try_from(arrow_schema: &ArrowSchema) -> Result { + let new_fields: Result, _> = arrow_schema + .fields() + .iter() + .map(|field| field.try_into()) + .collect(); + Ok(schema::Schema::new(new_fields?)) + } +} + +impl TryFrom<&ArrowField> for schema::SchemaField { + type Error = ArrowError; + fn try_from(arrow_field: &ArrowField) -> Result { + Ok(schema::SchemaField::new( + arrow_field.name().clone(), + arrow_field.data_type().try_into()?, + arrow_field.is_nullable(), + arrow_field + .metadata() + .as_ref() + .map_or_else(HashMap::new, |m| { + m.iter() + .map(|(k, v)| (k.clone(), serde_json::Value::String(v.clone()))) + .collect() + }), + )) + } +} + +impl TryFrom<&ArrowDataType> for schema::SchemaDataType { + type Error = ArrowError; + fn try_from(arrow_datatype: &ArrowDataType) -> Result { + match arrow_datatype { + ArrowDataType::Utf8 => Ok(schema::SchemaDataType::primitive("string".to_string())), + ArrowDataType::Int64 => Ok(schema::SchemaDataType::primitive("long".to_string())), // undocumented type + ArrowDataType::Int32 => Ok(schema::SchemaDataType::primitive("integer".to_string())), + ArrowDataType::Int16 => Ok(schema::SchemaDataType::primitive("short".to_string())), + ArrowDataType::Int8 => Ok(schema::SchemaDataType::primitive("byte".to_string())), + ArrowDataType::Float32 => Ok(schema::SchemaDataType::primitive("float".to_string())), + ArrowDataType::Float64 => Ok(schema::SchemaDataType::primitive("double".to_string())), + ArrowDataType::Boolean => Ok(schema::SchemaDataType::primitive("boolean".to_string())), + ArrowDataType::Binary => Ok(schema::SchemaDataType::primitive("binary".to_string())), + ArrowDataType::Decimal(p, s) => Ok(schema::SchemaDataType::primitive(format!( + "decimal({},{})", + p, s + ))), + ArrowDataType::Date32 => Ok(schema::SchemaDataType::primitive("date".to_string())), + ArrowDataType::Timestamp(TimeUnit::Microsecond, None) => { + Ok(schema::SchemaDataType::primitive("timestamp".to_string())) + } + ArrowDataType::Struct(fields) => { + let converted_fields: Result, _> = + fields.iter().map(|field| field.try_into()).collect(); + Ok(schema::SchemaDataType::r#struct( + schema::SchemaTypeStruct::new(converted_fields?), + )) + } + ArrowDataType::List(field) => { + Ok(schema::SchemaDataType::array(schema::SchemaTypeArray::new( + Box::new((*field).data_type().try_into()?), + (*field).is_nullable(), + ))) + } + ArrowDataType::FixedSizeList(field, _) => { + Ok(schema::SchemaDataType::array(schema::SchemaTypeArray::new( + Box::new((*field).data_type().try_into()?), + (*field).is_nullable(), + ))) + } + ArrowDataType::Map(field, _) => { + if let ArrowDataType::Struct(struct_fields) = field.data_type() { + let key_type = struct_fields[0].data_type().try_into()?; + let value_type = struct_fields[1].data_type().try_into()?; + let value_type_nullable = struct_fields[1].is_nullable(); + Ok(schema::SchemaDataType::map(schema::SchemaTypeMap::new( + Box::new(key_type), + Box::new(value_type), + value_type_nullable, + ))) + } else { + panic!("DataType::Map should contain a struct field child"); + } + } + s => Err(ArrowError::SchemaError(format!( + "Invalid data type for Delta Lake: {}", + s + ))), + } + } +} + /// Returns an arrow schema representing the delta log for use in checkpoints /// /// # Arguments diff --git a/rust/tests/adls_gen2_table_test.rs b/rust/tests/adls_gen2_table_test.rs index 2f3b95efef..b1ad2f3d6b 100644 --- a/rust/tests/adls_gen2_table_test.rs +++ b/rust/tests/adls_gen2_table_test.rs @@ -88,7 +88,7 @@ mod adls_gen2_table { let (metadata, protocol) = table_info(); // Act 1 - dt.create(metadata.clone(), protocol.clone(), None) + dt.create(metadata.clone(), protocol.clone(), None, None) .await .unwrap(); diff --git a/rust/tests/concurrent_writes_test.rs b/rust/tests/concurrent_writes_test.rs index 14d378f929..be7851197c 100644 --- a/rust/tests/concurrent_writes_test.rs +++ b/rust/tests/concurrent_writes_test.rs @@ -83,7 +83,7 @@ async fn concurrent_writes_azure() { min_writer_version: 2, }; - dt.create(metadata.clone(), protocol.clone(), None) + dt.create(metadata.clone(), protocol.clone(), None, None) .await .unwrap(); diff --git a/rust/tests/fs_common/mod.rs b/rust/tests/fs_common/mod.rs index 63c1b2b017..80984eab04 100644 --- a/rust/tests/fs_common/mod.rs +++ b/rust/tests/fs_common/mod.rs @@ -53,7 +53,7 @@ pub async fn create_test_table( min_reader_version: 1, min_writer_version: 2, }; - table.create(md, protocol, None).await.unwrap(); + table.create(md, protocol, None, None).await.unwrap(); table }