diff --git a/rust/src/operations/writer.rs b/rust/src/operations/writer.rs index c2cdb8aa65..b6134386c1 100644 --- a/rust/src/operations/writer.rs +++ b/rust/src/operations/writer.rs @@ -5,7 +5,7 @@ use std::collections::HashMap; use crate::action::Add; use crate::storage::ObjectStoreRef; use crate::writer::record_batch::{divide_by_partition_values, PartitionResult}; -use crate::writer::stats::{apply_null_counts, create_add, NullCounts}; +use crate::writer::stats::create_add; use crate::writer::utils::{ arrow_schema_without_partitions, record_batch_without_partitions, PartitionPath, ShareableBuffer, @@ -16,7 +16,6 @@ use arrow::datatypes::SchemaRef as ArrowSchemaRef; use arrow::error::ArrowError; use arrow::record_batch::RecordBatch; use bytes::Bytes; -use log::warn; use object_store::{path::Path, ObjectStore}; use parquet::arrow::ArrowWriter; use parquet::basic::Compression; @@ -269,7 +268,6 @@ pub(crate) struct PartitionWriter { buffer: ShareableBuffer, arrow_writer: ArrowWriter, part_counter: usize, - null_counts: NullCounts, files_written: Vec, } @@ -293,7 +291,6 @@ impl PartitionWriter { buffer, arrow_writer, part_counter: 0, - null_counts: NullCounts::new(), files_written: Vec::new(), }) } @@ -307,11 +304,8 @@ impl PartitionWriter { self.config.prefix.child(file_name) } - fn replace_arrow_buffer( - &mut self, - seed: impl AsRef<[u8]>, - ) -> DeltaResult<(ArrowWriter, ShareableBuffer)> { - let new_buffer = ShareableBuffer::from_bytes(seed.as_ref()); + fn reset_writer(&mut self) -> DeltaResult<(ArrowWriter, ShareableBuffer)> { + let new_buffer = ShareableBuffer::default(); let arrow_writer = ArrowWriter::try_new( new_buffer.clone(), self.config.file_schema.clone(), @@ -324,40 +318,27 @@ impl PartitionWriter { } fn write_batch(&mut self, batch: &RecordBatch) -> DeltaResult<()> { - // copy current cursor bytes so we can recover from failures - // TODO is copying this something we should be doing? - let buffer_bytes = self.buffer.to_vec(); - match self.arrow_writer.write(batch) { - Ok(_) => { - apply_null_counts(&batch.clone().into(), &mut self.null_counts, 0); - Ok(()) - } - Err(err) => { - // if a write fails we need to reset the state of the PartitionWriter - warn!("error writing to arrow buffer, resetting writer state."); - self.replace_arrow_buffer(buffer_bytes)?; - Err(err.into()) - } - } + Ok(self.arrow_writer.write(batch)?) } async fn flush_arrow_writer(&mut self) -> DeltaResult<()> { // replace counter / buffers and close the current writer - let (writer, buffer) = self.replace_arrow_buffer(vec![])?; - let null_counts = std::mem::take(&mut self.null_counts); + let (writer, buffer) = self.reset_writer()?; let metadata = writer.close()?; + let buffer = match buffer.into_inner() { + Some(buffer) => Bytes::from(buffer), + None => return Ok(()), // Nothing to write + }; // collect metadata let path = self.next_data_path(); - let obj_bytes = Bytes::from(buffer.to_vec()); - let file_size = obj_bytes.len() as i64; + let file_size = buffer.len() as i64; // write file to object store - self.object_store.put(&path, obj_bytes).await?; + self.object_store.put(&path, buffer).await?; self.files_written.push( create_add( &self.config.partition_values, - null_counts, path.to_string(), file_size, &metadata, diff --git a/rust/src/writer/json.rs b/rust/src/writer/json.rs index 0b378e2107..c601f7c1e4 100644 --- a/rust/src/writer/json.rs +++ b/rust/src/writer/json.rs @@ -3,7 +3,7 @@ use std::collections::HashMap; use std::convert::TryFrom; use std::sync::Arc; -use super::stats::{apply_null_counts, create_add, NullCounts}; +use super::stats::create_add; use super::utils::{ arrow_schema_without_partitions, next_data_path, record_batch_from_message, record_batch_without_partitions, stringified_partition_value, @@ -42,7 +42,6 @@ pub(crate) struct DataArrowWriter { buffer: ShareableBuffer, arrow_writer: ArrowWriter, partition_values: HashMap>, - null_counts: NullCounts, buffered_record_batch_count: usize, } @@ -120,7 +119,6 @@ impl DataArrowWriter { match result { Ok(_) => { self.buffered_record_batch_count += 1; - apply_null_counts(&record_batch.into(), &mut self.null_counts, 0); Ok(()) } // If a write fails we need to reset the state of the DeltaArrowWriter @@ -152,7 +150,6 @@ impl DataArrowWriter { )?; let partition_values = HashMap::new(); - let null_counts = NullCounts::new(); let buffered_record_batch_count = 0; Ok(Self { @@ -161,7 +158,6 @@ impl DataArrowWriter { buffer, arrow_writer, partition_values, - null_counts, buffered_record_batch_count, }) } @@ -363,19 +359,15 @@ impl DeltaWriter> for JsonWriter { let writers = std::mem::take(&mut self.arrow_writers); let mut actions = Vec::new(); - for (_, mut writer) in writers { + for (_, writer) in writers { let metadata = writer.arrow_writer.close()?; let path = next_data_path(&self.partition_columns, &writer.partition_values, None)?; let obj_bytes = Bytes::from(writer.buffer.to_vec()); let file_size = obj_bytes.len() as i64; self.storage.put(&path, obj_bytes).await?; - // Replace self null_counts with an empty map. Use the other for stats. - let null_counts = std::mem::take(&mut writer.null_counts); - actions.push(create_add( &writer.partition_values, - null_counts, path.to_string(), file_size, &metadata, diff --git a/rust/src/writer/mod.rs b/rust/src/writer/mod.rs index 522ab3614d..2dfe5a3822 100644 --- a/rust/src/writer/mod.rs +++ b/rust/src/writer/mod.rs @@ -4,10 +4,10 @@ use crate::action::{Action, Add, ColumnCountStat}; use crate::{DeltaTable, DeltaTableError}; -use arrow::{datatypes::SchemaRef, datatypes::*, error::ArrowError}; +use arrow::{datatypes::SchemaRef, error::ArrowError}; use async_trait::async_trait; use object_store::Error as ObjectStoreError; -use parquet::{basic::LogicalType, errors::ParquetError}; +use parquet::errors::ParquetError; use serde_json::Value; pub use json::JsonWriter; @@ -55,9 +55,15 @@ pub(crate) enum DeltaWriterError { }, /// Serialization of delta log statistics failed. - #[error("Serialization of delta log statistics failed: {source}")] - StatsSerializationFailed { - /// error raised during stats serialization. + #[error("Failed to write statistics value {debug_value} with logical type {logical_type:?}")] + StatsParsingFailed { + debug_value: String, + logical_type: Option, + }, + + /// JSON serialization failed + #[error("Failed to serialize data to JSON: {source}")] + JSONSerializationFailed { #[from] source: serde_json::Error, }, diff --git a/rust/src/writer/record_batch.rs b/rust/src/writer/record_batch.rs index fc9421599c..4f2789cd32 100644 --- a/rust/src/writer/record_batch.rs +++ b/rust/src/writer/record_batch.rs @@ -26,31 +26,32 @@ //! })) //! } //! ``` - -use std::collections::HashMap; -use std::convert::TryFrom; -use std::sync::Arc; - -use arrow_array::{ArrayRef, RecordBatch, UInt32Array}; -use arrow_ord::{partition::lexicographical_partition_ranges, sort::SortColumn}; +use std::{collections::HashMap, sync::Arc}; + +use super::{ + stats::create_add, + utils::{ + arrow_schema_without_partitions, next_data_path, record_batch_without_partitions, + stringified_partition_value, PartitionPath, + }, + DeltaWriter, DeltaWriterError, +}; +use crate::builder::DeltaTableBuilder; +use crate::writer::utils::ShareableBuffer; +use crate::DeltaTableError; +use crate::{action::Add, storage::DeltaObjectStore, DeltaTable, DeltaTableMetaData, Schema}; +use arrow::array::{Array, UInt32Array}; +use arrow::compute::{lexicographical_partition_ranges, take, SortColumn}; +use arrow::datatypes::{Schema as ArrowSchema, SchemaRef as ArrowSchemaRef}; +use arrow::error::ArrowError; +use arrow::record_batch::RecordBatch; +use arrow_array::ArrayRef; use arrow_row::{RowConverter, SortField}; -use arrow_schema::{ArrowError, Schema as ArrowSchema, SchemaRef as ArrowSchemaRef}; -use arrow_select::take::take; use bytes::Bytes; use object_store::ObjectStore; use parquet::{arrow::ArrowWriter, errors::ParquetError}; use parquet::{basic::Compression, file::properties::WriterProperties}; -use super::stats::{create_add, NullCounts}; -use super::utils::{ - arrow_schema_without_partitions, next_data_path, record_batch_without_partitions, - stringified_partition_value, PartitionPath, -}; -use super::{DeltaTableError, DeltaWriter, DeltaWriterError}; -use crate::builder::DeltaTableBuilder; -use crate::writer::{stats::apply_null_counts, utils::ShareableBuffer}; -use crate::{action::Add, storage::DeltaObjectStore, DeltaTable, DeltaTableMetaData, Schema}; - /// Writes messages to a delta lake table. pub struct RecordBatchWriter { storage: Arc, @@ -225,19 +226,15 @@ impl DeltaWriter for RecordBatchWriter { let writers = std::mem::take(&mut self.arrow_writers); let mut actions = Vec::new(); - for (_, mut writer) in writers { + for (_, writer) in writers { let metadata = writer.arrow_writer.close()?; let path = next_data_path(&self.partition_columns, &writer.partition_values, None)?; let obj_bytes = Bytes::from(writer.buffer.to_vec()); let file_size = obj_bytes.len() as i64; self.storage.put(&path, obj_bytes).await?; - // Replace self null_counts with an empty map. Use the other for stats. - let null_counts = std::mem::take(&mut writer.null_counts); - actions.push(create_add( &writer.partition_values, - null_counts, path.to_string(), file_size, &metadata, @@ -262,7 +259,6 @@ struct PartitionWriter { pub(super) buffer: ShareableBuffer, pub(super) arrow_writer: ArrowWriter, pub(super) partition_values: HashMap>, - pub(super) null_counts: NullCounts, pub(super) buffered_record_batch_count: usize, } @@ -279,7 +275,6 @@ impl PartitionWriter { Some(writer_properties.clone()), )?; - let null_counts = NullCounts::new(); let buffered_record_batch_count = 0; Ok(Self { @@ -288,7 +283,6 @@ impl PartitionWriter { buffer, arrow_writer, partition_values, - null_counts, buffered_record_batch_count, }) } @@ -310,7 +304,6 @@ impl PartitionWriter { match self.arrow_writer.write(record_batch) { Ok(_) => { self.buffered_record_batch_count += 1; - apply_null_counts(&record_batch.clone().into(), &mut self.null_counts, 0); Ok(()) } // If a write fails we need to reset the state of the PartitionWriter diff --git a/rust/src/writer/stats.rs b/rust/src/writer/stats.rs index a120c39b3a..faef0ee670 100644 --- a/rust/src/writer/stats.rs +++ b/rust/src/writer/stats.rs @@ -1,97 +1,23 @@ use super::*; -use crate::{ - action::{Add, ColumnValueStat, Stats}, - time_utils::timestamp_to_delta_stats_string, -}; -use arrow::{ - array::{ - as_boolean_array, as_primitive_array, as_struct_array, make_array, Array, ArrayData, - StructArray, - }, - buffer::MutableBuffer, -}; -use parquet::errors::ParquetError; -use parquet::file::{metadata::RowGroupMetaData, statistics::Statistics}; +use crate::action::{Add, ColumnValueStat, Stats}; use parquet::format::FileMetaData; use parquet::schema::types::{ColumnDescriptor, SchemaDescriptor}; -use serde_json::{Number, Value}; -use std::collections::HashMap; +use parquet::{basic::LogicalType, errors::ParquetError}; +use parquet::{ + file::{metadata::RowGroupMetaData, statistics::Statistics}, + format::TimeUnit, +}; use std::sync::Arc; use std::time::{SystemTime, UNIX_EPOCH}; - -pub type NullCounts = HashMap; -pub type MinAndMaxValues = ( - HashMap, - HashMap, -); - -pub(crate) fn apply_null_counts( - array: &StructArray, - null_counts: &mut HashMap, - _nest_level: i32, -) { - let fields = match array.data_type() { - DataType::Struct(fields) => fields, - _ => unreachable!(), - }; - - array - .columns() - .iter() - .zip(fields) - .for_each(|(column, field)| { - let key = field.name().to_owned(); - - match column.data_type() { - // Recursive case - DataType::Struct(_) => { - let col_struct = null_counts - .entry(key) - .or_insert_with(|| ColumnCountStat::Column(HashMap::new())); - - match col_struct { - ColumnCountStat::Column(map) => { - apply_null_counts(as_struct_array(column), map, _nest_level + 1); - } - _ => unreachable!(), - } - } - // Base case - _ => { - let col_struct = null_counts - .entry(key.clone()) - .or_insert_with(|| ColumnCountStat::Value(0)); - - match col_struct { - ColumnCountStat::Value(n) => { - let null_count = column.null_count() as i64; - let n = null_count + *n; - null_counts.insert(key, ColumnCountStat::Value(n)); - } - _ => unreachable!(), - } - } - } - }); -} +use std::{collections::HashMap, ops::AddAssign}; pub(crate) fn create_add( partition_values: &HashMap>, - null_counts: NullCounts, path: String, size: i64, file_metadata: &FileMetaData, ) -> Result { - let (min_values, max_values) = - min_max_values_from_file_metadata(partition_values, file_metadata)?; - - let stats = Stats { - num_records: file_metadata.num_rows, - min_values, - max_values, - null_count: null_counts, - }; - + let stats = stats_from_file_metadata(partition_values, file_metadata)?; let stats_string = serde_json::to_string(&stats)?; // Determine the modification timestamp to include in the add action - milliseconds since epoch @@ -112,15 +38,16 @@ pub(crate) fn create_add( }) } -fn min_max_values_from_file_metadata( +fn stats_from_file_metadata( partition_values: &HashMap>, file_metadata: &FileMetaData, -) -> Result { +) -> Result { let type_ptr = parquet::schema::types::from_thrift(file_metadata.schema.as_slice()); let schema_descriptor = type_ptr.map(|type_| Arc::new(SchemaDescriptor::new(type_)))?; let mut min_values: HashMap = HashMap::new(); let mut max_values: HashMap = HashMap::new(); + let mut null_count: HashMap = HashMap::new(); let row_group_metadata: Result, ParquetError> = file_metadata .row_groups @@ -132,14 +59,6 @@ fn min_max_values_from_file_metadata( for i in 0..schema_descriptor.num_columns() { let column_descr = schema_descriptor.column(i); - // If max rep level is > 0, this is an array element or a struct element of an array or something downstream of an array. - // delta/databricks only computes null counts for arrays - not min max. - // null counts are tracked at the record batch level, so skip any column with max_rep_level - // > 0 - if column_descr.max_rep_level() > 0 { - continue; - } - let column_path = column_descr.path(); let column_path_parts = column_path.parts(); @@ -148,45 +67,338 @@ fn min_max_values_from_file_metadata( continue; } - let statistics: Vec<&Statistics> = row_group_metadata + let maybe_stats: Option = row_group_metadata .iter() - .filter_map(|g| g.column(i).statistics()) - .collect(); - - apply_min_max_for_column( - statistics.as_slice(), - column_descr.clone(), - column_path_parts, - &mut min_values, - &mut max_values, - )?; + .map(|g| { + g.column(i) + .statistics() + .map(|s| AggregatedStats::from((s, &column_descr.logical_type()))) + }) + .reduce(|left, right| match (left, right) { + (Some(mut left), Some(right)) => { + left += right; + Some(left) + } + _ => None, + }) + .flatten(); + + if let Some(stats) = maybe_stats { + apply_min_max_for_column( + stats, + column_descr.clone(), + column_descr.path().parts(), + &mut min_values, + &mut max_values, + &mut null_count, + )?; + } } - Ok((min_values, max_values)) + Ok(Stats { + min_values, + max_values, + num_records: file_metadata.num_rows, + null_count, + }) +} + +/// Logical scalars extracted from statistics. These are used to aggregate +/// minimums and maximums. We can't use the physical scalars because they +/// are not ordered correctly for some types. For example, decimals are stored +/// as fixed length binary, and can't be sorted leixcographically. +#[derive(Debug, Clone, PartialEq, PartialOrd)] +enum StatsScalar { + Boolean(bool), + Int32(i32), + Int64(i64), + Float32(f32), + Float64(f64), + Date(chrono::NaiveDate), + Timestamp(chrono::NaiveDateTime), + // We are serializing to f64 later and the ordering should be the same + Decimal(f64), + String(String), + Bytes(Vec), +} + +impl StatsScalar { + fn try_from_stats( + stats: &Statistics, + logical_type: &Option, + use_min: bool, + ) -> Result { + macro_rules! get_stat { + ($val: expr) => { + if use_min { + *$val.min() + } else { + *$val.max() + } + }; + } + + match (stats, logical_type) { + (Statistics::Boolean(v), _) => Ok(Self::Boolean(get_stat!(v))), + // Int32 can be date, decimal, or just int32 + (Statistics::Int32(v), Some(LogicalType::Date)) => { + let date = chrono::NaiveDate::from_num_days_from_ce_opt(get_stat!(v)).ok_or( + DeltaWriterError::StatsParsingFailed { + debug_value: v.to_string(), + logical_type: Some(LogicalType::Date), + }, + )?; + Ok(Self::Date(date)) + } + (Statistics::Int32(v), Some(LogicalType::Decimal { scale, .. })) => { + let val = get_stat!(v) as f64 / 10.0_f64.powi(*scale); + // Spark serializes these as numbers + Ok(Self::Decimal(val)) + } + (Statistics::Int32(v), _) => Ok(Self::Int32(get_stat!(v))), + // Int64 can be timestamp, decimal, or integer + (Statistics::Int64(v), Some(LogicalType::Timestamp { unit, .. })) => { + // For now, we assume timestamps are adjusted to UTC. Non-UTC timestamps + // are behind a feature gate in Delta: + // https://github.com/delta-io/delta/blob/master/PROTOCOL.md#timestamp-without-timezone-timestampntz + let v = get_stat!(v); + let timestamp = match unit { + TimeUnit::MILLIS(_) => chrono::NaiveDateTime::from_timestamp_millis(v), + TimeUnit::MICROS(_) => chrono::NaiveDateTime::from_timestamp_micros(v), + TimeUnit::NANOS(_) => { + let secs = v / 1_000_000_000; + let nanosecs = (v % 1_000_000_000) as u32; + chrono::NaiveDateTime::from_timestamp_opt(secs, nanosecs) + } + }; + let timestamp = timestamp.ok_or(DeltaWriterError::StatsParsingFailed { + debug_value: v.to_string(), + logical_type: logical_type.clone(), + })?; + Ok(Self::Timestamp(timestamp)) + } + (Statistics::Int64(v), Some(LogicalType::Decimal { scale, .. })) => { + let val = get_stat!(v) as f64 / 10.0_f64.powi(*scale); + // Spark serializes these as numbers + Ok(Self::Decimal(val)) + } + (Statistics::Int64(v), _) => Ok(Self::Int64(get_stat!(v))), + (Statistics::Float(v), _) => Ok(Self::Float32(get_stat!(v))), + (Statistics::Double(v), _) => Ok(Self::Float64(get_stat!(v))), + (Statistics::ByteArray(v), logical_type) => { + let bytes = if use_min { + v.min_bytes() + } else { + v.max_bytes() + }; + match logical_type { + None => Ok(Self::Bytes(bytes.to_vec())), + Some(LogicalType::String) => { + Ok(Self::String(String::from_utf8(bytes.to_vec()).map_err( + |_| DeltaWriterError::StatsParsingFailed { + debug_value: format!("{bytes:?}"), + logical_type: Some(LogicalType::String), + }, + )?)) + } + _ => Err(DeltaWriterError::StatsParsingFailed { + debug_value: format!("{bytes:?}"), + logical_type: logical_type.clone(), + }), + } + } + (Statistics::FixedLenByteArray(v), Some(LogicalType::Decimal { scale, precision })) => { + let val = if use_min { + v.min_bytes() + } else { + v.max_bytes() + }; + + let val = if val.len() <= 4 { + let mut bytes = [0; 4]; + bytes[..val.len()].copy_from_slice(val); + i32::from_be_bytes(bytes) as f64 + } else if val.len() <= 8 { + let mut bytes = [0; 8]; + bytes[..val.len()].copy_from_slice(val); + i64::from_be_bytes(bytes) as f64 + } else if val.len() <= 16 { + let mut bytes = [0; 16]; + bytes[..val.len()].copy_from_slice(val); + i128::from_be_bytes(bytes) as f64 + } else { + return Err(DeltaWriterError::StatsParsingFailed { + debug_value: format!("{val:?}"), + logical_type: Some(LogicalType::Decimal { + scale: *scale, + precision: *precision, + }), + }); + }; + + let val = val / 10.0_f64.powi(*scale); + Ok(Self::Decimal(val)) + } + (stats, _) => Err(DeltaWriterError::StatsParsingFailed { + debug_value: format!("{stats:?}"), + logical_type: logical_type.clone(), + }), + } + } +} + +impl From for serde_json::Value { + fn from(scalar: StatsScalar) -> Self { + match scalar { + StatsScalar::Boolean(v) => serde_json::Value::Bool(v), + StatsScalar::Int32(v) => serde_json::Value::from(v), + StatsScalar::Int64(v) => serde_json::Value::from(v), + StatsScalar::Float32(v) => serde_json::Value::from(v), + StatsScalar::Float64(v) => serde_json::Value::from(v), + StatsScalar::Date(v) => serde_json::Value::from(v.format("%Y-%m-%d").to_string()), + StatsScalar::Timestamp(v) => { + serde_json::Value::from(v.format("%Y-%m-%dT%H:%M:%S%.fZ").to_string()) + } + StatsScalar::Decimal(v) => serde_json::Value::from(v), + StatsScalar::String(v) => serde_json::Value::from(v), + StatsScalar::Bytes(v) => { + let escaped_bytes = v + .into_iter() + .flat_map(std::ascii::escape_default) + .collect::>(); + let escaped_string = String::from_utf8(escaped_bytes).unwrap(); + serde_json::Value::from(escaped_string) + } + } + } +} + +/// Aggregated stats +struct AggregatedStats { + pub min: Option, + pub max: Option, + pub null_count: u64, +} + +impl From<(&Statistics, &Option)> for AggregatedStats { + fn from(value: (&Statistics, &Option)) -> Self { + let (stats, logical_type) = value; + let null_count = stats.null_count(); + if stats.has_min_max_set() { + let min = StatsScalar::try_from_stats(stats, logical_type, true).ok(); + let max = StatsScalar::try_from_stats(stats, logical_type, false).ok(); + Self { + min, + max, + null_count, + } + } else { + Self { + min: None, + max: None, + null_count, + } + } + } +} + +impl AddAssign for AggregatedStats { + fn add_assign(&mut self, rhs: Self) { + self.min = match (self.min.take(), rhs.min) { + (Some(lhs), Some(rhs)) => { + if lhs < rhs { + Some(lhs) + } else { + Some(rhs) + } + } + (lhs, rhs) => lhs.or(rhs), + }; + self.max = match (self.min.take(), rhs.max) { + (Some(lhs), Some(rhs)) => { + if lhs > rhs { + Some(lhs) + } else { + Some(rhs) + } + } + (lhs, rhs) => lhs.or(rhs), + }; + + self.null_count += rhs.null_count; + } +} + +/// For a list field, we don't want the inner field names. We need to chuck out +/// the list and items fields from the path, but also need to handle the +/// peculiar case where the user named the list field "list" or "item". +/// +/// For example: +/// +/// * ["some_nested_list", "list", "item", "list", "item"] -> "some_nested_list" +/// * ["some_list", "list", "item"] -> "some_list" +/// * ["list", "list", "item"] -> "list" +/// * ["item", "list", "item"] -> "item" +fn get_list_field_name(column_descr: &Arc) -> Option { + let max_rep_levels = column_descr.max_rep_level(); + let column_path_parts = column_descr.path().parts(); + + // If there are more nested names, we can't handle them yet. + if column_path_parts.len() > (2 * max_rep_levels + 1) as usize { + return None; + } + + let mut column_path_parts = column_path_parts.to_vec(); + let mut items_seen = 0; + let mut lists_seen = 0; + while let Some(part) = column_path_parts.pop() { + match (part.as_str(), lists_seen, items_seen) { + ("list", seen, _) if seen == max_rep_levels => return Some("list".to_string()), + ("item", _, seen) if seen == max_rep_levels => return Some("item".to_string()), + ("list", _, _) => lists_seen += 1, + ("item", _, _) => items_seen += 1, + (other, _, _) => return Some(other.to_string()), + } + } + None } fn apply_min_max_for_column( - statistics: &[&Statistics], + statistics: AggregatedStats, column_descr: Arc, column_path_parts: &[String], min_values: &mut HashMap, max_values: &mut HashMap, + null_counts: &mut HashMap, ) -> Result<(), DeltaWriterError> { + // Special handling for list column + if column_descr.max_rep_level() > 0 { + let key = get_list_field_name(&column_descr); + + if let Some(key) = key { + null_counts.insert(key, ColumnCountStat::Value(statistics.null_count as i64)); + } + + return Ok(()); + } + match (column_path_parts.len(), column_path_parts.first()) { // Base case - we are at the leaf struct level in the path (1, _) => { - let (min, max) = min_and_max_from_parquet_statistics(statistics, column_descr.clone())?; + let key = column_descr.name().to_string(); - if let Some(min) = min { - let min = ColumnValueStat::Value(min); - min_values.insert(column_descr.name().to_string(), min); + if let Some(min) = statistics.min { + let min = ColumnValueStat::Value(min.into()); + min_values.insert(key.clone(), min); } - if let Some(max) = max { - let max = ColumnValueStat::Value(max); - max_values.insert(column_descr.name().to_string(), max); + if let Some(max) = statistics.max { + let max = ColumnValueStat::Value(max.into()); + max_values.insert(key.clone(), max); } + null_counts.insert(key, ColumnCountStat::Value(statistics.null_count as i64)); + Ok(()) } // Recurse to load value at the appropriate level of HashMap @@ -197,9 +409,16 @@ fn apply_min_max_for_column( let child_max_values = max_values .entry(key.to_owned()) .or_insert_with(|| ColumnValueStat::Column(HashMap::new())); - - match (child_min_values, child_max_values) { - (ColumnValueStat::Column(mins), ColumnValueStat::Column(maxes)) => { + let child_null_counts = null_counts + .entry(key.to_owned()) + .or_insert_with(|| ColumnCountStat::Column(HashMap::new())); + + match (child_min_values, child_max_values, child_null_counts) { + ( + ColumnValueStat::Column(mins), + ColumnValueStat::Column(maxes), + ColumnCountStat::Column(null_counts), + ) => { let remaining_parts: Vec = column_path_parts .iter() .skip(1) @@ -212,6 +431,7 @@ fn apply_min_max_for_column( remaining_parts.as_slice(), mins, maxes, + null_counts, )?; Ok(()) @@ -228,200 +448,147 @@ fn apply_min_max_for_column( } } -#[inline] -fn is_utf8(opt: Option) -> bool { - matches!(opt.as_ref(), Some(LogicalType::String)) -} - -fn min_and_max_from_parquet_statistics( - statistics: &[&Statistics], - column_descr: Arc, -) -> Result<(Option, Option), DeltaWriterError> { - let stats_with_min_max: Vec<&Statistics> = statistics - .iter() - .filter(|s| s.has_min_max_set()) - .copied() - .collect(); - - if stats_with_min_max.is_empty() { - return Ok((None, None)); - } - - let (data_size, data_type) = match stats_with_min_max.first() { - Some(Statistics::Boolean(_)) => (std::mem::size_of::(), DataType::Boolean), - Some(Statistics::Int32(_)) => (std::mem::size_of::(), DataType::Int32), - Some(Statistics::Int64(_)) => (std::mem::size_of::(), DataType::Int64), - Some(Statistics::Float(_)) => (std::mem::size_of::(), DataType::Float32), - Some(Statistics::Double(_)) => (std::mem::size_of::(), DataType::Float64), - Some(Statistics::ByteArray(_)) if is_utf8(column_descr.logical_type()) => { - (0, DataType::Utf8) - } - _ => { - // NOTE: Skips - // Statistics::Int96(_) - // Statistics::ByteArray(_) - // Statistics::FixedLenByteArray(_) - - return Ok((None, None)); - } - }; - - if data_type == DataType::Utf8 { - return Ok(min_max_strings_from_stats(&stats_with_min_max)); - } - - let arrow_buffer_capacity = stats_with_min_max.len() * data_size; - - let min_array = arrow_array_from_bytes( - data_type.clone(), - arrow_buffer_capacity, - stats_with_min_max.iter().map(|s| s.min_bytes()).collect(), - )?; - - let max_array = arrow_array_from_bytes( - data_type.clone(), - arrow_buffer_capacity, - stats_with_min_max.iter().map(|s| s.max_bytes()).collect(), - )?; - - match data_type { - DataType::Boolean => { - let min = arrow::compute::min_boolean(as_boolean_array(&min_array)); - let min = min.map(Value::Bool); - - let max = arrow::compute::max_boolean(as_boolean_array(&max_array)); - let max = max.map(Value::Bool); - - Ok((min, max)) - } - DataType::Int32 => { - let min_array = as_primitive_array::(&min_array); - let min = arrow::compute::min(min_array); - let min = min.map(|i| Value::Number(Number::from(i))); - - let max_array = as_primitive_array::(&max_array); - let max = arrow::compute::max(max_array); - let max = max.map(|i| Value::Number(Number::from(i))); - - Ok((min, max)) - } - DataType::Int64 => { - let min_array = as_primitive_array::(&min_array); - let min = arrow::compute::min(min_array); - let max_array = as_primitive_array::(&max_array); - let max = arrow::compute::max(max_array); - - match column_descr.logical_type().as_ref() { - Some(LogicalType::Timestamp { unit, .. }) => { - let min = min - .and_then(|n| timestamp_to_delta_stats_string(n, unit).map(Value::String)); - let max = max - .and_then(|n| timestamp_to_delta_stats_string(n, unit).map(Value::String)); - - Ok((min, max)) - } - _ => { - let min = min.map(|i| Value::Number(Number::from(i))); - let max = max.map(|i| Value::Number(Number::from(i))); - - Ok((min, max)) - } - } - } - DataType::Float32 => { - let min_array = as_primitive_array::(&min_array); - let min = arrow::compute::min(min_array); - let min = min.and_then(|f| Number::from_f64(f as f64).map(Value::Number)); - - let max_array = as_primitive_array::(&max_array); - let max = arrow::compute::max(max_array); - let max = max.and_then(|f| Number::from_f64(f as f64).map(Value::Number)); - - Ok((min, max)) - } - DataType::Float64 => { - let min_array = as_primitive_array::(&min_array); - let min = arrow::compute::min(min_array); - let min = min.and_then(|f| Number::from_f64(f).map(Value::Number)); - - let max_array = as_primitive_array::(&max_array); - let max = arrow::compute::max(max_array); - let max = max.and_then(|f| Number::from_f64(f).map(Value::Number)); - - Ok((min, max)) - } - _ => Ok((None, None)), - } -} - -fn min_max_strings_from_stats( - stats_with_min_max: &[&Statistics], -) -> (Option, Option) { - let min_string_candidates = stats_with_min_max - .iter() - .filter_map(|s| std::str::from_utf8(s.min_bytes()).ok()); - - let min_value = min_string_candidates - .min() - .map(|s| Value::String(s.to_string())); - - let max_string_candidates = stats_with_min_max - .iter() - .filter_map(|s| std::str::from_utf8(s.max_bytes()).ok()); - - let max_value = max_string_candidates - .max() - .map(|s| Value::String(s.to_string())); - - (min_value, max_value) -} - -fn arrow_array_from_bytes( - data_type: DataType, - capacity: usize, - byte_arrays: Vec<&[u8]>, -) -> Result, DeltaWriterError> { - let mut buffer = MutableBuffer::new(capacity); - - for arr in byte_arrays.iter() { - buffer.extend_from_slice(arr); - } - - let builder = ArrayData::builder(data_type) - .len(byte_arrays.len()) - .add_buffer(buffer.into()); - - let data = builder.build()?; - - Ok(make_array(data)) -} - #[cfg(test)] mod tests { + use super::utils::record_batch_from_message; use super::*; - use super::{test_utils::get_record_batch, utils::record_batch_from_message}; use crate::{ action::{ColumnCountStat, ColumnValueStat}, builder::DeltaTableBuilder, DeltaTable, DeltaTableError, }; use lazy_static::lazy_static; + use parquet::data_type::{ByteArray, FixedLenByteArray}; + use parquet::file::statistics::ValueStatistics; use serde_json::{json, Value}; use std::collections::HashMap; use std::path::Path; - #[test] - fn test_apply_null_counts() { - let record_batch = get_record_batch(None, true); - let mut ref_null_counts = HashMap::new(); - ref_null_counts.insert("id".to_string(), ColumnCountStat::Value(3)); - ref_null_counts.insert("value".to_string(), ColumnCountStat::Value(1)); - ref_null_counts.insert("modified".to_string(), ColumnCountStat::Value(0)); - - let mut null_counts = HashMap::new(); - apply_null_counts(&record_batch.into(), &mut null_counts, 0); + macro_rules! simple_parquet_stat { + ($variant:expr, $value:expr) => { + $variant(ValueStatistics::new( + Some($value), + Some($value), + None, + 0, + false, + )) + }; + } - assert_eq!(null_counts, ref_null_counts) + #[test] + fn test_stats_scalar_serialization() { + let cases = &[ + ( + simple_parquet_stat!(Statistics::Boolean, true), + Some(LogicalType::Integer { + bit_width: 1, + is_signed: true, + }), + Value::Bool(true), + ), + ( + simple_parquet_stat!(Statistics::Int32, 1), + Some(LogicalType::Integer { + bit_width: 32, + is_signed: true, + }), + Value::from(1), + ), + ( + simple_parquet_stat!(Statistics::Int32, 1234), + Some(LogicalType::Decimal { + scale: 3, + precision: 4, + }), + Value::from(1.234), + ), + ( + simple_parquet_stat!(Statistics::Int32, 1234), + Some(LogicalType::Decimal { + scale: -1, + precision: 4, + }), + Value::from(12340.0), + ), + ( + simple_parquet_stat!(Statistics::Int32, 737821), + Some(LogicalType::Date), + Value::from("2021-01-31"), + ), + ( + simple_parquet_stat!(Statistics::Int64, 1641040496789123456), + Some(LogicalType::Timestamp { + is_adjusted_to_u_t_c: true, + unit: parquet::format::TimeUnit::NANOS(parquet::format::NanoSeconds {}), + }), + Value::from("2022-01-01T12:34:56.789123456Z"), + ), + ( + simple_parquet_stat!(Statistics::Int64, 1641040496789123), + Some(LogicalType::Timestamp { + is_adjusted_to_u_t_c: true, + unit: parquet::format::TimeUnit::MICROS(parquet::format::MicroSeconds {}), + }), + Value::from("2022-01-01T12:34:56.789123Z"), + ), + ( + simple_parquet_stat!(Statistics::Int64, 1641040496789), + Some(LogicalType::Timestamp { + is_adjusted_to_u_t_c: true, + unit: parquet::format::TimeUnit::MILLIS(parquet::format::MilliSeconds {}), + }), + Value::from("2022-01-01T12:34:56.789Z"), + ), + ( + simple_parquet_stat!(Statistics::Int64, 1234), + Some(LogicalType::Decimal { + scale: 3, + precision: 4, + }), + Value::from(1.234), + ), + ( + simple_parquet_stat!(Statistics::Int64, 1234), + Some(LogicalType::Decimal { + scale: -1, + precision: 4, + }), + Value::from(12340.0), + ), + ( + simple_parquet_stat!(Statistics::Int64, 1234), + None, + Value::from(1234), + ), + ( + simple_parquet_stat!(Statistics::ByteArray, ByteArray::from(b"hello".to_vec())), + Some(LogicalType::String), + Value::from("hello"), + ), + ( + simple_parquet_stat!(Statistics::ByteArray, ByteArray::from(b"\x00\\".to_vec())), + None, + Value::from("\\x00\\\\"), + ), + ( + simple_parquet_stat!( + Statistics::FixedLenByteArray, + FixedLenByteArray::from(1243124142314423i128.to_be_bytes().to_vec()) + ), + Some(LogicalType::Decimal { + scale: 3, + precision: 16, + }), + Value::from(1243124142314.423), + ), + ]; + + for (stats, logical_type, expected) in cases { + let scalar = StatsScalar::try_from_stats(stats, logical_type, true).unwrap(); + let actual = serde_json::Value::from(scalar); + assert_eq!(&actual, expected); + } } #[tokio::test] @@ -528,7 +695,7 @@ mod tests { ("some_bool", ColumnCountStat::Value(v)) => assert_eq!(100, *v), ("some_string", ColumnCountStat::Value(v)) => assert_eq!(100, *v), ("some_list", ColumnCountStat::Value(v)) => assert_eq!(100, *v), - ("some_nested_list", ColumnCountStat::Value(v)) => assert_eq!(0, *v), + ("some_nested_list", ColumnCountStat::Value(v)) => assert_eq!(100, *v), ("date", ColumnCountStat::Value(v)) => assert_eq!(0, *v), _ => panic!("Key should not be present"), } diff --git a/rust/src/writer/test_utils.rs b/rust/src/writer/test_utils.rs index 97a401dbe4..8a5432535a 100644 --- a/rust/src/writer/test_utils.rs +++ b/rust/src/writer/test_utils.rs @@ -1,10 +1,10 @@ #![allow(deprecated)] //! Utilities for writing unit tests -use super::*; use crate::{ action::Protocol, schema::Schema, DeltaTable, DeltaTableBuilder, DeltaTableMetaData, SchemaDataType, SchemaField, }; +use arrow::datatypes::{DataType, Field}; use arrow::record_batch::RecordBatch; use arrow::{ array::{Int32Array, StringArray, UInt32Array},