Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fix: more consistent handling of partition values and file paths #1661

Merged
merged 15 commits into from
Sep 25, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -27,3 +27,4 @@ Cargo.lock
!/delta-inspect/Cargo.lock
!/proofs/Cargo.lock

justfile
20 changes: 20 additions & 0 deletions python/deltalake/_util.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,20 @@
from datetime import date, datetime
from typing import Any


def encode_partition_value(val: Any) -> str:
# Rules based on: https://github.com/delta-io/delta/blob/master/PROTOCOL.md#partition-value-serialization
if isinstance(val, bool):
return str(val).lower()
if isinstance(val, str):
return val
elif isinstance(val, (int, float)):
return str(val)
elif isinstance(val, date):
return val.isoformat()
elif isinstance(val, datetime):
return val.isoformat(sep=" ")
elif isinstance(val, bytes):
return val.decode("unicode_escape", "backslashreplace")
else:
raise ValueError(f"Could not encode partition value for type: {val}")
5 changes: 3 additions & 2 deletions python/deltalake/table.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@
import pandas

from ._internal import RawDeltaTable
from ._util import encode_partition_value
from .data_catalog import DataCatalog
from .exceptions import DeltaProtocolError
from .fs import DeltaStorageHandler
Expand Down Expand Up @@ -625,9 +626,9 @@ def __stringify_partition_values(
for field, op, value in partition_filters:
str_value: Union[str, List[str]]
if isinstance(value, (list, tuple)):
str_value = [str(val) for val in value]
str_value = [encode_partition_value(val) for val in value]
else:
str_value = str(value)
str_value = encode_partition_value(value)
out.append((field, op, str_value))
return out

Expand Down
25 changes: 5 additions & 20 deletions python/deltalake/writer.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,9 +17,12 @@
Tuple,
Union,
)
from urllib.parse import unquote

from deltalake.fs import DeltaStorageHandler

from ._util import encode_partition_value

if TYPE_CHECKING:
import pandas as pd

Expand Down Expand Up @@ -262,7 +265,7 @@ def check_data_is_aligned_with_partition_filtering(
for i in range(partition_values.num_rows):
# Map will maintain order of partition_columns
partition_map = {
column_name: __encode_partition_value(
column_name: encode_partition_value(
batch.column(column_name)[i].as_py()
)
for column_name in table.metadata().partition_columns
Expand Down Expand Up @@ -422,7 +425,7 @@ def get_partitions_from_path(path: str) -> Tuple[str, Dict[str, Optional[str]]]:
if value == "__HIVE_DEFAULT_PARTITION__":
out[key] = None
else:
out[key] = value
out[key] = unquote(value)
return path, out


Expand Down Expand Up @@ -489,21 +492,3 @@ def iter_groups(metadata: Any) -> Iterator[Any]:
maximum for maximum in maximums if maximum is not None
)
return stats


def __encode_partition_value(val: Any) -> str:
# Rules based on: https://github.com/delta-io/delta/blob/master/PROTOCOL.md#partition-value-serialization
if isinstance(val, bool):
return str(val).lower()
if isinstance(val, str):
return val
elif isinstance(val, (int, float)):
return str(val)
elif isinstance(val, date):
return val.isoformat()
elif isinstance(val, datetime):
return val.isoformat(sep=" ")
elif isinstance(val, bytes):
return val.decode("unicode_escape", "backslashreplace")
else:
raise ValueError(f"Could not encode partition value for type: {val}")
34 changes: 33 additions & 1 deletion python/tests/pyspark_integration/test_writer_readable.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
import delta
import delta.pip_utils
import delta.tables
import pyspark.pandas as ps

spark = get_spark()
except ModuleNotFoundError:
Expand All @@ -34,7 +35,7 @@ def test_basic_read(sample_data: pa.Table, existing_table: DeltaTable):
@pytest.mark.pyspark
@pytest.mark.integration
def test_partitioned(tmp_path: pathlib.Path, sample_data: pa.Table):
partition_cols = ["date32", "utf8"]
partition_cols = ["date32", "utf8", "timestamp", "bool"]

# Add null values to sample data to verify we can read null partitions
sample_data_with_null = sample_data
Expand Down Expand Up @@ -63,3 +64,34 @@ def test_overwrite(

write_deltalake(path, sample_data, mode="overwrite")
assert_spark_read_equal(sample_data, path)


@pytest.mark.pyspark
@pytest.mark.integration
def test_issue_1591_roundtrip_special_characters(tmp_path: pathlib.Path):
test_string = r'$%&/()=^"[]#*?.:_-{=}|`<>~/\r\n+'
poisoned = "}|`<>~"
for char in poisoned:
test_string = test_string.replace(char, "")

data = pa.table(
{
"string": pa.array([test_string], type=pa.utf8()),
"data": pa.array(["python-module-test-write"]),
}
)

deltalake_path = tmp_path / "deltalake"
write_deltalake(
table_or_uri=deltalake_path, mode="append", data=data, partition_by=["string"]
)

loaded = ps.read_delta(str(deltalake_path), index_col=None).to_pandas()
assert loaded.shape == data.shape

spark_path = tmp_path / "spark"
spark_df = spark.createDataFrame(data.to_pandas())
spark_df.write.format("delta").partitionBy(["string"]).save(str(spark_path))

loaded = DeltaTable(spark_path).to_pandas()
assert loaded.shape == data.shape
46 changes: 46 additions & 0 deletions python/tests/test_table_read.py
Original file line number Diff line number Diff line change
Expand Up @@ -648,3 +648,49 @@ def assert_scan_equals(table, predicate, expected):
assert_num_fragments(table, predicate, 2)
expected = pa.table({"part": ["a", "a", "b", "b"], "value": [1, 1, None, None]})
assert_scan_equals(table, predicate, expected)


def test_issue_1653_filter_bool_partition(tmp_path: Path):
ta = pa.Table.from_pydict(
{
"bool_col": [True, False, True, False],
"int_col": [0, 1, 2, 3],
"str_col": ["a", "b", "c", "d"],
}
)
write_deltalake(
tmp_path, ta, partition_by=["bool_col", "int_col"], mode="overwrite"
)
dt = DeltaTable(tmp_path)

assert (
dt.to_pyarrow_table(
filters=[
("int_col", "=", 0),
("bool_col", "=", True),
]
).num_rows
== 1
)
assert (
len(
dt.file_uris(
partition_filters=[
("int_col", "=", 0),
("bool_col", "=", "true"),
]
)
)
== 1
)
assert (
len(
dt.file_uris(
partition_filters=[
("int_col", "=", 0),
("bool_col", "=", True),
]
)
)
== 1
)
17 changes: 17 additions & 0 deletions python/tests/test_writer.py
Original file line number Diff line number Diff line change
Expand Up @@ -168,6 +168,7 @@ def test_roundtrip_metadata(tmp_path: pathlib.Path, sample_data: pa.Table):
"bool",
"binary",
"date32",
"timestamp",
],
)
def test_roundtrip_partitioned(
Expand Down Expand Up @@ -888,3 +889,19 @@ def comp():
"a concurrent transaction deleted the same data your transaction deletes"
in str(exception)
)


def test_issue_1651_roundtrip_timestamp(tmp_path: pathlib.Path):
data = pa.table(
{
"id": pa.array([425], type=pa.int32()),
"data": pa.array(["python-module-test-write"]),
"t": pa.array([datetime(2023, 9, 15)]),
}
)

write_deltalake(table_or_uri=tmp_path, mode="append", data=data, partition_by=["t"])
dt = DeltaTable(table_uri=tmp_path)
dataset = dt.to_pyarrow_dataset()

assert dataset.count_rows() == 1
4 changes: 4 additions & 0 deletions rust/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -185,6 +185,10 @@ harness = false
name = "basic_operations"
required-features = ["datafusion"]

[[example]]
name = "load_table"
required-features = ["datafusion"]

[[example]]
name = "recordbatch-writer"
required-features = ["arrow"]
41 changes: 33 additions & 8 deletions rust/examples/basic_operations.rs
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
use arrow::{
array::{Int32Array, StringArray},
datatypes::{DataType, Field, Schema as ArrowSchema},
array::{Int32Array, StringArray, TimestampMicrosecondArray},
datatypes::{DataType, Field, Schema as ArrowSchema, TimeUnit},
record_batch::RecordBatch,
};
use deltalake::operations::collect_sendable_stream;
Expand All @@ -26,34 +26,59 @@ fn get_table_columns() -> Vec<SchemaField> {
true,
Default::default(),
),
SchemaField::new(
String::from("timestamp"),
SchemaDataType::primitive(String::from("timestamp")),
true,
Default::default(),
),
]
}

fn get_table_batches() -> RecordBatch {
let schema = Arc::new(ArrowSchema::new(vec![
Field::new("int", DataType::Int32, false),
Field::new("string", DataType::Utf8, true),
Field::new(
"timestamp",
DataType::Timestamp(TimeUnit::Microsecond, None),
true,
),
]));

let int_values = Int32Array::from(vec![1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11]);
let str_values = StringArray::from(vec!["A", "B", "A", "B", "A", "A", "A", "B", "B", "A", "A"]);

RecordBatch::try_new(schema, vec![Arc::new(int_values), Arc::new(str_values)]).unwrap()
let ts_values = TimestampMicrosecondArray::from(vec![
1000000012, 1000000012, 1000000012, 1000000012, 500012305, 500012305, 500012305, 500012305,
500012305, 500012305, 500012305,
]);
RecordBatch::try_new(
schema,
vec![
Arc::new(int_values),
Arc::new(str_values),
Arc::new(ts_values),
],
)
.unwrap()
}

#[tokio::main(flavor = "current_thread")]
async fn main() -> Result<(), deltalake::errors::DeltaTableError> {
// Create a delta operations client pointing at an un-initialized in-memory location.
// In a production environment this would be created with "try_new" and point at
// a real storage location.
let ops = DeltaOps::new_in_memory();
// Create a delta operations client pointing at an un-initialized location.
let ops = if let Ok(table_uri) = std::env::var("TABLE_URI") {
DeltaOps::try_from_uri(table_uri).await?
} else {
DeltaOps::new_in_memory()
};

// The operations module uses a builder pattern that allows specifying several options
// on how the command behaves. The builders implement `Into<Future>`, so once
// options are set you can run the command using `.await`.
let table = ops
.create()
.with_columns(get_table_columns())
.with_partition_columns(["timestamp"])
.with_table_name("my_table")
.with_comment("A table to show how delta-rs works")
.await?;
Expand Down
20 changes: 20 additions & 0 deletions rust/examples/load_table.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,20 @@
use arrow::record_batch::RecordBatch;
use deltalake::operations::collect_sendable_stream;
use deltalake::DeltaOps;

#[tokio::main(flavor = "current_thread")]
async fn main() -> Result<(), deltalake::errors::DeltaTableError> {
// Create a delta operations client pointing at an un-initialized location.
let ops = if let Ok(table_uri) = std::env::var("TABLE_URI") {
DeltaOps::try_from_uri(table_uri).await?
} else {
DeltaOps::try_from_uri("./rust/tests/data/delta-0.8.0").await?
};

let (_table, stream) = ops.load().await?;
let data: Vec<RecordBatch> = collect_sendable_stream(stream).await?;

println!("{:?}", data);

Ok(())
}
2 changes: 1 addition & 1 deletion rust/examples/recordbatch-writer.rs
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,7 @@ async fn main() -> Result<(), DeltaTableError> {
})?;
info!("Using the location of: {:?}", table_uri);

let table_path = Path::from(table_uri.as_ref());
let table_path = Path::parse(&table_uri)?;

let maybe_table = deltalake::open_table(&table_path).await;
let mut table = match maybe_table {
Expand Down
Loading
Loading