Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fixes API soundness issue in join() #81728

Merged
merged 2 commits into from
Mar 28, 2021
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
44 changes: 25 additions & 19 deletions library/alloc/src/str.rs
Original file line number Diff line number Diff line change
Expand Up @@ -90,8 +90,8 @@ impl<S: Borrow<str>> Join<&str> for [S] {
}
}

macro_rules! spezialize_for_lengths {
($separator:expr, $target:expr, $iter:expr; $($num:expr),*) => {
macro_rules! specialize_for_lengths {
($separator:expr, $target:expr, $iter:expr; $($num:expr),*) => {{
let mut target = $target;
let iter = $iter;
let sep_bytes = $separator;
Expand All @@ -102,19 +102,22 @@ macro_rules! spezialize_for_lengths {
$num => {
for s in iter {
copy_slice_and_advance!(target, sep_bytes);
copy_slice_and_advance!(target, s.borrow().as_ref());
let content_bytes = s.borrow().as_ref();
copy_slice_and_advance!(target, content_bytes);
}
},
)*
_ => {
// arbitrary non-zero size fallback
for s in iter {
copy_slice_and_advance!(target, sep_bytes);
copy_slice_and_advance!(target, s.borrow().as_ref());
let content_bytes = s.borrow().as_ref();
copy_slice_and_advance!(target, content_bytes);
}
}
}
};
target
}}
}

macro_rules! copy_slice_and_advance {
Expand Down Expand Up @@ -153,30 +156,33 @@ where
// if the `len` calculation overflows, we'll panic
// we would have run out of memory anyway and the rest of the function requires
// the entire Vec pre-allocated for safety
let len = sep_len
let reserved_len = sep_len
.checked_mul(iter.len())
.and_then(|n| {
slice.iter().map(|s| s.borrow().as_ref().len()).try_fold(n, usize::checked_add)
})
.expect("attempt to join into collection with len > usize::MAX");

// crucial for safety
let mut result = Vec::with_capacity(len);
assert!(result.capacity() >= len);
// prepare an uninitialized buffer
let mut result = Vec::with_capacity(reserved_len);
debug_assert!(result.capacity() >= reserved_len);

result.extend_from_slice(first.borrow().as_ref());

unsafe {
{
let pos = result.len();
let target = result.get_unchecked_mut(pos..len);

// copy separator and slices over without bounds checks
// generate loops with hardcoded offsets for small separators
// massive improvements possible (~ x2)
spezialize_for_lengths!(sep, target, iter; 0, 1, 2, 3, 4);
}
result.set_len(len);
let pos = result.len();
let target = result.get_unchecked_mut(pos..reserved_len);

// copy separator and slices over without bounds checks
// generate loops with hardcoded offsets for small separators
// massive improvements possible (~ x2)
let remain = specialize_for_lengths!(sep, target, iter; 0, 1, 2, 3, 4);

// A weird borrow implementation may return different
// slices for the length calculation and the actual copy.
// Make sure we don't expose uninitialized bytes to the caller.
let result_len = reserved_len - remain.len();
result.set_len(result_len);
}
result
}
Expand Down
30 changes: 30 additions & 0 deletions library/alloc/tests/str.rs
Original file line number Diff line number Diff line change
Expand Up @@ -160,6 +160,36 @@ fn test_join_for_different_lengths_with_long_separator() {
test_join!("~~~~~a~~~~~bc", ["", "a", "bc"], "~~~~~");
}

#[test]
fn test_join_isue_80335() {
use core::{borrow::Borrow, cell::Cell};

struct WeirdBorrow {
state: Cell<bool>,
}

impl Default for WeirdBorrow {
fn default() -> Self {
WeirdBorrow { state: Cell::new(false) }
}
}

impl Borrow<str> for WeirdBorrow {
fn borrow(&self) -> &str {
let state = self.state.get();
if state {
"0"
} else {
self.state.set(true);
"123456"
}
}
}

let arr: [WeirdBorrow; 3] = Default::default();
test_join!("0-0-0", arr, "-");
}

#[test]
#[cfg_attr(miri, ignore)] // Miri is too slow
fn test_unsafe_slice() {
Expand Down