Skip to content

Commit

Permalink
Reorganize tests, address review comments
Browse files Browse the repository at this point in the history
  • Loading branch information
mzohreva committed Feb 26, 2021
1 parent 7668388 commit de6b5f3
Show file tree
Hide file tree
Showing 7 changed files with 412 additions and 399 deletions.
1 change: 0 additions & 1 deletion async-usercalls/rustfmt.toml

This file was deleted.

65 changes: 65 additions & 0 deletions async-usercalls/src/io_bufs.rs
Original file line number Diff line number Diff line change
Expand Up @@ -257,3 +257,68 @@ impl ReadBuffer {
self.userbuf
}
}

#[cfg(test)]
mod tests {
use super::*;
use std::os::fortanix_sgx::usercalls::alloc::User;

#[test]
fn write_buffer_basic() {
const LENGTH: usize = 1024;
let mut write_buffer = WriteBuffer::new(User::<[u8]>::uninitialized(1024));

let buf = vec![0u8; LENGTH];
assert_eq!(write_buffer.write(&buf), LENGTH);
assert_eq!(write_buffer.write(&buf), 0);

let chunk = write_buffer.consumable_chunk().unwrap();
write_buffer.consume(chunk, 200);
assert_eq!(write_buffer.write(&buf), 200);
assert_eq!(write_buffer.write(&buf), 0);
}

#[test]
#[should_panic]
fn call_consumable_chunk_twice() {
const LENGTH: usize = 1024;
let mut write_buffer = WriteBuffer::new(User::<[u8]>::uninitialized(1024));

let buf = vec![0u8; LENGTH];
assert_eq!(write_buffer.write(&buf), LENGTH);
assert_eq!(write_buffer.write(&buf), 0);

let chunk1 = write_buffer.consumable_chunk().unwrap();
let _ = write_buffer.consumable_chunk().unwrap();
drop(chunk1);
}

#[test]
#[should_panic]
fn consume_wrong_buf() {
const LENGTH: usize = 1024;
let mut write_buffer = WriteBuffer::new(User::<[u8]>::uninitialized(1024));

let buf = vec![0u8; LENGTH];
assert_eq!(write_buffer.write(&buf), LENGTH);
assert_eq!(write_buffer.write(&buf), 0);

let unrelated_buf: UserBuf = User::<[u8]>::uninitialized(512).into();
write_buffer.consume(unrelated_buf, 100);
}

#[test]
fn read_buffer_basic() {
let mut buf = User::<[u8]>::uninitialized(64);
const DATA: &'static [u8] = b"hello";
buf[0..DATA.len()].copy_from_enclave(DATA);

let mut read_buffer = ReadBuffer::new(buf, DATA.len());
assert_eq!(read_buffer.len(), DATA.len());
assert_eq!(read_buffer.remaining_bytes(), DATA.len());
let mut buf = [0u8; 8];
assert_eq!(read_buffer.read(&mut buf), DATA.len());
assert_eq!(read_buffer.remaining_bytes(), 0);
assert_eq!(&buf, b"hello\0\0\0");
}
}
222 changes: 206 additions & 16 deletions async-usercalls/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,6 @@
use crossbeam_channel as mpmc;
use ipc_queue::Identified;
use std::collections::HashMap;
use std::panic;
use std::sync::Mutex;
use std::time::Duration;

Expand All @@ -20,7 +19,7 @@ mod provider_core;
mod queues;
mod raw;
#[cfg(test)]
mod tests;
mod test_support;

pub use self::batch_drop::batch_drop;
pub use self::callback::CbFn;
Expand Down Expand Up @@ -127,6 +126,8 @@ pub struct CallbackHandler {
}

impl CallbackHandler {
const RECV_BATCH_SIZE: usize = 1024;

// Returns an object that can be used to interrupt a blocked `self.poll()`.
pub fn waker(&self) -> CallbackHandlerWaker {
self.waker.clone()
Expand Down Expand Up @@ -161,10 +162,7 @@ impl CallbackHandler {
/// This can be interrupted using `CallbackHandlerWaker::wake()`.
pub fn poll(&self, timeout: Option<Duration>) -> usize {
// 1. wait for returns
let mut returns = [Identified {
id: 0,
data: Return(0, 0),
}; 1024];
let mut returns = [Identified::default(); Self::RECV_BATCH_SIZE];
let returns = match self.recv_returns(timeout, &mut returns) {
0 => return 0,
n => &returns[..n],
Expand All @@ -190,19 +188,211 @@ impl CallbackHandler {
let mut count = 0;
for (ret, cb) in ret_callbacks {
if let Some(cb) = cb {
let _r = panic::catch_unwind(panic::AssertUnwindSafe(move || {
cb.call(ret.data);
}));
cb.call(ret.data);
count += 1;
// if let Err(e) = _r {
// let msg = e
// .downcast_ref::<String>()
// .map(String::as_str)
// .or_else(|| e.downcast_ref::<&str>().map(|&s| s));
// println!("callback paniced: {:?}", msg);
// }
}
}
count
}
}

#[cfg(test)]
mod tests {
use super::*;
use crate::hacks::MakeSend;
use crate::test_support::*;
use crossbeam_channel as mpmc;
use std::io;
use std::net::{TcpListener, TcpStream};
use std::os::fortanix_sgx::io::AsRawFd;
use std::os::fortanix_sgx::usercalls::alloc::User;
use std::sync::atomic::{AtomicUsize, Ordering};
use std::sync::Arc;
use std::thread;
use std::time::Duration;

#[test]
fn cancel_accept() {
let provider = AutoPollingProvider::new();
let port = 6688;
let addr = format!("0.0.0.0:{}", port);
let (tx, rx) = mpmc::bounded(1);
provider.bind_stream(&addr, move |res| {
tx.send(res).unwrap();
});
let bind_res = rx.recv().unwrap();
let listener = bind_res.unwrap();
let fd = listener.as_raw_fd();
let accept_count = Arc::new(AtomicUsize::new(0));
let accept_count1 = Arc::clone(&accept_count);
let (tx, rx) = mpmc::bounded(1);
let accept = provider.accept_stream(fd, move |res| {
if let Ok(_) = res {
accept_count1.fetch_add(1, Ordering::Relaxed);
}
tx.send(()).unwrap();
});
accept.cancel();
thread::sleep(Duration::from_millis(10));
let _ = TcpStream::connect(&addr);
let _ = rx.recv();
assert_eq!(accept_count.load(Ordering::Relaxed), 0);
}

#[test]
fn connect() {
let listener = TcpListener::bind("0.0.0.0:0").unwrap();
let addr = listener.local_addr().unwrap().to_string();
let provider = AutoPollingProvider::new();
let (tx, rx) = mpmc::bounded(1);
provider.connect_stream(&addr, move |res| {
tx.send(res).unwrap();
});
let res = rx.recv().unwrap();
assert!(res.is_ok());
}

#[test]
fn safe_alloc_free() {
let provider = AutoPollingProvider::new();

const LEN: usize = 64 * 1024;
let (tx, rx) = mpmc::bounded(1);
provider.alloc_slice::<u8, _>(LEN, move |res| {
let buf = res.expect("failed to allocate memory");
tx.send(MakeSend::new(buf)).unwrap();
});
let user_buf = rx.recv().unwrap().into_inner();
assert_eq!(user_buf.len(), LEN);

let (tx, rx) = mpmc::bounded(1);
let cb = move || {
tx.send(()).unwrap();
};
provider.free(user_buf, Some(cb));
rx.recv().unwrap();
}

#[test]
fn callback_handler_waker() {
let (_provider, handler) = AsyncUsercallProvider::new();
let waker = handler.waker();
let (tx, rx) = mpmc::bounded(1);
let h = thread::spawn(move || {
let n1 = handler.poll(None);
tx.send(()).unwrap();
let n2 = handler.poll(Some(Duration::from_secs(3)));
tx.send(()).unwrap();
n1 + n2
});
for _ in 0..2 {
waker.wake();
rx.recv().unwrap();
}
assert_eq!(h.join().unwrap(), 0);
}

#[test]
#[ignore]
fn echo() {
println!();
let provider = Arc::new(AutoPollingProvider::new());
const ADDR: &'static str = "0.0.0.0:7799";
let (tx, rx) = mpmc::bounded(1);
provider.bind_stream(ADDR, move |res| {
tx.send(res).unwrap();
});
let bind_res = rx.recv().unwrap();
let listener = bind_res.unwrap();
println!("bind done: {:?}", listener);
let fd = listener.as_raw_fd();
let cb = KeepAccepting {
listener,
provider: Arc::clone(&provider),
};
provider.accept_stream(fd, cb);
thread::sleep(Duration::from_secs(60));
}

struct KeepAccepting {
listener: TcpListener,
provider: Arc<AutoPollingProvider>,
}

impl FnOnce<(io::Result<TcpStream>,)> for KeepAccepting {
type Output = ();

extern "rust-call" fn call_once(self, args: (io::Result<TcpStream>,)) -> Self::Output {
let res = args.0;
println!("accept result: {:?}", res);
if let Ok(stream) = res {
let fd = stream.as_raw_fd();
let cb = Echo {
stream,
read: true,
provider: self.provider.clone(),
};
self.provider
.read(fd, User::<[u8]>::uninitialized(Echo::READ_BUF_SIZE), cb);
}
let provider = Arc::clone(&self.provider);
provider.accept_stream(self.listener.as_raw_fd(), self);
}
}

struct Echo {
stream: TcpStream,
read: bool,
provider: Arc<AutoPollingProvider>,
}

impl Echo {
const READ_BUF_SIZE: usize = 1024;

fn close(self) {
let fd = self.stream.as_raw_fd();
println!("connection closed, fd = {}", fd);
self.provider.close(fd, None::<Box<dyn FnOnce() + Send>>);
}
}

// read callback
impl FnOnce<(io::Result<usize>, User<[u8]>)> for Echo {
type Output = ();

extern "rust-call" fn call_once(mut self, args: (io::Result<usize>, User<[u8]>)) -> Self::Output {
let (res, user) = args;
assert!(self.read);
match res {
Ok(len) if len > 0 => {
self.read = false;
let provider = Arc::clone(&self.provider);
provider.write(self.stream.as_raw_fd(), (user, 0..len).into(), self);
}
_ => self.close(),
}
}
}

// write callback
impl FnOnce<(io::Result<usize>, UserBuf)> for Echo {
type Output = ();

extern "rust-call" fn call_once(mut self, args: (io::Result<usize>, UserBuf)) -> Self::Output {
let (res, _) = args;
assert!(!self.read);
match res {
Ok(len) if len > 0 => {
self.read = true;
let provider = Arc::clone(&self.provider);
provider.read(
self.stream.as_raw_fd(),
User::<[u8]>::uninitialized(Echo::READ_BUF_SIZE),
self,
);
}
_ => self.close(),
}
}
}
}
12 changes: 5 additions & 7 deletions async-usercalls/src/queues.rs
Original file line number Diff line number Diff line change
Expand Up @@ -80,7 +80,7 @@ struct ReturnHandler {
}

impl ReturnHandler {
const N: usize = 1024;
const RECV_BATCH_SIZE: usize = 1024;

fn send(&self, returns: &[Identified<Return>]) {
// This should hold the lock only for a short amount of time
Expand All @@ -90,25 +90,23 @@ impl ReturnHandler {
let provider_map = self.provider_map.lock().unwrap();
for ret in returns {
let provider_id = (ret.id >> 32) as u32;
// NOTE: some providers might decide not to receive results of usercalls they send
// because the results are not interesting, e.g. BatchDropProvider.
if let Some(sender) = provider_map.get(provider_id).and_then(|entry| entry.as_ref()) {
let _ = sender.send(*ret);
}
}
}

fn run(self) {
const DEFAULT_RETURN: Identified<Return> = Identified {
id: 0,
data: Return(0, 0),
};
let mut returns = [Identified::default(); Self::RECV_BATCH_SIZE];
loop {
let mut returns = [DEFAULT_RETURN; Self::N];
let first = match self.return_queue_rx.recv() {
Ok(ret) => ret,
Err(RecvError::Closed) => break,
};
let mut count = 0;
for ret in iter::once(first).chain(self.return_queue_rx.try_iter().take(Self::N - 1)) {
for ret in iter::once(first).chain(self.return_queue_rx.try_iter().take(Self::RECV_BATCH_SIZE - 1)) {
assert!(ret.id != 0);
returns[count] = ret;
count += 1;
Expand Down
Loading

0 comments on commit de6b5f3

Please sign in to comment.