Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fix: fix segfault in list_domain #16979

Merged
merged 3 commits into from
Dec 2, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -150,10 +150,9 @@ where
}

fn calc_partition_point(&self) -> Partition {
let mut candidate = Candidate::new(&self.rows, EndDomain {
min: self.min_task,
max: self.max_task,
});
let mut candidate =
Candidate::new(&self.rows, EndDomain::new(self.min_task, self.max_task));

candidate.init();

// if candidate.is_small_task() {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -74,7 +74,7 @@ where Self: Debug

#[derive(Debug)]
pub struct Partition {
pub ends: Vec<(usize, usize)>,
pub ends: Vec<(usize, usize)>, // index, partition point
pub total: usize,
}

Expand Down Expand Up @@ -117,8 +117,11 @@ where
T: List + 'a,
T::Item<'a>: Debug,
{
// The cut point value.
target: T::Item<'a>,
// The domain of partition point in each list, and also the domain of the generated task size in each list.
domains: Vec<EndDomain>,
// The size domain of the generated task if the current target is used as the cut point.
sum: EndDomain,
}

Expand All @@ -136,6 +139,7 @@ where T: List
}

pub fn init(&mut self) -> bool {
// Take the smallest first and the smallest last of all the lists as the initial target range.
let target: (Option<T::Item<'a>>, Option<T::Item<'a>>) =
self.all_list.iter().fold((None, None), |(min, max), ls| {
let min = match (min, ls.first()) {
Expand All @@ -151,9 +155,8 @@ where T: List

(min, max)
});
let (min_target, max_target) = if let (Some(min), Some(max)) = target {
(min, max)
} else {
let (Some(min_target), Some(max_target)) = target else {
// invalid empty input
return false;
};

Expand Down Expand Up @@ -189,7 +192,7 @@ where T: List

pub fn is_small_task(&mut self) -> bool {
loop {
let sum = self.do_search_max(Some(8));
let sum = self.reduce_max_domain(Some(8));
match self.expect.overlaps(sum) {
Overlap::Left => return true,
Overlap::Right => return false,
Expand All @@ -203,7 +206,7 @@ where T: List
for _ in 0..max_iter {
match self.overlaps() {
(_, _, Overlap::Cross) => {
let sum = self.do_search_max(Some(n));
let sum = self.reduce_max_domain(Some(n));
if self.is_finish(sum) {
return Partition::new(self.max_target.unwrap());
}
Expand All @@ -221,7 +224,7 @@ where T: List
Some(Overlap::Cross),
Overlap::Right,
) => {
let sum = self.do_search_mid(Some(n));
let sum = self.reduce_mid_domain(Some(n));
match self.expect.overlaps(sum) {
Overlap::Right => self.cut_right(),
Overlap::Left if matches!(min_overlap, Overlap::Left) => self.cut_left(),
Expand All @@ -232,7 +235,7 @@ where T: List
}
}
(Overlap::Cross, Some(Overlap::Left), Overlap::Right) => {
let sum = self.do_search_min(Some(n));
let sum = self.reduce_min_domain(Some(n));
match self.expect.overlaps(sum) {
Overlap::Left => self.cut_left(),
Overlap::Cross if sum.done() => {
Expand All @@ -251,19 +254,19 @@ where T: List
};
}

self.do_search_max(None);
self.reduce_max_domain(None);
Partition::new(self.max_target.unwrap())
}

fn do_search_max(&mut self, n: Option<usize>) -> EndDomain {
fn reduce_max_domain(&mut self, n: Option<usize>) -> EndDomain {
do_search(self.all_list, self.max_target.as_mut().unwrap(), n)
}

fn do_search_min(&mut self, n: Option<usize>) -> EndDomain {
fn reduce_min_domain(&mut self, n: Option<usize>) -> EndDomain {
do_search(self.all_list, self.min_target.as_mut().unwrap(), n)
}

fn do_search_mid(&mut self, n: Option<usize>) -> EndDomain {
fn reduce_mid_domain(&mut self, n: Option<usize>) -> EndDomain {
do_search(self.all_list, self.mid_target.as_mut().unwrap(), n)
}

Expand All @@ -290,11 +293,8 @@ where T: List
if max_domain.is_zero() {
continue;
}
let five = EndDomain {
min: min_domain.min,
max: max_domain.min,
}
.five_point();

let five = min_domain.merge(max_domain).five_point();
for v in five.into_iter().filter_map(|i| {
let v = ls.index(i);
if v >= *min_target && v <= *max_target {
Expand Down Expand Up @@ -336,6 +336,7 @@ where T: List
}

fn overlaps(&self) -> (Overlap, Option<Overlap>, Overlap) {
// Compare expect task size domain with min_target,mid_target and max_target task size domain.
(
self.expect.overlaps(self.min_target.as_ref().unwrap().sum),
self.mid_target
Expand Down Expand Up @@ -392,11 +393,16 @@ where

#[derive(Debug, Clone, Copy, PartialEq, Eq, Default)]
pub struct EndDomain {
pub min: usize,
pub max: usize,
min: usize,
max: usize,
}

impl EndDomain {
pub fn new(min: usize, max: usize) -> EndDomain {
assert!(min <= max);
EndDomain { min, max }
}

fn done(&self) -> bool {
self.min == self.max
}
Expand Down Expand Up @@ -453,6 +459,13 @@ impl EndDomain {
],
}
}

fn merge(&self, other: &EndDomain) -> EndDomain {
EndDomain {
forsaken628 marked this conversation as resolved.
Show resolved Hide resolved
min: self.min.min(other.min),
max: self.max.max(other.max),
}
}
}

#[derive(Debug)]
Expand Down Expand Up @@ -481,15 +494,18 @@ impl std::iter::Sum for EndDomain {

impl From<std::ops::RangeInclusive<usize>> for EndDomain {
fn from(value: std::ops::RangeInclusive<usize>) -> Self {
EndDomain {
min: *value.start(),
max: *value.end(),
}
EndDomain::new(*value.start(), *value.end())
}
}

#[cfg(test)]
mod tests {
use std::iter::repeat_with;

use rand::rngs::StdRng;
use rand::Rng;
use rand::SeedableRng;

use super::*;

impl List for &[i32] {
Expand Down Expand Up @@ -525,35 +541,67 @@ mod tests {
run_test(&all_list, (5..=10).into(), 10);
}

{
let all_list = issue_16923();

let all_list: Vec<_> = all_list.iter().map(|v| v.as_slice()).collect();
run_test(&all_list, (5..=100).into(), 20);
}

for _ in 0..100 {
let all_list = rand_data();
let all_list = rand_data(rand::random());
let all_list: Vec<_> = all_list.iter().map(|v| v.as_slice()).collect();

run_test(&all_list, (5..=10).into(), 10)
run_test(&all_list, (5..=100).into(), 20)
}
}

fn rand_data() -> Vec<Vec<i32>> {
use rand::Rng;
let mut rng = rand::thread_rng();
fn rand_data(seed: u64) -> Vec<Vec<i32>> {
let mut rng = StdRng::seed_from_u64(seed);

(0..5)
.map(|_| {
let rows: usize = rng.gen_range(0..=20);
let mut data = (0..rows)
.map(|_| rng.gen_range(0..=1000))
.collect::<Vec<_>>();
data.sort();
data
})
.collect::<Vec<_>>()
let list = rng.gen_range(1..=10);
repeat_with(|| {
let rows = rng.gen_range(0..=40);
let mut data = repeat_with(|| rng.gen_range(0..=1000))
.take(rows)
.collect::<Vec<_>>();
data.sort();
data
})
.take(list)
.collect::<Vec<_>>()
}

fn issue_16923() -> Vec<Vec<i32>> {
vec![
vec![6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6],
vec![
3, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 8, 13, 13, 13, 13, 13, 13,
13, 13, 13, 13, 13, 13, 13, 13, 13, 18,
],
vec![6, 6, 6, 6, 6],
vec![
2, 4, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 7, 11, 12, 14, 15, 16, 19,
],
vec![
6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6,
],
vec![
1, 5, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6,
11, 12, 14, 15, 17, 18, 21, 22, 24, 25, 27,
],
vec![
0, 9, 10, 13, 13, 13, 13, 13, 13, 13, 13, 13, 13, 13, 13, 13, 13, 13, 13, 13, 20,
23, 26, 27, 27, 27, 27, 27, 27, 27, 28,
],
]
}

fn run_test(all_list: &[&[i32]], expect_size: EndDomain, max_iter: usize) {
let mut candidate = Candidate::new(all_list, expect_size);

let got = if candidate.init() {
candidate.calc_partition(3, max_iter)
candidate.calc_partition(4, max_iter)
} else {
let sum: usize = all_list.iter().map(|ls| ls.len()).sum();
assert_eq!(sum, 0);
Expand All @@ -574,11 +622,13 @@ mod tests {
(ls[..end].last(), ls[end..].first())
})
.fold((None, None), |acc, (end, start)| {
(acc.0.max(end), match (acc.1, start) {
let max_end = acc.0.max(end);
let min_start = match (acc.1, start) {
(None, None) => None,
(None, v @ Some(_)) | (v @ Some(_), None) => v,
(Some(a), Some(b)) => Some(a.min(b)),
})
};
(max_end, min_start)
});
match x {
(Some(a), Some(b)) => assert!(a < b, "all_list {all_list:?}"),
Expand Down