Skip to content

Commit

Permalink
Initialize Rust decompression support
Browse files Browse the repository at this point in the history
  • Loading branch information
wsxrdv committed Oct 7, 2024
1 parent 34b61df commit 4803e63
Show file tree
Hide file tree
Showing 8 changed files with 109 additions and 23 deletions.
26 changes: 26 additions & 0 deletions rust/Cargo.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

1 change: 1 addition & 0 deletions rust/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@ crate-type = ["cdylib", "rlib"]

[dependencies]
flatbuffers = "24.3"
lz4_flex = { version = "0.11.3", default-features = false , features = ["frame"] }
numpy = "0.21"
pyo3 = "0.21"
rand = "0.8"
Expand Down
41 changes: 36 additions & 5 deletions rust/src/example_iteration.rs
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
// See the License for the specific language governing permissions and
// limitations under the License.

use std::io::Read;
use yoke::Yoke;

pub use super::parallel_map::parallel_map;
Expand All @@ -25,6 +26,18 @@ pub struct ExampleIterator {
example_iterator: Box<dyn Iterator<Item = Example> + Send>,
}

#[derive(Clone, Copy, Debug)]
pub enum CompressionType {
Uncompressed,
LZ4,
}

#[derive(Clone, Debug)]
pub struct ShardInfo {
pub file_path: String,
pub compression_type: CompressionType,
}

impl ExampleIterator {
/// Takes a vector of file names of shards and creates an ExampleIterator over those. We assume
/// that all shard file names fit in memory. Alternatives to be re-evaluated:
Expand All @@ -33,7 +46,7 @@ impl ExampleIterator {
/// - Iterate over the shards in Rust. This would require having the shard filtering being
/// allowed to be called from Rust. But then we could pass an iterator of the following form:
/// `files: impl Iterator<Item = &str>`.
pub fn new(files: Vec<String>, repeat: bool, threads: usize) -> Self {
pub fn new(files: Vec<ShardInfo>, repeat: bool, threads: usize) -> Self {
assert!(!repeat, "Not implemented yet: repeat=true");
let example_iterator = Box::new(
parallel_map(|x| get_shard_progress(&x), files.into_iter(), threads).flatten(),
Expand All @@ -57,10 +70,28 @@ struct ShardProgress {
shard: LoadedShard,
}

/// Return a vector of bytes with the file content.
fn get_file_bytes(shard_info: &ShardInfo) -> Vec<u8> {
match shard_info.compression_type {
CompressionType::Uncompressed => std::fs::read(&shard_info.file_path).unwrap(),
CompressionType::LZ4 => {
let mut file_bytes = Vec::new();
let read_result = lz4_flex::frame::FrameDecoder::new(
std::fs::File::open(&shard_info.file_path).unwrap(),
)
.read_to_end(&mut file_bytes);
match read_result {
Err(err) => panic!("{}", err),
Ok(_bytes_read) => {}
};
file_bytes
}
}
}

/// Get ShardProgress.
fn get_shard_progress(file_path: &str) -> ShardProgress {
// TODO compressed file support.
let file_bytes = std::fs::read(file_path).unwrap();
fn get_shard_progress(shard_info: &ShardInfo) -> ShardProgress {
let file_bytes = get_file_bytes(shard_info);

// A shard is a vector of examples (positive number -- invariant kept by Python code).
// An example is vector of attributes (the same number of attributes in each example of each
Expand All @@ -86,7 +117,7 @@ fn get_shard_progress(file_path: &str) -> ShardProgress {
/// * `shard_progress` - The shard file information to be used. A copy from this memory happens.
/// Also the `shard_progress.used_examples` is not modified to allow multiple threads to access.
fn get_example(id: usize, shard_progress: &ShardProgress) -> Example {
assert!((shard_progress.used_examples .. shard_progress.total_examples).contains(&id));
assert!((shard_progress.used_examples..shard_progress.total_examples).contains(&id));

let shard = shard_progress.shard.get();
let examples = shard.examples().unwrap();
Expand Down
15 changes: 13 additions & 2 deletions rust/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,9 @@ mod static_iter {
use pyo3::prelude::*;
use pyo3::{pyclass, pymethods, PyRefMut};

use super::example_iteration::CompressionType;
use super::example_iteration::ExampleIterator;
use super::example_iteration::ShardInfo;

/// Implementation details: The goal is to own the ExampleIterator in Rust and only send
/// examples to Python. This helps with concurrent reading and parsing of shard files.
Expand Down Expand Up @@ -92,10 +94,19 @@ mod static_iter {
#[pymethods]
impl RustIter {
#[new]
fn new(files: Vec<String>, repeat: bool, threads: usize) -> Self {
fn new(files: Vec<String>, repeat: bool, threads: usize, compression: String) -> Self {
let static_index = rand::random();
let mut hash_map = STATIC_ITERATORS.lock().unwrap();
hash_map.insert(static_index, ExampleIterator::new(files, repeat, threads));
let compression_type = match compression.as_str() {
"" => CompressionType::Uncompressed,
"LZ4" => CompressionType::LZ4,
&_ => panic!("Not implemented: {}", compression),
};
let shard_infos = files
.into_iter()
.map(|file_path| ShardInfo { file_path: file_path.clone(), compression_type })
.collect();
hash_map.insert(static_index, ExampleIterator::new(shard_infos, repeat, threads));

RustIter { static_index, can_iterate: false }
}
Expand Down
14 changes: 7 additions & 7 deletions rust/src/parallel_map.rs
Original file line number Diff line number Diff line change
Expand Up @@ -115,7 +115,7 @@ where
// Start the threads and send an item to each of them.
let mut communication = Vec::new();
let mut handles = Vec::new();
for t in 0 .. threads {
for t in 0..threads {
// Next task for the thread.
let next_task = match iter.next() {
None => break, // Not creating a thread for nothing.
Expand Down Expand Up @@ -152,10 +152,10 @@ mod tests {

#[test]
fn plus_one() {
for iterator_length in 0 .. 50 {
for threads in 1 .. 20 {
for iterator_length in 0..50 {
for threads in 1..20 {
let mut iterations = 0;
for (i, res) in parallel_map(add_one, 0 .. iterator_length, threads).enumerate() {
for (i, res) in parallel_map(add_one, 0..iterator_length, threads).enumerate() {
// The value is correct and deterministic.
assert_eq!((i + 1) as i32, res);
iterations += 1;
Expand All @@ -173,12 +173,12 @@ mod tests {

#[test]
fn sleepy_plus_one() {
for iterator_length in 0 .. 10 {
for threads in 1 .. 10 {
for iterator_length in 0..10 {
for threads in 1..10 {
let mut iterations = 0;
let now = std::time::Instant::now();
for (i, res) in
parallel_map(sleepy_add_one, 0 .. iterator_length, threads).enumerate()
parallel_map(sleepy_add_one, 0..iterator_length, threads).enumerate()
{
// The value is correct and deterministic.
assert_eq!((i + 1) as i32, res);
Expand Down
16 changes: 10 additions & 6 deletions src/sedpack/io/dataset_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -141,7 +141,7 @@ def _shard_info_iterator(
for child in shard_list.children_shard_lists:
yield from self._shard_info_iterator(child)

def shard_info_iterator(self, split: SplitT) -> Iterator[ShardInfo]:
def shard_info_iterator(self, split: SplitT | None) -> Iterator[ShardInfo]:
"""Iterate all `ShardInfo` in the split.
Args:
Expand All @@ -151,10 +151,14 @@ def shard_info_iterator(self, split: SplitT) -> Iterator[ShardInfo]:
Raises: ValueError when the split is not present. A split not being
present is different from there not being any shard.
"""
if split not in self._dataset_info.splits:
# Split not present.
raise ValueError(f"There is no shard in {split}.")
if split:
if split not in self._dataset_info.splits:
# Split not present.
raise ValueError(f"There is no shard in {split}.")

shard_list_info: ShardListInfo = self._dataset_info.splits[split]
shard_list_info: ShardListInfo = self._dataset_info.splits[split]

yield from self._shard_info_iterator(shard_list_info)
yield from self._shard_info_iterator(shard_list_info)
else:
for shard_list_info in self._dataset_info.splits.values():
yield from self._shard_info_iterator(shard_list_info)
10 changes: 7 additions & 3 deletions src/sedpack/io/dataset_iteration.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@
Optional,
)
import os
import random

Check failure on line 25 in src/sedpack/io/dataset_iteration.py

View workflow job for this annotation

GitHub Actions / pylint (3.10)

Unused import random (unused-import)

Check failure on line 25 in src/sedpack/io/dataset_iteration.py

View workflow job for this annotation

GitHub Actions / pylint (3.11)

Unused import random (unused-import)

Check failure on line 25 in src/sedpack/io/dataset_iteration.py

View workflow job for this annotation

GitHub Actions / pylint (3.12)

Unused import random (unused-import)

import asyncstdlib
import tensorflow as tf
Expand Down Expand Up @@ -658,9 +659,12 @@ def to_dict(example):
)
return result

with _sedpack_rs.RustIter(files=shard_paths,
repeat=repeat,
threads=file_parallelism) as rust_iter:
with _sedpack_rs.RustIter(
files=shard_paths,
repeat=repeat,
threads=file_parallelism,
compression=self.dataset_structure.compression,
) as rust_iter:
example_iterator = map(to_dict, iter(rust_iter))
if process_record:
yield from map(process_record, example_iterator)
Expand Down
9 changes: 9 additions & 0 deletions tests/io/test_rust_iter.py
Original file line number Diff line number Diff line change
Expand Up @@ -89,3 +89,12 @@ def test_end2end_as_numpy_iterator_fb(tmpdir: Union[str, Path]) -> None:
shard_file_type="fb",
compression="",
)


def test_end2end_as_numpy_iterator_fb_lz4(tmpdir: Union[str, Path]) -> None:
end2end(
tmpdir=tmpdir,
dtype="float32",
shard_file_type="fb",
compression="LZ4",
)

0 comments on commit 4803e63

Please sign in to comment.