diff --git a/python/deltalake/schema.py b/python/deltalake/schema.py index e6854c3779..a22725fdc5 100644 --- a/python/deltalake/schema.py +++ b/python/deltalake/schema.py @@ -1,9 +1,7 @@ -from typing import TYPE_CHECKING, Tuple, Union +from typing import Generator, Union import pyarrow as pa - -if TYPE_CHECKING: - import pandas as pd +import pyarrow.dataset as ds from ._internal import ArrayType as ArrayType from ._internal import Field as Field @@ -17,34 +15,109 @@ DataType = Union["PrimitiveType", "MapType", "StructType", "ArrayType"] -def delta_arrow_schema_from_pandas( - data: "pd.DataFrame", -) -> Tuple[pa.Table, pa.Schema]: - """ - Infers the schema for the delta table from the Pandas DataFrame. - Necessary because of issues such as: https://github.com/delta-io/delta-rs/issues/686 - - Args: - data: Data to write. +### Inspired from Pola-rs repo - licensed with MIT License, see license in python/licenses/polars_license.txt.### +def _convert_pa_schema_to_delta( + schema: pa.schema, large_dtypes: bool = False +) -> pa.schema: + """Convert a PyArrow schema to a schema compatible with Delta Lake. Converts unsigned to signed equivalent, and + converts all timestamps to `us` timestamps. With the boolean flag large_dtypes you can control if the schema + should keep cast normal to large types in the schema, or from large to normal. - Returns: - A PyArrow Table and the inferred schema for the Delta Table + Args + schema: Source schema + large_dtypes: If True, the pyarrow schema is casted to large_dtypes """ + dtype_map = { + pa.uint8(): pa.int8(), + pa.uint16(): pa.int16(), + pa.uint32(): pa.int32(), + pa.uint64(): pa.int64(), + } + if large_dtypes: + dtype_map = { + **dtype_map, + **{pa.string(): pa.large_string(), pa.binary(): pa.large_binary()}, + } + else: + dtype_map = { + **dtype_map, + **{pa.large_string(): pa.string(), pa.large_binary(): pa.binary()}, + } - table = pa.Table.from_pandas(data) - schema = table.schema - schema_out = [] - for field in schema: - if isinstance(field.type, pa.TimestampType): - f = pa.field( - name=field.name, - type=pa.timestamp("us"), - nullable=field.nullable, - metadata=field.metadata, - ) - schema_out.append(f) + def dtype_to_delta_dtype(dtype: pa.DataType) -> pa.DataType: + # Handle nested types + if isinstance(dtype, (pa.LargeListType, pa.ListType)): + return list_to_delta_dtype(dtype) + elif isinstance(dtype, pa.StructType): + return struct_to_delta_dtype(dtype) + elif isinstance(dtype, pa.TimestampType): + return pa.timestamp( + "us" + ) # TODO(ion): propagate also timezone information during writeonce we can properly read TZ in delta schema + try: + return dtype_map[dtype] + except KeyError: + return dtype + + def list_to_delta_dtype( + dtype: Union[pa.LargeListType, pa.ListType], + ) -> Union[pa.LargeListType, pa.ListType]: + nested_dtype = dtype.value_type + nested_dtype_cast = dtype_to_delta_dtype(nested_dtype) + if large_dtypes: + return pa.large_list(nested_dtype_cast) else: - schema_out.append(field) - schema = pa.schema(schema_out, metadata=schema.metadata) - table = table.cast(target_schema=schema) - return table, schema + return pa.list_(nested_dtype_cast) + + def struct_to_delta_dtype(dtype: pa.StructType) -> pa.StructType: + fields = [dtype[i] for i in range(dtype.num_fields)] + fields_cast = [f.with_type(dtype_to_delta_dtype(f.type)) for f in fields] + return pa.struct(fields_cast) + + return pa.schema([f.with_type(dtype_to_delta_dtype(f.type)) for f in schema]) + + +def _cast_schema_to_recordbatchreader( + reader: pa.RecordBatchReader, schema: pa.schema +) -> Generator[pa.RecordBatch, None, None]: + """Creates recordbatch generator.""" + for batch in reader: + yield pa.Table.from_batches([batch]).cast(schema).to_batches()[0] + + +def convert_pyarrow_recordbatchreader( + data: pa.RecordBatchReader, large_dtypes: bool +) -> pa.RecordBatchReader: + """Converts a PyArrow RecordBatchReader to a PyArrow RecordBatchReader with a compatible delta schema""" + schema = _convert_pa_schema_to_delta(data.schema, large_dtypes=large_dtypes) + + data = pa.RecordBatchReader.from_batches( + schema, + _cast_schema_to_recordbatchreader(data, schema), + ) + return data + + +def convert_pyarrow_recordbatch( + data: pa.RecordBatch, large_dtypes: bool +) -> pa.RecordBatchReader: + """Converts a PyArrow RecordBatch to a PyArrow RecordBatchReader with a compatible delta schema""" + schema = _convert_pa_schema_to_delta(data.schema, large_dtypes=large_dtypes) + data = pa.Table.from_batches([data]).cast(schema).to_reader() + return data + + +def convert_pyarrow_table(data: pa.Table, large_dtypes: bool) -> pa.RecordBatchReader: + """Converts a PyArrow table to a PyArrow RecordBatchReader with a compatible delta schema""" + schema = _convert_pa_schema_to_delta(data.schema, large_dtypes=large_dtypes) + data = data.cast(schema).to_reader() + return data + + +def convert_pyarrow_dataset( + data: ds.Dataset, large_dtypes: bool +) -> pa.RecordBatchReader: + """Converts a PyArrow dataset to a PyArrow RecordBatchReader with a compatible delta schema""" + data = data.scanner().to_reader() + data = convert_pyarrow_recordbatchreader(data, large_dtypes) + return data diff --git a/python/deltalake/table.py b/python/deltalake/table.py index 941e0d1fce..b238af7929 100644 --- a/python/deltalake/table.py +++ b/python/deltalake/table.py @@ -20,6 +20,7 @@ ) import pyarrow +import pyarrow.dataset as ds import pyarrow.fs as pa_fs from pyarrow.dataset import ( Expression, @@ -596,7 +597,13 @@ def optimize( def merge( self, - source: Union[pyarrow.Table, pyarrow.RecordBatch, pyarrow.RecordBatchReader], + source: Union[ + pyarrow.Table, + pyarrow.RecordBatch, + pyarrow.RecordBatchReader, + ds.Dataset, + "pandas.DataFrame", + ], predicate: str, source_alias: Optional[str] = None, target_alias: Optional[str] = None, @@ -619,17 +626,28 @@ def merge( invariants = self.schema().invariants checker = _DeltaDataChecker(invariants) + from .schema import ( + convert_pyarrow_dataset, + convert_pyarrow_recordbatch, + convert_pyarrow_recordbatchreader, + convert_pyarrow_table, + ) + if isinstance(source, pyarrow.RecordBatchReader): - schema = source.schema + source = convert_pyarrow_recordbatchreader(source, large_dtypes=True) elif isinstance(source, pyarrow.RecordBatch): - schema = source.schema - source = [source] + source = convert_pyarrow_recordbatch(source, large_dtypes=True) elif isinstance(source, pyarrow.Table): - schema = source.schema - source = source.to_reader() + source = convert_pyarrow_table(source, large_dtypes=True) + elif isinstance(source, ds.Dataset): + source = convert_pyarrow_dataset(source, large_dtypes=True) + elif isinstance(source, pandas.DataFrame): + source = convert_pyarrow_table( + pyarrow.Table.from_pandas(source), large_dtypes=True + ) else: raise TypeError( - f"{type(source).__name__} is not a valid input. Only PyArrow RecordBatchReader, RecordBatch or Table are valid inputs for source." + f"{type(source).__name__} is not a valid input. Only PyArrow RecordBatchReader, RecordBatch, Table or Pandas DataFrame are valid inputs for source." ) def validate_batch(batch: pyarrow.RecordBatch) -> pyarrow.RecordBatch: @@ -637,7 +655,7 @@ def validate_batch(batch: pyarrow.RecordBatch) -> pyarrow.RecordBatch: return batch source = pyarrow.RecordBatchReader.from_batches( - schema, (validate_batch(batch) for batch in source) + source.schema, (validate_batch(batch) for batch in source) ) return TableMerger( diff --git a/python/deltalake/writer.py b/python/deltalake/writer.py index dd0d350eb4..065803f5c7 100644 --- a/python/deltalake/writer.py +++ b/python/deltalake/writer.py @@ -34,13 +34,17 @@ import pyarrow.fs as pa_fs from pyarrow.lib import RecordBatchReader -from deltalake.schema import delta_arrow_schema_from_pandas - from ._internal import DeltaDataChecker as _DeltaDataChecker from ._internal import batch_distinct from ._internal import convert_to_deltalake as _convert_to_deltalake from ._internal import write_new_deltalake as _write_new_deltalake from .exceptions import DeltaProtocolError, TableNotFoundError +from .schema import ( + convert_pyarrow_dataset, + convert_pyarrow_recordbatch, + convert_pyarrow_recordbatchreader, + convert_pyarrow_table, +) from .table import MAX_SUPPORTED_WRITER_VERSION, DeltaTable try: @@ -159,13 +163,8 @@ def write_deltalake( overwrite_schema: If True, allows updating the schema of the table. storage_options: options passed to the native delta filesystem. Unused if 'filesystem' is defined. partition_filters: the partition filters that will be used for partition overwrite. - large_dtypes: If True, the table schema is checked against large_dtypes + large_dtypes: If True, the data schema is kept in large_dtypes, has no effect on pandas dataframe input """ - if _has_pandas and isinstance(data, pd.DataFrame): - if schema is not None: - data = pa.Table.from_pandas(data, schema=schema) - else: - data, schema = delta_arrow_schema_from_pandas(data) table, table_uri = try_get_table_and_table_uri(table_or_uri, storage_options) @@ -173,13 +172,29 @@ def write_deltalake( if table: table.update_incremental() - if schema is None: - if isinstance(data, RecordBatchReader): - schema = data.schema - elif isinstance(data, Iterable): - raise ValueError("You must provide schema if data is Iterable") + if isinstance(data, RecordBatchReader): + data = convert_pyarrow_recordbatchreader(data, large_dtypes) + elif isinstance(data, pa.RecordBatch): + data = convert_pyarrow_recordbatch(data, large_dtypes) + elif isinstance(data, pa.Table): + data = convert_pyarrow_table(data, large_dtypes) + elif isinstance(data, ds.Dataset): + data = convert_pyarrow_dataset(data, large_dtypes) + elif _has_pandas and isinstance(data, pd.DataFrame): + if schema is not None: + data = pa.Table.from_pandas(data, schema=schema) else: - schema = data.schema + data = convert_pyarrow_table(pa.Table.from_pandas(data), False) + elif isinstance(data, Iterable): + if schema is None: + raise ValueError("You must provide schema if data is Iterable") + else: + raise TypeError( + f"{type(data).__name__} is not a valid input. Only PyArrow RecordBatchReader, RecordBatch, Iterable[RecordBatch], Table, Dataset or Pandas DataFrame are valid inputs for source." + ) + + if schema is None: + schema = data.schema if filesystem is not None: raise NotImplementedError("Filesystem support is not yet implemented. #570") @@ -226,7 +241,7 @@ def write_deltalake( current_version = -1 dtype_map = { - pa.large_string(): pa.string(), # type: ignore + pa.large_string(): pa.string(), } def _large_to_normal_dtype(dtype: pa.DataType) -> pa.DataType: @@ -328,19 +343,8 @@ def validate_batch(batch: pa.RecordBatch) -> pa.RecordBatch: return batch - if isinstance(data, RecordBatchReader): - batch_iter = data - elif isinstance(data, pa.RecordBatch): - batch_iter = [data] - elif isinstance(data, pa.Table): - batch_iter = data.to_batches() - elif isinstance(data, ds.Dataset): - batch_iter = data.to_batches() - else: - batch_iter = data - data = RecordBatchReader.from_batches( - schema, (validate_batch(batch) for batch in batch_iter) + schema, (validate_batch(batch) for batch in data) ) if file_options is not None: diff --git a/python/licenses/README.md b/python/licenses/README.md new file mode 100644 index 0000000000..7f8f61c9d4 --- /dev/null +++ b/python/licenses/README.md @@ -0,0 +1,8 @@ +# Licenses +Below are described which licenses apply to the deltalake package and to which areas of the source code. + +### deltalake_license.txt (APACHE 2.0 License) +Applies to the full deltalake package source code. + +### polars_license.txt (MIT License) +Applies solely to the `_convert_pa_schema_to_delta` function in `deltalake/schema.py`. \ No newline at end of file diff --git a/python/LICENSE.txt b/python/licenses/deltalake_license.txt similarity index 100% rename from python/LICENSE.txt rename to python/licenses/deltalake_license.txt diff --git a/python/licenses/polars_license.txt b/python/licenses/polars_license.txt new file mode 100644 index 0000000000..06d01f6abf --- /dev/null +++ b/python/licenses/polars_license.txt @@ -0,0 +1,19 @@ +Copyright (c) 2020 Ritchie Vink + +Permission is hereby granted, free of charge, to any person obtaining a copy +of this software and associated documentation files (the "Software"), to deal +in the Software without restriction, including without limitation the rights +to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +copies of the Software, and to permit persons to whom the Software is +furnished to do so, subject to the following conditions: + +The above copyright notice and this permission notice shall be included in all +copies or substantial portions of the Software. + +THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE +SOFTWARE. diff --git a/python/pyproject.toml b/python/pyproject.toml index aaeda6bfd2..6ffe4ca14c 100644 --- a/python/pyproject.toml +++ b/python/pyproject.toml @@ -6,7 +6,7 @@ build-backend = "maturin" name = "deltalake" description = "Native Delta Lake Python binding based on delta-rs with Pandas integration" readme = "README.md" -license = {file = "LICENSE.txt"} +license = {file = "licenses/deltalake_license.txt"} requires-python = ">=3.8" keywords = ["deltalake", "delta", "datalake", "pandas", "arrow"] classifiers = [ diff --git a/python/stubs/pyarrow/__init__.pyi b/python/stubs/pyarrow/__init__.pyi index f8c9d152aa..10edfcf663 100644 --- a/python/stubs/pyarrow/__init__.pyi +++ b/python/stubs/pyarrow/__init__.pyi @@ -19,10 +19,23 @@ type_for_alias: Any date32: Any date64: Any decimal128: Any +int8: Any +int16: Any int32: Any +int64: Any +uint8: Any +uint16: Any +uint32: Any +uint64: Any float16: Any float32: Any float64: Any +large_string: Any +string: Any +large_binary: Any +binary: Any +large_list: Any +LargeListType: Any dictionary: Any timestamp: Any TimestampType: Any diff --git a/python/tests/test_schema.py b/python/tests/test_schema.py index f63df0e9fb..6a30ca684e 100644 --- a/python/tests/test_schema.py +++ b/python/tests/test_schema.py @@ -1,10 +1,17 @@ import json -import pyarrow +import pyarrow as pa import pytest from deltalake import DeltaTable, Field -from deltalake.schema import ArrayType, MapType, PrimitiveType, Schema, StructType +from deltalake.schema import ( + ArrayType, + MapType, + PrimitiveType, + Schema, + StructType, + _convert_pa_schema_to_delta, +) def test_table_schema(): @@ -36,7 +43,7 @@ def test_table_schema_pyarrow_simple(): field = schema.field(0) assert len(schema.types) == 1 assert field.name == "id" - assert field.type == pyarrow.int64() + assert field.type == pa.int64() assert field.nullable is True assert field.metadata is None @@ -48,7 +55,7 @@ def test_table_schema_pyarrow_020(): field = schema.field(0) assert len(schema.types) == 1 assert field.name == "value" - assert field.type == pyarrow.int32() + assert field.type == pa.int32() assert field.nullable is True assert field.metadata is None @@ -213,3 +220,236 @@ def test_delta_schema(): schema_without_metadata = schema = Schema(fields) pa_schema = schema_without_metadata.to_pyarrow() assert schema_without_metadata == Schema.from_pyarrow(pa_schema) + + +@pytest.mark.parametrize( + "schema,expected_schema,large_dtypes", + [ + ( + pa.schema([("some_int", pa.uint32()), ("some_string", pa.string())]), + pa.schema([("some_int", pa.int32()), ("some_string", pa.string())]), + False, + ), + ( + pa.schema( + [ + pa.field("some_int", pa.uint32(), nullable=True), + pa.field("some_string", pa.string(), nullable=False), + ] + ), + pa.schema( + [ + pa.field("some_int", pa.int32(), nullable=True), + pa.field("some_string", pa.string(), nullable=False), + ] + ), + False, + ), + ( + pa.schema( + [ + pa.field("some_int", pa.uint32(), nullable=True), + pa.field("some_string", pa.string(), nullable=False), + ] + ), + pa.schema( + [ + pa.field("some_int", pa.int32(), nullable=True), + pa.field("some_string", pa.large_string(), nullable=False), + ] + ), + True, + ), + ( + pa.schema([("some_int", pa.uint32()), ("some_string", pa.string())]), + pa.schema([("some_int", pa.int32()), ("some_string", pa.large_string())]), + True, + ), + ( + pa.schema([("some_int", pa.uint32()), ("some_string", pa.large_string())]), + pa.schema([("some_int", pa.int32()), ("some_string", pa.string())]), + False, + ), + ( + pa.schema( + [ + ("some_int", pa.uint8()), + ("some_int1", pa.uint16()), + ("some_int2", pa.uint32()), + ("some_int3", pa.uint64()), + ] + ), + pa.schema( + [ + ("some_int", pa.int8()), + ("some_int1", pa.int16()), + ("some_int2", pa.int32()), + ("some_int3", pa.int64()), + ] + ), + True, + ), + ( + pa.schema( + [ + ("some_list", pa.list_(pa.string())), + ("some_list_binary", pa.list_(pa.binary())), + ("some_string", pa.large_string()), + ] + ), + pa.schema( + [ + ("some_list", pa.large_list(pa.large_string())), + ("some_list_binary", pa.large_list(pa.large_binary())), + ("some_string", pa.large_string()), + ] + ), + True, + ), + ( + pa.schema( + [ + ("some_list", pa.large_list(pa.string())), + ("some_string", pa.large_string()), + ("some_binary", pa.large_binary()), + ] + ), + pa.schema( + [ + ("some_list", pa.list_(pa.string())), + ("some_string", pa.string()), + ("some_binary", pa.binary()), + ] + ), + False, + ), + ( + pa.schema( + [ + ("highly_nested_list", pa.list_(pa.list_(pa.list_(pa.string())))), + ( + "highly_nested_list_binary", + pa.list_(pa.list_(pa.list_(pa.binary()))), + ), + ("some_string", pa.large_string()), + ("some_binary", pa.large_binary()), + ] + ), + pa.schema( + [ + ( + "highly_nested_list", + pa.large_list(pa.large_list(pa.large_list(pa.large_string()))), + ), + ( + "highly_nested_list_binary", + pa.large_list(pa.large_list(pa.large_list(pa.large_binary()))), + ), + ("some_string", pa.large_string()), + ("some_binary", pa.large_binary()), + ] + ), + True, + ), + ( + pa.schema( + [ + ( + "highly_nested_list", + pa.large_list(pa.list_(pa.large_list(pa.string()))), + ), + ( + "highly_nested_list_int", + pa.large_list(pa.list_(pa.large_list(pa.uint64()))), + ), + ("some_string", pa.large_string()), + ("some_binary", pa.large_binary()), + ] + ), + pa.schema( + [ + ("highly_nested_list", pa.list_(pa.list_(pa.list_(pa.string())))), + ( + "highly_nested_list_int", + pa.list_(pa.list_(pa.list_(pa.int64()))), + ), + ("some_string", pa.string()), + ("some_binary", pa.binary()), + ] + ), + False, + ), + ( + pa.schema( + [ + ("timestamp", pa.timestamp("s")), + ("timestamp1", pa.timestamp("ms")), + ("timestamp2", pa.timestamp("us")), + ("timestamp3", pa.timestamp("ns")), + ("timestamp4", pa.timestamp("s", tz="UTC")), + ("timestamp5", pa.timestamp("ms", tz="UTC")), + ("timestamp6", pa.timestamp("ns", tz="UTC")), + ("timestamp7", pa.timestamp("ns", tz="UTC")), + ] + ), + pa.schema( + [ + ("timestamp", pa.timestamp("us")), + ("timestamp1", pa.timestamp("us")), + ("timestamp2", pa.timestamp("us")), + ("timestamp3", pa.timestamp("us")), + ("timestamp4", pa.timestamp("us")), + ("timestamp5", pa.timestamp("us")), + ("timestamp6", pa.timestamp("us")), + ("timestamp7", pa.timestamp("us")), + ] + ), + False, + ), + ( + pa.schema( + [ + ( + "struct", + pa.struct( + { + "highly_nested_list": pa.large_list( + pa.list_(pa.large_list(pa.string())) + ), + "highly_nested_list_int": pa.large_list( + pa.list_(pa.large_list(pa.uint64())) + ), + "some_string": pa.large_string(), + "some_binary": pa.large_binary(), + } + ), + ) + ] + ), + pa.schema( + [ + ( + "struct", + pa.struct( + { + "highly_nested_list": pa.list_( + pa.list_(pa.list_(pa.string())) + ), + "highly_nested_list_int": pa.list_( + pa.list_(pa.list_(pa.int64())) + ), + "some_string": pa.string(), + "some_binary": pa.binary(), + } + ), + ) + ] + ), + False, + ), + ], +) +def test_schema_conversions(schema, expected_schema, large_dtypes): + result_schema = _convert_pa_schema_to_delta(schema, large_dtypes) + + assert result_schema == expected_schema diff --git a/python/tests/test_writer.py b/python/tests/test_writer.py index d048f8b79b..4330489e4a 100644 --- a/python/tests/test_writer.py +++ b/python/tests/test_writer.py @@ -304,15 +304,41 @@ def test_write_iterator( assert DeltaTable(tmp_path).to_pyarrow_table() == sample_data +@pytest.mark.parametrize("large_dtypes", [True, False]) +@pytest.mark.parametrize( + "constructor", + [ + lambda table: table.to_pyarrow_dataset(), + lambda table: table.to_pyarrow_table(), + lambda table: table.to_pyarrow_table().to_batches()[0], + ], +) +def test_write_dataset_table_recordbatch( + tmp_path: pathlib.Path, + existing_table: DeltaTable, + sample_data: pa.Table, + large_dtypes: bool, + constructor, +): + dataset = constructor(existing_table) + + write_deltalake(tmp_path, dataset, mode="overwrite", large_dtypes=large_dtypes) + assert DeltaTable(tmp_path).to_pyarrow_table() == sample_data + + +@pytest.mark.parametrize("large_dtypes", [True, False]) def test_write_recordbatchreader( - tmp_path: pathlib.Path, existing_table: DeltaTable, sample_data: pa.Table + tmp_path: pathlib.Path, + existing_table: DeltaTable, + sample_data: pa.Table, + large_dtypes: bool, ): batches = existing_table.to_pyarrow_dataset().to_batches() reader = RecordBatchReader.from_batches( existing_table.to_pyarrow_dataset().schema, batches ) - write_deltalake(tmp_path, reader, mode="overwrite") + write_deltalake(tmp_path, reader, mode="overwrite", large_dtypes=large_dtypes) assert DeltaTable(tmp_path).to_pyarrow_table() == sample_data