From 63fe9af7e318ffc11b05d513e3a3b3cbe19574c0 Mon Sep 17 00:00:00 2001 From: Will Jones Date: Sun, 17 Sep 2023 16:29:22 -0700 Subject: [PATCH] cleanup --- .github/workflows/build.yml | 2 +- kernel/src/actions/arrow.rs | 275 +++++++++++++++++++++++++++++++++ kernel/src/scan/file_stream.rs | 22 +-- 3 files changed, 283 insertions(+), 16 deletions(-) create mode 100644 kernel/src/actions/arrow.rs diff --git a/.github/workflows/build.yml b/.github/workflows/build.yml index 392d713b..d408147c 100644 --- a/.github/workflows/build.yml +++ b/.github/workflows/build.yml @@ -33,7 +33,7 @@ jobs: - name: build and lint with clippy run: cargo clippy --tests -- -D warnings - name: lint without default features - run: cargo clippy --no-default--features -- -D warnings + run: cargo clippy --no-default-features -- -D warnings test: runs-on: ubuntu-latest steps: diff --git a/kernel/src/actions/arrow.rs b/kernel/src/actions/arrow.rs new file mode 100644 index 00000000..5408eb8c --- /dev/null +++ b/kernel/src/actions/arrow.rs @@ -0,0 +1,275 @@ +use std::sync::Arc; + +use arrow_schema::{ + ArrowError, DataType as ArrowDataType, Field as ArrowField, Schema as ArrowSchema, + SchemaRef as ArrowSchemaRef, TimeUnit, +}; + +use crate::schema::{ArrayType, DataType, MapType, PrimitiveType, StructField, StructType}; + +impl TryFrom<&StructType> for ArrowSchema { + type Error = ArrowError; + + fn try_from(s: &StructType) -> Result { + let fields = s + .fields() + .iter() + .map(|f| >::try_from(*f)) + .collect::, ArrowError>>()?; + + Ok(ArrowSchema::new(fields)) + } +} + +impl TryFrom<&StructField> for ArrowField { + type Error = ArrowError; + + fn try_from(f: &StructField) -> Result { + let metadata = f + .metadata() + .iter() + .map(|(key, val)| Ok((key.clone(), serde_json::to_string(val)?))) + .collect::>() + .map_err(|err| ArrowError::JsonError(err.to_string()))?; + + let field = ArrowField::new( + f.name(), + ArrowDataType::try_from(f.data_type())?, + f.is_nullable(), + ) + .with_metadata(metadata); + + Ok(field) + } +} + +impl TryFrom<&ArrayType> for ArrowField { + type Error = ArrowError; + + fn try_from(a: &ArrayType) -> Result { + Ok(ArrowField::new( + "item", + ArrowDataType::try_from(a.element_type())?, + a.contains_null(), + )) + } +} + +impl TryFrom<&MapType> for ArrowField { + type Error = ArrowError; + + fn try_from(a: &MapType) -> Result { + Ok(ArrowField::new( + "entries", + ArrowDataType::Struct( + vec![ + ArrowField::new("key", ArrowDataType::try_from(a.key_type())?, false), + ArrowField::new( + "value", + ArrowDataType::try_from(a.value_type())?, + a.value_contains_null(), + ), + ] + .into(), + ), + false, // always non-null + )) + } +} + +impl TryFrom<&DataType> for ArrowDataType { + type Error = ArrowError; + + fn try_from(t: &DataType) -> Result { + match t { + DataType::Primitive(p) => { + match p { + PrimitiveType::String => Ok(ArrowDataType::Utf8), + PrimitiveType::Long => Ok(ArrowDataType::Int64), // undocumented type + PrimitiveType::Integer => Ok(ArrowDataType::Int32), + PrimitiveType::Short => Ok(ArrowDataType::Int16), + PrimitiveType::Byte => Ok(ArrowDataType::Int8), + PrimitiveType::Float => Ok(ArrowDataType::Float32), + PrimitiveType::Double => Ok(ArrowDataType::Float64), + PrimitiveType::Boolean => Ok(ArrowDataType::Boolean), + PrimitiveType::Binary => Ok(ArrowDataType::Binary), + PrimitiveType::Decimal(precision, scale) => { + let precision = u8::try_from(*precision).map_err(|_| { + ArrowError::SchemaError(format!( + "Invalid precision for decimal: {}", + precision + )) + })?; + let scale = i8::try_from(*scale).map_err(|_| { + ArrowError::SchemaError(format!("Invalid scale for decimal: {}", scale)) + })?; + + if precision <= 38 { + Ok(ArrowDataType::Decimal128(precision, scale)) + } else if precision <= 76 { + Ok(ArrowDataType::Decimal256(precision, scale)) + } else { + Err(ArrowError::SchemaError(format!( + "Precision too large to be represented in Arrow: {}", + precision + ))) + } + } + PrimitiveType::Date => { + // A calendar date, represented as a year-month-day triple without a + // timezone. Stored as 4 bytes integer representing days since 1970-01-01 + Ok(ArrowDataType::Date32) + } + PrimitiveType::Timestamp => { + // Issue: https://github.com/delta-io/delta/issues/643 + Ok(ArrowDataType::Timestamp(TimeUnit::Microsecond, None)) + } + } + } + DataType::Struct(s) => Ok(ArrowDataType::Struct( + s.fields() + .iter() + .map(|f| >::try_from(*f)) + .collect::, ArrowError>>()? + .into(), + )), + DataType::Array(a) => Ok(ArrowDataType::List(Arc::new(>::try_from(a)?))), + DataType::Map(m) => Ok(ArrowDataType::Map( + Arc::new(ArrowField::new( + "entries", + ArrowDataType::Struct( + vec![ + ArrowField::new( + "key", + >::try_from(m.key_type())?, + false, + ), + ArrowField::new( + "value", + >::try_from(m.value_type())?, + m.value_contains_null(), + ), + ] + .into(), + ), + true, + )), + false, + )), + } + } +} + +impl TryFrom<&ArrowSchema> for StructType { + type Error = ArrowError; + + fn try_from(arrow_schema: &ArrowSchema) -> Result { + let new_fields: Result, _> = arrow_schema + .fields() + .iter() + .map(|field| field.as_ref().try_into()) + .collect(); + Ok(StructType::new(new_fields?)) + } +} + +impl TryFrom for StructType { + type Error = ArrowError; + + fn try_from(arrow_schema: ArrowSchemaRef) -> Result { + arrow_schema.as_ref().try_into() + } +} + +impl TryFrom<&ArrowField> for StructField { + type Error = ArrowError; + + fn try_from(arrow_field: &ArrowField) -> Result { + Ok(StructField::new( + arrow_field.name().clone(), + arrow_field.data_type().try_into()?, + arrow_field.is_nullable(), + ) + .with_metadata(arrow_field.metadata().iter().map(|(k, v)| (k.clone(), v)))) + } +} + +impl TryFrom<&ArrowDataType> for DataType { + type Error = ArrowError; + + fn try_from(arrow_datatype: &ArrowDataType) -> Result { + match arrow_datatype { + ArrowDataType::Utf8 => Ok(DataType::Primitive(PrimitiveType::String)), + ArrowDataType::LargeUtf8 => Ok(DataType::Primitive(PrimitiveType::String)), + ArrowDataType::Int64 => Ok(DataType::Primitive(PrimitiveType::Long)), // undocumented type + ArrowDataType::Int32 => Ok(DataType::Primitive(PrimitiveType::Integer)), + ArrowDataType::Int16 => Ok(DataType::Primitive(PrimitiveType::Short)), + ArrowDataType::Int8 => Ok(DataType::Primitive(PrimitiveType::Byte)), + ArrowDataType::UInt64 => Ok(DataType::Primitive(PrimitiveType::Long)), // undocumented type + ArrowDataType::UInt32 => Ok(DataType::Primitive(PrimitiveType::Integer)), + ArrowDataType::UInt16 => Ok(DataType::Primitive(PrimitiveType::Short)), + ArrowDataType::UInt8 => Ok(DataType::Primitive(PrimitiveType::Boolean)), + ArrowDataType::Float32 => Ok(DataType::Primitive(PrimitiveType::Float)), + ArrowDataType::Float64 => Ok(DataType::Primitive(PrimitiveType::Double)), + ArrowDataType::Boolean => Ok(DataType::Primitive(PrimitiveType::Boolean)), + ArrowDataType::Binary => Ok(DataType::Primitive(PrimitiveType::Binary)), + ArrowDataType::FixedSizeBinary(_) => Ok(DataType::Primitive(PrimitiveType::Binary)), + ArrowDataType::LargeBinary => Ok(DataType::Primitive(PrimitiveType::Binary)), + ArrowDataType::Decimal128(p, s) => Ok(DataType::Primitive(PrimitiveType::Decimal( + *p as i32, *s as i32, + ))), + ArrowDataType::Decimal256(p, s) => Ok(DataType::Primitive(PrimitiveType::Decimal( + *p as i32, *s as i32, + ))), + ArrowDataType::Date32 => Ok(DataType::Primitive(PrimitiveType::Date)), + ArrowDataType::Date64 => Ok(DataType::Primitive(PrimitiveType::Date)), + ArrowDataType::Timestamp(TimeUnit::Microsecond, None) => { + Ok(DataType::Primitive(PrimitiveType::Timestamp)) + } + ArrowDataType::Timestamp(TimeUnit::Microsecond, Some(tz)) + if tz.eq_ignore_ascii_case("utc") => + { + Ok(DataType::Primitive(PrimitiveType::Timestamp)) + } + ArrowDataType::Struct(fields) => { + let converted_fields: Result, _> = fields + .iter() + .map(|field| field.as_ref().try_into()) + .collect(); + Ok(DataType::Struct(Box::new(StructType::new( + converted_fields?, + )))) + } + ArrowDataType::List(field) => Ok(DataType::Array(Box::new(ArrayType::new( + (*field).data_type().try_into()?, + (*field).is_nullable(), + )))), + ArrowDataType::LargeList(field) => Ok(DataType::Array(Box::new(ArrayType::new( + (*field).data_type().try_into()?, + (*field).is_nullable(), + )))), + ArrowDataType::FixedSizeList(field, _) => Ok(DataType::Array(Box::new( + ArrayType::new((*field).data_type().try_into()?, (*field).is_nullable()), + ))), + ArrowDataType::Map(field, _) => { + if let ArrowDataType::Struct(struct_fields) = field.data_type() { + let key_type = struct_fields[0].data_type().try_into()?; + let value_type = struct_fields[1].data_type().try_into()?; + let value_type_nullable = struct_fields[1].is_nullable(); + Ok(DataType::Map(Box::new(MapType::new( + key_type, + value_type, + value_type_nullable, + )))) + } else { + panic!("DataType::Map should contain a struct field child"); + } + } + s => Err(ArrowError::SchemaError(format!( + "Invalid data type for Delta Lake: {s}" + ))), + } + } +} diff --git a/kernel/src/scan/file_stream.rs b/kernel/src/scan/file_stream.rs index a52e0a5e..a47c8782 100644 --- a/kernel/src/scan/file_stream.rs +++ b/kernel/src/scan/file_stream.rs @@ -5,6 +5,7 @@ use crate::actions::{parse_actions, Action, ActionType, Add}; use crate::expressions::Expression; use crate::DeltaResult; use arrow_array::RecordBatch; +use either::Either; struct LogReplayScanner { predicate: Option, @@ -70,20 +71,11 @@ pub fn log_replay_iter( ) -> impl Iterator> { let mut log_scanner = LogReplayScanner::new(predicate); - action_iter.flat_map( - move |actions| -> Box>> { - match actions { - Ok(actions) => match log_scanner.process_batch(actions) { - Ok(adds) => Box::new(adds.into_iter().map(Ok)), - Err(err) => Box::new(std::iter::once(Err(err))), - }, - Err(err) => Box::new(std::iter::once(Err(err))), - } + action_iter.flat_map(move |actions| match actions { + Ok(actions) => match log_scanner.process_batch(actions) { + Ok(adds) => Either::Left(adds.into_iter().map(Ok)), + Err(err) => Either::Right(std::iter::once(Err(err))), }, - ) -} - -#[cfg(test)] -mod tests { - // TODO: add tests + Err(err) => Either::Right(std::iter::once(Err(err))), + }) }