From 49252766894f3bb5f5d4767037142cb1b10d92bc Mon Sep 17 00:00:00 2001 From: baishen Date: Wed, 25 May 2022 16:05:22 +0800 Subject: [PATCH] feat(function): Support variant group by --- common/datavalues/src/columns/group_hash.rs | 77 ++++++++++++++++++- common/io/src/binary_write.rs | 7 ++ .../03_dml/03_0003_select_group_by.result | 9 +++ .../03_dml/03_0003_select_group_by.sql | 29 +++++++ 4 files changed, 119 insertions(+), 3 deletions(-) diff --git a/common/datavalues/src/columns/group_hash.rs b/common/datavalues/src/columns/group_hash.rs index 84452c36236e..72acd5aea410 100644 --- a/common/datavalues/src/columns/group_hash.rs +++ b/common/datavalues/src/columns/group_hash.rs @@ -289,6 +289,77 @@ impl GroupHash for StringColumn { } } -// TODO(b41sh): implement GroupHash for VariantColumn -impl GroupHash for VariantColumn {} -impl GroupHash for ArrayColumn {} +impl GroupHash for VariantColumn { + fn serialize(&self, vec: &mut Vec, nulls: Option) -> Result<()> { + assert_eq!(vec.len(), self.len()); + + match nulls { + Some(bitmap) => { + for ((value, valid), vec) in self.iter().zip(bitmap.iter()).zip(vec) { + BinaryWrite::write_scalar(vec, &valid)?; + if valid { + BinaryWrite::write_binary(vec, value.to_string().as_bytes())?; + } + } + } + None => { + for (value, vec) in self.iter().zip(vec) { + BinaryWrite::write_binary(vec, value.to_string().as_bytes())?; + } + } + } + + Ok(()) + } +} + +impl GroupHash for ArrayColumn { + fn serialize(&self, vec: &mut Vec, nulls: Option) -> Result<()> { + assert_eq!(vec.len(), self.len()); + + let offsets = self.offsets(); + if offsets.len() <= 1 { + return Ok(()); + } + let inner_column = self.values(); + let inner_length = *offsets.last().unwrap() as usize; + let mut inner_keys = Vec::with_capacity(inner_length); + for _i in 0..inner_length { + inner_keys.push(SmallVu8::new()); + } + Series::serialize(inner_column, &mut inner_keys, None)?; + + match nulls { + Some(bitmap) => { + let mut offset = 0; + for i in 0..self.len() { + let valid = bitmap.get(i).unwrap(); + let v = vec.get_mut(i).unwrap(); + BinaryWrite::write_scalar(v, &valid)?; + if valid { + let length = self.size_at_index(i); + BinaryWrite::write_uvarint(v, length as u64)?; + for j in offset..offset + length { + BinaryWrite::write_raw(v, inner_keys.get(j).unwrap())?; + } + offset += length; + } + } + } + None => { + let mut offset = 0; + for i in 0..self.len() { + let v = vec.get_mut(i).unwrap(); + let length = self.size_at_index(i); + BinaryWrite::write_uvarint(v, length as u64)?; + for j in offset..offset + length { + BinaryWrite::write_raw(v, inner_keys.get(j).unwrap())?; + } + offset += length; + } + } + } + + Ok(()) + } +} diff --git a/common/io/src/binary_write.rs b/common/io/src/binary_write.rs index 9223333582c1..363cea49a9fc 100644 --- a/common/io/src/binary_write.rs +++ b/common/io/src/binary_write.rs @@ -27,6 +27,7 @@ pub trait BinaryWrite { fn write_string(&mut self, text: impl AsRef) -> Result<()>; fn write_uvarint(&mut self, v: u64) -> Result<()>; fn write_binary(&mut self, text: impl AsRef<[u8]>) -> Result<()>; + fn write_raw(&mut self, text: impl AsRef<[u8]>) -> Result<()>; fn write_opt_scalar(&mut self, v: &Option) -> Result<()> where V: Marshal + StatBuffer { @@ -71,6 +72,12 @@ where T: std::io::Write self.write_all(bytes)?; Ok(()) } + + fn write_raw(&mut self, text: impl AsRef<[u8]>) -> Result<()> { + let bytes = text.as_ref(); + self.write_all(bytes)?; + Ok(()) + } } // Another trait like BinaryWrite diff --git a/tests/suites/0_stateless/03_dml/03_0003_select_group_by.result b/tests/suites/0_stateless/03_dml/03_0003_select_group_by.result index d0ee43608d49..d7ffb8639a1e 100644 --- a/tests/suites/0_stateless/03_dml/03_0003_select_group_by.result +++ b/tests/suites/0_stateless/03_dml/03_0003_select_group_by.result @@ -53,3 +53,12 @@ NULL 2 3 2 1 3 1 4 1 +==GROUP BY Variant== +6 5 12 +4 3 "abcd" +2 1 {"k":"v"} +8 7 [1,2,3] +==GROUP BY Array(Int32)== +2 1 [] +4 3 [1, 2, 3] +6 5 [4, 5, 6] diff --git a/tests/suites/0_stateless/03_dml/03_0003_select_group_by.sql b/tests/suites/0_stateless/03_dml/03_0003_select_group_by.sql index 2b43969c60a0..00154d091cc8 100644 --- a/tests/suites/0_stateless/03_dml/03_0003_select_group_by.sql +++ b/tests/suites/0_stateless/03_dml/03_0003_select_group_by.sql @@ -49,3 +49,32 @@ SELECT number, count(*) FROM numbers_mt(1000) group by number order by number li set group_by_two_level_threshold=1000000000; SELECT number, count(*) FROM numbers_mt(1000) group by number order by number limit 5; +SELECT '==GROUP BY Variant=='; +CREATE TABLE IF NOT EXISTS t_variant(id Int null, var Variant null) Engine = Fuse; + +INSERT INTO t_variant VALUES(1, parse_json('{"k":"v"}')), + (2, parse_json('{"k":"v"}')), + (3, parse_json('"abcd"')), + (4, parse_json('"abcd"')), + (5, parse_json('12')), + (6, parse_json('12')), + (7, parse_json('[1,2,3]')), + (8, parse_json('[1,2,3]')); + +SELECT max(id), min(id), var FROM t_variant GROUP BY var ORDER BY var ASC; + +DROP TABLE t_variant; + +SELECT '==GROUP BY Array(Int32)=='; + +CREATE TABLE IF NOT EXISTS t_array(id Int null, arr Array(Int32) null) Engine = Fuse; +INSERT INTO t_array VALUES(1, []), + (2, []), + (3, [1,2,3]), + (4, [1,2,3]), + (5, [4,5,6]), + (6, [4,5,6]); + +SELECT max(id), min(id), arr FROM t_array GROUP BY arr ORDER BY arr ASC; + +DROP TABLE t_array;