Skip to content

Commit

Permalink
feat: cast record-batches on write
Browse files Browse the repository at this point in the history
  • Loading branch information
roeap authored and rtyler committed May 6, 2023
1 parent 4b4aac7 commit fc134d5
Show file tree
Hide file tree
Showing 2 changed files with 51 additions and 25 deletions.
11 changes: 9 additions & 2 deletions rust/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -13,12 +13,18 @@ readme = "README.md"
edition = "2021"

[dependencies]
arrow = { version = "37.0.0", optional = true }
arrow = { version = "37", optional = true }
arrow-array = { version = "37", optional = true }
arrow-cast = { version = "37", optional = true }
arrow-schema = { version = "37", optional = true }
async-trait = "0.1"
bytes = "1"
chrono = { version = "0.4.22", default-features = false, features = ["clock"] }
cfg-if = "1"
datafusion-objectstore-hdfs = { version = "0.1.3", default-features = false, features = ["hdfs3", "try_spawn_blocking"], optional = true }
datafusion-objectstore-hdfs = { version = "0.1.3", default-features = false, features = [
"hdfs3",
"try_spawn_blocking",
], optional = true }
errno = "0.3"
futures = "0.3"
itertools = "0.10"
Expand Down Expand Up @@ -96,6 +102,7 @@ glibc_version = { path = "../glibc_version", version = "0.1" }

[features]
azure = ["object_store/azure"]
arrow = ["dep:arrow", "arrow-array", "arrow-cast", "arrow-schema"]
default = ["arrow", "parquet"]
datafusion = [
"dep:datafusion",
Expand Down
65 changes: 42 additions & 23 deletions rust/src/operations/write.rs
Original file line number Diff line number Diff line change
Expand Up @@ -31,8 +31,9 @@ use crate::table_state::DeltaTableState;
use crate::writer::record_batch::divide_by_partition_values;
use crate::writer::utils::PartitionPath;

use arrow::datatypes::{DataType, SchemaRef as ArrowSchemaRef};
use arrow::record_batch::RecordBatch;
use arrow_array::RecordBatch;
use arrow_cast::{can_cast_types, cast};
use arrow_schema::{Schema as ArrowSchema, SchemaRef as ArrowSchemaRef};
use datafusion::execution::context::{SessionContext, SessionState, TaskContext};
use datafusion::physical_plan::{memory::MemoryExec, ExecutionPlan};
use futures::future::BoxFuture;
Expand Down Expand Up @@ -249,16 +250,13 @@ impl std::future::IntoFuture for WriteBuilder {
.snapshot
.physical_arrow_schema(this.store.clone())
.await
.or_else(|_| this.snapshot.arrow_schema());

// we cannot get a schema, if there is not data and or meta data in the table,
// i.e. not initialized
if let Ok(curr_schema) = table_schema {
if !schema_eq(curr_schema, schema.clone()) {
return Err(DeltaTableError::Generic(
"Updating table schema not yet implemented".to_string(),
));
}
.or_else(|_| this.snapshot.arrow_schema())
.unwrap_or(schema.clone());

if !can_cast_batch(schema.as_ref(), table_schema.as_ref()) {
return Err(DeltaTableError::Generic(
"Updating table schema not yet implemented".to_string(),
));
};

let data = if !partition_columns.is_empty() {
Expand All @@ -269,8 +267,7 @@ impl std::future::IntoFuture for WriteBuilder {
schema.clone(),
partition_columns.clone(),
&batch,
)
.unwrap();
)?;
for part in divided {
let key = PartitionPath::from_hashmap(
&partition_columns,
Expand Down Expand Up @@ -326,13 +323,15 @@ impl std::future::IntoFuture for WriteBuilder {
);
let mut writer = DeltaWriter::new(this.store.clone(), config);
let checker_stream = checker.clone();
let schema = inner_plan.schema().clone();
let mut stream = inner_plan.execute(i, task_ctx)?;
let handle: tokio::task::JoinHandle<DeltaResult<Vec<Add>>> =
tokio::task::spawn(async move {
while let Some(maybe_batch) = stream.next().await {
let batch = maybe_batch?;
checker_stream.check_batch(&batch).await?;
writer.write(&batch).await?;
let arr = cast_record_batch(&batch, schema.clone())?;
writer.write(&arr).await?;
}
writer.close().await
});
Expand Down Expand Up @@ -423,16 +422,36 @@ impl std::future::IntoFuture for WriteBuilder {
}
}

fn schema_to_vec_name_type(schema: ArrowSchemaRef) -> Vec<(String, DataType)> {
schema
.fields()
.iter()
.map(|f| (f.name().to_owned(), f.data_type().clone()))
.collect::<Vec<_>>()
fn can_cast_batch(from_schema: &ArrowSchema, to_schema: &ArrowSchema) -> bool {
if from_schema.fields.len() != to_schema.fields.len() {
return false;
}
from_schema.all_fields().iter().all(|f| {
if let Ok(target_field) = to_schema.field_with_name(f.name()) {
can_cast_types(f.data_type(), target_field.data_type())
} else {
false
}
})
}

fn schema_eq(l: ArrowSchemaRef, r: ArrowSchemaRef) -> bool {
schema_to_vec_name_type(l) == schema_to_vec_name_type(r)
fn cast_record_batch(
batch: &RecordBatch,
target_schema: ArrowSchemaRef,
) -> DeltaResult<RecordBatch> {
let columns = target_schema
.all_fields()
.iter()
.map(|f| {
let col = batch.column_by_name(f.name()).unwrap();
if !col.data_type().equals_datatype(f.data_type()) {
cast(col, f.data_type())
} else {
Ok(col.clone())
}
})
.collect::<Result<Vec<_>, _>>()?;
Ok(RecordBatch::try_new(target_schema, columns)?)
}

#[cfg(test)]
Expand Down

0 comments on commit fc134d5

Please sign in to comment.