diff --git a/src/query/expression/src/aggregate/aggregate_function.rs b/src/query/expression/src/aggregate/aggregate_function.rs index bb31d2f1afe5..c49f9d9daecc 100755 --- a/src/query/expression/src/aggregate/aggregate_function.rs +++ b/src/query/expression/src/aggregate/aggregate_function.rs @@ -16,6 +16,7 @@ use std::alloc::Layout; use std::fmt; use std::ops::Index; use std::ops::Range; +use std::slice::SliceIndex; use std::sync::Arc; use databend_common_arrow::arrow::bitmap::Bitmap; @@ -201,7 +202,8 @@ impl<'a> InputColumns<'a> { } } - pub fn slice(&self, index: Range) -> InputColumns<'_> { + pub fn slice(&self, index: I) -> InputColumns<'_> + where I: SliceIndex<[usize], Output = [usize]> + SliceIndex<[Column], Output = [Column]> { match self { Self::Slice(s) => Self::Slice(&s[index]), Self::Block(BlockProxy { args, data }) => Self::Block(BlockProxy { @@ -234,12 +236,32 @@ pub struct InputColumnsIter<'a> { this: &'a InputColumns<'a>, } +unsafe impl<'a> std::iter::TrustedLen for InputColumnsIter<'a> {} + impl<'a> Iterator for InputColumnsIter<'a> { type Item = &'a Column; fn next(&mut self) -> Option { self.iter.next().map(|index| self.this.index(index)) } + + fn size_hint(&self) -> (usize, Option) { + self.iter.size_hint() + } + + fn count(self) -> usize + where Self: Sized { + self.iter.count() + } + + fn nth(&mut self, n: usize) -> Option { + self.iter.nth(n).map(|index| self.this.index(index)) + } + + fn last(self) -> Option + where Self: Sized { + self.iter.last().map(|index| self.this.index(index)) + } } impl<'a> From<&'a [Column]> for InputColumns<'a> { diff --git a/src/query/expression/src/aggregate/aggregate_hashtable.rs b/src/query/expression/src/aggregate/aggregate_hashtable.rs index 7e1873b09268..00c48e83293e 100644 --- a/src/query/expression/src/aggregate/aggregate_hashtable.rs +++ b/src/query/expression/src/aggregate/aggregate_hashtable.rs @@ -32,6 +32,7 @@ use crate::AggregateFunctionRef; use crate::Column; use crate::ColumnBuilder; use crate::HashTableConfig; +use crate::InputColumns; use crate::Payload; use crate::StateAddr; use crate::BATCH_SIZE; @@ -127,9 +128,9 @@ impl AggregateHashTable { pub fn add_groups( &mut self, state: &mut ProbeState, - group_columns: &[Column], - params: &[Vec], - agg_states: &[Column], + group_columns: InputColumns, + params: &[InputColumns], + agg_states: InputColumns, row_count: usize, ) -> Result { if row_count <= BATCH_ADD_SIZE { @@ -147,6 +148,7 @@ impl AggregateHashTable { .iter() .map(|c| c.iter().map(|x| x.slice(start..end)).collect()) .collect::>(); + let step_params = step_params.iter().map(|v| v.into()).collect::>(); let agg_states = agg_states .iter() .map(|c| c.slice(start..end)) @@ -154,9 +156,9 @@ impl AggregateHashTable { new_count += self.add_groups_inner( state, - &step_group_columns, + (&step_group_columns).into(), &step_params, - &agg_states, + (&agg_states).into(), end - start, )?; } @@ -168,9 +170,9 @@ impl AggregateHashTable { fn add_groups_inner( &mut self, state: &mut ProbeState, - group_columns: &[Column], - params: &[Vec], - agg_states: &[Column], + group_columns: InputColumns, + params: &[InputColumns], + agg_states: InputColumns, row_count: usize, ) -> Result { state.row_count = row_count; @@ -205,7 +207,7 @@ impl AggregateHashTable { .zip(params.iter()) .zip(self.payload.state_addr_offsets.iter()) { - aggr.accumulate_keys(state_places, *addr_offset, params.into(), row_count)?; + aggr.accumulate_keys(state_places, *addr_offset, *params, row_count)?; } } else { for ((aggr, agg_state), addr_offset) in self @@ -242,7 +244,7 @@ impl AggregateHashTable { fn probe_and_create( &mut self, state: &mut ProbeState, - group_columns: &[Column], + group_columns: InputColumns, row_count: usize, ) -> usize { // exceed capacity or should resize @@ -390,7 +392,7 @@ impl AggregateHashTable { let _ = self.probe_and_create( &mut flush_state.probe_state, - &flush_state.group_columns, + (&flush_state.group_columns).into(), row_count, ); diff --git a/src/query/expression/src/aggregate/group_hash.rs b/src/query/expression/src/aggregate/group_hash.rs index f878a80faad1..e5229148c7cd 100644 --- a/src/query/expression/src/aggregate/group_hash.rs +++ b/src/query/expression/src/aggregate/group_hash.rs @@ -33,17 +33,17 @@ use crate::types::ValueType; use crate::types::VariantType; use crate::with_number_mapped_type; use crate::Column; +use crate::InputColumns; use crate::ScalarRef; const NULL_HASH_VAL: u64 = 0xd1cefa08eb382d69; -pub fn group_hash_columns(cols: &[Column], values: &mut [u64]) { +pub fn group_hash_columns(cols: InputColumns, values: &mut [u64]) { debug_assert!(!cols.is_empty()); - combine_group_hash_column::(&cols[0], values); - if cols.len() > 1 { - for col in &cols[1..] { - combine_group_hash_column::(col, values); - } + let mut iter = cols.iter(); + combine_group_hash_column::(iter.next().unwrap(), values); + for col in iter { + combine_group_hash_column::(col, values); } } diff --git a/src/query/expression/src/aggregate/partitioned_payload.rs b/src/query/expression/src/aggregate/partitioned_payload.rs index 6a0152b7978a..f813355c18a8 100644 --- a/src/query/expression/src/aggregate/partitioned_payload.rs +++ b/src/query/expression/src/aggregate/partitioned_payload.rs @@ -23,7 +23,7 @@ use super::probe_state::ProbeState; use crate::read; use crate::types::DataType; use crate::AggregateFunctionRef; -use crate::Column; +use crate::InputColumns; use crate::PayloadFlushState; use crate::BATCH_SIZE; @@ -101,7 +101,7 @@ impl PartitionedPayload { &mut self, state: &mut ProbeState, new_group_rows: usize, - group_columns: &[Column], + group_columns: InputColumns, ) { if self.payloads.len() == 1 { self.payloads[0].reserve_append_rows( diff --git a/src/query/expression/src/aggregate/payload.rs b/src/query/expression/src/aggregate/payload.rs index bb029998dedf..4396df667e37 100644 --- a/src/query/expression/src/aggregate/payload.rs +++ b/src/query/expression/src/aggregate/payload.rs @@ -31,6 +31,7 @@ use crate::AggregateFunctionRef; use crate::Column; use crate::ColumnBuilder; use crate::DataBlock; +use crate::InputColumns; use crate::PayloadFlushState; use crate::SelectVector; use crate::StateAddr; @@ -194,7 +195,7 @@ impl Payload { group_hashes: &[u64], address: &mut [*const u8], new_group_rows: usize, - group_columns: &[Column], + group_columns: InputColumns, ) { let tuple_size = self.tuple_size; let mut page = self.writable_page(); @@ -229,11 +230,11 @@ impl Payload { group_hashes: &[u64], address: &mut [*const u8], new_group_rows: usize, - group_columns: &[Column], + group_columns: InputColumns, ) { let mut write_offset = 0; // write validity - for col in group_columns { + for col in group_columns.iter() { if let Column::Nullable(c) = col { let bitmap = &c.validity; if bitmap.unset_bits() == 0 || bitmap.unset_bits() == bitmap.len() { diff --git a/src/query/expression/src/aggregate/payload_row.rs b/src/query/expression/src/aggregate/payload_row.rs index fb2dc158e3ce..06b52b594dcd 100644 --- a/src/query/expression/src/aggregate/payload_row.rs +++ b/src/query/expression/src/aggregate/payload_row.rs @@ -36,6 +36,7 @@ use crate::types::ValueType; use crate::with_decimal_mapped_type; use crate::with_number_mapped_type; use crate::Column; +use crate::InputColumns; use crate::Scalar; use crate::SelectVector; @@ -165,7 +166,7 @@ pub unsafe fn serialize_column_to_rowformat( } pub unsafe fn row_match_columns( - cols: &[Column], + cols: InputColumns, address: &[*const u8], select_vector: &mut SelectVector, temp_vector: &mut SelectVector, diff --git a/src/query/expression/tests/it/aggregate.rs b/src/query/expression/tests/it/aggregate.rs new file mode 100644 index 000000000000..0467e1e73b00 --- /dev/null +++ b/src/query/expression/tests/it/aggregate.rs @@ -0,0 +1,50 @@ +use databend_common_expression::types::*; +use databend_common_expression::FromData; +use databend_common_expression::InputColumns; + +use crate::common::new_block; + +#[test] +fn test_input_columns() { + let strings = (0..10).map(|i: i32| i.to_string()).collect::>(); + let nums = (0..10).collect::>(); + let bools = (0..10).map(|i: usize| i % 2 == 0).collect(); + + let columns = vec![ + StringType::from_data(strings), + Int32Type::from_data(nums), + BooleanType::from_data(bools), + ]; + let block = new_block(&columns); + + let proxy = InputColumns::new_block_proxy(&[1], &block); + assert_eq!(proxy.len(), 1); + + let proxy = InputColumns::new_block_proxy(&[2, 0, 1], &block); + assert_eq!(proxy.len(), 3); + assert!(proxy[0].as_boolean().is_some()); + assert!(proxy[1].as_string().is_some()); + assert!(proxy[2].as_number().is_some()); + + assert_eq!(proxy.iter().count(), 3); + + let mut iter = proxy.iter(); + assert_eq!(iter.size_hint(), (3, Some(3))); + let col = iter.nth(1); + assert!(col.unwrap().as_string().is_some()); + + assert_eq!(iter.size_hint(), (1, Some(1))); + assert_eq!(iter.count(), 1); + + assert!(proxy.iter().last().unwrap().as_number().is_some()); + assert_eq!(proxy.iter().count(), 3); + assert_eq!(proxy.iter().size_hint(), (3, Some(3))); + + let s = proxy.slice(..1); + assert_eq!(s.len(), 1); + assert!(s[0].as_boolean().is_some()); + + let s = proxy.slice(1..=1); + assert_eq!(s.len(), 1); + assert!(s[0].as_string().is_some()); +} diff --git a/src/query/expression/tests/it/main.rs b/src/query/expression/tests/it/main.rs index 5a530493117d..0929645f1180 100644 --- a/src/query/expression/tests/it/main.rs +++ b/src/query/expression/tests/it/main.rs @@ -24,10 +24,12 @@ use databend_common_expression::DataBlock; extern crate core; +mod aggregate; mod block; mod column; mod common; mod decimal; +mod fill_field_default_value; mod group_by; mod kernel; mod meta_scalar; @@ -35,6 +37,7 @@ mod row; mod schema; mod serde; mod sort; +mod types; fn rand_block_for_all_types(num_rows: usize) -> DataBlock { let types = get_all_test_data_types(); diff --git a/src/query/expression/tests/it/testdata/fill_field_default_value.txt b/src/query/expression/tests/it/testdata/fill_field_default_value.txt index 91acc6e23a77..e08a98334334 100644 --- a/src/query/expression/tests/it/testdata/fill_field_default_value.txt +++ b/src/query/expression/tests/it/testdata/fill_field_default_value.txt @@ -3,17 +3,17 @@ Source: +----------+----------+----------+----------+----------+ | Column 0 | Column 1 | Column 2 | Column 3 | Column 4 | +----------+----------+----------+----------+----------+ -| 1 | 2 | 1 | 4 | "x1" | -| 2 | 2 | 2 | 4 | "x2" | -| 3 | 2 | 3 | 4 | "x3" | +| 1 | 2 | 1 | 4 | 'x1' | +| 2 | 2 | 2 | 4 | 'x2' | +| 3 | 2 | 3 | 4 | 'x3' | +----------+----------+----------+----------+----------+ Result: +----------+----------+----------+----------+----------+ | Column 0 | Column 1 | Column 2 | Column 3 | Column 4 | +----------+----------+----------+----------+----------+ -| 1 | 2 | 1 | 4 | "x1" | -| 2 | 2 | 2 | 4 | "x2" | -| 3 | 2 | 3 | 4 | "x3" | +| 1 | 2 | 1 | 4 | 'x1' | +| 2 | 2 | 2 | 4 | 'x2' | +| 3 | 2 | 3 | 4 | 'x3' | +----------+----------+----------+----------+----------+ @@ -22,17 +22,17 @@ Source: +----------+----------+----------+----------+----------+ | Column 0 | Column 1 | Column 2 | Column 3 | Column 4 | +----------+----------+----------+----------+----------+ -| 1 | 0 | 10 | 0 | "ab" | -| 1 | 0 | 10 | 0 | "ab" | -| 1 | 0 | 10 | 0 | "ab" | +| 1 | 0 | 10 | 0 | 'ab' | +| 1 | 0 | 10 | 0 | 'ab' | +| 1 | 0 | 10 | 0 | 'ab' | +----------+----------+----------+----------+----------+ Result: +----------+----------+----------+----------+----------+ | Column 0 | Column 1 | Column 2 | Column 3 | Column 4 | +----------+----------+----------+----------+----------+ -| 1 | 0 | 10 | 0 | "ab" | -| 1 | 0 | 10 | 0 | "ab" | -| 1 | 0 | 10 | 0 | "ab" | +| 1 | 0 | 10 | 0 | 'ab' | +| 1 | 0 | 10 | 0 | 'ab' | +| 1 | 0 | 10 | 0 | 'ab' | +----------+----------+----------+----------+----------+ @@ -41,17 +41,17 @@ Source: +----------+----------+----------+----------+----------+ | Column 0 | Column 1 | Column 2 | Column 3 | Column 4 | +----------+----------+----------+----------+----------+ -| 1 | 2 | 10 | 4 | "ab" | -| 2 | 2 | 10 | 5 | "ab" | -| 3 | 2 | 10 | 6 | "ab" | +| 1 | 2 | 10 | 4 | 'ab' | +| 2 | 2 | 10 | 5 | 'ab' | +| 3 | 2 | 10 | 6 | 'ab' | +----------+----------+----------+----------+----------+ Result: +----------+----------+----------+----------+----------+ | Column 0 | Column 1 | Column 2 | Column 3 | Column 4 | +----------+----------+----------+----------+----------+ -| 1 | 2 | 10 | 4 | "ab" | -| 2 | 2 | 10 | 5 | "ab" | -| 3 | 2 | 10 | 6 | "ab" | +| 1 | 2 | 10 | 4 | 'ab' | +| 2 | 2 | 10 | 5 | 'ab' | +| 3 | 2 | 10 | 6 | 'ab' | +----------+----------+----------+----------+----------+ diff --git a/src/query/expression/tests/it/types.rs b/src/query/expression/tests/it/types.rs index 0b927af44651..d098600f2148 100644 --- a/src/query/expression/tests/it/types.rs +++ b/src/query/expression/tests/it/types.rs @@ -18,29 +18,11 @@ use databend_common_expression::types::timestamp::timestamp_to_string; #[test] fn test_timestamp_to_string_formats() { // Unix timestamp for "2024-01-01 01:02:03" UTC - let ts = 1_704_070_923; - + let ts = 1_704_070_923_000_000; let tz = Tz::UTC; - // Test with a valid format - let ts_format = "%Y-%m-%d %H:%M:%S"; assert_eq!( - timestamp_to_string(ts, tz, ts_format).to_string(), - "2024-01-01 01:02:03" - ); - - // Test with a format including fraction of a second - let ts_format = "%Y-%m-%d %H:%M:%S%.6f"; - assert_eq!( - timestamp_to_string(ts, tz, ts_format).to_string(), + timestamp_to_string(ts, tz).to_string(), "2024-01-01 01:02:03.000000" ); - - // Test with an invalid format (should use default format) - // let ts_format = "%Y-%Q-%W"; // Invalid format specifiers - // assert_eq!( - // timestamp_to_string(ts, tz, ts_format).to_string(), - // "2024-01-01 01:02:03.000000" // Default format - // ); - // } diff --git a/src/query/functions/src/aggregates/aggregate_combinator_distinct.rs b/src/query/functions/src/aggregates/aggregate_combinator_distinct.rs index 9aac40302a6f..81d2aadd6e44 100644 --- a/src/query/functions/src/aggregates/aggregate_combinator_distinct.rs +++ b/src/query/functions/src/aggregates/aggregate_combinator_distinct.rs @@ -83,24 +83,12 @@ where State: DistinctStateFunc input_rows: usize, ) -> Result<()> { let state = place.get::(); - match columns { - InputColumns::Slice(s) => state.batch_add(s, validity, input_rows), - _ => { - let columns = columns.iter().cloned().collect::>(); - state.batch_add(columns.as_slice(), validity, input_rows) - } - } + state.batch_add(columns, validity, input_rows) } fn accumulate_row(&self, place: StateAddr, columns: InputColumns, row: usize) -> Result<()> { let state = place.get::(); - match columns { - InputColumns::Slice(s) => state.add(s, row), - _ => { - let columns = columns.iter().cloned().collect::>(); - state.add(columns.as_slice(), row) - } - } + state.add(columns, row) } fn serialize(&self, place: StateAddr, writer: &mut Vec) -> Result<()> { diff --git a/src/query/functions/src/aggregates/aggregate_distinct_state.rs b/src/query/functions/src/aggregates/aggregate_distinct_state.rs index b9be795b7376..424043c2903b 100644 --- a/src/query/functions/src/aggregates/aggregate_distinct_state.rs +++ b/src/query/functions/src/aggregates/aggregate_distinct_state.rs @@ -35,6 +35,7 @@ use databend_common_expression::types::StringType; use databend_common_expression::types::ValueType; use databend_common_expression::Column; use databend_common_expression::ColumnBuilder; +use databend_common_expression::InputColumns; use databend_common_expression::Scalar; use databend_common_hashtable::HashSet as CommonHashSet; use databend_common_hashtable::HashtableKeyable; @@ -54,10 +55,10 @@ pub trait DistinctStateFunc: Sized + Send + Sync { fn deserialize(reader: &mut &[u8]) -> Result; fn is_empty(&self) -> bool; fn len(&self) -> usize; - fn add(&mut self, columns: &[Column], row: usize) -> Result<()>; + fn add(&mut self, columns: InputColumns, row: usize) -> Result<()>; fn batch_add( &mut self, - columns: &[Column], + columns: InputColumns, validity: Option<&Bitmap>, input_rows: usize, ) -> Result<()>; @@ -102,7 +103,7 @@ impl DistinctStateFunc for AggregateDistinctState { self.set.len() } - fn add(&mut self, columns: &[Column], row: usize) -> Result<()> { + fn add(&mut self, columns: InputColumns, row: usize) -> Result<()> { let values = columns .iter() .map(|col| unsafe { AnyType::index_column_unchecked(col, row).to_owned() }) @@ -115,7 +116,7 @@ impl DistinctStateFunc for AggregateDistinctState { fn batch_add( &mut self, - columns: &[Column], + columns: InputColumns, validity: Option<&Bitmap>, input_rows: usize, ) -> Result<()> { @@ -191,7 +192,7 @@ impl DistinctStateFunc for AggregateDistinctStringState { self.set.len() } - fn add(&mut self, columns: &[Column], row: usize) -> Result<()> { + fn add(&mut self, columns: InputColumns, row: usize) -> Result<()> { let column = StringType::try_downcast_column(&columns[0]).unwrap(); let data = unsafe { column.index_unchecked(row) }; let _ = self.set.set_insert(data.as_bytes()); @@ -200,7 +201,7 @@ impl DistinctStateFunc for AggregateDistinctStringState { fn batch_add( &mut self, - columns: &[Column], + columns: InputColumns, validity: Option<&Bitmap>, input_rows: usize, ) -> Result<()> { @@ -275,7 +276,7 @@ where T: Number + BorshSerialize + BorshDeserialize + HashtableKeyable self.set.len() } - fn add(&mut self, columns: &[Column], row: usize) -> Result<()> { + fn add(&mut self, columns: InputColumns, row: usize) -> Result<()> { let col = NumberType::::try_downcast_column(&columns[0]).unwrap(); let v = unsafe { col.get_unchecked(row) }; let _ = self.set.set_insert(*v).is_ok(); @@ -284,7 +285,7 @@ where T: Number + BorshSerialize + BorshDeserialize + HashtableKeyable fn batch_add( &mut self, - columns: &[Column], + columns: InputColumns, validity: Option<&Bitmap>, input_rows: usize, ) -> Result<()> { @@ -356,7 +357,7 @@ impl DistinctStateFunc for AggregateUniqStringState { self.set.len() } - fn add(&mut self, columns: &[Column], row: usize) -> Result<()> { + fn add(&mut self, columns: InputColumns, row: usize) -> Result<()> { let column = columns[0].as_string().unwrap(); let data = unsafe { column.index_unchecked(row) }; let mut hasher = SipHasher24::new(); @@ -368,7 +369,7 @@ impl DistinctStateFunc for AggregateUniqStringState { fn batch_add( &mut self, - columns: &[Column], + columns: InputColumns, validity: Option<&Bitmap>, input_rows: usize, ) -> Result<()> { diff --git a/src/query/functions/tests/it/aggregates/agg_hashtable.rs b/src/query/functions/tests/it/aggregates/agg_hashtable.rs index 6f00b119c004..c329f5e52039 100644 --- a/src/query/functions/tests/it/aggregates/agg_hashtable.rs +++ b/src/query/functions/tests/it/aggregates/agg_hashtable.rs @@ -88,6 +88,7 @@ fn test_agg_hashtable() { ]; let params: Vec> = aggrs.iter().map(|_| vec![columns[1].clone()]).collect(); + let params = params.iter().map(|v| v.into()).collect_vec(); let config = HashTableConfig::default(); let mut hashtable = AggregateHashTable::new( @@ -99,7 +100,13 @@ fn test_agg_hashtable() { let mut state = ProbeState::default(); let _ = hashtable - .add_groups(&mut state, &group_columns, ¶ms, &[], n) + .add_groups( + &mut state, + (&group_columns).into(), + ¶ms, + (&[]).into(), + n, + ) .unwrap(); let mut hashtable2 = AggregateHashTable::new( @@ -111,7 +118,13 @@ fn test_agg_hashtable() { let mut state2 = ProbeState::default(); let _ = hashtable2 - .add_groups(&mut state2, &group_columns, ¶ms, &[], n) + .add_groups( + &mut state2, + (&group_columns).into(), + ¶ms, + (&[]).into(), + n, + ) .unwrap(); let mut flush_state = PayloadFlushState::default(); diff --git a/src/query/service/src/pipelines/processors/transforms/aggregator/aggregate_meta.rs b/src/query/service/src/pipelines/processors/transforms/aggregator/aggregate_meta.rs index f8016c9e3ff1..2358d2d544a4 100644 --- a/src/query/service/src/pipelines/processors/transforms/aggregator/aggregate_meta.rs +++ b/src/query/service/src/pipelines/processors/transforms/aggregator/aggregate_meta.rs @@ -27,6 +27,7 @@ use databend_common_expression::BlockMetaInfoPtr; use databend_common_expression::Column; use databend_common_expression::DataBlock; use databend_common_expression::HashTableConfig; +use databend_common_expression::InputColumns; use databend_common_expression::PartitionedPayload; use databend_common_expression::Payload; use databend_common_expression::ProbeState; @@ -76,29 +77,19 @@ impl SerializedPayload { need_init_entry, ); - let agg_states = (0..agg_len) - .map(|i| { - self.data_block - .get_by_offset(i) - .value - .as_column() - .unwrap() - .clone() - }) - .collect::>(); - let group_columns = (agg_len..(agg_len + group_len)) - .map(|i| { - self.data_block - .get_by_offset(i) - .value - .as_column() - .unwrap() - .clone() - }) - .collect::>(); + let states_index: Vec = (0..agg_len).collect(); + let agg_states = InputColumns::new_block_proxy(&states_index, &self.data_block); - let _ = - hashtable.add_groups(&mut state, &group_columns, &[vec![]], &agg_states, rows_num)?; + let group_index: Vec = (agg_len..(agg_len + group_len)).collect(); + let group_columns = InputColumns::new_block_proxy(&group_index, &self.data_block); + + let _ = hashtable.add_groups( + &mut state, + group_columns, + &[(&[]).into()], + agg_states, + rows_num, + )?; hashtable.payload.mark_min_cardinality(); Ok(hashtable) diff --git a/src/query/service/src/pipelines/processors/transforms/aggregator/transform_aggregate_partial.rs b/src/query/service/src/pipelines/processors/transforms/aggregator/transform_aggregate_partial.rs index d6be772031d2..b91f9c583340 100644 --- a/src/query/service/src/pipelines/processors/transforms/aggregator/transform_aggregate_partial.rs +++ b/src/query/service/src/pipelines/processors/transforms/aggregator/transform_aggregate_partial.rs @@ -25,9 +25,9 @@ use databend_common_exception::ErrorCode; use databend_common_exception::Result; use databend_common_expression::AggregateHashTable; use databend_common_expression::BlockMetaInfoDowncast; -use databend_common_expression::Column; use databend_common_expression::DataBlock; use databend_common_expression::HashTableConfig; +use databend_common_expression::InputColumns; use databend_common_expression::PayloadFlushState; use databend_common_expression::ProbeState; use databend_common_functions::aggregates::StateAddr; @@ -170,30 +170,14 @@ impl TransformPartialAggregate { // Block should be `convert_to_full`. #[inline(always)] - fn aggregate_arguments( - block: &DataBlock, - params: &Arc, - ) -> Result>> { - let aggregate_functions_arguments = ¶ms.aggregate_functions_arguments; - let mut aggregate_arguments_columns = - Vec::with_capacity(aggregate_functions_arguments.len()); - for function_arguments in aggregate_functions_arguments { - let mut function_arguments_column = Vec::with_capacity(function_arguments.len()); - - for argument_index in function_arguments { - // Unwrap safety: chunk has been `convert_to_full`. - let argument_column = block - .get_by_offset(*argument_index) - .value - .as_column() - .unwrap(); - function_arguments_column.push(argument_column.clone()); - } - - aggregate_arguments_columns.push(function_arguments_column); - } - - Ok(aggregate_arguments_columns) + fn aggregate_arguments<'a>( + block: &'a DataBlock, + aggregate_functions_arguments: &'a [Vec], + ) -> Vec> { + aggregate_functions_arguments + .iter() + .map(|function_arguments| InputColumns::new_block_proxy(function_arguments, block)) + .collect::>() } #[inline(always)] @@ -203,20 +187,26 @@ impl TransformPartialAggregate { block: &DataBlock, places: &StateAddrs, ) -> Result<()> { - let aggregate_functions = ¶ms.aggregate_functions; - let offsets_aggregate_states = ¶ms.offsets_aggregate_states; - let aggregate_arguments_columns = Self::aggregate_arguments(block, params)?; + let AggregatorParams { + aggregate_functions, + offsets_aggregate_states, + aggregate_functions_arguments, + .. + } = &**params; // This can beneficial for the case of dereferencing // This will help improve the performance ~hundreds of megabits per second - let aggr_arg_columns_slice = &aggregate_arguments_columns; - + let aggr_arg_columns = Self::aggregate_arguments(block, aggregate_functions_arguments); + let aggr_arg_columns = aggr_arg_columns.as_slice(); let rows = block.num_rows(); for index in 0..aggregate_functions.len() { let function = &aggregate_functions[index]; - let state_offset = offsets_aggregate_states[index]; - let function_arguments = &aggr_arg_columns_slice[index]; - function.accumulate_keys(places, state_offset, function_arguments.into(), rows)?; + function.accumulate_keys( + places, + offsets_aggregate_states[index], + aggr_arg_columns[index], + rows, + )?; } Ok(()) @@ -259,7 +249,7 @@ impl TransformPartialAggregate { .map(|c| (c.value.as_column().unwrap().clone(), c.data_type.clone())) .collect::>(); - unsafe { + { let rows_num = block.num_rows(); match &mut self.hash_table { @@ -269,7 +259,7 @@ impl TransformPartialAggregate { let mut places = Vec::with_capacity(rows_num); for key in self.method.build_keys_iter(&state)? { - places.push(match hashtable.hashtable.insert_and_entry(key) { + places.push(match unsafe { hashtable.hashtable.insert_and_entry(key) } { Err(entry) => Into::::into(*entry.get()), Ok(mut entry) => { let place = self.params.alloc_layout(&mut hashtable.arena); @@ -290,7 +280,7 @@ impl TransformPartialAggregate { let mut places = Vec::with_capacity(rows_num); for key in self.method.build_keys_iter(&state)? { - places.push(match hashtable.hashtable.insert_and_entry(key) { + places.push(match unsafe { hashtable.hashtable.insert_and_entry(key) } { Err(entry) => Into::::into(*entry.get()), Ok(mut entry) => { let place = self.params.alloc_layout(&mut hashtable.arena); @@ -307,36 +297,37 @@ impl TransformPartialAggregate { } } HashTable::AggregateHashTable(hashtable) => { - let group_columns: Vec = - group_columns.into_iter().map(|c| c.0).collect(); + let group_columns = + InputColumns::new_block_proxy(&self.params.group_columns, &block); - let (params_columns, agg_states) = if is_agg_index_block { + let (params_columns, states_index) = if is_agg_index_block { + let num_columns = block.num_columns(); + let functions_count = self.params.aggregate_functions.len(); + ( + vec![], + (num_columns - functions_count..num_columns).collect::>(), + ) + } else { ( + Self::aggregate_arguments( + &block, + &self.params.aggregate_functions_arguments, + ), vec![], - (0..self.params.aggregate_functions.len()) - .map(|index| { - block - .get_by_offset( - block.num_columns() - - self.params.aggregate_functions.len() - + index, - ) - .value - .as_column() - .cloned() - .unwrap() - }) - .collect(), ) + }; + + let agg_states = if !states_index.is_empty() { + InputColumns::new_block_proxy(&states_index, &block) } else { - (Self::aggregate_arguments(&block, &self.params)?, vec![]) + (&[]).into() }; let _ = hashtable.add_groups( &mut self.probe_state, - &group_columns, + group_columns, ¶ms_columns, - &agg_states, + agg_states, rows_num, )?; Ok(()) diff --git a/src/query/service/src/pipelines/processors/transforms/aggregator/transform_group_by_partial.rs b/src/query/service/src/pipelines/processors/transforms/aggregator/transform_group_by_partial.rs index 70815d0113a1..915a2c6989c0 100644 --- a/src/query/service/src/pipelines/processors/transforms/aggregator/transform_group_by_partial.rs +++ b/src/query/service/src/pipelines/processors/transforms/aggregator/transform_group_by_partial.rs @@ -23,9 +23,9 @@ use databend_common_catalog::table_context::TableContext; use databend_common_exception::ErrorCode; use databend_common_exception::Result; use databend_common_expression::AggregateHashTable; -use databend_common_expression::Column; use databend_common_expression::DataBlock; use databend_common_expression::HashTableConfig; +use databend_common_expression::InputColumns; use databend_common_expression::PayloadFlushState; use databend_common_expression::ProbeState; use databend_common_hashtable::HashtableLike; @@ -156,15 +156,13 @@ impl AccumulatingTransform for TransformPartialGroupBy .params .group_columns .iter() - .map(|&index| block.get_by_offset(index)) + .map(|&index| { + let c = block.get_by_offset(index); + (c.value.as_column().unwrap().clone(), c.data_type.clone()) + }) .collect::>(); - let group_columns = group_columns - .iter() - .map(|c| (c.value.as_column().unwrap().clone(), c.data_type.clone())) - .collect::>(); - - unsafe { + { let rows_num = block.num_rows(); match &mut self.hash_table { @@ -172,23 +170,27 @@ impl AccumulatingTransform for TransformPartialGroupBy HashTable::HashTable(cell) => { let state = self.method.build_keys_state(&group_columns, rows_num)?; for key in self.method.build_keys_iter(&state)? { - let _ = cell.hashtable.insert_and_entry(key); + unsafe { + let _ = cell.hashtable.insert_and_entry(key); + } } } HashTable::PartitionedHashTable(cell) => { let state = self.method.build_keys_state(&group_columns, rows_num)?; for key in self.method.build_keys_iter(&state)? { - let _ = cell.hashtable.insert_and_entry(key); + unsafe { + let _ = cell.hashtable.insert_and_entry(key); + } } } HashTable::AggregateHashTable(hashtable) => { - let group_columns: Vec = - group_columns.into_iter().map(|c| c.0).collect(); + let group_columns = + InputColumns::new_block_proxy(&self.params.group_columns, &block); let _ = hashtable.add_groups( &mut self.probe_state, - &group_columns, - &[vec![]], - &[], + group_columns, + &[(&[]).into()], + (&[]).into(), rows_num, )?; }