Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fix: fixedsizelist in delta #2607

Merged
merged 8 commits into from
Feb 7, 2024
Merged
Show file tree
Hide file tree
Changes from 4 commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
78 changes: 65 additions & 13 deletions crates/datafusion_ext/src/planner/expr/arrow_cast.rs
Original file line number Diff line number Diff line change
Expand Up @@ -21,11 +21,20 @@
use std::fmt::Display;
use std::iter::Peekable;
use std::str::Chars;

use datafusion::arrow::datatypes::{DataType, IntervalUnit, TimeUnit};
use datafusion::common::{DFSchema, DataFusionError, Result, ScalarValue};
use std::sync::Arc;

use datafusion::arrow::datatypes::{DataType, Field, IntervalUnit, TimeUnit};
use datafusion::common::{
plan_datafusion_err,
plan_err,
DFSchema,
DataFusionError,
Result,
ScalarValue,
};
use datafusion::logical_expr::{Expr, ExprSchemable};


pub const ARROW_CAST_NAME: &str = "arrow_cast";

/// Create an [`Expr`] that evaluates the `arrow_cast` function
Expand All @@ -52,21 +61,18 @@ pub const ARROW_CAST_NAME: &str = "arrow_cast";
/// [`BuiltinScalarFunction`]: datafusion_expr::BuiltinScalarFunction
pub fn create_arrow_cast(mut args: Vec<Expr>, schema: &DFSchema) -> Result<Expr> {
if args.len() != 2 {
return Err(DataFusionError::Plan(format!(
"arrow_cast needs 2 arguments, {} provided",
args.len()
)));
return plan_err!("arrow_cast needs 2 arguments, {} provided", args.len());
}
let arg1 = args.pop().unwrap();
let arg0 = args.pop().unwrap();

// arg1 must be a stirng
// arg1 must be a string
let data_type_string = if let Expr::Literal(ScalarValue::Utf8(Some(v))) = arg1 {
v
} else {
return Err(DataFusionError::Plan(format!(
return plan_err!(
"arrow_cast requires its second argument to be a constant string, got {arg1}"
)));
);
};

// do the actual lookup to the appropriate data type
Expand Down Expand Up @@ -101,9 +107,7 @@ pub fn parse_data_type(val: &str) -> Result<DataType> {
}

fn make_error(val: &str, msg: &str) -> DataFusionError {
DataFusionError::Plan(
format!("Unsupported type '{val}'. Must be a supported arrow type name such as 'Int32' or 'Timestamp(Nanosecond, None)'. Error {msg}" )
)
plan_datafusion_err!("Unsupported type '{val}'. Must be a supported arrow type name such as 'Int32' or 'Timestamp(Nanosecond, None)'. Error {msg}" )
}

fn make_error_expected(val: &str, expected: &Token, actual: &Token) -> DataFusionError {
Expand Down Expand Up @@ -151,13 +155,49 @@ impl<'a> Parser<'a> {
Token::Decimal128 => self.parse_decimal_128(),
Token::Decimal256 => self.parse_decimal_256(),
Token::Dictionary => self.parse_dictionary(),
Token::List => self.parse_list(),
Token::LargeList => self.parse_large_list(),
Token::FixedSizeList => self.parse_fixed_size_list(),
tok => Err(make_error(
self.val,
&format!("finding next type, got unexpected '{tok}'"),
)),
}
}

/// Parses the List type
fn parse_list(&mut self) -> Result<DataType> {
self.expect_token(Token::LParen)?;
let data_type = self.parse_next_type()?;
self.expect_token(Token::RParen)?;
Ok(DataType::List(Arc::new(Field::new(
"item", data_type, true,
))))
}

/// Parses the LargeList type
fn parse_large_list(&mut self) -> Result<DataType> {
self.expect_token(Token::LParen)?;
let data_type = self.parse_next_type()?;
self.expect_token(Token::RParen)?;
Ok(DataType::LargeList(Arc::new(Field::new(
"item", data_type, true,
))))
}

/// Parses the FixedSizeList type
fn parse_fixed_size_list(&mut self) -> Result<DataType> {
self.expect_token(Token::LParen)?;
let length = self.parse_i32("FixedSizeList")?;
self.expect_token(Token::Comma)?;
let data_type = self.parse_next_type()?;
self.expect_token(Token::RParen)?;
Ok(DataType::FixedSizeList(
Arc::new(Field::new("item", data_type, true)),
length,
))
}

/// Parses the next timeunit
fn parse_time_unit(&mut self, context: &str) -> Result<TimeUnit> {
match self.next_token()? {
Expand Down Expand Up @@ -490,6 +530,10 @@ impl<'a> Tokenizer<'a> {
"Date32" => Token::SimpleType(DataType::Date32),
"Date64" => Token::SimpleType(DataType::Date64),

"List" => Token::List,
"LargeList" => Token::LargeList,
"FixedSizeList" => Token::FixedSizeList,

"Second" => Token::TimeUnit(TimeUnit::Second),
"Millisecond" => Token::TimeUnit(TimeUnit::Millisecond),
"Microsecond" => Token::TimeUnit(TimeUnit::Microsecond),
Expand Down Expand Up @@ -576,12 +620,18 @@ enum Token {
None,
Integer(i64),
DoubleQuotedString(String),
List,
LargeList,
FixedSizeList,
}

impl Display for Token {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
match self {
Token::SimpleType(t) => write!(f, "{t}"),
Token::List => write!(f, "List"),
Token::LargeList => write!(f, "LargeList"),
Token::FixedSizeList => write!(f, "FixedSizeList"),
Token::Timestamp => write!(f, "Timestamp"),
Token::Time32 => write!(f, "Time32"),
Token::Time64 => write!(f, "Time64"),
Expand All @@ -606,6 +656,8 @@ impl Display for Token {

#[cfg(test)]
mod test {


use super::*;

#[test]
Expand Down
3 changes: 2 additions & 1 deletion crates/datafusion_ext/src/planner/expr/function.rs
Original file line number Diff line number Diff line change
Expand Up @@ -47,6 +47,7 @@ use datafusion::sql::sqlparser::ast::{
};

use super::arrow_cast::ARROW_CAST_NAME;
use crate::planner::expr::arrow_cast::create_arrow_cast;
use crate::planner::{AsyncContextProvider, SqlQueryPlanner};

impl<'a, S: AsyncContextProvider> SqlQueryPlanner<'a, S> {
Expand Down Expand Up @@ -215,7 +216,7 @@ impl<'a, S: AsyncContextProvider> SqlQueryPlanner<'a, S> {

// Special case arrow_cast (as its type is dependent on its argument value)
if name == ARROW_CAST_NAME {
return super::arrow_cast::create_arrow_cast(args, schema);
return create_arrow_cast(args, schema);
}
}

Expand Down
2 changes: 1 addition & 1 deletion crates/datasources/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -45,7 +45,7 @@ reqwest = { workspace = true }
rust_decimal = { version = "1.34.2", features = ["db-tokio-postgres"] }
serde = { workspace = true }
serde_with = "3.6.0"
serde_json = { workspace = true }
serde_json.workspace = true
universalmind303 marked this conversation as resolved.
Show resolved Hide resolved
snowflake_connector = { path = "../snowflake_connector" }
tempfile = { workspace = true }
ssh-key = { version = "0.6.4", features = ["ed25519", "alloc"] }
Expand Down
141 changes: 120 additions & 21 deletions crates/datasources/src/native/access.rs
Original file line number Diff line number Diff line change
@@ -1,16 +1,22 @@
use std::any::Any;
use std::collections::HashMap;
use std::sync::Arc;

use async_trait::async_trait;
use datafusion::arrow::datatypes::{DataType, Schema as ArrowSchema, TimeUnit};
use datafusion::arrow::datatypes::{DataType, Field, Schema as ArrowSchema, TimeUnit};
use datafusion::common::ToDFSchema;
use datafusion::datasource::TableProvider;
use datafusion::error::Result as DataFusionResult;
use datafusion::execution::context::SessionState;
use datafusion::logical_expr::{LogicalPlan, TableProviderFilterPushDown, TableType};
use datafusion::logical_expr::{col, Cast, LogicalPlan, TableProviderFilterPushDown, TableType};
use datafusion::physical_expr::create_physical_expr;
use datafusion::physical_expr::execution_props::ExecutionProps;
use datafusion::physical_plan::empty::EmptyExec;
use datafusion::physical_plan::projection::ProjectionExec;
use datafusion::physical_plan::{ExecutionPlan, Statistics};
use datafusion::prelude::Expr;
use datafusion_ext::metrics::ReadOnlyDataSourceMetricsExecAdapter;
use deltalake::kernel::{ArrayType, DataType as DeltaDataType};
use deltalake::logstore::{default_logstore, logstores, LogStore, LogStoreFactory};
use deltalake::operations::create::CreateBuilder;
use deltalake::operations::delete::DeleteBuilder;
Expand All @@ -24,11 +30,8 @@ use object_store::prefix::PrefixStore;
use object_store::ObjectStore;
use object_store_util::shared::SharedObjectStore;
use protogen::metastore::types::catalog::TableEntry;
use protogen::metastore::types::options::{
InternalColumnDefinition,
TableOptions,
TableOptionsInternal,
};
use protogen::metastore::types::options::{TableOptions, TableOptionsInternal};
use serde_json::{json, Value};
use url::Url;
use uuid::Uuid;

Expand Down Expand Up @@ -85,6 +88,48 @@ impl LogStoreFactory for FakeStoreFactory {
Ok(default_logstore(store, location, options))
}
}
struct DeltaField {
data_type: DeltaDataType,
metadata: Option<HashMap<String, Value>>,
}

// Some datatypes get downgraded to a different type when they are stored in delta-lake.
// So we add some metadata to the field to indicate that it needs to be converted back to the original type.
fn arrow_to_delta_safe(arrow_type: &DataType) -> DeltaResult<DeltaField> {
match arrow_type {
dtype @ DataType::Timestamp(_, tz) => {
let delta_type =
(&DataType::Timestamp(TimeUnit::Microsecond, tz.clone())).try_into()?;
let mut metadata = HashMap::new();
metadata.insert("arrow_type".to_string(), json!(dtype));
tychoish marked this conversation as resolved.
Show resolved Hide resolved

Ok(DeltaField {
data_type: delta_type,
metadata: None,
})
}
dtype @ DataType::FixedSizeList(fld, _) => {
let inner_type = arrow_to_delta_safe(fld.data_type())?;
let arr_type = ArrayType::new(inner_type.data_type, fld.is_nullable());
let mut metadata = HashMap::new();

metadata.insert("arrow_type".to_string(), json!(dtype));

Ok(DeltaField {
data_type: DeltaDataType::Array(Box::new(arr_type)),
metadata: Some(metadata),
})
}
other => {
let delta_type = other.try_into()?;
Ok(DeltaField {
data_type: delta_type,
metadata: None,
})
}
}
}


impl NativeTableStorage {
/// Create a native table storage provider from a URL and an object store instance
Expand Down Expand Up @@ -149,20 +194,15 @@ impl NativeTableStorage {
.with_table_name(&table.meta.name)
.with_log_store(delta_store);


for col in &opts.columns {
let column = match &col.arrow_type {
DataType::Timestamp(_, tz) => InternalColumnDefinition {
name: col.name.clone(),
nullable: col.nullable,
arrow_type: DataType::Timestamp(TimeUnit::Microsecond, tz.clone()),
},
_ => col.to_owned(),
};
let delta_col = arrow_to_delta_safe(&col.arrow_type)?;

builder = builder.with_column(
column.name.clone(),
(&column.arrow_type).try_into()?,
column.nullable,
None,
col.name.clone(),
delta_col.data_type,
col.nullable,
delta_col.metadata,
);
}

Expand Down Expand Up @@ -315,8 +355,34 @@ impl TableProvider for NativeTable {
self
}

/// delta downgrades some types to a different type when it stores them.
/// so we need to do a projection to convert them back to the original type.
/// Ideally we should store the original type in a more accessible way (such as using the binary type and deserializing it ourselves)
/// but for now we just do a projection
fn schema(&self) -> Arc<ArrowSchema> {
TableProvider::schema(&self.delta)
let mut fields = vec![];
let arrow_schema = self.delta.snapshot().unwrap().arrow_schema().unwrap();

for col in arrow_schema.fields() {
let mut field = col.clone();
let metadata = col.metadata();

// If the field requires conversion, we need to use the original arrow type
if let Some(arrow_type) = metadata.get("arrow_type") {
// this is dumb AF, delta-lake is returning a string of a json object instead of a json object


// any panics here are bugs in writing the metadata in the first place
let s: String =
serde_json::from_str(arrow_type).expect("metadata was not correctly written");
let arrow_type: DataType =
serde_json::from_str(&s).expect("metadata was not correctly written");

field = Arc::new(Field::new(col.name(), arrow_type, col.is_nullable()));
}
fields.push(field);
}
Arc::new(ArrowSchema::new(fields))
}

fn table_type(&self) -> TableType {
Expand Down Expand Up @@ -345,10 +411,43 @@ impl TableProvider for NativeTable {
};

if num_rows == 0 {
let schema = TableProvider::schema(&self.delta);
let schema = self.schema();
Ok(Arc::new(EmptyExec::new(schema)))
} else {
let plan = self.delta.scan(session, projection, filters, limit).await?;
let output_schema = plan.schema();
let mut schema = self.schema();
if let Some(projection) = projection {
schema = Arc::new(schema.project(projection)?);
}
let df_schema = output_schema.clone().to_dfschema_ref()?;

let plan = if output_schema != schema {
let exprs = output_schema
.fields()
.into_iter()
.zip(schema.fields())
.map(|(f1, f2)| {
let expr = if f1.data_type() == f2.data_type() {
col(f1.name())
} else {
let cast_expr =
Cast::new(Box::new(col(f1.name())), f2.data_type().clone());
Expr::Cast(cast_expr)
};
let execution_props = ExecutionProps::new();
(
create_physical_expr(&expr, &df_schema, &execution_props).unwrap(),
f1.name().clone(),
)
})
.collect::<Vec<_>>();
let prj = ProjectionExec::try_new(exprs, plan)?;
// we need to do a projection to match the schema
Arc::new(prj)
} else {
plan
};
Ok(Arc::new(ReadOnlyDataSourceMetricsExecAdapter::new(plan)))
}
}
Expand Down
10 changes: 10 additions & 0 deletions testdata/sqllogictests/datatypes.slt
Original file line number Diff line number Diff line change
@@ -0,0 +1,10 @@
# fixed size list round trip

statement ok
create table test as select arrow_cast([1,2], 'FixedSizeList(2, Int64)') as f;
universalmind303 marked this conversation as resolved.
Show resolved Hide resolved

query III
describe test
----
f FixedSizeList<Int64; 2> t

Loading