Skip to content

Commit

Permalink
Infer data type from schema for Values and add struct coercion to `…
Browse files Browse the repository at this point in the history
…coalesce` (#12864)

* first draft

Signed-off-by: jayzhan211 <jayzhan211@gmail.com>

* cleanup

Signed-off-by: jayzhan211 <jayzhan211@gmail.com>

* add values table without schema

Signed-off-by: jayzhan211 <jayzhan211@gmail.com>

* cleanup

Signed-off-by: jayzhan211 <jayzhan211@gmail.com>

* fmt

Signed-off-by: jayzhan211 <jayzhan211@gmail.com>

* rm unused import

Signed-off-by: jayzhan211 <jayzhan211@gmail.com>

* fmt

Signed-off-by: jayzhan211 <jayzhan211@gmail.com>

* use option instead of vec<err>

Signed-off-by: jayzhan211 <jayzhan211@gmail.com>

* Fix clippy

* add values back and rename

Signed-off-by: jayzhan211 <jayzhan211@gmail.com>

* invalid query

Signed-off-by: jayzhan211 <jayzhan211@gmail.com>

* use values if no schema

Signed-off-by: jayzhan211 <jayzhan211@gmail.com>

* add doc

Signed-off-by: jayzhan211 <jayzhan211@gmail.com>

---------

Signed-off-by: jayzhan211 <jayzhan211@gmail.com>
Co-authored-by: Andrew Lamb <andrew@nerdnetworks.org>
  • Loading branch information
jayzhan211 and alamb authored Oct 24, 2024
1 parent de526a9 commit 18b2aaa
Show file tree
Hide file tree
Showing 20 changed files with 368 additions and 126 deletions.
1 change: 1 addition & 0 deletions datafusion-cli/Cargo.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

1 change: 0 additions & 1 deletion datafusion/common/src/dfschema.rs
Original file line number Diff line number Diff line change
Expand Up @@ -315,7 +315,6 @@ impl DFSchema {
None => self_unqualified_names.contains(field.name().as_str()),
};
if !duplicated_field {
// self.inner.fields.push(field.clone());
schema_builder.push(Arc::clone(field));
qualifiers.push(qualifier.cloned());
}
Expand Down
1 change: 1 addition & 0 deletions datafusion/expr-common/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -40,4 +40,5 @@ path = "src/lib.rs"
[dependencies]
arrow = { workspace = true }
datafusion-common = { workspace = true }
itertools = { workspace = true }
paste = "^1.0"
90 changes: 89 additions & 1 deletion datafusion/expr-common/src/type_coercion/binary.rs
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,10 @@ use arrow::datatypes::{
DataType, Field, FieldRef, Fields, TimeUnit, DECIMAL128_MAX_PRECISION,
DECIMAL128_MAX_SCALE, DECIMAL256_MAX_PRECISION, DECIMAL256_MAX_SCALE,
};
use datafusion_common::{exec_datafusion_err, plan_datafusion_err, plan_err, Result};
use datafusion_common::{
exec_datafusion_err, exec_err, internal_err, plan_datafusion_err, plan_err, Result,
};
use itertools::Itertools;

/// The type signature of an instantiation of binary operator expression such as
/// `lhs + rhs`
Expand Down Expand Up @@ -372,6 +375,8 @@ impl From<&DataType> for TypeCategory {
/// decimal precision and scale when coercing decimal types.
///
/// This function doesn't preserve correct field name and nullability for the struct type, we only care about data type.
///
/// Returns Option because we might want to continue on the code even if the data types are not coercible to the common type
pub fn type_union_resolution(data_types: &[DataType]) -> Option<DataType> {
if data_types.is_empty() {
return None;
Expand Down Expand Up @@ -529,6 +534,89 @@ fn type_union_resolution_coercion(
}
}

/// Handle type union resolution including struct type and others.
pub fn try_type_union_resolution(data_types: &[DataType]) -> Result<Vec<DataType>> {
let err = match try_type_union_resolution_with_struct(data_types) {
Ok(struct_types) => return Ok(struct_types),
Err(e) => Some(e),
};

if let Some(new_type) = type_union_resolution(data_types) {
Ok(vec![new_type; data_types.len()])
} else {
exec_err!("Fail to find the coerced type, errors: {:?}", err)
}
}

// Handle struct where we only change the data type but preserve the field name and nullability.
// Since field name is the key of the struct, so it shouldn't be updated to the common column name like "c0" or "c1"
pub fn try_type_union_resolution_with_struct(
data_types: &[DataType],
) -> Result<Vec<DataType>> {
let mut keys_string: Option<String> = None;
for data_type in data_types {
if let DataType::Struct(fields) = data_type {
let keys = fields.iter().map(|f| f.name().to_owned()).join(",");
if let Some(ref k) = keys_string {
if *k != keys {
return exec_err!("Expect same keys for struct type but got mismatched pair {} and {}", *k, keys);
}
} else {
keys_string = Some(keys);
}
} else {
return exec_err!("Expect to get struct but got {}", data_type);
}
}

let mut struct_types: Vec<DataType> = if let DataType::Struct(fields) = &data_types[0]
{
fields.iter().map(|f| f.data_type().to_owned()).collect()
} else {
return internal_err!("Struct type is checked is the previous function, so this should be unreachable");
};

for data_type in data_types.iter().skip(1) {
if let DataType::Struct(fields) = data_type {
let incoming_struct_types: Vec<DataType> =
fields.iter().map(|f| f.data_type().to_owned()).collect();
// The order of field is verified above
for (lhs_type, rhs_type) in
struct_types.iter_mut().zip(incoming_struct_types.iter())
{
if let Some(coerced_type) =
type_union_resolution_coercion(lhs_type, rhs_type)
{
*lhs_type = coerced_type;
} else {
return exec_err!(
"Fail to find the coerced type for {} and {}",
lhs_type,
rhs_type
);
}
}
} else {
return exec_err!("Expect to get struct but got {}", data_type);
}
}

let mut final_struct_types = vec![];
for s in data_types {
let mut new_fields = vec![];
if let DataType::Struct(fields) = s {
for (i, f) in fields.iter().enumerate() {
let field = Arc::unwrap_or_clone(Arc::clone(f))
.with_data_type(struct_types[i].to_owned());
new_fields.push(Arc::new(field));
}
}
final_struct_types.push(DataType::Struct(new_fields.into()))
}

Ok(final_struct_types)
}

/// Coerce `lhs_type` and `rhs_type` to a common type for the purposes of a
/// comparison operation
///
Expand Down
97 changes: 89 additions & 8 deletions datafusion/expr/src/logical_plan/builder.rs
Original file line number Diff line number Diff line change
Expand Up @@ -46,13 +46,15 @@ use crate::{

use super::dml::InsertOp;
use super::plan::ColumnUnnestList;
use arrow::compute::can_cast_types;
use arrow::datatypes::{DataType, Field, Fields, Schema, SchemaRef};
use datafusion_common::display::ToStringifiedPlan;
use datafusion_common::file_options::file_type::FileType;
use datafusion_common::{
get_target_functional_dependencies, internal_err, not_impl_err, plan_datafusion_err,
plan_err, Column, DFSchema, DFSchemaRef, DataFusionError, FunctionalDependencies,
Result, ScalarValue, TableReference, ToDFSchema, UnnestOptions,
exec_err, get_target_functional_dependencies, internal_err, not_impl_err,
plan_datafusion_err, plan_err, Column, DFSchema, DFSchemaRef, DataFusionError,
FunctionalDependencies, Result, ScalarValue, TableReference, ToDFSchema,
UnnestOptions,
};
use datafusion_expr_common::type_coercion::binary::type_union_resolution;

Expand Down Expand Up @@ -172,12 +174,45 @@ impl LogicalPlanBuilder {
/// `value`. See the [Postgres VALUES](https://www.postgresql.org/docs/current/queries-values.html)
/// documentation for more details.
///
/// so it's usually better to override the default names with a table alias list.
///
/// If the values include params/binders such as $1, $2, $3, etc, then the `param_data_types` should be provided.
pub fn values(values: Vec<Vec<Expr>>) -> Result<Self> {
if values.is_empty() {
return plan_err!("Values list cannot be empty");
}
let n_cols = values[0].len();
if n_cols == 0 {
return plan_err!("Values list cannot be zero length");
}
for (i, row) in values.iter().enumerate() {
if row.len() != n_cols {
return plan_err!(
"Inconsistent data length across values list: got {} values in row {} but expected {}",
row.len(),
i,
n_cols
);
}
}

// Infer from data itself
Self::infer_data(values)
}

/// Create a values list based relation, and the schema is inferred from data itself or table schema if provided, consuming
/// `value`. See the [Postgres VALUES](https://www.postgresql.org/docs/current/queries-values.html)
/// documentation for more details.
///
/// By default, it assigns the names column1, column2, etc. to the columns of a VALUES table.
/// The column names are not specified by the SQL standard and different database systems do it differently,
/// so it's usually better to override the default names with a table alias list.
///
/// If the values include params/binders such as $1, $2, $3, etc, then the `param_data_types` should be provided.
pub fn values(mut values: Vec<Vec<Expr>>) -> Result<Self> {
pub fn values_with_schema(
values: Vec<Vec<Expr>>,
schema: &DFSchemaRef,
) -> Result<Self> {
if values.is_empty() {
return plan_err!("Values list cannot be empty");
}
Expand All @@ -196,16 +231,53 @@ impl LogicalPlanBuilder {
}
}

let empty_schema = DFSchema::empty();
// Check the type of value against the schema
Self::infer_values_from_schema(values, schema)
}

fn infer_values_from_schema(
values: Vec<Vec<Expr>>,
schema: &DFSchema,
) -> Result<Self> {
let n_cols = values[0].len();
let mut field_types: Vec<DataType> = Vec::with_capacity(n_cols);
for j in 0..n_cols {
let field_type = schema.field(j).data_type();
for row in values.iter() {
let value = &row[j];
let data_type = value.get_type(schema)?;

if !data_type.equals_datatype(field_type) {
if can_cast_types(&data_type, field_type) {
} else {
return exec_err!(
"type mistmatch and can't cast to got {} and {}",
data_type,
field_type
);
}
}
}
field_types.push(field_type.to_owned());
}

Self::infer_inner(values, &field_types, schema)
}

fn infer_data(values: Vec<Vec<Expr>>) -> Result<Self> {
let n_cols = values[0].len();
let schema = DFSchema::empty();

let mut field_types: Vec<DataType> = Vec::with_capacity(n_cols);
for j in 0..n_cols {
let mut common_type: Option<DataType> = None;
for (i, row) in values.iter().enumerate() {
let value = &row[j];
let data_type = value.get_type(&empty_schema)?;
let data_type = value.get_type(&schema)?;
if data_type == DataType::Null {
continue;
}

if let Some(prev_type) = common_type {
// get common type of each column values.
let data_types = vec![prev_type.clone(), data_type.clone()];
Expand All @@ -221,14 +293,22 @@ impl LogicalPlanBuilder {
// since the code loop skips NULL
field_types.push(common_type.unwrap_or(DataType::Null));
}

Self::infer_inner(values, &field_types, &schema)
}

fn infer_inner(
mut values: Vec<Vec<Expr>>,
field_types: &[DataType],
schema: &DFSchema,
) -> Result<Self> {
// wrap cast if data type is not same as common type.
for row in &mut values {
for (j, field_type) in field_types.iter().enumerate() {
if let Expr::Literal(ScalarValue::Null) = row[j] {
row[j] = Expr::Literal(ScalarValue::try_from(field_type)?);
} else {
row[j] =
std::mem::take(&mut row[j]).cast_to(field_type, &empty_schema)?;
row[j] = std::mem::take(&mut row[j]).cast_to(field_type, schema)?;
}
}
}
Expand All @@ -243,6 +323,7 @@ impl LogicalPlanBuilder {
.collect::<Vec<_>>();
let dfschema = DFSchema::from_unqualified_fields(fields.into(), HashMap::new())?;
let schema = DFSchemaRef::new(dfschema);

Ok(Self::new(LogicalPlan::Values(Values { schema, values })))
}

Expand Down
64 changes: 14 additions & 50 deletions datafusion/functions-nested/src/make_array.rs
Original file line number Diff line number Diff line change
Expand Up @@ -28,15 +28,15 @@ use arrow_array::{
use arrow_buffer::OffsetBuffer;
use arrow_schema::DataType::{LargeList, List, Null};
use arrow_schema::{DataType, Field};
use datafusion_common::{exec_err, internal_err};
use datafusion_common::{plan_err, utils::array_into_list_array_nullable, Result};
use datafusion_expr::binary::type_union_resolution;
use datafusion_expr::binary::{
try_type_union_resolution_with_struct, type_union_resolution,
};
use datafusion_expr::scalar_doc_sections::DOC_SECTION_ARRAY;
use datafusion_expr::TypeSignature;
use datafusion_expr::{
ColumnarValue, Documentation, ScalarUDFImpl, Signature, Volatility,
};
use itertools::Itertools;

use crate::utils::make_scalar_function;

Expand Down Expand Up @@ -111,33 +111,16 @@ impl ScalarUDFImpl for MakeArray {
}

fn coerce_types(&self, arg_types: &[DataType]) -> Result<Vec<DataType>> {
if let Some(new_type) = type_union_resolution(arg_types) {
// TODO: Move the logic to type_union_resolution if this applies to other functions as well
// Handle struct where we only change the data type but preserve the field name and nullability.
// Since field name is the key of the struct, so it shouldn't be updated to the common column name like "c0" or "c1"
let is_struct_and_has_same_key = are_all_struct_and_have_same_key(arg_types)?;
if is_struct_and_has_same_key {
let data_types: Vec<_> = if let DataType::Struct(fields) = &arg_types[0] {
fields.iter().map(|f| f.data_type().to_owned()).collect()
} else {
return internal_err!("Struct type is checked is the previous function, so this should be unreachable");
};

let mut final_struct_types = vec![];
for s in arg_types {
let mut new_fields = vec![];
if let DataType::Struct(fields) = s {
for (i, f) in fields.iter().enumerate() {
let field = Arc::unwrap_or_clone(Arc::clone(f))
.with_data_type(data_types[i].to_owned());
new_fields.push(Arc::new(field));
}
}
final_struct_types.push(DataType::Struct(new_fields.into()))
}
return Ok(final_struct_types);
let mut errors = vec![];
match try_type_union_resolution_with_struct(arg_types) {
Ok(r) => return Ok(r),
Err(e) => {
errors.push(e);
}
}

if let Some(new_type) = type_union_resolution(arg_types) {
// TODO: Move FixedSizeList to List in type_union_resolution
if let DataType::FixedSizeList(field, _) = new_type {
Ok(vec![DataType::List(field); arg_types.len()])
} else if new_type.is_null() {
Expand All @@ -147,9 +130,10 @@ impl ScalarUDFImpl for MakeArray {
}
} else {
plan_err!(
"Fail to find the valid type between {:?} for {}",
"Fail to find the valid type between {:?} for {}, errors are {:?}",
arg_types,
self.name()
self.name(),
errors
)
}
}
Expand Down Expand Up @@ -188,26 +172,6 @@ fn get_make_array_doc() -> &'static Documentation {
})
}

fn are_all_struct_and_have_same_key(data_types: &[DataType]) -> Result<bool> {
let mut keys_string: Option<String> = None;
for data_type in data_types {
if let DataType::Struct(fields) = data_type {
let keys = fields.iter().map(|f| f.name().to_owned()).join(",");
if let Some(ref k) = keys_string {
if *k != keys {
return exec_err!("Expect same keys for struct type but got mismatched pair {} and {}", *k, keys);
}
} else {
keys_string = Some(keys);
}
} else {
return Ok(false);
}
}

Ok(true)
}

// Empty array is a special case that is useful for many other array functions
pub(super) fn empty_array_type() -> DataType {
DataType::List(Arc::new(Field::new("item", DataType::Int64, true)))
Expand Down
Loading

0 comments on commit 18b2aaa

Please sign in to comment.