diff --git a/Cargo.lock b/Cargo.lock index e230cb902c907d..706e107b6581d1 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -5521,6 +5521,7 @@ dependencies = [ "solana-sdk", "solana-zk-token-sdk", "solana_rbpf", + "test-case", "thiserror", ] diff --git a/programs/bpf_loader/Cargo.toml b/programs/bpf_loader/Cargo.toml index 16a52c07928620..48d771b8656828 100644 --- a/programs/bpf_loader/Cargo.toml +++ b/programs/bpf_loader/Cargo.toml @@ -27,6 +27,7 @@ assert_matches = { workspace = true } memoffset = { workspace = true } rand = { workspace = true } solana-sdk = { workspace = true, features = ["dev-context-only-utils"] } +test-case = { workspace = true } [lib] crate-type = ["lib"] diff --git a/programs/bpf_loader/src/syscalls/mem_ops.rs b/programs/bpf_loader/src/syscalls/mem_ops.rs index 7e9b69fc6f310c..a544cf58a286a7 100644 --- a/programs/bpf_loader/src/syscalls/mem_ops.rs +++ b/programs/bpf_loader/src/syscalls/mem_ops.rs @@ -289,8 +289,8 @@ fn iter_memory_pair_chunks( src_access: AccessType, src_addr: u64, dst_access: AccessType, - mut dst_addr: u64, - n: u64, + dst_addr: u64, + n_bytes: u64, memory_mapping: &MemoryMapping, reverse: bool, mut fun: F, @@ -299,52 +299,90 @@ where T: Default, F: FnMut(*const u8, *const u8, usize) -> Result, { - let mut src_chunk_iter = MemoryChunkIterator::new(memory_mapping, src_access, src_addr, n) - .map_err(EbpfError::from)?; - loop { - // iterate source chunks - let (src_region, src_vm_addr, mut src_len) = match if reverse { - src_chunk_iter.next_back() - } else { - src_chunk_iter.next() - } { - Some(item) => item?, - None => break, - }; - - let mut src_host_addr = Result::from(src_region.vm_to_host(src_vm_addr, src_len as u64))?; - let mut dst_chunk_iter = MemoryChunkIterator::new(memory_mapping, dst_access, dst_addr, n) + let mut src_chunk_iter = + MemoryChunkIterator::new(memory_mapping, src_access, src_addr, n_bytes) + .map_err(EbpfError::from)?; + let mut dst_chunk_iter = + MemoryChunkIterator::new(memory_mapping, dst_access, dst_addr, n_bytes) .map_err(EbpfError::from)?; - // iterate over destination chunks until this source chunk has been completely copied - while src_len > 0 { - loop { - let (dst_region, dst_vm_addr, dst_len) = match if reverse { - dst_chunk_iter.next_back() + + let mut src_chunk = None; + let mut dst_chunk = None; + + macro_rules! memory_chunk { + ($chunk_iter:ident, $chunk:ident) => { + if let Some($chunk) = &mut $chunk { + // Keep processing the current chunk + $chunk + } else { + // This is either the first call or we've processed all the bytes in the current + // chunk. Move to the next one. + let chunk = match if reverse { + $chunk_iter.next_back() } else { - dst_chunk_iter.next() + $chunk_iter.next() } { Some(item) => item?, None => break, }; - let dst_host_addr = - Result::from(dst_region.vm_to_host(dst_vm_addr, dst_len as u64))?; - let chunk_len = src_len.min(dst_len); - fun( - src_host_addr as *const u8, - dst_host_addr as *const u8, - chunk_len, - )?; - src_len = src_len.saturating_sub(chunk_len); - if reverse { - dst_addr = dst_addr.saturating_sub(chunk_len as u64); - } else { - dst_addr = dst_addr.saturating_add(chunk_len as u64); - } - if src_len == 0 { - break; - } - src_host_addr = src_host_addr.saturating_add(chunk_len as u64); + $chunk.insert(chunk) } + }; + } + + loop { + let (src_region, src_chunk_addr, src_remaining) = memory_chunk!(src_chunk_iter, src_chunk); + let (dst_region, dst_chunk_addr, dst_remaining) = memory_chunk!(dst_chunk_iter, dst_chunk); + + // We always process same-length pairs + let chunk_len = *src_remaining.min(dst_remaining); + + let (src_host_addr, dst_host_addr) = { + let (src_addr, dst_addr) = if reverse { + // When scanning backwards not only we want to scan regions from the end, + // we want to process the memory within regions backwards as well. + ( + src_chunk_addr + .saturating_add(*src_remaining as u64) + .saturating_sub(chunk_len as u64), + dst_chunk_addr + .saturating_add(*dst_remaining as u64) + .saturating_sub(chunk_len as u64), + ) + } else { + (*src_chunk_addr, *dst_chunk_addr) + }; + + ( + Result::from(src_region.vm_to_host(src_addr, chunk_len as u64))?, + Result::from(dst_region.vm_to_host(dst_addr, chunk_len as u64))?, + ) + }; + + fun( + src_host_addr as *const u8, + dst_host_addr as *const u8, + chunk_len, + )?; + + // Update how many bytes we have left to scan in each chunk + *src_remaining = src_remaining.saturating_sub(chunk_len); + *dst_remaining = dst_remaining.saturating_sub(chunk_len); + + if !reverse { + // We've scanned `chunk_len` bytes so we move the vm address forward. In reverse + // mode we don't do this since we make progress by decreasing src_len and + // dst_len. + *src_chunk_addr = src_chunk_addr.saturating_add(chunk_len as u64); + *dst_chunk_addr = dst_chunk_addr.saturating_add(chunk_len as u64); + } + + if *src_remaining == 0 { + src_chunk = None; + } + + if *dst_remaining == 0 { + dst_chunk = None; } } @@ -471,11 +509,13 @@ impl<'a> DoubleEndedIterator for MemoryChunkIterator<'a> { #[cfg(test)] #[allow(clippy::indexing_slicing)] +#[allow(clippy::arithmetic_side_effects)] mod tests { use { super::*, assert_matches::assert_matches, solana_rbpf::{ebpf::MM_PROGRAM_START, program::SBPFVersion}, + test_case::test_case, }; fn to_chunk_vec<'a>( @@ -734,72 +774,59 @@ mod tests { memmove_non_contiguous(MM_PROGRAM_START, MM_PROGRAM_START + 8, 4, &memory_mapping).unwrap(); } - #[test] - fn test_overlapping_memmove_non_contiguous_right() { + #[test_case(&[], (0, 0, 0); "no regions")] + #[test_case(&[10], (1, 10, 0); "single region 0 len")] + #[test_case(&[10], (0, 5, 5); "single region no overlap")] + #[test_case(&[10], (0, 0, 10) ; "single region complete overlap")] + #[test_case(&[10], (2, 0, 5); "single region partial overlap start")] + #[test_case(&[10], (0, 1, 6); "single region partial overlap middle")] + #[test_case(&[10], (2, 5, 5); "single region partial overlap end")] + #[test_case(&[3, 5], (0, 5, 2) ; "two regions no overlap, single source region")] + #[test_case(&[4, 7], (0, 5, 5) ; "two regions no overlap, multiple source regions")] + #[test_case(&[3, 8], (0, 0, 11) ; "two regions complete overlap")] + #[test_case(&[2, 9], (3, 0, 5) ; "two regions partial overlap start")] + #[test_case(&[3, 9], (1, 2, 5) ; "two regions partial overlap middle")] + #[test_case(&[7, 3], (2, 6, 4) ; "two regions partial overlap end")] + #[test_case(&[2, 6, 3, 4], (0, 10, 2) ; "many regions no overlap, single source region")] + #[test_case(&[2, 1, 2, 5, 6], (2, 10, 4) ; "many regions no overlap, multiple source regions")] + #[test_case(&[8, 1, 3, 6], (0, 0, 18) ; "many regions complete overlap")] + #[test_case(&[7, 3, 1, 4, 5], (5, 0, 8) ; "many regions overlap start")] + #[test_case(&[1, 5, 2, 9, 3], (5, 4, 8) ; "many regions overlap middle")] + #[test_case(&[3, 9, 1, 1, 2, 1], (2, 9, 8) ; "many regions overlap end")] + fn test_memmove_non_contiguous( + regions: &[usize], + (src_offset, dst_offset, len): (usize, usize, usize), + ) { let config = Config { aligned_memory_mapping: false, ..Config::default() }; - let mem1 = vec![0x11; 1]; - let mut mem2 = vec![0x22; 2]; - let mut mem3 = vec![0x33; 3]; - let mut mem4 = vec![0x44; 4]; - let memory_mapping = MemoryMapping::new( - vec![ - MemoryRegion::new_readonly(&mem1, MM_PROGRAM_START), - MemoryRegion::new_writable(&mut mem2, MM_PROGRAM_START + 1), - MemoryRegion::new_writable(&mut mem3, MM_PROGRAM_START + 3), - MemoryRegion::new_writable(&mut mem4, MM_PROGRAM_START + 6), - ], - &config, - &SBPFVersion::V2, - ) - .unwrap(); - - // overlapping memmove right - the implementation will copy backwards - assert_eq!( - memmove_non_contiguous(MM_PROGRAM_START + 1, MM_PROGRAM_START, 7, &memory_mapping) - .unwrap(), - 0 - ); - assert_eq!(&mem1, &[0x11]); - assert_eq!(&mem2, &[0x11, 0x22]); - assert_eq!(&mem3, &[0x22, 0x33, 0x33]); - assert_eq!(&mem4, &[0x33, 0x44, 0x44, 0x44]); - } - - #[test] - fn test_overlapping_memmove_non_contiguous_left() { - let config = Config { - aligned_memory_mapping: false, - ..Config::default() + let (mem, memory_mapping) = build_memory_mapping(regions, &config); + + // flatten the memory so we can memmove it with ptr::copy + let mut expected_memory = flatten_memory(&mem); + unsafe { + std::ptr::copy( + expected_memory.as_ptr().add(src_offset), + expected_memory.as_mut_ptr().add(dst_offset), + len, + ) }; - let mut mem1 = vec![0x11; 1]; - let mut mem2 = vec![0x22; 2]; - let mut mem3 = vec![0x33; 3]; - let mut mem4 = vec![0x44; 4]; - let memory_mapping = MemoryMapping::new( - vec![ - MemoryRegion::new_writable(&mut mem1, MM_PROGRAM_START), - MemoryRegion::new_writable(&mut mem2, MM_PROGRAM_START + 1), - MemoryRegion::new_writable(&mut mem3, MM_PROGRAM_START + 3), - MemoryRegion::new_writable(&mut mem4, MM_PROGRAM_START + 6), - ], - &config, - &SBPFVersion::V2, + + // do our memmove + memmove_non_contiguous( + MM_PROGRAM_START + dst_offset as u64, + MM_PROGRAM_START + src_offset as u64, + len as u64, + &memory_mapping, ) .unwrap(); - // overlapping memmove left - the implementation will copy forward - assert_eq!( - memmove_non_contiguous(MM_PROGRAM_START, MM_PROGRAM_START + 1, 7, &memory_mapping) - .unwrap(), - 0 - ); - assert_eq!(&mem1, &[0x22]); - assert_eq!(&mem2, &[0x22, 0x33]); - assert_eq!(&mem3, &[0x33, 0x33, 0x44]); - assert_eq!(&mem4, &[0x44, 0x44, 0x44, 0x44]); + // flatten memory post our memmove + let memory = flatten_memory(&mem); + + // compare libc's memmove with ours + assert_eq!(expected_memory, memory); } #[test] @@ -910,4 +937,29 @@ mod tests { unsafe { memcmp(b"oobar", b"obarb", 5) } ); } + + fn build_memory_mapping<'a>( + regions: &[usize], + config: &'a Config, + ) -> (Vec>, MemoryMapping<'a>) { + let mut regs = vec![]; + let mut mem = Vec::new(); + let mut offset = 0; + for (i, region_len) in regions.iter().enumerate() { + mem.push(vec![i as u8; *region_len]); + regs.push(MemoryRegion::new_writable( + &mut mem[i], + MM_PROGRAM_START + offset as u64, + )); + offset += *region_len; + } + + let memory_mapping = MemoryMapping::new(regs, config, &SBPFVersion::V2).unwrap(); + + (mem, memory_mapping) + } + + fn flatten_memory(mem: &[Vec]) -> Vec { + mem.iter().flatten().copied().collect() + } }