Skip to content

Commit

Permalink
refactor: extract record batch reader
Browse files Browse the repository at this point in the history
  • Loading branch information
sdd committed Feb 29, 2024
1 parent a4cf9b8 commit f01c062
Show file tree
Hide file tree
Showing 3 changed files with 115 additions and 79 deletions.
94 changes: 94 additions & 0 deletions crates/iceberg/src/file_record_batch_reader.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,94 @@
// Licensed to the Apache Software Foundation (ASF) under one
// or more contributor license agreements. See the NOTICE file
// distributed with this work for additional information
// regarding copyright ownership. The ASF licenses this file
// to you under the Apache License, Version 2.0 (the
// "License"); you may not use this file except in compliance
// with the License. You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing,
// software distributed under the License is distributed on an
// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
// KIND, either express or implied. See the License for the
// specific language governing permissions and limitations
// under the License.

//! Parquet file data reader

use crate::{Error, ErrorKind};
use async_stream::try_stream;
use futures::stream::StreamExt;
use parquet::arrow::{ParquetRecordBatchStreamBuilder, ProjectionMask};

use crate::io::FileIO;
use crate::scan::{ArrowRecordBatchStream, FileScanTask, FileScanTaskStream};
use crate::spec::SchemaRef;

/// Default arrow record batch size
const DEFAULT_BATCH_SIZE: usize = 1024;

/// Reads data from Parquet files
pub struct FileRecordBatchReader {
batch_size: Option<usize>,
#[allow(dead_code)]
schema: SchemaRef,
file_io: FileIO,
}

impl FileRecordBatchReader {
/// Constructs a new FileRecordBatchReader
pub fn new(file_io: FileIO, schema: SchemaRef, batch_size: Option<usize>) -> Self {
FileRecordBatchReader {
batch_size,
file_io,
schema,
}
}

/// Take a stream of FileScanTasks and reads all the files.
/// Returns a stream of Arrow RecordBatches containing the data from the files
pub fn read(self, mut tasks: FileScanTaskStream) -> crate::Result<ArrowRecordBatchStream> {
let file_io = self.file_io.clone();
let batch_size = self.batch_size.unwrap_or(DEFAULT_BATCH_SIZE);

Ok(
try_stream! {
while let Some(Ok(task)) = tasks.next().await {

let projection_mask = self.get_arrow_projection_mask(&task);

let parquet_reader = file_io
.new_input(task.data_file().file_path())?
.reader()
.await?;

let mut batch_stream = ParquetRecordBatchStreamBuilder::new(parquet_reader)
.await
.map_err(|err| map_parquet_error(err, "Failed create record batch stream builder", task.data_file().file_path()))?
.with_batch_size(batch_size)
.with_projection(projection_mask)
.build()
.map_err(|err| map_parquet_error(err, "Fail to build record batch stream builder", task.data_file().file_path()))?;

while let Some(batch) = batch_stream.next().await {
yield batch
.map_err(|err| map_parquet_error(err, "Fail to read record batch", task.data_file().file_path()))?;
}
}
}.boxed()
)
}

fn get_arrow_projection_mask(&self, _task: &FileScanTask) -> ProjectionMask {
// TODO: full implementation
ProjectionMask::all()
}
}

fn map_parquet_error(err: parquet::errors::ParquetError, message: &str, file_path: &str) -> Error {
Error::new(ErrorKind::Unexpected, message)
.with_source(err)
.with_context("filename", file_path)
}
1 change: 1 addition & 0 deletions crates/iceberg/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -52,4 +52,5 @@ pub mod expr;
pub mod transaction;
pub mod transform;

pub mod file_record_batch_reader;
pub mod writer;
99 changes: 20 additions & 79 deletions crates/iceberg/src/scan.rs
Original file line number Diff line number Diff line change
Expand Up @@ -17,27 +17,21 @@

//! Table scan api.

use crate::file_record_batch_reader::FileRecordBatchReader;
use crate::io::FileIO;
use crate::spec::{DataContentType, ManifestEntryRef, SchemaRef, SnapshotRef, TableMetadataRef};
use crate::table::Table;
use crate::{Error, ErrorKind};
use arrow_array::RecordBatch;
use async_stream::try_stream;
use futures::stream::{iter, BoxStream};
use futures::{StreamExt, TryStreamExt};
use parquet::arrow::arrow_reader::RowSelection;
use parquet::arrow::{ParquetRecordBatchStreamBuilder, ProjectionMask};

/// Default arrow record batch size
const DEFAULT_BATCH_SIZE: usize = 1024;
use futures::StreamExt;

/// Builder to create table scan.
pub struct TableScanBuilder<'a> {
table: &'a Table,
// Empty column names means to select all columns
column_names: Vec<String>,
snapshot_id: Option<i64>,
batch_size: Option<usize>,
}

impl<'a> TableScanBuilder<'a> {
Expand All @@ -46,7 +40,6 @@ impl<'a> TableScanBuilder<'a> {
table,
column_names: vec![],
snapshot_id: None,
batch_size: None,
}
}

Expand All @@ -71,11 +64,6 @@ impl<'a> TableScanBuilder<'a> {
self
}

pub fn with_batch_size(mut self, batch_size: usize) -> Self {
self.batch_size = Some(batch_size);
self
}

/// Build the table scan.
pub fn build(self) -> crate::Result<TableScan> {
let snapshot = match self.snapshot_id {
Expand Down Expand Up @@ -123,7 +111,6 @@ impl<'a> TableScanBuilder<'a> {
table_metadata: self.table.metadata_ref(),
column_names: self.column_names,
schema,
batch_size: self.batch_size,
})
}
}
Expand All @@ -137,7 +124,6 @@ pub struct TableScan {
file_io: FileIO,
column_names: Vec<String>,
schema: SchemaRef,
batch_size: Option<usize>,
}

/// A stream of [`FileScanTask`].
Expand Down Expand Up @@ -179,79 +165,32 @@ impl TableScan {
Ok(iter(file_scan_tasks).boxed())
}

/// Transforms a stream of FileScanTasks from plan_files into a stream of
/// Arrow RecordBatches.
pub fn open(&self, mut tasks: FileScanTaskStream) -> crate::Result<ArrowRecordBatchStream> {
let file_io = self.file_io.clone();
let batch_size = self.batch_size.unwrap_or(DEFAULT_BATCH_SIZE);
let projection_mask = self.get_arrow_projection_mask();
let row_selection = self.get_arrow_row_selection();

Ok(
try_stream! {
while let Some(Ok(task)) = tasks.next().await {
let parquet_reader = file_io
.new_input(task.data_file().file_path())?
.reader()
.await?;

let mut batch_stream = ParquetRecordBatchStreamBuilder::new(parquet_reader)
.await
.map_err(|err| {
Error::new(ErrorKind::Unexpected, "failed to load parquet file").with_source(err)
})?
.with_batch_size(batch_size)
.with_offset(task.start() as usize)
.with_limit(task.length() as usize)
.with_projection(projection_mask.clone())
.with_row_selection(row_selection.clone())
.build()
.unwrap()
.map_err(|err| Error::new(ErrorKind::Unexpected, "Fail to read data").with_source(err));

while let Some(batch) = batch_stream.next().await {
yield batch?;
}
}
}.boxed()
)
}

fn get_arrow_projection_mask(&self) -> ProjectionMask {
// TODO, dummy implementation
todo!()
}

fn get_arrow_row_selection(&self) -> RowSelection {
// TODO, dummy implementation
todo!()
pub async fn execute(
&self,
batch_size: Option<usize>,
) -> crate::Result<ArrowRecordBatchStream> {
FileRecordBatchReader::new(self.file_io.clone(), self.schema.clone(), batch_size)
.read(self.plan_files().await?)
}
}

/// A task to scan part of file.
#[derive(Debug)]
#[allow(dead_code)]
pub struct FileScanTask {
data_file: ManifestEntryRef,
#[allow(dead_code)]
start: u64,
#[allow(dead_code)]
length: u64,
}

/// A stream of arrow record batches.
pub type ArrowRecordBatchStream = BoxStream<'static, crate::Result<RecordBatch>>;

impl FileScanTask {
pub fn data_file(&self) -> ManifestEntryRef {
pub(crate) fn data_file(&self) -> ManifestEntryRef {
self.data_file.clone()
}

pub fn start(&self) -> u64 {
self.start
}

pub fn length(&self) -> u64 {
self.length
}
}

#[cfg(test)]
Expand Down Expand Up @@ -445,13 +384,16 @@ mod tests {
.set_compression(Compression::SNAPPY)
.build();

let file = File::create(format!("{}/1.parquet", &self.table_location)).unwrap();
let mut writer = ArrowWriter::try_new(file, to_write.schema(), Some(props)).unwrap();
for n in 1..=3 {
let file = File::create(format!("{}/{}.parquet", &self.table_location, n)).unwrap();
let mut writer =
ArrowWriter::try_new(file, to_write.schema(), Some(props.clone())).unwrap();

writer.write(&to_write).expect("Writing batch");
writer.write(&to_write).expect("Writing batch");

// writer must be closed to write footer
writer.close().unwrap();
// writer must be closed to write footer
writer.close().unwrap();
}
}
}

Expand Down Expand Up @@ -554,7 +496,6 @@ mod tests {
}

#[tokio::test]
#[ignore = "won't work yet as there are still some unimplemented methods"]
async fn test_open_parquet_no_deletions() {
let mut fixture = TableTestFixture::new();
fixture.setup_manifest_files().await;
Expand All @@ -563,7 +504,7 @@ mod tests {
let table_scan = fixture.table.scan().build().unwrap();
let tasks = table_scan.plan_files().await.unwrap();

let batch_stream = table_scan.open(tasks).unwrap();
let batch_stream = table_scan.execute(None).await.unwrap();

let batches: Vec<_> = batch_stream.try_collect().await.unwrap();

Expand Down

0 comments on commit f01c062

Please sign in to comment.