Skip to content

Commit

Permalink
lz4
Browse files Browse the repository at this point in the history
  • Loading branch information
james-rms committed Aug 14, 2024
1 parent 0f675fe commit bfe547a
Show file tree
Hide file tree
Showing 5 changed files with 232 additions and 8 deletions.
22 changes: 22 additions & 0 deletions rust_async/Cargo.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

2 changes: 2 additions & 0 deletions rust_async/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -10,4 +10,6 @@ thiserror = "1"
mcap = { path = "../rust" }
async-compression = { version = "*", features = ["tokio", "zstd"]}
tokio = { version = "1", features = ["full"] }
lz4 = "1"
binrw = "0.12.0"

Check failure on line 14 in rust_async/Cargo.toml

View workflow job for this annotation

GitHub Actions / spellcheck

Unknown word (binrw)
libc = "0.2.44"

Check failure on line 15 in rust_async/Cargo.toml

View workflow job for this annotation

GitHub Actions / spellcheck

Unknown word (libc)
2 changes: 2 additions & 0 deletions rust_async/src/lib.rs
Original file line number Diff line number Diff line change
@@ -1 +1,3 @@
mod lz4;

pub mod read;
131 changes: 131 additions & 0 deletions rust_async/src/lz4.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,131 @@
use lz4::liblz4::*;

Check failure on line 1 in rust_async/src/lz4.rs

View workflow job for this annotation

GitHub Actions / spellcheck

Unknown word (liblz)
use std::io::{Error, ErrorKind, Result};
use std::ptr;
use tokio::io::{AsyncRead, ReadBuf};

const BUFFER_SIZE: usize = 32 * 1024;

#[derive(Debug)]
struct DecoderContext {
c: LZ4FDecompressionContext,
}

#[derive(Debug)]
pub struct Lz4Decoder<R> {
c: DecoderContext,
r: R,
buf: Box<[u8]>, // 32kb input buffer. we repeatedly fill this and decompress into the output buffer.
pos: usize, // cursor into `buf` that represents the start of data in `buf` that has not yet been decompressed.
len: usize, // represents the number of valid bytes in `buf`.
next: usize, // the number of bytes to read into `buf` on the next read.
}

impl<R: AsyncRead> Lz4Decoder<R> {
/// Creates a new decoder which reads its input from the given
/// input stream. The input stream can be re-acquired by calling
/// `finish()`
pub fn new(r: R) -> Result<Lz4Decoder<R>> {
Ok(Lz4Decoder {
r,
c: DecoderContext::new()?,
buf: vec![0; BUFFER_SIZE].into_boxed_slice(),
pos: BUFFER_SIZE,
len: BUFFER_SIZE,
// Minimal LZ4 stream size
next: 11,
})
}

pub fn into_inner(self) -> R {
self.r
}

pub fn finish(self) -> (R, Result<()>) {
(
self.r,
match self.next {
0 => Ok(()),
_ => Err(Error::new(
ErrorKind::Interrupted,
"Finish runned before read end of compressed stream",

Check failure on line 50 in rust_async/src/lz4.rs

View workflow job for this annotation

GitHub Actions / spellcheck

Unknown word (runned)
)),
},
)
}
}

impl<R: AsyncRead + std::marker::Unpin> AsyncRead for Lz4Decoder<R> {
fn poll_read(
self: std::pin::Pin<&mut Self>,
cx: &mut std::task::Context<'_>,
buf: &mut tokio::io::ReadBuf<'_>,
) -> std::task::Poll<std::io::Result<()>> {
if self.next == 0 || buf.remaining() == 0 {
return std::task::Poll::Ready(Ok(()));
}
let mut written_len: usize = 0;
let mself = self.get_mut();

Check failure on line 67 in rust_async/src/lz4.rs

View workflow job for this annotation

GitHub Actions / spellcheck

Unknown word (mself)
while written_len == 0 {
if mself.pos >= mself.len {

Check failure on line 69 in rust_async/src/lz4.rs

View workflow job for this annotation

GitHub Actions / spellcheck

Unknown word (mself)

Check failure on line 69 in rust_async/src/lz4.rs

View workflow job for this annotation

GitHub Actions / spellcheck

Unknown word (mself)
let need = if mself.buf.len() < mself.next {

Check failure on line 70 in rust_async/src/lz4.rs

View workflow job for this annotation

GitHub Actions / spellcheck

Unknown word (mself)
mself.buf.len()
} else {
mself.next
};
{
let mut comp_buf = ReadBuf::new(&mut mself.buf[..need]);
let result = std::pin::pin!(&mut mself.r).poll_read(cx, &mut comp_buf);
match result {
std::task::Poll::Pending => return result,
std::task::Poll::Ready(Err(_)) => return result,
_ => {}
};
mself.len = comp_buf.filled().len();
}
if mself.len == 0 {
break;
}
mself.pos = 0;
mself.next -= mself.len;
}
while (written_len < buf.remaining()) && (mself.pos < mself.len) {
let mut src_size = mself.len - mself.pos;
let mut dst_size = buf.remaining() - written_len;
let len = check_error(unsafe {
LZ4F_decompress(
mself.c.c,
buf.initialize_unfilled().as_mut_ptr(),
&mut dst_size,
mself.buf[mself.pos..].as_ptr(),
&mut src_size,
ptr::null(),
)
})?;
mself.pos += src_size as usize;
written_len += dst_size as usize;
buf.set_filled(written_len);
if len == 0 {
mself.next = 0;
return std::task::Poll::Ready(Ok(()));
} else if mself.next < len {
mself.next = len;
}
}
}
return std::task::Poll::Ready(Ok(()));
}
}

impl DecoderContext {
fn new() -> Result<DecoderContext> {
let mut context = LZ4FDecompressionContext(ptr::null_mut());
check_error(unsafe { LZ4F_createDecompressionContext(&mut context, LZ4F_VERSION) })?;
Ok(DecoderContext { c: context })
}
}

impl Drop for DecoderContext {
fn drop(&mut self) {
unsafe { LZ4F_freeDecompressionContext(self.c) };
}
}
83 changes: 75 additions & 8 deletions rust_async/src/read.rs
Original file line number Diff line number Diff line change
@@ -1,13 +1,15 @@
use byteorder::ByteOrder;
use tokio::io::{AsyncRead, AsyncReadExt, BufReader, Take};

use crate::lz4::Lz4Decoder;
use async_compression::tokio::bufread::ZstdDecoder;
use mcap::{McapError, McapResult, MAGIC};

enum ReaderState<R> {
Base(R),
UncompressedChunk(Take<R>),
ZstdChunk(ZstdDecoder<BufReader<Take<R>>>),
Lz4Chunk(Lz4Decoder<Take<R>>),
Empty,
}

Expand All @@ -24,6 +26,7 @@ where
ReaderState::Base(r) => std::pin::pin!(r).poll_read(cx, buf),
ReaderState::UncompressedChunk(r) => std::pin::pin!(r).poll_read(cx, buf),
ReaderState::ZstdChunk(r) => std::pin::pin!(r).poll_read(cx, buf),
ReaderState::Lz4Chunk(r) => std::pin::pin!(r).poll_read(cx, buf),
ReaderState::Empty => {
panic!("invariant: reader is only set to empty while swapping with another valid variant")
}
Expand All @@ -39,13 +42,14 @@ where
ReaderState::Base(reader) => reader,
ReaderState::UncompressedChunk(take) => take.into_inner(),
ReaderState::ZstdChunk(decoder) => decoder.into_inner().into_inner().into_inner(),
ReaderState::Lz4Chunk(decoder) => decoder.into_inner().into_inner(),
ReaderState::Empty => {
panic!("invariant: reader is only set to empty while swapping with another valid variant")
}
}
}
}
pub struct LinearReader<R> {
pub struct RecordReader<R> {
reader: ReaderState<R>,
options: Options,
start_magic_seen: bool,
Expand All @@ -67,7 +71,7 @@ enum Cmd {
ExitChunk,
Stop,
}
impl<R> LinearReader<R>
impl<R> RecordReader<R>
where
R: AsyncRead + std::marker::Unpin,
{
Expand Down Expand Up @@ -104,6 +108,11 @@ where
rdr.into_inner().take(header.compressed_size),
)));
}
"lz4" => {
let decoder =
Lz4Decoder::new(rdr.into_inner().take(header.compressed_size))?;
self.reader = ReaderState::Lz4Chunk(decoder);
}
"" => {
self.reader = ReaderState::UncompressedChunk(
rdr.into_inner().take(header.compressed_size),
Expand Down Expand Up @@ -197,20 +206,25 @@ async fn read_chunk_header<R: AsyncRead + std::marker::Unpin>(
header.compression = match std::str::from_utf8(&scratch[..]) {
Ok(val) => val.to_owned(),
Err(err) => {
return Err(McapError::Parse(binrw::error::Error::Custom { pos: 32, err: Box::new(err) }));
return Err(McapError::Parse(binrw::error::Error::Custom {
pos: 32,
err: Box::new(err),
}));
}
};
scratch.resize(8, 0);
reader.read_exact(&mut scratch[..]).await?;
header.compressed_size = byteorder::LittleEndian::read_u64(&scratch[..]);
let available = record_len - (32 + compression_len as u64 + 8);
if available < header.compressed_size {
return Err(McapError::BadChunkLength { header: header.compressed_size , available });
return Err(McapError::BadChunkLength {
header: header.compressed_size,
available,
});
}
Ok(header)
}


#[cfg(test)]
mod tests {
use std::collections::BTreeMap;
Expand All @@ -223,7 +237,7 @@ mod tests {
let mut writer = mcap::Writer::new(&mut buf)?;
writer.finish()?;
}
let mut reader = LinearReader::new(std::io::Cursor::new(buf.into_inner()));
let mut reader = RecordReader::new(std::io::Cursor::new(buf.into_inner()));
let mut record: Vec<u8> = Vec::new();
let mut opcodes: Vec<u8> = Vec::new();
loop {
Expand Down Expand Up @@ -268,7 +282,7 @@ mod tests {
})?;
writer.finish()?;
}
let mut reader = LinearReader::new(std::io::Cursor::new(buf.into_inner()));
let mut reader = RecordReader::new(std::io::Cursor::new(buf.into_inner()));
let mut record = Vec::new();
let mut opcodes: Vec<u8> = Vec::new();
loop {
Expand Down Expand Up @@ -321,7 +335,60 @@ mod tests {
})?;
writer.finish()?;
}
let mut reader = LinearReader::new(std::io::Cursor::new(buf.into_inner()));
let mut reader = RecordReader::new(std::io::Cursor::new(buf.into_inner()));
let mut record = Vec::new();
let mut opcodes: Vec<u8> = Vec::new();
loop {
let opcode = reader.next_record(&mut record).await?;
if let Some(opcode) = opcode {
opcodes.push(opcode);
} else {
break;
}
}
assert_eq!(
opcodes.as_slice(),
[
mcap::records::op::HEADER,
mcap::records::op::CHANNEL,
mcap::records::op::MESSAGE,
mcap::records::op::MESSAGE_INDEX,
mcap::records::op::DATA_END,
mcap::records::op::CHANNEL,
mcap::records::op::CHUNK_INDEX,
mcap::records::op::STATISTICS,
mcap::records::op::SUMMARY_OFFSET,
mcap::records::op::SUMMARY_OFFSET,
mcap::records::op::SUMMARY_OFFSET,
mcap::records::op::FOOTER,
]
);
Ok(())
}
#[tokio::test]
async fn test_reads_lz4_chunk() -> Result<(), mcap::McapError> {
let mut buf = std::io::Cursor::new(Vec::new());
{
let mut writer = mcap::WriteOptions::new()
.compression(Some(mcap::Compression::Lz4))
.create(&mut buf)?;
let channel = std::sync::Arc::new(mcap::Channel {
topic: "chat".to_owned(),
schema: None,
message_encoding: "json".to_owned(),
metadata: BTreeMap::new(),
});
writer.add_channel(&channel)?;
writer.write(&mcap::Message {
channel,
sequence: 0,
log_time: 0,
publish_time: 0,
data: (&[0, 1, 2]).into(),
})?;
writer.finish()?;
}
let mut reader = RecordReader::new(std::io::Cursor::new(buf.into_inner()));
let mut record = Vec::new();
let mut opcodes: Vec<u8> = Vec::new();
loop {
Expand Down

0 comments on commit bfe547a

Please sign in to comment.