Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

SocketSet::add optimization #934

Open
wants to merge 4 commits into
base: main
Choose a base branch
from
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
116 changes: 111 additions & 5 deletions src/iface/socket_set.rs
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,11 @@ pub(crate) struct Item<'a> {
#[cfg_attr(feature = "defmt", derive(defmt::Format))]
pub struct SocketHandle(usize);

#[cfg(test)]
pub(crate) fn new_handle(index: usize) -> SocketHandle {
SocketHandle(index)
}

impl fmt::Display for SocketHandle {
fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
write!(f, "#{}", self.0)
Expand All @@ -43,6 +48,7 @@ impl fmt::Display for SocketHandle {
#[derive(Debug)]
pub struct SocketSet<'a> {
sockets: ManagedSlice<'a, SocketStorage<'a>>,
first_empty_index: usize,
}

impl<'a> SocketSet<'a> {
Expand All @@ -52,7 +58,10 @@ impl<'a> SocketSet<'a> {
SocketsT: Into<ManagedSlice<'a, SocketStorage<'a>>>,
{
let sockets = sockets.into();
SocketSet { sockets }
SocketSet {
sockets,
first_empty_index: 0,
}
}

/// Add a socket to the set, and return its handle.
Expand All @@ -73,10 +82,22 @@ impl<'a> SocketSet<'a> {

let socket = socket.upcast();

for (index, slot) in self.sockets.iter_mut().enumerate() {
if slot.inner.is_none() {
return put(index, slot, socket);
if self.first_empty_index < self.sockets.len() {
let handle = put(
self.first_empty_index,
&mut self.sockets[self.first_empty_index],
socket,
);

for i in (self.first_empty_index + 1)..self.sockets.len() {
if self.sockets[i].inner.is_none() {
self.first_empty_index = i;
return handle;
}
}

self.first_empty_index = self.sockets.len();
return handle;
}

match &mut self.sockets {
Expand All @@ -85,6 +106,7 @@ impl<'a> SocketSet<'a> {
ManagedSlice::Owned(sockets) => {
sockets.push(SocketStorage { inner: None });
let index = sockets.len() - 1;
self.first_empty_index = sockets.len();
put(index, &mut sockets[index], socket)
}
}
Expand Down Expand Up @@ -124,7 +146,13 @@ impl<'a> SocketSet<'a> {
pub fn remove(&mut self, handle: SocketHandle) -> Socket<'a> {
net_trace!("[{}]: removing", handle.0);
match self.sockets[handle.0].inner.take() {
Some(item) => item.socket,
Some(item) => {
if handle.0 < self.first_empty_index {
self.first_empty_index = handle.0;
}

item.socket
}
None => panic!("handle does not refer to a valid socket"),
}
}
Expand All @@ -149,3 +177,81 @@ impl<'a> SocketSet<'a> {
self.sockets.iter_mut().filter_map(|x| x.inner.as_mut())
}
}

#[cfg(test)]
#[cfg(all(feature = "socket-tcp", any(feature = "std", feature = "alloc")))]
pub(crate) mod test {
use crate::iface::socket_set::new_handle;
use crate::iface::SocketSet;
use crate::socket::tcp;
use crate::socket::tcp::Socket;
use std::ptr;

fn gen_owned_socket() -> Socket<'static> {
let rx = tcp::SocketBuffer::new(vec![0; 1]);
let tx = tcp::SocketBuffer::new(vec![0; 1]);
Socket::new(rx, tx)
}

fn gen_owned_socket_set(size: usize) -> SocketSet<'static> {
let mut socket_set = SocketSet::new(Vec::with_capacity(size));
for _ in 0..size {
socket_set.add(gen_owned_socket());
}

socket_set
}

#[test]
fn test_add() {
let socket_set = gen_owned_socket_set(5);
assert_eq!(socket_set.first_empty_index, 5);
}

#[test]
fn test_remove() {
let mut socket_set = gen_owned_socket_set(10);

let removed_socket = socket_set.remove(new_handle(5));
for socket in socket_set.iter() {
assert!(!ptr::eq(socket.1, &removed_socket));
}

assert_eq!(socket_set.first_empty_index, 5);
}

#[test]
fn test_remove_add_integrity() {
let mut socket_set = gen_owned_socket_set(10);

for remove_index in 0..10 {
let removed_socket = socket_set.remove(new_handle(remove_index));
for socket in socket_set.iter() {
assert!(!ptr::eq(socket.1, &removed_socket));
}

let new_socket = gen_owned_socket();
let handle = socket_set.add(new_socket);
assert_eq!(handle.0, remove_index);
}

assert_eq!(socket_set.first_empty_index, 10);
}

#[test]
fn test_full_reconstruct() {
let mut socket_set = gen_owned_socket_set(10);

for index in 0..10 {
socket_set.remove(new_handle(index));
}

assert_eq!(socket_set.first_empty_index, 0);

for _ in 0..10 {
socket_set.add(gen_owned_socket());
}

assert_eq!(socket_set.first_empty_index, 10);
}
}
Loading