From de6b5f32761e391f03153968517e579d571eada7 Mon Sep 17 00:00:00 2001 From: Mohsen Zohrevandi Date: Wed, 27 Jan 2021 15:32:13 -0800 Subject: [PATCH] Reorganize tests, address review comments --- async-usercalls/rustfmt.toml | 1 - async-usercalls/src/io_bufs.rs | 65 +++++ async-usercalls/src/lib.rs | 222 ++++++++++++++-- async-usercalls/src/queues.rs | 12 +- async-usercalls/src/raw.rs | 89 +++++++ async-usercalls/src/test_support.rs | 47 ++++ async-usercalls/src/tests.rs | 375 ---------------------------- 7 files changed, 412 insertions(+), 399 deletions(-) delete mode 100644 async-usercalls/rustfmt.toml create mode 100644 async-usercalls/src/test_support.rs delete mode 100644 async-usercalls/src/tests.rs diff --git a/async-usercalls/rustfmt.toml b/async-usercalls/rustfmt.toml deleted file mode 100644 index 75306517..00000000 --- a/async-usercalls/rustfmt.toml +++ /dev/null @@ -1 +0,0 @@ -max_width = 120 diff --git a/async-usercalls/src/io_bufs.rs b/async-usercalls/src/io_bufs.rs index 825a7a8a..a8ede0de 100644 --- a/async-usercalls/src/io_bufs.rs +++ b/async-usercalls/src/io_bufs.rs @@ -257,3 +257,68 @@ impl ReadBuffer { self.userbuf } } + +#[cfg(test)] +mod tests { + use super::*; + use std::os::fortanix_sgx::usercalls::alloc::User; + + #[test] + fn write_buffer_basic() { + const LENGTH: usize = 1024; + let mut write_buffer = WriteBuffer::new(User::<[u8]>::uninitialized(1024)); + + let buf = vec![0u8; LENGTH]; + assert_eq!(write_buffer.write(&buf), LENGTH); + assert_eq!(write_buffer.write(&buf), 0); + + let chunk = write_buffer.consumable_chunk().unwrap(); + write_buffer.consume(chunk, 200); + assert_eq!(write_buffer.write(&buf), 200); + assert_eq!(write_buffer.write(&buf), 0); + } + + #[test] + #[should_panic] + fn call_consumable_chunk_twice() { + const LENGTH: usize = 1024; + let mut write_buffer = WriteBuffer::new(User::<[u8]>::uninitialized(1024)); + + let buf = vec![0u8; LENGTH]; + assert_eq!(write_buffer.write(&buf), LENGTH); + assert_eq!(write_buffer.write(&buf), 0); + + let chunk1 = write_buffer.consumable_chunk().unwrap(); + let _ = write_buffer.consumable_chunk().unwrap(); + drop(chunk1); + } + + #[test] + #[should_panic] + fn consume_wrong_buf() { + const LENGTH: usize = 1024; + let mut write_buffer = WriteBuffer::new(User::<[u8]>::uninitialized(1024)); + + let buf = vec![0u8; LENGTH]; + assert_eq!(write_buffer.write(&buf), LENGTH); + assert_eq!(write_buffer.write(&buf), 0); + + let unrelated_buf: UserBuf = User::<[u8]>::uninitialized(512).into(); + write_buffer.consume(unrelated_buf, 100); + } + + #[test] + fn read_buffer_basic() { + let mut buf = User::<[u8]>::uninitialized(64); + const DATA: &'static [u8] = b"hello"; + buf[0..DATA.len()].copy_from_enclave(DATA); + + let mut read_buffer = ReadBuffer::new(buf, DATA.len()); + assert_eq!(read_buffer.len(), DATA.len()); + assert_eq!(read_buffer.remaining_bytes(), DATA.len()); + let mut buf = [0u8; 8]; + assert_eq!(read_buffer.read(&mut buf), DATA.len()); + assert_eq!(read_buffer.remaining_bytes(), 0); + assert_eq!(&buf, b"hello\0\0\0"); + } +} diff --git a/async-usercalls/src/lib.rs b/async-usercalls/src/lib.rs index ce58b8ad..eccc8f88 100644 --- a/async-usercalls/src/lib.rs +++ b/async-usercalls/src/lib.rs @@ -6,7 +6,6 @@ use crossbeam_channel as mpmc; use ipc_queue::Identified; use std::collections::HashMap; -use std::panic; use std::sync::Mutex; use std::time::Duration; @@ -20,7 +19,7 @@ mod provider_core; mod queues; mod raw; #[cfg(test)] -mod tests; +mod test_support; pub use self::batch_drop::batch_drop; pub use self::callback::CbFn; @@ -127,6 +126,8 @@ pub struct CallbackHandler { } impl CallbackHandler { + const RECV_BATCH_SIZE: usize = 1024; + // Returns an object that can be used to interrupt a blocked `self.poll()`. pub fn waker(&self) -> CallbackHandlerWaker { self.waker.clone() @@ -161,10 +162,7 @@ impl CallbackHandler { /// This can be interrupted using `CallbackHandlerWaker::wake()`. pub fn poll(&self, timeout: Option) -> usize { // 1. wait for returns - let mut returns = [Identified { - id: 0, - data: Return(0, 0), - }; 1024]; + let mut returns = [Identified::default(); Self::RECV_BATCH_SIZE]; let returns = match self.recv_returns(timeout, &mut returns) { 0 => return 0, n => &returns[..n], @@ -190,19 +188,211 @@ impl CallbackHandler { let mut count = 0; for (ret, cb) in ret_callbacks { if let Some(cb) = cb { - let _r = panic::catch_unwind(panic::AssertUnwindSafe(move || { - cb.call(ret.data); - })); + cb.call(ret.data); count += 1; - // if let Err(e) = _r { - // let msg = e - // .downcast_ref::() - // .map(String::as_str) - // .or_else(|| e.downcast_ref::<&str>().map(|&s| s)); - // println!("callback paniced: {:?}", msg); - // } } } count } } + +#[cfg(test)] +mod tests { + use super::*; + use crate::hacks::MakeSend; + use crate::test_support::*; + use crossbeam_channel as mpmc; + use std::io; + use std::net::{TcpListener, TcpStream}; + use std::os::fortanix_sgx::io::AsRawFd; + use std::os::fortanix_sgx::usercalls::alloc::User; + use std::sync::atomic::{AtomicUsize, Ordering}; + use std::sync::Arc; + use std::thread; + use std::time::Duration; + + #[test] + fn cancel_accept() { + let provider = AutoPollingProvider::new(); + let port = 6688; + let addr = format!("0.0.0.0:{}", port); + let (tx, rx) = mpmc::bounded(1); + provider.bind_stream(&addr, move |res| { + tx.send(res).unwrap(); + }); + let bind_res = rx.recv().unwrap(); + let listener = bind_res.unwrap(); + let fd = listener.as_raw_fd(); + let accept_count = Arc::new(AtomicUsize::new(0)); + let accept_count1 = Arc::clone(&accept_count); + let (tx, rx) = mpmc::bounded(1); + let accept = provider.accept_stream(fd, move |res| { + if let Ok(_) = res { + accept_count1.fetch_add(1, Ordering::Relaxed); + } + tx.send(()).unwrap(); + }); + accept.cancel(); + thread::sleep(Duration::from_millis(10)); + let _ = TcpStream::connect(&addr); + let _ = rx.recv(); + assert_eq!(accept_count.load(Ordering::Relaxed), 0); + } + + #[test] + fn connect() { + let listener = TcpListener::bind("0.0.0.0:0").unwrap(); + let addr = listener.local_addr().unwrap().to_string(); + let provider = AutoPollingProvider::new(); + let (tx, rx) = mpmc::bounded(1); + provider.connect_stream(&addr, move |res| { + tx.send(res).unwrap(); + }); + let res = rx.recv().unwrap(); + assert!(res.is_ok()); + } + + #[test] + fn safe_alloc_free() { + let provider = AutoPollingProvider::new(); + + const LEN: usize = 64 * 1024; + let (tx, rx) = mpmc::bounded(1); + provider.alloc_slice::(LEN, move |res| { + let buf = res.expect("failed to allocate memory"); + tx.send(MakeSend::new(buf)).unwrap(); + }); + let user_buf = rx.recv().unwrap().into_inner(); + assert_eq!(user_buf.len(), LEN); + + let (tx, rx) = mpmc::bounded(1); + let cb = move || { + tx.send(()).unwrap(); + }; + provider.free(user_buf, Some(cb)); + rx.recv().unwrap(); + } + + #[test] + fn callback_handler_waker() { + let (_provider, handler) = AsyncUsercallProvider::new(); + let waker = handler.waker(); + let (tx, rx) = mpmc::bounded(1); + let h = thread::spawn(move || { + let n1 = handler.poll(None); + tx.send(()).unwrap(); + let n2 = handler.poll(Some(Duration::from_secs(3))); + tx.send(()).unwrap(); + n1 + n2 + }); + for _ in 0..2 { + waker.wake(); + rx.recv().unwrap(); + } + assert_eq!(h.join().unwrap(), 0); + } + + #[test] + #[ignore] + fn echo() { + println!(); + let provider = Arc::new(AutoPollingProvider::new()); + const ADDR: &'static str = "0.0.0.0:7799"; + let (tx, rx) = mpmc::bounded(1); + provider.bind_stream(ADDR, move |res| { + tx.send(res).unwrap(); + }); + let bind_res = rx.recv().unwrap(); + let listener = bind_res.unwrap(); + println!("bind done: {:?}", listener); + let fd = listener.as_raw_fd(); + let cb = KeepAccepting { + listener, + provider: Arc::clone(&provider), + }; + provider.accept_stream(fd, cb); + thread::sleep(Duration::from_secs(60)); + } + + struct KeepAccepting { + listener: TcpListener, + provider: Arc, + } + + impl FnOnce<(io::Result,)> for KeepAccepting { + type Output = (); + + extern "rust-call" fn call_once(self, args: (io::Result,)) -> Self::Output { + let res = args.0; + println!("accept result: {:?}", res); + if let Ok(stream) = res { + let fd = stream.as_raw_fd(); + let cb = Echo { + stream, + read: true, + provider: self.provider.clone(), + }; + self.provider + .read(fd, User::<[u8]>::uninitialized(Echo::READ_BUF_SIZE), cb); + } + let provider = Arc::clone(&self.provider); + provider.accept_stream(self.listener.as_raw_fd(), self); + } + } + + struct Echo { + stream: TcpStream, + read: bool, + provider: Arc, + } + + impl Echo { + const READ_BUF_SIZE: usize = 1024; + + fn close(self) { + let fd = self.stream.as_raw_fd(); + println!("connection closed, fd = {}", fd); + self.provider.close(fd, None::>); + } + } + + // read callback + impl FnOnce<(io::Result, User<[u8]>)> for Echo { + type Output = (); + + extern "rust-call" fn call_once(mut self, args: (io::Result, User<[u8]>)) -> Self::Output { + let (res, user) = args; + assert!(self.read); + match res { + Ok(len) if len > 0 => { + self.read = false; + let provider = Arc::clone(&self.provider); + provider.write(self.stream.as_raw_fd(), (user, 0..len).into(), self); + } + _ => self.close(), + } + } + } + + // write callback + impl FnOnce<(io::Result, UserBuf)> for Echo { + type Output = (); + + extern "rust-call" fn call_once(mut self, args: (io::Result, UserBuf)) -> Self::Output { + let (res, _) = args; + assert!(!self.read); + match res { + Ok(len) if len > 0 => { + self.read = true; + let provider = Arc::clone(&self.provider); + provider.read( + self.stream.as_raw_fd(), + User::<[u8]>::uninitialized(Echo::READ_BUF_SIZE), + self, + ); + } + _ => self.close(), + } + } + } +} diff --git a/async-usercalls/src/queues.rs b/async-usercalls/src/queues.rs index 75fbf2eb..fc7bbd07 100644 --- a/async-usercalls/src/queues.rs +++ b/async-usercalls/src/queues.rs @@ -80,7 +80,7 @@ struct ReturnHandler { } impl ReturnHandler { - const N: usize = 1024; + const RECV_BATCH_SIZE: usize = 1024; fn send(&self, returns: &[Identified]) { // This should hold the lock only for a short amount of time @@ -90,6 +90,8 @@ impl ReturnHandler { let provider_map = self.provider_map.lock().unwrap(); for ret in returns { let provider_id = (ret.id >> 32) as u32; + // NOTE: some providers might decide not to receive results of usercalls they send + // because the results are not interesting, e.g. BatchDropProvider. if let Some(sender) = provider_map.get(provider_id).and_then(|entry| entry.as_ref()) { let _ = sender.send(*ret); } @@ -97,18 +99,14 @@ impl ReturnHandler { } fn run(self) { - const DEFAULT_RETURN: Identified = Identified { - id: 0, - data: Return(0, 0), - }; + let mut returns = [Identified::default(); Self::RECV_BATCH_SIZE]; loop { - let mut returns = [DEFAULT_RETURN; Self::N]; let first = match self.return_queue_rx.recv() { Ok(ret) => ret, Err(RecvError::Closed) => break, }; let mut count = 0; - for ret in iter::once(first).chain(self.return_queue_rx.try_iter().take(Self::N - 1)) { + for ret in iter::once(first).chain(self.return_queue_rx.try_iter().take(Self::RECV_BATCH_SIZE - 1)) { assert!(ret.id != 0); returns[count] = ret; count += 1; diff --git a/async-usercalls/src/raw.rs b/async-usercalls/src/raw.rs index fb2d4fac..d516bc69 100644 --- a/async-usercalls/src/raw.rs +++ b/async-usercalls/src/raw.rs @@ -153,3 +153,92 @@ impl RawApi for AsyncUsercallProvider { self.send_usercall(u, callback.map(|cb| Callback::Free(cb))); } } + +#[cfg(test)] +mod tests { + use super::*; + use crate::test_support::*; + use crossbeam_channel as mpmc; + use std::io; + use std::sync::atomic::{AtomicPtr, Ordering}; + use std::sync::Arc; + use std::thread; + use std::time::{Duration, UNIX_EPOCH}; + + #[test] + fn get_time_async_raw() { + fn run(tid: u32, provider: AutoPollingProvider) -> (u32, u32, Duration) { + let pid = provider.provider_id(); + const N: usize = 500; + let (tx, rx) = mpmc::bounded(N); + for _ in 0..N { + let tx = tx.clone(); + let cb = move |d| { + let system_time = UNIX_EPOCH + Duration::from_nanos(d); + tx.send(system_time).unwrap(); + }; + unsafe { + provider.raw_insecure_time(Some(cb.into())); + } + } + let mut all = Vec::with_capacity(N); + for _ in 0..N { + all.push(rx.recv().unwrap()); + } + + assert_eq!(all.len(), N); + // The results are returned in arbitrary order + all.sort(); + let t0 = *all.first().unwrap(); + let tn = *all.last().unwrap(); + let total = tn.duration_since(t0).unwrap(); + (tid, pid, total / N as u32) + } + + println!(); + const THREADS: usize = 4; + let mut providers = Vec::with_capacity(THREADS); + for _ in 0..THREADS { + providers.push(AutoPollingProvider::new()); + } + let mut handles = Vec::with_capacity(THREADS); + for (i, provider) in providers.into_iter().enumerate() { + handles.push(thread::spawn(move || run(i as u32, provider))); + } + for h in handles { + let res = h.join().unwrap(); + println!("[{}/{}] (Tn - T0) / N = {:?}", res.0, res.1, res.2); + } + } + + #[test] + fn raw_alloc_free() { + let provider = AutoPollingProvider::new(); + let ptr: Arc> = Arc::new(AtomicPtr::new(0 as _)); + let ptr2 = Arc::clone(&ptr); + const SIZE: usize = 1024; + const ALIGN: usize = 8; + + let (tx, rx) = mpmc::bounded(1); + let cb_alloc = move |p: io::Result<*mut u8>| { + let p = p.unwrap(); + ptr2.store(p, Ordering::Relaxed); + tx.send(()).unwrap(); + }; + unsafe { + provider.raw_alloc(SIZE, ALIGN, Some(cb_alloc.into())); + } + rx.recv().unwrap(); + let p = ptr.load(Ordering::Relaxed); + assert!(!p.is_null()); + + let (tx, rx) = mpmc::bounded(1); + let cb_free = move |()| { + tx.send(()).unwrap(); + }; + unsafe { + provider.raw_free(p, SIZE, ALIGN, Some(cb_free.into())); + } + rx.recv().unwrap(); + } +} diff --git a/async-usercalls/src/test_support.rs b/async-usercalls/src/test_support.rs new file mode 100644 index 00000000..fa3b75bd --- /dev/null +++ b/async-usercalls/src/test_support.rs @@ -0,0 +1,47 @@ +use crate::AsyncUsercallProvider; +use std::ops::Deref; +use std::sync::atomic::{AtomicBool, Ordering}; +use std::sync::Arc; +use std::thread; + +pub(crate) struct AutoPollingProvider { + provider: AsyncUsercallProvider, + shutdown: Arc, + join_handle: Option>, +} + +impl AutoPollingProvider { + pub fn new() -> Self { + let (provider, handler) = AsyncUsercallProvider::new(); + let shutdown = Arc::new(AtomicBool::new(false)); + let shutdown1 = shutdown.clone(); + let join_handle = Some(thread::spawn(move || loop { + handler.poll(None); + if shutdown1.load(Ordering::Relaxed) { + break; + } + })); + Self { + provider, + shutdown, + join_handle, + } + } +} + +impl Deref for AutoPollingProvider { + type Target = AsyncUsercallProvider; + + fn deref(&self) -> &Self::Target { + &self.provider + } +} + +impl Drop for AutoPollingProvider { + fn drop(&mut self) { + self.shutdown.store(true, Ordering::Relaxed); + // send a usercall to ensure thread wakes up + self.provider.insecure_time(|_| {}); + self.join_handle.take().unwrap().join().unwrap(); + } +} diff --git a/async-usercalls/src/tests.rs b/async-usercalls/src/tests.rs deleted file mode 100644 index ff838c48..00000000 --- a/async-usercalls/src/tests.rs +++ /dev/null @@ -1,375 +0,0 @@ -use super::*; -use crate::hacks::MakeSend; -use crossbeam_channel as mpmc; -use std::io; -use std::net::{TcpListener, TcpStream}; -use std::ops::Deref; -use std::os::fortanix_sgx::io::AsRawFd; -use std::os::fortanix_sgx::usercalls::alloc::User; -use std::sync::atomic::{AtomicBool, AtomicPtr, AtomicUsize, Ordering}; -use std::sync::Arc; -use std::thread; -use std::time::{Duration, UNIX_EPOCH}; - -struct AutoPollingProvider { - provider: AsyncUsercallProvider, - shutdown: Arc, - join_handle: Option>, -} - -impl AutoPollingProvider { - fn new() -> Self { - let (provider, handler) = AsyncUsercallProvider::new(); - let shutdown = Arc::new(AtomicBool::new(false)); - let shutdown1 = shutdown.clone(); - let join_handle = Some(thread::spawn(move || loop { - handler.poll(None); - if shutdown1.load(Ordering::Relaxed) { - break; - } - })); - Self { - provider, - shutdown, - join_handle, - } - } -} - -impl Deref for AutoPollingProvider { - type Target = AsyncUsercallProvider; - - fn deref(&self) -> &Self::Target { - &self.provider - } -} - -impl Drop for AutoPollingProvider { - fn drop(&mut self) { - self.shutdown.store(true, Ordering::Relaxed); - // send a usercall to ensure thread wakes up - self.provider.insecure_time(|_| {}); - self.join_handle.take().unwrap().join().unwrap(); - } -} - -#[test] -fn get_time_async_raw() { - fn run(tid: u32, provider: AutoPollingProvider) -> (u32, u32, Duration) { - let pid = provider.provider_id(); - const N: usize = 500; - let (tx, rx) = mpmc::bounded(N); - for _ in 0..N { - let tx = tx.clone(); - let cb = move |d| { - let system_time = UNIX_EPOCH + Duration::from_nanos(d); - tx.send(system_time).unwrap(); - }; - unsafe { - provider.raw_insecure_time(Some(cb.into())); - } - } - let mut all = Vec::with_capacity(N); - for _ in 0..N { - all.push(rx.recv().unwrap()); - } - - assert_eq!(all.len(), N); - // The results are returned in arbitrary order - all.sort(); - let t0 = *all.first().unwrap(); - let tn = *all.last().unwrap(); - let total = tn.duration_since(t0).unwrap(); - (tid, pid, total / N as u32) - } - - println!(); - const THREADS: usize = 4; - let mut providers = Vec::with_capacity(THREADS); - for _ in 0..THREADS { - providers.push(AutoPollingProvider::new()); - } - let mut handles = Vec::with_capacity(THREADS); - for (i, provider) in providers.into_iter().enumerate() { - handles.push(thread::spawn(move || run(i as u32, provider))); - } - for h in handles { - let res = h.join().unwrap(); - println!("[{}/{}] (Tn - T0) / N = {:?}", res.0, res.1, res.2); - } -} - -#[test] -fn raw_alloc_free() { - let provider = AutoPollingProvider::new(); - let ptr: Arc> = Arc::new(AtomicPtr::new(0 as _)); - let ptr2 = Arc::clone(&ptr); - const SIZE: usize = 1024; - const ALIGN: usize = 8; - - let (tx, rx) = mpmc::bounded(1); - let cb_alloc = move |p: io::Result<*mut u8>| { - let p = p.unwrap(); - ptr2.store(p, Ordering::Relaxed); - tx.send(()).unwrap(); - }; - unsafe { - provider.raw_alloc(SIZE, ALIGN, Some(cb_alloc.into())); - } - rx.recv().unwrap(); - let p = ptr.load(Ordering::Relaxed); - assert!(!p.is_null()); - - let (tx, rx) = mpmc::bounded(1); - let cb_free = move |()| { - tx.send(()).unwrap(); - }; - unsafe { - provider.raw_free(p, SIZE, ALIGN, Some(cb_free.into())); - } - rx.recv().unwrap(); -} - -#[test] -fn cancel_accept() { - let provider = AutoPollingProvider::new(); - let port = 6688; - let addr = format!("0.0.0.0:{}", port); - let (tx, rx) = mpmc::bounded(1); - provider.bind_stream(&addr, move |res| { - tx.send(res).unwrap(); - }); - let bind_res = rx.recv().unwrap(); - let listener = bind_res.unwrap(); - let fd = listener.as_raw_fd(); - let accept_count = Arc::new(AtomicUsize::new(0)); - let accept_count1 = Arc::clone(&accept_count); - let (tx, rx) = mpmc::bounded(1); - let accept = provider.accept_stream(fd, move |res| { - if let Ok(_) = res { - accept_count1.fetch_add(1, Ordering::Relaxed); - } - tx.send(()).unwrap(); - }); - accept.cancel(); - thread::sleep(Duration::from_millis(10)); - let _ = TcpStream::connect(&addr); - let _ = rx.recv(); - assert_eq!(accept_count.load(Ordering::Relaxed), 0); -} - -#[test] -fn connect() { - let listener = TcpListener::bind("0.0.0.0:0").unwrap(); - let addr = listener.local_addr().unwrap().to_string(); - let provider = AutoPollingProvider::new(); - let (tx, rx) = mpmc::bounded(1); - provider.connect_stream(&addr, move |res| { - tx.send(res).unwrap(); - }); - let res = rx.recv().unwrap(); - assert!(res.is_ok()); -} - -#[test] -fn safe_alloc_free() { - let provider = AutoPollingProvider::new(); - - const LEN: usize = 64 * 1024; - let (tx, rx) = mpmc::bounded(1); - provider.alloc_slice::(LEN, move |res| { - let buf = res.expect("failed to allocate memory"); - tx.send(MakeSend::new(buf)).unwrap(); - }); - let user_buf = rx.recv().unwrap().into_inner(); - assert_eq!(user_buf.len(), LEN); - - let (tx, rx) = mpmc::bounded(1); - let cb = move || { - tx.send(()).unwrap(); - }; - provider.free(user_buf, Some(cb)); - rx.recv().unwrap(); -} - -#[test] -fn write_buffer_basic() { - const LENGTH: usize = 1024; - let mut write_buffer = WriteBuffer::new(User::<[u8]>::uninitialized(1024)); - - let buf = vec![0u8; LENGTH]; - assert_eq!(write_buffer.write(&buf), LENGTH); - assert_eq!(write_buffer.write(&buf), 0); - - let chunk = write_buffer.consumable_chunk().unwrap(); - write_buffer.consume(chunk, 200); - assert_eq!(write_buffer.write(&buf), 200); - assert_eq!(write_buffer.write(&buf), 0); -} - -#[test] -#[should_panic] -fn call_consumable_chunk_twice() { - const LENGTH: usize = 1024; - let mut write_buffer = WriteBuffer::new(User::<[u8]>::uninitialized(1024)); - - let buf = vec![0u8; LENGTH]; - assert_eq!(write_buffer.write(&buf), LENGTH); - assert_eq!(write_buffer.write(&buf), 0); - - let chunk1 = write_buffer.consumable_chunk().unwrap(); - let _ = write_buffer.consumable_chunk().unwrap(); - drop(chunk1); -} - -#[test] -#[should_panic] -fn consume_wrong_buf() { - const LENGTH: usize = 1024; - let mut write_buffer = WriteBuffer::new(User::<[u8]>::uninitialized(1024)); - - let buf = vec![0u8; LENGTH]; - assert_eq!(write_buffer.write(&buf), LENGTH); - assert_eq!(write_buffer.write(&buf), 0); - - let unrelated_buf: UserBuf = User::<[u8]>::uninitialized(512).into(); - write_buffer.consume(unrelated_buf, 100); -} - -#[test] -fn read_buffer_basic() { - let mut buf = User::<[u8]>::uninitialized(64); - const DATA: &'static [u8] = b"hello"; - buf[0..DATA.len()].copy_from_enclave(DATA); - - let mut read_buffer = ReadBuffer::new(buf, DATA.len()); - assert_eq!(read_buffer.len(), DATA.len()); - assert_eq!(read_buffer.remaining_bytes(), DATA.len()); - let mut buf = [0u8; 8]; - assert_eq!(read_buffer.read(&mut buf), DATA.len()); - assert_eq!(read_buffer.remaining_bytes(), 0); - assert_eq!(&buf, b"hello\0\0\0"); -} - -#[test] -fn callback_handler_waker() { - let (_provider, handler) = AsyncUsercallProvider::new(); - let waker = handler.waker(); - let (tx, rx) = mpmc::bounded(1); - let h = thread::spawn(move || { - let n1 = handler.poll(None); - tx.send(()).unwrap(); - let n2 = handler.poll(Some(Duration::from_secs(3))); - tx.send(()).unwrap(); - n1 + n2 - }); - for _ in 0..2 { - waker.wake(); - rx.recv().unwrap(); - } - assert_eq!(h.join().unwrap(), 0); -} - -#[test] -#[ignore] -fn echo() { - println!(); - let provider = Arc::new(AutoPollingProvider::new()); - const ADDR: &'static str = "0.0.0.0:7799"; - let (tx, rx) = mpmc::bounded(1); - provider.bind_stream(ADDR, move |res| { - tx.send(res).unwrap(); - }); - let bind_res = rx.recv().unwrap(); - let listener = bind_res.unwrap(); - println!("bind done: {:?}", listener); - let fd = listener.as_raw_fd(); - let cb = KeepAccepting { - listener, - provider: Arc::clone(&provider), - }; - provider.accept_stream(fd, cb); - thread::sleep(Duration::from_secs(60)); -} - -struct KeepAccepting { - listener: TcpListener, - provider: Arc, -} - -impl FnOnce<(io::Result,)> for KeepAccepting { - type Output = (); - - extern "rust-call" fn call_once(self, args: (io::Result,)) -> Self::Output { - let res = args.0; - println!("accept result: {:?}", res); - if let Ok(stream) = res { - let fd = stream.as_raw_fd(); - let cb = Echo { - stream, - read: true, - provider: self.provider.clone(), - }; - self.provider - .read(fd, User::<[u8]>::uninitialized(Echo::READ_BUF_SIZE), cb); - } - let provider = Arc::clone(&self.provider); - provider.accept_stream(self.listener.as_raw_fd(), self); - } -} - -struct Echo { - stream: TcpStream, - read: bool, - provider: Arc, -} - -impl Echo { - const READ_BUF_SIZE: usize = 1024; - - fn close(self) { - let fd = self.stream.as_raw_fd(); - println!("connection closed, fd = {}", fd); - self.provider.close(fd, None::>); - } -} - -// read callback -impl FnOnce<(io::Result, User<[u8]>)> for Echo { - type Output = (); - - extern "rust-call" fn call_once(mut self, args: (io::Result, User<[u8]>)) -> Self::Output { - let (res, user) = args; - assert!(self.read); - match res { - Ok(len) if len > 0 => { - self.read = false; - let provider = Arc::clone(&self.provider); - provider.write(self.stream.as_raw_fd(), (user, 0..len).into(), self); - } - _ => self.close(), - } - } -} - -// write callback -impl FnOnce<(io::Result, UserBuf)> for Echo { - type Output = (); - - extern "rust-call" fn call_once(mut self, args: (io::Result, UserBuf)) -> Self::Output { - let (res, _) = args; - assert!(!self.read); - match res { - Ok(len) if len > 0 => { - self.read = true; - let provider = Arc::clone(&self.provider); - provider.read( - self.stream.as_raw_fd(), - User::<[u8]>::uninitialized(Echo::READ_BUF_SIZE), - self, - ); - } - _ => self.close(), - } - } -}