Skip to content

Commit

Permalink
Merge pull request #108 from MaxVerevkin/msgread
Browse files Browse the repository at this point in the history
read the enitre message in one go without copies
  • Loading branch information
KillingSpark authored Feb 27, 2024
2 parents 28127b3 + 885e865 commit af9be8e
Show file tree
Hide file tree
Showing 16 changed files with 80 additions and 94 deletions.
6 changes: 3 additions & 3 deletions rustbus/benches/marshal_benchmark.rs
Original file line number Diff line number Diff line change
Expand Up @@ -12,10 +12,10 @@ fn marsh(msg: &rustbus::message_builder::MarshalledMessage, buf: &mut Vec<u8>) {
}

fn unmarshal(buf: &[u8]) {
let (hdrbytes, header) = unmarshal_header(&buf, 0).unwrap();
let (dynhdrbytes, dynheader) = unmarshal_dynamic_header(&header, &buf, hdrbytes).unwrap();
let (hdrbytes, header) = unmarshal_header(buf, 0).unwrap();
let (dynhdrbytes, dynheader) = unmarshal_dynamic_header(&header, buf, hdrbytes).unwrap();
let (_, _unmarshed_msg) =
unmarshal_next_message(&header, dynheader, &buf, hdrbytes + dynhdrbytes).unwrap();
unmarshal_next_message(&header, dynheader, buf.to_vec(), hdrbytes + dynhdrbytes).unwrap();
}

fn criterion_benchmark(c: &mut Criterion) {
Expand Down
7 changes: 2 additions & 5 deletions rustbus/examples/conn.rs
Original file line number Diff line number Diff line change
Expand Up @@ -35,10 +35,7 @@ fn main() -> Result<(), rustbus::connection::Error> {
println!("\n");

let reqname_serial = rpc_con
.send_message(&mut standard_messages::request_name(
"io.killing.spark".into(),
0,
))?
.send_message(&mut standard_messages::request_name("io.killing.spark", 0))?
.write_all()
.unwrap();

Expand Down Expand Up @@ -73,7 +70,7 @@ fn main() -> Result<(), rustbus::connection::Error> {
println!("\n");
println!("\n");

let mut sig_listen_msg = standard_messages::add_match("type='signal'".into());
let mut sig_listen_msg = standard_messages::add_match("type='signal'");

//println!("Send message: {:?}", sig_listen_msg);
rpc_con
Expand Down
4 changes: 2 additions & 2 deletions rustbus/examples/dispatch.rs
Original file line number Diff line number Diff line change
Expand Up @@ -70,10 +70,10 @@ fn main() {
con.send_hello(rustbus::connection::Timeout::Infinite)
.unwrap();

if std::env::args().find(|arg| "server".eq(arg)).is_some() {
if std::env::args().any(|arg| "server".eq(&arg)) {
con.send
.send_message(&mut rustbus::standard_messages::request_name(
"killing.spark.io".into(),
"killing.spark.io",
rustbus::standard_messages::DBUS_NAME_FLAG_REPLACE_EXISTING,
))
.unwrap()
Expand Down
7 changes: 3 additions & 4 deletions rustbus/examples/fd.rs
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@ fn main() -> Result<(), rustbus::connection::Error> {
.write_all()
.unwrap();

con.send_message(&mut standard_messages::add_match("type='signal'".into()))?
con.send_message(&mut standard_messages::add_match("type='signal'"))?
.write_all()
.unwrap();

Expand All @@ -31,10 +31,9 @@ fn main() -> Result<(), rustbus::connection::Error> {
.dynheader
.interface
.eq(&Some("io.killing.spark".to_owned()))
&& signal.dynheader.member.eq(&Some("TestSignal".to_owned()))
{
if signal.dynheader.member.eq(&Some("TestSignal".to_owned())) {
break signal;
}
break signal;
}
};

Expand Down
5 changes: 1 addition & 4 deletions rustbus/examples/server.rs
Original file line number Diff line number Diff line change
Expand Up @@ -32,10 +32,7 @@ fn main() -> Result<(), rustbus::connection::Error> {
let mut rpc_con = RpcConn::session_conn(Timeout::Infinite)?;

let namereq_serial = rpc_con
.send_message(&mut standard_messages::request_name(
"io.killing.spark".into(),
0,
))?
.send_message(&mut standard_messages::request_name("io.killing.spark", 0))?
.write_all()
.unwrap();
let resp = rpc_con.wait_response(namereq_serial, Timeout::Infinite)?;
Expand Down
2 changes: 1 addition & 1 deletion rustbus/fuzz/fuzz_targets/fuzz_unmarshal.rs
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@ fuzz_target!(|data: &[u8]| {
let (_bytes_used, msg) = match rustbus::wire::unmarshal::unmarshal_next_message(
&header,
dynheader,
data,
data.to_vec(),
hdrbytes + dynhdrbytes,
) {
Ok(msg) => msg,
Expand Down
2 changes: 1 addition & 1 deletion rustbus/src/bin/fuzz_artifact.rs
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,7 @@ fn run_artifact(path: &str) {
let (_bytes_used, msg) = match rustbus::wire::unmarshal::unmarshal_next_message(
&header,
dynheader,
data,
data.clone(),
hdrbytes + dynhdrbytes,
) {
Ok(msg) => msg,
Expand Down
41 changes: 22 additions & 19 deletions rustbus/src/connection/ll_conn.rs
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,9 @@ pub struct RecvConn {
stream: UnixStream,

msg_buf_in: Vec<u8>,
msg_buf_filled: usize,
cmsgs_in: Vec<ControlMessageOwned>,
cmsgspace: Vec<u8>,
}

pub struct DuplexConn {
Expand All @@ -56,14 +58,13 @@ impl RecvConn {
/// Reads from the source once but takes care that the internal buffer only reaches at maximum max_buffer_size
/// so we can process messages separatly and avoid leaking file descriptors to wrong messages
fn refill_buffer(&mut self, max_buffer_size: usize, timeout: Timeout) -> Result<()> {
let bytes_to_read = max_buffer_size - self.msg_buf_in.len();

const BUFSIZE: usize = 512;
let mut tmpbuf = [0u8; BUFSIZE];
if self.msg_buf_in.len() != max_buffer_size {
self.msg_buf_in.resize(max_buffer_size, 0);
}

let iovec = IoSliceMut::new(&mut tmpbuf[..usize::min(bytes_to_read, BUFSIZE)]);
let iovec = IoSliceMut::new(&mut self.msg_buf_in[self.msg_buf_filled..max_buffer_size]);

let mut cmsgspace = cmsg_space!([RawFd; 10]);
self.cmsgspace.clear();
let flags = MsgFlags::empty();

let old_timeout = self.stream.read_timeout()?;
Expand All @@ -82,7 +83,7 @@ impl RecvConn {
let msg = recvmsg::<SockaddrStorage>(
self.stream.as_raw_fd(),
iovec_mut,
Some(&mut cmsgspace),
Some(&mut self.cmsgspace),
flags,
)
.map_err(|e| match e {
Expand All @@ -101,19 +102,18 @@ impl RecvConn {

self.cmsgs_in.extend(msg.cmsgs());
let bytes = msg.bytes;
self.msg_buf_in.extend_from_slice(&tmpbuf[..bytes]);
self.msg_buf_filled += bytes;
Ok(())
}

pub fn bytes_needed_for_current_message(&self) -> Result<usize> {
if self.msg_buf_in.len() < 16 {
if self.msg_buf_filled < 16 {
return Ok(16);
}
let (_, header) = unmarshal::unmarshal_header(&self.msg_buf_in, 0)?;
let (_, header_fields_len) = crate::wire::util::parse_u32(
&self.msg_buf_in[unmarshal::HEADER_LEN..],
header.byteorder,
)?;
let msg_buf_in = &self.msg_buf_in[..self.msg_buf_filled];
let (_, header) = unmarshal::unmarshal_header(msg_buf_in, 0)?;
let (_, header_fields_len) =
crate::wire::util::parse_u32(&msg_buf_in[unmarshal::HEADER_LEN..], header.byteorder)?;
let complete_header_size = unmarshal::HEADER_LEN + header_fields_len as usize + 4; // +4 because the length of the header fields does not count

let padding_between_header_and_body = 8 - ((complete_header_size) % 8);
Expand All @@ -130,7 +130,7 @@ impl RecvConn {

// Checks if the internal buffer currently holds a complete message
pub fn buffer_contains_whole_message(&self) -> Result<bool> {
if self.msg_buf_in.len() < 16 {
if self.msg_buf_filled < 16 {
return Ok(false);
}
let bytes_needed = self.bytes_needed_for_current_message();
Expand All @@ -142,7 +142,7 @@ impl RecvConn {
Err(e)
}
}
Ok(bytes_needed) => Ok(self.msg_buf_in.len() >= bytes_needed),
Ok(bytes_needed) => Ok(self.msg_buf_filled >= bytes_needed),
}
}
/// Blocks until a message has been read from the conn or the timeout has been reached
Expand Down Expand Up @@ -170,21 +170,22 @@ impl RecvConn {
/// Blocks until a message has been read from the conn or the timeout has been reached
pub fn get_next_message(&mut self, timeout: Timeout) -> Result<MarshalledMessage> {
self.read_whole_message(timeout)?;
debug_assert_eq!(self.msg_buf_filled, self.msg_buf_in.len());
let (hdrbytes, header) = unmarshal::unmarshal_header(&self.msg_buf_in, 0)?;
let (dynhdrbytes, dynheader) =
unmarshal::unmarshal_dynamic_header(&header, &self.msg_buf_in, hdrbytes)?;

let (bytes_used, mut msg) = unmarshal::unmarshal_next_message(
&header,
dynheader,
&self.msg_buf_in,
std::mem::take(&mut self.msg_buf_in),
hdrbytes + dynhdrbytes,
)?;

if self.msg_buf_in.len() != bytes_used + hdrbytes + dynhdrbytes {
if self.msg_buf_filled != bytes_used + hdrbytes + dynhdrbytes {
return Err(Error::UnmarshalError(UnmarshalError::NotAllBytesUsed));
}
self.msg_buf_in.clear();
self.msg_buf_filled = 0;

for cmsg in &self.cmsgs_in {
match cmsg {
Expand Down Expand Up @@ -481,7 +482,9 @@ impl DuplexConn {
},
recv: RecvConn {
msg_buf_in: Vec::new(),
msg_buf_filled: 0,
cmsgs_in: Vec::new(),
cmsgspace: cmsg_space!([RawFd; 10]),
stream,
},
})
Expand Down
3 changes: 2 additions & 1 deletion rustbus/src/tests.rs
Original file line number Diff line number Diff line change
Expand Up @@ -48,7 +48,8 @@ fn test_marshal_unmarshal() {
let headers_plus_padding = hdrbytes + dynhdrbytes + (8 - ((hdrbytes + dynhdrbytes) % 8));
assert_eq!(headers_plus_padding, buf.len());

let (_, unmarshed_msg) = unmarshal_next_message(&header, dynheader, msg.get_buf(), 0).unwrap();
let (_, unmarshed_msg) =
unmarshal_next_message(&header, dynheader, msg.get_buf().to_vec(), 0).unwrap();

let msg = unmarshed_msg.unmarshall_all().unwrap();

Expand Down
16 changes: 8 additions & 8 deletions rustbus/src/tests/dbus_send.rs
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,7 @@ fn test_dbus_send_comp() -> Result<(), crate::connection::Error> {
// Request name
let reqname_serial = rpc_con
.send_message(&mut standard_messages::request_name(
"io.killing.spark.dbustest".into(),
"io.killing.spark.dbustest",
0,
))?
.write_all()
Expand All @@ -44,7 +44,7 @@ fn test_dbus_send_comp() -> Result<(), crate::connection::Error> {
)?;

let sig_serial = rpc_con
.send_message(&mut standard_messages::add_match("type='signal'".into()))?
.send_message(&mut standard_messages::add_match("type='signal'"))?
.write_all()
.map_err(force_finish_on_error)?;
let _msg = rpc_con.wait_response(
Expand All @@ -53,7 +53,7 @@ fn test_dbus_send_comp() -> Result<(), crate::connection::Error> {
)?;

std::process::Command::new("dbus-send")
.args(&[
.args([
"--dest=io.killing.spark.dbustest",
"/",
"io.killing.spark.dbustest.Member",
Expand All @@ -64,7 +64,7 @@ fn test_dbus_send_comp() -> Result<(), crate::connection::Error> {
.unwrap();

std::process::Command::new("dbus-send")
.args(&[
.args([
"--dest=io.killing.spark.dbustest",
"/",
"io.killing.spark.dbustest.Member",
Expand All @@ -76,7 +76,7 @@ fn test_dbus_send_comp() -> Result<(), crate::connection::Error> {
.unwrap();

std::process::Command::new("dbus-send")
.args(&[
.args([
"--dest=io.killing.spark.dbustest",
"/",
"io.killing.spark.dbustest.Member",
Expand All @@ -88,7 +88,7 @@ fn test_dbus_send_comp() -> Result<(), crate::connection::Error> {
.unwrap();

std::process::Command::new("dbus-send")
.args(&[
.args([
"--dest=io.killing.spark.dbustest",
"/",
"io.killing.spark.dbustest.Member",
Expand All @@ -100,7 +100,7 @@ fn test_dbus_send_comp() -> Result<(), crate::connection::Error> {
.unwrap();

std::process::Command::new("dbus-send")
.args(&[
.args([
"--dest=io.killing.spark.dbustest",
"/",
"io.killing.spark.dbustest.Member",
Expand All @@ -116,7 +116,7 @@ fn test_dbus_send_comp() -> Result<(), crate::connection::Error> {
.unwrap();

std::process::Command::new("dbus-send")
.args(&[
.args([
"--dest=io.killing.spark.dbustest",
"/",
"io.killing.spark.dbustest.Member",
Expand Down
15 changes: 6 additions & 9 deletions rustbus/src/tests/fdpassing.rs
Original file line number Diff line number Diff line change
Expand Up @@ -20,12 +20,10 @@ fn test_fd_passing() {
.unwrap()
.write_all()
.unwrap();
con2.send_message(&mut crate::standard_messages::add_match(
"type='signal'".into(),
))
.unwrap()
.write_all()
.unwrap();
con2.send_message(&mut crate::standard_messages::add_match("type='signal'"))
.unwrap()
.write_all()
.unwrap();

std::thread::sleep(std::time::Duration::from_secs(1));

Expand All @@ -39,10 +37,9 @@ fn test_fd_passing() {
.dynheader
.interface
.eq(&Some("io.killing.spark".to_owned()))
&& signal.dynheader.member.eq(&Some("TestSignal".to_owned()))
{
if signal.dynheader.member.eq(&Some("TestSignal".to_owned())) {
break signal;
}
break signal;
}
};

Expand Down
2 changes: 1 addition & 1 deletion rustbus/src/tests/verify_marshalling.rs
Original file line number Diff line number Diff line change
Expand Up @@ -244,7 +244,7 @@ fn verify_dict_marshalling() {
byteorder: ByteOrder::LittleEndian,
};
let ctx = &mut ctx;
(&map).marshal(ctx).unwrap();
map.marshal(ctx).unwrap();
assert_eq!(
ctx.buf,
// Note the longer \0 chain after the length. This is the needed padding after the u32 length and the dict-entry
Expand Down
14 changes: 6 additions & 8 deletions rustbus/src/wire/unmarshal.rs
Original file line number Diff line number Diff line change
Expand Up @@ -148,11 +148,11 @@ pub fn unmarshal_body<'a, 'e>(
pub fn unmarshal_next_message(
header: &Header,
dynheader: DynamicHeader,
buf: &[u8],
mut buf: Vec<u8>,
offset: usize,
) -> UnmarshalResult<MarshalledMessage> {
let sig = dynheader.signature.clone().unwrap_or_else(|| "".to_owned());
let padding = align_offset(8, buf, offset)?;
let padding = align_offset(8, &buf, offset)?;

if header.body_len == 0 {
let msg = MarshalledMessage {
Expand All @@ -169,14 +169,12 @@ pub fn unmarshal_next_message(
return Err(UnmarshalError::NotEnoughBytes);
}

// TODO: keep the offset around instead of shifting the bytes.
drop(buf.drain(..offset));

let msg = MarshalledMessage {
dynheader,
body: MarshalledMessageBody::from_parts(
buf[offset..].to_vec(),
vec![],
sig,
header.byteorder,
),
body: MarshalledMessageBody::from_parts(buf, vec![], sig, header.byteorder),
typ: header.typ,
flags: header.flags,
};
Expand Down
Loading

0 comments on commit af9be8e

Please sign in to comment.