Skip to content

Commit

Permalink
fix(query): spill block need consider scalar (#15387)
Browse files Browse the repository at this point in the history
* fix(query): spill block need consider scalar

* refactor some as_column() to convert_to_full_column

Note:

as_column in some way is safe.

Because sometimes the block is already convert_to_full.

And sometimes we can confirme the Value<AnyType> is Column. E.g. in topK

* add function BlockEntry::to_column
  • Loading branch information
TCeason authored May 2, 2024
1 parent 3f8e36c commit 61293e2
Show file tree
Hide file tree
Showing 10 changed files with 37 additions and 25 deletions.
4 changes: 4 additions & 0 deletions src/query/expression/src/block.rs
Original file line number Diff line number Diff line change
Expand Up @@ -73,6 +73,10 @@ impl BlockEntry {
_ => self,
}
}

pub fn to_column(&self, num_rows: usize) -> Column {
self.value.convert_to_full_column(&self.data_type, num_rows)
}
}

#[typetag::serde(tag = "type")]
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -230,8 +230,8 @@ pub fn agg_spilling_aggregate_payload<Method: HashMethodBounds>(
let mut columns_data = Vec::with_capacity(columns.len());
let mut columns_layout = Vec::with_capacity(columns.len());
for column in columns.into_iter() {
let column = column.value.as_column().unwrap();
let column_data = serialize_column(column);
let column = column.to_column(data_block.num_rows());
let column_data = serialize_column(&column);
write_size += column_data.len() as u64;
columns_layout.push(column_data.len() as u64);
columns_data.push(column_data);
Expand Down Expand Up @@ -327,8 +327,8 @@ pub fn spilling_aggregate_payload<Method: HashMethodBounds>(
let mut columns_layout = Vec::with_capacity(columns.len());

for column in columns.into_iter() {
let column = column.value.as_column().unwrap();
let column_data = serialize_column(column);
let column = column.to_column(data_block.num_rows());
let column_data = serialize_column(&column);
write_size += column_data.len() as u64;
columns_layout.push(column_data.len() as u64);
columns_data.push(column_data);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -277,8 +277,8 @@ fn agg_spilling_aggregate_payload<Method: HashMethodBounds>(
let mut columns_layout = Vec::with_capacity(columns.len());

for column in columns.into_iter() {
let column = column.value.as_column().unwrap();
let column_data = serialize_column(column);
let column = column.to_column(data_block.num_rows());
let column_data = serialize_column(&column);
write_size += column_data.len() as u64;
columns_layout.push(column_data.len() as u64);
columns_data.push(column_data);
Expand Down Expand Up @@ -398,8 +398,8 @@ fn spilling_aggregate_payload<Method: HashMethodBounds>(
let mut columns_layout = Vec::with_capacity(columns.len());

for column in columns.into_iter() {
let column = column.value.as_column().unwrap();
let column_data = serialize_column(column);
let column = column.to_column(data_block.num_rows());
let column_data = serialize_column(&column);
write_size += column_data.len() as u64;
columns_layout.push(column_data.len() as u64);
columns_data.push(column_data);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -322,6 +322,7 @@ fn agg_spilling_group_by_payload<Method: HashMethodBounds>(
}

let data_block = payload.group_by_flush_all()?;
let num_rows = data_block.num_rows();
rows += data_block.num_rows();

let old_write_size = write_size;
Expand All @@ -330,8 +331,8 @@ fn agg_spilling_group_by_payload<Method: HashMethodBounds>(
let mut columns_layout = Vec::with_capacity(columns.len());

for column in columns.into_iter() {
let column = column.value.as_column().unwrap();
let column_data = serialize_column(column);
let column = column.to_column(num_rows);
let column_data = serialize_column(&column);
write_size += column_data.len() as u64;
columns_layout.push(column_data.len() as u64);
columns_data.push(column_data);
Expand Down Expand Up @@ -440,6 +441,7 @@ fn spilling_group_by_payload<Method: HashMethodBounds>(
}

let data_block = serialize_group_by(method, inner_table)?;
let num_rows = data_block.num_rows();
rows += 0;

let old_write_size = write_size;
Expand All @@ -448,8 +450,8 @@ fn spilling_group_by_payload<Method: HashMethodBounds>(
let mut columns_layout = Vec::with_capacity(columns.len());

for column in columns.into_iter() {
let column = column.value.as_column().unwrap();
let column_data = serialize_column(column);
let column = column.to_column(num_rows);
let column_data = serialize_column(&column);
write_size += column_data.len() as u64;
columns_layout.push(column_data.len() as u64);
columns_data.push(column_data);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -227,8 +227,10 @@ pub fn agg_spilling_group_by_payload<Method: HashMethodBounds>(
let mut columns_data = Vec::with_capacity(columns.len());
let mut columns_layout = Vec::with_capacity(columns.len());
for column in columns.into_iter() {
let column = column.value.as_column().unwrap();
let column_data = serialize_column(column);
let column = column
.value
.convert_to_full_column(&column.data_type, data_block.num_rows());
let column_data = serialize_column(&column);
write_size += column_data.len() as u64;
columns_layout.push(column_data.len() as u64);
columns_data.push(column_data);
Expand Down Expand Up @@ -320,8 +322,8 @@ pub fn spilling_group_by_payload<Method: HashMethodBounds>(
let mut columns_data = Vec::with_capacity(columns.len());
let mut columns_layout = Vec::with_capacity(columns.len());
for column in columns.into_iter() {
let column = column.value.as_column().unwrap();
let column_data = serialize_column(column);
let column = column.to_column(data_block.num_rows());
let column_data = serialize_column(&column);
write_size += column_data.len() as u64;
columns_layout.push(column_data.len() as u64);
columns_data.push(column_data);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -228,14 +228,15 @@ impl<Method: HashMethodBounds> TransformPartialAggregate<Method> {
let aggregate_functions = &self.params.aggregate_functions;
let offsets_aggregate_states = &self.params.offsets_aggregate_states;

let num_rows = block.num_rows();
for index in 0..aggregate_functions.len() {
// Aggregation states are in the back of the block.
let agg_index = block.num_columns() - aggregate_functions.len() + index;
let function = &aggregate_functions[index];
let offset = offsets_aggregate_states[index];
let agg_state = block.get_by_offset(agg_index).value.as_column().unwrap();
let agg_state = block.get_by_offset(agg_index).to_column(num_rows);

function.batch_merge(places, offset, agg_state)?;
function.batch_merge(places, offset, &agg_state)?;
}

Ok(())
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -48,7 +48,7 @@ impl HashJoinProbeState {
let max_block_size = probe_state.max_block_size;
// `probe_column` is the subquery result column.
// For sql: select * from t1 where t1.a in (select t2.a from t2); t2.a is the `probe_column`,
let probe_column = input.get_by_offset(0).value.as_column().unwrap();
let probe_column = input.get_by_offset(0).to_column(input.num_rows());
// Check if there is any null in the probe column.
if matches!(probe_column.validity().1, Some(x) if x.unset_bits() > 0) {
let mut has_null = self
Expand Down Expand Up @@ -148,7 +148,7 @@ impl HashJoinProbeState {
let max_block_size = probe_state.max_block_size;
// `probe_column` is the subquery result column.
// For sql: select * from t1 where t1.a in (select t2.a from t2); t2.a is the `probe_column`,
let probe_column = input.get_by_offset(0).value.as_column().unwrap();
let probe_column = input.get_by_offset(0).to_column(input.num_rows());
// Check if there is any null in the probe column.
if matches!(probe_column.validity().1, Some(x) if x.unset_bits() > 0) {
let mut has_null = self
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -118,10 +118,11 @@ impl HashJoinProbeState {
let build_indexes_ptr = build_indexes.as_mut_ptr();
let pointers = probe_state.hashes.as_slice();
let selection = &probe_state.selection.as_slice()[0..probe_state.selection_count];
let num_rows = input.num_rows();
let cols = input
.columns()
.iter()
.map(|c| (c.value.as_column().unwrap().clone(), c.data_type.clone()))
.map(|c| (c.to_column(num_rows), c.data_type.clone()))
.collect::<Vec<_>>();
let markers = probe_state.markers.as_mut().unwrap();
self.hash_join_state
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -574,7 +574,7 @@ impl<T: Number> TransformWindow<T> {
}
if cur != self.frame_end {
let block = &self.blocks.get(cur.block - self.first_block).unwrap().block;
let col = block.get_by_offset(func.arg).value.as_column().unwrap();
let col = block.get_by_offset(func.arg).to_column(block.num_rows());
col.index(cur.row).unwrap().to_owned()
} else {
// No such row
Expand All @@ -585,7 +585,7 @@ impl<T: Number> TransformWindow<T> {
let cur = self.goback_row(self.frame_end);
debug_assert!(self.frame_start <= cur);
let block = &self.blocks.get(cur.block - self.first_block).unwrap().block;
let col = block.get_by_offset(func.arg).value.as_column().unwrap();
let col = block.get_by_offset(func.arg).to_column(block.num_rows());
col.index(cur.row).unwrap().to_owned()
};
let builder = &mut self.blocks[self.current_row.block - self.first_block].builder;
Expand Down
6 changes: 4 additions & 2 deletions src/query/service/src/spillers/spiller.rs
Original file line number Diff line number Diff line change
Expand Up @@ -157,8 +157,10 @@ impl Spiller {
let columns = data.columns().to_vec();
let mut columns_data = Vec::with_capacity(columns.len());
for column in columns.into_iter() {
let column = column.value.as_column().unwrap();
let column_data = serialize_column(column);
let column = column
.value
.convert_to_full_column(&column.data_type, data.num_rows());
let column_data = serialize_column(&column);
self.columns_layout
.entry(location.to_string())
.and_modify(|layouts| {
Expand Down

0 comments on commit 61293e2

Please sign in to comment.