diff --git a/src/iface/socket_set.rs b/src/iface/socket_set.rs index be55fef5d..028e904fa 100644 --- a/src/iface/socket_set.rs +++ b/src/iface/socket_set.rs @@ -29,6 +29,11 @@ pub(crate) struct Item<'a> { #[cfg_attr(feature = "defmt", derive(defmt::Format))] pub struct SocketHandle(usize); +#[cfg(test)] +pub(crate) fn new_handle(index: usize) -> SocketHandle { + SocketHandle(index) +} + impl fmt::Display for SocketHandle { fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { write!(f, "#{}", self.0) @@ -43,6 +48,7 @@ impl fmt::Display for SocketHandle { #[derive(Debug)] pub struct SocketSet<'a> { sockets: ManagedSlice<'a, SocketStorage<'a>>, + first_empty_index: usize, } impl<'a> SocketSet<'a> { @@ -52,7 +58,10 @@ impl<'a> SocketSet<'a> { SocketsT: Into>>, { let sockets = sockets.into(); - SocketSet { sockets } + SocketSet { + sockets, + first_empty_index: 0, + } } /// Add a socket to the set, and return its handle. @@ -73,10 +82,22 @@ impl<'a> SocketSet<'a> { let socket = socket.upcast(); - for (index, slot) in self.sockets.iter_mut().enumerate() { - if slot.inner.is_none() { - return put(index, slot, socket); + if self.first_empty_index < self.sockets.len() { + let handle = put( + self.first_empty_index, + &mut self.sockets[self.first_empty_index], + socket, + ); + + for i in (self.first_empty_index + 1)..self.sockets.len() { + if self.sockets[i].inner.is_none() { + self.first_empty_index = i; + return handle; + } } + + self.first_empty_index = self.sockets.len(); + return handle; } match &mut self.sockets { @@ -85,6 +106,7 @@ impl<'a> SocketSet<'a> { ManagedSlice::Owned(sockets) => { sockets.push(SocketStorage { inner: None }); let index = sockets.len() - 1; + self.first_empty_index = sockets.len(); put(index, &mut sockets[index], socket) } } @@ -124,7 +146,13 @@ impl<'a> SocketSet<'a> { pub fn remove(&mut self, handle: SocketHandle) -> Socket<'a> { net_trace!("[{}]: removing", handle.0); match self.sockets[handle.0].inner.take() { - Some(item) => item.socket, + Some(item) => { + if handle.0 < self.first_empty_index { + self.first_empty_index = handle.0; + } + + item.socket + } None => panic!("handle does not refer to a valid socket"), } } @@ -149,3 +177,81 @@ impl<'a> SocketSet<'a> { self.sockets.iter_mut().filter_map(|x| x.inner.as_mut()) } } + +#[cfg(test)] +#[cfg(all(feature = "socket-tcp", any(feature = "std", feature = "alloc")))] +pub(crate) mod test { + use crate::iface::socket_set::new_handle; + use crate::iface::SocketSet; + use crate::socket::tcp; + use crate::socket::tcp::Socket; + use std::ptr; + + fn gen_owned_socket() -> Socket<'static> { + let rx = tcp::SocketBuffer::new(vec![0; 1]); + let tx = tcp::SocketBuffer::new(vec![0; 1]); + Socket::new(rx, tx) + } + + fn gen_owned_socket_set(size: usize) -> SocketSet<'static> { + let mut socket_set = SocketSet::new(Vec::with_capacity(size)); + for _ in 0..size { + socket_set.add(gen_owned_socket()); + } + + socket_set + } + + #[test] + fn test_add() { + let socket_set = gen_owned_socket_set(5); + assert_eq!(socket_set.first_empty_index, 5); + } + + #[test] + fn test_remove() { + let mut socket_set = gen_owned_socket_set(10); + + let removed_socket = socket_set.remove(new_handle(5)); + for socket in socket_set.iter() { + assert!(!ptr::eq(socket.1, &removed_socket)); + } + + assert_eq!(socket_set.first_empty_index, 5); + } + + #[test] + fn test_remove_add_integrity() { + let mut socket_set = gen_owned_socket_set(10); + + for remove_index in 0..10 { + let removed_socket = socket_set.remove(new_handle(remove_index)); + for socket in socket_set.iter() { + assert!(!ptr::eq(socket.1, &removed_socket)); + } + + let new_socket = gen_owned_socket(); + let handle = socket_set.add(new_socket); + assert_eq!(handle.0, remove_index); + } + + assert_eq!(socket_set.first_empty_index, 10); + } + + #[test] + fn test_full_reconstruct() { + let mut socket_set = gen_owned_socket_set(10); + + for index in 0..10 { + socket_set.remove(new_handle(index)); + } + + assert_eq!(socket_set.first_empty_index, 0); + + for _ in 0..10 { + socket_set.add(gen_owned_socket()); + } + + assert_eq!(socket_set.first_empty_index, 10); + } +}