Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

direct_mapping: fix iter_memory_pair_chunks in reverse mode #34204

Merged
merged 1 commit into from
Nov 27, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions Cargo.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

1 change: 1 addition & 0 deletions programs/bpf_loader/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@ assert_matches = { workspace = true }
memoffset = { workspace = true }
rand = { workspace = true }
solana-sdk = { workspace = true, features = ["dev-context-only-utils"] }
test-case = { workspace = true }

[lib]
crate-type = ["lib"]
Expand Down
254 changes: 155 additions & 99 deletions programs/bpf_loader/src/syscalls/mem_ops.rs
Original file line number Diff line number Diff line change
Expand Up @@ -289,8 +289,8 @@ fn iter_memory_pair_chunks<T, F>(
src_access: AccessType,
src_addr: u64,
dst_access: AccessType,
mut dst_addr: u64,
n: u64,
dst_addr: u64,
n_bytes: u64,
memory_mapping: &MemoryMapping,
reverse: bool,
mut fun: F,
Expand All @@ -299,52 +299,90 @@ where
T: Default,
F: FnMut(*const u8, *const u8, usize) -> Result<T, Error>,
{
let mut src_chunk_iter = MemoryChunkIterator::new(memory_mapping, src_access, src_addr, n)
.map_err(EbpfError::from)?;
loop {
// iterate source chunks
let (src_region, src_vm_addr, mut src_len) = match if reverse {
src_chunk_iter.next_back()
} else {
src_chunk_iter.next()
} {
Some(item) => item?,
None => break,
};

let mut src_host_addr = Result::from(src_region.vm_to_host(src_vm_addr, src_len as u64))?;
let mut dst_chunk_iter = MemoryChunkIterator::new(memory_mapping, dst_access, dst_addr, n)
let mut src_chunk_iter =
MemoryChunkIterator::new(memory_mapping, src_access, src_addr, n_bytes)
.map_err(EbpfError::from)?;
let mut dst_chunk_iter =
MemoryChunkIterator::new(memory_mapping, dst_access, dst_addr, n_bytes)
.map_err(EbpfError::from)?;
// iterate over destination chunks until this source chunk has been completely copied
while src_len > 0 {
loop {
let (dst_region, dst_vm_addr, dst_len) = match if reverse {
dst_chunk_iter.next_back()

let mut src_chunk = None;
let mut dst_chunk = None;

macro_rules! memory_chunk {
($chunk_iter:ident, $chunk:ident) => {
if let Some($chunk) = &mut $chunk {
// Keep processing the current chunk
$chunk
} else {
// This is either the first call or we've processed all the bytes in the current
// chunk. Move to the next one.
let chunk = match if reverse {
$chunk_iter.next_back()
} else {
dst_chunk_iter.next()
$chunk_iter.next()
} {
Some(item) => item?,
None => break,
};
let dst_host_addr =
Result::from(dst_region.vm_to_host(dst_vm_addr, dst_len as u64))?;
let chunk_len = src_len.min(dst_len);
fun(
src_host_addr as *const u8,
dst_host_addr as *const u8,
chunk_len,
)?;
src_len = src_len.saturating_sub(chunk_len);
if reverse {
dst_addr = dst_addr.saturating_sub(chunk_len as u64);
} else {
dst_addr = dst_addr.saturating_add(chunk_len as u64);
}
if src_len == 0 {
break;
}
src_host_addr = src_host_addr.saturating_add(chunk_len as u64);
$chunk.insert(chunk)
}
};
}

loop {
let (src_region, src_chunk_addr, src_remaining) = memory_chunk!(src_chunk_iter, src_chunk);
let (dst_region, dst_chunk_addr, dst_remaining) = memory_chunk!(dst_chunk_iter, dst_chunk);

// We always process same-length pairs
let chunk_len = *src_remaining.min(dst_remaining);

let (src_host_addr, dst_host_addr) = {
let (src_addr, dst_addr) = if reverse {
// When scanning backwards not only we want to scan regions from the end,
// we want to process the memory within regions backwards as well.
(
src_chunk_addr
.saturating_add(*src_remaining as u64)
.saturating_sub(chunk_len as u64),
dst_chunk_addr
.saturating_add(*dst_remaining as u64)
.saturating_sub(chunk_len as u64),
)
} else {
(*src_chunk_addr, *dst_chunk_addr)
};

(
Result::from(src_region.vm_to_host(src_addr, chunk_len as u64))?,
Result::from(dst_region.vm_to_host(dst_addr, chunk_len as u64))?,
)
};

fun(
src_host_addr as *const u8,
dst_host_addr as *const u8,
chunk_len,
)?;

// Update how many bytes we have left to scan in each chunk
*src_remaining = src_remaining.saturating_sub(chunk_len);
*dst_remaining = dst_remaining.saturating_sub(chunk_len);

if !reverse {
// We've scanned `chunk_len` bytes so we move the vm address forward. In reverse
// mode we don't do this since we make progress by decreasing src_len and
// dst_len.
*src_chunk_addr = src_chunk_addr.saturating_add(chunk_len as u64);
*dst_chunk_addr = dst_chunk_addr.saturating_add(chunk_len as u64);
}

if *src_remaining == 0 {
src_chunk = None;
}

if *dst_remaining == 0 {
dst_chunk = None;
}
}

Expand Down Expand Up @@ -471,11 +509,13 @@ impl<'a> DoubleEndedIterator for MemoryChunkIterator<'a> {

#[cfg(test)]
#[allow(clippy::indexing_slicing)]
#[allow(clippy::arithmetic_side_effects)]
mod tests {
use {
super::*,
assert_matches::assert_matches,
solana_rbpf::{ebpf::MM_PROGRAM_START, program::SBPFVersion},
test_case::test_case,
};

fn to_chunk_vec<'a>(
Expand Down Expand Up @@ -734,72 +774,59 @@ mod tests {
memmove_non_contiguous(MM_PROGRAM_START, MM_PROGRAM_START + 8, 4, &memory_mapping).unwrap();
}

#[test]
fn test_overlapping_memmove_non_contiguous_right() {
#[test_case(&[], (0, 0, 0); "no regions")]
#[test_case(&[10], (1, 10, 0); "single region 0 len")]
#[test_case(&[10], (0, 5, 5); "single region no overlap")]
#[test_case(&[10], (0, 0, 10) ; "single region complete overlap")]
#[test_case(&[10], (2, 0, 5); "single region partial overlap start")]
#[test_case(&[10], (0, 1, 6); "single region partial overlap middle")]
#[test_case(&[10], (2, 5, 5); "single region partial overlap end")]
#[test_case(&[3, 5], (0, 5, 2) ; "two regions no overlap, single source region")]
#[test_case(&[4, 7], (0, 5, 5) ; "two regions no overlap, multiple source regions")]
#[test_case(&[3, 8], (0, 0, 11) ; "two regions complete overlap")]
#[test_case(&[2, 9], (3, 0, 5) ; "two regions partial overlap start")]
#[test_case(&[3, 9], (1, 2, 5) ; "two regions partial overlap middle")]
#[test_case(&[7, 3], (2, 6, 4) ; "two regions partial overlap end")]
#[test_case(&[2, 6, 3, 4], (0, 10, 2) ; "many regions no overlap, single source region")]
#[test_case(&[2, 1, 2, 5, 6], (2, 10, 4) ; "many regions no overlap, multiple source regions")]
#[test_case(&[8, 1, 3, 6], (0, 0, 18) ; "many regions complete overlap")]
#[test_case(&[7, 3, 1, 4, 5], (5, 0, 8) ; "many regions overlap start")]
#[test_case(&[1, 5, 2, 9, 3], (5, 4, 8) ; "many regions overlap middle")]
#[test_case(&[3, 9, 1, 1, 2, 1], (2, 9, 8) ; "many regions overlap end")]
fn test_memmove_non_contiguous(
Lichtso marked this conversation as resolved.
Show resolved Hide resolved
regions: &[usize],
(src_offset, dst_offset, len): (usize, usize, usize),
) {
let config = Config {
aligned_memory_mapping: false,
..Config::default()
};
let mem1 = vec![0x11; 1];
let mut mem2 = vec![0x22; 2];
let mut mem3 = vec![0x33; 3];
let mut mem4 = vec![0x44; 4];
let memory_mapping = MemoryMapping::new(
vec![
MemoryRegion::new_readonly(&mem1, MM_PROGRAM_START),
MemoryRegion::new_writable(&mut mem2, MM_PROGRAM_START + 1),
MemoryRegion::new_writable(&mut mem3, MM_PROGRAM_START + 3),
MemoryRegion::new_writable(&mut mem4, MM_PROGRAM_START + 6),
],
&config,
&SBPFVersion::V2,
)
.unwrap();

// overlapping memmove right - the implementation will copy backwards
assert_eq!(
memmove_non_contiguous(MM_PROGRAM_START + 1, MM_PROGRAM_START, 7, &memory_mapping)
.unwrap(),
0
);
assert_eq!(&mem1, &[0x11]);
assert_eq!(&mem2, &[0x11, 0x22]);
assert_eq!(&mem3, &[0x22, 0x33, 0x33]);
assert_eq!(&mem4, &[0x33, 0x44, 0x44, 0x44]);
}

#[test]
fn test_overlapping_memmove_non_contiguous_left() {
let config = Config {
aligned_memory_mapping: false,
..Config::default()
let (mem, memory_mapping) = build_memory_mapping(regions, &config);

// flatten the memory so we can memmove it with ptr::copy
let mut expected_memory = flatten_memory(&mem);
unsafe {
std::ptr::copy(
expected_memory.as_ptr().add(src_offset),
expected_memory.as_mut_ptr().add(dst_offset),
len,
)
};
let mut mem1 = vec![0x11; 1];
let mut mem2 = vec![0x22; 2];
let mut mem3 = vec![0x33; 3];
let mut mem4 = vec![0x44; 4];
let memory_mapping = MemoryMapping::new(
vec![
MemoryRegion::new_writable(&mut mem1, MM_PROGRAM_START),
MemoryRegion::new_writable(&mut mem2, MM_PROGRAM_START + 1),
MemoryRegion::new_writable(&mut mem3, MM_PROGRAM_START + 3),
MemoryRegion::new_writable(&mut mem4, MM_PROGRAM_START + 6),
],
&config,
&SBPFVersion::V2,

// do our memmove
memmove_non_contiguous(
MM_PROGRAM_START + dst_offset as u64,
MM_PROGRAM_START + src_offset as u64,
len as u64,
&memory_mapping,
)
.unwrap();

// overlapping memmove left - the implementation will copy forward
assert_eq!(
memmove_non_contiguous(MM_PROGRAM_START, MM_PROGRAM_START + 1, 7, &memory_mapping)
.unwrap(),
0
);
assert_eq!(&mem1, &[0x22]);
assert_eq!(&mem2, &[0x22, 0x33]);
assert_eq!(&mem3, &[0x33, 0x33, 0x44]);
assert_eq!(&mem4, &[0x44, 0x44, 0x44, 0x44]);
// flatten memory post our memmove
let memory = flatten_memory(&mem);

// compare libc's memmove with ours
assert_eq!(expected_memory, memory);
}

#[test]
Expand Down Expand Up @@ -910,4 +937,33 @@ mod tests {
unsafe { memcmp(b"oobar", b"obarb", 5) }
);
}

fn build_memory_mapping<'a>(
regions: &[usize],
config: &'a Config,
) -> (Vec<Vec<u8>>, MemoryMapping<'a>) {
let mut regs = vec![];
let mut mem = Vec::new();
let mut offset = 0;
for (i, region_len) in regions.iter().enumerate() {
mem.push(
(0..*region_len)
.map(|x| (i * 10 + x) as u8)
.collect::<Vec<_>>(),
);
regs.push(MemoryRegion::new_writable(
&mut mem[i],
MM_PROGRAM_START + offset as u64,
));
offset += *region_len;
}

let memory_mapping = MemoryMapping::new(regs, config, &SBPFVersion::V2).unwrap();

(mem, memory_mapping)
}

fn flatten_memory(mem: &[Vec<u8>]) -> Vec<u8> {
mem.iter().flatten().copied().collect()
}
}