diff --git a/python/deltalake/schema.py b/python/deltalake/schema.py index f90307eb96..abc88dba46 100644 --- a/python/deltalake/schema.py +++ b/python/deltalake/schema.py @@ -1,7 +1,42 @@ -from typing import Union +from typing import TYPE_CHECKING, Tuple, Union + +import pyarrow as pa + +if TYPE_CHECKING: + import pandas as pd from ._internal import ArrayType, Field, MapType, PrimitiveType, Schema, StructType # Can't implement inheritance (see note in src/schema.rs), so this is next # best thing. DataType = Union["PrimitiveType", "MapType", "StructType", "ArrayType"] + + +def delta_arrow_schema_from_pandas( + data: "pd.DataFrame", +) -> Tuple[pa.Table, pa.Schema]: + """ + Infers the schema for the delta table from the Pandas DataFrame. + Necessary because of issues such as: https://github.com/delta-io/delta-rs/issues/686 + + :param data: Data to write. + :return: A PyArrow Table and the inferred schema for the Delta Table + """ + + table = pa.Table.from_pandas(data) + schema = table.schema + schema_out = [] + for field in schema: + if isinstance(field.type, pa.TimestampType): + f = pa.field( + name=field.name, + type=pa.timestamp("us"), + nullable=field.nullable, + metadata=field.metadata, + ) + schema_out.append(f) + else: + schema_out.append(field) + schema = pa.schema(schema_out, metadata=schema.metadata) + data = pa.Table.from_pandas(data, schema=schema) + return data, schema diff --git a/python/deltalake/writer.py b/python/deltalake/writer.py index 41de1961c6..ab3481f50a 100644 --- a/python/deltalake/writer.py +++ b/python/deltalake/writer.py @@ -35,6 +35,8 @@ import pyarrow.fs as pa_fs from pyarrow.lib import RecordBatchReader +from deltalake.schema import delta_arrow_schema_from_pandas + from ._internal import DeltaDataChecker as _DeltaDataChecker from ._internal import PyDeltaTableError from ._internal import write_new_deltalake as _write_new_deltalake @@ -132,8 +134,12 @@ def write_deltalake( :param overwrite_schema: If True, allows updating the schema of the table. :param storage_options: options passed to the native delta filesystem. Unused if 'filesystem' is defined. """ + if _has_pandas and isinstance(data, pd.DataFrame): - data = pa.Table.from_pandas(data) + if schema is not None: + data = pa.Table.from_pandas(data, schema=schema) + else: + data, schema = delta_arrow_schema_from_pandas(data) if schema is None: if isinstance(data, RecordBatchReader): diff --git a/python/stubs/pyarrow/__init__.pyi b/python/stubs/pyarrow/__init__.pyi index dde6a48050..fb01f5796d 100644 --- a/python/stubs/pyarrow/__init__.pyi +++ b/python/stubs/pyarrow/__init__.pyi @@ -23,6 +23,8 @@ float16: Any float32: Any float64: Any dictionary: Any +timestamp: Any +TimestampType: Any py_buffer: Callable[[bytes], Any] NativeFile: Any diff --git a/python/tests/test_writer.py b/python/tests/test_writer.py index cda1c3601e..6f1e5a2077 100644 --- a/python/tests/test_writer.py +++ b/python/tests/test_writer.py @@ -272,12 +272,16 @@ def test_fails_wrong_partitioning(existing_table: DeltaTable, sample_data: pa.Ta @pytest.mark.pandas -def test_write_pandas(tmp_path: pathlib.Path, sample_data: pa.Table): +@pytest.mark.parametrize("schema_provided", [True, False]) +def test_write_pandas(tmp_path: pathlib.Path, sample_data: pa.Table, schema_provided): # When timestamp is converted to Pandas, it gets casted to ns resolution, # but Delta Lake schemas only support us resolution. - sample_pandas = sample_data.to_pandas().drop(["timestamp"], axis=1) - write_deltalake(str(tmp_path), sample_pandas) - + sample_pandas = sample_data.to_pandas() + if schema_provided is True: + schema = sample_data.schema + else: + schema = None + write_deltalake(str(tmp_path), sample_pandas, schema=schema) delta_table = DeltaTable(str(tmp_path)) df = delta_table.to_pandas() assert_frame_equal(df, sample_pandas)