Skip to content

Commit

Permalink
Merge pull request #5980 from sundy-li/fix-if
Browse files Browse the repository at this point in the history
fix(function): fix incorrect return datatype of function if
  • Loading branch information
BohuTANG authored Jun 15, 2022
2 parents 2b67b08 + ce9a5c2 commit fccf6a8
Show file tree
Hide file tree
Showing 21 changed files with 114 additions and 138 deletions.
5 changes: 5 additions & 0 deletions common/datablocks/src/data_block.rs
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,11 @@ pub struct DataBlock {
impl DataBlock {
#[inline]
pub fn create(schema: DataSchemaRef, columns: Vec<ColumnRef>) -> Self {
debug_assert!(schema.fields().iter().zip(columns.iter()).all(|(f, c)| f
.data_type()
.data_type_id()
.to_physical_type()
== c.data_type().data_type_id().to_physical_type()));
DataBlock { schema, columns }
}

Expand Down
2 changes: 1 addition & 1 deletion common/datablocks/tests/it/data_block.rs
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@ use pretty_assertions::assert_eq;
fn test_data_block() -> Result<()> {
let schema = DataSchemaRefExt::create(vec![DataField::new("a", i64::to_data_type())]);

let block = DataBlock::create(schema.clone(), vec![Series::from_data(vec![1, 2, 3])]);
let block = DataBlock::create(schema.clone(), vec![Series::from_data(vec![1i64, 2, 3])]);
assert_eq!(&schema, block.schema());

assert_eq!(3, block.num_rows());
Expand Down
7 changes: 3 additions & 4 deletions common/datablocks/tests/it/kernels/data_block_sort.rs
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@ fn test_data_block_sort() -> Result<()> {
]);

let raw = DataBlock::create(schema, vec![
Series::from_data(vec![6, 4, 3, 2, 1, 7]),
Series::from_data(vec![6i64, 4, 3, 2, 1, 7]),
Series::from_data(vec!["b1", "b2", "b3", "b4", "b5", "b6"]),
]);

Expand Down Expand Up @@ -185,7 +185,6 @@ fn test_data_block_sort() -> Result<()> {
asc: true,
nulls_first: false,
}];
println!("raw={:?}", raw);
let results = DataBlock::sort_block(&raw, &options, Some(3))?;
assert_eq!(raw.schema(), results.schema());

Expand Down Expand Up @@ -233,12 +232,12 @@ fn test_data_block_merge_sort() -> Result<()> {
]);

let raw1 = DataBlock::create(schema.clone(), vec![
Series::from_data(vec![3, 5, 7]),
Series::from_data(vec![3i64, 5, 7]),
Series::from_data(vec!["b1", "b2", "b3"]),
]);

let raw2 = DataBlock::create(schema, vec![
Series::from_data(vec![2, 4, 6]),
Series::from_data(vec![2i64, 4, 6]),
Series::from_data(vec!["b4", "b5", "b6"]),
]);

Expand Down
12 changes: 7 additions & 5 deletions common/functions/src/scalars/conditionals/if.rs
Original file line number Diff line number Diff line change
Expand Up @@ -57,16 +57,18 @@ impl IfFunction {
&self,
cond_col: &ColumnRef,
columns: &ColumnsWithField,
_func_ctx: &FunctionContext,
func_ctx: &FunctionContext,
) -> Result<ColumnRef> {
debug_assert!(cond_col.is_const());
// whether nullable or not, we can use viewer to make it
let cond_viewer = bool::try_create_viewer(cond_col)?;
if cond_viewer.value_at(0) {
return Ok(columns[0].column().clone());
let c = if cond_viewer.value_at(0) {
columns[0].clone()
} else {
return Ok(columns[1].column().clone());
}
columns[1].clone()
};

cast_column_field(&c, c.data_type(), &self.least_supertype, func_ctx)
}

// lhs is const column and:
Expand Down
12 changes: 11 additions & 1 deletion common/functions/src/scalars/strings/regexp_instr.rs
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,7 @@ use crate::scalars::FunctionFeatures;
#[derive(Clone)]
pub struct RegexpInStrFunction {
display_name: String,
return_type: DataTypeImpl,
}

impl RegexpInStrFunction {
Expand All @@ -53,8 +54,17 @@ impl RegexpInStrFunction {
}
}

let has_null = args.iter().any(|arg| arg.is_null());

let return_type = if has_null {
NullType::new_impl()
} else {
NullableType::new_impl(u64::to_data_type())
};

Ok(Box::new(Self {
display_name: display_name.to_string(),
return_type,
}))
}

Expand All @@ -74,7 +84,7 @@ impl Function for RegexpInStrFunction {
}

fn return_type(&self) -> DataTypeImpl {
NullableType::new_impl(u64::to_data_type())
self.return_type.clone()
}

// Notes: https://dev.mysql.com/doc/refman/8.0/en/regexp.html#function_regexp-instr
Expand Down
12 changes: 11 additions & 1 deletion common/functions/src/scalars/strings/regexp_replace.rs
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,7 @@ use crate::scalars::FunctionFeatures;
#[derive(Clone)]
pub struct RegexpReplaceFunction {
display_name: String,
return_type: DataTypeImpl,
}

impl RegexpReplaceFunction {
Expand All @@ -54,8 +55,17 @@ impl RegexpReplaceFunction {
}
}

let has_null = args.iter().any(|arg| arg.is_null());

let return_type = if has_null {
NullType::new_impl()
} else {
NullableType::new_impl(StringType::new_impl())
};

Ok(Box::new(Self {
display_name: display_name.to_string(),
return_type,
}))
}

Expand All @@ -75,7 +85,7 @@ impl Function for RegexpReplaceFunction {
}

fn return_type(&self) -> DataTypeImpl {
NullableType::new_impl(StringType::new_impl())
self.return_type.clone()
}

// Notes: https://dev.mysql.com/doc/refman/8.0/en/regexp.html#function_regexp-replace
Expand Down
11 changes: 10 additions & 1 deletion common/functions/src/scalars/strings/regexp_substr.rs
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,7 @@ use crate::scalars::FunctionFeatures;
#[derive(Clone)]
pub struct RegexpSubStrFunction {
display_name: String,
return_type: DataTypeImpl,
}

impl RegexpSubStrFunction {
Expand All @@ -54,8 +55,16 @@ impl RegexpSubStrFunction {
}
}

let has_null = args.iter().any(|arg| arg.is_null());
let return_type = if has_null {
NullType::new_impl()
} else {
NullableType::new_impl(StringType::new_impl())
};

Ok(Box::new(Self {
display_name: display_name.to_string(),
return_type,
}))
}

Expand All @@ -75,7 +84,7 @@ impl Function for RegexpSubStrFunction {
}

fn return_type(&self) -> DataTypeImpl {
NullableType::new_impl(StringType::new_impl())
self.return_type.clone()
}

// Notes: https://dev.mysql.com/doc/refman/8.0/en/regexp.html#function_regexp-substr
Expand Down
4 changes: 2 additions & 2 deletions common/streams/tests/it/stream_datablock.rs
Original file line number Diff line number Diff line change
Expand Up @@ -21,8 +21,8 @@ use futures::stream::StreamExt;
#[tokio::test]
async fn test_datablock_stream() {
let schema = DataSchemaRefExt::create(vec![
DataField::new("name", i32::to_data_type()),
DataField::new("age", Vu8::to_data_type()),
DataField::new("name", Vu8::to_data_type()),
DataField::new("age", i32::to_data_type()),
]);

let data_blocks = vec![
Expand Down
2 changes: 1 addition & 1 deletion query/src/procedures/stats/tenant_tables.rs
Original file line number Diff line number Diff line change
Expand Up @@ -71,7 +71,7 @@ impl Procedure for TenantTablesProcedure {
fn schema(&self) -> Arc<DataSchema> {
DataSchemaRefExt::create(vec![
DataField::new("tenant_id", Vu8::to_data_type()),
DataField::new("table_count", u32::to_data_type()),
DataField::new("table_count", u64::to_data_type()),
])
}
}
19 changes: 19 additions & 0 deletions query/src/sql/exec/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,8 @@ use common_datavalues::DataSchema;
use common_datavalues::DataSchemaRef;
use common_datavalues::DataSchemaRefExt;
use common_datavalues::DataTypeImpl;
use common_datavalues::ToDataType;
use common_datavalues::Vu8;
use common_exception::ErrorCode;
use common_exception::Result;
use common_planners::find_aggregate_exprs;
Expand Down Expand Up @@ -444,6 +446,23 @@ impl PipelineBuilder {
// Get partial schema from agg_expressions
let partial_data_fields =
RewriteHelper::exprs_to_fields(agg_expressions.as_slice(), &input_schema)?;
let mut partial_data_fields = partial_data_fields
.iter()
.map(|f| DataField::new(f.name(), Vu8::to_data_type()))
.collect::<Vec<_>>();

if !group_expressions.is_empty() {
// Fields. [aggrs, key]
// aggrs: aggr_len aggregate states
// key: Varint by hash method
let group_cols: Vec<String> = group_expressions
.iter()
.map(|expr| expr.column_name())
.collect();
let sample_block = DataBlock::empty_with_schema(input_schema.clone());
let method = DataBlock::choose_hash_method(&sample_block, &group_cols)?;
partial_data_fields.push(DataField::new("_group_by_key", method.data_type()));
}
let partial_schema = DataSchemaRefExt::create(partial_data_fields);

// Get final schema from agg_expression and group expression
Expand Down
20 changes: 8 additions & 12 deletions query/tests/it/formats/output_format_tcsv.rs
Original file line number Diff line number Diff line change
Expand Up @@ -38,29 +38,25 @@ fn test_data_block(is_nullable: bool) -> Result<()> {
]),
};

let block = DataBlock::create(schema.clone(), vec![
Series::from_data(vec![1, 2, 3]),
let mut columns = vec![
Series::from_data(vec![1i32, 2, 3]),
Series::from_data(vec!["a", "b", "c"]),
Series::from_data(vec![true, true, false]),
Series::from_data(vec![1.1, 2.2, 3.3]),
Series::from_data(vec![1.1f64, 2.2, 3.3]),
Series::from_data(vec![1_i32, 2_i32, 3_i32]),
]);
];

let block = if is_nullable {
let columns = block
.columns()
if is_nullable {
columns = columns
.iter()
.map(|c| {
let mut validity = MutableBitmap::new();
validity.extend_constant(c.len(), true);
NullableColumn::wrap_inner(c.clone(), Some(validity.into()))
})
.collect();
DataBlock::create(schema.clone(), columns)
} else {
block
};

}
let block = DataBlock::create(schema.clone(), columns);
let mut format_setting = FormatSettings::default();

{
Expand Down
17 changes: 8 additions & 9 deletions query/tests/it/servers/http/json_block.rs
Original file line number Diff line number Diff line change
Expand Up @@ -46,28 +46,27 @@ fn test_data_block(is_nullable: bool) -> Result<()> {
]),
};

let block = DataBlock::create(schema.clone(), vec![
let mut columns = vec![
Series::from_data(vec![1, 2, 3]),
Series::from_data(vec!["a", "b", "c"]),
Series::from_data(vec![true, true, false]),
Series::from_data(vec![1.1, 2.2, 3.3]),
Series::from_data(vec![1_i32, 2_i32, 3_i32]),
]);
];

let block = if is_nullable {
let columns = block
.columns()
if is_nullable {
columns = columns
.iter()
.map(|c| {
let mut validity = MutableBitmap::new();
validity.extend_constant(c.len(), true);
NullableColumn::wrap_inner(c.clone(), Some(validity.into()))
})
.collect();
DataBlock::create(schema, columns)
} else {
block
};
}

let block = DataBlock::create(schema, columns);

let format = FormatSettings::default();
let json_block = JsonBlock::new(&block, &format, false)?;
let expect = vec![
Expand Down
41 changes: 17 additions & 24 deletions query/tests/it/storages/index/bloom_filter.rs
Original file line number Diff line number Diff line change
Expand Up @@ -42,8 +42,7 @@ fn test_num_bits_hashes() -> Result<()> {

#[tokio::test]
async fn test_bloom_add_find_string() -> Result<()> {
let schema =
DataSchemaRefExt::create(vec![DataField::new_nullable("name", Vu8::to_data_type())]);
let schema = DataSchemaRefExt::create(vec![DataField::new("name", Vu8::to_data_type())]);
let block = DataBlock::create(schema, vec![Series::from_data(vec![
"Alice", "Bob", "Batman", "Superman", "123",
])]);
Expand Down Expand Up @@ -233,25 +232,23 @@ async fn create_bloom(
// create test data, all numerics are odd number, even numbers are reserved for testing.
fn create_blocks() -> Vec<DataBlock> {
let schema = DataSchemaRefExt::create(vec![
DataField::new_nullable("ColumnUInt8", u8::to_data_type()),
DataField::new_nullable("ColumnUInt16", u16::to_data_type()),
DataField::new_nullable("ColumnUInt32", u32::to_data_type()),
DataField::new_nullable("ColumnUInt64", u64::to_data_type()),
DataField::new_nullable("ColumnInt8", i8::to_data_type()),
DataField::new_nullable("ColumnInt16", i16::to_data_type()),
DataField::new_nullable("ColumnInt32", i32::to_data_type()),
DataField::new_nullable("ColumnInt64", i64::to_data_type()),
DataField::new_nullable("ColumnFloat32", f32::to_data_type()),
DataField::new_nullable("ColumnFloat64", f64::to_data_type()),
DataField::new_nullable("ColumnDate16", DateType::new_impl()),
DataField::new_nullable("ColumnDate32", DateType::new_impl()),
DataField::new_nullable("ColumnDateTime32", TimestampType::new_impl(0)),
DataField::new_nullable("ColumnDateTime64", TimestampType::new_impl(3)),
DataField::new_nullable(
DataField::new("ColumnUInt8", u8::to_data_type()),
DataField::new("ColumnUInt16", u16::to_data_type()),
DataField::new("ColumnUInt32", u32::to_data_type()),
DataField::new("ColumnUInt64", u64::to_data_type()),
DataField::new("ColumnInt8", i8::to_data_type()),
DataField::new("ColumnInt16", i16::to_data_type()),
DataField::new("ColumnInt32", i32::to_data_type()),
DataField::new("ColumnInt64", i64::to_data_type()),
DataField::new("ColumnFloat32", f32::to_data_type()),
DataField::new("ColumnFloat64", f64::to_data_type()),
DataField::new("ColumnDate32", DateType::new_impl()),
DataField::new("ColumnDateTime64", TimestampType::new_impl(3)),
DataField::new(
"ColumnIntervalDays",
IntervalType::new_impl(IntervalKind::Day),
),
DataField::new_nullable("ColumnString", Vu8::to_data_type()),
DataField::new("ColumnString", Vu8::to_data_type()),
]);

let block1 = DataBlock::create(schema.clone(), vec![
Expand All @@ -265,9 +262,7 @@ fn create_blocks() -> Vec<DataBlock> {
Series::from_data(vec![-1_i64, -3, -5, -7]),
Series::from_data(vec![1.0_f32, 3.0, 5.0, 7.0]),
Series::from_data(vec![1.0_f64, 3.0, 5.0, 7.0]),
Series::from_data(vec![1_u16, 3, 5, 7]),
Series::from_data(vec![1_u32, 3, 5, 7]),
Series::from_data(vec![1_u32, 3, 5, 7]),
Series::from_data(vec![1_i32, 3, 5, 7]),
Series::from_data(vec![1_i64, 3, 5, 7]),
Series::from_data(vec![1_i64, 3, 5, 7]),
Series::from_data(vec!["Alice", "Bob", "Batman", "Superman"]),
Expand All @@ -284,9 +279,7 @@ fn create_blocks() -> Vec<DataBlock> {
Series::from_data(vec![-9_i64, -11, -13, -15]),
Series::from_data(vec![9.0_f32, 11.0, 13.0, 15.0]),
Series::from_data(vec![9.0_f64, 11.0, 13.0, 15.0]),
Series::from_data(vec![9_u16, 11, 13, 15]),
Series::from_data(vec![9_u32, 11, 13, 15]),
Series::from_data(vec![9_u32, 11, 13, 15]),
Series::from_data(vec![9_i32, 11, 13, 15]),
Series::from_data(vec![9_i64, 11, 13, 15]),
Series::from_data(vec![9_i64, 11, 13, 15]),
Series::from_data(vec!["Iron man", "Thor", "Professor X", "Wolverine"]),
Expand Down
2 changes: 1 addition & 1 deletion query/tests/it/storages/index/index_min_max.rs
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@ use pretty_assertions::assert_eq;
#[test]
fn test_min_max_index() -> Result<()> {
let schema = DataSchemaRefExt::create(vec![
DataField::new_nullable("name", Vu8::to_data_type()),
DataField::new("name", Vu8::to_data_type()),
DataField::new("age", i32::to_data_type()),
]);

Expand Down
2 changes: 1 addition & 1 deletion query/tests/it/storages/index/index_sparse.rs
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@ use pretty_assertions::assert_eq;
#[test]
fn test_sparse_index() -> Result<()> {
let schema = DataSchemaRefExt::create(vec![
DataField::new_nullable("name", Vu8::to_data_type()),
DataField::new("name", Vu8::to_data_type()),
DataField::new("age", i32::to_data_type()),
]);

Expand Down
Loading

0 comments on commit fccf6a8

Please sign in to comment.