diff --git a/libraries/shared-memory-server/src/channel.rs b/libraries/shared-memory-server/src/channel.rs index 3f63c3c9f..ca8a8e220 100644 --- a/libraries/shared-memory-server/src/channel.rs +++ b/libraries/shared-memory-server/src/channel.rs @@ -3,7 +3,8 @@ use raw_sync_2::events::{Event, EventImpl, EventInit, EventState}; use serde::{Deserialize, Serialize}; use shared_memory_extended::Shmem; use std::{ - mem, slice, + mem::{self, align_of}, + slice, sync::atomic::{AtomicBool, AtomicU64}, time::Duration, }; @@ -27,7 +28,7 @@ impl ShmemChannel { unsafe { Event::new(memory.as_ptr().wrapping_add(server_event_len), true) } .map_err(|err| eyre!("failed to open raw client event: {err}"))?; let (disconnect_offset, len_offset, data_offset) = - offsets(server_event_len, client_event_len); + offsets(memory.as_ptr(), server_event_len, client_event_len); server_event .set(EventState::Clear) @@ -68,7 +69,7 @@ impl ShmemChannel { unsafe { Event::from_existing(memory.as_ptr().wrapping_add(server_event_len)) } .map_err(|err| eyre!("failed to open raw client event: {err}"))?; let (disconnect_offset, len_offset, data_offset) = - offsets(server_event_len, client_event_len); + offsets(memory.as_ptr(), server_event_len, client_event_len); Ok(Self { memory, @@ -188,11 +189,31 @@ impl ShmemChannel { } } -fn offsets(server_event_len: usize, client_event_len: usize) -> (usize, usize, usize) { - let disconnect_offset = server_event_len + client_event_len; - let len_offset = disconnect_offset + mem::size_of::(); - let data_offset = len_offset + mem::size_of::(); - (disconnect_offset, len_offset, data_offset) +fn offsets( + base_ptr: *mut u8, + server_event_len: usize, + client_event_len: usize, +) -> (usize, usize, usize) { + let (disconnect, len, data) = offset_ptrs( + base_ptr + .wrapping_add(server_event_len) + .wrapping_add(client_event_len), + ); + let base = base_ptr as usize; + ( + disconnect as usize - base, + len as usize - base, + data as usize - base, + ) +} + +fn offset_ptrs(next_free: *mut u8) -> (*mut AtomicBool, *mut AtomicU64, *mut u8) { + let disconnect_ptr = next_free.wrapping_add(next_free.align_offset(align_of::())); + let len_ptr_unaligned = disconnect_ptr.wrapping_add(mem::size_of::()); + let len_ptr = + len_ptr_unaligned.wrapping_add(len_ptr_unaligned.align_offset(align_of::())); + let data_ptr = len_ptr.wrapping_add(mem::size_of::()); + (disconnect_ptr.cast(), len_ptr.cast(), data_ptr) } unsafe impl Send for ShmemChannel {}