Skip to content

Commit

Permalink
fix: fix segfault in list_domain (#16979)
Browse files Browse the repository at this point in the history
* fix list_domain

Signed-off-by: coldWater <forsaken628@gmail.com>

* fix

Signed-off-by: coldWater <forsaken628@gmail.com>

---------

Signed-off-by: coldWater <forsaken628@gmail.com>
Co-authored-by: TCeason <33082201+TCeason@users.noreply.github.com>
  • Loading branch information
forsaken628 and TCeason authored Dec 2, 2024
1 parent f40b8c3 commit 37de572
Show file tree
Hide file tree
Showing 2 changed files with 94 additions and 45 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -150,10 +150,9 @@ where
}

fn calc_partition_point(&self) -> Partition {
let mut candidate = Candidate::new(&self.rows, EndDomain {
min: self.min_task,
max: self.max_task,
});
let mut candidate =
Candidate::new(&self.rows, EndDomain::new(self.min_task, self.max_task));

candidate.init();

// if candidate.is_small_task() {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -74,7 +74,7 @@ where Self: Debug

#[derive(Debug)]
pub struct Partition {
pub ends: Vec<(usize, usize)>,
pub ends: Vec<(usize, usize)>, // index, partition point
pub total: usize,
}

Expand Down Expand Up @@ -117,8 +117,11 @@ where
T: List + 'a,
T::Item<'a>: Debug,
{
// The cut point value.
target: T::Item<'a>,
// The domain of partition point in each list, and also the domain of the generated task size in each list.
domains: Vec<EndDomain>,
// The size domain of the generated task if the current target is used as the cut point.
sum: EndDomain,
}

Expand All @@ -136,6 +139,7 @@ where T: List
}

pub fn init(&mut self) -> bool {
// Take the smallest first and the smallest last of all the lists as the initial target range.
let target: (Option<T::Item<'a>>, Option<T::Item<'a>>) =
self.all_list.iter().fold((None, None), |(min, max), ls| {
let min = match (min, ls.first()) {
Expand All @@ -151,9 +155,8 @@ where T: List

(min, max)
});
let (min_target, max_target) = if let (Some(min), Some(max)) = target {
(min, max)
} else {
let (Some(min_target), Some(max_target)) = target else {
// invalid empty input
return false;
};

Expand Down Expand Up @@ -189,7 +192,7 @@ where T: List

pub fn is_small_task(&mut self) -> bool {
loop {
let sum = self.do_search_max(Some(8));
let sum = self.reduce_max_domain(Some(8));
match self.expect.overlaps(sum) {
Overlap::Left => return true,
Overlap::Right => return false,
Expand All @@ -203,7 +206,7 @@ where T: List
for _ in 0..max_iter {
match self.overlaps() {
(_, _, Overlap::Cross) => {
let sum = self.do_search_max(Some(n));
let sum = self.reduce_max_domain(Some(n));
if self.is_finish(sum) {
return Partition::new(self.max_target.unwrap());
}
Expand All @@ -221,7 +224,7 @@ where T: List
Some(Overlap::Cross),
Overlap::Right,
) => {
let sum = self.do_search_mid(Some(n));
let sum = self.reduce_mid_domain(Some(n));
match self.expect.overlaps(sum) {
Overlap::Right => self.cut_right(),
Overlap::Left if matches!(min_overlap, Overlap::Left) => self.cut_left(),
Expand All @@ -232,7 +235,7 @@ where T: List
}
}
(Overlap::Cross, Some(Overlap::Left), Overlap::Right) => {
let sum = self.do_search_min(Some(n));
let sum = self.reduce_min_domain(Some(n));
match self.expect.overlaps(sum) {
Overlap::Left => self.cut_left(),
Overlap::Cross if sum.done() => {
Expand All @@ -251,19 +254,19 @@ where T: List
};
}

self.do_search_max(None);
self.reduce_max_domain(None);
Partition::new(self.max_target.unwrap())
}

fn do_search_max(&mut self, n: Option<usize>) -> EndDomain {
fn reduce_max_domain(&mut self, n: Option<usize>) -> EndDomain {
do_search(self.all_list, self.max_target.as_mut().unwrap(), n)
}

fn do_search_min(&mut self, n: Option<usize>) -> EndDomain {
fn reduce_min_domain(&mut self, n: Option<usize>) -> EndDomain {
do_search(self.all_list, self.min_target.as_mut().unwrap(), n)
}

fn do_search_mid(&mut self, n: Option<usize>) -> EndDomain {
fn reduce_mid_domain(&mut self, n: Option<usize>) -> EndDomain {
do_search(self.all_list, self.mid_target.as_mut().unwrap(), n)
}

Expand All @@ -290,11 +293,8 @@ where T: List
if max_domain.is_zero() {
continue;
}
let five = EndDomain {
min: min_domain.min,
max: max_domain.min,
}
.five_point();

let five = min_domain.merge(max_domain).five_point();
for v in five.into_iter().filter_map(|i| {
let v = ls.index(i);
if v >= *min_target && v <= *max_target {
Expand Down Expand Up @@ -336,6 +336,7 @@ where T: List
}

fn overlaps(&self) -> (Overlap, Option<Overlap>, Overlap) {
// Compare expect task size domain with min_target,mid_target and max_target task size domain.
(
self.expect.overlaps(self.min_target.as_ref().unwrap().sum),
self.mid_target
Expand Down Expand Up @@ -392,11 +393,16 @@ where

#[derive(Debug, Clone, Copy, PartialEq, Eq, Default)]
pub struct EndDomain {
pub min: usize,
pub max: usize,
min: usize,
max: usize,
}

impl EndDomain {
pub fn new(min: usize, max: usize) -> EndDomain {
assert!(min <= max);
EndDomain { min, max }
}

fn done(&self) -> bool {
self.min == self.max
}
Expand Down Expand Up @@ -453,6 +459,13 @@ impl EndDomain {
],
}
}

fn merge(&self, other: &EndDomain) -> EndDomain {
EndDomain {
min: self.min.min(other.min),
max: self.max.max(other.max),
}
}
}

#[derive(Debug)]
Expand Down Expand Up @@ -481,15 +494,18 @@ impl std::iter::Sum for EndDomain {

impl From<std::ops::RangeInclusive<usize>> for EndDomain {
fn from(value: std::ops::RangeInclusive<usize>) -> Self {
EndDomain {
min: *value.start(),
max: *value.end(),
}
EndDomain::new(*value.start(), *value.end())
}
}

#[cfg(test)]
mod tests {
use std::iter::repeat_with;

use rand::rngs::StdRng;
use rand::Rng;
use rand::SeedableRng;

use super::*;

impl List for &[i32] {
Expand Down Expand Up @@ -525,35 +541,67 @@ mod tests {
run_test(&all_list, (5..=10).into(), 10);
}

{
let all_list = issue_16923();

let all_list: Vec<_> = all_list.iter().map(|v| v.as_slice()).collect();
run_test(&all_list, (5..=100).into(), 20);
}

for _ in 0..100 {
let all_list = rand_data();
let all_list = rand_data(rand::random());
let all_list: Vec<_> = all_list.iter().map(|v| v.as_slice()).collect();

run_test(&all_list, (5..=10).into(), 10)
run_test(&all_list, (5..=100).into(), 20)
}
}

fn rand_data() -> Vec<Vec<i32>> {
use rand::Rng;
let mut rng = rand::thread_rng();
fn rand_data(seed: u64) -> Vec<Vec<i32>> {
let mut rng = StdRng::seed_from_u64(seed);

(0..5)
.map(|_| {
let rows: usize = rng.gen_range(0..=20);
let mut data = (0..rows)
.map(|_| rng.gen_range(0..=1000))
.collect::<Vec<_>>();
data.sort();
data
})
.collect::<Vec<_>>()
let list = rng.gen_range(1..=10);
repeat_with(|| {
let rows = rng.gen_range(0..=40);
let mut data = repeat_with(|| rng.gen_range(0..=1000))
.take(rows)
.collect::<Vec<_>>();
data.sort();
data
})
.take(list)
.collect::<Vec<_>>()
}

fn issue_16923() -> Vec<Vec<i32>> {
vec![
vec![6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6],
vec![
3, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 8, 13, 13, 13, 13, 13, 13,
13, 13, 13, 13, 13, 13, 13, 13, 13, 18,
],
vec![6, 6, 6, 6, 6],
vec![
2, 4, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 7, 11, 12, 14, 15, 16, 19,
],
vec![
6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6,
],
vec![
1, 5, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6,
11, 12, 14, 15, 17, 18, 21, 22, 24, 25, 27,
],
vec![
0, 9, 10, 13, 13, 13, 13, 13, 13, 13, 13, 13, 13, 13, 13, 13, 13, 13, 13, 13, 20,
23, 26, 27, 27, 27, 27, 27, 27, 27, 28,
],
]
}

fn run_test(all_list: &[&[i32]], expect_size: EndDomain, max_iter: usize) {
let mut candidate = Candidate::new(all_list, expect_size);

let got = if candidate.init() {
candidate.calc_partition(3, max_iter)
candidate.calc_partition(4, max_iter)
} else {
let sum: usize = all_list.iter().map(|ls| ls.len()).sum();
assert_eq!(sum, 0);
Expand All @@ -574,11 +622,13 @@ mod tests {
(ls[..end].last(), ls[end..].first())
})
.fold((None, None), |acc, (end, start)| {
(acc.0.max(end), match (acc.1, start) {
let max_end = acc.0.max(end);
let min_start = match (acc.1, start) {
(None, None) => None,
(None, v @ Some(_)) | (v @ Some(_), None) => v,
(Some(a), Some(b)) => Some(a.min(b)),
})
};
(max_end, min_start)
});
match x {
(Some(a), Some(b)) => assert!(a < b, "all_list {all_list:?}"),
Expand Down

0 comments on commit 37de572

Please sign in to comment.