Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Initialize Rust decompression support #29

Merged
merged 6 commits into from
Oct 8, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
26 changes: 26 additions & 0 deletions rust/Cargo.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

1 change: 1 addition & 0 deletions rust/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@ crate-type = ["cdylib", "rlib"]

[dependencies]
flatbuffers = "24.3"
lz4_flex = { version = "0.11.3", default-features = false , features = ["frame"] }
numpy = "0.21"
pyo3 = "0.21"
rand = "0.8"
Expand Down
47 changes: 43 additions & 4 deletions rust/src/example_iteration.rs
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,8 @@
// See the License for the specific language governing permissions and
// limitations under the License.

use std::io::Read;

use yoke::Yoke;

pub use super::parallel_map::parallel_map;
Expand All @@ -25,6 +27,30 @@ pub struct ExampleIterator {
example_iterator: Box<dyn Iterator<Item = Example> + Send>,
}

#[derive(Clone, Copy, Debug)]
pub enum CompressionType {
Uncompressed,
LZ4,
}

impl std::str::FromStr for CompressionType {
type Err = String;

fn from_str(input: &str) -> Result<Self, Self::Err> {
match input {
"" => Ok(CompressionType::Uncompressed),
"LZ4" => Ok(CompressionType::LZ4),
_ => Err("{input} unimplemented".to_string()),
}
}
}

#[derive(Clone, Debug)]
pub struct ShardInfo {
pub file_path: String,
pub compression_type: CompressionType,
}

impl ExampleIterator {
/// Takes a vector of file names of shards and creates an ExampleIterator over those. We assume
/// that all shard file names fit in memory. Alternatives to be re-evaluated:
Expand All @@ -33,7 +59,7 @@ impl ExampleIterator {
/// - Iterate over the shards in Rust. This would require having the shard filtering being
/// allowed to be called from Rust. But then we could pass an iterator of the following form:
/// `files: impl Iterator<Item = &str>`.
pub fn new(files: Vec<String>, repeat: bool, threads: usize) -> Self {
pub fn new(files: Vec<ShardInfo>, repeat: bool, threads: usize) -> Self {
assert!(!repeat, "Not implemented yet: repeat=true");
let example_iterator = Box::new(
parallel_map(|x| get_shard_progress(&x), files.into_iter(), threads).flatten(),
Expand All @@ -57,10 +83,23 @@ struct ShardProgress {
shard: LoadedShard,
}

/// Return a vector of bytes with the file content.
fn get_file_bytes(shard_info: &ShardInfo) -> Vec<u8> {
match shard_info.compression_type {
CompressionType::Uncompressed => std::fs::read(&shard_info.file_path).unwrap(),
CompressionType::LZ4 => {
let mut file_bytes = Vec::new();
lz4_flex::frame::FrameDecoder::new(std::fs::File::open(&shard_info.file_path).unwrap())
.read_to_end(&mut file_bytes)
.unwrap();
file_bytes
}
}
}

/// Get ShardProgress.
fn get_shard_progress(file_path: &str) -> ShardProgress {
// TODO compressed file support.
let file_bytes = std::fs::read(file_path).unwrap();
fn get_shard_progress(shard_info: &ShardInfo) -> ShardProgress {
let file_bytes = get_file_bytes(shard_info);

// A shard is a vector of examples (positive number -- invariant kept by Python code).
// An example is vector of attributes (the same number of attributes in each example of each
Expand Down
12 changes: 9 additions & 3 deletions rust/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -31,12 +31,13 @@ mod shard_generated;
/// Python wrappers around `example_iteration`.
mod static_iter {
use std::collections::HashMap;
use std::str::FromStr;

use numpy::IntoPyArray;
use pyo3::prelude::*;
use pyo3::{pyclass, pymethods, PyRefMut};

use super::example_iteration::ExampleIterator;
use super::example_iteration::{CompressionType, ExampleIterator, ShardInfo};

/// Implementation details: The goal is to own the ExampleIterator in Rust and only send
/// examples to Python. This helps with concurrent reading and parsing of shard files.
Expand Down Expand Up @@ -92,10 +93,15 @@ mod static_iter {
#[pymethods]
impl RustIter {
#[new]
fn new(files: Vec<String>, repeat: bool, threads: usize) -> Self {
fn new(files: Vec<String>, repeat: bool, threads: usize, compression: String) -> Self {
let static_index = rand::random();
let mut hash_map = STATIC_ITERATORS.lock().unwrap();
hash_map.insert(static_index, ExampleIterator::new(files, repeat, threads));
let compression_type = CompressionType::from_str(&compression).unwrap();
let shard_infos = files
.into_iter()
.map(|file_path| ShardInfo { file_path: file_path.clone(), compression_type })
.collect();
hash_map.insert(static_index, ExampleIterator::new(shard_infos, repeat, threads));

RustIter { static_index, can_iterate: false }
}
Expand Down
47 changes: 12 additions & 35 deletions rust/src/shard_generated.rs
Original file line number Diff line number Diff line change
Expand Up @@ -54,9 +54,7 @@ pub mod sedpack {
type Inner = Attribute<'a>;
#[inline]
unsafe fn follow(buf: &'a [u8], loc: usize) -> Self::Inner {
Self {
_tab: flatbuffers::Table::new(buf, loc),
}
Self { _tab: flatbuffers::Table::new(buf, loc) }
}
}

Expand Down Expand Up @@ -102,8 +100,7 @@ pub mod sedpack {
impl flatbuffers::Verifiable for Attribute<'_> {
#[inline]
fn run_verifier(
v: &mut flatbuffers::Verifier,
pos: usize,
v: &mut flatbuffers::Verifier, pos: usize,
) -> Result<(), flatbuffers::InvalidFlatbuffer> {
use self::flatbuffers::Verifiable;
v.visit_table(pos)?
Expand All @@ -119,9 +116,7 @@ pub mod sedpack {
impl<'a> Default for AttributeArgs<'a> {
#[inline]
fn default() -> Self {
AttributeArgs {
attribute_bytes: None,
}
AttributeArgs { attribute_bytes: None }
}
}

Expand All @@ -145,10 +140,7 @@ pub mod sedpack {
_fbb: &'b mut flatbuffers::FlatBufferBuilder<'a, A>,
) -> AttributeBuilder<'a, 'b, A> {
let start = _fbb.start_table();
AttributeBuilder {
fbb_: _fbb,
start_: start,
}
AttributeBuilder { fbb_: _fbb, start_: start }
}
#[inline]
pub fn finish(self) -> flatbuffers::WIPOffset<Attribute<'a>> {
Expand All @@ -175,9 +167,7 @@ pub mod sedpack {
type Inner = Example<'a>;
#[inline]
unsafe fn follow(buf: &'a [u8], loc: usize) -> Self::Inner {
Self {
_tab: flatbuffers::Table::new(buf, loc),
}
Self { _tab: flatbuffers::Table::new(buf, loc) }
}
}

Expand Down Expand Up @@ -224,8 +214,7 @@ pub mod sedpack {
impl flatbuffers::Verifiable for Example<'_> {
#[inline]
fn run_verifier(
v: &mut flatbuffers::Verifier,
pos: usize,
v: &mut flatbuffers::Verifier, pos: usize,
) -> Result<(), flatbuffers::InvalidFlatbuffer> {
use self::flatbuffers::Verifiable;
v.visit_table(pos)?
Expand Down Expand Up @@ -272,10 +261,7 @@ pub mod sedpack {
_fbb: &'b mut flatbuffers::FlatBufferBuilder<'a, A>,
) -> ExampleBuilder<'a, 'b, A> {
let start = _fbb.start_table();
ExampleBuilder {
fbb_: _fbb,
start_: start,
}
ExampleBuilder { fbb_: _fbb, start_: start }
}
#[inline]
pub fn finish(self) -> flatbuffers::WIPOffset<Example<'a>> {
Expand All @@ -293,7 +279,6 @@ pub mod sedpack {
}
pub enum ShardOffset {}


// Added Yokeable to the autogenerated code.
#[derive(Copy, Clone, PartialEq, yoke::Yokeable)]
pub struct Shard<'a> {
Expand All @@ -304,9 +289,7 @@ pub mod sedpack {
type Inner = Shard<'a>;
#[inline]
unsafe fn follow(buf: &'a [u8], loc: usize) -> Self::Inner {
Self {
_tab: flatbuffers::Table::new(buf, loc),
}
Self { _tab: flatbuffers::Table::new(buf, loc) }
}
}

Expand Down Expand Up @@ -353,8 +336,7 @@ pub mod sedpack {
impl flatbuffers::Verifiable for Shard<'_> {
#[inline]
fn run_verifier(
v: &mut flatbuffers::Verifier,
pos: usize,
v: &mut flatbuffers::Verifier, pos: usize,
) -> Result<(), flatbuffers::InvalidFlatbuffer> {
use self::flatbuffers::Verifiable;
v.visit_table(pos)?
Expand Down Expand Up @@ -401,10 +383,7 @@ pub mod sedpack {
_fbb: &'b mut flatbuffers::FlatBufferBuilder<'a, A>,
) -> ShardBuilder<'a, 'b, A> {
let start = _fbb.start_table();
ShardBuilder {
fbb_: _fbb,
start_: start,
}
ShardBuilder { fbb_: _fbb, start_: start }
}
#[inline]
pub fn finish(self) -> flatbuffers::WIPOffset<Shard<'a>> {
Expand Down Expand Up @@ -450,8 +429,7 @@ pub mod sedpack {
/// previous, unchecked, behavior use
/// `root_as_shard_unchecked`.
pub fn root_as_shard_with_opts<'b, 'o>(
opts: &'o flatbuffers::VerifierOptions,
buf: &'b [u8],
opts: &'o flatbuffers::VerifierOptions, buf: &'b [u8],
) -> Result<Shard<'b>, flatbuffers::InvalidFlatbuffer> {
flatbuffers::root_with_opts::<Shard<'b>>(opts, buf)
}
Expand All @@ -463,8 +441,7 @@ pub mod sedpack {
/// previous, unchecked, behavior use
/// `root_as_shard_unchecked`.
pub fn size_prefixed_root_as_shard_with_opts<'b, 'o>(
opts: &'o flatbuffers::VerifierOptions,
buf: &'b [u8],
opts: &'o flatbuffers::VerifierOptions, buf: &'b [u8],
) -> Result<Shard<'b>, flatbuffers::InvalidFlatbuffer> {
flatbuffers::size_prefixed_root_with_opts::<Shard<'b>>(opts, buf)
}
Expand Down
16 changes: 10 additions & 6 deletions src/sedpack/io/dataset_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -141,7 +141,7 @@ def _shard_info_iterator(
for child in shard_list.children_shard_lists:
yield from self._shard_info_iterator(child)

def shard_info_iterator(self, split: SplitT) -> Iterator[ShardInfo]:
def shard_info_iterator(self, split: SplitT | None) -> Iterator[ShardInfo]:
"""Iterate all `ShardInfo` in the split.

Args:
Expand All @@ -151,10 +151,14 @@ def shard_info_iterator(self, split: SplitT) -> Iterator[ShardInfo]:
Raises: ValueError when the split is not present. A split not being
present is different from there not being any shard.
"""
if split not in self._dataset_info.splits:
# Split not present.
raise ValueError(f"There is no shard in {split}.")
if split:
if split not in self._dataset_info.splits:
# Split not present.
raise ValueError(f"There is no shard in {split}.")

shard_list_info: ShardListInfo = self._dataset_info.splits[split]
shard_list_info: ShardListInfo = self._dataset_info.splits[split]

yield from self._shard_info_iterator(shard_list_info)
yield from self._shard_info_iterator(shard_list_info)
else:
for shard_list_info in self._dataset_info.splits.values():
yield from self._shard_info_iterator(shard_list_info)
9 changes: 6 additions & 3 deletions src/sedpack/io/dataset_iteration.py
Original file line number Diff line number Diff line change
Expand Up @@ -658,9 +658,12 @@ def to_dict(example):
)
return result

with _sedpack_rs.RustIter(files=shard_paths,
repeat=repeat,
threads=file_parallelism) as rust_iter:
with _sedpack_rs.RustIter(
files=shard_paths,
repeat=repeat,
threads=file_parallelism,
compression=self.dataset_structure.compression,
) as rust_iter:
example_iterator = map(to_dict, iter(rust_iter))
if process_record:
yield from map(process_record, example_iterator)
Expand Down
9 changes: 9 additions & 0 deletions tests/io/test_rust_iter.py
Original file line number Diff line number Diff line change
Expand Up @@ -89,3 +89,12 @@ def test_end2end_as_numpy_iterator_fb(tmpdir: Union[str, Path]) -> None:
shard_file_type="fb",
compression="",
)


def test_end2end_as_numpy_iterator_fb_lz4(tmpdir: Union[str, Path]) -> None:
end2end(
tmpdir=tmpdir,
dtype="float32",
shard_file_type="fb",
compression="LZ4",
)